Fix organizations foreign key reference (org_id not id)
This commit is contained in:
parent
4fdad88333
commit
38f9abb7db
9 changed files with 794 additions and 46 deletions
|
|
@ -2051,7 +2051,7 @@ COMMENT ON TABLE public.system_automations IS 'System automations with TriggerKi
|
|||
-- User organization memberships (users can belong to multiple orgs)
|
||||
CREATE TABLE IF NOT EXISTS public.user_organizations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES public.users(user_id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES public.users(id) ON DELETE CASCADE,
|
||||
org_id UUID NOT NULL REFERENCES public.organizations(org_id) ON DELETE CASCADE,
|
||||
role VARCHAR(50) DEFAULT 'member', -- 'owner', 'admin', 'member', 'viewer'
|
||||
is_default BOOLEAN DEFAULT false,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use uuid::Uuid;
|
|||
|
||||
pub type Config = AppConfig;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct AppConfig {
|
||||
pub drive: DriveConfig,
|
||||
pub server: ServerConfig,
|
||||
|
|
@ -22,19 +22,19 @@ pub struct AppConfig {
|
|||
pub site_path: String,
|
||||
pub data_dir: String,
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DriveConfig {
|
||||
pub server: String,
|
||||
pub access_key: String,
|
||||
pub secret_key: String,
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub base_url: String,
|
||||
}
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct EmailConfig {
|
||||
pub server: String,
|
||||
pub port: u16,
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ impl PackageManager {
|
|||
]),
|
||||
data_download_list: Vec::new(),
|
||||
exec_cmd: "nohup {{BIN_PATH}}/minio server {{DATA_PATH}} --address :9000 --console-address :9001 > {{LOGS_PATH}}/minio.log 2>&1 &".to_string(),
|
||||
check_cmd: "pgrep -f 'minio server' >/dev/null 2>&1".to_string(),
|
||||
check_cmd: "curl -sf http://127.0.0.1:9000/minio/health/live >/dev/null 2>&1".to_string(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
|
@ -1162,44 +1162,71 @@ EOF"#.to_string(),
|
|||
fn fetch_vault_credentials() -> HashMap<String, String> {
|
||||
let mut credentials = HashMap::new();
|
||||
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
let vault_addr =
|
||||
std::env::var("VAULT_ADDR").unwrap_or_else(|_| "http://localhost:8200".to_string());
|
||||
let vault_token = std::env::var("VAULT_TOKEN").unwrap_or_default();
|
||||
|
||||
if vault_token.is_empty() {
|
||||
trace!("VAULT_TOKEN not set, skipping Vault credential fetch");
|
||||
warn!("VAULT_TOKEN not set, cannot fetch credentials from Vault");
|
||||
return credentials;
|
||||
}
|
||||
|
||||
if let Ok(output) = std::process::Command::new("sh")
|
||||
let base_path = std::env::var("BOTSERVER_STACK_PATH")
|
||||
.map(std::path::PathBuf::from)
|
||||
.unwrap_or_else(|_| {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| std::path::PathBuf::from("."))
|
||||
.join("botserver-stack")
|
||||
});
|
||||
let vault_bin = base_path.join("bin/vault/vault");
|
||||
let vault_bin_str = vault_bin.to_string_lossy();
|
||||
|
||||
info!("Fetching drive credentials from Vault at {} using {}", vault_addr, vault_bin_str);
|
||||
let drive_cmd = format!(
|
||||
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/drive",
|
||||
vault_addr, vault_token, vault_bin_str
|
||||
);
|
||||
match std::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(format!(
|
||||
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/drive 2>/dev/null",
|
||||
vault_addr, vault_token
|
||||
))
|
||||
.arg(&drive_cmd)
|
||||
.output()
|
||||
{
|
||||
if output.status.success() {
|
||||
if let Ok(json_str) = String::from_utf8(output.stdout) {
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_str) {
|
||||
if let Some(data) = json.get("data").and_then(|d| d.get("data")) {
|
||||
if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) {
|
||||
credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string());
|
||||
}
|
||||
if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) {
|
||||
credentials.insert("DRIVE_SECRET".to_string(), secret.to_string());
|
||||
Ok(output) => {
|
||||
if output.status.success() {
|
||||
let json_str = String::from_utf8_lossy(&output.stdout);
|
||||
info!("Vault drive response: {}", json_str);
|
||||
match serde_json::from_str::<serde_json::Value>(&json_str) {
|
||||
Ok(json) => {
|
||||
if let Some(data) = json.get("data").and_then(|d| d.get("data")) {
|
||||
if let Some(accesskey) = data.get("accesskey").and_then(|v| v.as_str()) {
|
||||
info!("Found DRIVE_ACCESSKEY from Vault");
|
||||
credentials.insert("DRIVE_ACCESSKEY".to_string(), accesskey.to_string());
|
||||
}
|
||||
if let Some(secret) = data.get("secret").and_then(|v| v.as_str()) {
|
||||
info!("Found DRIVE_SECRET from Vault");
|
||||
credentials.insert("DRIVE_SECRET".to_string(), secret.to_string());
|
||||
}
|
||||
} else {
|
||||
warn!("Vault response missing data.data field");
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to parse Vault JSON: {}", e),
|
||||
}
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
warn!("Vault drive command failed: {}", stderr);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to execute Vault command: {}", e),
|
||||
}
|
||||
|
||||
if let Ok(output) = std::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(format!(
|
||||
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} ./botserver-stack/bin/vault/vault kv get -format=json secret/gbo/cache 2>/dev/null",
|
||||
vault_addr, vault_token
|
||||
"unset VAULT_CLIENT_CERT VAULT_CLIENT_KEY VAULT_CACERT; VAULT_ADDR={} VAULT_TOKEN={} {} kv get -format=json secret/gbo/cache 2>/dev/null",
|
||||
vault_addr, vault_token, vault_bin_str
|
||||
))
|
||||
.output()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -189,6 +189,7 @@ impl DriveMonitor {
|
|||
}
|
||||
async fn check_gbot(&self, client: &Client) -> Result<(), Box<dyn Error + Send + Sync>> {
|
||||
let config_manager = ConfigManager::new(self.state.conn.clone());
|
||||
debug!("check_gbot: Checking bucket {} for config.csv changes", self.bucket_name);
|
||||
let mut continuation_token = None;
|
||||
loop {
|
||||
let list_objects = match tokio::time::timeout(
|
||||
|
|
@ -202,27 +203,28 @@ impl DriveMonitor {
|
|||
.await
|
||||
{
|
||||
Ok(Ok(list)) => list,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Ok(Err(e)) => {
|
||||
error!("check_gbot: Failed to list objects in bucket {}: {}", self.bucket_name, e);
|
||||
return Err(e.into());
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Timeout listing objects in bucket {}", self.bucket_name);
|
||||
error!("Timeout listing objects in bucket {}", self.bucket_name);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
for obj in list_objects.contents.unwrap_or_default() {
|
||||
let path = obj.key().unwrap_or_default().to_string();
|
||||
let path_parts: Vec<&str> = path.split('/').collect();
|
||||
if path_parts.len() < 2
|
||||
|| !std::path::Path::new(path_parts[0])
|
||||
.extension()
|
||||
.is_some_and(|ext| ext.eq_ignore_ascii_case("gbot"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if !path.eq_ignore_ascii_case("config.csv")
|
||||
&& !path.to_ascii_lowercase().ends_with("/config.csv")
|
||||
{
|
||||
let path_lower = path.to_ascii_lowercase();
|
||||
|
||||
let is_config_csv = path_lower == "config.csv"
|
||||
|| path_lower.ends_with("/config.csv")
|
||||
|| path_lower.contains(".gbot/config.csv");
|
||||
|
||||
if !is_config_csv {
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("check_gbot: Found config.csv at path: {}", path);
|
||||
match client
|
||||
.head_object()
|
||||
.bucket(&self.bucket_name)
|
||||
|
|
@ -248,12 +250,22 @@ impl DriveMonitor {
|
|||
let _ = config_manager.sync_gbot_config(&self.bot_id, &csv_content);
|
||||
} else {
|
||||
use crate::llm::local::ensure_llama_servers_running;
|
||||
use crate::llm::DynamicLLMProvider;
|
||||
let mut restart_needed = false;
|
||||
for line in llm_lines {
|
||||
let mut llm_url_changed = false;
|
||||
let mut new_llm_url = String::new();
|
||||
let mut new_llm_model = String::new();
|
||||
for line in &llm_lines {
|
||||
let parts: Vec<&str> = line.split(',').collect();
|
||||
if parts.len() >= 2 {
|
||||
let key = parts[0].trim();
|
||||
let new_value = parts[1].trim();
|
||||
if key == "llm-url" {
|
||||
new_llm_url = new_value.to_string();
|
||||
}
|
||||
if key == "llm-model" {
|
||||
new_llm_model = new_value.to_string();
|
||||
}
|
||||
match config_manager.get_config(&self.bot_id, key, None) {
|
||||
Ok(old_value) => {
|
||||
if old_value != new_value {
|
||||
|
|
@ -262,10 +274,16 @@ impl DriveMonitor {
|
|||
key, old_value, new_value
|
||||
);
|
||||
restart_needed = true;
|
||||
if key == "llm-url" || key == "llm-model" {
|
||||
llm_url_changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
restart_needed = true;
|
||||
if key == "llm-url" || key == "llm-model" {
|
||||
llm_url_changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -278,6 +296,28 @@ impl DriveMonitor {
|
|||
log::error!("Failed to restart LLaMA servers after llm- config change: {}", e);
|
||||
}
|
||||
}
|
||||
if llm_url_changed {
|
||||
info!("check_gbot: LLM config changed, updating provider...");
|
||||
let effective_url = if new_llm_url.is_empty() {
|
||||
config_manager.get_config(&self.bot_id, "llm-url", None).unwrap_or_default()
|
||||
} else {
|
||||
new_llm_url
|
||||
};
|
||||
info!("check_gbot: Effective LLM URL: {}", effective_url);
|
||||
if !effective_url.is_empty() {
|
||||
if let Some(dynamic_provider) = self.state.extensions.get::<Arc<DynamicLLMProvider>>().await {
|
||||
let model = if new_llm_model.is_empty() { None } else { Some(new_llm_model.clone()) };
|
||||
dynamic_provider.update_from_config(&effective_url, model).await;
|
||||
info!("Updated LLM provider to use URL: {}, model: {:?}", effective_url, new_llm_model);
|
||||
} else {
|
||||
error!("DynamicLLMProvider not found in extensions, LLM provider cannot be updated dynamically");
|
||||
}
|
||||
} else {
|
||||
error!("check_gbot: No llm-url found in config, cannot update provider");
|
||||
}
|
||||
} else {
|
||||
debug!("check_gbot: No LLM config changes detected");
|
||||
}
|
||||
}
|
||||
if csv_content.lines().any(|line| line.starts_with("theme-")) {
|
||||
self.broadcast_theme_change(&csv_content).await?;
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use crate::{config::EmailConfig, core::urls::ApiUrls, shared::state::AppState};
|
|||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
response::{Html, IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use axum::{
|
||||
|
|
|
|||
|
|
@ -63,6 +63,8 @@ pub mod instagram;
|
|||
pub mod llm;
|
||||
#[cfg(feature = "llm")]
|
||||
pub use llm::cache::{CacheConfig, CachedLLMProvider, CachedResponse, LocalEmbeddingService};
|
||||
#[cfg(feature = "llm")]
|
||||
pub use llm::DynamicLLMProvider;
|
||||
|
||||
#[cfg(feature = "meet")]
|
||||
pub mod meet;
|
||||
|
|
|
|||
530
src/llm/claude.rs
Normal file
530
src/llm/claude.rs
Normal file
|
|
@ -0,0 +1,530 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{info, trace, warn};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::{llm_models::get_handler, LLMProvider};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
pub messages: Vec<ClaudeMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
pub content_type: String,
|
||||
#[serde(default)]
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ClaudeContentBlock>,
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub stop_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeStreamDelta {
|
||||
#[serde(rename = "type")]
|
||||
pub delta_type: String,
|
||||
#[serde(default)]
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClaudeStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: String,
|
||||
#[serde(default)]
|
||||
pub delta: Option<ClaudeStreamDelta>,
|
||||
#[serde(default)]
|
||||
pub index: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClaudeClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
deployment_name: String,
|
||||
is_azure: bool,
|
||||
}
|
||||
|
||||
impl ClaudeClient {
|
||||
pub fn new(base_url: String, deployment_name: Option<String>) -> Self {
|
||||
let is_azure = base_url.contains("azure.com") || base_url.contains("openai.azure.com");
|
||||
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
deployment_name: deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string()),
|
||||
is_azure,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn azure(endpoint: String, deployment_name: String) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url: endpoint,
|
||||
deployment_name,
|
||||
is_azure: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_url(&self) -> String {
|
||||
if self.is_azure {
|
||||
format!(
|
||||
"{}/deployments/{}/messages?api-version=2024-06-01",
|
||||
self.base_url.trim_end_matches('/'),
|
||||
self.deployment_name
|
||||
)
|
||||
} else {
|
||||
format!("{}/v1/messages", self.base_url.trim_end_matches('/'))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_headers(&self, api_key: &str) -> reqwest::header::HeaderMap {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
|
||||
if self.is_azure {
|
||||
if let Ok(val) = api_key.parse() {
|
||||
headers.insert("api-key", val);
|
||||
}
|
||||
} else {
|
||||
if let Ok(val) = api_key.parse() {
|
||||
headers.insert("x-api-key", val);
|
||||
}
|
||||
if let Ok(val) = "2023-06-01".parse() {
|
||||
headers.insert("anthropic-version", val);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(val) = "application/json".parse() {
|
||||
headers.insert(reqwest::header::CONTENT_TYPE, val);
|
||||
}
|
||||
|
||||
headers
|
||||
}
|
||||
|
||||
pub fn build_messages(
|
||||
system_prompt: &str,
|
||||
context_data: &str,
|
||||
history: &[(String, String)],
|
||||
) -> (Option<String>, Vec<ClaudeMessage>) {
|
||||
let mut system_parts = Vec::new();
|
||||
|
||||
if !system_prompt.is_empty() {
|
||||
system_parts.push(system_prompt.to_string());
|
||||
}
|
||||
if !context_data.is_empty() {
|
||||
system_parts.push(context_data.to_string());
|
||||
}
|
||||
|
||||
let system = if system_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(system_parts.join("\n\n"))
|
||||
};
|
||||
|
||||
let messages: Vec<ClaudeMessage> = history
|
||||
.iter()
|
||||
.map(|(role, content)| ClaudeMessage {
|
||||
role: role.clone(),
|
||||
content: content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
(system, messages)
|
||||
}
|
||||
|
||||
fn extract_text_from_response(&self, response: &ClaudeResponse) -> String {
|
||||
response
|
||||
.content
|
||||
.iter()
|
||||
.filter(|block| block.content_type == "text")
|
||||
.map(|block| block.text.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for ClaudeClient {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = self.build_url();
|
||||
let headers = self.build_headers(key);
|
||||
|
||||
let model_name = if model.is_empty() {
|
||||
&self.deployment_name
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let empty_vec = vec![];
|
||||
let claude_messages: Vec<ClaudeMessage> = if messages.is_array() {
|
||||
let arr = messages.as_array().unwrap_or(&empty_vec);
|
||||
if arr.is_empty() {
|
||||
vec![ClaudeMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}]
|
||||
} else {
|
||||
arr.iter()
|
||||
.filter_map(|m| {
|
||||
let role = m["role"].as_str().unwrap_or("user");
|
||||
let content = m["content"].as_str().unwrap_or("");
|
||||
if role == "system" {
|
||||
None
|
||||
} else {
|
||||
Some(ClaudeMessage {
|
||||
role: role.to_string(),
|
||||
content: content.to_string(),
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![ClaudeMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}]
|
||||
};
|
||||
|
||||
let system_prompt: Option<String> = if messages.is_array() {
|
||||
messages
|
||||
.as_array()
|
||||
.unwrap_or(&empty_vec)
|
||||
.iter()
|
||||
.filter(|m| m["role"].as_str() == Some("system"))
|
||||
.map(|m| m["content"].as_str().unwrap_or("").to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n")
|
||||
.into()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let system = system_prompt.filter(|s| !s.is_empty());
|
||||
|
||||
let request = ClaudeRequest {
|
||||
model: model_name.to_string(),
|
||||
max_tokens: 4096,
|
||||
messages: claude_messages,
|
||||
system,
|
||||
stream: None,
|
||||
};
|
||||
|
||||
info!("Claude request to {}: model={}", url, model_name);
|
||||
trace!("Claude request body: {:?}", serde_json::to_string(&request));
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.headers(headers)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
warn!("Claude API error ({}): {}", status, error_text);
|
||||
return Err(format!("Claude API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let result: ClaudeResponse = response.json().await?;
|
||||
let raw_content = self.extract_text_from_response(&result);
|
||||
|
||||
let handler = get_handler(model_name);
|
||||
let content = handler.process_content(&raw_content);
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
messages: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let url = self.build_url();
|
||||
let headers = self.build_headers(key);
|
||||
|
||||
let model_name = if model.is_empty() {
|
||||
&self.deployment_name
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let empty_vec = vec![];
|
||||
let claude_messages: Vec<ClaudeMessage> = if messages.is_array() {
|
||||
let arr = messages.as_array().unwrap_or(&empty_vec);
|
||||
if arr.is_empty() {
|
||||
vec![ClaudeMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}]
|
||||
} else {
|
||||
arr.iter()
|
||||
.filter_map(|m| {
|
||||
let role = m["role"].as_str().unwrap_or("user");
|
||||
let content = m["content"].as_str().unwrap_or("");
|
||||
if role == "system" {
|
||||
None
|
||||
} else {
|
||||
Some(ClaudeMessage {
|
||||
role: role.to_string(),
|
||||
content: content.to_string(),
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![ClaudeMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt.to_string(),
|
||||
}]
|
||||
};
|
||||
|
||||
let system_prompt: Option<String> = if messages.is_array() {
|
||||
messages
|
||||
.as_array()
|
||||
.unwrap_or(&empty_vec)
|
||||
.iter()
|
||||
.filter(|m| m["role"].as_str() == Some("system"))
|
||||
.map(|m| m["content"].as_str().unwrap_or("").to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n")
|
||||
.into()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let system = system_prompt.filter(|s| !s.is_empty());
|
||||
|
||||
let request = ClaudeRequest {
|
||||
model: model_name.to_string(),
|
||||
max_tokens: 4096,
|
||||
messages: claude_messages,
|
||||
system,
|
||||
stream: Some(true),
|
||||
};
|
||||
|
||||
info!("Claude streaming request to {}: model={}", url, model_name);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.headers(headers)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
warn!("Claude streaming API error ({}): {}", status, error_text);
|
||||
return Err(format!("Claude streaming API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let handler = get_handler(model_name);
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
||||
|
||||
for line in chunk_str.lines() {
|
||||
let line = line.trim();
|
||||
|
||||
if line.starts_with("data: ") {
|
||||
let data = &line[6..];
|
||||
|
||||
if data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Ok(event) = serde_json::from_str::<ClaudeStreamEvent>(data) {
|
||||
if event.event_type == "content_block_delta" {
|
||||
if let Some(delta) = event.delta {
|
||||
if delta.delta_type == "text_delta" && !delta.text.is_empty() {
|
||||
let processed = handler.process_content(&delta.text);
|
||||
if !processed.is_empty() {
|
||||
let _ = tx.send(processed).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_claude_client_new() {
|
||||
let client = ClaudeClient::new(
|
||||
"https://api.anthropic.com".to_string(),
|
||||
Some("claude-3-opus".to_string()),
|
||||
);
|
||||
assert!(!client.is_azure);
|
||||
assert_eq!(client.deployment_name, "claude-3-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_client_azure() {
|
||||
let client = ClaudeClient::azure(
|
||||
"https://myendpoint.openai.azure.com/anthropic".to_string(),
|
||||
"claude-opus-4-5".to_string(),
|
||||
);
|
||||
assert!(client.is_azure);
|
||||
assert_eq!(client.deployment_name, "claude-opus-4-5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_azure() {
|
||||
let client = ClaudeClient::azure(
|
||||
"https://myendpoint.openai.azure.com/anthropic".to_string(),
|
||||
"claude-opus-4-5".to_string(),
|
||||
);
|
||||
let url = client.build_url();
|
||||
assert!(url.contains("deployments/claude-opus-4-5/messages"));
|
||||
assert!(url.contains("api-version="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_url_anthropic() {
|
||||
let client = ClaudeClient::new(
|
||||
"https://api.anthropic.com".to_string(),
|
||||
None,
|
||||
);
|
||||
let url = client.build_url();
|
||||
assert_eq!(url, "https://api.anthropic.com/v1/messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_empty() {
|
||||
let (system, messages) = ClaudeClient::build_messages("", "", &[]);
|
||||
assert!(system.is_none());
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_with_system() {
|
||||
let (system, messages) = ClaudeClient::build_messages(
|
||||
"You are a helpful assistant.",
|
||||
"",
|
||||
&[],
|
||||
);
|
||||
assert_eq!(system, Some("You are a helpful assistant.".to_string()));
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_with_history() {
|
||||
let history = vec![
|
||||
("user".to_string(), "Hello".to_string()),
|
||||
("assistant".to_string(), "Hi there!".to_string()),
|
||||
];
|
||||
let (system, messages) = ClaudeClient::build_messages("", "", &history);
|
||||
assert!(system.is_none());
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_messages_full() {
|
||||
let history = vec![
|
||||
("user".to_string(), "What is 2+2?".to_string()),
|
||||
];
|
||||
let (system, messages) = ClaudeClient::build_messages(
|
||||
"You are a math tutor.",
|
||||
"Focus on step-by-step explanations.",
|
||||
&history,
|
||||
);
|
||||
assert!(system.is_some());
|
||||
assert!(system.unwrap().contains("math tutor"));
|
||||
assert_eq!(messages.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_request_serialization() {
|
||||
let request = ClaudeRequest {
|
||||
model: "claude-3-opus".to_string(),
|
||||
max_tokens: 4096,
|
||||
messages: vec![ClaudeMessage {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
}],
|
||||
system: Some("Be helpful".to_string()),
|
||||
stream: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("claude-3-opus"));
|
||||
assert!(json.contains("max_tokens"));
|
||||
assert!(json.contains("Be helpful"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_response_deserialization() {
|
||||
let json = r#"{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello!"}],
|
||||
"model": "claude-3-opus",
|
||||
"stop_reason": "end_turn"
|
||||
}"#;
|
||||
|
||||
let response: ClaudeResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(response.id, "msg_123");
|
||||
assert_eq!(response.content.len(), 1);
|
||||
assert_eq!(response.content[0].text, "Hello!");
|
||||
}
|
||||
}
|
||||
115
src/llm/mod.rs
115
src/llm/mod.rs
|
|
@ -2,14 +2,17 @@ use async_trait::async_trait;
|
|||
use futures::StreamExt;
|
||||
use log::{info, trace};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
|
||||
pub mod cache;
|
||||
pub mod claude;
|
||||
pub mod episodic_memory;
|
||||
pub mod llm_models;
|
||||
pub mod local;
|
||||
pub mod observability;
|
||||
|
||||
pub use claude::ClaudeClient;
|
||||
pub use llm_models::get_handler;
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -184,6 +187,116 @@ pub fn start_llm_services(state: &std::sync::Arc<crate::shared::state::AppState>
|
|||
info!("LLM services started (episodic memory scheduler)");
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum LLMProviderType {
|
||||
OpenAI,
|
||||
Claude,
|
||||
AzureClaude,
|
||||
}
|
||||
|
||||
impl From<&str> for LLMProviderType {
|
||||
fn from(s: &str) -> Self {
|
||||
let lower = s.to_lowercase();
|
||||
if lower.contains("claude") || lower.contains("anthropic") {
|
||||
if lower.contains("azure") {
|
||||
Self::AzureClaude
|
||||
} else {
|
||||
Self::Claude
|
||||
}
|
||||
} else {
|
||||
Self::OpenAI
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_llm_provider(
|
||||
provider_type: LLMProviderType,
|
||||
base_url: String,
|
||||
deployment_name: Option<String>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
match provider_type {
|
||||
LLMProviderType::OpenAI => {
|
||||
info!("Creating OpenAI LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(OpenAIClient::new("empty".to_string(), Some(base_url)))
|
||||
}
|
||||
LLMProviderType::Claude => {
|
||||
info!("Creating Claude LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(ClaudeClient::new(base_url, deployment_name))
|
||||
}
|
||||
LLMProviderType::AzureClaude => {
|
||||
let deployment = deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string());
|
||||
info!(
|
||||
"Creating Azure Claude LLM provider with URL: {}, deployment: {}",
|
||||
base_url, deployment
|
||||
);
|
||||
std::sync::Arc::new(ClaudeClient::azure(base_url, deployment))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_llm_provider_from_url(url: &str, model: Option<String>) -> std::sync::Arc<dyn LLMProvider> {
|
||||
let provider_type = LLMProviderType::from(url);
|
||||
create_llm_provider(provider_type, url.to_string(), model)
|
||||
}
|
||||
|
||||
pub struct DynamicLLMProvider {
|
||||
inner: RwLock<Arc<dyn LLMProvider>>,
|
||||
}
|
||||
|
||||
impl DynamicLLMProvider {
|
||||
pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(provider),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_provider(&self, new_provider: Arc<dyn LLMProvider>) {
|
||||
let mut guard = self.inner.write().await;
|
||||
*guard = new_provider;
|
||||
info!("LLM provider updated dynamically");
|
||||
}
|
||||
|
||||
pub async fn update_from_config(&self, url: &str, model: Option<String>) {
|
||||
let new_provider = create_llm_provider_from_url(url, model);
|
||||
self.update_provider(new_provider).await;
|
||||
}
|
||||
|
||||
async fn get_provider(&self) -> Arc<dyn LLMProvider> {
|
||||
self.inner.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LLMProvider for DynamicLLMProvider {
|
||||
async fn generate(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider().await.generate(prompt, config, model, key).await
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str,
|
||||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider().await.generate_stream(prompt, config, tx, model, key).await
|
||||
}
|
||||
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
session_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
self.get_provider().await.cancel_job(session_id).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
50
src/main.rs
50
src/main.rs
|
|
@ -185,6 +185,7 @@ async fn run_axum_server(
|
|||
.add_anonymous_path("/api/v1/health")
|
||||
.add_anonymous_path("/ws")
|
||||
.add_anonymous_path("/auth")
|
||||
.add_anonymous_path("/api/auth")
|
||||
.add_public_path("/static")
|
||||
.add_public_path("/favicon.ico"));
|
||||
|
||||
|
|
@ -750,10 +751,23 @@ async fn main() -> std::io::Result<()> {
|
|||
.unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||
info!("LLM URL: {}", llm_url);
|
||||
|
||||
let base_llm_provider = Arc::new(botserver::llm::OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some(llm_url.clone()),
|
||||
)) as Arc<dyn botserver::llm::LLMProvider>;
|
||||
let llm_model = config_manager
|
||||
.get_config(&default_bot_id, "llm-model", Some(""))
|
||||
.unwrap_or_default();
|
||||
if !llm_model.is_empty() {
|
||||
info!("LLM Model: {}", llm_model);
|
||||
}
|
||||
|
||||
let _llm_key = config_manager
|
||||
.get_config(&default_bot_id, "llm-key", Some(""))
|
||||
.unwrap_or_default();
|
||||
|
||||
let base_llm_provider = botserver::llm::create_llm_provider_from_url(
|
||||
&llm_url,
|
||||
if llm_model.is_empty() { None } else { Some(llm_model.clone()) },
|
||||
);
|
||||
|
||||
let dynamic_llm_provider = Arc::new(botserver::llm::DynamicLLMProvider::new(base_llm_provider));
|
||||
|
||||
let llm_provider: Arc<dyn botserver::llm::LLMProvider> = if let Some(ref cache) = redis_client {
|
||||
let embedding_url = config_manager
|
||||
|
|
@ -784,14 +798,14 @@ async fn main() -> std::io::Result<()> {
|
|||
};
|
||||
|
||||
Arc::new(botserver::llm::cache::CachedLLMProvider::with_db_pool(
|
||||
base_llm_provider,
|
||||
dynamic_llm_provider.clone() as Arc<dyn botserver::llm::LLMProvider>,
|
||||
cache.clone(),
|
||||
cache_config,
|
||||
embedding_service,
|
||||
pool.clone(),
|
||||
))
|
||||
} else {
|
||||
base_llm_provider
|
||||
dynamic_llm_provider.clone() as Arc<dyn botserver::llm::LLMProvider>
|
||||
};
|
||||
|
||||
let kb_manager = Arc::new(botserver::core::kb::KnowledgeBaseManager::new("work"));
|
||||
|
|
@ -833,7 +847,11 @@ async fn main() -> std::io::Result<()> {
|
|||
voice_adapter: voice_adapter.clone(),
|
||||
kb_manager: Some(kb_manager.clone()),
|
||||
task_engine,
|
||||
extensions: botserver::core::shared::state::Extensions::new(),
|
||||
extensions: {
|
||||
let ext = botserver::core::shared::state::Extensions::new();
|
||||
ext.insert_blocking(Arc::clone(&dynamic_llm_provider));
|
||||
ext
|
||||
},
|
||||
attendant_broadcast: Some(attendant_tx),
|
||||
});
|
||||
|
||||
|
|
@ -868,6 +886,24 @@ async fn main() -> std::io::Result<()> {
|
|||
error!("Failed to mount bots: {}", e);
|
||||
}
|
||||
|
||||
#[cfg(feature = "drive")]
|
||||
{
|
||||
let drive_monitor_state = app_state.clone();
|
||||
let bucket_name = "default.gbai".to_string();
|
||||
let monitor_bot_id = default_bot_id;
|
||||
tokio::spawn(async move {
|
||||
let monitor = botserver::DriveMonitor::new(
|
||||
drive_monitor_state,
|
||||
bucket_name.clone(),
|
||||
monitor_bot_id,
|
||||
);
|
||||
info!("Starting DriveMonitor for bucket: {}", bucket_name);
|
||||
if let Err(e) = monitor.start_monitoring().await {
|
||||
error!("DriveMonitor failed: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let automation_state = app_state.clone();
|
||||
tokio::spawn(async move {
|
||||
let automation = AutomationService::new(automation_state);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue