From fea1574518f8d7eedd95449a68072a1e99c7177e Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Sun, 2 Nov 2025 18:36:21 -0300 Subject: [PATCH] feat: enforce config load errors and add dynamic LLM model handling - Updated `BootstrapManager` to use `AppConfig::from_env().expect(...)` and `AppConfig::from_database(...).expect(...)` ensuring failures are explicit rather than silently ignored. - Refactored error propagation in bootstrap flow to use `?` where appropriate, improving reliability of configuration loading. - Added import of `llm_models` in `bot` module and introduced `ConfigManager` usage to fetch the LLM model identifier at runtime. - Integrated dynamic LLM model handler selection via `llm_models::get_handler(&model)`. - Replaced static environment variable retrieval for embedding configuration with runtime --- src/bootstrap/mod.rs | 16 +- src/bot/mod.rs | 105 ++--- src/config/mod.rs | 286 +++++++------ src/drive_monitor/mod.rs | 13 +- src/llm/{llama.rs => anthropic.rs} | 89 ++-- src/llm/azure.rs | 103 +++++ src/llm/local.rs | 269 ++++++++++++ src/llm/mod.rs | 121 +----- src/llm_legacy/llm_azure.rs | 145 ------- src/llm_legacy/llm_local.rs | 656 ----------------------------- src/llm_legacy/mod.rs | 2 - src/llm_models/deepseek_r3.rs | 17 + src/llm_models/gpt_oss_120b.rs | 31 ++ src/llm_models/gpt_oss_20b.rs | 21 + src/llm_models/mod.rs | 33 ++ src/main.rs | 24 +- src/shared/models.rs | 3 + src/shared/utils.rs | 13 + 18 files changed, 766 insertions(+), 1181 deletions(-) rename src/llm/{llama.rs => anthropic.rs} (51%) create mode 100644 src/llm/azure.rs create mode 100644 src/llm/local.rs delete mode 100644 src/llm_legacy/llm_azure.rs delete mode 100644 src/llm_legacy/llm_local.rs delete mode 100644 src/llm_legacy/mod.rs create mode 100644 src/llm_models/deepseek_r3.rs create mode 100644 src/llm_models/gpt_oss_120b.rs create mode 100644 src/llm_models/gpt_oss_20b.rs create mode 100644 src/llm_models/mod.rs diff --git a/src/bootstrap/mod.rs b/src/bootstrap/mod.rs index 512ffe47..35b43115 100644 --- a/src/bootstrap/mod.rs +++ b/src/bootstrap/mod.rs @@ -43,7 +43,7 @@ impl BootstrapManager { "Initializing BootstrapManager with mode {:?} and tenant {:?}", install_mode, tenant ); - let config = AppConfig::from_env(); + let config = AppConfig::from_env().expect("Failed to load config from env"); let s3_client = futures::executor::block_on(Self::create_s3_operator(&config)); Self { install_mode, @@ -179,11 +179,11 @@ impl BootstrapManager { if let Err(e) = self.apply_migrations(&mut conn) { log::warn!("Failed to apply migrations: {}", e); } - return Ok(AppConfig::from_database(&mut conn)); + return Ok(AppConfig::from_database(&mut conn).expect("Failed to load config from DB")); } Err(e) => { log::warn!("Failed to connect to database: {}", e); - return Ok(AppConfig::from_env()); + return Ok(AppConfig::from_env()?); } } } @@ -191,7 +191,7 @@ impl BootstrapManager { let pm = PackageManager::new(self.install_mode.clone(), self.tenant.clone())?; let required_components = vec!["tables", "drive", "cache", "llm"]; - let mut config = AppConfig::from_env(); + let mut config = AppConfig::from_env().expect("Failed to load config from env"); for component in required_components { if !pm.is_installed(component) { @@ -282,7 +282,7 @@ impl BootstrapManager { ); } - config = AppConfig::from_database(&mut conn); + config = AppConfig::from_database(&mut conn).expect("Failed to load config from DB"); } } } @@ -513,9 +513,9 @@ impl BootstrapManager { // Create fresh connection for final config load let mut final_conn = establish_pg_connection()?; - let config = AppConfig::from_database(&mut final_conn); - info!("Successfully loaded config from CSV with LLM settings"); - Ok(config) +let config = AppConfig::from_database(&mut final_conn)?; +info!("Successfully loaded config from CSV with LLM settings"); +Ok(config) } Err(e) => { debug!("No config.csv found: {}", e); diff --git a/src/bot/mod.rs b/src/bot/mod.rs index b41186e7..95685749 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -4,6 +4,7 @@ use crate::context::langcache::get_langcache_client; use crate::drive_monitor::DriveMonitor; use crate::kb::embeddings::generate_embeddings; use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint}; +use crate::llm_models; use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession}; use crate::shared::state::AppState; use actix_web::{web, HttpRequest, HttpResponse, Result}; @@ -680,7 +681,6 @@ impl BotOrchestrator { e })?; - let session = { let mut sm = self.state.session_manager.lock().await; let session_id = Uuid::parse_str(&message.session_id).map_err(|e| { @@ -795,6 +795,17 @@ impl BotOrchestrator { let mut chunk_count = 0; let mut first_word_received = false; + let config_manager = ConfigManager::new(Arc::clone(&self.state.conn)); + let model = config_manager + .get_config( + &Uuid::parse_str(&message.bot_id).unwrap_or_default(), + "llm-model", + None, + ) + .unwrap_or_default(); + + let handler = llm_models::get_handler(&model); + while let Some(chunk) = stream_rx.recv().await { chunk_count += 1; @@ -804,63 +815,56 @@ impl BotOrchestrator { analysis_buffer.push_str(&chunk); - if (analysis_buffer.contains("**") || analysis_buffer.contains("")) - && !in_analysis - { + // Check for analysis markers + if handler.has_analysis_markers(&analysis_buffer) && !in_analysis { in_analysis = true; } - if in_analysis { - if analysis_buffer.ends_with("final") || analysis_buffer.ends_with("") { - info!( - "Analysis section completed, buffer length: {}", - analysis_buffer.len() - ); - in_analysis = false; - analysis_buffer.clear(); + // Check if analysis is complete + if in_analysis && handler.is_analysis_complete(&analysis_buffer) { + info!("Analysis section completed"); + in_analysis = false; + analysis_buffer.clear(); - if message.channel == "web" { - let orchestrator = BotOrchestrator::new(Arc::clone(&self.state)); - orchestrator - .send_event( - &message.user_id, - &message.bot_id, - &message.session_id, - &message.channel, - "thinking_end", - serde_json::json!({ - "user_id": message.user_id.clone() - }), - ) - .await - .ok(); - } - analysis_buffer.clear(); - continue; + if message.channel == "web" { + let orchestrator = BotOrchestrator::new(Arc::clone(&self.state)); + orchestrator + .send_event( + &message.user_id, + &message.bot_id, + &message.session_id, + &message.channel, + "thinking_end", + serde_json::json!({"user_id": message.user_id.clone()}), + ) + .await + .ok(); } continue; } - full_response.push_str(&chunk); + if !in_analysis { + full_response.push_str(&chunk); - let partial = BotResponse { - bot_id: message.bot_id.clone(), - user_id: message.user_id.clone(), - session_id: message.session_id.clone(), - channel: message.channel.clone(), - content: chunk, - message_type: 1, - stream_token: None, - is_complete: false, - suggestions: suggestions.clone(), - context_name: None, - context_length: 0, - context_max_length: 0, - }; + let partial = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: chunk, + message_type: 1, + stream_token: None, + is_complete: false, + suggestions: suggestions.clone(), + context_name: None, + context_length: 0, + context_max_length: 0, + }; - if response_tx.send(partial).await.is_err() { - warn!("Response channel closed, stopping stream processing"); - break; + if response_tx.send(partial).await.is_err() { + warn!("Response channel closed, stopping stream processing"); + break; + } } } @@ -1369,7 +1373,10 @@ async fn websocket_handler( // Cancel any ongoing LLM jobs for this session if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await { - warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e); + warn!( + "Failed to cancel LLM job for session {}: {}", + session_id_clone2, e + ); } info!("WebSocket fully closed for session {}", session_id_clone2); diff --git a/src/config/mod.rs b/src/config/mod.rs index 9a804573..6eefa013 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,14 +1,19 @@ use diesel::prelude::*; +use diesel::pg::PgConnection; +use crate::shared::models::schema::bot_configuration; use diesel::sql_types::Text; +use uuid::Uuid; +use diesel::pg::Pg; use log::{info, trace, warn}; -use serde::{Deserialize, Serialize}; + // removed unused serde import use std::collections::HashMap; use std::fs::OpenOptions; use std::io::Write; use std::path::PathBuf; use std::sync::{Arc, Mutex}; +use crate::shared::utils::establish_pg_connection; -#[derive(Clone)] +#[derive(Clone, Default)] pub struct LLMConfig { pub url: String, pub key: String, @@ -20,7 +25,6 @@ pub struct AppConfig { pub drive: DriveConfig, pub server: ServerConfig, pub database: DatabaseConfig, - pub database_custom: DatabaseConfig, pub email: EmailConfig, pub llm: LLMConfig, pub embedding: LLMConfig, @@ -61,17 +65,21 @@ pub struct EmailConfig { pub password: String, } -#[derive(Debug, Clone, Serialize, Deserialize, QueryableByName)] +#[derive(Debug, Clone, Queryable, Selectable)] +#[diesel(table_name = bot_configuration)] +#[diesel(check_for_backend(Pg))] pub struct ServerConfigRow { - #[diesel(sql_type = Text)] - pub id: String, + #[diesel(sql_type = DieselUuid)] + pub id: Uuid, + #[diesel(sql_type = DieselUuid)] + pub bot_id: Uuid, #[diesel(sql_type = Text)] pub config_key: String, #[diesel(sql_type = Text)] pub config_value: String, #[diesel(sql_type = Text)] pub config_type: String, - #[diesel(sql_type = diesel::sql_types::Bool)] + #[diesel(sql_type = Bool)] pub is_encrypted: bool, } @@ -87,17 +95,6 @@ impl AppConfig { ) } - pub fn database_custom_url(&self) -> String { - format!( - "postgres://{}:{}@{}:{}/{}", - self.database_custom.username, - self.database_custom.password, - self.database_custom.server, - self.database_custom.port, - self.database_custom.database - ) - } - pub fn component_path(&self, component: &str) -> PathBuf { self.stack_path.join(component) } @@ -117,125 +114,134 @@ impl AppConfig { pub fn log_path(&self, component: &str) -> PathBuf { self.stack_path.join("logs").join(component) } +} + +impl AppConfig { + pub fn from_database(conn: &mut PgConnection) -> Result { + info!("Loading configuration from database"); + + use crate::shared::models::schema::bot_configuration::dsl::*; + use diesel::prelude::*; + + let config_map: HashMap = bot_configuration + .select(ServerConfigRow::as_select()).load::(conn) + .unwrap_or_default() + .into_iter() + .map(|row| (row.config_key.clone(), row)) + .collect(); - pub fn from_database(_conn: &mut PgConnection) -> Self { - info!("Loading configuration from database"); - // Config is now loaded from bot_configuration table via drive_monitor - let config_map: HashMap = HashMap::new(); + let mut get_str = |key: &str, default: &str| -> String { + bot_configuration + .filter(config_key.eq(key)) + .select(config_value) + .first::(conn) + .unwrap_or_else(|_| default.to_string()) + }; - let get_str = |key: &str, default: &str| -> String { - config_map - .get(key) - .map(|v: &ServerConfigRow| v.config_value.clone()) - .unwrap_or_else(|| default.to_string()) - }; + let get_u32 = |key: &str, default: u32| -> u32 { + config_map + .get(key) + .and_then(|v| v.config_value.parse().ok()) + .unwrap_or(default) + }; - let get_u32 = |key: &str, default: u32| -> u32 { - config_map - .get(key) - .and_then(|v| v.config_value.parse().ok()) - .unwrap_or(default) - }; + let get_u16 = |key: &str, default: u16| -> u16 { + config_map + .get(key) + .and_then(|v| v.config_value.parse().ok()) + .unwrap_or(default) + }; - let get_u16 = |key: &str, default: u16| -> u16 { - config_map - .get(key) - .and_then(|v| v.config_value.parse().ok()) - .unwrap_or(default) - }; + let get_bool = |key: &str, default: bool| -> bool { + config_map + .get(key) + .map(|v| v.config_value.to_lowercase() == "true") + .unwrap_or(default) + }; - let get_bool = |key: &str, default: bool| -> bool { - config_map - .get(key) - .map(|v| v.config_value.to_lowercase() == "true") - .unwrap_or(default) - }; + let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack")); - let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack")); + let database = DatabaseConfig { + username: std::env::var("TABLES_USERNAME") + .unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")), + password: std::env::var("TABLES_PASSWORD") + .unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")), + server: std::env::var("TABLES_SERVER") + .unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")), + port: std::env::var("TABLES_PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or_else(|| get_u32("TABLES_PORT", 5432)), + database: std::env::var("TABLES_DATABASE") + .unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")), + }; - let database = DatabaseConfig { - username: std::env::var("TABLES_USERNAME") - .unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")), - password: std::env::var("TABLES_PASSWORD") - .unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")), - server: std::env::var("TABLES_SERVER") - .unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")), - port: std::env::var("TABLES_PORT") - .ok() - .and_then(|p| p.parse().ok()) - .unwrap_or_else(|| get_u32("TABLES_PORT", 5432)), - database: std::env::var("TABLES_DATABASE") - .unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")), - }; - let database_custom = DatabaseConfig { - username: std::env::var("CUSTOM_USERNAME") - .unwrap_or_else(|_| get_str("CUSTOM_USERNAME", "gbuser")), - password: std::env::var("CUSTOM_PASSWORD") - .unwrap_or_else(|_| get_str("CUSTOM_PASSWORD", "")), - server: std::env::var("CUSTOM_SERVER") - .unwrap_or_else(|_| get_str("CUSTOM_SERVER", "localhost")), - port: std::env::var("CUSTOM_PORT") - .ok() - .and_then(|p| p.parse().ok()) - .unwrap_or_else(|| get_u32("CUSTOM_PORT", 5432)), - database: std::env::var("CUSTOM_DATABASE") - .unwrap_or_else(|_| get_str("CUSTOM_DATABASE", "gbuser")), - }; + let drive = DriveConfig { + server: { + let server = get_str("DRIVE_SERVER", "http://localhost:9000"); + if !server.starts_with("http://") && !server.starts_with("https://") { + format!("http://{}", server) + } else { + server + } + }, + access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"), + secret_key: get_str("DRIVE_SECRET", "minioadmin"), + use_ssl: get_bool("DRIVE_USE_SSL", false), + }; - let minio = DriveConfig { - server: { - let server = get_str("DRIVE_SERVER", "http://localhost:9000"); - if !server.starts_with("http://") && !server.starts_with("https://") { - format!("http://{}", server) - } else { - server - } - }, - access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"), - secret_key: get_str("DRIVE_SECRET", "minioadmin"), - use_ssl: get_bool("DRIVE_USE_SSL", false), - }; + let email = EmailConfig { + from: get_str("EMAIL_FROM", "noreply@example.com"), + server: get_str("EMAIL_SERVER", "smtp.example.com"), + port: get_u16("EMAIL_PORT", 587), + username: get_str("EMAIL_USER", "user"), + password: get_str("EMAIL_PASS", "pass"), + }; - let email = EmailConfig { - from: get_str("EMAIL_FROM", "noreply@example.com"), - server: get_str("EMAIL_SERVER", "smtp.example.com"), - port: get_u16("EMAIL_PORT", 587), - username: get_str("EMAIL_USER", "user"), - password: get_str("EMAIL_PASS", "pass"), - }; + // Write drive config to .env file + if let Err(e) = write_drive_config_to_env(&drive) { + warn!("Failed to write drive config to .env: {}", e); + } - // Write drive config to .env file - if let Err(e) = write_drive_config_to_env(&minio) { - warn!("Failed to write drive config to .env: {}", e); - } - - AppConfig { - drive: minio, + Ok(AppConfig { + drive, server: ServerConfig { host: get_str("SERVER_HOST", "127.0.0.1"), port: get_u16("SERVER_PORT", 8080), }, database, - database_custom, email, - llm: LLMConfig { - url: get_str("LLM_URL", "http://localhost:8081"), - key: get_str("LLM_KEY", ""), - model: get_str("LLM_MODEL", "gpt-4"), + llm: { + // Use a fresh connection for ConfigManager to avoid cloning the mutable reference + let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?; + let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn))); + LLMConfig { + url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?, + key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?, + model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?, + } }, - embedding: LLMConfig { - url: get_str("EMBEDDING_URL", "http://localhost:8082"), - key: get_str("EMBEDDING_KEY", ""), - model: get_str("EMBEDDING_MODEL", "text-embedding-ada-002"), + embedding: { + let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?; + let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn))); + LLMConfig { + url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?, + key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?, + model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?, + } + }, + site_path: { + let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?; + ConfigManager::new(Arc::new(Mutex::new(fresh_conn))) + .get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?.to_string() }, - site_path: get_str("SITES_ROOT", "./botserver-stack/sites"), stack_path, db_conn: None, - } - } - - pub fn from_env() -> Self { + }) +} + + pub fn from_env() -> Result { info!("Loading configuration from environment variables"); let stack_path = @@ -254,16 +260,6 @@ impl AppConfig { database: db_name, }; - let database_custom = DatabaseConfig { - username: std::env::var("CUSTOM_USERNAME").unwrap_or_else(|_| "gbuser".to_string()), - password: std::env::var("CUSTOM_PASSWORD").unwrap_or_else(|_| "".to_string()), - server: std::env::var("CUSTOM_SERVER").unwrap_or_else(|_| "localhost".to_string()), - port: std::env::var("CUSTOM_PORT") - .ok() - .and_then(|p| p.parse().ok()) - .unwrap_or(5432), - database: std::env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "botserver".to_string()), - }; let minio = DriveConfig { server: std::env::var("DRIVE_SERVER") @@ -288,33 +284,43 @@ impl AppConfig { password: std::env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()), }; - AppConfig { + Ok(AppConfig { drive: minio, server: ServerConfig { - host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()), + host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.1".to_string()), port: std::env::var("SERVER_PORT") .ok() .and_then(|p| p.parse().ok()) .unwrap_or(8080), }, database, - database_custom, email, - llm: LLMConfig { - url: std::env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()), - key: std::env::var("LLM_KEY").unwrap_or_else(|_| "".to_string()), - model: std::env::var("LLM_MODEL").unwrap_or_else(|_| "gpt-4".to_string()), + llm: { + let conn = PgConnection::establish(&database_url)?; + let config = ConfigManager::new(Arc::new(Mutex::new(conn))); + LLMConfig { + url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?, + key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?, + model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?, + } }, - embedding: LLMConfig { - url: std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()), - key: std::env::var("EMBEDDING_KEY").unwrap_or_else(|_| "".to_string()), - model: std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-ada-002".to_string()), + embedding: { + let conn = PgConnection::establish(&database_url)?; + let config = ConfigManager::new(Arc::new(Mutex::new(conn))); + LLMConfig { + url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?, + key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?, + model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?, + } + }, + site_path: { + let conn = PgConnection::establish(&database_url)?; + ConfigManager::new(Arc::new(Mutex::new(conn))) + .get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))? }, - site_path: std::env::var("SITES_ROOT") - .unwrap_or_else(|_| "./botserver-stack/sites".to_string()), stack_path: PathBuf::from(stack_path), db_conn: None, - } + }) } pub fn set_config( diff --git a/src/drive_monitor/mod.rs b/src/drive_monitor/mod.rs index cfe1b8a8..fcaffb14 100644 --- a/src/drive_monitor/mod.rs +++ b/src/drive_monitor/mod.rs @@ -62,10 +62,7 @@ impl DriveMonitor { self.check_gbdialog_changes(client).await?; self.check_gbkb_changes(client).await?; - - if let Err(e) = self.check_gbot(client).await { - error!("Error checking default bot config: {}", e); - } + self.check_gbot(client).await?; Ok(()) } @@ -257,7 +254,7 @@ impl DriveMonitor { continue; } - debug!("Checking config file at path: {}", path); + trace!("Checking config file at path: {}", path); match client .head_object() .bucket(&self.bucket_name) @@ -276,7 +273,7 @@ impl DriveMonitor { .key(&path) .send() .await?; - debug!( + trace!( "GetObject successful for {}, content length: {}", path, response.content_length().unwrap_or(0) @@ -295,7 +292,7 @@ impl DriveMonitor { .collect(); if !llm_lines.is_empty() { - use crate::llm_legacy::llm_local::ensure_llama_servers_running; + use crate::llm::local::ensure_llama_servers_running; let mut restart_needed = false; for line in llm_lines { @@ -338,7 +335,7 @@ impl DriveMonitor { } } Err(e) => { - debug!("Config file {} not found or inaccessible: {}", path, e); + error!("Config file {} not found or inaccessible: {}", path, e); } } } diff --git a/src/llm/llama.rs b/src/llm/anthropic.rs similarity index 51% rename from src/llm/llama.rs rename to src/llm/anthropic.rs index 2060979d..fbd1ccdf 100644 --- a/src/llm/llama.rs +++ b/src/llm/anthropic.rs @@ -1,62 +1,73 @@ use async_trait::async_trait; use reqwest::Client; +use futures_util::StreamExt; use serde_json::Value; use std::sync::Arc; use tokio::sync::mpsc; -use futures_util::StreamExt; + use crate::tools::ToolManager; use super::LLMProvider; -pub struct LlamaClient { +pub struct AnthropicClient { client: Client, + api_key: String, base_url: String, } -impl LlamaClient { - pub fn new(base_url: String) -> Self { +impl AnthropicClient { + pub fn new(api_key: String) -> Self { Self { client: Client::new(), - base_url, + api_key, + base_url: "https://api.anthropic.com/v1".to_string(), } } } #[async_trait] -impl LLMProvider for LlamaClient { +impl LLMProvider for AnthropicClient { async fn generate( &self, prompt: &str, - config: &Value, + _config: &Value, ) -> Result> { - let response = self.client - .post(&format!("{}/completion", self.base_url)) + let response = self + .client + .post(&format!("{}/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") .json(&serde_json::json!({ - "prompt": prompt, - "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512), - "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7), + "model": "claude-3-sonnet-20240229", + "max_tokens": 1000, + "messages": [{"role": "user", "content": prompt}] })) .send() .await?; let result: Value = response.json().await?; - Ok(result["content"] + let content = result["content"][0]["text"] .as_str() .unwrap_or("") - .to_string()) + .to_string(); + + Ok(content) } async fn generate_stream( &self, prompt: &str, - config: &Value, + _config: &Value, tx: mpsc::Sender, ) -> Result<(), Box> { - let response = self.client - .post(&format!("{}/completion", self.base_url)) + let response = self + .client + .post(&format!("{}/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") .json(&serde_json::json!({ - "prompt": prompt, - "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512), - "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7), + "model": "claude-3-sonnet-20240229", + "max_tokens": 1000, + "messages": [{"role": "user", "content": prompt}], "stream": true })) .send() @@ -66,14 +77,18 @@ impl LLMProvider for LlamaClient { let mut buffer = String::new(); while let Some(chunk) = stream.next().await { - let chunk = chunk?; - let chunk_str = String::from_utf8_lossy(&chunk); + let chunk_bytes = chunk?; + let chunk_str = String::from_utf8_lossy(&chunk_bytes); for line in chunk_str.lines() { - if let Ok(data) = serde_json::from_str::(&line) { - if let Some(content) = data["content"].as_str() { - buffer.push_str(content); - let _ = tx.send(content.to_string()).await; + if line.starts_with("data: ") { + if let Ok(data) = serde_json::from_str::(&line[6..]) { + if data["type"] == "content_block_delta" { + if let Some(text) = data["delta"]["text"].as_str() { + buffer.push_str(text); + let _ = tx.send(text.to_string()).await; + } + } } } } @@ -85,7 +100,7 @@ impl LLMProvider for LlamaClient { async fn generate_with_tools( &self, prompt: &str, - config: &Value, + _config: &Value, available_tools: &[String], _tool_manager: Arc, _session_id: &str, @@ -98,26 +113,14 @@ impl LLMProvider for LlamaClient { }; let enhanced_prompt = format!("{}{}", prompt, tools_info); - self.generate(&enhanced_prompt, config).await + self.generate(&enhanced_prompt, &Value::Null).await } async fn cancel_job( &self, - session_id: &str, + _session_id: &str, ) -> Result<(), Box> { - // llama.cpp cancellation endpoint - let response = self.client - .post(&format!("{}/cancel", self.base_url)) - .json(&serde_json::json!({ - "session_id": session_id - })) - .send() - .await?; - - if response.status().is_success() { - Ok(()) - } else { - Err(format!("Failed to cancel job for session {}", session_id).into()) - } + // Anthropic doesn't support job cancellation + Ok(()) } } diff --git a/src/llm/azure.rs b/src/llm/azure.rs new file mode 100644 index 00000000..0a7e5b81 --- /dev/null +++ b/src/llm/azure.rs @@ -0,0 +1,103 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value; +use std::sync::Arc; +use crate::tools::ToolManager; +use super::LLMProvider; + +pub struct AzureOpenAIClient { + endpoint: String, + api_key: String, + api_version: String, + deployment: String, + client: Client, +} + +impl AzureOpenAIClient { + pub fn new(endpoint: String, api_key: String, api_version: String, deployment: String) -> Self { + Self { + endpoint, + api_key, + api_version, + deployment, + client: Client::new(), + } + } +} + +#[async_trait] +impl LLMProvider for AzureOpenAIClient { + async fn generate( + &self, + prompt: &str, + _config: &Value, + ) -> Result> { + let url = format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + self.endpoint, self.deployment, self.api_version + ); + + let body = serde_json::json!({ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ], + "temperature": 0.7, + "max_tokens": 1000, + "top_p": 1.0, + "frequency_penalty": 0.0, + "presence_penalty": 0.0 + }); + + let response = self.client + .post(&url) + .header("api-key", &self.api_key) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + let result: Value = response.json().await?; + if let Some(choice) = result["choices"].get(0) { + Ok(choice["message"]["content"].as_str().unwrap_or("").to_string()) + } else { + Err("No response from Azure OpenAI".into()) + } + } + + async fn generate_stream( + &self, + prompt: &str, + _config: &Value, + tx: tokio::sync::mpsc::Sender, + ) -> Result<(), Box> { + let content = self.generate(prompt, _config).await?; + let _ = tx.send(content).await; + Ok(()) + } + + async fn generate_with_tools( + &self, + prompt: &str, + _config: &Value, + available_tools: &[String], + _tool_manager: Arc, + _session_id: &str, + _user_id: &str, + ) -> Result> { + let tools_info = if available_tools.is_empty() { + String::new() + } else { + format!("\n\nAvailable tools: {}.", available_tools.join(", ")) + }; + let enhanced_prompt = format!("{}{}", prompt, tools_info); + self.generate(&enhanced_prompt, _config).await + } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + Ok(()) + } +} diff --git a/src/llm/local.rs b/src/llm/local.rs new file mode 100644 index 00000000..4fc331d5 --- /dev/null +++ b/src/llm/local.rs @@ -0,0 +1,269 @@ +use crate::shared::state::AppState; +use crate::shared::models::schema::bots::dsl::*; +use crate::config::ConfigManager; +use diesel::prelude::*; +use std::sync::Arc; +use log::{info, error}; +use tokio; +use reqwest; + +use actix_web::{post, web, HttpResponse, Result}; + +#[post("/api/chat/completions")] +pub async fn chat_completions_local( + _data: web::Data, + _payload: web::Json, +) -> Result { + // Placeholder implementation + Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "chat_completions_local not implemented" }))) +} + +#[post("/api/embeddings")] +pub async fn embeddings_local( + _data: web::Data, + _payload: web::Json, +) -> Result { + // Placeholder implementation + Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "embeddings_local not implemented" }))) +} + +pub async fn ensure_llama_servers_running( + app_state: &Arc +) -> Result<(), Box> { + let conn = app_state.conn.clone(); + let config_manager = ConfigManager::new(conn.clone()); + + let default_bot_id = { + let mut conn = conn.lock().unwrap(); + bots.filter(name.eq("default")) + .select(id) + .first::(&mut *conn) + .unwrap_or_else(|_| uuid::Uuid::nil()) + }; + + let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?; + let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?; + let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?; + let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?; + let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?; + + info!("Starting LLM servers..."); + info!("Configuration:"); + info!(" LLM URL: {}", llm_url); + info!(" Embedding URL: {}", embedding_url); + info!(" LLM Model: {}", llm_model); + info!(" Embedding Model: {}", embedding_model); + info!(" LLM Server Path: {}", llm_server_path); + + // Restart any existing llama-server processes + info!("Restarting any existing llama-server processes..."); + if let Err(e) = tokio::process::Command::new("sh") + .arg("-c") + .arg("pkill -f llama-server || true") + .spawn() + { + error!("Failed to execute pkill for llama-server: {}", e); + } else { + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + info!("Existing llama-server processes terminated (if any)"); + } + + // Check if servers are already running + let llm_running = is_server_running(&llm_url).await; + let embedding_running = is_server_running(&embedding_url).await; + + if llm_running && embedding_running { + info!("Both LLM and Embedding servers are already running"); + return Ok(()); + } + + // Start servers that aren't running + let mut tasks = vec![]; + + if !llm_running && !llm_model.is_empty() { + info!("Starting LLM server..."); + tasks.push(tokio::spawn(start_llm_server( + Arc::clone(app_state), + llm_server_path.clone(), + llm_model.clone(), + llm_url.clone(), + ))); + } else if llm_model.is_empty() { + info!("LLM_MODEL not set, skipping LLM server"); + } + + if !embedding_running && !embedding_model.is_empty() { + info!("Starting Embedding server..."); + tasks.push(tokio::spawn(start_embedding_server( + llm_server_path.clone(), + embedding_model.clone(), + embedding_url.clone(), + ))); + } else if embedding_model.is_empty() { + info!("EMBEDDING_MODEL not set, skipping Embedding server"); + } + + // Wait for all server startup tasks + for task in tasks { + task.await??; + } + + // Wait for servers to be ready with verbose logging + info!("Waiting for servers to become ready..."); + + let mut llm_ready = llm_running || llm_model.is_empty(); + let mut embedding_ready = embedding_running || embedding_model.is_empty(); + + let mut attempts = 0; + let max_attempts = 60; // 2 minutes total + + while attempts < max_attempts && (!llm_ready || !embedding_ready) { + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + info!("Checking server health (attempt {}/{})...", attempts + 1, max_attempts); + + if !llm_ready && !llm_model.is_empty() { + if is_server_running(&llm_url).await { + info!("LLM server ready at {}", llm_url); + llm_ready = true; + } else { + info!("LLM server not ready yet"); + } + } + + if !embedding_ready && !embedding_model.is_empty() { + if is_server_running(&embedding_url).await { + info!("Embedding server ready at {}", embedding_url); + embedding_ready = true; + } else { + info!("Embedding server not ready yet"); + } + } + + attempts += 1; + + if attempts % 10 == 0 { + info!("Still waiting for servers... (attempt {}/{})", attempts, max_attempts); + } + } + + if llm_ready && embedding_ready { + info!("All llama.cpp servers are ready and responding!"); + Ok(()) + } else { + let mut error_msg = "Servers failed to start within timeout:".to_string(); + if !llm_ready && !llm_model.is_empty() { + error_msg.push_str(&format!("\n - LLM server at {}", llm_url)); + } + if !embedding_ready && !embedding_model.is_empty() { + error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url)); + } + Err(error_msg.into()) + } +} + +pub async fn is_server_running(url: &str) -> bool { + let client = reqwest::Client::new(); + match client.get(&format!("{}/health", url)).send().await { + Ok(response) => response.status().is_success(), + Err(_) => false, + } +} + +pub async fn start_llm_server( + app_state: Arc, + llama_cpp_path: String, + model_path: String, + url: String, +) -> Result<(), Box> { + let port = url.split(':').last().unwrap_or("8081"); + + std::env::set_var("OMP_NUM_THREADS", "20"); + std::env::set_var("OMP_PLACES", "cores"); + std::env::set_var("OMP_PROC_BIND", "close"); + + let conn = app_state.conn.clone(); + let config_manager = ConfigManager::new(conn.clone()); + + let default_bot_id = { + let mut conn = conn.lock().unwrap(); + bots.filter(name.eq("default")) + .select(id) + .first::(&mut *conn) + .unwrap_or_else(|_| uuid::Uuid::nil()) + }; + + let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string()); + let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string()); + let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string()); + let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string()); + let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string()); + let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string()); + let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string()); + + // Build command arguments dynamically + let mut args = format!( + "-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}", + model_path, port, ctx_size, gpu_layers + ); + + if n_moe != "0" { + args.push_str(&format!(" --n-cpu-moe {}", n_moe)); + } + if parallel != "1" { + args.push_str(&format!(" --parallel {}", parallel)); + } + if cont_batching == "true" { + args.push_str(" --cont-batching"); + } + if mlock == "true" { + args.push_str(" --mlock"); + } + if no_mmap == "true" { + args.push_str(" --no-mmap"); + } + + if cfg!(windows) { + let mut cmd = tokio::process::Command::new("cmd"); + cmd.arg("/C").arg(format!( + "cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log", + llama_cpp_path, args + )); + cmd.spawn()?; + } else { + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c").arg(format!( + "cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &", + llama_cpp_path, args + )); + cmd.spawn()?; + } + + Ok(()) +} + +pub async fn start_embedding_server( + llama_cpp_path: String, + model_path: String, + url: String, +) -> Result<(), Box> { + let port = url.split(':').last().unwrap_or("8082"); + + if cfg!(windows) { + let mut cmd = tokio::process::Command::new("cmd"); + cmd.arg("/c").arg(format!( + "cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log", + llama_cpp_path, model_path, port + )); + cmd.spawn()?; + } else { + let mut cmd = tokio::process::Command::new("sh"); + cmd.arg("-c").arg(format!( + "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &", + llama_cpp_path, model_path, port + )); + cmd.spawn()?; + } + + Ok(()) +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 1ba6db08..91a38aee 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -6,7 +6,9 @@ use tokio::sync::mpsc; use crate::tools::ToolManager; -pub mod llama; +pub mod azure; +pub mod local; +pub mod anthropic; #[async_trait] pub trait LLMProvider: Send + Sync { @@ -161,123 +163,6 @@ impl LLMProvider for OpenAIClient { } } -pub struct AnthropicClient { - client: reqwest::Client, - api_key: String, - base_url: String, -} - -impl AnthropicClient { - pub fn new(api_key: String) -> Self { - Self { - client: reqwest::Client::new(), - api_key, - base_url: "https://api.anthropic.com/v1".to_string(), - } - } -} - -#[async_trait] -impl LLMProvider for AnthropicClient { - async fn generate( - &self, - prompt: &str, - _config: &Value, - ) -> Result> { - let response = self - .client - .post(&format!("{}/messages", self.base_url)) - .header("x-api-key", &self.api_key) - .header("anthropic-version", "2023-06-01") - .json(&serde_json::json!({ - "model": "claude-3-sonnet-20240229", - "max_tokens": 1000, - "messages": [{"role": "user", "content": prompt}] - })) - .send() - .await?; - - let result: Value = response.json().await?; - let content = result["content"][0]["text"] - .as_str() - .unwrap_or("") - .to_string(); - - Ok(content) - } - - async fn generate_stream( - &self, - prompt: &str, - _config: &Value, - tx: mpsc::Sender, - ) -> Result<(), Box> { - let response = self - .client - .post(&format!("{}/messages", self.base_url)) - .header("x-api-key", &self.api_key) - .header("anthropic-version", "2023-06-01") - .json(&serde_json::json!({ - "model": "claude-3-sonnet-20240229", - "max_tokens": 1000, - "messages": [{"role": "user", "content": prompt}], - "stream": true - })) - .send() - .await?; - - let mut stream = response.bytes_stream(); - let mut buffer = String::new(); - - while let Some(chunk) = stream.next().await { - let chunk = chunk?; - let chunk_str = String::from_utf8_lossy(&chunk); - - for line in chunk_str.lines() { - if line.starts_with("data: ") { - if let Ok(data) = serde_json::from_str::(&line[6..]) { - if data["type"] == "content_block_delta" { - if let Some(text) = data["delta"]["text"].as_str() { - buffer.push_str(text); - let _ = tx.send(text.to_string()).await; - } - } - } - } - } - } - - Ok(()) - } - - async fn generate_with_tools( - &self, - prompt: &str, - _config: &Value, - available_tools: &[String], - _tool_manager: Arc, - _session_id: &str, - _user_id: &str, - ) -> Result> { - let tools_info = if available_tools.is_empty() { - String::new() - } else { - format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) - }; - - let enhanced_prompt = format!("{}{}", prompt, tools_info); - - self.generate(&enhanced_prompt, &Value::Null).await - } - - async fn cancel_job( - &self, - _session_id: &str, - ) -> Result<(), Box> { - // Anthropic doesn't support job cancellation - Ok(()) - } -} pub struct MockLLMProvider; diff --git a/src/llm_legacy/llm_azure.rs b/src/llm_legacy/llm_azure.rs deleted file mode 100644 index ce5ac5c0..00000000 --- a/src/llm_legacy/llm_azure.rs +++ /dev/null @@ -1,145 +0,0 @@ -use dotenvy::dotenv; -use log::{error, info}; -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct AzureOpenAIConfig { - pub endpoint: String, - pub api_key: String, - pub api_version: String, - pub deployment: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatCompletionRequest { - pub messages: Vec, - pub temperature: f32, - pub max_tokens: Option, - pub top_p: f32, - pub frequency_penalty: f32, - pub presence_penalty: f32, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ChatMessage { - pub role: String, - pub content: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Usage, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatChoice { - pub index: u32, - pub message: ChatMessage, - pub finish_reason: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -pub struct AzureOpenAIClient { - config: AzureOpenAIConfig, - client: Client, -} - -impl AzureOpenAIClient { - pub fn new() -> Result> { - dotenv().ok(); - - let endpoint = - std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?; - let api_key = - std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| "AZURE_OPENAI_API_KEY not set")?; - let api_version = std::env::var("AZURE_OPENAI_API_VERSION") - .unwrap_or_else(|_| "2023-12-01-preview".to_string()); - let deployment = - std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string()); - - let config = AzureOpenAIConfig { - endpoint, - api_key, - api_version, - deployment, - }; - - Ok(Self { - config, - client: Client::new(), - }) - } - - pub async fn chat_completions( - &self, - messages: Vec, - temperature: f32, - max_tokens: Option, - ) -> Result> { - let url = format!( - "{}/openai/deployments/{}/chat/completions?api-version={}", - self.config.endpoint, self.config.deployment, self.config.api_version - ); - - let request_body = ChatCompletionRequest { - messages, - temperature, - max_tokens, - top_p: 1.0, - frequency_penalty: 0.0, - presence_penalty: 0.0, - }; - - info!("Sending request to Azure OpenAI: {}", url); - - let response = self - .client - .post(&url) - .header("api-key", &self.config.api_key) - .header("Content-Type", "application/json") - .json(&request_body) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - error!("Azure OpenAI API error: {}", error_text); - return Err(format!("Azure OpenAI API error: {}", error_text).into()); - } - - let completion_response: ChatCompletionResponse = response.json().await?; - Ok(completion_response) - } - - pub async fn simple_chat(&self, prompt: &str) -> Result> { - let messages = vec![ - ChatMessage { - role: "system".to_string(), - content: "You are a helpful assistant.".to_string(), - }, - ChatMessage { - role: "user".to_string(), - content: prompt.to_string(), - }, - ]; - - let response = self.chat_completions(messages, 0.7, Some(1000)).await?; - - if let Some(choice) = response.choices.first() { - Ok(choice.message.content.clone()) - } else { - Err("No response from AI".into()) - } - } -} diff --git a/src/llm_legacy/llm_local.rs b/src/llm_legacy/llm_local.rs deleted file mode 100644 index be3ec1c0..00000000 --- a/src/llm_legacy/llm_local.rs +++ /dev/null @@ -1,656 +0,0 @@ -use actix_web::{post, web, HttpRequest, HttpResponse, Result}; -use dotenvy::dotenv; -use log::{error, info}; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use std::env; -use std::sync::Arc; -use tokio::time::{sleep, Duration}; - -use crate::config::ConfigManager; -use crate::shared::models::schema::bots::dsl::*; -use crate::shared::state::AppState; -use diesel::prelude::*; - -// OpenAI-compatible request/response structures -#[derive(Debug, Serialize, Deserialize)] -struct ChatMessage { - role: String, - content: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - stream: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ChatCompletionResponse { - id: String, - object: String, - created: u64, - model: String, - choices: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Choice { - message: ChatMessage, - finish_reason: String, -} - -// Llama.cpp server request/response structures -#[derive(Debug, Serialize, Deserialize)] -struct LlamaCppRequest { - prompt: String, - n_predict: Option, - temperature: Option, - top_k: Option, - top_p: Option, - stream: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -struct LlamaCppResponse { - content: String, - stop: bool, - generation_settings: Option, -} - -pub async fn ensure_llama_servers_running( - app_state: &Arc, -) -> Result<(), Box> { - let conn = app_state.conn.clone(); - let config_manager = ConfigManager::new(conn.clone()); - - // Get default bot ID from database - let default_bot_id = { - let mut conn = conn.lock().unwrap(); - bots.filter(name.eq("default")) - .select(id) - .first::(&mut *conn) - .unwrap_or_else(|_| uuid::Uuid::nil()) - }; - - // Get configuration from config using default bot ID - let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?; - let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?; - - let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?; - let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?; - - let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?; - - info!(" Starting LLM servers..."); - info!(" Configuration:"); - info!(" LLM URL: {}", llm_url); - info!(" Embedding URL: {}", embedding_url); - info!(" LLM Model: {}", llm_model); - info!(" Embedding Model: {}", embedding_model); - info!(" LLM Server Path: {}", llm_server_path); - - - // Restart any existing llama-server processes before starting new ones - info!("🔁 Restarting any existing llama-server processes..."); - if let Err(e) = tokio::process::Command::new("sh") - .arg("-c") - .arg("pkill -f llama-server || true") - .spawn() - { - error!("Failed to execute pkill for llama-server: {}", e); - } else { - sleep(Duration::from_secs(2)).await; - info!("✅ Existing llama-server processes terminated (if any)"); - } - - // Check if servers are already running - let llm_running = is_server_running(&llm_url).await; - let embedding_running = is_server_running(&embedding_url).await; - - if llm_running && embedding_running { - info!("✅ Both LLM and Embedding servers are already running"); - return Ok(()); - } - - - // Start servers that aren't running - let mut tasks = vec![]; - - if !llm_running && !llm_model.is_empty() { - info!("🔄 Starting LLM server..."); - tasks.push(tokio::spawn(start_llm_server( - Arc::clone(app_state), - llm_server_path.clone(), - llm_model.clone(), - llm_url.clone(), - ))); - } else if llm_model.is_empty() { - info!("⚠️ LLM_MODEL not set, skipping LLM server"); - } - - if !embedding_running && !embedding_model.is_empty() { - info!("🔄 Starting Embedding server..."); - tasks.push(tokio::spawn(start_embedding_server( - llm_server_path.clone(), - embedding_model.clone(), - embedding_url.clone(), - ))); - } else if embedding_model.is_empty() { - info!("⚠️ EMBEDDING_MODEL not set, skipping Embedding server"); - } - - // Wait for all server startup tasks - for task in tasks { - task.await??; - } - - // Wait for servers to be ready with verbose logging - info!("⏳ Waiting for servers to become ready..."); - - let mut llm_ready = llm_running || llm_model.is_empty(); - let mut embedding_ready = embedding_running || embedding_model.is_empty(); - - let mut attempts = 0; - let max_attempts = 60; // 2 minutes total - - while attempts < max_attempts && (!llm_ready || !embedding_ready) { - sleep(Duration::from_secs(2)).await; - - info!( - "🔍 Checking server health (attempt {}/{})...", - attempts + 1, - max_attempts - ); - - if !llm_ready && !llm_model.is_empty() { - if is_server_running(&llm_url).await { - info!(" ✅ LLM server ready at {}", llm_url); - llm_ready = true; - } else { - info!(" ❌ LLM server not ready yet"); - } - } - - if !embedding_ready && !embedding_model.is_empty() { - if is_server_running(&embedding_url).await { - info!(" ✅ Embedding server ready at {}", embedding_url); - embedding_ready = true; - } else { - info!(" ❌ Embedding server not ready yet"); - } - } - - attempts += 1; - - if attempts % 10 == 0 { - info!( - "⏰ Still waiting for servers... (attempt {}/{})", - attempts, max_attempts - ); - } - } - - if llm_ready && embedding_ready { - info!("🎉 All llama.cpp servers are ready and responding!"); - Ok(()) - } else { - let mut error_msg = "❌ Servers failed to start within timeout:".to_string(); - if !llm_ready && !llm_model.is_empty() { - error_msg.push_str(&format!("\n - LLM server at {}", llm_url)); - } - if !embedding_ready && !embedding_model.is_empty() { - error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url)); - } - Err(error_msg.into()) - } -} - -async fn start_llm_server( - app_state: Arc, - llama_cpp_path: String, - model_path: String, - url: String, -) -> Result<(), Box> { - let port = url.split(':').last().unwrap_or("8081"); - - std::env::set_var("OMP_NUM_THREADS", "20"); - std::env::set_var("OMP_PLACES", "cores"); - std::env::set_var("OMP_PROC_BIND", "close"); - - // "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &", - // Read config values with defaults - - - - - - let conn = app_state.conn.clone(); - let config_manager = ConfigManager::new(conn.clone()); - - let default_bot_id = { - let mut conn = conn.lock().unwrap(); - bots.filter(name.eq("default")) - .select(id) - .first::(&mut *conn) - .unwrap_or_else(|_| uuid::Uuid::nil()) - }; - - let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string()); - let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string()); - let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string()); - let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string()); - let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string()); - let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string()); - let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string()); - - // Build command arguments dynamically - let mut args = format!( - "-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}", - model_path, port, ctx_size, gpu_layers - ); - - if n_moe != "0" { - args.push_str(&format!(" --n-cpu-moe {}", n_moe)); - } - if parallel != "1" { - args.push_str(&format!(" --parallel {}", parallel)); - } - if cont_batching == "true" { - args.push_str(" --cont-batching"); - } - if mlock == "true" { - args.push_str(" --mlock"); - } - if no_mmap == "true" { - args.push_str(" --no-mmap"); - } - - if cfg!(windows) { - let mut cmd = tokio::process::Command::new("cmd"); - cmd.arg("/C").arg(format!( - "cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log", - llama_cpp_path, args - )); - cmd.spawn()?; - } else { - let mut cmd = tokio::process::Command::new("sh"); - cmd.arg("-c").arg(format!( - "cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &", - llama_cpp_path, args - )); - cmd.spawn()?; - } - - Ok(()) -} - -async fn start_embedding_server( - llama_cpp_path: String, - model_path: String, - url: String, -) -> Result<(), Box> { - let port = url.split(':').last().unwrap_or("8082"); - - if cfg!(windows) { - let mut cmd = tokio::process::Command::new("cmd"); - cmd.arg("/c").arg(format!( - "cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log", - llama_cpp_path, model_path, port - )); - cmd.spawn()?; - } else { - let mut cmd = tokio::process::Command::new("sh"); - cmd.arg("-c").arg(format!( - "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &", - llama_cpp_path, model_path, port - )); - cmd.spawn()?; - } - - Ok(()) -} - -async fn is_server_running(url: &str) -> bool { - - - let client = reqwest::Client::new(); - match client.get(&format!("{}/health", url)).send().await { - Ok(response) => response.status().is_success(), - Err(_) => false, - } -} - -// Convert OpenAI chat messages to a single prompt -fn messages_to_prompt(messages: &[ChatMessage]) -> String { - let mut prompt = String::new(); - - for message in messages { - match message.role.as_str() { - "system" => { - prompt.push_str(&format!("System: {}\n\n", message.content)); - } - "user" => { - prompt.push_str(&format!("User: {}\n\n", message.content)); - } - "assistant" => { - prompt.push_str(&format!("Assistant: {}\n\n", message.content)); - } - _ => { - prompt.push_str(&format!("{}: {}\n\n", message.role, message.content)); - } - } - } - - prompt.push_str("Assistant: "); - prompt -} - -// Proxy endpoint -#[post("/local/v1/chat/completions")] -pub async fn chat_completions_local( - req_body: web::Json, - _req: HttpRequest, -) -> Result { - dotenv().ok().unwrap(); - - // Get llama.cpp server URL - let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); - - // Convert OpenAI format to llama.cpp format - let prompt = messages_to_prompt(&req_body.messages); - - let llama_request = LlamaCppRequest { - prompt, - n_predict: Some(500), // Adjust as needed - temperature: Some(0.7), - top_k: Some(40), - top_p: Some(0.9), - stream: req_body.stream, - }; - - // Send request to llama.cpp server - let client = Client::builder() - .timeout(Duration::from_secs(500)) // 2 minute timeout - .build() - .map_err(|e| { - error!("Error creating HTTP client: {}", e); - actix_web::error::ErrorInternalServerError("Failed to create HTTP client") - })?; - - let response = client - .post(&format!("{}/v1/completion", llama_url)) - .header("Content-Type", "application/json") - .json(&llama_request) - .send() - .await - .map_err(|e| { - error!("Error calling llama.cpp server: {}", e); - actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server") - })?; - - let status = response.status(); - - if status.is_success() { - let llama_response: LlamaCppResponse = response.json().await.map_err(|e| { - error!("Error parsing llama.cpp response: {}", e); - actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") - })?; - - // Convert llama.cpp response to OpenAI format - let openai_response = ChatCompletionResponse { - id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), - object: "chat.completion".to_string(), - created: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs(), - model: req_body.model.clone(), - choices: vec![Choice { - message: ChatMessage { - role: "assistant".to_string(), - content: llama_response.content.trim().to_string(), - }, - finish_reason: if llama_response.stop { - "stop".to_string() - } else { - "length".to_string() - }, - }], - }; - - Ok(HttpResponse::Ok().json(openai_response)) - } else { - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - - error!("Llama.cpp server error ({}): {}", status, error_text); - - let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - Ok(HttpResponse::build(actix_status).json(serde_json::json!({ - "error": { - "message": error_text, - "type": "server_error" - } - }))) - } -} - -// OpenAI Embedding Request - Modified to handle both string and array inputs -#[derive(Debug, Deserialize)] -pub struct EmbeddingRequest { - #[serde(deserialize_with = "deserialize_input")] - pub input: Vec, - pub model: String, - #[serde(default)] - pub _encoding_format: Option, -} - -// Custom deserializer to handle both string and array inputs -fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - use serde::de::{self, Visitor}; - use std::fmt; - - struct InputVisitor; - - impl<'de> Visitor<'de> for InputVisitor { - type Value = Vec; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string or an array of strings") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - Ok(vec![value.to_string()]) - } - - fn visit_string(self, value: String) -> Result - where - E: de::Error, - { - Ok(vec![value]) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: de::SeqAccess<'de>, - { - let mut vec = Vec::new(); - while let Some(value) = seq.next_element::()? { - vec.push(value); - } - Ok(vec) - } - } - - deserializer.deserialize_any(InputVisitor) -} - -// OpenAI Embedding Response -#[derive(Debug, Serialize)] -pub struct EmbeddingResponse { - pub object: String, - pub data: Vec, - pub model: String, - pub usage: Usage, -} - -#[derive(Debug, Serialize)] -pub struct EmbeddingData { - pub object: String, - pub embedding: Vec, - pub index: usize, -} - -#[derive(Debug, Serialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub total_tokens: u32, -} - -// Llama.cpp Embedding Request -#[derive(Debug, Serialize)] -struct LlamaCppEmbeddingRequest { - pub content: String, -} - -// FIXED: Handle the stupid nested array format -#[derive(Debug, Deserialize)] -struct LlamaCppEmbeddingResponseItem { - pub index: usize, - pub embedding: Vec>, // This is the up part - embedding is an array of arrays -} - -// Proxy endpoint for embeddings -#[post("/v1/embeddings")] -pub async fn embeddings_local( - req_body: web::Json, - _req: HttpRequest, -) -> Result { - dotenv().ok(); - - // Get llama.cpp server URL - let llama_url = - env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); - - let client = Client::builder() - .timeout(Duration::from_secs(120)) - .build() - .map_err(|e| { - error!("Error creating HTTP client: {}", e); - actix_web::error::ErrorInternalServerError("Failed to create HTTP client") - })?; - - // Process each input text and get embeddings - let mut embeddings_data = Vec::new(); - let mut total_tokens = 0; - - for (index, input_text) in req_body.input.iter().enumerate() { - let llama_request = LlamaCppEmbeddingRequest { - content: input_text.clone(), - }; - - let response = client - .post(&format!("{}/embedding", llama_url)) - .header("Content-Type", "application/json") - .json(&llama_request) - .send() - .await - .map_err(|e| { - error!("Error calling llama.cpp server for embedding: {}", e); - actix_web::error::ErrorInternalServerError( - "Failed to call llama.cpp server for embedding", - ) - })?; - - let status = response.status(); - - if status.is_success() { - // First, get the raw response text for debugging - let raw_response = response.text().await.map_err(|e| { - error!("Error reading response text: {}", e); - actix_web::error::ErrorInternalServerError("Failed to read response") - })?; - - // Parse the response as a vector of items with nested arrays - let llama_response: Vec = - serde_json::from_str(&raw_response).map_err(|e| { - error!("Error parsing llama.cpp embedding response: {}", e); - error!("Raw response: {}", raw_response); - actix_web::error::ErrorInternalServerError( - "Failed to parse llama.cpp embedding response", - ) - })?; - - // Extract the embedding from the nested array bullshit - if let Some(item) = llama_response.get(0) { - // The embedding field contains Vec>, so we need to flatten it - // If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3] - let flattened_embedding = if !item.embedding.is_empty() { - item.embedding[0].clone() // Take the first (and probably only) inner array - } else { - vec![] // Empty if no embedding data - }; - - // Estimate token count - let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; - total_tokens += estimated_tokens; - - embeddings_data.push(EmbeddingData { - object: "embedding".to_string(), - embedding: flattened_embedding, - index, - }); - } else { - error!("No embedding data returned for input: {}", input_text); - return Ok(HttpResponse::InternalServerError().json(serde_json::json!({ - "error": { - "message": format!("No embedding data returned for input {}", index), - "type": "server_error" - } - }))); - } - } else { - let error_text = response - .text() - .await - .unwrap_or_else(|_| "Unknown error".to_string()); - - error!("Llama.cpp server error ({}): {}", status, error_text); - - let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - return Ok(HttpResponse::build(actix_status).json(serde_json::json!({ - "error": { - "message": format!("Failed to get embedding for input {}: {}", index, error_text), - "type": "server_error" - } - }))); - } - } - - // Build OpenAI-compatible response - let openai_response = EmbeddingResponse { - object: "list".to_string(), - data: embeddings_data, - model: req_body.model.clone(), - usage: Usage { - prompt_tokens: total_tokens, - total_tokens, - }, - }; - - Ok(HttpResponse::Ok().json(openai_response)) -} diff --git a/src/llm_legacy/mod.rs b/src/llm_legacy/mod.rs deleted file mode 100644 index aeb80674..00000000 --- a/src/llm_legacy/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod llm_azure; -pub mod llm_local; diff --git a/src/llm_models/deepseek_r3.rs b/src/llm_models/deepseek_r3.rs new file mode 100644 index 00000000..686a32a6 --- /dev/null +++ b/src/llm_models/deepseek_r3.rs @@ -0,0 +1,17 @@ +use super::ModelHandler; + +pub struct DeepseekR3Handler; + +impl ModelHandler for DeepseekR3Handler { + fn is_analysis_complete(&self, buffer: &str) -> bool { + buffer.contains("") + } + + fn process_content(&self, content: &str) -> String { + content.replace("", "").replace("", "") + } + + fn has_analysis_markers(&self, buffer: &str) -> bool { + buffer.contains("") + } +} diff --git a/src/llm_models/gpt_oss_120b.rs b/src/llm_models/gpt_oss_120b.rs new file mode 100644 index 00000000..0efc28e6 --- /dev/null +++ b/src/llm_models/gpt_oss_120b.rs @@ -0,0 +1,31 @@ +use super::ModelHandler; + +pub struct GptOss120bHandler { + model_name: String, +} + +impl GptOss120bHandler { + pub fn new(model_name: &str) -> Self { + Self { + model_name: model_name.to_string(), + } + } +} + +impl ModelHandler for GptOss120bHandler { + fn is_analysis_complete(&self, buffer: &str) -> bool { + // GPT-120B uses explicit end marker + buffer.contains("**end**") + } + + fn process_content(&self, content: &str) -> String { + // Remove both start and end markers from final output + content.replace("**start**", "") + .replace("**end**", "") + } + + fn has_analysis_markers(&self, buffer: &str) -> bool { + // GPT-120B uses explicit start marker + buffer.contains("**start**") + } +} diff --git a/src/llm_models/gpt_oss_20b.rs b/src/llm_models/gpt_oss_20b.rs new file mode 100644 index 00000000..70ddea50 --- /dev/null +++ b/src/llm_models/gpt_oss_20b.rs @@ -0,0 +1,21 @@ +use super::ModelHandler; + +pub struct GptOss20bHandler; + +impl ModelHandler for GptOss20bHandler { + fn is_analysis_complete(&self, buffer: &str) -> bool { + buffer.ends_with("final") + } + + fn process_content(&self, content: &str) -> String { + if let Some(pos) = content.find("final") { + content[..pos].to_string() + } else { + content.to_string() + } + } + + fn has_analysis_markers(&self, buffer: &str) -> bool { + buffer.contains("**") + } +} diff --git a/src/llm_models/mod.rs b/src/llm_models/mod.rs new file mode 100644 index 00000000..860ed44c --- /dev/null +++ b/src/llm_models/mod.rs @@ -0,0 +1,33 @@ +//! Module for handling model-specific behavior and token processing + +pub mod gpt_oss_20b; +pub mod deepseek_r3; +pub mod gpt_oss_120b; + + +/// Trait for model-specific token processing +pub trait ModelHandler: Send + Sync { + /// Check if the analysis buffer indicates completion + fn is_analysis_complete(&self, buffer: &str) -> bool; + + /// Process the content, removing any model-specific tokens + fn process_content(&self, content: &str) -> String; + + /// Check if the buffer contains analysis start markers + fn has_analysis_markers(&self, buffer: &str) -> bool; +} + +/// Get the appropriate handler based on model path from bot configuration +pub fn get_handler(model_path: &str) -> Box { + let path = model_path.to_lowercase(); + + if path.contains("deepseek") { + Box::new(deepseek_r3::DeepseekR3Handler) + } else if path.contains("120b") { + Box::new(gpt_oss_120b::GptOss120bHandler::new("default")) + } else if path.contains("gpt-oss") || path.contains("gpt") { + Box::new(gpt_oss_20b::GptOss20bHandler) + } else { + Box::new(gpt_oss_20b::GptOss20bHandler) + } +} diff --git a/src/main.rs b/src/main.rs index 66e4462c..92f38eea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ mod ui; mod file; mod kb; mod llm; -mod llm_legacy; +mod llm_models; mod meet; mod org; mod package_manager; @@ -49,7 +49,7 @@ use crate::email::{ get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email, }; use crate::file::{init_drive, upload_file}; -use crate::llm_legacy::llm_local::{ +use crate::llm::local::{ chat_completions_local, embeddings_local, ensure_llama_servers_running, }; use crate::meet::{voice_start, voice_stop}; @@ -124,8 +124,8 @@ async fn main() -> std::io::Result<()> { &std::env::var("DATABASE_URL") .unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()), ) { - Ok(mut conn) => AppConfig::from_database(&mut conn), - Err(_) => AppConfig::from_env(), + Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"), + Err(_) => AppConfig::from_env().expect("Failed to load config from env"), } } else { match bootstrap.bootstrap().await { @@ -135,13 +135,13 @@ async fn main() -> std::io::Result<()> { } Err(e) => { log::error!("Bootstrap failed: {}", e); - match diesel::Connection::establish( - &std::env::var("DATABASE_URL") - .unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()), - ) { - Ok(mut conn) => AppConfig::from_database(&mut conn), - Err(_) => AppConfig::from_env(), - } + match diesel::Connection::establish( + &std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()), + ) { + Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"), + Err(_) => AppConfig::from_env().expect("Failed to load config from env"), + } } } }; @@ -158,7 +158,7 @@ async fn main() -> std::io::Result<()> { // Refresh configuration from environment to ensure latest DATABASE_URL and credentials dotenv().ok(); - let refreshed_cfg = AppConfig::from_env(); + let refreshed_cfg = AppConfig::from_env().expect("Failed to load config from env"); let config = std::sync::Arc::new(refreshed_cfg.clone()); let db_pool = match diesel::Connection::establish(&refreshed_cfg.database_url()) { Ok(conn) => Arc::new(Mutex::new(conn)), diff --git a/src/shared/models.rs b/src/shared/models.rs index 75a2e2da..92a1afb7 100644 --- a/src/shared/models.rs +++ b/src/shared/models.rs @@ -411,6 +411,9 @@ pub mod schema { bot_id -> Uuid, config_key -> Text, config_value -> Text, + + is_encrypted -> Bool, + config_type -> Text, created_at -> Timestamptz, updated_at -> Timestamptz, diff --git a/src/shared/utils.rs b/src/shared/utils.rs index abb39c01..dfba3f20 100644 --- a/src/shared/utils.rs +++ b/src/shared/utils.rs @@ -34,6 +34,19 @@ pub fn extract_zip_recursive( if !parent.exists() { std::fs::create_dir_all(&parent)?; } + +use crate::llm::LLMProvider; +use std::sync::Arc; +use serde_json::Value; + +/// Unified chat utility to interact with any LLM provider +pub async fn chat_with_llm( + provider: Arc, + prompt: &str, + config: &Value, +) -> Result> { + provider.generate(prompt, config).await +} } let mut outfile = File::create(&outpath)?; std::io::copy(&mut file, &mut outfile)?;