feat: enforce config load errors and add dynamic LLM model handling
- Updated `BootstrapManager` to use `AppConfig::from_env().expect(...)` and `AppConfig::from_database(...).expect(...)` ensuring failures are explicit rather than silently ignored. - Refactored error propagation in bootstrap flow to use `?` where appropriate, improving reliability of configuration loading. - Added import of `llm_models` in `bot` module and introduced `ConfigManager` usage to fetch the LLM model identifier at runtime. - Integrated dynamic LLM model handler selection via `llm_models::get_handler(&model)`. - Replaced static environment variable retrieval for embedding configuration with runtime
This commit is contained in:
parent
e7fa2f7bdb
commit
fea1574518
18 changed files with 766 additions and 1181 deletions
|
|
@ -43,7 +43,7 @@ impl BootstrapManager {
|
|||
"Initializing BootstrapManager with mode {:?} and tenant {:?}",
|
||||
install_mode, tenant
|
||||
);
|
||||
let config = AppConfig::from_env();
|
||||
let config = AppConfig::from_env().expect("Failed to load config from env");
|
||||
let s3_client = futures::executor::block_on(Self::create_s3_operator(&config));
|
||||
Self {
|
||||
install_mode,
|
||||
|
|
@ -179,11 +179,11 @@ impl BootstrapManager {
|
|||
if let Err(e) = self.apply_migrations(&mut conn) {
|
||||
log::warn!("Failed to apply migrations: {}", e);
|
||||
}
|
||||
return Ok(AppConfig::from_database(&mut conn));
|
||||
return Ok(AppConfig::from_database(&mut conn).expect("Failed to load config from DB"));
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to connect to database: {}", e);
|
||||
return Ok(AppConfig::from_env());
|
||||
return Ok(AppConfig::from_env()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -191,7 +191,7 @@ impl BootstrapManager {
|
|||
|
||||
let pm = PackageManager::new(self.install_mode.clone(), self.tenant.clone())?;
|
||||
let required_components = vec!["tables", "drive", "cache", "llm"];
|
||||
let mut config = AppConfig::from_env();
|
||||
let mut config = AppConfig::from_env().expect("Failed to load config from env");
|
||||
|
||||
for component in required_components {
|
||||
if !pm.is_installed(component) {
|
||||
|
|
@ -282,7 +282,7 @@ impl BootstrapManager {
|
|||
);
|
||||
}
|
||||
|
||||
config = AppConfig::from_database(&mut conn);
|
||||
config = AppConfig::from_database(&mut conn).expect("Failed to load config from DB");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -513,7 +513,7 @@ impl BootstrapManager {
|
|||
|
||||
// Create fresh connection for final config load
|
||||
let mut final_conn = establish_pg_connection()?;
|
||||
let config = AppConfig::from_database(&mut final_conn);
|
||||
let config = AppConfig::from_database(&mut final_conn)?;
|
||||
info!("Successfully loaded config from CSV with LLM settings");
|
||||
Ok(config)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ use crate::context::langcache::get_langcache_client;
|
|||
use crate::drive_monitor::DriveMonitor;
|
||||
use crate::kb::embeddings::generate_embeddings;
|
||||
use crate::kb::qdrant_client::{ensure_collection_exists, get_qdrant_client, QdrantPoint};
|
||||
use crate::llm_models;
|
||||
use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession};
|
||||
use crate::shared::state::AppState;
|
||||
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||||
|
|
@ -680,7 +681,6 @@ impl BotOrchestrator {
|
|||
e
|
||||
})?;
|
||||
|
||||
|
||||
let session = {
|
||||
let mut sm = self.state.session_manager.lock().await;
|
||||
let session_id = Uuid::parse_str(&message.session_id).map_err(|e| {
|
||||
|
|
@ -795,6 +795,17 @@ impl BotOrchestrator {
|
|||
let mut chunk_count = 0;
|
||||
let mut first_word_received = false;
|
||||
|
||||
let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
|
||||
let model = config_manager
|
||||
.get_config(
|
||||
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
|
||||
"llm-model",
|
||||
None,
|
||||
)
|
||||
.unwrap_or_default();
|
||||
|
||||
let handler = llm_models::get_handler(&model);
|
||||
|
||||
while let Some(chunk) = stream_rx.recv().await {
|
||||
chunk_count += 1;
|
||||
|
||||
|
|
@ -804,18 +815,14 @@ impl BotOrchestrator {
|
|||
|
||||
analysis_buffer.push_str(&chunk);
|
||||
|
||||
if (analysis_buffer.contains("**") || analysis_buffer.contains("<think>"))
|
||||
&& !in_analysis
|
||||
{
|
||||
// Check for analysis markers
|
||||
if handler.has_analysis_markers(&analysis_buffer) && !in_analysis {
|
||||
in_analysis = true;
|
||||
}
|
||||
|
||||
if in_analysis {
|
||||
if analysis_buffer.ends_with("final") || analysis_buffer.ends_with("</think>") {
|
||||
info!(
|
||||
"Analysis section completed, buffer length: {}",
|
||||
analysis_buffer.len()
|
||||
);
|
||||
// Check if analysis is complete
|
||||
if in_analysis && handler.is_analysis_complete(&analysis_buffer) {
|
||||
info!("Analysis section completed");
|
||||
in_analysis = false;
|
||||
analysis_buffer.clear();
|
||||
|
||||
|
|
@ -828,19 +835,15 @@ impl BotOrchestrator {
|
|||
&message.session_id,
|
||||
&message.channel,
|
||||
"thinking_end",
|
||||
serde_json::json!({
|
||||
"user_id": message.user_id.clone()
|
||||
}),
|
||||
serde_json::json!({"user_id": message.user_id.clone()}),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
analysis_buffer.clear();
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if !in_analysis {
|
||||
full_response.push_str(&chunk);
|
||||
|
||||
let partial = BotResponse {
|
||||
|
|
@ -863,6 +866,7 @@ impl BotOrchestrator {
|
|||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Stream processing completed, {} chunks processed",
|
||||
|
|
@ -1369,7 +1373,10 @@ async fn websocket_handler(
|
|||
|
||||
// Cancel any ongoing LLM jobs for this session
|
||||
if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await {
|
||||
warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e);
|
||||
warn!(
|
||||
"Failed to cancel LLM job for session {}: {}",
|
||||
session_id_clone2, e
|
||||
);
|
||||
}
|
||||
|
||||
info!("WebSocket fully closed for session {}", session_id_clone2);
|
||||
|
|
|
|||
|
|
@ -1,14 +1,19 @@
|
|||
use diesel::prelude::*;
|
||||
use diesel::pg::PgConnection;
|
||||
use crate::shared::models::schema::bot_configuration;
|
||||
use diesel::sql_types::Text;
|
||||
use uuid::Uuid;
|
||||
use diesel::pg::Pg;
|
||||
use log::{info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
// removed unused serde import
|
||||
use std::collections::HashMap;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use crate::shared::utils::establish_pg_connection;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LLMConfig {
|
||||
pub url: String,
|
||||
pub key: String,
|
||||
|
|
@ -20,7 +25,6 @@ pub struct AppConfig {
|
|||
pub drive: DriveConfig,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub database_custom: DatabaseConfig,
|
||||
pub email: EmailConfig,
|
||||
pub llm: LLMConfig,
|
||||
pub embedding: LLMConfig,
|
||||
|
|
@ -61,17 +65,21 @@ pub struct EmailConfig {
|
|||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, QueryableByName)]
|
||||
#[derive(Debug, Clone, Queryable, Selectable)]
|
||||
#[diesel(table_name = bot_configuration)]
|
||||
#[diesel(check_for_backend(Pg))]
|
||||
pub struct ServerConfigRow {
|
||||
#[diesel(sql_type = Text)]
|
||||
pub id: String,
|
||||
#[diesel(sql_type = DieselUuid)]
|
||||
pub id: Uuid,
|
||||
#[diesel(sql_type = DieselUuid)]
|
||||
pub bot_id: Uuid,
|
||||
#[diesel(sql_type = Text)]
|
||||
pub config_key: String,
|
||||
#[diesel(sql_type = Text)]
|
||||
pub config_value: String,
|
||||
#[diesel(sql_type = Text)]
|
||||
pub config_type: String,
|
||||
#[diesel(sql_type = diesel::sql_types::Bool)]
|
||||
#[diesel(sql_type = Bool)]
|
||||
pub is_encrypted: bool,
|
||||
}
|
||||
|
||||
|
|
@ -87,17 +95,6 @@ impl AppConfig {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn database_custom_url(&self) -> String {
|
||||
format!(
|
||||
"postgres://{}:{}@{}:{}/{}",
|
||||
self.database_custom.username,
|
||||
self.database_custom.password,
|
||||
self.database_custom.server,
|
||||
self.database_custom.port,
|
||||
self.database_custom.database
|
||||
)
|
||||
}
|
||||
|
||||
pub fn component_path(&self, component: &str) -> PathBuf {
|
||||
self.stack_path.join(component)
|
||||
}
|
||||
|
|
@ -117,17 +114,28 @@ impl AppConfig {
|
|||
pub fn log_path(&self, component: &str) -> PathBuf {
|
||||
self.stack_path.join("logs").join(component)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_database(_conn: &mut PgConnection) -> Self {
|
||||
impl AppConfig {
|
||||
pub fn from_database(conn: &mut PgConnection) -> Result<Self, diesel::result::Error> {
|
||||
info!("Loading configuration from database");
|
||||
// Config is now loaded from bot_configuration table via drive_monitor
|
||||
let config_map: HashMap<String, ServerConfigRow> = HashMap::new();
|
||||
|
||||
let get_str = |key: &str, default: &str| -> String {
|
||||
config_map
|
||||
.get(key)
|
||||
.map(|v: &ServerConfigRow| v.config_value.clone())
|
||||
.unwrap_or_else(|| default.to_string())
|
||||
use crate::shared::models::schema::bot_configuration::dsl::*;
|
||||
use diesel::prelude::*;
|
||||
|
||||
let config_map: HashMap<String, ServerConfigRow> = bot_configuration
|
||||
.select(ServerConfigRow::as_select()).load::<ServerConfigRow>(conn)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|row| (row.config_key.clone(), row))
|
||||
.collect();
|
||||
|
||||
let mut get_str = |key: &str, default: &str| -> String {
|
||||
bot_configuration
|
||||
.filter(config_key.eq(key))
|
||||
.select(config_value)
|
||||
.first::<String>(conn)
|
||||
.unwrap_or_else(|_| default.to_string())
|
||||
};
|
||||
|
||||
let get_u32 = |key: &str, default: u32| -> u32 {
|
||||
|
|
@ -168,22 +176,8 @@ impl AppConfig {
|
|||
.unwrap_or_else(|_| get_str("TABLES_DATABASE", "botserver")),
|
||||
};
|
||||
|
||||
let database_custom = DatabaseConfig {
|
||||
username: std::env::var("CUSTOM_USERNAME")
|
||||
.unwrap_or_else(|_| get_str("CUSTOM_USERNAME", "gbuser")),
|
||||
password: std::env::var("CUSTOM_PASSWORD")
|
||||
.unwrap_or_else(|_| get_str("CUSTOM_PASSWORD", "")),
|
||||
server: std::env::var("CUSTOM_SERVER")
|
||||
.unwrap_or_else(|_| get_str("CUSTOM_SERVER", "localhost")),
|
||||
port: std::env::var("CUSTOM_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or_else(|| get_u32("CUSTOM_PORT", 5432)),
|
||||
database: std::env::var("CUSTOM_DATABASE")
|
||||
.unwrap_or_else(|_| get_str("CUSTOM_DATABASE", "gbuser")),
|
||||
};
|
||||
|
||||
let minio = DriveConfig {
|
||||
let drive = DriveConfig {
|
||||
server: {
|
||||
let server = get_str("DRIVE_SERVER", "http://localhost:9000");
|
||||
if !server.starts_with("http://") && !server.starts_with("https://") {
|
||||
|
|
@ -206,36 +200,48 @@ impl AppConfig {
|
|||
};
|
||||
|
||||
// Write drive config to .env file
|
||||
if let Err(e) = write_drive_config_to_env(&minio) {
|
||||
if let Err(e) = write_drive_config_to_env(&drive) {
|
||||
warn!("Failed to write drive config to .env: {}", e);
|
||||
}
|
||||
|
||||
AppConfig {
|
||||
drive: minio,
|
||||
Ok(AppConfig {
|
||||
drive,
|
||||
server: ServerConfig {
|
||||
host: get_str("SERVER_HOST", "127.0.0.1"),
|
||||
port: get_u16("SERVER_PORT", 8080),
|
||||
},
|
||||
database,
|
||||
database_custom,
|
||||
email,
|
||||
llm: LLMConfig {
|
||||
url: get_str("LLM_URL", "http://localhost:8081"),
|
||||
key: get_str("LLM_KEY", ""),
|
||||
model: get_str("LLM_MODEL", "gpt-4"),
|
||||
llm: {
|
||||
// Use a fresh connection for ConfigManager to avoid cloning the mutable reference
|
||||
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
|
||||
let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
|
||||
LLMConfig {
|
||||
url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
|
||||
key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
|
||||
model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
|
||||
}
|
||||
},
|
||||
embedding: LLMConfig {
|
||||
url: get_str("EMBEDDING_URL", "http://localhost:8082"),
|
||||
key: get_str("EMBEDDING_KEY", ""),
|
||||
model: get_str("EMBEDDING_MODEL", "text-embedding-ada-002"),
|
||||
embedding: {
|
||||
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
|
||||
let config = ConfigManager::new(Arc::new(Mutex::new(fresh_conn)));
|
||||
LLMConfig {
|
||||
url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
|
||||
key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
|
||||
model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
|
||||
}
|
||||
},
|
||||
site_path: {
|
||||
let fresh_conn = establish_pg_connection().map_err(|e| diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::UnableToSendCommand, Box::new(e.to_string())))?;
|
||||
ConfigManager::new(Arc::new(Mutex::new(fresh_conn)))
|
||||
.get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?.to_string()
|
||||
},
|
||||
site_path: get_str("SITES_ROOT", "./botserver-stack/sites"),
|
||||
stack_path,
|
||||
db_conn: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_env() -> Self {
|
||||
pub fn from_env() -> Result<Self, anyhow::Error> {
|
||||
info!("Loading configuration from environment variables");
|
||||
|
||||
let stack_path =
|
||||
|
|
@ -254,16 +260,6 @@ impl AppConfig {
|
|||
database: db_name,
|
||||
};
|
||||
|
||||
let database_custom = DatabaseConfig {
|
||||
username: std::env::var("CUSTOM_USERNAME").unwrap_or_else(|_| "gbuser".to_string()),
|
||||
password: std::env::var("CUSTOM_PASSWORD").unwrap_or_else(|_| "".to_string()),
|
||||
server: std::env::var("CUSTOM_SERVER").unwrap_or_else(|_| "localhost".to_string()),
|
||||
port: std::env::var("CUSTOM_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(5432),
|
||||
database: std::env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "botserver".to_string()),
|
||||
};
|
||||
|
||||
let minio = DriveConfig {
|
||||
server: std::env::var("DRIVE_SERVER")
|
||||
|
|
@ -288,33 +284,43 @@ impl AppConfig {
|
|||
password: std::env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()),
|
||||
};
|
||||
|
||||
AppConfig {
|
||||
Ok(AppConfig {
|
||||
drive: minio,
|
||||
server: ServerConfig {
|
||||
host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
|
||||
host: std::env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.1".to_string()),
|
||||
port: std::env::var("SERVER_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(8080),
|
||||
},
|
||||
database,
|
||||
database_custom,
|
||||
email,
|
||||
llm: LLMConfig {
|
||||
url: std::env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()),
|
||||
key: std::env::var("LLM_KEY").unwrap_or_else(|_| "".to_string()),
|
||||
model: std::env::var("LLM_MODEL").unwrap_or_else(|_| "gpt-4".to_string()),
|
||||
llm: {
|
||||
let conn = PgConnection::establish(&database_url)?;
|
||||
let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
|
||||
LLMConfig {
|
||||
url: config.get_config(&Uuid::nil(), "LLM_URL", Some("http://localhost:8081"))?,
|
||||
key: config.get_config(&Uuid::nil(), "LLM_KEY", Some(""))?,
|
||||
model: config.get_config(&Uuid::nil(), "LLM_MODEL", Some("gpt-4"))?,
|
||||
}
|
||||
},
|
||||
embedding: LLMConfig {
|
||||
url: std::env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()),
|
||||
key: std::env::var("EMBEDDING_KEY").unwrap_or_else(|_| "".to_string()),
|
||||
model: std::env::var("EMBEDDING_MODEL").unwrap_or_else(|_| "text-embedding-ada-002".to_string()),
|
||||
embedding: {
|
||||
let conn = PgConnection::establish(&database_url)?;
|
||||
let config = ConfigManager::new(Arc::new(Mutex::new(conn)));
|
||||
LLMConfig {
|
||||
url: config.get_config(&Uuid::nil(), "EMBEDDING_URL", Some("http://localhost:8082"))?,
|
||||
key: config.get_config(&Uuid::nil(), "EMBEDDING_KEY", Some(""))?,
|
||||
model: config.get_config(&Uuid::nil(), "EMBEDDING_MODEL", Some("text-embedding-ada-002"))?,
|
||||
}
|
||||
},
|
||||
site_path: {
|
||||
let conn = PgConnection::establish(&database_url)?;
|
||||
ConfigManager::new(Arc::new(Mutex::new(conn)))
|
||||
.get_config(&Uuid::nil(), "SITES_ROOT", Some("./botserver-stack/sites"))?
|
||||
},
|
||||
site_path: std::env::var("SITES_ROOT")
|
||||
.unwrap_or_else(|_| "./botserver-stack/sites".to_string()),
|
||||
stack_path: PathBuf::from(stack_path),
|
||||
db_conn: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_config(
|
||||
|
|
|
|||
|
|
@ -62,10 +62,7 @@ impl DriveMonitor {
|
|||
|
||||
self.check_gbdialog_changes(client).await?;
|
||||
self.check_gbkb_changes(client).await?;
|
||||
|
||||
if let Err(e) = self.check_gbot(client).await {
|
||||
error!("Error checking default bot config: {}", e);
|
||||
}
|
||||
self.check_gbot(client).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -257,7 +254,7 @@ impl DriveMonitor {
|
|||
continue;
|
||||
}
|
||||
|
||||
debug!("Checking config file at path: {}", path);
|
||||
trace!("Checking config file at path: {}", path);
|
||||
match client
|
||||
.head_object()
|
||||
.bucket(&self.bucket_name)
|
||||
|
|
@ -276,7 +273,7 @@ impl DriveMonitor {
|
|||
.key(&path)
|
||||
.send()
|
||||
.await?;
|
||||
debug!(
|
||||
trace!(
|
||||
"GetObject successful for {}, content length: {}",
|
||||
path,
|
||||
response.content_length().unwrap_or(0)
|
||||
|
|
@ -295,7 +292,7 @@ impl DriveMonitor {
|
|||
.collect();
|
||||
|
||||
if !llm_lines.is_empty() {
|
||||
use crate::llm_legacy::llm_local::ensure_llama_servers_running;
|
||||
use crate::llm::local::ensure_llama_servers_running;
|
||||
let mut restart_needed = false;
|
||||
|
||||
for line in llm_lines {
|
||||
|
|
@ -338,7 +335,7 @@ impl DriveMonitor {
|
|||
}
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Config file {} not found or inaccessible: {}", path, e);
|
||||
error!("Config file {} not found or inaccessible: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,62 +1,73 @@
|
|||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use futures_util::StreamExt;
|
||||
|
||||
use crate::tools::ToolManager;
|
||||
use super::LLMProvider;
|
||||
|
||||
pub struct LlamaClient {
|
||||
pub struct AnthropicClient {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl LlamaClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
impl AnthropicClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
api_key,
|
||||
base_url: "https://api.anthropic.com/v1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for LlamaClient {
|
||||
impl LLMProvider for AnthropicClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self.client
|
||||
.post(&format!("{}/completion", self.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
Ok(result["content"]
|
||||
let content = result["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string())
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self.client
|
||||
.post(&format!("{}/completion", self.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512),
|
||||
"temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7),
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
|
|
@ -66,14 +77,18 @@ impl LLMProvider for LlamaClient {
|
|||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
let chunk_bytes = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk_bytes);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line) {
|
||||
if let Some(content) = data["content"].as_str() {
|
||||
buffer.push_str(content);
|
||||
let _ = tx.send(content.to_string()).await;
|
||||
if line.starts_with("data: ") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if data["type"] == "content_block_delta" {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
buffer.push_str(text);
|
||||
let _ = tx.send(text.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -85,7 +100,7 @@ impl LLMProvider for LlamaClient {
|
|||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
|
|
@ -98,26 +113,14 @@ impl LLMProvider for LlamaClient {
|
|||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
self.generate(&enhanced_prompt, config).await
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// llama.cpp cancellation endpoint
|
||||
let response = self.client
|
||||
.post(&format!("{}/cancel", self.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"session_id": session_id
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
// Anthropic doesn't support job cancellation
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Failed to cancel job for session {}", session_id).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
103
src/llm/azure.rs
Normal file
103
src/llm/azure.rs
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use crate::tools::ToolManager;
|
||||
use super::LLMProvider;
|
||||
|
||||
pub struct AzureOpenAIClient {
|
||||
endpoint: String,
|
||||
api_key: String,
|
||||
api_version: String,
|
||||
deployment: String,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
pub fn new(endpoint: String, api_key: String, api_version: String, deployment: String) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
api_key,
|
||||
api_version,
|
||||
deployment,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for AzureOpenAIClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint, self.deployment, self.api_version
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0
|
||||
});
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
if let Some(choice) = result["choices"].get(0) {
|
||||
Ok(choice["message"]["content"].as_str().unwrap_or("").to_string())
|
||||
} else {
|
||||
Err("No response from Azure OpenAI".into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: tokio::sync::mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let content = self.generate(prompt, _config).await?;
|
||||
let _ = tx.send(content).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}.", available_tools.join(", "))
|
||||
};
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
self.generate(&enhanced_prompt, _config).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
269
src/llm/local.rs
Normal file
269
src/llm/local.rs
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
use crate::shared::state::AppState;
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
use crate::config::ConfigManager;
|
||||
use diesel::prelude::*;
|
||||
use std::sync::Arc;
|
||||
use log::{info, error};
|
||||
use tokio;
|
||||
use reqwest;
|
||||
|
||||
use actix_web::{post, web, HttpResponse, Result};
|
||||
|
||||
#[post("/api/chat/completions")]
|
||||
pub async fn chat_completions_local(
|
||||
_data: web::Data<AppState>,
|
||||
_payload: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
// Placeholder implementation
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "chat_completions_local not implemented" })))
|
||||
}
|
||||
|
||||
#[post("/api/embeddings")]
|
||||
pub async fn embeddings_local(
|
||||
_data: web::Data<AppState>,
|
||||
_payload: web::Json<serde_json::Value>,
|
||||
) -> Result<HttpResponse> {
|
||||
// Placeholder implementation
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "embeddings_local not implemented" })))
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running(
|
||||
app_state: &Arc<AppState>
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let conn = app_state.conn.clone();
|
||||
let config_manager = ConfigManager::new(conn.clone());
|
||||
|
||||
let default_bot_id = {
|
||||
let mut conn = conn.lock().unwrap();
|
||||
bots.filter(name.eq("default"))
|
||||
.select(id)
|
||||
.first::<uuid::Uuid>(&mut *conn)
|
||||
.unwrap_or_else(|_| uuid::Uuid::nil())
|
||||
};
|
||||
|
||||
let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
|
||||
let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
|
||||
let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
|
||||
let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
|
||||
let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
|
||||
|
||||
info!("Starting LLM servers...");
|
||||
info!("Configuration:");
|
||||
info!(" LLM URL: {}", llm_url);
|
||||
info!(" Embedding URL: {}", embedding_url);
|
||||
info!(" LLM Model: {}", llm_model);
|
||||
info!(" Embedding Model: {}", embedding_model);
|
||||
info!(" LLM Server Path: {}", llm_server_path);
|
||||
|
||||
// Restart any existing llama-server processes
|
||||
info!("Restarting any existing llama-server processes...");
|
||||
if let Err(e) = tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg("pkill -f llama-server || true")
|
||||
.spawn()
|
||||
{
|
||||
error!("Failed to execute pkill for llama-server: {}", e);
|
||||
} else {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
info!("Existing llama-server processes terminated (if any)");
|
||||
}
|
||||
|
||||
// Check if servers are already running
|
||||
let llm_running = is_server_running(&llm_url).await;
|
||||
let embedding_running = is_server_running(&embedding_url).await;
|
||||
|
||||
if llm_running && embedding_running {
|
||||
info!("Both LLM and Embedding servers are already running");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Start servers that aren't running
|
||||
let mut tasks = vec![];
|
||||
|
||||
if !llm_running && !llm_model.is_empty() {
|
||||
info!("Starting LLM server...");
|
||||
tasks.push(tokio::spawn(start_llm_server(
|
||||
Arc::clone(app_state),
|
||||
llm_server_path.clone(),
|
||||
llm_model.clone(),
|
||||
llm_url.clone(),
|
||||
)));
|
||||
} else if llm_model.is_empty() {
|
||||
info!("LLM_MODEL not set, skipping LLM server");
|
||||
}
|
||||
|
||||
if !embedding_running && !embedding_model.is_empty() {
|
||||
info!("Starting Embedding server...");
|
||||
tasks.push(tokio::spawn(start_embedding_server(
|
||||
llm_server_path.clone(),
|
||||
embedding_model.clone(),
|
||||
embedding_url.clone(),
|
||||
)));
|
||||
} else if embedding_model.is_empty() {
|
||||
info!("EMBEDDING_MODEL not set, skipping Embedding server");
|
||||
}
|
||||
|
||||
// Wait for all server startup tasks
|
||||
for task in tasks {
|
||||
task.await??;
|
||||
}
|
||||
|
||||
// Wait for servers to be ready with verbose logging
|
||||
info!("Waiting for servers to become ready...");
|
||||
|
||||
let mut llm_ready = llm_running || llm_model.is_empty();
|
||||
let mut embedding_ready = embedding_running || embedding_model.is_empty();
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 60; // 2 minutes total
|
||||
|
||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
|
||||
info!("Checking server health (attempt {}/{})...", attempts + 1, max_attempts);
|
||||
|
||||
if !llm_ready && !llm_model.is_empty() {
|
||||
if is_server_running(&llm_url).await {
|
||||
info!("LLM server ready at {}", llm_url);
|
||||
llm_ready = true;
|
||||
} else {
|
||||
info!("LLM server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
if !embedding_ready && !embedding_model.is_empty() {
|
||||
if is_server_running(&embedding_url).await {
|
||||
info!("Embedding server ready at {}", embedding_url);
|
||||
embedding_ready = true;
|
||||
} else {
|
||||
info!("Embedding server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
|
||||
if attempts % 10 == 0 {
|
||||
info!("Still waiting for servers... (attempt {}/{})", attempts, max_attempts);
|
||||
}
|
||||
}
|
||||
|
||||
if llm_ready && embedding_ready {
|
||||
info!("All llama.cpp servers are ready and responding!");
|
||||
Ok(())
|
||||
} else {
|
||||
let mut error_msg = "Servers failed to start within timeout:".to_string();
|
||||
if !llm_ready && !llm_model.is_empty() {
|
||||
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||||
}
|
||||
if !embedding_ready && !embedding_model.is_empty() {
|
||||
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||||
}
|
||||
Err(error_msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_server_running(url: &str) -> bool {
|
||||
let client = reqwest::Client::new();
|
||||
match client.get(&format!("{}/health", url)).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_llm_server(
|
||||
app_state: Arc<AppState>,
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8081");
|
||||
|
||||
std::env::set_var("OMP_NUM_THREADS", "20");
|
||||
std::env::set_var("OMP_PLACES", "cores");
|
||||
std::env::set_var("OMP_PROC_BIND", "close");
|
||||
|
||||
let conn = app_state.conn.clone();
|
||||
let config_manager = ConfigManager::new(conn.clone());
|
||||
|
||||
let default_bot_id = {
|
||||
let mut conn = conn.lock().unwrap();
|
||||
bots.filter(name.eq("default"))
|
||||
.select(id)
|
||||
.first::<uuid::Uuid>(&mut *conn)
|
||||
.unwrap_or_else(|_| uuid::Uuid::nil())
|
||||
};
|
||||
|
||||
let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
|
||||
let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
|
||||
let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
|
||||
let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
|
||||
let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
|
||||
let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
|
||||
let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
|
||||
|
||||
// Build command arguments dynamically
|
||||
let mut args = format!(
|
||||
"-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
|
||||
model_path, port, ctx_size, gpu_layers
|
||||
);
|
||||
|
||||
if n_moe != "0" {
|
||||
args.push_str(&format!(" --n-cpu-moe {}", n_moe));
|
||||
}
|
||||
if parallel != "1" {
|
||||
args.push_str(&format!(" --parallel {}", parallel));
|
||||
}
|
||||
if cont_batching == "true" {
|
||||
args.push_str(" --cont-batching");
|
||||
}
|
||||
if mlock == "true" {
|
||||
args.push_str(" --mlock");
|
||||
}
|
||||
if no_mmap == "true" {
|
||||
args.push_str(" --no-mmap");
|
||||
}
|
||||
|
||||
if cfg!(windows) {
|
||||
let mut cmd = tokio::process::Command::new("cmd");
|
||||
cmd.arg("/C").arg(format!(
|
||||
"cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
|
||||
llama_cpp_path, args
|
||||
));
|
||||
cmd.spawn()?;
|
||||
} else {
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
|
||||
llama_cpp_path, args
|
||||
));
|
||||
cmd.spawn()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn start_embedding_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8082");
|
||||
|
||||
if cfg!(windows) {
|
||||
let mut cmd = tokio::process::Command::new("cmd");
|
||||
cmd.arg("/c").arg(format!(
|
||||
"cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
cmd.spawn()?;
|
||||
} else {
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
cmd.spawn()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
121
src/llm/mod.rs
121
src/llm/mod.rs
|
|
@ -6,7 +6,9 @@ use tokio::sync::mpsc;
|
|||
|
||||
use crate::tools::ToolManager;
|
||||
|
||||
pub mod llama;
|
||||
pub mod azure;
|
||||
pub mod local;
|
||||
pub mod anthropic;
|
||||
|
||||
#[async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
|
|
@ -161,123 +163,6 @@ impl LLMProvider for OpenAIClient {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct AnthropicClient {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
api_key,
|
||||
base_url: "https://api.anthropic.com/v1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for AnthropicClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
let content = result["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.json(&serde_json::json!({
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
if line.starts_with("data: ") {
|
||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
||||
if data["type"] == "content_block_delta" {
|
||||
if let Some(text) = data["delta"]["text"].as_str() {
|
||||
buffer.push_str(text);
|
||||
let _ = tx.send(text.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generate_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_config: &Value,
|
||||
available_tools: &[String],
|
||||
_tool_manager: Arc<ToolManager>,
|
||||
_session_id: &str,
|
||||
_user_id: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tools_info = if available_tools.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||
};
|
||||
|
||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||
|
||||
self.generate(&enhanced_prompt, &Value::Null).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Anthropic doesn't support job cancellation
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockLLMProvider;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,145 +0,0 @@
|
|||
use dotenvy::dotenv;
|
||||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AzureOpenAIConfig {
|
||||
pub endpoint: String,
|
||||
pub api_key: String,
|
||||
pub api_version: String,
|
||||
pub deployment: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub top_p: f32,
|
||||
pub frequency_penalty: f32,
|
||||
pub presence_penalty: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
pub struct AzureOpenAIClient {
|
||||
config: AzureOpenAIConfig,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl AzureOpenAIClient {
|
||||
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
||||
dotenv().ok();
|
||||
|
||||
let endpoint =
|
||||
std::env::var("AZURE_OPENAI_ENDPOINT").map_err(|_| "AZURE_OPENAI_ENDPOINT not set")?;
|
||||
let api_key =
|
||||
std::env::var("AZURE_OPENAI_API_KEY").map_err(|_| "AZURE_OPENAI_API_KEY not set")?;
|
||||
let api_version = std::env::var("AZURE_OPENAI_API_VERSION")
|
||||
.unwrap_or_else(|_| "2023-12-01-preview".to_string());
|
||||
let deployment =
|
||||
std::env::var("AZURE_OPENAI_DEPLOYMENT").unwrap_or_else(|_| "gpt-35-turbo".to_string());
|
||||
|
||||
let config = AzureOpenAIConfig {
|
||||
endpoint,
|
||||
api_key,
|
||||
api_version,
|
||||
deployment,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
client: Client::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn chat_completions(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: Option<u32>,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.config.endpoint, self.config.deployment, self.config.api_version
|
||||
);
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
top_p: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
};
|
||||
|
||||
info!("Sending request to Azure OpenAI: {}", url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.config.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await?;
|
||||
error!("Azure OpenAI API error: {}", error_text);
|
||||
return Err(format!("Azure OpenAI API error: {}", error_text).into());
|
||||
}
|
||||
|
||||
let completion_response: ChatCompletionResponse = response.json().await?;
|
||||
Ok(completion_response)
|
||||
}
|
||||
|
||||
pub async fn simple_chat(&self, prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are a helpful assistant.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let response = self.chat_completions(messages, 0.7, Some(1000)).await?;
|
||||
|
||||
if let Some(choice) = response.choices.first() {
|
||||
Ok(choice.message.content.clone())
|
||||
} else {
|
||||
Err("No response from AI".into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,656 +0,0 @@
|
|||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||
use dotenvy::dotenv;
|
||||
use log::{error, info};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use crate::config::ConfigManager;
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
use crate::shared::state::AppState;
|
||||
use diesel::prelude::*;
|
||||
|
||||
// OpenAI-compatible request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: String,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
finish_reason: String,
|
||||
}
|
||||
|
||||
// Llama.cpp server request/response structures
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppRequest {
|
||||
prompt: String,
|
||||
n_predict: Option<i32>,
|
||||
temperature: Option<f32>,
|
||||
top_k: Option<i32>,
|
||||
top_p: Option<f32>,
|
||||
stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct LlamaCppResponse {
|
||||
content: String,
|
||||
stop: bool,
|
||||
generation_settings: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn ensure_llama_servers_running(
|
||||
app_state: &Arc<AppState>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let conn = app_state.conn.clone();
|
||||
let config_manager = ConfigManager::new(conn.clone());
|
||||
|
||||
// Get default bot ID from database
|
||||
let default_bot_id = {
|
||||
let mut conn = conn.lock().unwrap();
|
||||
bots.filter(name.eq("default"))
|
||||
.select(id)
|
||||
.first::<uuid::Uuid>(&mut *conn)
|
||||
.unwrap_or_else(|_| uuid::Uuid::nil())
|
||||
};
|
||||
|
||||
// Get configuration from config using default bot ID
|
||||
let llm_url = config_manager.get_config(&default_bot_id, "llm-url", None)?;
|
||||
let llm_model = config_manager.get_config(&default_bot_id, "llm-model", None)?;
|
||||
|
||||
let embedding_url = config_manager.get_config(&default_bot_id, "embedding-url", None)?;
|
||||
let embedding_model = config_manager.get_config(&default_bot_id, "embedding-model", None)?;
|
||||
|
||||
let llm_server_path = config_manager.get_config(&default_bot_id, "llm-server-path", None)?;
|
||||
|
||||
info!(" Starting LLM servers...");
|
||||
info!(" Configuration:");
|
||||
info!(" LLM URL: {}", llm_url);
|
||||
info!(" Embedding URL: {}", embedding_url);
|
||||
info!(" LLM Model: {}", llm_model);
|
||||
info!(" Embedding Model: {}", embedding_model);
|
||||
info!(" LLM Server Path: {}", llm_server_path);
|
||||
|
||||
|
||||
// Restart any existing llama-server processes before starting new ones
|
||||
info!("🔁 Restarting any existing llama-server processes...");
|
||||
if let Err(e) = tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg("pkill -f llama-server || true")
|
||||
.spawn()
|
||||
{
|
||||
error!("Failed to execute pkill for llama-server: {}", e);
|
||||
} else {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
info!("✅ Existing llama-server processes terminated (if any)");
|
||||
}
|
||||
|
||||
// Check if servers are already running
|
||||
let llm_running = is_server_running(&llm_url).await;
|
||||
let embedding_running = is_server_running(&embedding_url).await;
|
||||
|
||||
if llm_running && embedding_running {
|
||||
info!("✅ Both LLM and Embedding servers are already running");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
// Start servers that aren't running
|
||||
let mut tasks = vec![];
|
||||
|
||||
if !llm_running && !llm_model.is_empty() {
|
||||
info!("🔄 Starting LLM server...");
|
||||
tasks.push(tokio::spawn(start_llm_server(
|
||||
Arc::clone(app_state),
|
||||
llm_server_path.clone(),
|
||||
llm_model.clone(),
|
||||
llm_url.clone(),
|
||||
)));
|
||||
} else if llm_model.is_empty() {
|
||||
info!("⚠️ LLM_MODEL not set, skipping LLM server");
|
||||
}
|
||||
|
||||
if !embedding_running && !embedding_model.is_empty() {
|
||||
info!("🔄 Starting Embedding server...");
|
||||
tasks.push(tokio::spawn(start_embedding_server(
|
||||
llm_server_path.clone(),
|
||||
embedding_model.clone(),
|
||||
embedding_url.clone(),
|
||||
)));
|
||||
} else if embedding_model.is_empty() {
|
||||
info!("⚠️ EMBEDDING_MODEL not set, skipping Embedding server");
|
||||
}
|
||||
|
||||
// Wait for all server startup tasks
|
||||
for task in tasks {
|
||||
task.await??;
|
||||
}
|
||||
|
||||
// Wait for servers to be ready with verbose logging
|
||||
info!("⏳ Waiting for servers to become ready...");
|
||||
|
||||
let mut llm_ready = llm_running || llm_model.is_empty();
|
||||
let mut embedding_ready = embedding_running || embedding_model.is_empty();
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 60; // 2 minutes total
|
||||
|
||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
info!(
|
||||
"🔍 Checking server health (attempt {}/{})...",
|
||||
attempts + 1,
|
||||
max_attempts
|
||||
);
|
||||
|
||||
if !llm_ready && !llm_model.is_empty() {
|
||||
if is_server_running(&llm_url).await {
|
||||
info!(" ✅ LLM server ready at {}", llm_url);
|
||||
llm_ready = true;
|
||||
} else {
|
||||
info!(" ❌ LLM server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
if !embedding_ready && !embedding_model.is_empty() {
|
||||
if is_server_running(&embedding_url).await {
|
||||
info!(" ✅ Embedding server ready at {}", embedding_url);
|
||||
embedding_ready = true;
|
||||
} else {
|
||||
info!(" ❌ Embedding server not ready yet");
|
||||
}
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
|
||||
if attempts % 10 == 0 {
|
||||
info!(
|
||||
"⏰ Still waiting for servers... (attempt {}/{})",
|
||||
attempts, max_attempts
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if llm_ready && embedding_ready {
|
||||
info!("🎉 All llama.cpp servers are ready and responding!");
|
||||
Ok(())
|
||||
} else {
|
||||
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
|
||||
if !llm_ready && !llm_model.is_empty() {
|
||||
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||||
}
|
||||
if !embedding_ready && !embedding_model.is_empty() {
|
||||
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||||
}
|
||||
Err(error_msg.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_llm_server(
|
||||
app_state: Arc<AppState>,
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8081");
|
||||
|
||||
std::env::set_var("OMP_NUM_THREADS", "20");
|
||||
std::env::set_var("OMP_PLACES", "cores");
|
||||
std::env::set_var("OMP_PROC_BIND", "close");
|
||||
|
||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
||||
// Read config values with defaults
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
let conn = app_state.conn.clone();
|
||||
let config_manager = ConfigManager::new(conn.clone());
|
||||
|
||||
let default_bot_id = {
|
||||
let mut conn = conn.lock().unwrap();
|
||||
bots.filter(name.eq("default"))
|
||||
.select(id)
|
||||
.first::<uuid::Uuid>(&mut *conn)
|
||||
.unwrap_or_else(|_| uuid::Uuid::nil())
|
||||
};
|
||||
|
||||
let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string());
|
||||
let ctx_size = config_manager.get_config(&default_bot_id, "llm-server-ctx-size", None).unwrap_or("4096".to_string());
|
||||
let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string());
|
||||
let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string());
|
||||
let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string());
|
||||
let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string());
|
||||
let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string());
|
||||
|
||||
// Build command arguments dynamically
|
||||
let mut args = format!(
|
||||
"-m {} --host 0.0.0.0 --port {} --top_p 0.95 --temp 0.6 --ctx-size {} --repeat-penalty 1.2 -ngl {}",
|
||||
model_path, port, ctx_size, gpu_layers
|
||||
);
|
||||
|
||||
if n_moe != "0" {
|
||||
args.push_str(&format!(" --n-cpu-moe {}", n_moe));
|
||||
}
|
||||
if parallel != "1" {
|
||||
args.push_str(&format!(" --parallel {}", parallel));
|
||||
}
|
||||
if cont_batching == "true" {
|
||||
args.push_str(" --cont-batching");
|
||||
}
|
||||
if mlock == "true" {
|
||||
args.push_str(" --mlock");
|
||||
}
|
||||
if no_mmap == "true" {
|
||||
args.push_str(" --no-mmap");
|
||||
}
|
||||
|
||||
if cfg!(windows) {
|
||||
let mut cmd = tokio::process::Command::new("cmd");
|
||||
cmd.arg("/C").arg(format!(
|
||||
"cd {} && .\\llama-server.exe {} >../../../../logs/llm/stdout.log",
|
||||
llama_cpp_path, args
|
||||
));
|
||||
cmd.spawn()?;
|
||||
} else {
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server {} >../../../../logs/llm/stdout.log 2>&1 &",
|
||||
llama_cpp_path, args
|
||||
));
|
||||
cmd.spawn()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_embedding_server(
|
||||
llama_cpp_path: String,
|
||||
model_path: String,
|
||||
url: String,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let port = url.split(':').last().unwrap_or("8082");
|
||||
|
||||
if cfg!(windows) {
|
||||
let mut cmd = tokio::process::Command::new("cmd");
|
||||
cmd.arg("/c").arg(format!(
|
||||
"cd {} && .\\llama-server.exe -m {} --log-disable --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
cmd.spawn()?;
|
||||
} else {
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(format!(
|
||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 >../../../../logs/llm/stdout.log 2>&1 &",
|
||||
llama_cpp_path, model_path, port
|
||||
));
|
||||
cmd.spawn()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn is_server_running(url: &str) -> bool {
|
||||
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
match client.get(&format!("{}/health", url)).send().await {
|
||||
Ok(response) => response.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert OpenAI chat messages to a single prompt
|
||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||
let mut prompt = String::new();
|
||||
|
||||
for message in messages {
|
||||
match message.role.as_str() {
|
||||
"system" => {
|
||||
prompt.push_str(&format!("System: {}\n\n", message.content));
|
||||
}
|
||||
"user" => {
|
||||
prompt.push_str(&format!("User: {}\n\n", message.content));
|
||||
}
|
||||
"assistant" => {
|
||||
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
|
||||
}
|
||||
_ => {
|
||||
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str("Assistant: ");
|
||||
prompt
|
||||
}
|
||||
|
||||
// Proxy endpoint
|
||||
#[post("/local/v1/chat/completions")]
|
||||
pub async fn chat_completions_local(
|
||||
req_body: web::Json<ChatCompletionRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok().unwrap();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
|
||||
// Convert OpenAI format to llama.cpp format
|
||||
let prompt = messages_to_prompt(&req_body.messages);
|
||||
|
||||
let llama_request = LlamaCppRequest {
|
||||
prompt,
|
||||
n_predict: Some(500), // Adjust as needed
|
||||
temperature: Some(0.7),
|
||||
top_k: Some(40),
|
||||
top_p: Some(0.9),
|
||||
stream: req_body.stream,
|
||||
};
|
||||
|
||||
// Send request to llama.cpp server
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(500)) // 2 minute timeout
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/v1/completion", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
|
||||
error!("Error parsing llama.cpp response: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||
})?;
|
||||
|
||||
// Convert llama.cpp response to OpenAI format
|
||||
let openai_response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
model: req_body.model.clone(),
|
||||
choices: vec![Choice {
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: llama_response.content.trim().to_string(),
|
||||
},
|
||||
finish_reason: if llama_response.stop {
|
||||
"stop".to_string()
|
||||
} else {
|
||||
"length".to_string()
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": error_text,
|
||||
"type": "server_error"
|
||||
}
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
#[serde(deserialize_with = "deserialize_input")]
|
||||
pub input: Vec<String>,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub _encoding_format: Option<String>,
|
||||
}
|
||||
|
||||
// Custom deserializer to handle both string and array inputs
|
||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
use serde::de::{self, Visitor};
|
||||
use std::fmt;
|
||||
|
||||
struct InputVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for InputVisitor {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value.to_string()])
|
||||
}
|
||||
|
||||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
Ok(vec![value])
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(value) = seq.next_element::<String>()? {
|
||||
vec.push(value);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(InputVisitor)
|
||||
}
|
||||
|
||||
// OpenAI Embedding Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<EmbeddingData>,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct EmbeddingData {
|
||||
pub object: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// Llama.cpp Embedding Request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LlamaCppEmbeddingRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// FIXED: Handle the stupid nested array format
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlamaCppEmbeddingResponseItem {
|
||||
pub index: usize,
|
||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
||||
}
|
||||
|
||||
// Proxy endpoint for embeddings
|
||||
#[post("/v1/embeddings")]
|
||||
pub async fn embeddings_local(
|
||||
req_body: web::Json<EmbeddingRequest>,
|
||||
_req: HttpRequest,
|
||||
) -> Result<HttpResponse> {
|
||||
dotenv().ok();
|
||||
|
||||
// Get llama.cpp server URL
|
||||
let llama_url =
|
||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
error!("Error creating HTTP client: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||
})?;
|
||||
|
||||
// Process each input text and get embeddings
|
||||
let mut embeddings_data = Vec::new();
|
||||
let mut total_tokens = 0;
|
||||
|
||||
for (index, input_text) in req_body.input.iter().enumerate() {
|
||||
let llama_request = LlamaCppEmbeddingRequest {
|
||||
content: input_text.clone(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}/embedding", llama_url))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&llama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Error calling llama.cpp server for embedding: {}", e);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to call llama.cpp server for embedding",
|
||||
)
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if status.is_success() {
|
||||
// First, get the raw response text for debugging
|
||||
let raw_response = response.text().await.map_err(|e| {
|
||||
error!("Error reading response text: {}", e);
|
||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||
})?;
|
||||
|
||||
// Parse the response as a vector of items with nested arrays
|
||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||
serde_json::from_str(&raw_response).map_err(|e| {
|
||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||
error!("Raw response: {}", raw_response);
|
||||
actix_web::error::ErrorInternalServerError(
|
||||
"Failed to parse llama.cpp embedding response",
|
||||
)
|
||||
})?;
|
||||
|
||||
// Extract the embedding from the nested array bullshit
|
||||
if let Some(item) = llama_response.get(0) {
|
||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
||||
let flattened_embedding = if !item.embedding.is_empty() {
|
||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
||||
} else {
|
||||
vec![] // Empty if no embedding data
|
||||
};
|
||||
|
||||
// Estimate token count
|
||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||
total_tokens += estimated_tokens;
|
||||
|
||||
embeddings_data.push(EmbeddingData {
|
||||
object: "embedding".to_string(),
|
||||
embedding: flattened_embedding,
|
||||
index,
|
||||
});
|
||||
} else {
|
||||
error!("No embedding data returned for input: {}", input_text);
|
||||
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("No embedding data returned for input {}", index),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
} else {
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||
|
||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
|
||||
"type": "server_error"
|
||||
}
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
// Build OpenAI-compatible response
|
||||
let openai_response = EmbeddingResponse {
|
||||
object: "list".to_string(),
|
||||
data: embeddings_data,
|
||||
model: req_body.model.clone(),
|
||||
usage: Usage {
|
||||
prompt_tokens: total_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
|
||||
Ok(HttpResponse::Ok().json(openai_response))
|
||||
}
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
pub mod llm_azure;
|
||||
pub mod llm_local;
|
||||
17
src/llm_models/deepseek_r3.rs
Normal file
17
src/llm_models/deepseek_r3.rs
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
use super::ModelHandler;
|
||||
|
||||
pub struct DeepseekR3Handler;
|
||||
|
||||
impl ModelHandler for DeepseekR3Handler {
|
||||
fn is_analysis_complete(&self, buffer: &str) -> bool {
|
||||
buffer.contains("</think>")
|
||||
}
|
||||
|
||||
fn process_content(&self, content: &str) -> String {
|
||||
content.replace("<think>", "").replace("</think>", "")
|
||||
}
|
||||
|
||||
fn has_analysis_markers(&self, buffer: &str) -> bool {
|
||||
buffer.contains("<think>")
|
||||
}
|
||||
}
|
||||
31
src/llm_models/gpt_oss_120b.rs
Normal file
31
src/llm_models/gpt_oss_120b.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use super::ModelHandler;
|
||||
|
||||
pub struct GptOss120bHandler {
|
||||
model_name: String,
|
||||
}
|
||||
|
||||
impl GptOss120bHandler {
|
||||
pub fn new(model_name: &str) -> Self {
|
||||
Self {
|
||||
model_name: model_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelHandler for GptOss120bHandler {
|
||||
fn is_analysis_complete(&self, buffer: &str) -> bool {
|
||||
// GPT-120B uses explicit end marker
|
||||
buffer.contains("**end**")
|
||||
}
|
||||
|
||||
fn process_content(&self, content: &str) -> String {
|
||||
// Remove both start and end markers from final output
|
||||
content.replace("**start**", "")
|
||||
.replace("**end**", "")
|
||||
}
|
||||
|
||||
fn has_analysis_markers(&self, buffer: &str) -> bool {
|
||||
// GPT-120B uses explicit start marker
|
||||
buffer.contains("**start**")
|
||||
}
|
||||
}
|
||||
21
src/llm_models/gpt_oss_20b.rs
Normal file
21
src/llm_models/gpt_oss_20b.rs
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
use super::ModelHandler;
|
||||
|
||||
pub struct GptOss20bHandler;
|
||||
|
||||
impl ModelHandler for GptOss20bHandler {
|
||||
fn is_analysis_complete(&self, buffer: &str) -> bool {
|
||||
buffer.ends_with("final")
|
||||
}
|
||||
|
||||
fn process_content(&self, content: &str) -> String {
|
||||
if let Some(pos) = content.find("final") {
|
||||
content[..pos].to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn has_analysis_markers(&self, buffer: &str) -> bool {
|
||||
buffer.contains("**")
|
||||
}
|
||||
}
|
||||
33
src/llm_models/mod.rs
Normal file
33
src/llm_models/mod.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
//! Module for handling model-specific behavior and token processing
|
||||
|
||||
pub mod gpt_oss_20b;
|
||||
pub mod deepseek_r3;
|
||||
pub mod gpt_oss_120b;
|
||||
|
||||
|
||||
/// Trait for model-specific token processing
|
||||
pub trait ModelHandler: Send + Sync {
|
||||
/// Check if the analysis buffer indicates completion
|
||||
fn is_analysis_complete(&self, buffer: &str) -> bool;
|
||||
|
||||
/// Process the content, removing any model-specific tokens
|
||||
fn process_content(&self, content: &str) -> String;
|
||||
|
||||
/// Check if the buffer contains analysis start markers
|
||||
fn has_analysis_markers(&self, buffer: &str) -> bool;
|
||||
}
|
||||
|
||||
/// Get the appropriate handler based on model path from bot configuration
|
||||
pub fn get_handler(model_path: &str) -> Box<dyn ModelHandler> {
|
||||
let path = model_path.to_lowercase();
|
||||
|
||||
if path.contains("deepseek") {
|
||||
Box::new(deepseek_r3::DeepseekR3Handler)
|
||||
} else if path.contains("120b") {
|
||||
Box::new(gpt_oss_120b::GptOss120bHandler::new("default"))
|
||||
} else if path.contains("gpt-oss") || path.contains("gpt") {
|
||||
Box::new(gpt_oss_20b::GptOss20bHandler)
|
||||
} else {
|
||||
Box::new(gpt_oss_20b::GptOss20bHandler)
|
||||
}
|
||||
}
|
||||
14
src/main.rs
14
src/main.rs
|
|
@ -25,7 +25,7 @@ mod ui;
|
|||
mod file;
|
||||
mod kb;
|
||||
mod llm;
|
||||
mod llm_legacy;
|
||||
mod llm_models;
|
||||
mod meet;
|
||||
mod org;
|
||||
mod package_manager;
|
||||
|
|
@ -49,7 +49,7 @@ use crate::email::{
|
|||
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
||||
};
|
||||
use crate::file::{init_drive, upload_file};
|
||||
use crate::llm_legacy::llm_local::{
|
||||
use crate::llm::local::{
|
||||
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
||||
};
|
||||
use crate::meet::{voice_start, voice_stop};
|
||||
|
|
@ -124,8 +124,8 @@ async fn main() -> std::io::Result<()> {
|
|||
&std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
|
||||
) {
|
||||
Ok(mut conn) => AppConfig::from_database(&mut conn),
|
||||
Err(_) => AppConfig::from_env(),
|
||||
Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
|
||||
Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
|
||||
}
|
||||
} else {
|
||||
match bootstrap.bootstrap().await {
|
||||
|
|
@ -139,8 +139,8 @@ async fn main() -> std::io::Result<()> {
|
|||
&std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://gbuser:@localhost:5432/botserver".to_string()),
|
||||
) {
|
||||
Ok(mut conn) => AppConfig::from_database(&mut conn),
|
||||
Err(_) => AppConfig::from_env(),
|
||||
Ok(mut conn) => AppConfig::from_database(&mut conn).expect("Failed to load config from DB"),
|
||||
Err(_) => AppConfig::from_env().expect("Failed to load config from env"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -158,7 +158,7 @@ async fn main() -> std::io::Result<()> {
|
|||
|
||||
// Refresh configuration from environment to ensure latest DATABASE_URL and credentials
|
||||
dotenv().ok();
|
||||
let refreshed_cfg = AppConfig::from_env();
|
||||
let refreshed_cfg = AppConfig::from_env().expect("Failed to load config from env");
|
||||
let config = std::sync::Arc::new(refreshed_cfg.clone());
|
||||
let db_pool = match diesel::Connection::establish(&refreshed_cfg.database_url()) {
|
||||
Ok(conn) => Arc::new(Mutex::new(conn)),
|
||||
|
|
|
|||
|
|
@ -411,6 +411,9 @@ pub mod schema {
|
|||
bot_id -> Uuid,
|
||||
config_key -> Text,
|
||||
config_value -> Text,
|
||||
|
||||
is_encrypted -> Bool,
|
||||
|
||||
config_type -> Text,
|
||||
created_at -> Timestamptz,
|
||||
updated_at -> Timestamptz,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,19 @@ pub fn extract_zip_recursive(
|
|||
if !parent.exists() {
|
||||
std::fs::create_dir_all(&parent)?;
|
||||
}
|
||||
|
||||
use crate::llm::LLMProvider;
|
||||
use std::sync::Arc;
|
||||
use serde_json::Value;
|
||||
|
||||
/// Unified chat utility to interact with any LLM provider
|
||||
pub async fn chat_with_llm(
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
provider.generate(prompt, config).await
|
||||
}
|
||||
}
|
||||
let mut outfile = File::create(&outpath)?;
|
||||
std::io::copy(&mut file, &mut outfile)?;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue