use crate::core::bot::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}; use crate::core::bot_database::BotDatabaseManager; use crate::core::config::AppConfig; use crate::core::session::SessionManager; use crate::core::shared::analytics::MetricsCollector; use crate::core::shared::state::{AppState, Extensions}; #[cfg(feature = "directory")] use crate::core::directory::client::ZitadelConfig; #[cfg(feature = "directory")] use crate::core::directory::AuthService; #[cfg(feature = "llm")] use crate::llm::LLMProvider; use crate::shared::models::BotResponse; use crate::shared::utils::{get_database_url_sync, DbPool}; use crate::tasks::TaskEngine; use async_trait::async_trait; use diesel::r2d2::{ConnectionManager, Pool}; use diesel::PgConnection; use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{broadcast, mpsc, Mutex}; #[cfg(feature = "llm")] #[derive(Debug)] pub struct MockLLMProvider { pub response: String, } #[cfg(feature = "llm")] impl MockLLMProvider { pub fn new() -> Self { Self { response: "Mock LLM response".to_string(), } } pub fn with_response(response: &str) -> Self { Self { response: response.to_string(), } } } #[cfg(feature = "llm")] impl Default for MockLLMProvider { fn default() -> Self { Self::new() } } #[cfg(feature = "llm")] #[async_trait] impl LLMProvider for MockLLMProvider { async fn generate( &self, _prompt: &str, _config: &Value, _model: &str, _key: &str, ) -> Result> { Ok(self.response.clone()) } async fn generate_stream( &self, _prompt: &str, _config: &Value, tx: mpsc::Sender, _model: &str, _key: &str, ) -> Result<(), Box> { tx.send(self.response.clone()).await?; Ok(()) } async fn cancel_job( &self, _session_id: &str, ) -> Result<(), Box> { Ok(()) } } #[derive(Debug)] pub struct MockChannelAdapter { pub name: String, pub messages: Arc>>, } impl MockChannelAdapter { pub fn new(name: &str) -> Self { Self { name: name.into(), messages: Arc::new(Mutex::new(Vec::new())), } } pub async fn get_sent_messages(&self) -> Vec { self.messages.lock().await.clone() } } #[async_trait] impl ChannelAdapter for MockChannelAdapter { fn name(&self) -> &'static str { "Mock" } fn is_configured(&self) -> bool { true } async fn send_message( &self, response: BotResponse, ) -> Result<(), Box> { self.messages.lock().await.push(response); Ok(()) } async fn receive_message( &self, _payload: Value, ) -> Result, Box> { Ok(Some("mock_message".to_string())) } async fn get_user_info( &self, user_id: &str, ) -> Result> { Ok(serde_json::json!({ "id": user_id, "platform": self.name, "name": "Mock User" })) } } #[derive(Debug)] pub struct TestAppStateBuilder { database_url: Option, bucket_name: String, config: Option, } impl TestAppStateBuilder { pub fn new() -> Self { Self { database_url: None, bucket_name: "test-bucket".to_string(), config: None, } } pub fn with_database_url(mut self, url: &str) -> Self { self.database_url = Some(url.to_string()); self } pub fn with_bucket_name(mut self, name: &str) -> Self { self.bucket_name = name.to_string(); self } pub fn with_config(mut self, config: AppConfig) -> Self { self.config = Some(config); self } pub fn build(self) -> Result> { let database_url = self .database_url .or_else(|| get_database_url_sync().ok()) .unwrap_or_else(|| "postgres://test:test@localhost:5432/test".to_string()); let manager = ConnectionManager::::new(&database_url); let pool = Pool::builder() .max_size(1) .test_on_check_out(false) .connection_timeout(std::time::Duration::from_secs(5)) .build(manager)?; let conn = pool.get()?; let session_manager = SessionManager::new(conn, None); let (attendant_tx, _) = broadcast::channel(100); let (task_progress_tx, _) = broadcast::channel(100); let bot_database_manager = Arc::new(BotDatabaseManager::new(pool.clone(), &database_url)); Ok(AppState { #[cfg(feature = "drive")] drive: None, s3_client: None, #[cfg(feature = "cache")] cache: None, bucket_name: self.bucket_name, config: self.config, conn: pool.clone(), database_url, bot_database_manager, session_manager: Arc::new(tokio::sync::Mutex::new(session_manager)), metrics_collector: MetricsCollector::new(), task_scheduler: None, #[cfg(feature = "llm")] llm_provider: Arc::new(MockLLMProvider::new()), #[cfg(feature = "directory")] auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())), channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), web_adapter: Arc::new(WebChannelAdapter::new()), voice_adapter: Arc::new(VoiceAdapter::new()), kb_manager: None, task_engine: Arc::new(TaskEngine::new(pool)), extensions: Extensions::new(), attendant_broadcast: Some(attendant_tx), task_progress_broadcast: Some(task_progress_tx), task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())), project_service: Arc::new(tokio::sync::RwLock::new(crate::project::ProjectService::new())), legal_service: Arc::new(tokio::sync::RwLock::new(crate::legal::LegalService::new())), }) } } impl Default for TestAppStateBuilder { fn default() -> Self { Self::new() } } #[cfg(feature = "directory")] pub fn create_mock_auth_service() -> AuthService { let config = ZitadelConfig { issuer_url: "http://localhost:8080".to_string(), issuer: "http://localhost:8080".to_string(), client_id: "mock_client_id".to_string(), client_secret: "mock_client_secret".to_string(), redirect_uri: "http://localhost:3000/callback".to_string(), project_id: "mock_project_id".to_string(), api_url: "http://localhost:8080".to_string(), service_account_key: None, }; AuthService::new(config).expect("Failed to create mock AuthService") } pub fn create_test_db_pool() -> Result> { let database_url = get_database_url_sync() .unwrap_or_else(|_| "postgres://test:test@localhost:5432/test".to_string()); let manager = ConnectionManager::::new(&database_url); let pool = Pool::builder() .max_size(1) .connection_timeout(std::time::Duration::from_secs(5)) .build(manager)?; Ok(pool) } pub fn create_mock_metrics_collector() -> MetricsCollector { MetricsCollector::new() }