diff --git a/migrations/6.1.0_enterprise_suite/up.sql b/migrations/6.1.0_enterprise_suite/up.sql index d8c5c7890..89971ac40 100644 --- a/migrations/6.1.0_enterprise_suite/up.sql +++ b/migrations/6.1.0_enterprise_suite/up.sql @@ -2051,7 +2051,7 @@ COMMENT ON TABLE public.system_automations IS 'System automations with TriggerKi -- User organization memberships (users can belong to multiple orgs) CREATE TABLE IF NOT EXISTS public.user_organizations ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - user_id UUID NOT NULL REFERENCES public.users(user_id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES public.users(id) ON DELETE CASCADE, org_id UUID NOT NULL REFERENCES public.organizations(org_id) ON DELETE CASCADE, role VARCHAR(50) DEFAULT 'member', -- 'owner', 'admin', 'member', 'viewer' is_default BOOLEAN DEFAULT false, diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index f7df6c03a..1fbefed88 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -14,7 +14,7 @@ use uuid::Uuid; pub type Config = AppConfig; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct AppConfig { pub drive: DriveConfig, pub server: ServerConfig, @@ -22,19 +22,19 @@ pub struct AppConfig { pub site_path: String, pub data_dir: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct DriveConfig { pub server: String, pub access_key: String, pub secret_key: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct ServerConfig { pub host: String, pub port: u16, pub base_url: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct EmailConfig { pub server: String, pub port: u16, diff --git a/src/core/package_manager/installer.rs b/src/core/package_manager/installer.rs index 24996c276..e7a00112d 100644 --- a/src/core/package_manager/installer.rs +++ b/src/core/package_manager/installer.rs @@ -252,7 +252,7 @@ impl PackageManager { ]), data_download_list: Vec::new(), exec_cmd: "nohup {{BIN_PATH}}/minio server {{DATA_PATH}} --address :9000 --console-address :9001 > {{LOGS_PATH}}/minio.log 2>&1 &".to_string(), - check_cmd: "pgrep -f 'minio server' >/dev/null 2>&1".to_string(), + check_cmd: "curl -sf http://127.0.0.1:9000/minio/health/live >/dev/null 2>&1".to_string(), }, ); } @@ -1162,44 +1162,71 @@ EOF"#.to_string(), fn fetch_vault_credentials() -> HashMap { let mut credentials = HashMap::new(); + dotenvy::dotenv().ok(); + let vault_addr = std::env::var("VAULT_ADDR").unwrap_or_else(|_| "http://localhost:8200".to_string()); let vault_token = std::env::var("VAULT_TOKEN").unwrap_or_default(); if vault_token.is_empty() { - trace!("VAULT_TOKEN not set, skipping Vault credential fetch"); + warn!("VAULT_TOKEN not set, cannot fetch credentials from Vault"); return credentials; } - if let Ok(output) = std::process::Command::new("sh") + let base_path = std::env::var("BOTSERVER_STACK_PATH") + .map(std::path::PathBuf::from) + .unwrap_or_else(|_| { + std::env::current_dir() + .unwrap_or_else(|_| std::path::PathBuf::from(".")) + .join("botserver-stack") + }); + let vault_bin = base_path.join("bin/vault/vault"); + let vault_bin_str = vault_bin.to_string_lossy(); + + info!("Fetching drive credentials from Vault at {} using {}", vault_addr, vault_bin_str); + let drive_cmd = format!( + "unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/drive", + vault_addr, vault_token, vault_bin_str + ); + match std::process::Command::new("sh") .arg("-c") - .arg(format!( - "unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/drive 2>/dev/null", - vault_addr, vault_token - )) + .arg(&drive_cmd) .output() { - if output.status.success() { - if let Ok(json_str) = String::from_utf8(output.stdout) { - if let Ok(json) = serde_json::from_str::(&json_str) { - if let Some(data) = json.get("data").and_then(|d| d.get("data")) { - if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) { - credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string()); - } - if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) { - credentials.insert("DRIVE_SECRET".to_string(), secret.to_string()); + Ok(output) => { + if output.status.success() { + let json_str = String::from_utf8_lossy(&output.stdout); + info!("Vault drive response: {}", json_str); + match serde_json::from_str::(&json_str) { + Ok(json) => { + if let Some(data) = json.get("data").and_then(|d| d.get("data")) { + if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) { + info!("Found DRIVE_ACCESSKEY from Vault"); + credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string()); + } + if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) { + info!("Found DRIVE_SECRET from Vault"); + credentials.insert("DRIVE_SECRET".to_string(), secret.to_string()); + } + } else { + warn!("Vault response missing data.data field"); } } + Err(e) => warn!("Failed to parse Vault JSON: {}", e), } + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + warn!("Vault drive command failed: {}", stderr); } } + Err(e) => warn!("Failed to execute Vault command: {}", e), } if let Ok(output) = std::process::Command::new("sh") .arg("-c") .arg(format!( - "unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/cache 2>/dev/null", - vault_addr, vault_token + "unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/cache 2>/dev/null", + vault_addr, vault_token, vault_bin_str )) .output() { diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index 11a3afaaa..313e2d84f 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -189,6 +189,7 @@ impl DriveMonitor { } async fn check_gbot(&self, client: &Client) -> Result<(), Box> { let config_manager = ConfigManager::new(self.state.conn.clone()); + debug!("check_gbot: Checking bucket {} for config.csv changes", self.bucket_name); let mut continuation_token = None; loop { let list_objects = match tokio::time::timeout( @@ -202,27 +203,28 @@ impl DriveMonitor { .await { Ok(Ok(list)) => list, - Ok(Err(e)) => return Err(e.into()), + Ok(Err(e)) => { + error!("check_gbot: Failed to list objects in bucket {}: {}", self.bucket_name, e); + return Err(e.into()); + } Err(_) => { - log::error!("Timeout listing objects in bucket {}", self.bucket_name); + error!("Timeout listing objects in bucket {}", self.bucket_name); return Ok(()); } }; for obj in list_objects.contents.unwrap_or_default() { let path = obj.key().unwrap_or_default().to_string(); - let path_parts: Vec<&str> = path.split('/').collect(); - if path_parts.len() < 2 - || !std::path::Path::new(path_parts[0]) - .extension() - .is_some_and(|ext| ext.eq_ignore_ascii_case("gbot")) - { - continue; - } - if !path.eq_ignore_ascii_case("config.csv") - && !path.to_ascii_lowercase().ends_with("/config.csv") - { + let path_lower = path.to_ascii_lowercase(); + + let is_config_csv = path_lower == "config.csv" + || path_lower.ends_with("/config.csv") + || path_lower.contains(".gbot/config.csv"); + + if !is_config_csv { continue; } + + debug!("check_gbot: Found config.csv at path: {}", path); match client .head_object() .bucket(&self.bucket_name) @@ -248,12 +250,22 @@ impl DriveMonitor { let _ = config_manager.sync_gbot_config(&self.bot_id, &csv_content); } else { use crate::llm::local::ensure_llama_servers_running; + use crate::llm::DynamicLLMProvider; let mut restart_needed = false; - for line in llm_lines { + let mut llm_url_changed = false; + let mut new_llm_url = String::new(); + let mut new_llm_model = String::new(); + for line in &llm_lines { let parts: Vec<&str> = line.split(',').collect(); if parts.len() >= 2 { let key = parts[0].trim(); let new_value = parts[1].trim(); + if key == "llm-url" { + new_llm_url = new_value.to_string(); + } + if key == "llm-model" { + new_llm_model = new_value.to_string(); + } match config_manager.get_config(&self.bot_id, key, None) { Ok(old_value) => { if old_value != new_value { @@ -262,10 +274,16 @@ impl DriveMonitor { key, old_value, new_value ); restart_needed = true; + if key == "llm-url" || key == "llm-model" { + llm_url_changed = true; + } } } Err(_) => { restart_needed = true; + if key == "llm-url" || key == "llm-model" { + llm_url_changed = true; + } } } } @@ -278,6 +296,28 @@ impl DriveMonitor { log::error!("Failed to restart LLaMA servers after llm- config change: {}", e); } } + if llm_url_changed { + info!("check_gbot: LLM config changed, updating provider..."); + let effective_url = if new_llm_url.is_empty() { + config_manager.get_config(&self.bot_id, "llm-url", None).unwrap_or_default() + } else { + new_llm_url + }; + info!("check_gbot: Effective LLM URL: {}", effective_url); + if !effective_url.is_empty() { + if let Some(dynamic_provider) = self.state.extensions.get::>().await { + let model = if new_llm_model.is_empty() { None } else { Some(new_llm_model.clone()) }; + dynamic_provider.update_from_config(&effective_url, model).await; + info!("Updated LLM provider to use URL: {}, model: {:?}", effective_url, new_llm_model); + } else { + error!("DynamicLLMProvider not found in extensions, LLM provider cannot be updated dynamically"); + } + } else { + error!("check_gbot: No llm-url found in config, cannot update provider"); + } + } else { + debug!("check_gbot: No LLM config changes detected"); + } } if csv_content.lines().any(|line| line.starts_with("theme-")) { self.broadcast_theme_change(&csv_content).await?; diff --git a/src/email/mod.rs b/src/email/mod.rs index 357c61d81..21e874537 100644 --- a/src/email/mod.rs +++ b/src/email/mod.rs @@ -2,7 +2,7 @@ use crate::{config::EmailConfig, core::urls::ApiUrls, shared::state::AppState}; use axum::{ extract::{Path, Query, State}, http::StatusCode, - response::{IntoResponse, Response}, + response::{Html, IntoResponse, Response}, Json, }; use axum::{ diff --git a/src/lib.rs b/src/lib.rs index f828cd093..fb68f7159 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,8 @@ pub mod instagram; pub mod llm; #[cfg(feature = "llm")] pub use llm::cache::{CacheConfig, CachedLLMProvider, CachedResponse, LocalEmbeddingService}; +#[cfg(feature = "llm")] +pub use llm::DynamicLLMProvider; #[cfg(feature = "meet")] pub mod meet; diff --git a/src/llm/claude.rs b/src/llm/claude.rs new file mode 100644 index 000000000..1c45d6987 --- /dev/null +++ b/src/llm/claude.rs @@ -0,0 +1,530 @@ +use async_trait::async_trait; +use futures::StreamExt; +use log::{info, trace, warn}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::sync::mpsc; + +use super::{llm_models::get_handler, LLMProvider}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeRequest { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeContentBlock { + #[serde(rename = "type")] + pub content_type: String, + #[serde(default)] + pub text: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeResponse { + pub id: String, + #[serde(rename = "type")] + pub response_type: String, + pub role: String, + pub content: Vec, + pub model: String, + #[serde(default)] + pub stop_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeStreamDelta { + #[serde(rename = "type")] + pub delta_type: String, + #[serde(default)] + pub text: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClaudeStreamEvent { + #[serde(rename = "type")] + pub event_type: String, + #[serde(default)] + pub delta: Option, + #[serde(default)] + pub index: Option, +} + +#[derive(Debug)] +pub struct ClaudeClient { + client: reqwest::Client, + base_url: String, + deployment_name: String, + is_azure: bool, +} + +impl ClaudeClient { + pub fn new(base_url: String, deployment_name: Option) -> Self { + let is_azure = base_url.contains("azure.com") || base_url.contains("openai.azure.com"); + + Self { + client: reqwest::Client::new(), + base_url, + deployment_name: deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string()), + is_azure, + } + } + + pub fn azure(endpoint: String, deployment_name: String) -> Self { + Self { + client: reqwest::Client::new(), + base_url: endpoint, + deployment_name, + is_azure: true, + } + } + + fn build_url(&self) -> String { + if self.is_azure { + format!( + "{}/deployments/{}/messages?api-version=2024-06-01", + self.base_url.trim_end_matches('/'), + self.deployment_name + ) + } else { + format!("{}/v1/messages", self.base_url.trim_end_matches('/')) + } + } + + fn build_headers(&self, api_key: &str) -> reqwest::header::HeaderMap { + let mut headers = reqwest::header::HeaderMap::new(); + + if self.is_azure { + if let Ok(val) = api_key.parse() { + headers.insert("api-key", val); + } + } else { + if let Ok(val) = api_key.parse() { + headers.insert("x-api-key", val); + } + if let Ok(val) = "2023-06-01".parse() { + headers.insert("anthropic-version", val); + } + } + + if let Ok(val) = "application/json".parse() { + headers.insert(reqwest::header::CONTENT_TYPE, val); + } + + headers + } + + pub fn build_messages( + system_prompt: &str, + context_data: &str, + history: &[(String, String)], + ) -> (Option, Vec) { + let mut system_parts = Vec::new(); + + if !system_prompt.is_empty() { + system_parts.push(system_prompt.to_string()); + } + if !context_data.is_empty() { + system_parts.push(context_data.to_string()); + } + + let system = if system_parts.is_empty() { + None + } else { + Some(system_parts.join("\n\n")) + }; + + let messages: Vec = history + .iter() + .map(|(role, content)| ClaudeMessage { + role: role.clone(), + content: content.clone(), + }) + .collect(); + + (system, messages) + } + + fn extract_text_from_response(&self, response: &ClaudeResponse) -> String { + response + .content + .iter() + .filter(|block| block.content_type == "text") + .map(|block| block.text.clone()) + .collect::>() + .join("") + } +} + +#[async_trait] +impl LLMProvider for ClaudeClient { + async fn generate( + &self, + prompt: &str, + messages: &Value, + model: &str, + key: &str, + ) -> Result> { + let url = self.build_url(); + let headers = self.build_headers(key); + + let model_name = if model.is_empty() { + &self.deployment_name + } else { + model + }; + + let empty_vec = vec![]; + let claude_messages: Vec = if messages.is_array() { + let arr = messages.as_array().unwrap_or(&empty_vec); + if arr.is_empty() { + vec![ClaudeMessage { + role: "user".to_string(), + content: prompt.to_string(), + }] + } else { + arr.iter() + .filter_map(|m| { + let role = m["role"].as_str().unwrap_or("user"); + let content = m["content"].as_str().unwrap_or(""); + if role == "system" { + None + } else { + Some(ClaudeMessage { + role: role.to_string(), + content: content.to_string(), + }) + } + }) + .collect() + } + } else { + vec![ClaudeMessage { + role: "user".to_string(), + content: prompt.to_string(), + }] + }; + + let system_prompt: Option = if messages.is_array() { + messages + .as_array() + .unwrap_or(&empty_vec) + .iter() + .filter(|m| m["role"].as_str() == Some("system")) + .map(|m| m["content"].as_str().unwrap_or("").to_string()) + .collect::>() + .join("\n\n") + .into() + } else { + None + }; + + let system = system_prompt.filter(|s| !s.is_empty()); + + let request = ClaudeRequest { + model: model_name.to_string(), + max_tokens: 4096, + messages: claude_messages, + system, + stream: None, + }; + + info!("Claude request to {}: model={}", url, model_name); + trace!("Claude request body: {:?}", serde_json::to_string(&request)); + + let response = self + .client + .post(&url) + .headers(headers) + .json(&request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + warn!("Claude API error ({}): {}", status, error_text); + return Err(format!("Claude API error ({}): {}", status, error_text).into()); + } + + let result: ClaudeResponse = response.json().await?; + let raw_content = self.extract_text_from_response(&result); + + let handler = get_handler(model_name); + let content = handler.process_content(&raw_content); + + Ok(content) + } + + async fn generate_stream( + &self, + prompt: &str, + messages: &Value, + tx: mpsc::Sender, + model: &str, + key: &str, + ) -> Result<(), Box> { + let url = self.build_url(); + let headers = self.build_headers(key); + + let model_name = if model.is_empty() { + &self.deployment_name + } else { + model + }; + + let empty_vec = vec![]; + let claude_messages: Vec = if messages.is_array() { + let arr = messages.as_array().unwrap_or(&empty_vec); + if arr.is_empty() { + vec![ClaudeMessage { + role: "user".to_string(), + content: prompt.to_string(), + }] + } else { + arr.iter() + .filter_map(|m| { + let role = m["role"].as_str().unwrap_or("user"); + let content = m["content"].as_str().unwrap_or(""); + if role == "system" { + None + } else { + Some(ClaudeMessage { + role: role.to_string(), + content: content.to_string(), + }) + } + }) + .collect() + } + } else { + vec![ClaudeMessage { + role: "user".to_string(), + content: prompt.to_string(), + }] + }; + + let system_prompt: Option = if messages.is_array() { + messages + .as_array() + .unwrap_or(&empty_vec) + .iter() + .filter(|m| m["role"].as_str() == Some("system")) + .map(|m| m["content"].as_str().unwrap_or("").to_string()) + .collect::>() + .join("\n\n") + .into() + } else { + None + }; + + let system = system_prompt.filter(|s| !s.is_empty()); + + let request = ClaudeRequest { + model: model_name.to_string(), + max_tokens: 4096, + messages: claude_messages, + system, + stream: Some(true), + }; + + info!("Claude streaming request to {}: model={}", url, model_name); + + let response = self + .client + .post(&url) + .headers(headers) + .json(&request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + warn!("Claude streaming API error ({}): {}", status, error_text); + return Err(format!("Claude streaming API error ({}): {}", status, error_text).into()); + } + + let handler = get_handler(model_name); + let mut stream = response.bytes_stream(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + let chunk_str = String::from_utf8_lossy(&chunk); + + for line in chunk_str.lines() { + let line = line.trim(); + + if line.starts_with("data: ") { + let data = &line[6..]; + + if data == "[DONE]" { + break; + } + + if let Ok(event) = serde_json::from_str::(data) { + if event.event_type == "content_block_delta" { + if let Some(delta) = event.delta { + if delta.delta_type == "text_delta" && !delta.text.is_empty() { + let processed = handler.process_content(&delta.text); + if !processed.is_empty() { + let _ = tx.send(processed).await; + } + } + } + } + } + } + } + } + + Ok(()) + } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_claude_client_new() { + let client = ClaudeClient::new( + "https://api.anthropic.com".to_string(), + Some("claude-3-opus".to_string()), + ); + assert!(!client.is_azure); + assert_eq!(client.deployment_name, "claude-3-opus"); + } + + #[test] + fn test_claude_client_azure() { + let client = ClaudeClient::azure( + "https://myendpoint.openai.azure.com/anthropic".to_string(), + "claude-opus-4-5".to_string(), + ); + assert!(client.is_azure); + assert_eq!(client.deployment_name, "claude-opus-4-5"); + } + + #[test] + fn test_build_url_azure() { + let client = ClaudeClient::azure( + "https://myendpoint.openai.azure.com/anthropic".to_string(), + "claude-opus-4-5".to_string(), + ); + let url = client.build_url(); + assert!(url.contains("deployments/claude-opus-4-5/messages")); + assert!(url.contains("api-version=")); + } + + #[test] + fn test_build_url_anthropic() { + let client = ClaudeClient::new( + "https://api.anthropic.com".to_string(), + None, + ); + let url = client.build_url(); + assert_eq!(url, "https://api.anthropic.com/v1/messages"); + } + + #[test] + fn test_build_messages_empty() { + let (system, messages) = ClaudeClient::build_messages("", "", &[]); + assert!(system.is_none()); + assert!(messages.is_empty()); + } + + #[test] + fn test_build_messages_with_system() { + let (system, messages) = ClaudeClient::build_messages( + "You are a helpful assistant.", + "", + &[], + ); + assert_eq!(system, Some("You are a helpful assistant.".to_string())); + assert!(messages.is_empty()); + } + + #[test] + fn test_build_messages_with_history() { + let history = vec![ + ("user".to_string(), "Hello".to_string()), + ("assistant".to_string(), "Hi there!".to_string()), + ]; + let (system, messages) = ClaudeClient::build_messages("", "", &history); + assert!(system.is_none()); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, "user"); + assert_eq!(messages[0].content, "Hello"); + } + + #[test] + fn test_build_messages_full() { + let history = vec![ + ("user".to_string(), "What is 2+2?".to_string()), + ]; + let (system, messages) = ClaudeClient::build_messages( + "You are a math tutor.", + "Focus on step-by-step explanations.", + &history, + ); + assert!(system.is_some()); + assert!(system.unwrap().contains("math tutor")); + assert_eq!(messages.len(), 1); + } + + #[test] + fn test_claude_request_serialization() { + let request = ClaudeRequest { + model: "claude-3-opus".to_string(), + max_tokens: 4096, + messages: vec![ClaudeMessage { + role: "user".to_string(), + content: "Hello".to_string(), + }], + system: Some("Be helpful".to_string()), + stream: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("claude-3-opus")); + assert!(json.contains("max_tokens")); + assert!(json.contains("Be helpful")); + } + + #[test] + fn test_claude_response_deserialization() { + let json = r#"{ + "id": "msg_123", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + "model": "claude-3-opus", + "stop_reason": "end_turn" + }"#; + + let response: ClaudeResponse = serde_json::from_str(json).unwrap(); + assert_eq!(response.id, "msg_123"); + assert_eq!(response.content.len(), 1); + assert_eq!(response.content[0].text, "Hello!"); + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 80e633487..64629ed38 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -2,14 +2,17 @@ use async_trait::async_trait; use futures::StreamExt; use log::{info, trace}; use serde_json::Value; -use tokio::sync::mpsc; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; pub mod cache; +pub mod claude; pub mod episodic_memory; pub mod llm_models; pub mod local; pub mod observability; +pub use claude::ClaudeClient; pub use llm_models::get_handler; #[async_trait] @@ -184,6 +187,116 @@ pub fn start_llm_services(state: &std::sync::Arc info!("LLM services started (episodic memory scheduler)"); } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LLMProviderType { + OpenAI, + Claude, + AzureClaude, +} + +impl From<&str> for LLMProviderType { + fn from(s: &str) -> Self { + let lower = s.to_lowercase(); + if lower.contains("claude") || lower.contains("anthropic") { + if lower.contains("azure") { + Self::AzureClaude + } else { + Self::Claude + } + } else { + Self::OpenAI + } + } +} + +pub fn create_llm_provider( + provider_type: LLMProviderType, + base_url: String, + deployment_name: Option, +) -> std::sync::Arc { + match provider_type { + LLMProviderType::OpenAI => { + info!("Creating OpenAI LLM provider with URL: {}", base_url); + std::sync::Arc::new(OpenAIClient::new("empty".to_string(), Some(base_url))) + } + LLMProviderType::Claude => { + info!("Creating Claude LLM provider with URL: {}", base_url); + std::sync::Arc::new(ClaudeClient::new(base_url, deployment_name)) + } + LLMProviderType::AzureClaude => { + let deployment = deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string()); + info!( + "Creating Azure Claude LLM provider with URL: {}, deployment: {}", + base_url, deployment + ); + std::sync::Arc::new(ClaudeClient::azure(base_url, deployment)) + } + } +} + +pub fn create_llm_provider_from_url(url: &str, model: Option) -> std::sync::Arc { + let provider_type = LLMProviderType::from(url); + create_llm_provider(provider_type, url.to_string(), model) +} + +pub struct DynamicLLMProvider { + inner: RwLock>, +} + +impl DynamicLLMProvider { + pub fn new(provider: Arc) -> Self { + Self { + inner: RwLock::new(provider), + } + } + + pub async fn update_provider(&self, new_provider: Arc) { + let mut guard = self.inner.write().await; + *guard = new_provider; + info!("LLM provider updated dynamically"); + } + + pub async fn update_from_config(&self, url: &str, model: Option) { + let new_provider = create_llm_provider_from_url(url, model); + self.update_provider(new_provider).await; + } + + async fn get_provider(&self) -> Arc { + self.inner.read().await.clone() + } +} + +#[async_trait] +impl LLMProvider for DynamicLLMProvider { + async fn generate( + &self, + prompt: &str, + config: &Value, + model: &str, + key: &str, + ) -> Result> { + self.get_provider().await.generate(prompt, config, model, key).await + } + + async fn generate_stream( + &self, + prompt: &str, + config: &Value, + tx: mpsc::Sender, + model: &str, + key: &str, + ) -> Result<(), Box> { + self.get_provider().await.generate_stream(prompt, config, tx, model, key).await + } + + async fn cancel_job( + &self, + session_id: &str, + ) -> Result<(), Box> { + self.get_provider().await.cancel_job(session_id).await + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/main.rs b/src/main.rs index 9fd30c205..c2f78d1a0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -185,6 +185,7 @@ async fn run_axum_server( .add_anonymous_path("/api/v1/health") .add_anonymous_path("/ws") .add_anonymous_path("/auth") + .add_anonymous_path("/api/auth") .add_public_path("/static") .add_public_path("/favicon.ico")); @@ -750,10 +751,23 @@ async fn main() -> std::io::Result<()> { .unwrap_or_else(|_| "http://localhost:8081".to_string()); info!("LLM URL: {}", llm_url); - let base_llm_provider = Arc::new(botserver::llm::OpenAIClient::new( - "empty".to_string(), - Some(llm_url.clone()), - )) as Arc; + let llm_model = config_manager + .get_config(&default_bot_id, "llm-model", Some("")) + .unwrap_or_default(); + if !llm_model.is_empty() { + info!("LLM Model: {}", llm_model); + } + + let _llm_key = config_manager + .get_config(&default_bot_id, "llm-key", Some("")) + .unwrap_or_default(); + + let base_llm_provider = botserver::llm::create_llm_provider_from_url( + &llm_url, + if llm_model.is_empty() { None } else { Some(llm_model.clone()) }, + ); + + let dynamic_llm_provider = Arc::new(botserver::llm::DynamicLLMProvider::new(base_llm_provider)); let llm_provider: Arc = if let Some(ref cache) = redis_client { let embedding_url = config_manager @@ -784,14 +798,14 @@ async fn main() -> std::io::Result<()> { }; Arc::new(botserver::llm::cache::CachedLLMProvider::with_db_pool( - base_llm_provider, + dynamic_llm_provider.clone() as Arc, cache.clone(), cache_config, embedding_service, pool.clone(), )) } else { - base_llm_provider + dynamic_llm_provider.clone() as Arc }; let kb_manager = Arc::new(botserver::core::kb::KnowledgeBaseManager::new("work")); @@ -833,7 +847,11 @@ async fn main() -> std::io::Result<()> { voice_adapter: voice_adapter.clone(), kb_manager: Some(kb_manager.clone()), task_engine, - extensions: botserver::core::shared::state::Extensions::new(), + extensions: { + let ext = botserver::core::shared::state::Extensions::new(); + ext.insert_blocking(Arc::clone(&dynamic_llm_provider)); + ext + }, attendant_broadcast: Some(attendant_tx), }); @@ -868,6 +886,24 @@ async fn main() -> std::io::Result<()> { error!("Failed to mount bots: {}", e); } + #[cfg(feature = "drive")] + { + let drive_monitor_state = app_state.clone(); + let bucket_name = "default.gbai".to_string(); + let monitor_bot_id = default_bot_id; + tokio::spawn(async move { + let monitor = botserver::DriveMonitor::new( + drive_monitor_state, + bucket_name.clone(), + monitor_bot_id, + ); + info!("Starting DriveMonitor for bucket: {}", bucket_name); + if let Err(e) = monitor.start_monitoring().await { + error!("DriveMonitor failed: {}", e); + } + }); + } + let automation_state = app_state.clone(); tokio::spawn(async move { let automation = AutomationService::new(automation_state);