feat: enforce config load errors and add dynamic LLM model handling

- Updated `BootstrapManager` to use `AppConfig::from_env().expect(...)` and `AppConfig::from_database(...).expect(...)` ensuring failures are explicit rather than silently ignored.
- Refactored error propagation in bootstrap flow to use `?` where appropriate, improving reliability of configuration loading.
- Added import of `llm_models` in `bot` module and introduced `ConfigManager` usage to fetch the LLM model identifier at runtime.
- Integrated dynamic LLM model handler selection via `llm_models::get_handler(&model)`.
- Replaced static environment variable retrieval for embedding configuration with runtime
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-11-02 18:36:21 -03:00
parent e7fa2f7bdb
commit fea1574518
18 changed files with 766 additions and 1181 deletions

View file

@ -43,7 +43,7 @@ impl BootstrapManager {
"Initializing BootstrapManager with mode {:?} and tenant {:?}",
install_mode, tenant
);
let config = AppConfig::from_env();
let config = AppConfig::from_env().expect("Failed to load config from env");
let s3_client = futures::executor::block_on(Self::create_s3_operator(&config));
Self {
install_mode,
@ -179,11 +179,11 @@ impl BootstrapManager {
if let Err(e) = self.apply_migrations(&mut conn) {
log::warn!("Failed to apply migrations: {}", e);
}
return Ok(AppConfig::from_database(&mut conn));
return Ok(AppConfig::from_database(&mut conn).expect("Failed to load config from DB"));
}
Err(e) => {
log::warn!("Failed to connect to database: {}", e);
return Ok(AppConfig::from_env());
return Ok(AppConfig::from_env()?);
}
}
}
@ -191,7 +191,7 @@ impl BootstrapManager {
let pm = PackageManager::new(self.install_mode.clone(), self.tenant.clone())?;
let required_components = vec!["tables", "drive", "cache", "llm"];
let mut config = AppConfig::from_env();
let mut config = AppConfig::from_env().expect("Failed to load config from env");
for component in required_components {
if !pm.is_installed(component) {
@ -282,7 +282,7 @@ impl BootstrapManager {
);
}
config = AppConfig::from_database(&mut conn);
config = AppConfig::from_database(&mut conn).expect("Failed to load config from DB");
}
}
}
@ -513,9 +513,9 @@ impl BootstrapManager {
// Create fresh connection for final config load
let mut final_conn = establish_pg_connection()?;
let config = AppConfig::from_database(&mut final_conn);
info!("Successfully loaded config from CSV with LLM settings");
Ok(config)
let config = AppConfig::from_database(&mut final_conn)?;
info!("Successfully loaded config from CSV with LLM settings");
Ok(config)
}
Err(e) => {
debug!("No config.csv found: {}", e);

View file

@ -4,6 +4,7 @@ use crate::context::langcache::get_langcache_client;
use crate::drive_monitor::DriveMonitor;
use crate::kb::embeddings::generate_embeddings;
use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint};
use crate::llm_models;
use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession};
use crate::shared::state::AppState;
use actix_web::{web, HttpRequest, HttpResponse, Result};
@ -680,7 +681,6 @@ impl BotOrchestrator {
e
})?;
let session = {
let mut sm = self.state.session_manager.lock().await;
let session_id = Uuid::parse_str(&message.session_id).map_err(|e| {
@ -795,6 +795,17 @@ impl BotOrchestrator {
let mut chunk_count = 0;
let mut first_word_received = false;
let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
let model = config_manager
.get_config(
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
"llm-model",
None,
)
.unwrap_or_default();
let handler = llm_models::get_handler(&model);
while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1;
@ -804,63 +815,56 @@ impl BotOrchestrator {
analysis_buffer.push_str(&chunk);
if (analysis_buffer.contains("**") || analysis_buffer.contains("<think>"))
&& !in_analysis
{
// Check for analysis markers
if handler.has_analysis_markers(&analysis_buffer) && !in_analysis {
in_analysis = true;
}
if in_analysis {
if analysis_buffer.ends_with("final") || analysis_buffer.ends_with("</think>") {
info!(
"Analysis section completed, buffer length: {}",
analysis_buffer.len()
);
in_analysis = false;
analysis_buffer.clear();
// Check if analysis is complete
if in_analysis && handler.is_analysis_complete(&analysis_buffer) {
info!("Analysis section completed");
in_analysis = false;
analysis_buffer.clear();
if message.channel == "web" {
let orchestrator = BotOrchestrator::new(Arc::clone(&self.state));
orchestrator
.send_event(
&message.user_id,
&message.bot_id,
&message.session_id,
&message.channel,
"thinking_end",
serde_json::json!({
"user_id": message.user_id.clone()
}),
)
.await
.ok();
}
analysis_buffer.clear();
continue;
if message.channel == "web" {
let orchestrator = BotOrchestrator::new(Arc::clone(&self.state));
orchestrator
.send_event(
&message.user_id,
&message.bot_id,
&message.session_id,
&message.channel,
"thinking_end",
serde_json::json!({"user_id": message.user_id.clone()}),
)
.await
.ok();
}
continue;
}
full_response.push_str(&chunk);
if !in_analysis {
full_response.push_str(&chunk);
let partial = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: chunk,
message_type: 1,
stream_token: None,
is_complete: false,
suggestions: suggestions.clone(),
context_name: None,
context_length: 0,
context_max_length: 0,
};
let partial = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: chunk,
message_type: 1,
stream_token: None,
is_complete: false,
suggestions: suggestions.clone(),
context_name: None,
context_length: 0,
context_max_length: 0,
};
if response_tx.send(partial).await.is_err() {
warn!("Response channel closed, stopping stream processing");
break;
if response_tx.send(partial).await.is_err() {
warn!("Response channel closed, stopping stream processing");
break;
}
}
}
@ -1369,7 +1373,10 @@ async fn websocket_handler(
// Cancel any ongoing LLM jobs for this session
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
warn!(
"Failed to cancel LLM job for session {}: {}",
session_id_clone2, e
);
}
info!("WebSocket fully closed for session {}", session_id_clone2);

View file

@ -1,14 +1,19 @@
use diesel::prelude::*;
use diesel::pg::PgConnection;
use crate::shared::models::schema::bot_configuration;
use diesel::sql_types::Text;
use uuid::Uuid;
use diesel::pg::Pg;
use log::{info, trace, warn};
use serde::{Deserialize, Serialize};
// removed unused serde import
use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::Write;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use crate::shared::utils::establish_pg_connection;
#[derive(Clone)]
#[derive(Clone, Default)]
pub struct LLMConfig {
pub url: String,
pub key: String,
@ -20,7 +25,6 @@ pub struct AppConfig {
pub drive: DriveConfig,
pub server: ServerConfig,
pub database: DatabaseConfig,
pub database_custom: DatabaseConfig,
pub email: EmailConfig,
pub llm: LLMConfig,
pub embedding: LLMConfig,
@ -61,17 +65,21 @@ pub struct EmailConfig {
pub password: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, QueryableByName)]
#[derive(Debug, Clone, Queryable, Selectable)]
#[diesel(table_name = bot_configuration)]
#[diesel(check_for_backend(Pg))]
pub struct ServerConfigRow {
#[diesel(sql_type = Text)]
pub id: String,
#[diesel(sql_type = DieselUuid)]
pub id: Uuid,
#[diesel(sql_type = DieselUuid)]
pub bot_id: Uuid,
#[diesel(sql_type = Text)]
pub config_key: String,
#[diesel(sql_type = Text)]
pub config_value: String,
#[diesel(sql_type = Text)]
pub config_type: String,
#[diesel(sql_type = diesel::sql_types::Bool)]
#[diesel(sql_type = Bool)]
pub is_encrypted: bool,
}
@ -87,17 +95,6 @@ impl AppConfig {
)
}
pub fn database_custom_url(&self) -> String {
format!(
"postgres://{}:{}@{}:{}/{}",
self.database_custom.username,
self.database_custom.password,
self.database_custom.server,
self.database_custom.port,
self.database_custom.database
)
}
pub fn component_path(&self, component: &str) -> PathBuf {
self.stack_path.join(component)
}
@ -117,125 +114,134 @@ impl AppConfig {
pub fn log_path(&self, component: &str) -> PathBuf {
self.stack_path.join("logs").join(component)
}
}
pub fn from_database(_conn: &mut PgConnection) -> Self {
info!("Loading configuration from database");
// Config is now loaded from bot_configuration table via drive_monitor
let config_map: HashMap<String, ServerConfigRow> = HashMap::new();
impl AppConfig {
pub fn from_database(conn: &mut PgConnection) -> Result<Self, diesel::result::Error> {
info!("Loading configuration from database");
let get_str = |key: &str, default: &str| -> String {
config_map
.get(key)
.map(|v: &ServerConfigRow| v.config_value.clone())
.unwrap_or_else(|| default.to_string())
};
use crate::shared::models::schema::bot_configuration::dsl::*;
use diesel::prelude::*;
let get_u32 = |key: &str, default: u32| -> u32 {
config_map
.get(key)
.and_then(|v| v.config_value.parse().ok())
.unwrap_or(default)
};
let config_map: HashMap<String, ServerConfigRow> = bot_configuration
.select(ServerConfigRow::as_select()).load::<ServerConfigRow>(conn)
.unwrap_or_default()
.into_iter()
.map(|row| (row.config_key.clone(), row))
.collect();
let get_u16 = |key: &str, default: u16| -> u16 {
config_map
.get(key)
.and_then(|v| v.config_value.parse().ok())
.unwrap_or(default)
};
let mut get_str = |key: &str, default: &str| -> String {
bot_configuration
.filter(config_key.eq(key))
.select(config_value)
.first::<String>(conn)
.unwrap_or_else(|_| default.to_string())
};
let get_bool = |key: &str, default: bool| -> bool {
config_map
.get(key)
.map(|v| v.config_value.to_lowercase() == "true")
.unwrap_or(default)
};
let get_u32 = |key: &str, default: u32| -> u32 {
config_map
.get(key)
.and_then(|v| v.config_value.parse().ok())
.unwrap_or(default)
};
let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack"));
let get_u16 = |key: &str, default: u16| -> u16 {
config_map
.get(key)
.and_then(|v| v.config_value.parse().ok())
.unwrap_or(default)
};
let database = DatabaseConfig {
username: std::env::var("TABLES_USERNAME")
.unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")),
password: std::env::var("TABLES_PASSWORD")
.unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")),
server: std::env::var("TABLES_SERVER")
.unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")),
port: std::env::var("TABLES_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or_else(|| get_u32("TABLES_PORT", 5432)),
database: std::env::var("TABLES_DATABASE")
.unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")),
};
let get_bool = |key: &str, default: bool| -> bool {
config_map
.get(key)
.map(|v| v.config_value.to_lowercase() == "true")
.unwrap_or(default)
};
let database_custom = DatabaseConfig {
username: std::env::var("CUSTOM_USERNAME")
.unwrap_or_else(|_| get_str("CUSTOM_USERNAME", "gbuser")),
password: std::env::var("CUSTOM_PASSWORD")
.unwrap_or_else(|_| get_str("CUSTOM_PASSWORD", "")),
server: std::env::var("CUSTOM_SERVER")
.unwrap_or_else(|_| get_str("CUSTOM_SERVER", "localhost")),
port: std::env::var("CUSTOM_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or_else(|| get_u32("CUSTOM_PORT", 5432)),
database: std::env::var("CUSTOM_DATABASE")
.unwrap_or_else(|_| get_str("CUSTOM_DATABASE", "gbuser")),
};
let stack_path = PathBuf::from(get_str("STACK_PATH", "./botserver-stack"));
let minio = DriveConfig {
server: {
let server = get_str("DRIVE_SERVER", "http://localhost:9000");
if !server.starts_with("http://") && !server.starts_with("https://") {
format!("http://{}", server)
} else {
server
}
},
access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"),
secret_key: get_str("DRIVE_SECRET", "minioadmin"),
use_ssl: get_bool("DRIVE_USE_SSL", false),
};
let database = DatabaseConfig {
username: std::env::var("TABLES_USERNAME")
.unwrap_or_else(|_| get_str("TABLES_USERNAME", "gbuser")),
password: std::env::var("TABLES_PASSWORD")
.unwrap_or_else(|_| get_str("TABLES_PASSWORD", "")),
server: std::env::var("TABLES_SERVER")
.unwrap_or_else(|_| get_str("TABLES_SERVER", "localhost")),
port: std::env::var("TABLES_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or_else(|| get_u32("TABLES_PORT", 5432)),
database: std::env::var("TABLES_DATABASE")
.unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")),
};
let email = EmailConfig {
from: get_str("EMAIL_FROM", "noreply@example.com"),
server: get_str("EMAIL_SERVER", "smtp.example.com"),
port: get_u16("EMAIL_PORT", 587),
username: get_str("EMAIL_USER", "user"),
password: get_str("EMAIL_PASS", "pass"),
};
// Write drive config to .env file
if let Err(e) = write_drive_config_to_env(&minio) {
warn!("Failed to write drive config to .env: {}", e);
}
let drive = DriveConfig {
server: {
let server = get_str("DRIVE_SERVER", "http://localhost:9000");
if !server.starts_with("http://") && !server.starts_with("https://") {
format!("http://{}", server)
} else {
server
}
},
access_key: get_str("DRIVE_ACCESSKEY", "minioadmin"),
secret_key: get_str("DRIVE_SECRET", "minioadmin"),
use_ssl: get_bool("DRIVE_USE_SSL", false),
};
AppConfig {
drive: minio,
let email = EmailConfig {
from: get_str("EMAIL_FROM", "noreply@example.com"),
server: get_str("EMAIL_SERVER", "smtp.example.com"),
port: get_u16("EMAIL_PORT", 587),
username: get_str("EMAIL_USER", "user"),
password: get_str("EMAIL_PASS", "pass"),
};
// Write drive config to .env file
if let Err(e) = write_drive_config_to_env(&drive) {
warn!("Failed to write drive config to .env: {}", e);
}
Ok(AppConfig {
drive,
server: ServerConfig {
host: get_str("SERVER_HOST", "127.0.0.1"),
port: get_u16("SERVER_PORT", 8080),
},
database,
database_custom,
email,
llm: LLMConfig {
url: get_str("LLM_URL", "http://localhost:8081"),
key: get_str("LLM_KEY", ""),
model: get_str("LLM_MODEL", "gpt-4"),
llm: {
// Use a fresh connection for ConfigManager to avoid cloning the mutable reference
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
LLMConfig {
url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
}
},
embedding: LLMConfig {
url: get_str("EMBEDDING_URL", "http://localhost:8082"),
key: get_str("EMBEDDING_KEY", ""),
model: get_str("EMBEDDING_MODEL", "text-embedding-ada-002"),
embedding: {
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
LLMConfig {
url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
}
},
site_path: {
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
ConfigManager::new(Arc::new(Mutex::new(fresh_conn)))
.get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?.to_string()
},
site_path: get_str("SITES_ROOT", "./botserver-stack/sites"),
stack_path,
db_conn: None,
}
}
})
}
pub fn from_env() -> Self {
pub fn from_env() -> Result<Self, anyhow::Error> {
info!("Loading configuration from environment variables");
let stack_path =
@ -254,16 +260,6 @@ impl AppConfig {
database: db_name,
};
let database_custom = DatabaseConfig {
username: std::env::var("CUSTOM_USERNAME").unwrap_or_else(|_| "gbuser".to_string()),
password: std::env::var("CUSTOM_PASSWORD").unwrap_or_else(|_| "".to_string()),
server: std::env::var("CUSTOM_SERVER").unwrap_or_else(|_| "localhost".to_string()),
port: std::env::var("CUSTOM_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(5432),
database: std::env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "botserver".to_string()),
};
let minio = DriveConfig {
server: std::env::var("DRIVE_SERVER")
@ -288,33 +284,43 @@ impl AppConfig {
password: std::env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()),
};
AppConfig {
Ok(AppConfig {
drive: minio,
server: ServerConfig {
host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.1".to_string()),
port: std::env::var("SERVER_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(8080),
},
database,
database_custom,
email,
llm: LLMConfig {
url: std::env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()),
key: std::env::var("LLM_KEY").unwrap_or_else(|_| "".to_string()),
model: std::env::var("LLM_MODEL").unwrap_or_else(|_| "gpt-4".to_string()),
llm: {
let conn = PgConnection::establish(&database_url)?;
let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
LLMConfig {
url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
}
},
embedding: LLMConfig {
url: std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()),
key: std::env::var("EMBEDDING_KEY").unwrap_or_else(|_| "".to_string()),
model: std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-ada-002".to_string()),
embedding: {
let conn = PgConnection::establish(&database_url)?;
let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
LLMConfig {
url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
}
},
site_path: {
let conn = PgConnection::establish(&database_url)?;
ConfigManager::new(Arc::new(Mutex::new(conn)))
.get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?
},
site_path: std::env::var("SITES_ROOT")
.unwrap_or_else(|_| "./botserver-stack/sites".to_string()),
stack_path: PathBuf::from(stack_path),
db_conn: None,
}
})
}
pub fn set_config(

View file

@ -62,10 +62,7 @@ impl DriveMonitor {
self.check_gbdialog_changes(client).await?;
self.check_gbkb_changes(client).await?;
if let Err(e) = self.check_gbot(client).await {
error!("Error checking default bot config: {}", e);
}
self.check_gbot(client).await?;
Ok(())
}
@ -257,7 +254,7 @@ impl DriveMonitor {
continue;
}
debug!("Checking config file at path: {}", path);
trace!("Checking config file at path: {}", path);
match client
.head_object()
.bucket(&self.bucket_name)
@ -276,7 +273,7 @@ impl DriveMonitor {
.key(&path)
.send()
.await?;
debug!(
trace!(
"GetObject successful for {}, content length: {}",
path,
response.content_length().unwrap_or(0)
@ -295,7 +292,7 @@ impl DriveMonitor {
.collect();
if !llm_lines.is_empty() {
use crate::llm_legacy::llm_local::ensure_llama_servers_running;
use crate::llm::local::ensure_llama_servers_running;
let mut restart_needed = false;
for line in llm_lines {
@ -338,7 +335,7 @@ impl DriveMonitor {
}
}
Err(e) => {
debug!("Config file {} not found or inaccessible: {}", path, e);
error!("Config file {} not found or inaccessible: {}", path, e);
}
}
}

View file

@ -1,62 +1,73 @@
use async_trait::async_trait;
use reqwest::Client;
use futures_util::StreamExt;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use futures_util::StreamExt;
use crate::tools::ToolManager;
use super::LLMProvider;
pub struct LlamaClient {
pub struct AnthropicClient {
client: Client,
api_key: String,
base_url: String,
}
impl LlamaClient {
pub fn new(base_url: String) -> Self {
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
Self {
client: Client::new(),
base_url,
api_key,
base_url: "https://api.anthropic.com/v1".to_string(),
}
}
}
#[async_trait]
impl LLMProvider for LlamaClient {
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
config: &Value,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let response = self.client
.post(&format!("{}/completion", self.base_url))
let response = self
.client
.post(&format!("{}/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
"prompt": prompt,
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [{"role": "user", "content": prompt}]
}))
.send()
.await?;
let result: Value = response.json().await?;
Ok(result["content"]
let content = result["content"][0]["text"]
.as_str()
.unwrap_or("")
.to_string())
.to_string();
Ok(content)
}
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = self.client
.post(&format!("{}/completion", self.base_url))
let response = self
.client
.post(&format!("{}/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
"prompt": prompt,
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [{"role": "user", "content": prompt}],
"stream": true
}))
.send()
@ -66,14 +77,18 @@ impl LLMProvider for LlamaClient {
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
let chunk_bytes = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk_bytes);
for line in chunk_str.lines() {
if let Ok(data) = serde_json::from_str::<Value>(&line) {
if let Some(content) = data["content"].as_str() {
buffer.push_str(content);
let _ = tx.send(content.to_string()).await;
if line.starts_with("data: ") {
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
if data["type"] == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
buffer.push_str(text);
let _ = tx.send(text.to_string()).await;
}
}
}
}
}
@ -85,7 +100,7 @@ impl LLMProvider for LlamaClient {
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
@ -98,26 +113,14 @@ impl LLMProvider for LlamaClient {
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
self.generate(&enhanced_prompt, config).await
self.generate(&enhanced_prompt, &Value::Null).await
}
async fn cancel_job(
&self,
session_id: &str,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// llama.cpp cancellation endpoint
let response = self.client
.post(&format!("{}/cancel", self.base_url))
.json(&serde_json::json!({
"session_id": session_id
}))
.send()
.await?;
if response.status().is_success() {
Ok(())
} else {
Err(format!("Failed to cancel job for session {}", session_id).into())
}
// Anthropic doesn't support job cancellation
Ok(())
}
}

103
src/llm/azure.rs Normal file
View file

@ -0,0 +1,103 @@
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::sync::Arc;
use crate::tools::ToolManager;
use super::LLMProvider;
pub struct AzureOpenAIClient {
endpoint: String,
api_key: String,
api_version: String,
deployment: String,
client: Client,
}
impl AzureOpenAIClient {
pub fn new(endpoint: String, api_key: String, api_version: String, deployment: String) -> Self {
Self {
endpoint,
api_key,
api_version,
deployment,
client: Client::new(),
}
}
}
#[async_trait]
impl LLMProvider for AzureOpenAIClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint, self.deployment, self.api_version
);
let body = serde_json::json!({
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 1.0,
"frequency_penalty": 0.0,
"presence_penalty": 0.0
});
let response = self.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
let result: Value = response.json().await?;
if let Some(choice) = result["choices"].get(0) {
Ok(choice["message"]["content"].as_str().unwrap_or("").to_string())
} else {
Err("No response from Azure OpenAI".into())
}
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: tokio::sync::mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let content = self.generate(prompt, _config).await?;
let _ = tx.send(content).await;
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
self.generate(&enhanced_prompt, _config).await
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}

269
src/llm/local.rs Normal file
View file

@ -0,0 +1,269 @@
use crate::shared::state::AppState;
use crate::shared::models::schema::bots::dsl::*;
use crate::config::ConfigManager;
use diesel::prelude::*;
use std::sync::Arc;
use log::{info, error};
use tokio;
use reqwest;
use actix_web::{post, web, HttpResponse, Result};
#[post("/api/chat/completions")]
pub async fn chat_completions_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
// Placeholder implementation
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "chat_completions_local not implemented" })))
}
#[post("/api/embeddings")]
pub async fn embeddings_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
// Placeholder implementation
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "embeddings_local not implemented" })))
}
pub async fn ensure_llama_servers_running(
app_state: &Arc<AppState>
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let conn = app_state.conn.clone();
let config_manager = ConfigManager::new(conn.clone());
let default_bot_id = {
let mut conn = conn.lock().unwrap();
bots.filter(name.eq("default"))
.select(id)
.first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil())
};
let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
info!("Starting LLM servers...");
info!("Configuration:");
info!(" LLM URL: {}", llm_url);
info!(" Embedding URL: {}", embedding_url);
info!(" LLM Model: {}", llm_model);
info!(" Embedding Model: {}", embedding_model);
info!(" LLM Server Path: {}", llm_server_path);
// Restart any existing llama-server processes
info!("Restarting any existing llama-server processes...");
if let Err(e) = tokio::process::Command::new("sh")
.arg("-c")
.arg("pkill -f llama-server || true")
.spawn()
{
error!("Failed to execute pkill for llama-server: {}", e);
} else {
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
info!("Existing llama-server processes terminated (if any)");
}
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await;
if llm_running && embedding_running {
info!("Both LLM and Embedding servers are already running");
return Ok(());
}
// Start servers that aren't running
let mut tasks = vec![];
if !llm_running && !llm_model.is_empty() {
info!("Starting LLM server...");
tasks.push(tokio::spawn(start_llm_server(
Arc::clone(app_state),
llm_server_path.clone(),
llm_model.clone(),
llm_url.clone(),
)));
} else if llm_model.is_empty() {
info!("LLM_MODEL not set, skipping LLM server");
}
if !embedding_running && !embedding_model.is_empty() {
info!("Starting Embedding server...");
tasks.push(tokio::spawn(start_embedding_server(
llm_server_path.clone(),
embedding_model.clone(),
embedding_url.clone(),
)));
} else if embedding_model.is_empty() {
info!("EMBEDDING_MODEL not set, skipping Embedding server");
}
// Wait for all server startup tasks
for task in tasks {
task.await??;
}
// Wait for servers to be ready with verbose logging
info!("Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model.is_empty();
let mut embedding_ready = embedding_running || embedding_model.is_empty();
let mut attempts = 0;
let max_attempts = 60; // 2 minutes total
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
info!("Checking server health (attempt {}/{})...", attempts + 1, max_attempts);
if !llm_ready && !llm_model.is_empty() {
if is_server_running(&llm_url).await {
info!("LLM server ready at {}", llm_url);
llm_ready = true;
} else {
info!("LLM server not ready yet");
}
}
if !embedding_ready && !embedding_model.is_empty() {
if is_server_running(&embedding_url).await {
info!("Embedding server ready at {}", embedding_url);
embedding_ready = true;
} else {
info!("Embedding server not ready yet");
}
}
attempts += 1;
if attempts % 10 == 0 {
info!("Still waiting for servers... (attempt {}/{})", attempts, max_attempts);
}
}
if llm_ready && embedding_ready {
info!("All llama.cpp servers are ready and responding!");
Ok(())
} else {
let mut error_msg = "Servers failed to start within timeout:".to_string();
if !llm_ready && !llm_model.is_empty() {
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
}
if !embedding_ready && !embedding_model.is_empty() {
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
}
Err(error_msg.into())
}
}
pub async fn is_server_running(url: &str) -> bool {
let client = reqwest::Client::new();
match client.get(&format!("{}/health", url)).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
pub async fn start_llm_server(
app_state: Arc<AppState>,
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = url.split(':').last().unwrap_or("8081");
std::env::set_var("OMP_NUM_THREADS", "20");
std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close");
let conn = app_state.conn.clone();
let config_manager = ConfigManager::new(conn.clone());
let default_bot_id = {
let mut conn = conn.lock().unwrap();
bots.filter(name.eq("default"))
.select(id)
.first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil())
};
let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
// Build command arguments dynamically
let mut args = format!(
"-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
model_path, port, ctx_size, gpu_layers
);
if n_moe != "0" {
args.push_str(&format!(" --n-cpu-moe {}", n_moe));
}
if parallel != "1" {
args.push_str(&format!(" --parallel {}", parallel));
}
if cont_batching == "true" {
args.push_str(" --cont-batching");
}
if mlock == "true" {
args.push_str(" --mlock");
}
if no_mmap == "true" {
args.push_str(" --no-mmap");
}
if cfg!(windows) {
let mut cmd = tokio::process::Command::new("cmd");
cmd.arg("/C").arg(format!(
"cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
llama_cpp_path, args
));
cmd.spawn()?;
} else {
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
llama_cpp_path, args
));
cmd.spawn()?;
}
Ok(())
}
pub async fn start_embedding_server(
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = url.split(':').last().unwrap_or("8082");
if cfg!(windows) {
let mut cmd = tokio::process::Command::new("cmd");
cmd.arg("/c").arg(format!(
"cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
} else {
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
}
Ok(())
}

View file

@ -6,7 +6,9 @@ use tokio::sync::mpsc;
use crate::tools::ToolManager;
pub mod llama;
pub mod azure;
pub mod local;
pub mod anthropic;
#[async_trait]
pub trait LLMProvider: Send + Sync {
@ -161,123 +163,6 @@ impl LLMProvider for OpenAIClient {
}
}
pub struct AnthropicClient {
client: reqwest::Client,
api_key: String,
base_url: String,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
Self {
client: reqwest::Client::new(),
api_key,
base_url: "https://api.anthropic.com/v1".to_string(),
}
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let response = self
.client
.post(&format!("{}/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [{"role": "user", "content": prompt}]
}))
.send()
.await?;
let result: Value = response.json().await?;
let content = result["content"][0]["text"]
.as_str()
.unwrap_or("")
.to_string();
Ok(content)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = self
.client
.post(&format!("{}/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.json(&serde_json::json!({
"model": "claude-3-sonnet-20240229",
"max_tokens": 1000,
"messages": [{"role": "user", "content": prompt}],
"stream": true
}))
.send()
.await?;
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
if line.starts_with("data: ") {
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
if data["type"] == "content_block_delta" {
if let Some(text) = data["delta"]["text"].as_str() {
buffer.push_str(text);
let _ = tx.send(text.to_string()).await;
}
}
}
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
self.generate(&enhanced_prompt, &Value::Null).await
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Anthropic doesn't support job cancellation
Ok(())
}
}
pub struct MockLLMProvider;

