Fix organizations foreign key reference (org_id not id)

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-12-29 08:07:42 -03:00
parent 4fdad88333
commit 38f9abb7db
9 changed files with 794 additions and 46 deletions

View file

@ -2051,7 +2051,7 @@ COMMENT ON TABLE public.system_automations IS 'System automations with TriggerKi
-- User organization memberships (users can belong to multiple orgs) -- User organization memberships (users can belong to multiple orgs)
CREATE TABLE IF NOT EXISTS public.user_organizations ( CREATE TABLE IF NOT EXISTS public.user_organizations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES public.users(user_id) ON DELETE CASCADE, user_id UUID NOT NULL REFERENCES public.users(id) ON DELETE CASCADE,
org_id UUID NOT NULL REFERENCES public.organizations(org_id) ON DELETE CASCADE, org_id UUID NOT NULL REFERENCES public.organizations(org_id) ON DELETE CASCADE,
role VARCHAR(50) DEFAULT 'member', -- 'owner', 'admin', 'member', 'viewer' role VARCHAR(50) DEFAULT 'member', -- 'owner', 'admin', 'member', 'viewer'
is_default BOOLEAN DEFAULT false, is_default BOOLEAN DEFAULT false,

View file

@ -14,7 +14,7 @@ use uuid::Uuid;
pub type Config = AppConfig; pub type Config = AppConfig;
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct AppConfig { pub struct AppConfig {
pub drive: DriveConfig, pub drive: DriveConfig,
pub server: ServerConfig, pub server: ServerConfig,
@ -22,19 +22,19 @@ pub struct AppConfig {
pub site_path: String, pub site_path: String,
pub data_dir: String, pub data_dir: String,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct DriveConfig { pub struct DriveConfig {
pub server: String, pub server: String,
pub access_key: String, pub access_key: String,
pub secret_key: String, pub secret_key: String,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct ServerConfig { pub struct ServerConfig {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
pub base_url: String, pub base_url: String,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct EmailConfig { pub struct EmailConfig {
pub server: String, pub server: String,
pub port: u16, pub port: u16,

View file

@ -252,7 +252,7 @@ impl PackageManager {
]), ]),
data_download_list: Vec::new(), data_download_list: Vec::new(),
exec_cmd: "nohup {{BIN_PATH}}/minio server {{DATA_PATH}} --address :9000 --console-address :9001 > {{LOGS_PATH}}/minio.log 2>&1 &".to_string(), exec_cmd: "nohup {{BIN_PATH}}/minio server {{DATA_PATH}} --address :9000 --console-address :9001 > {{LOGS_PATH}}/minio.log 2>&1 &".to_string(),
check_cmd: "pgrep -f 'minio server' >/dev/null 2>&1".to_string(), check_cmd: "curl -sf http://127.0.0.1:9000/minio/health/live >/dev/null 2>&1".to_string(),
}, },
); );
} }
@ -1162,44 +1162,71 @@ EOF"#.to_string(),
fn fetch_vault_credentials() -> HashMap<String, String> { fn fetch_vault_credentials() -> HashMap<String, String> {
let mut credentials = HashMap::new(); let mut credentials = HashMap::new();
dotenvy::dotenv().ok();
let vault_addr = let vault_addr =
std::env::var("VAULT_ADDR").unwrap_or_else(|_| "http://localhost:8200".to_string()); std::env::var("VAULT_ADDR").unwrap_or_else(|_| "http://localhost:8200".to_string());
let vault_token = std::env::var("VAULT_TOKEN").unwrap_or_default(); let vault_token = std::env::var("VAULT_TOKEN").unwrap_or_default();
if vault_token.is_empty() { if vault_token.is_empty() {
trace!("VAULT_TOKEN not set, skipping Vault credential fetch"); warn!("VAULT_TOKEN not set, cannot fetch credentials from Vault");
return credentials; return credentials;
} }
if let Ok(output) = std::process::Command::new("sh") let base_path = std::env::var("BOTSERVER_STACK_PATH")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| {
std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("."))
.join("botserver-stack")
});
let vault_bin = base_path.join("bin/vault/vault");
let vault_bin_str = vault_bin.to_string_lossy();
info!("Fetching drive credentials from Vault at {} using {}", vault_addr, vault_bin_str);
let drive_cmd = format!(
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/drive",
vault_addr, vault_token, vault_bin_str
);
match std::process::Command::new("sh")
.arg("-c") .arg("-c")
.arg(format!( .arg(&drive_cmd)
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/drive 2>/dev/null",
vault_addr, vault_token
))
.output() .output()
{ {
if output.status.success() { Ok(output) => {
if let Ok(json_str) = String::from_utf8(output.stdout) { if output.status.success() {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_str) { let json_str = String::from_utf8_lossy(&output.stdout);
if let Some(data) = json.get("data").and_then(|d| d.get("data")) { info!("Vault drive response: {}", json_str);
if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) { match serde_json::from_str::<serde_json::Value>(&json_str) {
credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string()); Ok(json) => {
} if let Some(data) = json.get("data").and_then(|d| d.get("data")) {
if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) { if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) {
credentials.insert("DRIVE_SECRET".to_string(), secret.to_string()); info!("Found DRIVE_ACCESSKEY from Vault");
credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string());
}
if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) {
info!("Found DRIVE_SECRET from Vault");
credentials.insert("DRIVE_SECRET".to_string(), secret.to_string());
}
} else {
warn!("Vault response missing data.data field");
} }
} }
Err(e) => warn!("Failed to parse Vault JSON: {}", e),
} }
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
warn!("Vault drive command failed: {}", stderr);
} }
} }
Err(e) => warn!("Failed to execute Vault command: {}", e),
} }
if let Ok(output) = std::process::Command::new("sh") if let Ok(output) = std::process::Command::new("sh")
.arg("-c") .arg("-c")
.arg(format!( .arg(format!(
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/cache 2>/dev/null", "unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/cache 2>/dev/null",
vault_addr, vault_token vault_addr, vault_token, vault_bin_str
)) ))
.output() .output()
{ {

View file

@ -189,6 +189,7 @@ impl DriveMonitor {
} }
async fn check_gbot(&self, client: &Client) -> Result<(), Box<dyn Error + Send + Sync>> { async fn check_gbot(&self, client: &Client) -> Result<(), Box<dyn Error + Send + Sync>> {
let config_manager = ConfigManager::new(self.state.conn.clone()); let config_manager = ConfigManager::new(self.state.conn.clone());
debug!("check_gbot: Checking bucket {} for config.csv changes", self.bucket_name);
let mut continuation_token = None; let mut continuation_token = None;
loop { loop {
let list_objects = match tokio::time::timeout( let list_objects = match tokio::time::timeout(
@ -202,27 +203,28 @@ impl DriveMonitor {
.await .await
{ {
Ok(Ok(list)) => list, Ok(Ok(list)) => list,
Ok(Err(e)) => return Err(e.into()), Ok(Err(e)) => {
error!("check_gbot: Failed to list objects in bucket {}: {}", self.bucket_name, e);
return Err(e.into());
}
Err(_) => { Err(_) => {
log::error!("Timeout listing objects in bucket {}", self.bucket_name); error!("Timeout listing objects in bucket {}", self.bucket_name);
return Ok(()); return Ok(());
} }
}; };
for obj in list_objects.contents.unwrap_or_default() { for obj in list_objects.contents.unwrap_or_default() {
let path = obj.key().unwrap_or_default().to_string(); let path = obj.key().unwrap_or_default().to_string();
let path_parts: Vec<&str> = path.split('/').collect(); let path_lower = path.to_ascii_lowercase();
if path_parts.len() < 2
|| !std::path::Path::new(path_parts[0]) let is_config_csv = path_lower == "config.csv"
.extension() || path_lower.ends_with("/config.csv")
.is_some_and(|ext| ext.eq_ignore_ascii_case("gbot")) || path_lower.contains(".gbot/config.csv");
{
continue; if !is_config_csv {
}
if !path.eq_ignore_ascii_case("config.csv")
&& !path.to_ascii_lowercase().ends_with("/config.csv")
{
continue; continue;
} }
debug!("check_gbot: Found config.csv at path: {}", path);
match client match client
.head_object() .head_object()
.bucket(&self.bucket_name) .bucket(&self.bucket_name)
@ -248,12 +250,22 @@ impl DriveMonitor {
let _ = config_manager.sync_gbot_config(&self.bot_id, &csv_content); let _ = config_manager.sync_gbot_config(&self.bot_id, &csv_content);
} else { } else {
use crate::llm::local::ensure_llama_servers_running; use crate::llm::local::ensure_llama_servers_running;
use crate::llm::DynamicLLMProvider;
let mut restart_needed = false; let mut restart_needed = false;
for line in llm_lines { let mut llm_url_changed = false;
let mut new_llm_url = String::new();
let mut new_llm_model = String::new();
for line in &llm_lines {
let parts: Vec<&str> = line.split(',').collect(); let parts: Vec<&str> = line.split(',').collect();
if parts.len() >= 2 { if parts.len() >= 2 {
let key = parts[0].trim(); let key = parts[0].trim();
let new_value = parts[1].trim(); let new_value = parts[1].trim();
if key == "llm-url" {
new_llm_url = new_value.to_string();
}
if key == "llm-model" {
new_llm_model = new_value.to_string();
}
match config_manager.get_config(&self.bot_id, key, None) { match config_manager.get_config(&self.bot_id, key, None) {
Ok(old_value) => { Ok(old_value) => {
if old_value != new_value { if old_value != new_value {
@ -262,10 +274,16 @@ impl DriveMonitor {
key, old_value, new_value key, old_value, new_value
); );
restart_needed = true; restart_needed = true;
if key == "llm-url" || key == "llm-model" {
llm_url_changed = true;
}
} }
} }
Err(_) => { Err(_) => {
restart_needed = true; restart_needed = true;
if key == "llm-url" || key == "llm-model" {
llm_url_changed = true;
}
} }
} }
} }
@ -278,6 +296,28 @@ impl DriveMonitor {
log::error!("Failed to restart LLaMA servers after llm- config change: {}", e); log::error!("Failed to restart LLaMA servers after llm- config change: {}", e);
} }
} }
if llm_url_changed {
info!("check_gbot: LLM config changed, updating provider...");
let effective_url = if new_llm_url.is_empty() {
config_manager.get_config(&self.bot_id, "llm-url", None).unwrap_or_default()
} else {
new_llm_url
};
info!("check_gbot: Effective LLM URL: {}", effective_url);
if !effective_url.is_empty() {
if let Some(dynamic_provider) = self.state.extensions.get::<Arc<DynamicLLMProvider>>().await {
let model = if new_llm_model.is_empty() { None } else { Some(new_llm_model.clone()) };
dynamic_provider.update_from_config(&effective_url, model).await;
info!("Updated LLM provider to use URL: {}, model: {:?}", effective_url, new_llm_model);
} else {
error!("DynamicLLMProvider not found in extensions, LLM provider cannot be updated dynamically");
}
} else {
error!("check_gbot: No llm-url found in config, cannot update provider");
}
} else {
debug!("check_gbot: No LLM config changes detected");
}
} }
if csv_content.lines().any(|line| line.starts_with("theme-")) { if csv_content.lines().any(|line| line.starts_with("theme-")) {
self.broadcast_theme_change(&csv_content).await?; self.broadcast_theme_change(&csv_content).await?;

View file

@ -2,7 +2,7 @@ use crate::{config::EmailConfig, core::urls::ApiUrls, shared::state::AppState};
use axum::{ use axum::{
extract::{Path, Query, State}, extract::{Path, Query, State},
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{Html, IntoResponse, Response},
Json, Json,
}; };
use axum::{ use axum::{

View file

@ -63,6 +63,8 @@ pub mod instagram;
pub mod llm; pub mod llm;
#[cfg(feature = "llm")] #[cfg(feature = "llm")]
pub use llm::cache::{CacheConfig, CachedLLMProvider, CachedResponse, LocalEmbeddingService}; pub use llm::cache::{CacheConfig, CachedLLMProvider, CachedResponse, LocalEmbeddingService};
#[cfg(feature = "llm")]
pub use llm::DynamicLLMProvider;
#[cfg(feature = "meet")] #[cfg(feature = "meet")]
pub mod meet; pub mod meet;

530
src/llm/claude.rs Normal file
View file

@ -0,0 +1,530 @@
use async_trait::async_trait;
use futures::StreamExt;
use log::{info, trace, warn};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc;
use super::{llm_models::get_handler, LLMProvider};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeRequest {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<ClaudeMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeContentBlock {
#[serde(rename = "type")]
pub content_type: String,
#[serde(default)]
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: String,
pub content: Vec<ClaudeContentBlock>,
pub model: String,
#[serde(default)]
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeStreamDelta {
#[serde(rename = "type")]
pub delta_type: String,
#[serde(default)]
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClaudeStreamEvent {
#[serde(rename = "type")]
pub event_type: String,
#[serde(default)]
pub delta: Option<ClaudeStreamDelta>,
#[serde(default)]
pub index: Option<u32>,
}
#[derive(Debug)]
pub struct ClaudeClient {
client: reqwest::Client,
base_url: String,
deployment_name: String,
is_azure: bool,
}
impl ClaudeClient {
pub fn new(base_url: String, deployment_name: Option<String>) -> Self {
let is_azure = base_url.contains("azure.com") || base_url.contains("openai.azure.com");
Self {
client: reqwest::Client::new(),
base_url,
deployment_name: deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string()),
is_azure,
}
}
pub fn azure(endpoint: String, deployment_name: String) -> Self {
Self {
client: reqwest::Client::new(),
base_url: endpoint,
deployment_name,
is_azure: true,
}
}
fn build_url(&self) -> String {
if self.is_azure {
format!(
"{}/deployments/{}/messages?api-version=2024-06-01",
self.base_url.trim_end_matches('/'),
self.deployment_name
)
} else {
format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
}
}
fn build_headers(&self, api_key: &str) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
if self.is_azure {
if let Ok(val) = api_key.parse() {
headers.insert("api-key", val);
}
} else {
if let Ok(val) = api_key.parse() {
headers.insert("x-api-key", val);
}
if let Ok(val) = "2023-06-01".parse() {
headers.insert("anthropic-version", val);
}
}
if let Ok(val) = "application/json".parse() {
headers.insert(reqwest::header::CONTENT_TYPE, val);
}
headers
}
pub fn build_messages(
system_prompt: &str,
context_data: &str,
history: &[(String, String)],
) -> (Option<String>, Vec<ClaudeMessage>) {
let mut system_parts = Vec::new();
if !system_prompt.is_empty() {
system_parts.push(system_prompt.to_string());
}
if !context_data.is_empty() {
system_parts.push(context_data.to_string());
}
let system = if system_parts.is_empty() {
None
} else {
Some(system_parts.join("\n\n"))
};
let messages: Vec<ClaudeMessage> = history
.iter()
.map(|(role, content)| ClaudeMessage {
role: role.clone(),
content: content.clone(),
})
.collect();
(system, messages)
}
fn extract_text_from_response(&self, response: &ClaudeResponse) -> String {
response
.content
.iter()
.filter(|block| block.content_type == "text")
.map(|block| block.text.clone())
.collect::<Vec<_>>()
.join("")
}
}
#[async_trait]
impl LLMProvider for ClaudeClient {
async fn generate(
&self,
prompt: &str,
messages: &Value,
model: &str,
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let url = self.build_url();
let headers = self.build_headers(key);
let model_name = if model.is_empty() {
&self.deployment_name
} else {
model
};
let empty_vec = vec![];
let claude_messages: Vec<ClaudeMessage> = if messages.is_array() {
let arr = messages.as_array().unwrap_or(&empty_vec);
if arr.is_empty() {
vec![ClaudeMessage {
role: "user".to_string(),
content: prompt.to_string(),
}]
} else {
arr.iter()
.filter_map(|m| {
let role = m["role"].as_str().unwrap_or("user");
let content = m["content"].as_str().unwrap_or("");
if role == "system" {
None
} else {
Some(ClaudeMessage {
role: role.to_string(),
content: content.to_string(),
})
}
})
.collect()
}
} else {
vec![ClaudeMessage {
role: "user".to_string(),
content: prompt.to_string(),
}]
};
let system_prompt: Option<String> = if messages.is_array() {
messages
.as_array()
.unwrap_or(&empty_vec)
.iter()
.filter(|m| m["role"].as_str() == Some("system"))
.map(|m| m["content"].as_str().unwrap_or("").to_string())
.collect::<Vec<_>>()
.join("\n\n")
.into()
} else {
None
};
let system = system_prompt.filter(|s| !s.is_empty());
let request = ClaudeRequest {
model: model_name.to_string(),
max_tokens: 4096,
messages: claude_messages,
system,
stream: None,
};
info!("Claude request to {}: model={}", url, model_name);
trace!("Claude request body: {:?}", serde_json::to_string(&request));
let response = self
.client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
warn!("Claude API error ({}): {}", status, error_text);
return Err(format!("Claude API error ({}): {}", status, error_text).into());
}
let result: ClaudeResponse = response.json().await?;
let raw_content = self.extract_text_from_response(&result);
let handler = get_handler(model_name);
let content = handler.process_content(&raw_content);
Ok(content)
}
async fn generate_stream(
&self,
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str,
key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let url = self.build_url();
let headers = self.build_headers(key);
let model_name = if model.is_empty() {
&self.deployment_name
} else {
model
};
let empty_vec = vec![];
let claude_messages: Vec<ClaudeMessage> = if messages.is_array() {
let arr = messages.as_array().unwrap_or(&empty_vec);
if arr.is_empty() {
vec![ClaudeMessage {
role: "user".to_string(),
content: prompt.to_string(),
}]
} else {
arr.iter()
.filter_map(|m| {
let role = m["role"].as_str().unwrap_or("user");
let content = m["content"].as_str().unwrap_or("");
if role == "system" {
None
} else {
Some(ClaudeMessage {
role: role.to_string(),
content: content.to_string(),
})
}
})
.collect()
}
} else {
vec![ClaudeMessage {
role: "user".to_string(),
content: prompt.to_string(),
}]
};
let system_prompt: Option<String> = if messages.is_array() {
messages
.as_array()
.unwrap_or(&empty_vec)
.iter()
.filter(|m| m["role"].as_str() == Some("system"))
.map(|m| m["content"].as_str().unwrap_or("").to_string())
.collect::<Vec<_>>()
.join("\n\n")
.into()
} else {
None
};
let system = system_prompt.filter(|s| !s.is_empty());
let request = ClaudeRequest {
model: model_name.to_string(),
max_tokens: 4096,
messages: claude_messages,
system,
stream: Some(true),
};
info!("Claude streaming request to {}: model={}", url, model_name);
let response = self
.client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
warn!("Claude streaming API error ({}): {}", status, error_text);
return Err(format!("Claude streaming API error ({}): {}", status, error_text).into());
}
let handler = get_handler(model_name);
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
let line = line.trim();
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
break;
}
if let Ok(event) = serde_json::from_str::<ClaudeStreamEvent>(data) {
if event.event_type == "content_block_delta" {
if let Some(delta) = event.delta {
if delta.delta_type == "text_delta" && !delta.text.is_empty() {
let processed = handler.process_content(&delta.text);
if !processed.is_empty() {
let _ = tx.send(processed).await;
}
}
}
}
}
}
}
}
Ok(())
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claude_client_new() {
let client = ClaudeClient::new(
"https://api.anthropic.com".to_string(),
Some("claude-3-opus".to_string()),
);
assert!(!client.is_azure);
assert_eq!(client.deployment_name, "claude-3-opus");
}
#[test]
fn test_claude_client_azure() {
let client = ClaudeClient::azure(
"https://myendpoint.openai.azure.com/anthropic".to_string(),
"claude-opus-4-5".to_string(),
);
assert!(client.is_azure);
assert_eq!(client.deployment_name, "claude-opus-4-5");
}
#[test]
fn test_build_url_azure() {
let client = ClaudeClient::azure(
"https://myendpoint.openai.azure.com/anthropic".to_string(),
"claude-opus-4-5".to_string(),
);
let url = client.build_url();
assert!(url.contains("deployments/claude-opus-4-5/messages"));
assert!(url.contains("api-version="));
}
#[test]
fn test_build_url_anthropic() {
let client = ClaudeClient::new(
"https://api.anthropic.com".to_string(),
None,
);
let url = client.build_url();
assert_eq!(url, "https://api.anthropic.com/v1/messages");
}
#[test]
fn test_build_messages_empty() {
let (system, messages) = ClaudeClient::build_messages("", "", &[]);
assert!(system.is_none());
assert!(messages.is_empty());
}
#[test]
fn test_build_messages_with_system() {
let (system, messages) = ClaudeClient::build_messages(
"You are a helpful assistant.",
"",
&[],
);
assert_eq!(system, Some("You are a helpful assistant.".to_string()));
assert!(messages.is_empty());
}
#[test]
fn test_build_messages_with_history() {
let history = vec![
("user".to_string(), "Hello".to_string()),
("assistant".to_string(), "Hi there!".to_string()),
];
let (system, messages) = ClaudeClient::build_messages("", "", &history);
assert!(system.is_none());
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "Hello");
}
#[test]
fn test_build_messages_full() {
let history = vec![
("user".to_string(), "What is 2+2?".to_string()),
];
let (system, messages) = ClaudeClient::build_messages(
"You are a math tutor.",
"Focus on step-by-step explanations.",
&history,
);
assert!(system.is_some());
assert!(system.unwrap().contains("math tutor"));
assert_eq!(messages.len(), 1);
}
#[test]
fn test_claude_request_serialization() {
let request = ClaudeRequest {
model: "claude-3-opus".to_string(),
max_tokens: 4096,
messages: vec![ClaudeMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}],
system: Some("Be helpful".to_string()),
stream: None,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("claude-3-opus"));
assert!(json.contains("max_tokens"));
assert!(json.contains("Be helpful"));
}
#[test]
fn test_claude_response_deserialization() {
let json = r#"{
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-3-opus",
"stop_reason": "end_turn"
}"#;
let response: ClaudeResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.id, "msg_123");
assert_eq!(response.content.len(), 1);
assert_eq!(response.content[0].text, "Hello!");
}
}

View file

@ -2,14 +2,17 @@ use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use log::{info, trace}; use log::{info, trace};
use serde_json::Value; use serde_json::Value;
use tokio::sync::mpsc; use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
pub mod cache; pub mod cache;
pub mod claude;
pub mod episodic_memory; pub mod episodic_memory;
pub mod llm_models; pub mod llm_models;
pub mod local; pub mod local;
pub mod observability; pub mod observability;
pub use claude::ClaudeClient;
pub use llm_models::get_handler; pub use llm_models::get_handler;
#[async_trait] #[async_trait]
@ -184,6 +187,116 @@ pub fn start_llm_services(state: &std::sync::Arc<crate::shared::state::AppState>
info!("LLM services started (episodic memory scheduler)"); info!("LLM services started (episodic memory scheduler)");
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LLMProviderType {
OpenAI,
Claude,
AzureClaude,
}
impl From<&str> for LLMProviderType {
fn from(s: &str) -> Self {
let lower = s.to_lowercase();
if lower.contains("claude") || lower.contains("anthropic") {
if lower.contains("azure") {
Self::AzureClaude
} else {
Self::Claude
}
} else {
Self::OpenAI
}
}
}
pub fn create_llm_provider(
provider_type: LLMProviderType,
base_url: String,
deployment_name: Option<String>,
) -> std::sync::Arc<dyn LLMProvider> {
match provider_type {
LLMProviderType::OpenAI => {
info!("Creating OpenAI LLM provider with URL: {}", base_url);
std::sync::Arc::new(OpenAIClient::new("empty".to_string(), Some(base_url)))
}
LLMProviderType::Claude => {
info!("Creating Claude LLM provider with URL: {}", base_url);
std::sync::Arc::new(ClaudeClient::new(base_url, deployment_name))
}
LLMProviderType::AzureClaude => {
let deployment = deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string());
info!(
"Creating Azure Claude LLM provider with URL: {}, deployment: {}",
base_url, deployment
);
std::sync::Arc::new(ClaudeClient::azure(base_url, deployment))
}
}
}
pub fn create_llm_provider_from_url(url: &str, model: Option<String>) -> std::sync::Arc<dyn LLMProvider> {
let provider_type = LLMProviderType::from(url);
create_llm_provider(provider_type, url.to_string(), model)
}
pub struct DynamicLLMProvider {
inner: RwLock<Arc<dyn LLMProvider>>,
}
impl DynamicLLMProvider {
pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
Self {
inner: RwLock::new(provider),
}
}
pub async fn update_provider(&self, new_provider: Arc<dyn LLMProvider>) {
let mut guard = self.inner.write().await;
*guard = new_provider;
info!("LLM provider updated dynamically");
}
pub async fn update_from_config(&self, url: &str, model: Option<String>) {
let new_provider = create_llm_provider_from_url(url, model);
self.update_provider(new_provider).await;
}
async fn get_provider(&self) -> Arc<dyn LLMProvider> {
self.inner.read().await.clone()
}
}
#[async_trait]
impl LLMProvider for DynamicLLMProvider {
async fn generate(
&self,
prompt: &str,
config: &Value,
model: &str,
key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
self.get_provider().await.generate(prompt, config, model, key).await
}
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
model: &str,
key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.get_provider().await.generate_stream(prompt, config, tx, model, key).await
}
async fn cancel_job(
&self,
session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.get_provider().await.cancel_job(session_id).await
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -185,6 +185,7 @@ async fn run_axum_server(
.add_anonymous_path("/api/v1/health") .add_anonymous_path("/api/v1/health")
.add_anonymous_path("/ws") .add_anonymous_path("/ws")
.add_anonymous_path("/auth") .add_anonymous_path("/auth")
.add_anonymous_path("/api/auth")
.add_public_path("/static") .add_public_path("/static")
.add_public_path("/favicon.ico")); .add_public_path("/favicon.ico"));
@ -750,10 +751,23 @@ async fn main() -> std::io::Result<()> {
.unwrap_or_else(|_| "http://localhost:8081".to_string()); .unwrap_or_else(|_| "http://localhost:8081".to_string());
info!("LLM URL: {}", llm_url); info!("LLM URL: {}", llm_url);
let base_llm_provider = Arc::new(botserver::llm::OpenAIClient::new( let llm_model = config_manager
"empty".to_string(), .get_config(&default_bot_id, "llm-model", Some(""))
Some(llm_url.clone()), .unwrap_or_default();
)) as Arc<dyn botserver::llm::LLMProvider>; if !llm_model.is_empty() {
info!("LLM Model: {}", llm_model);
}
let _llm_key = config_manager
.get_config(&default_bot_id, "llm-key", Some(""))
.unwrap_or_default();
let base_llm_provider = botserver::llm::create_llm_provider_from_url(
&llm_url,
if llm_model.is_empty() { None } else { Some(llm_model.clone()) },
);
let dynamic_llm_provider = Arc::new(botserver::llm::DynamicLLMProvider::new(base_llm_provider));
let llm_provider: Arc<dyn botserver::llm::LLMProvider> = if let Some(ref cache) = redis_client { let llm_provider: Arc<dyn botserver::llm::LLMProvider> = if let Some(ref cache) = redis_client {
let embedding_url = config_manager let embedding_url = config_manager
@ -784,14 +798,14 @@ async fn main() -> std::io::Result<()> {
}; };
Arc::new(botserver::llm::cache::CachedLLMProvider::with_db_pool( Arc::new(botserver::llm::cache::CachedLLMProvider::with_db_pool(
base_llm_provider, dynamic_llm_provider.clone() as Arc<dyn botserver::llm::LLMProvider>,
cache.clone(), cache.clone(),
cache_config, cache_config,
embedding_service, embedding_service,
pool.clone(), pool.clone(),
)) ))
} else { } else {
base_llm_provider dynamic_llm_provider.clone() as Arc<dyn botserver::llm::LLMProvider>
}; };
let kb_manager = Arc::new(botserver::core::kb::KnowledgeBaseManager::new("work")); let kb_manager = Arc::new(botserver::core::kb::KnowledgeBaseManager::new("work"));
@ -833,7 +847,11 @@ async fn main() -> std::io::Result<()> {
voice_adapter: voice_adapter.clone(), voice_adapter: voice_adapter.clone(),
kb_manager: Some(kb_manager.clone()), kb_manager: Some(kb_manager.clone()),
task_engine, task_engine,
extensions: botserver::core::shared::state::Extensions::new(), extensions: {
let ext = botserver::core::shared::state::Extensions::new();
ext.insert_blocking(Arc::clone(&dynamic_llm_provider));
ext
},
attendant_broadcast: Some(attendant_tx), attendant_broadcast: Some(attendant_tx),
}); });
@ -868,6 +886,24 @@ async fn main() -> std::io::Result<()> {
error!("Failed to mount bots: {}", e); error!("Failed to mount bots: {}", e);
} }
#[cfg(feature = "drive")]
{
let drive_monitor_state = app_state.clone();
let bucket_name = "default.gbai".to_string();
let monitor_bot_id = default_bot_id;
tokio::spawn(async move {
let monitor = botserver::DriveMonitor::new(
drive_monitor_state,
bucket_name.clone(),
monitor_bot_id,
);
info!("Starting DriveMonitor for bucket: {}", bucket_name);
if let Err(e) = monitor.start_monitoring().await {
error!("DriveMonitor failed: {}", e);
}
});
}
let automation_state = app_state.clone(); let automation_state = app_state.clone();
tokio::spawn(async move { tokio::spawn(async move {
let automation = AutomationService::new(automation_state); let automation = AutomationService::new(automation_state);