From 82bfd0a44360678a551a67ecca8d071d0d0cf5ab Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Tue, 10 Mar 2026 12:36:24 -0300 Subject: [PATCH] Fix Bedrock config for OpenAI GPT-OSS models --- src/attendance/llm_assist_config.rs | 2 +- src/core/bootstrap/bootstrap_manager.rs | 35 +- src/core/bootstrap/bootstrap_utils.rs | 25 +- src/core/bot/channels/whatsapp.rs | 2 +- src/core/bot/mod.rs | 22 +- src/core/config/mod.rs | 2 +- src/llm/bedrock.rs | 239 +++++++++++ src/llm/mod.rs | 35 +- src/llm/vertex.rs | 515 ++++++++++++++++++++++++ src/main_module/bootstrap.rs | 24 +- 10 files changed, 857 insertions(+), 44 deletions(-) create mode 100644 src/llm/bedrock.rs create mode 100644 src/llm/vertex.rs diff --git a/src/attendance/llm_assist_config.rs b/src/attendance/llm_assist_config.rs index f1a3745e..9b915b40 100644 --- a/src/attendance/llm_assist_config.rs +++ b/src/attendance/llm_assist_config.rs @@ -23,7 +23,7 @@ impl LlmAssistConfig { if let Ok(content) = std::fs::read_to_string(&path) { for line in content.lines() { - let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect(); + let parts: Vec<&str> = line.splitn(2, ',').map(|s| s.trim()).collect(); if parts.len() < 2 { continue; diff --git a/src/core/bootstrap/bootstrap_manager.rs b/src/core/bootstrap/bootstrap_manager.rs index 856c87f2..e2f031c8 100644 --- a/src/core/bootstrap/bootstrap_manager.rs +++ b/src/core/bootstrap/bootstrap_manager.rs @@ -160,30 +160,25 @@ impl BootstrapManager { if pm.is_installed("directory") { // Wait for Zitadel to be ready - it might have been started during installation - // Use incremental backoff: check frequently at first, then slow down + // Use very aggressive backoff for fastest startup detection let mut directory_already_running = zitadel_health_check(); if !directory_already_running { - info!("Zitadel not responding to health check, waiting with incremental backoff..."); - // Check intervals: 1s x5, 2s x5, 5s x5, 10s x3 = ~60s total - let intervals: [(u64, u32); 4] = [(1, 5), (2, 5), (5, 5), (10, 3)]; - let mut total_waited: u64 = 0; - for (interval_secs, count) in intervals { - for i in 0..count { - if zitadel_health_check() { - info!("Zitadel/Directory service is now responding (waited {}s)", total_waited); - directory_already_running = true; - break; - } - sleep(Duration::from_secs(interval_secs)).await; - total_waited += interval_secs; - // Show incremental progress every ~10s - if total_waited % 10 == 0 || i == 0 { - info!("Zitadel health check: {}s elapsed, retrying...", total_waited); - } - } - if directory_already_running { + info!("Zitadel not responding to health check, waiting..."); + // Check every 500ms for fast detection (was: 1s, 2s, 5s, 10s) + let mut checks = 0; + let max_checks = 120; // 60 seconds max + while checks < max_checks { + if zitadel_health_check() { + info!("Zitadel/Directory service is now responding (checked {} times)", checks); + directory_already_running = true; break; } + sleep(Duration::from_millis(500)).await; + checks += 1; + // Log progress every 10 checks (5 seconds) + if checks % 10 == 0 { + info!("Zitadel health check: {}s elapsed, retrying...", checks / 2); + } } } diff --git a/src/core/bootstrap/bootstrap_utils.rs b/src/core/bootstrap/bootstrap_utils.rs index 19a8b0d5..d6963efb 100644 --- a/src/core/bootstrap/bootstrap_utils.rs +++ b/src/core/bootstrap/bootstrap_utils.rs @@ -207,16 +207,27 @@ pub enum BotExistsResult { /// Check if Zitadel directory is healthy pub fn zitadel_health_check() -> bool { // Check if Zitadel is responding on port 8300 - if let Ok(output) = Command::new("curl") - .args(["-f", "-s", "--connect-timeout", "2", "http://localhost:8300/debug/healthz"]) - .output() - { - if output.status.success() { - return true; + // Use very short timeout for fast startup detection + let output = Command::new("curl") + .args(["-f", "-s", "--connect-timeout", "1", "-m", "2", "http://localhost:8300/debug/healthz"]) + .output(); + + match output { + Ok(result) => { + if result.status.success() { + let response = String::from_utf8_lossy(&result.stdout); + debug!("Zitadel health check response: {}", response); + return response.trim() == "ok"; + } + let stderr = String::from_utf8_lossy(&result.stderr); + debug!("Zitadel health check failed: {}", stderr); + } + Err(e) => { + debug!("Zitadel health check error: {}", e); } } - // Fallback: just check if port 8300 is listening + // Fast fallback: just check if port 8300 is listening match Command::new("nc") .args(["-z", "-w", "1", "127.0.0.1", "8300"]) .output() diff --git a/src/core/bot/channels/whatsapp.rs b/src/core/bot/channels/whatsapp.rs index da0bf044..e27ab440 100644 --- a/src/core/bot/channels/whatsapp.rs +++ b/src/core/bot/channels/whatsapp.rs @@ -503,7 +503,7 @@ impl WhatsAppAdapter { } if !list_block.is_empty() { - if !current_part.is_empty() && current_part.len() + list_block.len() + 1 <= max_length { + if !current_part.is_empty() && current_part.len() + list_block.len() < max_length { current_part.push('\n'); current_part.push_str(&list_block); } else { diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index cd38290d..55a0da6a 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -416,7 +416,7 @@ impl BotOrchestrator { let session_id = Uuid::parse_str(&message.session_id)?; let message_content = message.content.clone(); - let (session, context_data, history, model, key, system_prompt) = { + let (session, context_data, history, model, key, system_prompt, bot_llm_url) = { let state_clone = self.state.clone(); tokio::task::spawn_blocking( move || -> Result<_, Box> { @@ -453,14 +453,19 @@ impl BotOrchestrator { .get_config(&session.bot_id, "llm-key", Some("")) .unwrap_or_default(); + // Load bot-specific llm-url (may differ from global default) + let bot_llm_url = config_manager + .get_bot_config_value(&session.bot_id, "llm-url") + .ok(); + // Load system-prompt from config.csv, fallback to default let system_prompt = config_manager .get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response.")) - .unwrap_or_else(|_| "You are a helpful assistant.".to_string()); + .unwrap_or_else(|_| "You are a helpful General Bots assistant.".to_string()); - info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(200)]); + info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(500)]); - Ok((session, context_data, history, model, key, system_prompt)) + Ok((session, context_data, history, model, key, system_prompt, bot_llm_url)) }, ) .await?? @@ -615,7 +620,14 @@ impl BotOrchestrator { } let (stream_tx, mut stream_rx) = mpsc::channel::(100); - let llm = self.state.llm_provider.clone(); + + // Use bot-specific LLM provider if the bot has its own llm-url configured + let llm: std::sync::Arc = if let Some(ref url) = bot_llm_url { + info!("Bot has custom llm-url: {}, creating per-bot LLM provider", url); + crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None) + } else { + self.state.llm_provider.clone() + }; let model_clone = model.clone(); let key_clone = key.clone(); diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index f5d986f3..b842eaf4 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -482,7 +482,7 @@ impl ConfigManager { let mut updated = 0; for line in content.lines().skip(1) { - let parts: Vec<&str> = line.split(',').collect(); + let parts: Vec<&str> = line.splitn(2, ',').collect(); if parts.len() >= 2 { let key = parts[0].trim(); let value = parts[1].trim(); diff --git a/src/llm/bedrock.rs b/src/llm/bedrock.rs new file mode 100644 index 00000000..e5a58562 --- /dev/null +++ b/src/llm/bedrock.rs @@ -0,0 +1,239 @@ +use async_trait::async_trait; +use futures::StreamExt; +use log::{error, info}; +use serde_json::Value; +use tokio::sync::mpsc; + +use crate::llm::LLMProvider; + +#[derive(Debug)] +pub struct BedrockClient { + client: reqwest::Client, + base_url: String, +} + +impl BedrockClient { + pub fn new(base_url: String) -> Self { + Self { + client: reqwest::Client::new(), + base_url, + } + } +} + +#[async_trait] +impl LLMProvider for BedrockClient { + async fn generate( + &self, + prompt: &str, + messages: &Value, + model: &str, + key: &str, + ) -> Result> { + let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); + + let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { + messages + } else { + &default_messages + }; + + + let mut messages_limited = Vec::new(); + if let Some(msg_array) = raw_messages.as_array() { + for msg in msg_array { + messages_limited.push(msg.clone()); + } + } + let formatted_messages = serde_json::Value::Array(messages_limited); + + let auth_header = if key.starts_with("Bearer ") { + key.to_string() + } else { + format!("Bearer {}", key) + }; + + let request_body = serde_json::json!({ + "model": model, + "messages": formatted_messages, + "stream": false + }); + + info!("Sending request to Bedrock endpoint: {}", self.base_url); + + let response = self.client + .post(&self.base_url) + .header("Authorization", &auth_header) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Bedrock generate error: {}", error_text); + return Err(format!("Bedrock API error ({}): {}", status, error_text).into()); + } + + let json: Value = response.json().await?; + + if let Some(choices) = json.get("choices") { + if let Some(first_choice) = choices.get(0) { + if let Some(message) = first_choice.get("message") { + if let Some(content) = message.get("content") { + if let Some(content_str) = content.as_str() { + return Ok(content_str.to_string()); + } + } + } + } + } + + Err("Failed to parse response from Bedrock".into()) + } + + async fn generate_stream( + &self, + prompt: &str, + messages: &Value, + tx: mpsc::Sender, + model: &str, + key: &str, + tools: Option<&Vec>, + ) -> Result<(), Box> { + let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); + + let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { + messages + } else { + &default_messages + }; + + let mut messages_limited = Vec::new(); + if let Some(msg_array) = raw_messages.as_array() { + for msg in msg_array { + messages_limited.push(msg.clone()); + } + } + let formatted_messages = serde_json::Value::Array(messages_limited); + + let auth_header = if key.starts_with("Bearer ") { + key.to_string() + } else { + format!("Bearer {}", key) + }; + + let mut request_body = serde_json::json!({ + "model": model, + "messages": formatted_messages, + "stream": true + }); + + if let Some(tools_value) = tools { + if !tools_value.is_empty() { + request_body["tools"] = serde_json::json!(tools_value); + info!("Added {} tools to Bedrock request", tools_value.len()); + } + } + + let stream_url = if self.base_url.ends_with("/invoke") { + self.base_url.replace("/invoke", "/invoke-with-response-stream") + } else { + self.base_url.clone() + }; + + info!("Sending streaming request to Bedrock endpoint: {}", stream_url); + + let response = self.client + .post(&stream_url) + .header("Authorization", &auth_header) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Bedrock generate_stream error: {}", error_text); + return Err(format!("Bedrock API error ({}): {}", status, error_text).into()); + } + + let mut stream = response.bytes_stream(); + let mut tool_call_buffer = String::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Ok(text) = std::str::from_utf8(&chunk) { + for line in text.split('\n') { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + continue; + } + + if let Ok(json) = serde_json::from_str::(data) { + if let Some(choices) = json.get("choices") { + if let Some(first_choice) = choices.get(0) { + if let Some(delta) = first_choice.get("delta") { + // Handle standard content streaming + if let Some(content) = delta.get("content") { + if let Some(content_str) = content.as_str() { + if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() { + return Ok(()); + } + } + } + + // Handle tool calls streaming + if let Some(tool_calls) = delta.get("tool_calls") { + if let Some(calls_array) = tool_calls.as_array() { + if let Some(first_call) = calls_array.first() { + if let Some(function) = first_call.get("function") { + // Stream function JSON representation just like OpenAI does + if let Some(name) = function.get("name") { + if let Some(name_str) = name.as_str() { + tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str); + } + } + if let Some(args) = function.get("arguments") { + if let Some(args_str) = args.as_str() { + tool_call_buffer.push_str(args_str); + } + } + } + } + } + } + } + } + } + } + } + } + } + } + Err(e) => { + error!("Bedrock stream reading error: {}", e); + break; + } + } + } + + // Finalize tool call JSON parsing + if !tool_call_buffer.is_empty() { + tool_call_buffer.push_str("\"}"); + if tx.send(format!("`tool_call`: {}", tool_call_buffer)).await.is_err() { + return Ok(()); + } + } + + Ok(()) + } + + async fn cancel_job(&self, _session_id: &str) -> Result<(), Box> { + Ok(()) + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index daa70961..a46ea00b 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -14,11 +14,15 @@ pub mod llm_models; pub mod local; pub mod rate_limiter; pub mod smart_router; +pub mod vertex; +pub mod bedrock; pub use claude::ClaudeClient; pub use glm::GLMClient; pub use llm_models::get_handler; pub use rate_limiter::{ApiRateLimiter, RateLimits}; +pub use vertex::VertexTokenManager; +pub use bedrock::BedrockClient; #[async_trait] pub trait LLMProvider: Send + Sync { @@ -248,14 +252,16 @@ impl LLMProvider for OpenAIClient { // GLM-4.7 has 202750 tokens, other models vary let context_limit = if model.contains("glm-4") || model.contains("GLM-4") { 202750 - } else if model.contains("gpt-4") { - 128000 + } else if model.contains("gemini") { + 1000000 // Google Gemini models have 1M+ token context + } else if model.contains("gpt-oss") || model.contains("gpt-4") { + 128000 // Cerebras gpt-oss models and GPT-4 variants } else if model.contains("gpt-3.5") { 16385 } else if model.starts_with("http://localhost:808") || model == "local" { 768 // Local llama.cpp server context limit } else { - 4096 // Default conservative limit + 32768 // Default conservative limit for modern models }; let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit); @@ -329,14 +335,16 @@ impl LLMProvider for OpenAIClient { // GLM-4.7 has 202750 tokens, other models vary let context_limit = if model.contains("glm-4") || model.contains("GLM-4") { 202750 - } else if model.contains("gpt-4") { - 128000 + } else if model.contains("gemini") { + 1000000 // Google Gemini models have 1M+ token context + } else if model.contains("gpt-oss") || model.contains("gpt-4") { + 128000 // Cerebras gpt-oss models and GPT-4 variants } else if model.contains("gpt-3.5") { 16385 } else if model.starts_with("http://localhost:808") || model == "local" { 768 // Local llama.cpp server context limit } else { - 4096 // Default conservative limit + 32768 // Default conservative limit for modern models }; let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit); @@ -448,6 +456,8 @@ pub enum LLMProviderType { Claude, AzureClaude, GLM, + Bedrock, + Vertex, } impl From<&str> for LLMProviderType { @@ -461,6 +471,10 @@ impl From<&str> for LLMProviderType { } } else if lower.contains("z.ai") || lower.contains("glm") { Self::GLM + } else if lower.contains("bedrock") { + Self::Bedrock + } else if lower.contains("googleapis.com") || lower.contains("vertex") || lower.contains("generativelanguage") { + Self::Vertex } else { Self::OpenAI } @@ -498,6 +512,15 @@ pub fn create_llm_provider( info!("Creating GLM/z.ai LLM provider with URL: {}", base_url); std::sync::Arc::new(GLMClient::new(base_url)) } + LLMProviderType::Bedrock => { + info!("Creating Bedrock LLM provider with exact URL: {}", base_url); + std::sync::Arc::new(BedrockClient::new(base_url)) + } + LLMProviderType::Vertex => { + info!("Creating Vertex/Gemini LLM provider with URL: {}", base_url); + // Re-export the struct if we haven't already + std::sync::Arc::new(crate::llm::vertex::VertexClient::new(base_url, endpoint_path)) + } } } diff --git a/src/llm/vertex.rs b/src/llm/vertex.rs new file mode 100644 index 00000000..c8ce4322 --- /dev/null +++ b/src/llm/vertex.rs @@ -0,0 +1,515 @@ +//! Google Vertex AI and Gemini Native API Integration +//! Support for both OpenAI-compatible endpoints and native Google AI Studio / Vertex AI endpoints. + +use log::{error, info, trace}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::RwLock; + +const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token"; +const GOOGLE_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; + +#[derive(Debug, Deserialize)] +struct ServiceAccountKey { + client_email: String, + private_key: String, + token_uri: Option, +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: u64, +} + +#[derive(Debug, Clone)] +struct CachedToken { + access_token: String, + expires_at: std::time::Instant, +} + +/// Manages OAuth2 access tokens for Google Vertex AI +#[derive(Debug, Clone)] +pub struct VertexTokenManager { + credentials_path: String, + cached_token: Arc>>, +} + +impl VertexTokenManager { + pub fn new(credentials_path: &str) -> Self { + Self { + credentials_path: credentials_path.to_string(), + cached_token: Arc::new(RwLock::new(None)), + } + } + + /// Get a valid access token, refreshing if necessary + pub async fn get_access_token(&self) -> Result> { + // Check cached token + { + let cache = self.cached_token.read().await; + if let Some(ref token) = *cache { + // Return cached token if it expires in more than 60 seconds + if token.expires_at > std::time::Instant::now() + std::time::Duration::from_secs(60) { + trace!("Using cached Vertex AI access token"); + return Ok(token.access_token.clone()); + } + } + } + + // Generate new token + info!("Generating new Vertex AI access token from service account"); + let token = self.generate_token().await?; + + // Cache it + { + let mut cache = self.cached_token.write().await; + *cache = Some(token.clone()); + } + + Ok(token.access_token) + } + + async fn generate_token(&self) -> Result> { + let content = if self.credentials_path.trim().starts_with('{') { + self.credentials_path.clone() + } else { + // Expand ~ to home directory + let expanded_path = if self.credentials_path.starts_with('~') { + let home = std::env::var("HOME").unwrap_or_else(|_| "/root".to_string()); + self.credentials_path.replacen('~', &home, 1) + } else { + self.credentials_path.clone() + }; + + tokio::fs::read_to_string(&expanded_path) + .await + .map_err(|e| format!("Failed to read service account key from '{}': {}", expanded_path, e))? + }; + + let sa_key: ServiceAccountKey = serde_json::from_str(&content) + .map_err(|e| format!("Failed to parse service account JSON: {}", e))?; + + let token_uri = sa_key.token_uri.as_deref().unwrap_or(GOOGLE_TOKEN_URL); + + // Create JWT assertion + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_err(|e| format!("System time error: {}", e))? + .as_secs(); + + let jwt_claims = JwtClaims { + iss: sa_key.client_email.clone(), + scope: GOOGLE_SCOPE.to_string(), + aud: token_uri.to_string(), + iat: now, + exp: now + 3600, + }; + + let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256); + let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(sa_key.private_key.as_bytes()) + .map_err(|e| format!("Failed to parse RSA private key: {}", e))?; + + let jwt = jsonwebtoken::encode(&header, &jwt_claims, &encoding_key) + .map_err(|e| format!("Failed to encode JWT: {}", e))?; + + // Exchange JWT for access token + let client = reqwest::Client::new(); + let response = client + .post(token_uri) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt), + ]) + .send() + .await + .map_err(|e| format!("Failed to request token: {}", e))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + error!("Google OAuth2 token request failed: {} - {}", status, body); + return Err(format!("Token request failed with status {}: {}", status, body).into()); + } + + let token_response: TokenResponse = response + .json() + .await + .map_err(|e| format!("Failed to parse token response: {}", e))?; + + info!("Successfully obtained Vertex AI access token (expires in {}s)", token_response.expires_in); + + Ok(CachedToken { + access_token: token_response.access_token, + expires_at: std::time::Instant::now() + std::time::Duration::from_secs(token_response.expires_in), + }) + } +} + +#[derive(Debug, Serialize)] +struct JwtClaims { + iss: String, + scope: String, + aud: String, + iat: u64, + exp: u64, +} + +/// Builds the Vertex AI OpenAI-compatible endpoint URL +pub fn build_vertex_url(project: &str, location: &str) -> String { + format!( + "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/endpoints/openapi" + ) +} + +/// Check if a URL is a Vertex AI endpoint +pub fn is_vertex_ai_url(url: &str) -> bool { + url.contains("aiplatform.googleapis.com") || url.contains("generativelanguage.googleapis.com") +} + +use crate::llm::LLMProvider; +use serde_json::Value; +use tokio::sync::mpsc; +use async_trait::async_trait; +use futures::StreamExt; + +#[derive(Debug)] +pub struct VertexClient { + client: reqwest::Client, + base_url: String, + endpoint_path: String, + token_manager: Arc>>>, +} + +impl VertexClient { + pub fn new(base_url: String, endpoint_path: Option) -> Self { + let endpoint = endpoint_path.unwrap_or_else(|| { + if base_url.contains("generativelanguage.googleapis.com") { + // Default to OpenAI compatibility if possible, but generate_stream will auto-detect + "/v1beta/openai/chat/completions".to_string() + } else if base_url.contains("aiplatform.googleapis.com") && !base_url.contains("openapi") { + // If it's a naked aiplatform URL, it might be missing the project/location path + // We'll leave it empty and let the caller provide a full URL if needed + "".to_string() + } else { + "".to_string() + } + }); + + Self { + client: reqwest::Client::new(), + base_url, + endpoint_path: endpoint, + token_manager: Arc::new(tokio::sync::RwLock::new(None)), + } + } + + async fn get_auth_header(&self, key: &str) -> (&'static str, String) { + // If the key is a path or JSON (starts with {), it's a Service Account + if key.starts_with('{') || key.starts_with('~') || key.starts_with('/') || key.ends_with(".json") { + let mut manager_opt = self.token_manager.write().await; + if manager_opt.is_none() { + *manager_opt = Some(Arc::new(VertexTokenManager::new(key))); + } + if let Some(manager) = manager_opt.as_ref() { + match manager.get_access_token().await { + Ok(token) => return ("Authorization", format!("Bearer {}", token)), + Err(e) => error!("Failed to get Vertex OAuth token: {}", e), + } + } + } + + // Default to Google API Key (Google AI Studio / Gemini Developer APIs) + ("x-goog-api-key", key.to_string()) + } + + fn convert_messages_to_gemini(&self, messages: &Value) -> Value { + let mut contents = Vec::new(); + if let Some(msg_array) = messages.as_array() { + for msg in msg_array { + let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("user"); + let content = msg.get("content").and_then(|c| c.as_str()).unwrap_or(""); + + let gemini_role = match role { + "assistant" => "model", + "system" => "user", // Gemini doesn't have system role in 'contents' by default, often wrapped in systemInstruction + _ => "user", + }; + + contents.push(serde_json::json!({ + "role": gemini_role, + "parts": [{"text": content}] + })); + } + } + serde_json::json!(contents) + } +} + +#[async_trait] +impl LLMProvider for VertexClient { + async fn generate( + &self, + prompt: &str, + messages: &Value, + model: &str, + key: &str, + ) -> Result> { + let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); + let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { + messages + } else { + &default_messages + }; + + let (header_name, auth_value) = self.get_auth_header(key).await; + + // Auto-detect if we should use native Gemini API + let is_native = self.endpoint_path.contains(":generateContent") || + self.base_url.contains("generativelanguage") && !self.endpoint_path.contains("openai"); + + let full_url = if is_native && !self.endpoint_path.contains(":generateContent") { + // Build native URL for generativelanguage + format!("{}/v1beta/models/{}:generateContent", self.base_url.trim_end_matches('/'), model) + } else { + format!("{}{}", self.base_url, self.endpoint_path) + }.trim_end_matches('/').to_string(); + + let request_body = if is_native { + serde_json::json!({ + "contents": self.convert_messages_to_gemini(raw_messages), + "generationConfig": { + "temperature": 0.7, + } + }) + } else { + serde_json::json!({ + "model": model, + "messages": raw_messages, + "stream": false + }) + }; + + info!("Sending request to Vertex/Gemini endpoint: {}", full_url); + + let response = self.client + .post(full_url) + .header(header_name, &auth_value) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Vertex/Gemini generate error: {}", error_text); + return Err(format!("Google API error ({}): {}", status, error_text).into()); + } + + let json: Value = response.json().await?; + + // Try parsing OpenAI format + if let Some(choices) = json.get("choices") { + if let Some(first_choice) = choices.get(0) { + if let Some(message) = first_choice.get("message") { + if let Some(content) = message.get("content") { + if let Some(content_str) = content.as_str() { + return Ok(content_str.to_string()); + } + } + } + } + } + + // Try parsing Native Gemini format + if let Some(candidates) = json.get("candidates") { + if let Some(first_candidate) = candidates.get(0) { + if let Some(content) = first_candidate.get("content") { + if let Some(parts) = content.get("parts") { + if let Some(first_part) = parts.get(0) { + if let Some(text) = first_part.get("text") { + if let Some(text_str) = text.as_str() { + return Ok(text_str.to_string()); + } + } + } + } + } + } + } + + Err("Failed to parse response from Vertex/Gemini (unknown format)".into()) + } + + async fn generate_stream( + &self, + prompt: &str, + messages: &Value, + tx: mpsc::Sender, + model: &str, + key: &str, + tools: Option<&Vec>, + ) -> Result<(), Box> { + let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); + let raw_messages = if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { + messages + } else { + &default_messages + }; + + let (header_name, auth_value) = self.get_auth_header(key).await; + + // Auto-detect if we should use native Gemini API + let is_native = self.endpoint_path.contains("GenerateContent") || + (self.base_url.contains("googleapis.com") && + !self.endpoint_path.contains("openai") && + !self.endpoint_path.contains("completions")); + + let full_url = if is_native && !self.endpoint_path.contains("GenerateContent") { + // Build native URL for Gemini if it looks like we need it + if self.base_url.contains("aiplatform") { + // Global Vertex endpoint format + format!("{}/v1/publishers/google/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model) + } else { + // Google AI Studio format + format!("{}/v1beta/models/{}:streamGenerateContent", self.base_url.trim_end_matches('/'), model) + } + } else { + format!("{}{}", self.base_url, self.endpoint_path) + }.trim_end_matches('/').to_string(); + + let mut request_body = if is_native { + let mut body = serde_json::json!({ + "contents": self.convert_messages_to_gemini(raw_messages), + "generationConfig": { + "temperature": 0.7, + } + }); + + // Handle thinking models if requested + if model.contains("preview") || model.contains("3.1-flash") { + body["generationConfig"]["thinkingConfig"] = serde_json::json!({ + "thinkingLevel": "LOW" + }); + } + body + } else { + serde_json::json!({ + "model": model, + "messages": raw_messages, + "stream": true + }) + }; + + if let Some(tools_value) = tools { + if !tools_value.is_empty() { + request_body["tools"] = serde_json::json!(tools_value); + info!("Added {} tools to Vertex request", tools_value.len()); + } + } + + info!("Sending streaming request to Vertex/Gemini endpoint: {}", full_url); + + let response = self.client + .post(full_url) + .header(header_name, &auth_value) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Vertex/Gemini generate_stream error: {}", error_text); + return Err(format!("Google API error ({}): {}", status, error_text).into()); + } + + let mut stream = response.bytes_stream(); + let mut tool_call_buffer = String::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Ok(text) = std::str::from_utf8(&chunk) { + for line in text.split('\n') { + let line = line.trim(); + if line.is_empty() { continue; } + + if let Some(data) = line.strip_prefix("data: ") { + // --- OpenAI SSE Format --- + if data == "[DONE]" { continue; } + if let Ok(json) = serde_json::from_str::(data) { + if let Some(choices) = json.get("choices") { + if let Some(first_choice) = choices.get(0) { + if let Some(delta) = first_choice.get("delta") { + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + let _ = tx.send(content.to_string()).await; + } + } + + // Handle tool calls... + if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) { + if let Some(first_call) = tool_calls.first() { + if let Some(function) = first_call.get("function") { + if let Some(name) = function.get("name").and_then(|n| n.as_str()) { + tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name); + } + if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) { + tool_call_buffer.push_str(args); + } + } + } + } + } + } + } + } + } else { + // --- Native Gemini JSON format --- + // It usually arrives as raw JSON objects, sometimes with leading commas or brackets in a stream array + let trimmed = line.trim_start_matches(|c| c == ',' || c == '[' || c == ']').trim_end_matches(']'); + if trimmed.is_empty() { continue; } + + if let Ok(json) = serde_json::from_str::(trimmed) { + if let Some(candidates) = json.get("candidates").and_then(|c| c.as_array()) { + if let Some(candidate) = candidates.first() { + if let Some(content) = candidate.get("content") { + if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { + for part in parts { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + if !text.is_empty() { + let _ = tx.send(text.to_string()).await; + } + } + } + } + } + } + } + } + } + } + } + } + Err(e) => { + error!("Vertex stream reading error: {}", e); + break; + } + } + } + + if !tool_call_buffer.is_empty() { + tool_call_buffer.push_str("\"}"); + let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await; + } + + Ok(()) + } + + async fn cancel_job(&self, _session_id: &str) -> Result<(), Box> { + Ok(()) + } +} diff --git a/src/main_module/bootstrap.rs b/src/main_module/bootstrap.rs index cf948c0b..942b29b8 100644 --- a/src/main_module/bootstrap.rs +++ b/src/main_module/bootstrap.rs @@ -880,12 +880,30 @@ async fn start_drive_monitors( .await .unwrap_or_default(); + let local_dev_bots = tokio::task::spawn_blocking(move || { + let mut bots = std::collections::HashSet::new(); + let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/opt/gbo/data".to_string()); + if let Ok(entries) = std::fs::read_dir(data_dir) { + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|e| e.to_str()).map(|e| e.eq_ignore_ascii_case("gbai")).unwrap_or(false) { + if let Some(bot_name) = path.file_stem().and_then(|s| s.to_str()) { + bots.insert(bot_name.to_string()); + } + } + } + } + bots + }) + .await + .unwrap_or_default(); + info!("Found {} active bots to monitor", bots_to_monitor.len()); for (bot_id, bot_name) in bots_to_monitor { - // Skip default bot - it's managed locally via ConfigWatcher - if bot_name == "default" { - info!("Skipping DriveMonitor for 'default' bot - managed via ConfigWatcher"); + // Skip default bot and local dev bots - they are managed locally + if bot_name == "default" || local_dev_bots.contains(&bot_name) { + info!("Skipping DriveMonitor for '{}' bot - managed locally via ConfigWatcher/LocalFileMonitor", bot_name); continue; }