View file

@ -1,145 +0,0 @@
use dotenvy::dotenv;
use log::{error, info};
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct AzureOpenAIConfig {
pub endpoint: String,
pub api_key: String,
pub api_version: String,
pub deployment: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub messages: Vec<ChatMessage>,
pub temperature: f32,
pub max_tokens: Option<u32>,
pub top_p: f32,
pub frequency_penalty: f32,
pub presence_penalty: f32,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
pub struct AzureOpenAIClient {
config: AzureOpenAIConfig,
client: Client,
}
impl AzureOpenAIClient {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
dotenv().ok();
let endpoint =
std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
let api_key =
std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
let api_version = std::env::var("AZURE_OPENAI_API_VERSION")
.unwrap_or_else(|_| "2023-12-01-preview".to_string());
let deployment =
std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
let config = AzureOpenAIConfig {
endpoint,
api_key,
api_version,
deployment,
};
Ok(Self {
config,
client: Client::new(),
})
}
pub async fn chat_completions(
&self,
messages: Vec<ChatMessage>,
temperature: f32,
max_tokens: Option<u32>,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.config.endpoint, self.config.deployment, self.config.api_version
);
let request_body = ChatCompletionRequest {
messages,
temperature,
max_tokens,
top_p: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
};
info!("Sending request to Azure OpenAI: {}", url);
let response = self
.client
.post(&url)
.header("api-key", &self.config.api_key)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
error!("Azure OpenAI API error: {}", error_text);
return Err(format!("Azure OpenAI API error: {}", error_text).into());
}
let completion_response: ChatCompletionResponse = response.json().await?;
Ok(completion_response)
}
pub async fn simple_chat(&self, prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are a helpful assistant.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt.to_string(),
},
];
let response = self.chat_completions(messages, 0.7, Some(1000)).await?;
if let Some(choice) = response.choices.first() {
Ok(choice.message.content.clone())
} else {
Err("No response from AI".into())
}
}
}

