diff --git a/src/bootstrap/mod.rs b/src/bootstrap/mod.rs
index 512ffe47..35b43115 100644
--- a/src/bootstrap/mod.rs
+++ b/src/bootstrap/mod.rs
@@ -43,7 +43,7 @@ impl BootstrapManager {
"Initializing BootstrapManager with mode {:?} and tenant {:?}",
install_mode, tenant
);
- let config = AppConfig::from_env();
+ let config = AppConfig::from_env().expect("Failed to load config from env");
let s3_client = futures::executor::block_on(Self::create_s3_operator(&config));
Self {
install_mode,
@@ -179,11 +179,11 @@ impl BootstrapManager {
if let Err(e) = self.apply_migrations(&mut conn) {
log::warn!("Failed to apply migrations: {}", e);
}
- return Ok(AppConfig::from_database(&mut conn));
+ return Ok(AppConfig::from_database(&mut conn).expect("Failed to load config from DB"));
}
Err(e) => {
log::warn!("Failed to connect to database: {}", e);
- return Ok(AppConfig::from_env());
+ return Ok(AppConfig::from_env()?);
}
}
}
@@ -191,7 +191,7 @@ impl BootstrapManager {
let pm = PackageManager::new(self.install_mode.clone(), self.tenant.clone())?;
let required_components = vec!["tables", "drive", "cache", "llm"];
- let mut config = AppConfig::from_env();
+ let mut config = AppConfig::from_env().expect("Failed to load config from env");
for component in required_components {
if !pm.is_installed(component) {
@@ -282,7 +282,7 @@ impl BootstrapManager {
);
}
- config = AppConfig::from_database(&mut conn);
+ config = AppConfig::from_database(&mut conn).expect("Failed to load config from DB");
}
}
}
@@ -513,9 +513,9 @@ impl BootstrapManager {
// Create fresh connection for final config load
let mut final_conn = establish_pg_connection()?;
- let config = AppConfig::from_database(&mut final_conn);
- info!("Successfully loaded config from CSV with LLM settings");
- Ok(config)
+let config = AppConfig::from_database(&mut final_conn)?;
+info!("Successfully loaded config from CSV with LLM settings");
+Ok(config)
}
Err(e) => {
debug!("No config.csv found: {}", e);
diff --git a/src/bot/mod.rs b/src/bot/mod.rs
index b41186e7..95685749 100644
--- a/src/bot/mod.rs
+++ b/src/bot/mod.rs
@@ -4,6 +4,7 @@ use crate::context::langcache::get_langcache_client;
use crate::drive_monitor::DriveMonitor;
use crate::kb::embeddings::generate_embeddings;
use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint};
+use crate::llm_models;
use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession};
use crate::shared::state::AppState;
use actix_web::{web, HttpRequest, HttpResponse, Result};
@@ -680,7 +681,6 @@ impl BotOrchestrator {
e
})?;
-
let session = {
let mut sm = self.state.session_manager.lock().await;
let session_id = Uuid::parse_str(&message.session_id).map_err(|e| {
@@ -795,6 +795,17 @@ impl BotOrchestrator {
let mut chunk_count = 0;
let mut first_word_received = false;
+ let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
+ let model = config_manager
+ .get_config(
+ &Uuid::parse_str(&message.bot_id).unwrap_or_default(),
+ "llm-model",
+ None,
+ )
+ .unwrap_or_default();
+
+ let handler = llm_models::get_handler(&model);
+
while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1;
@@ -804,63 +815,56 @@ impl BotOrchestrator {
analysis_buffer.push_str(&chunk);
- if (analysis_buffer.contains("**") || analysis_buffer.contains(""))
- && !in_analysis
- {
+ // Check for analysis markers
+ if handler.has_analysis_markers(&analysis_buffer) && !in_analysis {
in_analysis = true;
}
- if in_analysis {
- if analysis_buffer.ends_with("final") || analysis_buffer.ends_with("") {
- info!(
- "Analysis section completed, buffer length: {}",
- analysis_buffer.len()
- );
- in_analysis = false;
- analysis_buffer.clear();
+ // Check if analysis is complete
+ if in_analysis && handler.is_analysis_complete(&analysis_buffer) {
+ info!("Analysis section completed");
+ in_analysis = false;
+ analysis_buffer.clear();
- if message.channel == "web" {
- let orchestrator = BotOrchestrator::new(Arc::clone(&self.state));
- orchestrator
- .send_event(
- &message.user_id,
- &message.bot_id,
- &message.session_id,
- &message.channel,
- "thinking_end",
- serde_json::json!({
- "user_id": message.user_id.clone()
- }),
- )
- .await
- .ok();
- }
- analysis_buffer.clear();
- continue;
+ if message.channel == "web" {
+ let orchestrator = BotOrchestrator::new(Arc::clone(&self.state));
+ orchestrator
+ .send_event(
+ &message.user_id,
+ &message.bot_id,
+ &message.session_id,
+ &message.channel,
+ "thinking_end",
+ serde_json::json!({"user_id": message.user_id.clone()}),
+ )
+ .await
+ .ok();
}
continue;
}
- full_response.push_str(&chunk);
+ if !in_analysis {
+ full_response.push_str(&chunk);
- let partial = BotResponse {
- bot_id: message.bot_id.clone(),
- user_id: message.user_id.clone(),
- session_id: message.session_id.clone(),
- channel: message.channel.clone(),
- content: chunk,
- message_type: 1,
- stream_token: None,
- is_complete: false,
- suggestions: suggestions.clone(),
- context_name: None,
- context_length: 0,
- context_max_length: 0,
- };
+ let partial = BotResponse {
+ bot_id: message.bot_id.clone(),
+ user_id: message.user_id.clone(),
+ session_id: message.session_id.clone(),
+ channel: message.channel.clone(),
+ content: chunk,
+ message_type: 1,
+ stream_token: None,
+ is_complete: false,
+ suggestions: suggestions.clone(),
+ context_name: None,
+ context_length: 0,
+ context_max_length: 0,
+ };
- if response_tx.send(partial).await.is_err() {
- warn!("Response channel closed, stopping stream processing");
- break;
+ if response_tx.send(partial).await.is_err() {
+ warn!("Response channel closed, stopping stream processing");
+ break;
+ }
}
}
@@ -1369,7 +1373,10 @@ async fn websocket_handler(
// Cancel any ongoing LLM jobs for this session
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
- warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
+ warn!(
+ "Failed to cancel LLM job for session {}: {}",
+ session_id_clone2, e
+ );
}
info!("WebSocket fully closed for session {}", session_id_clone2);
diff --git a/src/config/mod.rs b/src/config/mod.rs
index 9a804573..6eefa013 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -1,14 +1,19 @@
use diesel::prelude::*;
+use diesel::pg::PgConnection;
+use crate::shared::models::schema::bot_configuration;
use diesel::sql_types::Text;
+use uuid::Uuid;
+use diesel::pg::Pg;
use log::{info, trace, warn};
-use serde::{Deserialize, Serialize};
+ // removed unused serde import
use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::Write;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
+use crate::shared::utils::establish_pg_connection;
-#[derive(Clone)]
+#[derive(Clone, Default)]
pub struct LLMConfig {
pub url: String,
pub key: String,
@@ -20,7 +25,6 @@ pub struct AppConfig {
pub drive: DriveConfig,
pub server: ServerConfig,
pub database: DatabaseConfig,
- pub database_custom: DatabaseConfig,
pub email: EmailConfig,
pub llm: LLMConfig,
pub embedding: LLMConfig,
@@ -61,17 +65,21 @@ pub struct EmailConfig {
pub password: String,
}
-#[derive(Debug, Clone, Serialize, Deserialize, QueryableByName)]
+#[derive(Debug, Clone, Queryable, Selectable)]
+#[diesel(table_name = bot_configuration)]
+#[diesel(check_for_backend(Pg))]
pub struct ServerConfigRow {
- #[diesel(sql_type = Text)]
- pub id: String,
+ #[diesel(sql_type = DieselUuid)]
+ pub id: Uuid,
+ #[diesel(sql_type = DieselUuid)]
+ pub bot_id: Uuid,
#[diesel(sql_type = Text)]
pub config_key: String,
#[diesel(sql_type = Text)]
pub config_value: String,
#[diesel(sql_type = Text)]
pub config_type: String,
- #[diesel(sql_type = diesel::sql_types::Bool)]
+ #[diesel(sql_type = Bool)]
pub is_encrypted: bool,
}
@@ -87,17 +95,6 @@ impl AppConfig {
)
}
- pub fn database_custom_url(&self) -> String {
- format!(
- "postgres://{}:{}@{}:{}/{}",
- self.database_custom.username,
- self.database_custom.password,
- self.database_custom.server,
- self.database_custom.port,
- self.database_custom.database
- )
- }
-
pub fn component_path(&self, component: &str) -> PathBuf {
self.stack_path.join(component)
}
@@ -117,125 +114,134 @@ impl AppConfig {
pub fn log_path(&self, component: &str) -> PathBuf {
self.stack_path.join("logs").join(component)
}
+}
+
+impl AppConfig {
+ pub fn from_database(conn: &mut PgConnection) -> Result {
+ info!("Loading configuration from database");
+
+ use crate::shared::models::schema::bot_configuration::dsl::*;
+ use diesel::prelude::*;
+
+ let config_map: HashMap = bot_configuration
+ .select(ServerConfigRow::as_select()).load::(conn)
+ .unwrap_or_default()
+ .into_iter()
+ .map(|row| (row.config_key.clone(), row))
+ .collect();
- pub fn from_database(_conn: &mut PgConnection) -> Self {
- info!("Loading configuration from database");
- // Config is now loaded from bot_configuration table via drive_monitor
- let config_map: HashMap = HashMap::new();
+ let mut get_str = |key: &str, default: &str| -> String {
+ bot_configuration
+ .filter(config_key.eq(key))
+ .select(config_value)
+ .first::(conn)
+ .unwrap_or_else(|_| default.to_string())
+ };
- let get_str = |key: &str, default: &str| -> String {
- config_map
- .get(key)
- .map(|v: &ServerConfigRow| v.config_value.clone())
- .unwrap_or_else(|| default.to_string())
- };
+ let get_u32 = |key: &str, default: u32| -> u32 {
+ config_map
+ .get(key)
+ .and_then(|v| v.config_value.parse().ok())
+ .unwrap_or(default)
+ };
- let get_u32 = |key: &str, default: u32| -> u32 {
- config_map
- .get(key)
- .and_then(|v| v.config_value.parse().ok())
- .unwrap_or(default)
- };
+ let get_u16 = |key: &str, default: u16| -> u16 {
+ config_map
+ .get(key)
+ .and_then(|v| v.config_value.parse().ok())
+ .unwrap_or(default)
+ };
- let get_u16 = |key: &str, default: u16| -> u16 {
- config_map
- .get(key)
- .and_then(|v| v.config_value.parse().ok())
- .unwrap_or(default)
- };
+ let get_bool = |key: &str, default: bool| -> bool {
+ config_map
+ .get(key)
+ .map(|v| v.config_value.to_lowercase() == "true")
+ .unwrap_or(default)
+ };
- let get_bool = |key: &str, default: bool| -> bool {
- config_map
- .get(key)
- .map(|v| v.config_value.to_lowercase() == "true")
- .unwrap_or(default)
- };
+ let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack"));
- let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack"));
+ let database = DatabaseConfig {
+ username: std::env::var("TABLES_USERNAME")
+ .unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")),
+ password: std::env::var("TABLES_PASSWORD")
+ .unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")),
+ server: std::env::var("TABLES_SERVER")
+ .unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")),
+ port: std::env::var("TABLES_PORT")
+ .ok()
+ .and_then(|p| p.parse().ok())
+ .unwrap_or_else(|| get_u32("TABLES_PORT", 5432)),
+ database: std::env::var("TABLES_DATABASE")
+ .unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")),
+ };
- let database = DatabaseConfig {
- username: std::env::var("TABLES_USERNAME")
- .unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")),
- password: std::env::var("TABLES_PASSWORD")
- .unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")),
- server: std::env::var("TABLES_SERVER")
- .unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")),
- port: std::env::var("TABLES_PORT")
- .ok()
- .and_then(|p| p.parse().ok())
- .unwrap_or_else(|| get_u32("TABLES_PORT", 5432)),
- database: std::env::var("TABLES_DATABASE")
- .unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")),
- };
- let database_custom = DatabaseConfig {
- username: std::env::var("CUSTOM_USERNAME")
- .unwrap_or_else(|_| get_str("CUSTOM_USERNAME", "gbuser")),
- password: std::env::var("CUSTOM_PASSWORD")
- .unwrap_or_else(|_| get_str("CUSTOM_PASSWORD", "")),
- server: std::env::var("CUSTOM_SERVER")
- .unwrap_or_else(|_| get_str("CUSTOM_SERVER", "localhost")),
- port: std::env::var("CUSTOM_PORT")
- .ok()
- .and_then(|p| p.parse().ok())
- .unwrap_or_else(|| get_u32("CUSTOM_PORT", 5432)),
- database: std::env::var("CUSTOM_DATABASE")
- .unwrap_or_else(|_| get_str("CUSTOM_DATABASE", "gbuser")),
- };
+ let drive = DriveConfig {
+ server: {
+ let server = get_str("DRIVE_SERVER", "http://localhost:9000");
+ if !server.starts_with("http://") && !server.starts_with("https://") {
+ format!("http://{}", server)
+ } else {
+ server
+ }
+ },
+ access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"),
+ secret_key: get_str("DRIVE_SECRET", "minioadmin"),
+ use_ssl: get_bool("DRIVE_USE_SSL", false),
+ };
- let minio = DriveConfig {
- server: {
- let server = get_str("DRIVE_SERVER", "http://localhost:9000");
- if !server.starts_with("http://") && !server.starts_with("https://") {
- format!("http://{}", server)
- } else {
- server
- }
- },
- access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"),
- secret_key: get_str("DRIVE_SECRET", "minioadmin"),
- use_ssl: get_bool("DRIVE_USE_SSL", false),
- };
+ let email = EmailConfig {
+ from: get_str("EMAIL_FROM", "noreply@example.com"),
+ server: get_str("EMAIL_SERVER", "smtp.example.com"),
+ port: get_u16("EMAIL_PORT", 587),
+ username: get_str("EMAIL_USER", "user"),
+ password: get_str("EMAIL_PASS", "pass"),
+ };
- let email = EmailConfig {
- from: get_str("EMAIL_FROM", "noreply@example.com"),
- server: get_str("EMAIL_SERVER", "smtp.example.com"),
- port: get_u16("EMAIL_PORT", 587),
- username: get_str("EMAIL_USER", "user"),
- password: get_str("EMAIL_PASS", "pass"),
- };
+ // Write drive config to .env file
+ if let Err(e) = write_drive_config_to_env(&drive) {
+ warn!("Failed to write drive config to .env: {}", e);
+ }
- // Write drive config to .env file
- if let Err(e) = write_drive_config_to_env(&minio) {
- warn!("Failed to write drive config to .env: {}", e);
- }
-
- AppConfig {
- drive: minio,
+ Ok(AppConfig {
+ drive,
server: ServerConfig {
host: get_str("SERVER_HOST", "127.0.0.1"),
port: get_u16("SERVER_PORT", 8080),
},
database,
- database_custom,
email,
- llm: LLMConfig {
- url: get_str("LLM_URL", "http://localhost:8081"),
- key: get_str("LLM_KEY", ""),
- model: get_str("LLM_MODEL", "gpt-4"),
+ llm: {
+ // Use a fresh connection for ConfigManager to avoid cloning the mutable reference
+ let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
+ let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
+ LLMConfig {
+ url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
+ key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
+ model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
+ }
},
- embedding: LLMConfig {
- url: get_str("EMBEDDING_URL", "http://localhost:8082"),
- key: get_str("EMBEDDING_KEY", ""),
- model: get_str("EMBEDDING_MODEL", "text-embedding-ada-002"),
+ embedding: {
+ let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
+ let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
+ LLMConfig {
+ url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
+ key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
+ model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
+ }
+ },
+ site_path: {
+ let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
+ ConfigManager::new(Arc::new(Mutex::new(fresh_conn)))
+ .get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?.to_string()
},
- site_path: get_str("SITES_ROOT", "./botserver-stack/sites"),
stack_path,
db_conn: None,
- }
- }
-
- pub fn from_env() -> Self {
+ })
+}
+
+ pub fn from_env() -> Result {
info!("Loading configuration from environment variables");
let stack_path =
@@ -254,16 +260,6 @@ impl AppConfig {
database: db_name,
};
- let database_custom = DatabaseConfig {
- username: std::env::var("CUSTOM_USERNAME").unwrap_or_else(|_| "gbuser".to_string()),
- password: std::env::var("CUSTOM_PASSWORD").unwrap_or_else(|_| "".to_string()),
- server: std::env::var("CUSTOM_SERVER").unwrap_or_else(|_| "localhost".to_string()),
- port: std::env::var("CUSTOM_PORT")
- .ok()
- .and_then(|p| p.parse().ok())
- .unwrap_or(5432),
- database: std::env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "botserver".to_string()),
- };
let minio = DriveConfig {
server: std::env::var("DRIVE_SERVER")
@@ -288,33 +284,43 @@ impl AppConfig {
password: std::env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()),
};
- AppConfig {
+ Ok(AppConfig {
drive: minio,
server: ServerConfig {
- host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
+ host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.1".to_string()),
port: std::env::var("SERVER_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8080),
},
database,
- database_custom,
email,
- llm: LLMConfig {
- url: std::env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()),
- key: std::env::var("LLM_KEY").unwrap_or_else(|_| "".to_string()),
- model: std::env::var("LLM_MODEL").unwrap_or_else(|_| "gpt-4".to_string()),
+ llm: {
+ let conn = PgConnection::establish(&database_url)?;
+ let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
+ LLMConfig {
+ url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
+ key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
+ model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
+ }
},
- embedding: LLMConfig {
- url: std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()),
- key: std::env::var("EMBEDDING_KEY").unwrap_or_else(|_| "".to_string()),
- model: std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-ada-002".to_string()),
+ embedding: {
+ let conn = PgConnection::establish(&database_url)?;
+ let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
+ LLMConfig {
+ url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
+ key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
+ model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
+ }
+ },
+ site_path: {
+ let conn = PgConnection::establish(&database_url)?;
+ ConfigManager::new(Arc::new(Mutex::new(conn)))
+ .get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?
},
- site_path: std::env::var("SITES_ROOT")
- .unwrap_or_else(|_| "./botserver-stack/sites".to_string()),
stack_path: PathBuf::from(stack_path),
db_conn: None,
- }
+ })
}
pub fn set_config(
diff --git a/src/drive_monitor/mod.rs b/src/drive_monitor/mod.rs
index cfe1b8a8..fcaffb14 100644
--- a/src/drive_monitor/mod.rs
+++ b/src/drive_monitor/mod.rs
@@ -62,10 +62,7 @@ impl DriveMonitor {
self.check_gbdialog_changes(client).await?;
self.check_gbkb_changes(client).await?;
-
- if let Err(e) = self.check_gbot(client).await {
- error!("Error checking default bot config: {}", e);
- }
+ self.check_gbot(client).await?;
Ok(())
}
@@ -257,7 +254,7 @@ impl DriveMonitor {
continue;
}
- debug!("Checking config file at path: {}", path);
+ trace!("Checking config file at path: {}", path);
match client
.head_object()
.bucket(&self.bucket_name)
@@ -276,7 +273,7 @@ impl DriveMonitor {
.key(&path)
.send()
.await?;
- debug!(
+ trace!(
"GetObject successful for {}, content length: {}",
path,
response.content_length().unwrap_or(0)
@@ -295,7 +292,7 @@ impl DriveMonitor {
.collect();
if !llm_lines.is_empty() {
- use crate::llm_legacy::llm_local::ensure_llama_servers_running;
+ use crate::llm::local::ensure_llama_servers_running;
let mut restart_needed = false;
for line in llm_lines {
@@ -338,7 +335,7 @@ impl DriveMonitor {
}
}
Err(e) => {
- debug!("Config file {} not found or inaccessible: {}", path, e);
+ error!("Config file {} not found or inaccessible: {}", path, e);
}
}
}
diff --git a/src/llm/llama.rs b/src/llm/anthropic.rs
similarity index 51%
rename from src/llm/llama.rs
rename to src/llm/anthropic.rs
index 2060979d..fbd1ccdf 100644
--- a/src/llm/llama.rs
+++ b/src/llm/anthropic.rs
@@ -1,62 +1,73 @@
use async_trait::async_trait;
use reqwest::Client;
+use futures_util::StreamExt;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
-use futures_util::StreamExt;
+
use crate::tools::ToolManager;
use super::LLMProvider;
-pub struct LlamaClient {
+pub struct AnthropicClient {
client: Client,
+ api_key: String,
base_url: String,
}
-impl LlamaClient {
- pub fn new(base_url: String) -> Self {
+impl AnthropicClient {
+ pub fn new(api_key: String) -> Self {
Self {
client: Client::new(),
- base_url,
+ api_key,
+ base_url: "https://api.anthropic.com/v1".to_string(),
}
}
}
#[async_trait]
-impl LLMProvider for LlamaClient {
+impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
- config: &Value,
+ _config: &Value,
) -> Result> {
- let response = self.client
- .post(&format!("{}/completion", self.base_url))
+ let response = self
+ .client
+ .post(&format!("{}/messages", self.base_url))
+ .header("x-api-key", &self.api_key)
+ .header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
- "prompt": prompt,
- "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
- "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
+ "model": "claude-3-sonnet-20240229",
+ "max_tokens": 1000,
+ "messages": [{"role": "user", "content": prompt}]
}))
.send()
.await?;
let result: Value = response.json().await?;
- Ok(result["content"]
+ let content = result["content"][0]["text"]
.as_str()
.unwrap_or("")
- .to_string())
+ .to_string();
+
+ Ok(content)
}
async fn generate_stream(
&self,
prompt: &str,
- config: &Value,
+ _config: &Value,
tx: mpsc::Sender,
) -> Result<(), Box> {
- let response = self.client
- .post(&format!("{}/completion", self.base_url))
+ let response = self
+ .client
+ .post(&format!("{}/messages", self.base_url))
+ .header("x-api-key", &self.api_key)
+ .header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
- "prompt": prompt,
- "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
- "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
+ "model": "claude-3-sonnet-20240229",
+ "max_tokens": 1000,
+ "messages": [{"role": "user", "content": prompt}],
"stream": true
}))
.send()
@@ -66,14 +77,18 @@ impl LLMProvider for LlamaClient {
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
- let chunk = chunk?;
- let chunk_str = String::from_utf8_lossy(&chunk);
+ let chunk_bytes = chunk?;
+ let chunk_str = String::from_utf8_lossy(&chunk_bytes);
for line in chunk_str.lines() {
- if let Ok(data) = serde_json::from_str::(&line) {
- if let Some(content) = data["content"].as_str() {
- buffer.push_str(content);
- let _ = tx.send(content.to_string()).await;
+ if line.starts_with("data: ") {
+ if let Ok(data) = serde_json::from_str::(&line[6..]) {
+ if data["type"] == "content_block_delta" {
+ if let Some(text) = data["delta"]["text"].as_str() {
+ buffer.push_str(text);
+ let _ = tx.send(text.to_string()).await;
+ }
+ }
}
}
}
@@ -85,7 +100,7 @@ impl LLMProvider for LlamaClient {
async fn generate_with_tools(
&self,
prompt: &str,
- config: &Value,
+ _config: &Value,
available_tools: &[String],
_tool_manager: Arc,
_session_id: &str,
@@ -98,26 +113,14 @@ impl LLMProvider for LlamaClient {
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
- self.generate(&enhanced_prompt, config).await
+ self.generate(&enhanced_prompt, &Value::Null).await
}
async fn cancel_job(
&self,
- session_id: &str,
+ _session_id: &str,
) -> Result<(), Box> {
- // llama.cpp cancellation endpoint
- let response = self.client
- .post(&format!("{}/cancel", self.base_url))
- .json(&serde_json::json!({
- "session_id": session_id
- }))
- .send()
- .await?;
-
- if response.status().is_success() {
- Ok(())
- } else {
- Err(format!("Failed to cancel job for session {}", session_id).into())
- }
+ // Anthropic doesn't support job cancellation
+ Ok(())
}
}
diff --git a/src/llm/azure.rs b/src/llm/azure.rs
new file mode 100644
index 00000000..0a7e5b81
--- /dev/null
+++ b/src/llm/azure.rs
@@ -0,0 +1,103 @@
+use async_trait::async_trait;
+use reqwest::Client;
+use serde_json::Value;
+use std::sync::Arc;
+use crate::tools::ToolManager;
+use super::LLMProvider;
+
+pub struct AzureOpenAIClient {
+ endpoint: String,
+ api_key: String,
+ api_version: String,
+ deployment: String,
+ client: Client,
+}
+
+impl AzureOpenAIClient {
+ pub fn new(endpoint: String, api_key: String, api_version: String, deployment: String) -> Self {
+ Self {
+ endpoint,
+ api_key,
+ api_version,
+ deployment,
+ client: Client::new(),
+ }
+ }
+}
+
+#[async_trait]
+impl LLMProvider for AzureOpenAIClient {
+ async fn generate(
+ &self,
+ prompt: &str,
+ _config: &Value,
+ ) -> Result> {
+ let url = format!(
+ "{}/openai/deployments/{}/chat/completions?api-version={}",
+ self.endpoint, self.deployment, self.api_version
+ );
+
+ let body = serde_json::json!({
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ],
+ "temperature": 0.7,
+ "max_tokens": 1000,
+ "top_p": 1.0,
+ "frequency_penalty": 0.0,
+ "presence_penalty": 0.0
+ });
+
+ let response = self.client
+ .post(&url)
+ .header("api-key", &self.api_key)
+ .header("Content-Type", "application/json")
+ .json(&body)
+ .send()
+ .await?;
+
+ let result: Value = response.json().await?;
+ if let Some(choice) = result["choices"].get(0) {
+ Ok(choice["message"]["content"].as_str().unwrap_or("").to_string())
+ } else {
+ Err("No response from Azure OpenAI".into())
+ }
+ }
+
+ async fn generate_stream(
+ &self,
+ prompt: &str,
+ _config: &Value,
+ tx: tokio::sync::mpsc::Sender,
+ ) -> Result<(), Box> {
+ let content = self.generate(prompt, _config).await?;
+ let _ = tx.send(content).await;
+ Ok(())
+ }
+
+ async fn generate_with_tools(
+ &self,
+ prompt: &str,
+ _config: &Value,
+ available_tools: &[String],
+ _tool_manager: Arc,
+ _session_id: &str,
+ _user_id: &str,
+ ) -> Result> {
+ let tools_info = if available_tools.is_empty() {
+ String::new()
+ } else {
+ format!("\n\nAvailable tools: {}.", available_tools.join(", "))
+ };
+ let enhanced_prompt = format!("{}{}", prompt, tools_info);
+ self.generate(&enhanced_prompt, _config).await
+ }
+
+ async fn cancel_job(
+ &self,
+ _session_id: &str,
+ ) -> Result<(), Box> {
+ Ok(())
+ }
+}
diff --git a/src/llm/local.rs b/src/llm/local.rs
new file mode 100644
index 00000000..4fc331d5
--- /dev/null
+++ b/src/llm/local.rs
@@ -0,0 +1,269 @@
+use crate::shared::state::AppState;
+use crate::shared::models::schema::bots::dsl::*;
+use crate::config::ConfigManager;
+use diesel::prelude::*;
+use std::sync::Arc;
+use log::{info, error};
+use tokio;
+use reqwest;
+
+use actix_web::{post, web, HttpResponse, Result};
+
+#[post("/api/chat/completions")]
+pub async fn chat_completions_local(
+ _data: web::Data,
+ _payload: web::Json,
+) -> Result {
+ // Placeholder implementation
+ Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "chat_completions_local not implemented" })))
+}
+
+#[post("/api/embeddings")]
+pub async fn embeddings_local(
+ _data: web::Data,
+ _payload: web::Json,
+) -> Result {
+ // Placeholder implementation
+ Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "embeddings_local not implemented" })))
+}
+
+pub async fn ensure_llama_servers_running(
+ app_state: &Arc
+) -> Result<(), Box> {
+ let conn = app_state.conn.clone();
+ let config_manager = ConfigManager::new(conn.clone());
+
+ let default_bot_id = {
+ let mut conn = conn.lock().unwrap();
+ bots.filter(name.eq("default"))
+ .select(id)
+ .first::(&mut *conn)
+ .unwrap_or_else(|_| uuid::Uuid::nil())
+ };
+
+ let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
+ let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
+ let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
+ let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
+ let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
+
+ info!("Starting LLM servers...");
+ info!("Configuration:");
+ info!(" LLM URL: {}", llm_url);
+ info!(" Embedding URL: {}", embedding_url);
+ info!(" LLM Model: {}", llm_model);
+ info!(" Embedding Model: {}", embedding_model);
+ info!(" LLM Server Path: {}", llm_server_path);
+
+ // Restart any existing llama-server processes
+ info!("Restarting any existing llama-server processes...");
+ if let Err(e) = tokio::process::Command::new("sh")
+ .arg("-c")
+ .arg("pkill -f llama-server || true")
+ .spawn()
+ {
+ error!("Failed to execute pkill for llama-server: {}", e);
+ } else {
+ tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
+ info!("Existing llama-server processes terminated (if any)");
+ }
+
+ // Check if servers are already running
+ let llm_running = is_server_running(&llm_url).await;
+ let embedding_running = is_server_running(&embedding_url).await;
+
+ if llm_running && embedding_running {
+ info!("Both LLM and Embedding servers are already running");
+ return Ok(());
+ }
+
+ // Start servers that aren't running
+ let mut tasks = vec![];
+
+ if !llm_running && !llm_model.is_empty() {
+ info!("Starting LLM server...");
+ tasks.push(tokio::spawn(start_llm_server(
+ Arc::clone(app_state),
+ llm_server_path.clone(),
+ llm_model.clone(),
+ llm_url.clone(),
+ )));
+ } else if llm_model.is_empty() {
+ info!("LLM_MODEL not set, skipping LLM server");
+ }
+
+ if !embedding_running && !embedding_model.is_empty() {
+ info!("Starting Embedding server...");
+ tasks.push(tokio::spawn(start_embedding_server(
+ llm_server_path.clone(),
+ embedding_model.clone(),
+ embedding_url.clone(),
+ )));
+ } else if embedding_model.is_empty() {
+ info!("EMBEDDING_MODEL not set, skipping Embedding server");
+ }
+
+ // Wait for all server startup tasks
+ for task in tasks {
+ task.await??;
+ }
+
+ // Wait for servers to be ready with verbose logging
+ info!("Waiting for servers to become ready...");
+
+ let mut llm_ready = llm_running || llm_model.is_empty();
+ let mut embedding_ready = embedding_running || embedding_model.is_empty();
+
+ let mut attempts = 0;
+ let max_attempts = 60; // 2 minutes total
+
+ while attempts < max_attempts && (!llm_ready || !embedding_ready) {
+ tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
+
+ info!("Checking server health (attempt {}/{})...", attempts + 1, max_attempts);
+
+ if !llm_ready && !llm_model.is_empty() {
+ if is_server_running(&llm_url).await {
+ info!("LLM server ready at {}", llm_url);
+ llm_ready = true;
+ } else {
+ info!("LLM server not ready yet");
+ }
+ }
+
+ if !embedding_ready && !embedding_model.is_empty() {
+ if is_server_running(&embedding_url).await {
+ info!("Embedding server ready at {}", embedding_url);
+ embedding_ready = true;
+ } else {
+ info!("Embedding server not ready yet");
+ }
+ }
+
+ attempts += 1;
+
+ if attempts % 10 == 0 {
+ info!("Still waiting for servers... (attempt {}/{})", attempts, max_attempts);
+ }
+ }
+
+ if llm_ready && embedding_ready {
+ info!("All llama.cpp servers are ready and responding!");
+ Ok(())
+ } else {
+ let mut error_msg = "Servers failed to start within timeout:".to_string();
+ if !llm_ready && !llm_model.is_empty() {
+ error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
+ }
+ if !embedding_ready && !embedding_model.is_empty() {
+ error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
+ }
+ Err(error_msg.into())
+ }
+}
+
+pub async fn is_server_running(url: &str) -> bool {
+ let client = reqwest::Client::new();
+ match client.get(&format!("{}/health", url)).send().await {
+ Ok(response) => response.status().is_success(),
+ Err(_) => false,
+ }
+}
+
+pub async fn start_llm_server(
+ app_state: Arc,
+ llama_cpp_path: String,
+ model_path: String,
+ url: String,
+) -> Result<(), Box> {
+ let port = url.split(':').last().unwrap_or("8081");
+
+ std::env::set_var("OMP_NUM_THREADS", "20");
+ std::env::set_var("OMP_PLACES", "cores");
+ std::env::set_var("OMP_PROC_BIND", "close");
+
+ let conn = app_state.conn.clone();
+ let config_manager = ConfigManager::new(conn.clone());
+
+ let default_bot_id = {
+ let mut conn = conn.lock().unwrap();
+ bots.filter(name.eq("default"))
+ .select(id)
+ .first::(&mut *conn)
+ .unwrap_or_else(|_| uuid::Uuid::nil())
+ };
+
+ let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
+ let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
+ let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
+ let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
+ let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
+ let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
+ let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
+
+ // Build command arguments dynamically
+ let mut args = format!(
+ "-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
+ model_path, port, ctx_size, gpu_layers
+ );
+
+ if n_moe != "0" {
+ args.push_str(&format!(" --n-cpu-moe {}", n_moe));
+ }
+ if parallel != "1" {
+ args.push_str(&format!(" --parallel {}", parallel));
+ }
+ if cont_batching == "true" {
+ args.push_str(" --cont-batching");
+ }
+ if mlock == "true" {
+ args.push_str(" --mlock");
+ }
+ if no_mmap == "true" {
+ args.push_str(" --no-mmap");
+ }
+
+ if cfg!(windows) {
+ let mut cmd = tokio::process::Command::new("cmd");
+ cmd.arg("/C").arg(format!(
+ "cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
+ llama_cpp_path, args
+ ));
+ cmd.spawn()?;
+ } else {
+ let mut cmd = tokio::process::Command::new("sh");
+ cmd.arg("-c").arg(format!(
+ "cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
+ llama_cpp_path, args
+ ));
+ cmd.spawn()?;
+ }
+
+ Ok(())
+}
+
+pub async fn start_embedding_server(
+ llama_cpp_path: String,
+ model_path: String,
+ url: String,
+) -> Result<(), Box> {
+ let port = url.split(':').last().unwrap_or("8082");
+
+ if cfg!(windows) {
+ let mut cmd = tokio::process::Command::new("cmd");
+ cmd.arg("/c").arg(format!(
+ "cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
+ llama_cpp_path, model_path, port
+ ));
+ cmd.spawn()?;
+ } else {
+ let mut cmd = tokio::process::Command::new("sh");
+ cmd.arg("-c").arg(format!(
+ "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
+ llama_cpp_path, model_path, port
+ ));
+ cmd.spawn()?;
+ }
+
+ Ok(())
+}
diff --git a/src/llm/mod.rs b/src/llm/mod.rs
index 1ba6db08..91a38aee 100644
--- a/src/llm/mod.rs
+++ b/src/llm/mod.rs
@@ -6,7 +6,9 @@ use tokio::sync::mpsc;
use crate::tools::ToolManager;
-pub mod llama;
+pub mod azure;
+pub mod local;
+pub mod anthropic;
#[async_trait]
pub trait LLMProvider: Send + Sync {
@@ -161,123 +163,6 @@ impl LLMProvider for OpenAIClient {
}
}
-pub struct AnthropicClient {
- client: reqwest::Client,
- api_key: String,
- base_url: String,
-}
-
-impl AnthropicClient {
- pub fn new(api_key: String) -> Self {
- Self {
- client: reqwest::Client::new(),
- api_key,
- base_url: "https://api.anthropic.com/v1".to_string(),
- }
- }
-}
-
-#[async_trait]
-impl LLMProvider for AnthropicClient {
- async fn generate(
- &self,
- prompt: &str,
- _config: &Value,
- ) -> Result> {
- let response = self
- .client
- .post(&format!("{}/messages", self.base_url))
- .header("x-api-key", &self.api_key)
- .header("anthropic-version", "2023-06-01")
- .json(&serde_json::json!({
- "model": "claude-3-sonnet-20240229",
- "max_tokens": 1000,
- "messages": [{"role": "user", "content": prompt}]
- }))
- .send()
- .await?;
-
- let result: Value = response.json().await?;
- let content = result["content"][0]["text"]
- .as_str()
- .unwrap_or("")
- .to_string();
-
- Ok(content)
- }
-
- async fn generate_stream(
- &self,
- prompt: &str,
- _config: &Value,
- tx: mpsc::Sender,
- ) -> Result<(), Box> {
- let response = self
- .client
- .post(&format!("{}/messages", self.base_url))
- .header("x-api-key", &self.api_key)
- .header("anthropic-version", "2023-06-01")
- .json(&serde_json::json!({
- "model": "claude-3-sonnet-20240229",
- "max_tokens": 1000,
- "messages": [{"role": "user", "content": prompt}],
- "stream": true
- }))
- .send()
- .await?;
-
- let mut stream = response.bytes_stream();
- let mut buffer = String::new();
-
- while let Some(chunk) = stream.next().await {
- let chunk = chunk?;
- let chunk_str = String::from_utf8_lossy(&chunk);
-
- for line in chunk_str.lines() {
- if line.starts_with("data: ") {
- if let Ok(data) = serde_json::from_str::(&line[6..]) {
- if data["type"] == "content_block_delta" {
- if let Some(text) = data["delta"]["text"].as_str() {
- buffer.push_str(text);
- let _ = tx.send(text.to_string()).await;
- }
- }
- }
- }
- }
- }
-
- Ok(())
- }
-
- async fn generate_with_tools(
- &self,
- prompt: &str,
- _config: &Value,
- available_tools: &[String],
- _tool_manager: Arc,
- _session_id: &str,
- _user_id: &str,
- ) -> Result> {
- let tools_info = if available_tools.is_empty() {
- String::new()
- } else {
- format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
- };
-
- let enhanced_prompt = format!("{}{}", prompt, tools_info);
-
- self.generate(&enhanced_prompt, &Value::Null).await
- }
-
- async fn cancel_job(
- &self,
- _session_id: &str,
- ) -> Result<(), Box> {
- // Anthropic doesn't support job cancellation
- Ok(())
- }
-}
pub struct MockLLMProvider;
diff --git a/src/llm_legacy/llm_azure.rs b/src/llm_legacy/llm_azure.rs
deleted file mode 100644
index ce5ac5c0..00000000
--- a/src/llm_legacy/llm_azure.rs
+++ /dev/null
@@ -1,145 +0,0 @@
-use dotenvy::dotenv;
-use log::{error, info};
-use reqwest::Client;
-use serde::{Deserialize, Serialize};
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct AzureOpenAIConfig {
- pub endpoint: String,
- pub api_key: String,
- pub api_version: String,
- pub deployment: String,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct ChatCompletionRequest {
- pub messages: Vec,
- pub temperature: f32,
- pub max_tokens: Option,
- pub top_p: f32,
- pub frequency_penalty: f32,
- pub presence_penalty: f32,
-}
-
-#[derive(Debug, Serialize, Deserialize, Clone)]
-pub struct ChatMessage {
- pub role: String,
- pub content: String,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct ChatCompletionResponse {
- pub id: String,
- pub object: String,
- pub created: u64,
- pub choices: Vec,
- pub usage: Usage,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct ChatChoice {
- pub index: u32,
- pub message: ChatMessage,
- pub finish_reason: Option,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-pub struct Usage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-pub struct AzureOpenAIClient {
- config: AzureOpenAIConfig,
- client: Client,
-}
-
-impl AzureOpenAIClient {
- pub fn new() -> Result> {
- dotenv().ok();
-
- let endpoint =
- std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
- let api_key =
- std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
- let api_version = std::env::var("AZURE_OPENAI_API_VERSION")
- .unwrap_or_else(|_| "2023-12-01-preview".to_string());
- let deployment =
- std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
-
- let config = AzureOpenAIConfig {
- endpoint,
- api_key,
- api_version,
- deployment,
- };
-
- Ok(Self {
- config,
- client: Client::new(),
- })
- }
-
- pub async fn chat_completions(
- &self,
- messages: Vec,
- temperature: f32,
- max_tokens: Option,
- ) -> Result> {
- let url = format!(
- "{}/openai/deployments/{}/chat/completions?api-version={}",
- self.config.endpoint, self.config.deployment, self.config.api_version
- );
-
- let request_body = ChatCompletionRequest {
- messages,
- temperature,
- max_tokens,
- top_p: 1.0,
- frequency_penalty: 0.0,
- presence_penalty: 0.0,
- };
-
- info!("Sending request to Azure OpenAI: {}", url);
-
- let response = self
- .client
- .post(&url)
- .header("api-key", &self.config.api_key)
- .header("Content-Type", "application/json")
- .json(&request_body)
- .send()
- .await?;
-
- if !response.status().is_success() {
- let error_text = response.text().await?;
- error!("Azure OpenAI API error: {}", error_text);
- return Err(format!("Azure OpenAI API error: {}", error_text).into());
- }
-
- let completion_response: ChatCompletionResponse = response.json().await?;
- Ok(completion_response)
- }
-
- pub async fn simple_chat(&self, prompt: &str) -> Result> {
- let messages = vec![
- ChatMessage {
- role: "system".to_string(),
- content: "You are a helpful assistant.".to_string(),
- },
- ChatMessage {
- role: "user".to_string(),
- content: prompt.to_string(),
- },
- ];
-
- let response = self.chat_completions(messages, 0.7, Some(1000)).await?;
-
- if let Some(choice) = response.choices.first() {
- Ok(choice.message.content.clone())
- } else {
- Err("No response from AI".into())
- }
- }
-}
diff --git a/src/llm_legacy/llm_local.rs b/src/llm_legacy/llm_local.rs
deleted file mode 100644
index be3ec1c0..00000000
--- a/src/llm_legacy/llm_local.rs
+++ /dev/null
@@ -1,656 +0,0 @@
-use actix_web::{post, web, HttpRequest, HttpResponse, Result};
-use dotenvy::dotenv;
-use log::{error, info};
-use reqwest::Client;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::sync::Arc;
-use tokio::time::{sleep, Duration};
-
-use crate::config::ConfigManager;
-use crate::shared::models::schema::bots::dsl::*;
-use crate::shared::state::AppState;
-use diesel::prelude::*;
-
-// OpenAI-compatible request/response structures
-#[derive(Debug, Serialize, Deserialize)]
-struct ChatMessage {
- role: String,
- content: String,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct ChatCompletionRequest {
- model: String,
- messages: Vec,
- stream: Option,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct ChatCompletionResponse {
- id: String,
- object: String,
- created: u64,
- model: String,
- choices: Vec,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct Choice {
- message: ChatMessage,
- finish_reason: String,
-}
-
-// Llama.cpp server request/response structures
-#[derive(Debug, Serialize, Deserialize)]
-struct LlamaCppRequest {
- prompt: String,
- n_predict: Option,
- temperature: Option,
- top_k: Option,
- top_p: Option,
- stream: Option,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct LlamaCppResponse {
- content: String,
- stop: bool,
- generation_settings: Option,
-}
-
-pub async fn ensure_llama_servers_running(
- app_state: &Arc,
-) -> Result<(), Box> {
- let conn = app_state.conn.clone();
- let config_manager = ConfigManager::new(conn.clone());
-
- // Get default bot ID from database
- let default_bot_id = {
- let mut conn = conn.lock().unwrap();
- bots.filter(name.eq("default"))
- .select(id)
- .first::(&mut *conn)
- .unwrap_or_else(|_| uuid::Uuid::nil())
- };
-
- // Get configuration from config using default bot ID
- let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
- let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
-
- let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
- let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
-
- let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
-
- info!(" Starting LLM servers...");
- info!(" Configuration:");
- info!(" LLM URL: {}", llm_url);
- info!(" Embedding URL: {}", embedding_url);
- info!(" LLM Model: {}", llm_model);
- info!(" Embedding Model: {}", embedding_model);
- info!(" LLM Server Path: {}", llm_server_path);
-
-
- // Restart any existing llama-server processes before starting new ones
- info!("🔁 Restarting any existing llama-server processes...");
- if let Err(e) = tokio::process::Command::new("sh")
- .arg("-c")
- .arg("pkill -f llama-server || true")
- .spawn()
- {
- error!("Failed to execute pkill for llama-server: {}", e);
- } else {
- sleep(Duration::from_secs(2)).await;
- info!("✅ Existing llama-server processes terminated (if any)");
- }
-
- // Check if servers are already running
- let llm_running = is_server_running(&llm_url).await;
- let embedding_running = is_server_running(&embedding_url).await;
-
- if llm_running && embedding_running {
- info!("✅ Both LLM and Embedding servers are already running");
- return Ok(());
- }
-
-
- // Start servers that aren't running
- let mut tasks = vec![];
-
- if !llm_running && !llm_model.is_empty() {
- info!("🔄 Starting LLM server...");
- tasks.push(tokio::spawn(start_llm_server(
- Arc::clone(app_state),
- llm_server_path.clone(),
- llm_model.clone(),
- llm_url.clone(),
- )));
- } else if llm_model.is_empty() {
- info!("⚠️ LLM_MODEL not set, skipping LLM server");
- }
-
- if !embedding_running && !embedding_model.is_empty() {
- info!("🔄 Starting Embedding server...");
- tasks.push(tokio::spawn(start_embedding_server(
- llm_server_path.clone(),
- embedding_model.clone(),
- embedding_url.clone(),
- )));
- } else if embedding_model.is_empty() {
- info!("⚠️ EMBEDDING_MODEL not set, skipping Embedding server");
- }
-
- // Wait for all server startup tasks
- for task in tasks {
- task.await??;
- }
-
- // Wait for servers to be ready with verbose logging
- info!("⏳ Waiting for servers to become ready...");
-
- let mut llm_ready = llm_running || llm_model.is_empty();
- let mut embedding_ready = embedding_running || embedding_model.is_empty();
-
- let mut attempts = 0;
- let max_attempts = 60; // 2 minutes total
-
- while attempts < max_attempts && (!llm_ready || !embedding_ready) {
- sleep(Duration::from_secs(2)).await;
-
- info!(
- "🔍 Checking server health (attempt {}/{})...",
- attempts + 1,
- max_attempts
- );
-
- if !llm_ready && !llm_model.is_empty() {
- if is_server_running(&llm_url).await {
- info!(" ✅ LLM server ready at {}", llm_url);
- llm_ready = true;
- } else {
- info!(" ❌ LLM server not ready yet");
- }
- }
-
- if !embedding_ready && !embedding_model.is_empty() {
- if is_server_running(&embedding_url).await {
- info!(" ✅ Embedding server ready at {}", embedding_url);
- embedding_ready = true;
- } else {
- info!(" ❌ Embedding server not ready yet");
- }
- }
-
- attempts += 1;
-
- if attempts % 10 == 0 {
- info!(
- "⏰ Still waiting for servers... (attempt {}/{})",
- attempts, max_attempts
- );
- }
- }
-
- if llm_ready && embedding_ready {
- info!("🎉 All llama.cpp servers are ready and responding!");
- Ok(())
- } else {
- let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
- if !llm_ready && !llm_model.is_empty() {
- error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
- }
- if !embedding_ready && !embedding_model.is_empty() {
- error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
- }
- Err(error_msg.into())
- }
-}
-
-async fn start_llm_server(
- app_state: Arc,
- llama_cpp_path: String,
- model_path: String,
- url: String,
-) -> Result<(), Box> {
- let port = url.split(':').last().unwrap_or("8081");
-
- std::env::set_var("OMP_NUM_THREADS", "20");
- std::env::set_var("OMP_PLACES", "cores");
- std::env::set_var("OMP_PROC_BIND", "close");
-
- // "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
- // Read config values with defaults
-
-
-
-
-
- let conn = app_state.conn.clone();
- let config_manager = ConfigManager::new(conn.clone());
-
- let default_bot_id = {
- let mut conn = conn.lock().unwrap();
- bots.filter(name.eq("default"))
- .select(id)
- .first::(&mut *conn)
- .unwrap_or_else(|_| uuid::Uuid::nil())
- };
-
- let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
- let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
- let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
- let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
- let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
- let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
- let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
-
- // Build command arguments dynamically
- let mut args = format!(
- "-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
- model_path, port, ctx_size, gpu_layers
- );
-
- if n_moe != "0" {
- args.push_str(&format!(" --n-cpu-moe {}", n_moe));
- }
- if parallel != "1" {
- args.push_str(&format!(" --parallel {}", parallel));
- }
- if cont_batching == "true" {
- args.push_str(" --cont-batching");
- }
- if mlock == "true" {
- args.push_str(" --mlock");
- }
- if no_mmap == "true" {
- args.push_str(" --no-mmap");
- }
-
- if cfg!(windows) {
- let mut cmd = tokio::process::Command::new("cmd");
- cmd.arg("/C").arg(format!(
- "cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
- llama_cpp_path, args
- ));
- cmd.spawn()?;
- } else {
- let mut cmd = tokio::process::Command::new("sh");
- cmd.arg("-c").arg(format!(
- "cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
- llama_cpp_path, args
- ));
- cmd.spawn()?;
- }
-
- Ok(())
-}
-
-async fn start_embedding_server(
- llama_cpp_path: String,
- model_path: String,
- url: String,
-) -> Result<(), Box> {
- let port = url.split(':').last().unwrap_or("8082");
-
- if cfg!(windows) {
- let mut cmd = tokio::process::Command::new("cmd");
- cmd.arg("/c").arg(format!(
- "cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
- llama_cpp_path, model_path, port
- ));
- cmd.spawn()?;
- } else {
- let mut cmd = tokio::process::Command::new("sh");
- cmd.arg("-c").arg(format!(
- "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
- llama_cpp_path, model_path, port
- ));
- cmd.spawn()?;
- }
-
- Ok(())
-}
-
-async fn is_server_running(url: &str) -> bool {
-
-
- let client = reqwest::Client::new();
- match client.get(&format!("{}/health", url)).send().await {
- Ok(response) => response.status().is_success(),
- Err(_) => false,
- }
-}
-
-// Convert OpenAI chat messages to a single prompt
-fn messages_to_prompt(messages: &[ChatMessage]) -> String {
- let mut prompt = String::new();
-
- for message in messages {
- match message.role.as_str() {
- "system" => {
- prompt.push_str(&format!("System: {}\n\n", message.content));
- }
- "user" => {
- prompt.push_str(&format!("User: {}\n\n", message.content));
- }
- "assistant" => {
- prompt.push_str(&format!("Assistant: {}\n\n", message.content));
- }
- _ => {
- prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
- }
- }
- }
-
- prompt.push_str("Assistant: ");
- prompt
-}
-
-// Proxy endpoint
-#[post("/local/v1/chat/completions")]
-pub async fn chat_completions_local(
- req_body: web::Json,
- _req: HttpRequest,
-) -> Result {
- dotenv().ok().unwrap();
-
- // Get llama.cpp server URL
- let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
-
- // Convert OpenAI format to llama.cpp format
- let prompt = messages_to_prompt(&req_body.messages);
-
- let llama_request = LlamaCppRequest {
- prompt,
- n_predict: Some(500), // Adjust as needed
- temperature: Some(0.7),
- top_k: Some(40),
- top_p: Some(0.9),
- stream: req_body.stream,
- };
-
- // Send request to llama.cpp server
- let client = Client::builder()
- .timeout(Duration::from_secs(500)) // 2 minute timeout
- .build()
- .map_err(|e| {
- error!("Error creating HTTP client: {}", e);
- actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
- })?;
-
- let response = client
- .post(&format!("{}/v1/completion", llama_url))
- .header("Content-Type", "application/json")
- .json(&llama_request)
- .send()
- .await
- .map_err(|e| {
- error!("Error calling llama.cpp server: {}", e);
- actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
- })?;
-
- let status = response.status();
-
- if status.is_success() {
- let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
- error!("Error parsing llama.cpp response: {}", e);
- actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
- })?;
-
- // Convert llama.cpp response to OpenAI format
- let openai_response = ChatCompletionResponse {
- id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
- object: "chat.completion".to_string(),
- created: std::time::SystemTime::now()
- .duration_since(std::time::UNIX_EPOCH)
- .unwrap()
- .as_secs(),
- model: req_body.model.clone(),
- choices: vec![Choice {
- message: ChatMessage {
- role: "assistant".to_string(),
- content: llama_response.content.trim().to_string(),
- },
- finish_reason: if llama_response.stop {
- "stop".to_string()
- } else {
- "length".to_string()
- },
- }],
- };
-
- Ok(HttpResponse::Ok().json(openai_response))
- } else {
- let error_text = response
- .text()
- .await
- .unwrap_or_else(|_| "Unknown error".to_string());
-
- error!("Llama.cpp server error ({}): {}", status, error_text);
-
- let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
- .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
-
- Ok(HttpResponse::build(actix_status).json(serde_json::json!({
- "error": {
- "message": error_text,
- "type": "server_error"
- }
- })))
- }
-}
-
-// OpenAI Embedding Request - Modified to handle both string and array inputs
-#[derive(Debug, Deserialize)]
-pub struct EmbeddingRequest {
- #[serde(deserialize_with = "deserialize_input")]
- pub input: Vec,
- pub model: String,
- #[serde(default)]
- pub _encoding_format: Option,
-}
-
-// Custom deserializer to handle both string and array inputs
-fn deserialize_input<'de, D>(deserializer: D) -> Result, D::Error>
-where
- D: serde::Deserializer<'de>,
-{
- use serde::de::{self, Visitor};
- use std::fmt;
-
- struct InputVisitor;
-
- impl<'de> Visitor<'de> for InputVisitor {
- type Value = Vec;
-
- fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
- formatter.write_str("a string or an array of strings")
- }
-
- fn visit_str(self, value: &str) -> Result
- where
- E: de::Error,
- {
- Ok(vec![value.to_string()])
- }
-
- fn visit_string(self, value: String) -> Result
- where
- E: de::Error,
- {
- Ok(vec![value])
- }
-
- fn visit_seq(self, mut seq: A) -> Result
- where
- A: de::SeqAccess<'de>,
- {
- let mut vec = Vec::new();
- while let Some(value) = seq.next_element::()? {
- vec.push(value);
- }
- Ok(vec)
- }
- }
-
- deserializer.deserialize_any(InputVisitor)
-}
-
-// OpenAI Embedding Response
-#[derive(Debug, Serialize)]
-pub struct EmbeddingResponse {
- pub object: String,
- pub data: Vec,
- pub model: String,
- pub usage: Usage,
-}
-
-#[derive(Debug, Serialize)]
-pub struct EmbeddingData {
- pub object: String,
- pub embedding: Vec,
- pub index: usize,
-}
-
-#[derive(Debug, Serialize)]
-pub struct Usage {
- pub prompt_tokens: u32,
- pub total_tokens: u32,
-}
-
-// Llama.cpp Embedding Request
-#[derive(Debug, Serialize)]
-struct LlamaCppEmbeddingRequest {
- pub content: String,
-}
-
-// FIXED: Handle the stupid nested array format
-#[derive(Debug, Deserialize)]
-struct LlamaCppEmbeddingResponseItem {
- pub index: usize,
- pub embedding: Vec>, // This is the up part - embedding is an array of arrays
-}
-
-// Proxy endpoint for embeddings
-#[post("/v1/embeddings")]
-pub async fn embeddings_local(
- req_body: web::Json,
- _req: HttpRequest,
-) -> Result {
- dotenv().ok();
-
- // Get llama.cpp server URL
- let llama_url =
- env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
-
- let client = Client::builder()
- .timeout(Duration::from_secs(120))
- .build()
- .map_err(|e| {
- error!("Error creating HTTP client: {}", e);
- actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
- })?;
-
- // Process each input text and get embeddings
- let mut embeddings_data = Vec::new();
- let mut total_tokens = 0;
-
- for (index, input_text) in req_body.input.iter().enumerate() {
- let llama_request = LlamaCppEmbeddingRequest {
- content: input_text.clone(),
- };
-
- let response = client
- .post(&format!("{}/embedding", llama_url))
- .header("Content-Type", "application/json")
- .json(&llama_request)
- .send()
- .await
- .map_err(|e| {
- error!("Error calling llama.cpp server for embedding: {}", e);
- actix_web::error::ErrorInternalServerError(
- "Failed to call llama.cpp server for embedding",
- )
- })?;
-
- let status = response.status();
-
- if status.is_success() {
- // First, get the raw response text for debugging
- let raw_response = response.text().await.map_err(|e| {
- error!("Error reading response text: {}", e);
- actix_web::error::ErrorInternalServerError("Failed to read response")
- })?;
-
- // Parse the response as a vector of items with nested arrays
- let llama_response: Vec =
- serde_json::from_str(&raw_response).map_err(|e| {
- error!("Error parsing llama.cpp embedding response: {}", e);
- error!("Raw response: {}", raw_response);
- actix_web::error::ErrorInternalServerError(
- "Failed to parse llama.cpp embedding response",
- )
- })?;
-
- // Extract the embedding from the nested array bullshit
- if let Some(item) = llama_response.get(0) {
- // The embedding field contains Vec>, so we need to flatten it
- // If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
- let flattened_embedding = if !item.embedding.is_empty() {
- item.embedding[0].clone() // Take the first (and probably only) inner array
- } else {
- vec![] // Empty if no embedding data
- };
-
- // Estimate token count
- let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
- total_tokens += estimated_tokens;
-
- embeddings_data.push(EmbeddingData {
- object: "embedding".to_string(),
- embedding: flattened_embedding,
- index,
- });
- } else {
- error!("No embedding data returned for input: {}", input_text);
- return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
- "error": {
- "message": format!("No embedding data returned for input {}", index),
- "type": "server_error"
- }
- })));
- }
- } else {
- let error_text = response
- .text()
- .await
- .unwrap_or_else(|_| "Unknown error".to_string());
-
- error!("Llama.cpp server error ({}): {}", status, error_text);
-
- let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
- .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
-
- return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
- "error": {
- "message": format!("Failed to get embedding for input {}: {}", index, error_text),
- "type": "server_error"
- }
- })));
- }
- }
-
- // Build OpenAI-compatible response
- let openai_response = EmbeddingResponse {
- object: "list".to_string(),
- data: embeddings_data,
- model: req_body.model.clone(),
- usage: Usage {
- prompt_tokens: total_tokens,
- total_tokens,
- },
- };
-
- Ok(HttpResponse::Ok().json(openai_response))
-}
diff --git a/src/llm_legacy/mod.rs b/src/llm_legacy/mod.rs
deleted file mode 100644
index aeb80674..00000000
--- a/src/llm_legacy/mod.rs
+++ /dev/null
@@ -1,2 +0,0 @@
-pub mod llm_azure;
-pub mod llm_local;
diff --git a/src/llm_models/deepseek_r3.rs b/src/llm_models/deepseek_r3.rs
new file mode 100644
index 00000000..686a32a6
--- /dev/null
+++ b/src/llm_models/deepseek_r3.rs
@@ -0,0 +1,17 @@
+use super::ModelHandler;
+
+pub struct DeepseekR3Handler;
+
+impl ModelHandler for DeepseekR3Handler {
+ fn is_analysis_complete(&self, buffer: &str) -> bool {
+ buffer.contains("")
+ }
+
+ fn process_content(&self, content: &str) -> String {
+ content.replace("", "").replace("", "")
+ }
+
+ fn has_analysis_markers(&self, buffer: &str) -> bool {
+ buffer.contains("")
+ }
+}
diff --git a/src/llm_models/gpt_oss_120b.rs b/src/llm_models/gpt_oss_120b.rs
new file mode 100644
index 00000000..0efc28e6
--- /dev/null
+++ b/src/llm_models/gpt_oss_120b.rs
@@ -0,0 +1,31 @@
+use super::ModelHandler;
+
+pub struct GptOss120bHandler {
+ model_name: String,
+}
+
+impl GptOss120bHandler {
+ pub fn new(model_name: &str) -> Self {
+ Self {
+ model_name: model_name.to_string(),
+ }
+ }
+}
+
+impl ModelHandler for GptOss120bHandler {
+ fn is_analysis_complete(&self, buffer: &str) -> bool {
+ // GPT-120B uses explicit end marker
+ buffer.contains("**end**")
+ }
+
+ fn process_content(&self, content: &str) -> String {
+ // Remove both start and end markers from final output
+ content.replace("**start**", "")
+ .replace("**end**", "")
+ }
+
+ fn has_analysis_markers(&self, buffer: &str) -> bool {
+ // GPT-120B uses explicit start marker
+ buffer.contains("**start**")
+ }
+}
diff --git a/src/llm_models/gpt_oss_20b.rs b/src/llm_models/gpt_oss_20b.rs
new file mode 100644
index 00000000..70ddea50
--- /dev/null
+++ b/src/llm_models/gpt_oss_20b.rs
@@ -0,0 +1,21 @@
+use super::ModelHandler;
+
+pub struct GptOss20bHandler;
+
+impl ModelHandler for GptOss20bHandler {
+ fn is_analysis_complete(&self, buffer: &str) -> bool {
+ buffer.ends_with("final")
+ }
+
+ fn process_content(&self, content: &str) -> String {
+ if let Some(pos) = content.find("final") {
+ content[..pos].to_string()
+ } else {
+ content.to_string()
+ }
+ }
+
+ fn has_analysis_markers(&self, buffer: &str) -> bool {
+ buffer.contains("**")
+ }
+}
diff --git a/src/llm_models/mod.rs b/src/llm_models/mod.rs
new file mode 100644
index 00000000..860ed44c
--- /dev/null
+++ b/src/llm_models/mod.rs
@@ -0,0 +1,33 @@
+//! Module for handling model-specific behavior and token processing
+
+pub mod gpt_oss_20b;
+pub mod deepseek_r3;
+pub mod gpt_oss_120b;
+
+
+/// Trait for model-specific token processing
+pub trait ModelHandler: Send + Sync {
+ /// Check if the analysis buffer indicates completion
+ fn is_analysis_complete(&self, buffer: &str) -> bool;
+
+ /// Process the content, removing any model-specific tokens
+ fn process_content(&self, content: &str) -> String;
+
+ /// Check if the buffer contains analysis start markers
+ fn has_analysis_markers(&self, buffer: &str) -> bool;
+}
+
+/// Get the appropriate handler based on model path from bot configuration
+pub fn get_handler(model_path: &str) -> Box {
+ let path = model_path.to_lowercase();
+
+ if path.contains("deepseek") {
+ Box::new(deepseek_r3::DeepseekR3Handler)
+ } else if path.contains("120b") {
+ Box::new(gpt_oss_120b::GptOss120bHandler::new("default"))
+ } else if path.contains("gpt-oss") || path.contains("gpt") {
+ Box::new(gpt_oss_20b::GptOss20bHandler)
+ } else {
+ Box::new(gpt_oss_20b::GptOss20bHandler)
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 66e4462c..92f38eea 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,7 +25,7 @@ mod ui;
mod file;
mod kb;
mod llm;
-mod llm_legacy;
+mod llm_models;
mod meet;
mod org;
mod package_manager;
@@ -49,7 +49,7 @@ use crate::email::{
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
};
use crate::file::{init_drive, upload_file};
-use crate::llm_legacy::llm_local::{
+use crate::llm::local::{
chat_completions_local, embeddings_local, ensure_llama_servers_running,
};
use crate::meet::{voice_start, voice_stop};
@@ -124,8 +124,8 @@ async fn main() -> std::io::Result<()> {
&std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
) {
- Ok(mut conn) => AppConfig::from_database(&mut conn),
- Err(_) => AppConfig::from_env(),
+ Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
+ Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
}
} else {
match bootstrap.bootstrap().await {
@@ -135,13 +135,13 @@ async fn main() -> std::io::Result<()> {
}
Err(e) => {
log::error!("Bootstrap failed: {}", e);
- match diesel::Connection::establish(
- &std::env::var("DATABASE_URL")
- .unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
- ) {
- Ok(mut conn) => AppConfig::from_database(&mut conn),
- Err(_) => AppConfig::from_env(),
- }
+ match diesel::Connection::establish(
+ &std::env::var("DATABASE_URL")
+ .unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
+ ) {
+ Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
+ Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
+ }
}
}
};
@@ -158,7 +158,7 @@ async fn main() -> std::io::Result<()> {
// Refresh configuration from environment to ensure latest DATABASE_URL and credentials
dotenv().ok();
- let refreshed_cfg = AppConfig::from_env();
+ let refreshed_cfg = AppConfig::from_env().expect("Failed to load config from env");
let config = std::sync::Arc::new(refreshed_cfg.clone());
let db_pool = match diesel::Connection::establish(&refreshed_cfg.database_url()) {
Ok(conn) => Arc::new(Mutex::new(conn)),
diff --git a/src/shared/models.rs b/src/shared/models.rs
index 75a2e2da..92a1afb7 100644
--- a/src/shared/models.rs
+++ b/src/shared/models.rs
@@ -411,6 +411,9 @@ pub mod schema {
bot_id -> Uuid,
config_key -> Text,
config_value -> Text,
+
+ is_encrypted -> Bool,
+
config_type -> Text,
created_at -> Timestamptz,
updated_at -> Timestamptz,
diff --git a/src/shared/utils.rs b/src/shared/utils.rs
index abb39c01..dfba3f20 100644
--- a/src/shared/utils.rs
+++ b/src/shared/utils.rs
@@ -34,6 +34,19 @@ pub fn extract_zip_recursive(
if !parent.exists() {
std::fs::create_dir_all(&parent)?;
}
+
+use crate::llm::LLMProvider;
+use std::sync::Arc;
+use serde_json::Value;
+
+/// Unified chat utility to interact with any LLM provider
+pub async fn chat_with_llm(
+ provider: Arc,
+ prompt: &str,
+ config: &Value,
+) -> Result> {
+ provider.generate(prompt, config).await
+}
}
let mut outfile = File::create(&outpath)?;
std::io::copy(&mut file, &mut outfile)?;