diff --git a/src/auto_task/container_session.rs b/src/auto_task/container_session.rs index 25049e35..ec4a13c5 100644 --- a/src/auto_task/container_session.rs +++ b/src/auto_task/container_session.rs @@ -25,7 +25,7 @@ impl ContainerSession { // Launch the container (this might take a moment if the image isn't cached locally) info!("Launching LXC container: {}", container_name); let launch_status = Command::new("lxc") - .args(&["launch", "ubuntu:22.04", &container_name]) + .args(["launch", "ubuntu:22.04", &container_name]) .output() .await .map_err(|e| format!("Failed to execute lxc launch: {}", e))?; @@ -50,7 +50,7 @@ impl ContainerSession { info!("Starting terminal session in container: {}", self.container_name); let mut child = Command::new("lxc") - .args(&["exec", &self.container_name, "--", "bash"]) + .args(["exec", &self.container_name, "--", "bash"]) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -115,7 +115,7 @@ impl ContainerSession { // Clean up container let status = Command::new("lxc") - .args(&["delete", &self.container_name, "--force"]) + .args(["delete", &self.container_name, "--force"]) .output() .await .map_err(|e| format!("Failed to delete container: {}", e))?; diff --git a/src/auto_task/orchestrator.rs b/src/auto_task/orchestrator.rs index 5855f814..ca3ea4c5 100644 --- a/src/auto_task/orchestrator.rs +++ b/src/auto_task/orchestrator.rs @@ -659,7 +659,7 @@ impl Orchestrator { if !classification.entities.features.is_empty() { for feature in &classification.entities.features { let slug = - feature.to_lowercase().replace(' ', "_").replace('-', "_"); + feature.to_lowercase().replace([' ', '-'], "_"); tasks.push(PipelineSubTask { name: feature.clone(), description: format!( @@ -681,7 +681,7 @@ impl Orchestrator { let inferred = infer_features_from_intent(&lower); for feature in &inferred { let slug = - feature.to_lowercase().replace(' ', "_").replace('-', "_"); + feature.to_lowercase().replace([' ', '-'], "_"); tasks.push(PipelineSubTask { name: feature.clone(), description: format!("Build {} UI with HTMX", feature), @@ -897,7 +897,7 @@ impl Orchestrator { let event = TaskProgressEvent::new( &self.task_id, step, - &format!("Mantis #{agent_id} activity"), + format!("Mantis #{agent_id} activity"), ) .with_event_type("agent_activity") .with_activity(activity.clone()); diff --git a/src/basic/keywords/send_mail.rs b/src/basic/keywords/send_mail.rs index 08ddf995..c45fc80c 100644 --- a/src/basic/keywords/send_mail.rs +++ b/src/basic/keywords/send_mail.rs @@ -350,9 +350,9 @@ async fn execute_send_mail( if email_service .send_email( - &to, - &subject, - &body, + to, + subject, + body, if attachments.is_empty() { None } else { diff --git a/src/basic/keywords/sms.rs b/src/basic/keywords/sms.rs index 670daf7c..97936218 100644 --- a/src/basic/keywords/sms.rs +++ b/src/basic/keywords/sms.rs @@ -47,18 +47,15 @@ pub enum SmsProvider { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Default)] pub enum SmsPriority { Low, + #[default] Normal, High, Urgent, } -impl Default for SmsPriority { - fn default() -> Self { - Self::Normal - } -} impl From<&str> for SmsPriority { fn from(s: &str) -> Self { diff --git a/src/basic/keywords/use_website.rs b/src/basic/keywords/use_website.rs index 6525c7b5..5edbc02d 100644 --- a/src/basic/keywords/use_website.rs +++ b/src/basic/keywords/use_website.rs @@ -594,7 +594,10 @@ pub fn register_website_for_crawling_with_refresh( Ok(()) } -/// Update refresh policy if the new interval is shorter than the existing one +/// Update refresh policy if the new interval is shorter than the existing one. +/// When the policy is updated, triggers an immediate crawl by setting next_crawl=NOW() +/// and crawl_status=0, ensuring the website is recrawled on the next crawler cycle +/// regardless of the previous schedule. fn update_refresh_policy_if_shorter( conn: &mut PgConnection, bot_id: &Uuid, @@ -636,7 +639,7 @@ fn update_refresh_policy_if_shorter( let expires_policy = days_to_expires_policy(new_days); diesel::sql_query( - "UPDATE website_crawls SET refresh_policy = $3, expires_policy = $4 + "UPDATE website_crawls SET refresh_policy = $3, expires_policy = $4, next_crawl = NOW(), crawl_status = 0 WHERE bot_id = $1 AND url = $2" ) .bind::(bot_id) @@ -645,6 +648,8 @@ fn update_refresh_policy_if_shorter( .bind::(expires_policy) .execute(conn) .map_err(|e| format!("Failed to update refresh policy: {}", e))?; + + info!("Refresh policy updated to {} for {} - immediate crawl scheduled", refresh_interval, url); } Ok(()) diff --git a/src/core/bootstrap/bootstrap_manager.rs b/src/core/bootstrap/bootstrap_manager.rs index e2f031c8..80caae2e 100644 --- a/src/core/bootstrap/bootstrap_manager.rs +++ b/src/core/bootstrap/bootstrap_manager.rs @@ -159,38 +159,15 @@ impl BootstrapManager { } if pm.is_installed("directory") { - // Wait for Zitadel to be ready - it might have been started during installation - // Use very aggressive backoff for fastest startup detection - let mut directory_already_running = zitadel_health_check(); - if !directory_already_running { - info!("Zitadel not responding to health check, waiting..."); - // Check every 500ms for fast detection (was: 1s, 2s, 5s, 10s) - let mut checks = 0; - let max_checks = 120; // 60 seconds max - while checks < max_checks { - if zitadel_health_check() { - info!("Zitadel/Directory service is now responding (checked {} times)", checks); - directory_already_running = true; - break; - } - sleep(Duration::from_millis(500)).await; - checks += 1; - // Log progress every 10 checks (5 seconds) - if checks % 10 == 0 { - info!("Zitadel health check: {}s elapsed, retrying...", checks / 2); - } - } - } + // Check once if Zitadel is already running + let directory_already_running = zitadel_health_check(); if directory_already_running { info!("Zitadel/Directory service is already running"); // Create OAuth client if config doesn't exist (even when already running) - // Check both Vault and file system for existing config let config_path = self.stack_dir("conf/system/directory_config.json"); - let has_config = config_path.exists(); - - if !has_config { + if !config_path.exists() { info!("Creating OAuth client for Directory service..."); match crate::core::package_manager::setup_directory().await { Ok(_) => info!("OAuth client created successfully"), @@ -200,6 +177,7 @@ impl BootstrapManager { info!("Directory config already exists, skipping OAuth setup"); } } else { + // Not running — start it immediately, then wait for it to become ready info!("Starting Zitadel/Directory service..."); match pm.start("directory") { Ok(_child) => { @@ -208,14 +186,18 @@ impl BootstrapManager { for i in 0..150 { sleep(Duration::from_secs(2)).await; if zitadel_health_check() { - info!("Zitadel/Directory service is responding"); + info!("Zitadel/Directory service is responding after {}s", (i + 1) * 2); zitadel_ready = true; break; } - if i == 149 { - warn!("Zitadel/Directory service did not respond after 300 seconds"); + // Log progress every 15 checks (30 seconds) + if i % 15 == 14 { + info!("Zitadel health check: {}s elapsed, retrying...", (i + 1) * 2); } } + if !zitadel_ready { + warn!("Zitadel/Directory service did not respond after 300 seconds"); + } // Create OAuth client if Zitadel is ready and config doesn't exist if zitadel_ready { diff --git a/src/core/bot/channels/whatsapp_rate_limiter.rs b/src/core/bot/channels/whatsapp_rate_limiter.rs index 9418b175..318f8d26 100644 --- a/src/core/bot/channels/whatsapp_rate_limiter.rs +++ b/src/core/bot/channels/whatsapp_rate_limiter.rs @@ -28,8 +28,10 @@ use tokio::time::sleep; /// WhatsApp throughput tier levels (matches Meta's tiers) #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Default)] pub enum WhatsAppTier { /// Tier 1: New phone numbers (40 msg/s, 1000 conv/day) + #[default] Tier1, /// Tier 2: Medium quality (80 msg/s, 10000 conv/day) Tier2, @@ -39,11 +41,6 @@ pub enum WhatsAppTier { Tier4, } -impl Default for WhatsAppTier { - fn default() -> Self { - Self::Tier1 - } -} impl WhatsAppTier { /// Get messages per second for this tier diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index 55a0da6a..c8b14758 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -416,7 +416,7 @@ impl BotOrchestrator { let session_id = Uuid::parse_str(&message.session_id)?; let message_content = message.content.clone(); - let (session, context_data, history, model, key, system_prompt, bot_llm_url) = { + let (session, context_data, history, model, key, system_prompt, bot_llm_url, explicit_llm_provider) = { let state_clone = self.state.clone(); tokio::task::spawn_blocking( move || -> Result<_, Box> { @@ -458,6 +458,12 @@ impl BotOrchestrator { .get_bot_config_value(&session.bot_id, "llm-url") .ok(); + // Load explicit llm-provider from config.csv (e.g., "openai", "bedrock", "claude") + // This allows overriding auto-detection from URL + let explicit_llm_provider = config_manager + .get_bot_config_value(&session.bot_id, "llm-provider") + .ok(); + // Load system-prompt from config.csv, fallback to default let system_prompt = config_manager .get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response.")) @@ -465,7 +471,7 @@ impl BotOrchestrator { info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(500)]); - Ok((session, context_data, history, model, key, system_prompt, bot_llm_url)) + Ok((session, context_data, history, model, key, system_prompt, bot_llm_url, explicit_llm_provider)) }, ) .await?? @@ -624,7 +630,13 @@ impl BotOrchestrator { // Use bot-specific LLM provider if the bot has its own llm-url configured let llm: std::sync::Arc = if let Some(ref url) = bot_llm_url { info!("Bot has custom llm-url: {}, creating per-bot LLM provider", url); - crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None) + // Parse explicit provider type if configured (e.g., "openai", "bedrock", "claude") + let explicit_type = explicit_llm_provider.as_ref().map(|p| { + let parsed: crate::llm::LLMProviderType = p.as_str().into(); + info!("Using explicit llm-provider config: {:?} for bot {}", parsed, session.bot_id); + parsed + }); + crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None, explicit_type) } else { self.state.llm_provider.clone() }; diff --git a/src/core/config/watcher.rs b/src/core/config/watcher.rs index 3e9d82eb..0a9880a7 100644 --- a/src/core/config/watcher.rs +++ b/src/core/config/watcher.rs @@ -183,7 +183,7 @@ impl ConfigWatcher { }; info!("ConfigWatcher: Refreshing LLM provider with URL={}, model={}, endpoint={:?}", base_url, llm_model, endpoint_path); - dynamic_llm.update_from_config(base_url, Some(llm_model), endpoint_path.map(|s| s.to_string())).await; + dynamic_llm.update_from_config(base_url, Some(llm_model), endpoint_path.map(|s| s.to_string()), None).await; } } } diff --git a/src/core/config_reload.rs b/src/core/config_reload.rs index 656afffc..37b49bc3 100644 --- a/src/core/config_reload.rs +++ b/src/core/config_reload.rs @@ -38,7 +38,7 @@ pub async fn reload_config( // Update LLM provider if let Some(dynamic_llm) = &state.dynamic_llm_provider { dynamic_llm - .update_from_config(&llm_url, Some(llm_model.clone()), Some(llm_endpoint_path.clone())) + .update_from_config(&llm_url, Some(llm_model.clone()), Some(llm_endpoint_path.clone()), None) .await; Ok(Json(json!({ diff --git a/src/deployment/types.rs b/src/deployment/types.rs index c3a91522..47d99319 100644 --- a/src/deployment/types.rs +++ b/src/deployment/types.rs @@ -75,17 +75,14 @@ pub struct DeploymentConfig { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Default)] pub enum DeploymentEnvironment { + #[default] Development, Staging, Production, } -impl Default for DeploymentEnvironment { - fn default() -> Self { - DeploymentEnvironment::Development - } -} impl std::fmt::Display for DeploymentEnvironment { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/src/drive/drive_monitor/mod.rs b/src/drive/drive_monitor/mod.rs index 09b75e52..999c67cb 100644 --- a/src/drive/drive_monitor/mod.rs +++ b/src/drive/drive_monitor/mod.rs @@ -693,6 +693,7 @@ impl DriveMonitor { &effective_url, Some(effective_model), Some(effective_endpoint_path), + None, ) .await; trace!("Dynamic LLM provider updated with new configuration"); diff --git a/src/email/messages.rs b/src/email/messages.rs index 399160fa..a9302951 100644 --- a/src/email/messages.rs +++ b/src/email/messages.rs @@ -83,16 +83,20 @@ fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc>, +struct EmailTrackingParams<'a> { tracking_id: Uuid, account_id: Uuid, bot_id: Uuid, - from_email: &str, - to_email: &str, - cc: Option<&str>, - bcc: Option<&str>, - subject: &str, + from_email: &'a str, + to_email: &'a str, + cc: Option<&'a str>, + bcc: Option<&'a str>, + subject: &'a str, +} + +fn save_email_tracking_record( + conn: diesel::r2d2::Pool>, + params: EmailTrackingParams, ) -> Result<(), String> { let mut db_conn = conn.get().map_err(|e| format!("DB connection error: {e}"))?; @@ -101,14 +105,14 @@ fn save_email_tracking_record( VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW())" ) .bind::(Uuid::new_v4()) - .bind::(tracking_id.to_string()) - .bind::(bot_id) - .bind::(account_id) - .bind::(from_email) - .bind::(to_email) - .bind::, _>(cc) - .bind::, _>(bcc) - .bind::(subject) + .bind::(params.tracking_id.to_string()) + .bind::(params.bot_id) + .bind::(params.account_id) + .bind::(params.from_email) + .bind::(params.to_email) + .bind::, _>(params.cc) + .bind::, _>(params.bcc) + .bind::(params.subject) .execute(&mut db_conn) .map_err(|e| format!("Failed to save tracking record: {e}"))?; @@ -368,14 +372,16 @@ pub async fn send_email( let _ = tokio::task::spawn_blocking(move || { save_email_tracking_record( conn, - tracking_id, - account_uuid, - Uuid::nil(), - &from_email, - &to_email, - cc_clone.as_deref(), - bcc_clone.as_deref(), - &subject, + EmailTrackingParams { + tracking_id, + account_id: account_uuid, + bot_id: Uuid::nil(), + from_email: &from_email, + to_email: &to_email, + cc: cc_clone.as_deref(), + bcc: bcc_clone.as_deref(), + subject: &subject, + }, ) }) .await; diff --git a/src/email/tracking.rs b/src/email/tracking.rs index d54d3794..24a8917b 100644 --- a/src/email/tracking.rs +++ b/src/email/tracking.rs @@ -49,16 +49,20 @@ pub fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc { + pub tracking_id: Uuid, + pub account_id: Uuid, + pub bot_id: Uuid, + pub from_email: &'a str, + pub to_email: &'a str, + pub cc: Option<&'a str>, + pub bcc: Option<&'a str>, + pub subject: &'a str, +} + pub fn save_email_tracking_record( conn: crate::core::shared::utils::DbPool, - tracking_id: Uuid, - account_id: Uuid, - bot_id: Uuid, - from_email: &str, - to_email: &str, - cc: Option<&str>, - bcc: Option<&str>, - subject: &str, + params: EmailTrackingParams, ) -> Result<(), String> { let mut db_conn = conn .get() @@ -73,19 +77,19 @@ pub fn save_email_tracking_record( VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 0, false)" ) .bind::(id) - .bind::(tracking_id) - .bind::(bot_id) - .bind::(account_id) - .bind::(from_email) - .bind::(to_email) - .bind::, _>(cc) - .bind::, _>(bcc) - .bind::(subject) + .bind::(params.tracking_id) + .bind::(params.bot_id) + .bind::(params.account_id) + .bind::(params.from_email) + .bind::(params.to_email) + .bind::, _>(params.cc) + .bind::, _>(params.bcc) + .bind::(params.subject) .bind::(now) .execute(&mut db_conn) .map_err(|e| format!("Failed to save tracking record: {}", e))?; - debug!("Saved email tracking record: tracking_id={}", tracking_id); + debug!("Saved email tracking record: tracking_id={}", params.tracking_id); Ok(()) } diff --git a/src/llm/bedrock.rs b/src/llm/bedrock.rs index e5a58562..c8143d01 100644 --- a/src/llm/bedrock.rs +++ b/src/llm/bedrock.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use futures::StreamExt; -use log::{error, info}; +use log::{error, info, warn}; use serde_json::Value; use tokio::sync::mpsc; @@ -14,11 +14,165 @@ pub struct BedrockClient { impl BedrockClient { pub fn new(base_url: String) -> Self { + // Accept three URL formats: + // 1. OpenAI-compatible: .../openai/v1/chat/completions (use as-is) + // 2. Native invoke: .../model/{model-id}/invoke (use as-is, streaming swaps to invoke-with-response-stream) + // 3. Bare domain: https://bedrock-runtime.region.amazonaws.com (auto-append OpenAI path) + let url = if base_url.contains("/openai/") || base_url.contains("/chat/completions") || base_url.contains("/model/") { + base_url + } else { + let trimmed = base_url.trim_end_matches('/'); + format!("{}/openai/v1/chat/completions", trimmed) + }; + Self { client: reqwest::Client::new(), - base_url, + base_url: url, } } + + /// Check if URL is native Bedrock invoke endpoint (not OpenAI-compatible) + fn is_native_invoke(&self) -> bool { + self.base_url.contains("/model/") && self.base_url.contains("/invoke") + } + + /// Get streaming URL: for native invoke, swap /invoke to /invoke-with-response-stream + fn stream_url(&self) -> String { + if self.is_native_invoke() && self.base_url.ends_with("/invoke") { + self.base_url.replace("/invoke", "/invoke-with-response-stream") + } else { + self.base_url.clone() + } + } + + /// Build the auth header from the key + fn auth_header(key: &str) -> String { + if key.starts_with("Bearer ") { + key.to_string() + } else { + format!("Bearer {}", key) + } + } + + /// Build formatted messages from raw input + fn build_messages(raw_messages: &Value) -> Value { + let mut messages_limited = Vec::new(); + if let Some(msg_array) = raw_messages.as_array() { + for msg in msg_array { + messages_limited.push(msg.clone()); + } + } + Value::Array(messages_limited) + } + + /// Send a streaming request and process the response + async fn do_stream( + &self, + formatted_messages: &Value, + model: &str, + key: &str, + tools: Option<&Vec>, + tx: &mpsc::Sender, + ) -> Result<(), Box> { + let auth_header = Self::auth_header(key); + + let mut request_body = serde_json::json!({ + "model": model, + "messages": formatted_messages, + "stream": true + }); + + if let Some(tools_value) = tools { + if !tools_value.is_empty() { + request_body["tools"] = serde_json::json!(tools_value); + info!("Added {} tools to Bedrock request", tools_value.len()); + } + } + + let url = self.stream_url(); + info!("Sending streaming request to Bedrock endpoint: {}", url); + + let response = self.client + .post(&url) + .header("Authorization", &auth_header) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Bedrock generate_stream error: {}", error_text); + return Err(format!("Bedrock API error ({}): {}", status, error_text).into()); + } + + let mut stream = response.bytes_stream(); + let mut tool_call_buffer = String::new(); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if let Ok(text) = std::str::from_utf8(&chunk) { + for line in text.split('\n') { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + continue; + } + + if let Ok(json) = serde_json::from_str::(data) { + if let Some(choices) = json.get("choices") { + if let Some(first_choice) = choices.get(0) { + if let Some(delta) = first_choice.get("delta") { + if let Some(content) = delta.get("content") { + if let Some(content_str) = content.as_str() { + if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() { + return Ok(()); + } + } + } + + if let Some(tool_calls) = delta.get("tool_calls") { + if let Some(calls_array) = tool_calls.as_array() { + if let Some(first_call) = calls_array.first() { + if let Some(function) = first_call.get("function") { + if let Some(name) = function.get("name") { + if let Some(name_str) = name.as_str() { + tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str); + } + } + if let Some(args) = function.get("arguments") { + if let Some(args_str) = args.as_str() { + tool_call_buffer.push_str(args_str); + } + } + } + } + } + } + } + } + } + } + } + } + } + } + Err(e) => { + error!("Bedrock stream reading error: {}", e); + break; + } + } + } + + if !tool_call_buffer.is_empty() { + tool_call_buffer.push_str("\"}"); + let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await; + } + + Ok(()) + } } #[async_trait] @@ -38,20 +192,8 @@ impl LLMProvider for BedrockClient { &default_messages }; - - let mut messages_limited = Vec::new(); - if let Some(msg_array) = raw_messages.as_array() { - for msg in msg_array { - messages_limited.push(msg.clone()); - } - } - let formatted_messages = serde_json::Value::Array(messages_limited); - - let auth_header = if key.starts_with("Bearer ") { - key.to_string() - } else { - format!("Bearer {}", key) - }; + let formatted_messages = Self::build_messages(raw_messages); + let auth_header = Self::auth_header(key); let request_body = serde_json::json!({ "model": model, @@ -77,7 +219,7 @@ impl LLMProvider for BedrockClient { } let json: Value = response.json().await?; - + if let Some(choices) = json.get("choices") { if let Some(first_choice) = choices.get(0) { if let Some(message) = first_choice.get("message") { @@ -110,127 +252,26 @@ impl LLMProvider for BedrockClient { &default_messages }; - let mut messages_limited = Vec::new(); - if let Some(msg_array) = raw_messages.as_array() { - for msg in msg_array { - messages_limited.push(msg.clone()); - } - } - let formatted_messages = serde_json::Value::Array(messages_limited); + let formatted_messages = Self::build_messages(raw_messages); - let auth_header = if key.starts_with("Bearer ") { - key.to_string() - } else { - format!("Bearer {}", key) - }; + // Try with tools first + let result = self.do_stream(&formatted_messages, model, key, tools, &tx).await; - let mut request_body = serde_json::json!({ - "model": model, - "messages": formatted_messages, - "stream": true - }); - - if let Some(tools_value) = tools { - if !tools_value.is_empty() { - request_body["tools"] = serde_json::json!(tools_value); - info!("Added {} tools to Bedrock request", tools_value.len()); + if let Err(ref e) = result { + let err_str = e.to_string(); + // If error is "Operation not allowed" or validation_error, retry without tools + if (err_str.contains("Operation not allowed") || err_str.contains("validation_error")) + && tools.is_some() + { + warn!( + "Bedrock model '{}' does not support tools, retrying without tools", + model + ); + return self.do_stream(&formatted_messages, model, key, None, &tx).await; } } - let stream_url = if self.base_url.ends_with("/invoke") { - self.base_url.replace("/invoke", "/invoke-with-response-stream") - } else { - self.base_url.clone() - }; - - info!("Sending streaming request to Bedrock endpoint: {}", stream_url); - - let response = self.client - .post(&stream_url) - .header("Authorization", &auth_header) - .header("Content-Type", "application/json") - .json(&request_body) - .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let error_text = response.text().await.unwrap_or_default(); - error!("Bedrock generate_stream error: {}", error_text); - return Err(format!("Bedrock API error ({}): {}", status, error_text).into()); - } - - let mut stream = response.bytes_stream(); - let mut tool_call_buffer = String::new(); - - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - if let Ok(text) = std::str::from_utf8(&chunk) { - for line in text.split('\n') { - let line = line.trim(); - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - continue; - } - - if let Ok(json) = serde_json::from_str::(data) { - if let Some(choices) = json.get("choices") { - if let Some(first_choice) = choices.get(0) { - if let Some(delta) = first_choice.get("delta") { - // Handle standard content streaming - if let Some(content) = delta.get("content") { - if let Some(content_str) = content.as_str() { - if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() { - return Ok(()); - } - } - } - - // Handle tool calls streaming - if let Some(tool_calls) = delta.get("tool_calls") { - if let Some(calls_array) = tool_calls.as_array() { - if let Some(first_call) = calls_array.first() { - if let Some(function) = first_call.get("function") { - // Stream function JSON representation just like OpenAI does - if let Some(name) = function.get("name") { - if let Some(name_str) = name.as_str() { - tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str); - } - } - if let Some(args) = function.get("arguments") { - if let Some(args_str) = args.as_str() { - tool_call_buffer.push_str(args_str); - } - } - } - } - } - } - } - } - } - } - } - } - } - } - Err(e) => { - error!("Bedrock stream reading error: {}", e); - break; - } - } - } - - // Finalize tool call JSON parsing - if !tool_call_buffer.is_empty() { - tool_call_buffer.push_str("\"}"); - if tx.send(format!("`tool_call`: {}", tool_call_buffer)).await.is_err() { - return Ok(()); - } - } - - Ok(()) + result } async fn cancel_job(&self, _session_id: &str) -> Result<(), Box> { diff --git a/src/llm/hallucination_detector.rs b/src/llm/hallucination_detector.rs index 58699ab2..fe0318de 100644 --- a/src/llm/hallucination_detector.rs +++ b/src/llm/hallucination_detector.rs @@ -58,7 +58,7 @@ impl HallucinationDetector { // Ignore Markdown formatting patterns let md_patterns = ["**", "__", "*", "_", "`", "~~", "---", "***"]; - if md_patterns.iter().any(|p| trimmed == *p) { + if md_patterns.contains(&trimmed) { return false; } diff --git a/src/llm/local.rs b/src/llm/local.rs index c065bad2..da36930f 100644 --- a/src/llm/local.rs +++ b/src/llm/local.rs @@ -214,7 +214,7 @@ pub async fn ensure_llama_servers_running( log_jemalloc_stats(); trace!("ensure_llama_servers_running EXIT OK (non-blocking)"); - return Ok(()); + Ok(()) // OLD BLOCKING CODE - REMOVED TO PREVENT HTTP SERVER BLOCKING /* diff --git a/src/llm/mod.rs b/src/llm/mod.rs index a46ea00b..6fd3ba29 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -162,12 +162,19 @@ impl OpenAIClient { } pub fn new(_api_key: String, base_url: Option, endpoint_path: Option) -> Self { let base = base_url.unwrap_or_else(|| "https://api.openai.com".to_string()); + let trimmed_base = base.trim_end_matches('/').to_string(); + + // Detect if the base URL already contains a completions path + let has_v1_path = trimmed_base.contains("/v1/chat/completions"); + let has_chat_path = !has_v1_path && trimmed_base.contains("/chat/completions"); - // For z.ai API, use different endpoint path let endpoint = if let Some(path) = endpoint_path { path - } else if base.contains("z.ai") || base.contains("/v4") { - "/chat/completions".to_string() // z.ai uses /chat/completions, not /v1/chat/completions + } else if has_v1_path || (has_chat_path && !trimmed_base.contains("z.ai")) { + // Path already in base_url, use empty endpoint + "".to_string() + } else if trimmed_base.contains("z.ai") || trimmed_base.contains("/v4") { + "/chat/completions".to_string() // z.ai uses /chat/completions, not /v1/chat/completions } else { "/v1/chat/completions".to_string() }; @@ -416,7 +423,7 @@ impl LLMProvider for OpenAIClient { let _ = tx.send(processed).await; } } - + // Handle standard OpenAI tool_calls if let Some(tool_calls) = data["choices"][0]["delta"]["tool_calls"].as_array() { for tool_call in tool_calls { @@ -450,7 +457,7 @@ pub fn start_llm_services(state: &std::sync::Arc, endpoint_path: Option, + explicit_provider: Option, ) -> std::sync::Arc { - let provider_type = LLMProviderType::from(url); + let provider_type = explicit_provider.as_ref().map(|p| *p).unwrap_or_else(|| LLMProviderType::from(url)); + if explicit_provider.is_some() { + info!("Using explicit LLM provider type: {:?} for URL: {}", provider_type, url); + } create_llm_provider(provider_type, url.to_string(), model, endpoint_path) } @@ -555,8 +568,9 @@ impl DynamicLLMProvider { url: &str, model: Option, endpoint_path: Option, + explicit_provider: Option, ) { - let new_provider = create_llm_provider_from_url(url, model, endpoint_path); + let new_provider = create_llm_provider_from_url(url, model, endpoint_path, explicit_provider); self.update_provider(new_provider).await; } diff --git a/src/llm/vertex.rs b/src/llm/vertex.rs index c8ce4322..7ba5e9ea 100644 --- a/src/llm/vertex.rs +++ b/src/llm/vertex.rs @@ -470,7 +470,7 @@ impl LLMProvider for VertexClient { } else { // --- Native Gemini JSON format --- // It usually arrives as raw JSON objects, sometimes with leading commas or brackets in a stream array - let trimmed = line.trim_start_matches(|c| c == ',' || c == '[' || c == ']').trim_end_matches(']'); + let trimmed = line.trim_start_matches([',', '[', ']']).trim_end_matches(']'); if trimmed.is_empty() { continue; } if let Ok(json) = serde_json::from_str::(trimmed) { diff --git a/src/main_module/bootstrap.rs b/src/main_module/bootstrap.rs index 942b29b8..d2cbea14 100644 --- a/src/main_module/bootstrap.rs +++ b/src/main_module/bootstrap.rs @@ -473,6 +473,7 @@ pub async fn create_app_state( Some(llm_model.clone()) }, Some(llm_endpoint_path.clone()), + None, ); #[cfg(feature = "llm")] @@ -487,10 +488,12 @@ pub async fn create_app_state( llm_url, if llm_model.is_empty() { "(default)" } else { &llm_model }, llm_endpoint_path.clone()); + #[cfg(feature = "llm")] dynamic_llm_provider.update_from_config( &llm_url, if llm_model.is_empty() { None } else { Some(llm_model.clone()) }, Some(llm_endpoint_path), + None, ).await; info!("DynamicLLMProvider initialized successfully"); }