View file

@ -1,656 +0,0 @@
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenvy::dotenv;
use log::{error, info};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
use crate::config::ConfigManager;
use crate::shared::models::schema::bots::dsl::*;
use crate::shared::state::AppState;
use diesel::prelude::*;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
message: ChatMessage,
finish_reason: String,
}
// Llama.cpp server request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest {
prompt: String,
n_predict: Option<i32>,
temperature: Option<f32>,
top_k: Option<i32>,
top_p: Option<f32>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppResponse {
content: String,
stop: bool,
generation_settings: Option<serde_json::Value>,
}
pub async fn ensure_llama_servers_running(
app_state: &Arc<AppState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let conn = app_state.conn.clone();
let config_manager = ConfigManager::new(conn.clone());
// Get default bot ID from database
let default_bot_id = {
let mut conn = conn.lock().unwrap();
bots.filter(name.eq("default"))
.select(id)
.first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil())
};
// Get configuration from config using default bot ID
let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
info!(" Starting LLM servers...");
info!(" Configuration:");
info!(" LLM URL: {}", llm_url);
info!(" Embedding URL: {}", embedding_url);
info!(" LLM Model: {}", llm_model);
info!(" Embedding Model: {}", embedding_model);
info!(" LLM Server Path: {}", llm_server_path);
// Restart any existing llama-server processes before starting new ones
info!("🔁 Restarting any existing llama-server processes...");
if let Err(e) = tokio::process::Command::new("sh")
.arg("-c")
.arg("pkill -f llama-server || true")
.spawn()
{
error!("Failed to execute pkill for llama-server: {}", e);
} else {
sleep(Duration::from_secs(2)).await;
info!("✅ Existing llama-server processes terminated (if any)");
}
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await;
if llm_running && embedding_running {
info!("✅ Both LLM and Embedding servers are already running");
return Ok(());
}
// Start servers that aren't running
let mut tasks = vec![];
if !llm_running && !llm_model.is_empty() {
info!("🔄 Starting LLM server...");
tasks.push(tokio::spawn(start_llm_server(
Arc::clone(app_state),
llm_server_path.clone(),
llm_model.clone(),
llm_url.clone(),
)));
} else if llm_model.is_empty() {
info!("⚠️ LLM_MODEL not set, skipping LLM server");
}
if !embedding_running && !embedding_model.is_empty() {
info!("🔄 Starting Embedding server...");
tasks.push(tokio::spawn(start_embedding_server(
llm_server_path.clone(),
embedding_model.clone(),
embedding_url.clone(),
)));
} else if embedding_model.is_empty() {
info!("⚠️ EMBEDDING_MODEL not set, skipping Embedding server");
}
// Wait for all server startup tasks
for task in tasks {
task.await??;
}
// Wait for servers to be ready with verbose logging
info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model.is_empty();
let mut embedding_ready = embedding_running || embedding_model.is_empty();
let mut attempts = 0;
let max_attempts = 60; // 2 minutes total
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await;
info!(
"🔍 Checking server health (attempt {}/{})...",
attempts + 1,
max_attempts
);
if !llm_ready && !llm_model.is_empty() {
if is_server_running(&llm_url).await {
info!(" ✅ LLM server ready at {}", llm_url);
llm_ready = true;
} else {
info!(" ❌ LLM server not ready yet");
}
}
if !embedding_ready && !embedding_model.is_empty() {
if is_server_running(&embedding_url).await {
info!(" ✅ Embedding server ready at {}", embedding_url);
embedding_ready = true;
} else {
info!(" ❌ Embedding server not ready yet");
}
}
attempts += 1;
if attempts % 10 == 0 {
info!(
"⏰ Still waiting for servers... (attempt {}/{})",
attempts, max_attempts
);
}
}
if llm_ready && embedding_ready {
info!("🎉 All llama.cpp servers are ready and responding!");
Ok(())
} else {
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
if !llm_ready && !llm_model.is_empty() {
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
}
if !embedding_ready && !embedding_model.is_empty() {
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
}
Err(error_msg.into())
}
}
async fn start_llm_server(
app_state: Arc<AppState>,
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = url.split(':').last().unwrap_or("8081");
std::env::set_var("OMP_NUM_THREADS", "20");
std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close");
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
// Read config values with defaults
let conn = app_state.conn.clone();
let config_manager = ConfigManager::new(conn.clone());
let default_bot_id = {
let mut conn = conn.lock().unwrap();
bots.filter(name.eq("default"))
.select(id)
.first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil())
};
let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
// Build command arguments dynamically
let mut args = format!(
"-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
model_path, port, ctx_size, gpu_layers
);
if n_moe != "0" {
args.push_str(&format!(" --n-cpu-moe {}", n_moe));
}
if parallel != "1" {
args.push_str(&format!(" --parallel {}", parallel));
}
if cont_batching == "true" {
args.push_str(" --cont-batching");
}
if mlock == "true" {
args.push_str(" --mlock");
}
if no_mmap == "true" {
args.push_str(" --no-mmap");
}
if cfg!(windows) {
let mut cmd = tokio::process::Command::new("cmd");
cmd.arg("/C").arg(format!(
"cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
llama_cpp_path, args
));
cmd.spawn()?;
} else {
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
llama_cpp_path, args
));
cmd.spawn()?;
}
Ok(())
}
async fn start_embedding_server(
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = url.split(':').last().unwrap_or("8082");
if cfg!(windows) {
let mut cmd = tokio::process::Command::new("cmd");
cmd.arg("/c").arg(format!(
"cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
} else {
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
}
Ok(())
}
async fn is_server_running(url: &str) -> bool {
let client = reqwest::Client::new();
match client.get(&format!("{}/health", url)).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
// Convert OpenAI chat messages to a single prompt
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for message in messages {
match message.role.as_str() {
"system" => {
prompt.push_str(&format!("System: {}\n\n", message.content));
}
"user" => {
prompt.push_str(&format!("User: {}\n\n", message.content));
}
"assistant" => {
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
}
_ => {
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
}
}
}
prompt.push_str("Assistant: ");
prompt
}
// Proxy endpoint
#[post("/local/v1/chat/completions")]
pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest,
) -> Result<HttpResponse> {
dotenv().ok().unwrap();
// Get llama.cpp server URL
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
// Convert OpenAI format to llama.cpp format
let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest {
prompt,
n_predict: Some(500), // Adjust as needed
temperature: Some(0.7),
top_k: Some(40),
top_p: Some(0.9),
stream: req_body.stream,
};
// Send request to llama.cpp server
let client = Client::builder()
.timeout(Duration::from_secs(500)) // 2 minute timeout
.build()
.map_err(|e| {
error!("Error creating HTTP client: {}", e);
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?;
let response = client
.post(&format!("{}/v1/completion", llama_url))
.header("Content-Type", "application/json")
.json(&llama_request)
.send()
.await
.map_err(|e| {
error!("Error calling llama.cpp server: {}", e);
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
})?;
let status = response.status();
if status.is_success() {
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
error!("Error parsing llama.cpp response: {}", e);
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?;
// Convert llama.cpp response to OpenAI format
let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
model: req_body.model.clone(),
choices: vec![Choice {
message: ChatMessage {
role: "assistant".to_string(),
content: llama_response.content.trim().to_string(),
},
finish_reason: if llama_response.stop {
"stop".to_string()
} else {
"length".to_string()
},
}],
};
Ok(HttpResponse::Ok().json(openai_response))
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Llama.cpp server error ({}): {}", status, error_text);
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
"error": {
"message": error_text,
"type": "server_error"
}
})))
}
}
// OpenAI Embedding Request - Modified to handle both string and array inputs
#[derive(Debug, Deserialize)]
pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")]
pub input: Vec<String>,
pub model: String,
#[serde(default)]
pub _encoding_format: Option<String>,
}
// Custom deserializer to handle both string and array inputs
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
use std::fmt;
struct InputVisitor;
impl<'de> Visitor<'de> for InputVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or an array of strings")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(vec![value.to_string()])
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(vec![value])
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(value) = seq.next_element::<String>()? {
vec.push(value);
}
Ok(vec)
}
}
deserializer.deserialize_any(InputVisitor)
}
// OpenAI Embedding Response
#[derive(Debug, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
// Llama.cpp Embedding Request
#[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest {
pub content: String,
}
// FIXED: Handle the stupid nested array format
#[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem {
pub index: usize,
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
}
// Proxy endpoint for embeddings
#[post("/v1/embeddings")]
pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>,
_req: HttpRequest,
) -> Result<HttpResponse> {
dotenv().ok();
// Get llama.cpp server URL
let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| {
error!("Error creating HTTP client: {}", e);
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?;
// Process each input text and get embeddings
let mut embeddings_data = Vec::new();
let mut total_tokens = 0;
for (index, input_text) in req_body.input.iter().enumerate() {
let llama_request = LlamaCppEmbeddingRequest {
content: input_text.clone(),
};
let response = client
.post(&format!("{}/embedding", llama_url))
.header("Content-Type", "application/json")
.json(&llama_request)
.send()
.await
.map_err(|e| {
error!("Error calling llama.cpp server for embedding: {}", e);
actix_web::error::ErrorInternalServerError(
"Failed to call llama.cpp server for embedding",
)
})?;
let status = response.status();
if status.is_success() {
// First, get the raw response text for debugging
let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response")
})?;
// Parse the response as a vector of items with nested arrays
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e);
error!("Raw response: {}", raw_response);
actix_web::error::ErrorInternalServerError(
"Failed to parse llama.cpp embedding response",
)
})?;
// Extract the embedding from the nested array bullshit
if let Some(item) = llama_response.get(0) {
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone() // Take the first (and probably only) inner array
} else {
vec![] // Empty if no embedding data
};
// Estimate token count
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens;
embeddings_data.push(EmbeddingData {
object: "embedding".to_string(),
embedding: flattened_embedding,
index,
});
} else {
error!("No embedding data returned for input: {}", input_text);
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
"error": {
"message": format!("No embedding data returned for input {}", index),
"type": "server_error"
}
})));
}
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Llama.cpp server error ({}): {}", status, error_text);
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
"error": {
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
"type": "server_error"
}
})));
}
}
// Build OpenAI-compatible response
let openai_response = EmbeddingResponse {
object: "list".to_string(),
data: embeddings_data,
model: req_body.model.clone(),
usage: Usage {
prompt_tokens: total_tokens,
total_tokens,
},
};
Ok(HttpResponse::Ok().json(openai_response))
}

