botserver/src/core/shared/test_utils.rs

263 lines
7.7 KiB
Rust

use crate::core::bot::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
use crate::core::bot_database::BotDatabaseManager;
use crate::core::config::AppConfig;
use crate::core::session::SessionManager;
use crate::core::shared::analytics::MetricsCollector;
use crate::core::shared::state::{AppState, Extensions};
#[cfg(feature = "directory")]
use crate::core::directory::client::ZitadelConfig;
#[cfg(feature = "directory")]
use crate::core::directory::AuthService;
#[cfg(feature = "llm")]
use crate::llm::LLMProvider;
use crate::shared::models::BotResponse;
use crate::shared::utils::{get_database_url_sync, DbPool};
use crate::tasks::TaskEngine;
use async_trait::async_trait;
use diesel::r2d2::{ConnectionManager, Pool};
use diesel::PgConnection;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, Mutex};
#[cfg(feature = "llm")]
#[derive(Debug)]
pub struct MockLLMProvider {
pub response: String,
}
#[cfg(feature = "llm")]
impl MockLLMProvider {
pub fn new() -> Self {
Self {
response: "Mock LLM response".to_string(),
}
}
pub fn with_response(response: &str) -> Self {
Self {
response: response.to_string(),
}
}
}
#[cfg(feature = "llm")]
impl Default for MockLLMProvider {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "llm")]
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
_prompt: &str,
_config: &Value,
_model: &str,
_key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(self.response.clone())
}
async fn generate_stream(
&self,
_prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
_model: &str,
_key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tx.send(self.response.clone()).await?;
Ok(())
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
#[derive(Debug)]
pub struct MockChannelAdapter {
pub name: String,
pub messages: Arc<Mutex<Vec<BotResponse>>>,
}
impl MockChannelAdapter {
pub fn new(name: &str) -> Self {
Self {
name: name.into(),
messages: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn get_sent_messages(&self) -> Vec<BotResponse> {
self.messages.lock().await.clone()
}
}
#[async_trait]
impl ChannelAdapter for MockChannelAdapter {
fn name(&self) -> &'static str {
"Mock"
}
fn is_configured(&self) -> bool {
true
}
async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.messages.lock().await.push(response);
Ok(())
}
async fn receive_message(
&self,
_payload: Value,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
Ok(Some("mock_message".to_string()))
}
async fn get_user_info(
&self,
user_id: &str,
) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
Ok(serde_json::json!({
"id": user_id,
"platform": self.name,
"name": "Mock User"
}))
}
}
#[derive(Debug)]
pub struct TestAppStateBuilder {
database_url: Option<String>,
bucket_name: String,
config: Option<AppConfig>,
}
impl TestAppStateBuilder {
pub fn new() -> Self {
Self {
database_url: None,
bucket_name: "test-bucket".to_string(),
config: None,
}
}
pub fn with_database_url(mut self, url: &str) -> Self {
self.database_url = Some(url.to_string());
self
}
pub fn with_bucket_name(mut self, name: &str) -> Self {
self.bucket_name = name.to_string();
self
}
pub fn with_config(mut self, config: AppConfig) -> Self {
self.config = Some(config);
self
}
pub fn build(self) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
let database_url = self
.database_url
.or_else(|| get_database_url_sync().ok())
.unwrap_or_else(|| "postgres://test:test@localhost:5432/test".to_string());
let manager = ConnectionManager::<PgConnection>::new(&database_url);
let pool = Pool::builder()
.max_size(1)
.test_on_check_out(false)
.connection_timeout(std::time::Duration::from_secs(5))
.build(manager)?;
let conn = pool.get()?;
let session_manager = SessionManager::new(conn, None);
let (attendant_tx, _) = broadcast::channel(100);
let (task_progress_tx, _) = broadcast::channel(100);
let bot_database_manager = Arc::new(BotDatabaseManager::new(pool.clone(), &database_url));
Ok(AppState {
#[cfg(feature = "drive")]
drive: None,
s3_client: None,
#[cfg(feature = "cache")]
cache: None,
bucket_name: self.bucket_name,
config: self.config,
conn: pool.clone(),
database_url,
bot_database_manager,
session_manager: Arc::new(tokio::sync::Mutex::new(session_manager)),
metrics_collector: MetricsCollector::new(),
task_scheduler: None,
#[cfg(feature = "llm")]
llm_provider: Arc::new(MockLLMProvider::new()),
#[cfg(feature = "directory")]
auth_service: Arc::new(tokio::sync::Mutex::new(create_mock_auth_service())),
channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
web_adapter: Arc::new(WebChannelAdapter::new()),
voice_adapter: Arc::new(VoiceAdapter::new()),
kb_manager: None,
task_engine: Arc::new(TaskEngine::new(pool)),
extensions: Extensions::new(),
attendant_broadcast: Some(attendant_tx),
task_progress_broadcast: Some(task_progress_tx),
task_manifests: Arc::new(std::sync::RwLock::new(HashMap::new())),
project_service: Arc::new(tokio::sync::RwLock::new(crate::project::ProjectService::new())),
legal_service: Arc::new(tokio::sync::RwLock::new(crate::legal::LegalService::new())),
})
}
}
impl Default for TestAppStateBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "directory")]
pub fn create_mock_auth_service() -> AuthService {
let config = ZitadelConfig {
issuer_url: "http://localhost:8080".to_string(),
issuer: "http://localhost:8080".to_string(),
client_id: "mock_client_id".to_string(),
client_secret: "mock_client_secret".to_string(),
redirect_uri: "http://localhost:3000/callback".to_string(),
project_id: "mock_project_id".to_string(),
api_url: "http://localhost:8080".to_string(),
service_account_key: None,
};
AuthService::new(config).expect("Failed to create mock AuthService")
}
pub fn create_test_db_pool() -> Result<DbPool, Box<dyn std::error::Error + Send + Sync>> {
let database_url = get_database_url_sync()
.unwrap_or_else(|_| "postgres://test:test@localhost:5432/test".to_string());
let manager = ConnectionManager::<PgConnection>::new(&database_url);
let pool = Pool::builder()
.max_size(1)
.connection_timeout(std::time::Duration::from_secs(5))
.build(manager)?;
Ok(pool)
}
pub fn create_mock_metrics_collector() -> MetricsCollector {
MetricsCollector::new()
}