View file

@ -1,2 +0,0 @@
pub mod llm_azure;
pub mod llm_local;

View file

@ -0,0 +1,17 @@
use super::ModelHandler;
pub struct DeepseekR3Handler;
impl ModelHandler for DeepseekR3Handler {
fn is_analysis_complete(&self, buffer: &str) -> bool {
buffer.contains("</think>")
}
fn process_content(&self, content: &str) -> String {
content.replace("<think>", "").replace("</think>", "")
}
fn has_analysis_markers(&self, buffer: &str) -> bool {
buffer.contains("<think>")
}
}

View file

@ -0,0 +1,31 @@
use super::ModelHandler;
pub struct GptOss120bHandler {
model_name: String,
}
impl GptOss120bHandler {
pub fn new(model_name: &str) -> Self {
Self {
model_name: model_name.to_string(),
}
}
}
impl ModelHandler for GptOss120bHandler {
fn is_analysis_complete(&self, buffer: &str) -> bool {
// GPT-120B uses explicit end marker
buffer.contains("**end**")
}
fn process_content(&self, content: &str) -> String {
// Remove both start and end markers from final output
content.replace("**start**", "")
.replace("**end**", "")
}
fn has_analysis_markers(&self, buffer: &str) -> bool {
// GPT-120B uses explicit start marker
buffer.contains("**start**")
}
}

View file

@ -0,0 +1,21 @@
use super::ModelHandler;
pub struct GptOss20bHandler;
impl ModelHandler for GptOss20bHandler {
fn is_analysis_complete(&self, buffer: &str) -> bool {
buffer.ends_with("final")
}
fn process_content(&self, content: &str) -> String {
if let Some(pos) = content.find("final") {
content[..pos].to_string()
} else {
content.to_string()
}
}
fn has_analysis_markers(&self, buffer: &str) -> bool {
buffer.contains("**")
}
}

33
src/llm_models/mod.rs Normal file
View file

@ -0,0 +1,33 @@
//! Module for handling model-specific behavior and token processing
pub mod gpt_oss_20b;
pub mod deepseek_r3;
pub mod gpt_oss_120b;
/// Trait for model-specific token processing
pub trait ModelHandler: Send + Sync {
/// Check if the analysis buffer indicates completion
fn is_analysis_complete(&self, buffer: &str) -> bool;
/// Process the content, removing any model-specific tokens
fn process_content(&self, content: &str) -> String;
/// Check if the buffer contains analysis start markers
fn has_analysis_markers(&self, buffer: &str) -> bool;
}
/// Get the appropriate handler based on model path from bot configuration
pub fn get_handler(model_path: &str) -> Box<dyn ModelHandler> {
let path = model_path.to_lowercase();
if path.contains("deepseek") {
Box::new(deepseek_r3::DeepseekR3Handler)
} else if path.contains("120b") {
Box::new(gpt_oss_120b::GptOss120bHandler::new("default"))
} else if path.contains("gpt-oss") || path.contains("gpt") {
Box::new(gpt_oss_20b::GptOss20bHandler)
} else {
Box::new(gpt_oss_20b::GptOss20bHandler)
}
}

View file

@ -25,7 +25,7 @@ mod ui;
mod file;
mod kb;
mod llm;
mod llm_legacy;
mod llm_models;
mod meet;
mod org;
mod package_manager;
@ -49,7 +49,7 @@ use crate::email::{
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
};
use crate::file::{init_drive, upload_file};
use crate::llm_legacy::llm_local::{
use crate::llm::local::{
chat_completions_local, embeddings_local, ensure_llama_servers_running,
};
use crate::meet::{voice_start, voice_stop};
@ -124,8 +124,8 @@ async fn main() -> std::io::Result<()> {
&std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
) {
Ok(mut conn) => AppConfig::from_database(&mut conn),
Err(_) => AppConfig::from_env(),
Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
}
} else {
match bootstrap.bootstrap().await {
@ -135,13 +135,13 @@ async fn main() -> std::io::Result<()> {
}
Err(e) => {
log::error!("Bootstrap failed: {}", e);
match diesel::Connection::establish(
&std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
) {
Ok(mut conn) => AppConfig::from_database(&mut conn),
Err(_) => AppConfig::from_env(),
}
match diesel::Connection::establish(
&std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
) {
Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
}
}
}
};
@ -158,7 +158,7 @@ async fn main() -> std::io::Result<()> {
// Refresh configuration from environment to ensure latest DATABASE_URL and credentials
dotenv().ok();
let refreshed_cfg = AppConfig::from_env();
let refreshed_cfg = AppConfig::from_env().expect("Failed to load config from env");
let config = std::sync::Arc::new(refreshed_cfg.clone());
let db_pool = match diesel::Connection::establish(&refreshed_cfg.database_url()) {
Ok(conn) => Arc::new(Mutex::new(conn)),

View file

@ -411,6 +411,9 @@ pub mod schema {
bot_id -> Uuid,
config_key -> Text,
config_value -> Text,
is_encrypted -> Bool,
config_type -> Text,
created_at -> Timestamptz,
updated_at -> Timestamptz,

View file

@ -34,6 +34,19 @@ pub fn extract_zip_recursive(
if !parent.exists() {
std::fs::create_dir_all(&parent)?;
}
use crate::llm::LLMProvider;
use std::sync::Arc;
use serde_json::Value;
/// Unified chat utility to interact with any LLM provider
pub async fn chat_with_llm(
provider: Arc<dyn LLMProvider>,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
provider.generate(prompt, config).await
}
}
let mut outfile = File::create(&outpath)?;
std::io::copy(&mut file, &mut outfile)?;