refactor: apply various fixes across botserver
Some checks failed
BotServer CI / build (push) Has been cancelled
Some checks failed
BotServer CI / build (push) Has been cancelled
This commit is contained in:
parent
82bfd0a443
commit
260a13e77d
20 changed files with 302 additions and 243 deletions
|
|
@ -25,7 +25,7 @@ impl ContainerSession {
|
|||
// Launch the container (this might take a moment if the image isn't cached locally)
|
||||
info!("Launching LXC container: {}", container_name);
|
||||
let launch_status = Command::new("lxc")
|
||||
.args(&["launch", "ubuntu:22.04", &container_name])
|
||||
.args(["launch", "ubuntu:22.04", &container_name])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to execute lxc launch: {}", e))?;
|
||||
|
|
@ -50,7 +50,7 @@ impl ContainerSession {
|
|||
info!("Starting terminal session in container: {}", self.container_name);
|
||||
|
||||
let mut child = Command::new("lxc")
|
||||
.args(&["exec", &self.container_name, "--", "bash"])
|
||||
.args(["exec", &self.container_name, "--", "bash"])
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
|
|
@ -115,7 +115,7 @@ impl ContainerSession {
|
|||
|
||||
// Clean up container
|
||||
let status = Command::new("lxc")
|
||||
.args(&["delete", &self.container_name, "--force"])
|
||||
.args(["delete", &self.container_name, "--force"])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to delete container: {}", e))?;
|
||||
|
|
|
|||
|
|
@ -659,7 +659,7 @@ impl Orchestrator {
|
|||
if !classification.entities.features.is_empty() {
|
||||
for feature in &classification.entities.features {
|
||||
let slug =
|
||||
feature.to_lowercase().replace(' ', "_").replace('-', "_");
|
||||
feature.to_lowercase().replace([' ', '-'], "_");
|
||||
tasks.push(PipelineSubTask {
|
||||
name: feature.clone(),
|
||||
description: format!(
|
||||
|
|
@ -681,7 +681,7 @@ impl Orchestrator {
|
|||
let inferred = infer_features_from_intent(&lower);
|
||||
for feature in &inferred {
|
||||
let slug =
|
||||
feature.to_lowercase().replace(' ', "_").replace('-', "_");
|
||||
feature.to_lowercase().replace([' ', '-'], "_");
|
||||
tasks.push(PipelineSubTask {
|
||||
name: feature.clone(),
|
||||
description: format!("Build {} UI with HTMX", feature),
|
||||
|
|
@ -897,7 +897,7 @@ impl Orchestrator {
|
|||
let event = TaskProgressEvent::new(
|
||||
&self.task_id,
|
||||
step,
|
||||
&format!("Mantis #{agent_id} activity"),
|
||||
format!("Mantis #{agent_id} activity"),
|
||||
)
|
||||
.with_event_type("agent_activity")
|
||||
.with_activity(activity.clone());
|
||||
|
|
|
|||
|
|
@ -350,9 +350,9 @@ async fn execute_send_mail(
|
|||
|
||||
if email_service
|
||||
.send_email(
|
||||
&to,
|
||||
&subject,
|
||||
&body,
|
||||
to,
|
||||
subject,
|
||||
body,
|
||||
if attachments.is_empty() {
|
||||
None
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -47,18 +47,15 @@ pub enum SmsProvider {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub enum SmsPriority {
|
||||
Low,
|
||||
#[default]
|
||||
Normal,
|
||||
High,
|
||||
Urgent,
|
||||
}
|
||||
|
||||
impl Default for SmsPriority {
|
||||
fn default() -> Self {
|
||||
Self::Normal
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for SmsPriority {
|
||||
fn from(s: &str) -> Self {
|
||||
|
|
|
|||
|
|
@ -594,7 +594,10 @@ pub fn register_website_for_crawling_with_refresh(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Update refresh policy if the new interval is shorter than the existing one
|
||||
/// Update refresh policy if the new interval is shorter than the existing one.
|
||||
/// When the policy is updated, triggers an immediate crawl by setting next_crawl=NOW()
|
||||
/// and crawl_status=0, ensuring the website is recrawled on the next crawler cycle
|
||||
/// regardless of the previous schedule.
|
||||
fn update_refresh_policy_if_shorter(
|
||||
conn: &mut PgConnection,
|
||||
bot_id: &Uuid,
|
||||
|
|
@ -636,7 +639,7 @@ fn update_refresh_policy_if_shorter(
|
|||
let expires_policy = days_to_expires_policy(new_days);
|
||||
|
||||
diesel::sql_query(
|
||||
"UPDATE website_crawls SET refresh_policy = $3, expires_policy = $4
|
||||
"UPDATE website_crawls SET refresh_policy = $3, expires_policy = $4, next_crawl = NOW(), crawl_status = 0
|
||||
WHERE bot_id = $1 AND url = $2"
|
||||
)
|
||||
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
||||
|
|
@ -645,6 +648,8 @@ fn update_refresh_policy_if_shorter(
|
|||
.bind::<diesel::sql_types::Text, _>(expires_policy)
|
||||
.execute(conn)
|
||||
.map_err(|e| format!("Failed to update refresh policy: {}", e))?;
|
||||
|
||||
info!("Refresh policy updated to {} for {} - immediate crawl scheduled", refresh_interval, url);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -159,38 +159,15 @@ impl BootstrapManager {
|
|||
}
|
||||
|
||||
if pm.is_installed("directory") {
|
||||
// Wait for Zitadel to be ready - it might have been started during installation
|
||||
// Use very aggressive backoff for fastest startup detection
|
||||
let mut directory_already_running = zitadel_health_check();
|
||||
if !directory_already_running {
|
||||
info!("Zitadel not responding to health check, waiting...");
|
||||
// Check every 500ms for fast detection (was: 1s, 2s, 5s, 10s)
|
||||
let mut checks = 0;
|
||||
let max_checks = 120; // 60 seconds max
|
||||
while checks < max_checks {
|
||||
if zitadel_health_check() {
|
||||
info!("Zitadel/Directory service is now responding (checked {} times)", checks);
|
||||
directory_already_running = true;
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
checks += 1;
|
||||
// Log progress every 10 checks (5 seconds)
|
||||
if checks % 10 == 0 {
|
||||
info!("Zitadel health check: {}s elapsed, retrying...", checks / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check once if Zitadel is already running
|
||||
let directory_already_running = zitadel_health_check();
|
||||
|
||||
if directory_already_running {
|
||||
info!("Zitadel/Directory service is already running");
|
||||
|
||||
// Create OAuth client if config doesn't exist (even when already running)
|
||||
// Check both Vault and file system for existing config
|
||||
let config_path = self.stack_dir("conf/system/directory_config.json");
|
||||
let has_config = config_path.exists();
|
||||
|
||||
if !has_config {
|
||||
if !config_path.exists() {
|
||||
info!("Creating OAuth client for Directory service...");
|
||||
match crate::core::package_manager::setup_directory().await {
|
||||
Ok(_) => info!("OAuth client created successfully"),
|
||||
|
|
@ -200,6 +177,7 @@ impl BootstrapManager {
|
|||
info!("Directory config already exists, skipping OAuth setup");
|
||||
}
|
||||
} else {
|
||||
// Not running — start it immediately, then wait for it to become ready
|
||||
info!("Starting Zitadel/Directory service...");
|
||||
match pm.start("directory") {
|
||||
Ok(_child) => {
|
||||
|
|
@ -208,14 +186,18 @@ impl BootstrapManager {
|
|||
for i in 0..150 {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
if zitadel_health_check() {
|
||||
info!("Zitadel/Directory service is responding");
|
||||
info!("Zitadel/Directory service is responding after {}s", (i + 1) * 2);
|
||||
zitadel_ready = true;
|
||||
break;
|
||||
}
|
||||
if i == 149 {
|
||||
warn!("Zitadel/Directory service did not respond after 300 seconds");
|
||||
// Log progress every 15 checks (30 seconds)
|
||||
if i % 15 == 14 {
|
||||
info!("Zitadel health check: {}s elapsed, retrying...", (i + 1) * 2);
|
||||
}
|
||||
}
|
||||
if !zitadel_ready {
|
||||
warn!("Zitadel/Directory service did not respond after 300 seconds");
|
||||
}
|
||||
|
||||
// Create OAuth client if Zitadel is ready and config doesn't exist
|
||||
if zitadel_ready {
|
||||
|
|
|
|||
|
|
@ -28,8 +28,10 @@ use tokio::time::sleep;
|
|||
|
||||
/// WhatsApp throughput tier levels (matches Meta's tiers)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Default)]
|
||||
pub enum WhatsAppTier {
|
||||
/// Tier 1: New phone numbers (40 msg/s, 1000 conv/day)
|
||||
#[default]
|
||||
Tier1,
|
||||
/// Tier 2: Medium quality (80 msg/s, 10000 conv/day)
|
||||
Tier2,
|
||||
|
|
@ -39,11 +41,6 @@ pub enum WhatsAppTier {
|
|||
Tier4,
|
||||
}
|
||||
|
||||
impl Default for WhatsAppTier {
|
||||
fn default() -> Self {
|
||||
Self::Tier1
|
||||
}
|
||||
}
|
||||
|
||||
impl WhatsAppTier {
|
||||
/// Get messages per second for this tier
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ impl BotOrchestrator {
|
|||
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||
let message_content = message.content.clone();
|
||||
|
||||
let (session, context_data, history, model, key, system_prompt, bot_llm_url) = {
|
||||
let (session, context_data, history, model, key, system_prompt, bot_llm_url, explicit_llm_provider) = {
|
||||
let state_clone = self.state.clone();
|
||||
tokio::task::spawn_blocking(
|
||||
move || -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
|
@ -458,6 +458,12 @@ impl BotOrchestrator {
|
|||
.get_bot_config_value(&session.bot_id, "llm-url")
|
||||
.ok();
|
||||
|
||||
// Load explicit llm-provider from config.csv (e.g., "openai", "bedrock", "claude")
|
||||
// This allows overriding auto-detection from URL
|
||||
let explicit_llm_provider = config_manager
|
||||
.get_bot_config_value(&session.bot_id, "llm-provider")
|
||||
.ok();
|
||||
|
||||
// Load system-prompt from config.csv, fallback to default
|
||||
let system_prompt = config_manager
|
||||
.get_config(&session.bot_id, "system-prompt", Some("You are a helpful assistant with access to tools that can help you complete tasks. When a user's request matches one of your available tools, use the appropriate tool instead of providing a generic response."))
|
||||
|
|
@ -465,7 +471,7 @@ impl BotOrchestrator {
|
|||
|
||||
info!("Loaded system-prompt for bot {}: {}", session.bot_id, &system_prompt[..system_prompt.len().min(500)]);
|
||||
|
||||
Ok((session, context_data, history, model, key, system_prompt, bot_llm_url))
|
||||
Ok((session, context_data, history, model, key, system_prompt, bot_llm_url, explicit_llm_provider))
|
||||
},
|
||||
)
|
||||
.await??
|
||||
|
|
@ -624,7 +630,13 @@ impl BotOrchestrator {
|
|||
// Use bot-specific LLM provider if the bot has its own llm-url configured
|
||||
let llm: std::sync::Arc<dyn crate::llm::LLMProvider> = if let Some(ref url) = bot_llm_url {
|
||||
info!("Bot has custom llm-url: {}, creating per-bot LLM provider", url);
|
||||
crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None)
|
||||
// Parse explicit provider type if configured (e.g., "openai", "bedrock", "claude")
|
||||
let explicit_type = explicit_llm_provider.as_ref().map(|p| {
|
||||
let parsed: crate::llm::LLMProviderType = p.as_str().into();
|
||||
info!("Using explicit llm-provider config: {:?} for bot {}", parsed, session.bot_id);
|
||||
parsed
|
||||
});
|
||||
crate::llm::create_llm_provider_from_url(url, Some(model.clone()), None, explicit_type)
|
||||
} else {
|
||||
self.state.llm_provider.clone()
|
||||
};
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ impl ConfigWatcher {
|
|||
};
|
||||
|
||||
info!("ConfigWatcher: Refreshing LLM provider with URL={}, model={}, endpoint={:?}", base_url, llm_model, endpoint_path);
|
||||
dynamic_llm.update_from_config(base_url, Some(llm_model), endpoint_path.map(|s| s.to_string())).await;
|
||||
dynamic_llm.update_from_config(base_url, Some(llm_model), endpoint_path.map(|s| s.to_string()), None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ pub async fn reload_config(
|
|||
// Update LLM provider
|
||||
if let Some(dynamic_llm) = &state.dynamic_llm_provider {
|
||||
dynamic_llm
|
||||
.update_from_config(&llm_url, Some(llm_model.clone()), Some(llm_endpoint_path.clone()))
|
||||
.update_from_config(&llm_url, Some(llm_model.clone()), Some(llm_endpoint_path.clone()), None)
|
||||
.await;
|
||||
|
||||
Ok(Json(json!({
|
||||
|
|
|
|||
|
|
@ -75,17 +75,14 @@ pub struct DeploymentConfig {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Default)]
|
||||
pub enum DeploymentEnvironment {
|
||||
#[default]
|
||||
Development,
|
||||
Staging,
|
||||
Production,
|
||||
}
|
||||
|
||||
impl Default for DeploymentEnvironment {
|
||||
fn default() -> Self {
|
||||
DeploymentEnvironment::Development
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DeploymentEnvironment {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
|
|
|||
|
|
@ -693,6 +693,7 @@ impl DriveMonitor {
|
|||
&effective_url,
|
||||
Some(effective_model),
|
||||
Some(effective_endpoint_path),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
trace!("Dynamic LLM provider updated with new configuration");
|
||||
|
|
|
|||
|
|
@ -83,16 +83,20 @@ fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc<AppStat
|
|||
}
|
||||
}
|
||||
|
||||
fn save_email_tracking_record(
|
||||
conn: diesel::r2d2::Pool<diesel::r2d2::ConnectionManager<diesel::PgConnection>>,
|
||||
struct EmailTrackingParams<'a> {
|
||||
tracking_id: Uuid,
|
||||
account_id: Uuid,
|
||||
bot_id: Uuid,
|
||||
from_email: &str,
|
||||
to_email: &str,
|
||||
cc: Option<&str>,
|
||||
bcc: Option<&str>,
|
||||
subject: &str,
|
||||
from_email: &'a str,
|
||||
to_email: &'a str,
|
||||
cc: Option<&'a str>,
|
||||
bcc: Option<&'a str>,
|
||||
subject: &'a str,
|
||||
}
|
||||
|
||||
fn save_email_tracking_record(
|
||||
conn: diesel::r2d2::Pool<diesel::r2d2::ConnectionManager<diesel::PgConnection>>,
|
||||
params: EmailTrackingParams,
|
||||
) -> Result<(), String> {
|
||||
let mut db_conn = conn.get().map_err(|e| format!("DB connection error: {e}"))?;
|
||||
|
||||
|
|
@ -101,14 +105,14 @@ fn save_email_tracking_record(
|
|||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, NOW())"
|
||||
)
|
||||
.bind::<diesel::sql_types::Uuid, _>(Uuid::new_v4())
|
||||
.bind::<diesel::sql_types::Text, _>(tracking_id.to_string())
|
||||
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(account_id)
|
||||
.bind::<diesel::sql_types::Text, _>(from_email)
|
||||
.bind::<diesel::sql_types::Text, _>(to_email)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(cc)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(bcc)
|
||||
.bind::<diesel::sql_types::Text, _>(subject)
|
||||
.bind::<diesel::sql_types::Text, _>(params.tracking_id.to_string())
|
||||
.bind::<diesel::sql_types::Uuid, _>(params.bot_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(params.account_id)
|
||||
.bind::<diesel::sql_types::Text, _>(params.from_email)
|
||||
.bind::<diesel::sql_types::Text, _>(params.to_email)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(params.cc)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(params.bcc)
|
||||
.bind::<diesel::sql_types::Text, _>(params.subject)
|
||||
.execute(&mut db_conn)
|
||||
.map_err(|e| format!("Failed to save tracking record: {e}"))?;
|
||||
|
||||
|
|
@ -368,14 +372,16 @@ pub async fn send_email(
|
|||
let _ = tokio::task::spawn_blocking(move || {
|
||||
save_email_tracking_record(
|
||||
conn,
|
||||
EmailTrackingParams {
|
||||
tracking_id,
|
||||
account_uuid,
|
||||
Uuid::nil(),
|
||||
&from_email,
|
||||
&to_email,
|
||||
cc_clone.as_deref(),
|
||||
bcc_clone.as_deref(),
|
||||
&subject,
|
||||
account_id: account_uuid,
|
||||
bot_id: Uuid::nil(),
|
||||
from_email: &from_email,
|
||||
to_email: &to_email,
|
||||
cc: cc_clone.as_deref(),
|
||||
bcc: bcc_clone.as_deref(),
|
||||
subject: &subject,
|
||||
},
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
|
|
|||
|
|
@ -49,16 +49,20 @@ pub fn inject_tracking_pixel(html_body: &str, tracking_id: &str, state: &Arc<App
|
|||
}
|
||||
}
|
||||
|
||||
pub struct EmailTrackingParams<'a> {
|
||||
pub tracking_id: Uuid,
|
||||
pub account_id: Uuid,
|
||||
pub bot_id: Uuid,
|
||||
pub from_email: &'a str,
|
||||
pub to_email: &'a str,
|
||||
pub cc: Option<&'a str>,
|
||||
pub bcc: Option<&'a str>,
|
||||
pub subject: &'a str,
|
||||
}
|
||||
|
||||
pub fn save_email_tracking_record(
|
||||
conn: crate::core::shared::utils::DbPool,
|
||||
tracking_id: Uuid,
|
||||
account_id: Uuid,
|
||||
bot_id: Uuid,
|
||||
from_email: &str,
|
||||
to_email: &str,
|
||||
cc: Option<&str>,
|
||||
bcc: Option<&str>,
|
||||
subject: &str,
|
||||
params: EmailTrackingParams,
|
||||
) -> Result<(), String> {
|
||||
let mut db_conn = conn
|
||||
.get()
|
||||
|
|
@ -73,19 +77,19 @@ pub fn save_email_tracking_record(
|
|||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 0, false)"
|
||||
)
|
||||
.bind::<diesel::sql_types::Uuid, _>(id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(tracking_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(bot_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(account_id)
|
||||
.bind::<diesel::sql_types::Text, _>(from_email)
|
||||
.bind::<diesel::sql_types::Text, _>(to_email)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(cc)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(bcc)
|
||||
.bind::<diesel::sql_types::Text, _>(subject)
|
||||
.bind::<diesel::sql_types::Uuid, _>(params.tracking_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(params.bot_id)
|
||||
.bind::<diesel::sql_types::Uuid, _>(params.account_id)
|
||||
.bind::<diesel::sql_types::Text, _>(params.from_email)
|
||||
.bind::<diesel::sql_types::Text, _>(params.to_email)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(params.cc)
|
||||
.bind::<diesel::sql_types::Nullable<diesel::sql_types::Text>, _>(params.bcc)
|
||||
.bind::<diesel::sql_types::Text, _>(params.subject)
|
||||
.bind::<diesel::sql_types::Timestamptz, _>(now)
|
||||
.execute(&mut db_conn)
|
||||
.map_err(|e| format!("Failed to save tracking record: {}", e))?;
|
||||
|
||||
debug!("Saved email tracking record: tracking_id={}", tracking_id);
|
||||
debug!("Saved email tracking record: tracking_id={}", params.tracking_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{error, info};
|
||||
use log::{error, info, warn};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
|
|
@ -14,11 +14,165 @@ pub struct BedrockClient {
|
|||
|
||||
impl BedrockClient {
|
||||
pub fn new(base_url: String) -> Self {
|
||||
// Accept three URL formats:
|
||||
// 1. OpenAI-compatible: .../openai/v1/chat/completions (use as-is)
|
||||
// 2. Native invoke: .../model/{model-id}/invoke (use as-is, streaming swaps to invoke-with-response-stream)
|
||||
// 3. Bare domain: https://bedrock-runtime.region.amazonaws.com (auto-append OpenAI path)
|
||||
let url = if base_url.contains("/openai/") || base_url.contains("/chat/completions") || base_url.contains("/model/") {
|
||||
base_url
|
||||
} else {
|
||||
let trimmed = base_url.trim_end_matches('/');
|
||||
format!("{}/openai/v1/chat/completions", trimmed)
|
||||
};
|
||||
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url,
|
||||
base_url: url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if URL is native Bedrock invoke endpoint (not OpenAI-compatible)
|
||||
fn is_native_invoke(&self) -> bool {
|
||||
self.base_url.contains("/model/") && self.base_url.contains("/invoke")
|
||||
}
|
||||
|
||||
/// Get streaming URL: for native invoke, swap /invoke to /invoke-with-response-stream
|
||||
fn stream_url(&self) -> String {
|
||||
if self.is_native_invoke() && self.base_url.ends_with("/invoke") {
|
||||
self.base_url.replace("/invoke", "/invoke-with-response-stream")
|
||||
} else {
|
||||
self.base_url.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the auth header from the key
|
||||
fn auth_header(key: &str) -> String {
|
||||
if key.starts_with("Bearer ") {
|
||||
key.to_string()
|
||||
} else {
|
||||
format!("Bearer {}", key)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build formatted messages from raw input
|
||||
fn build_messages(raw_messages: &Value) -> Value {
|
||||
let mut messages_limited = Vec::new();
|
||||
if let Some(msg_array) = raw_messages.as_array() {
|
||||
for msg in msg_array {
|
||||
messages_limited.push(msg.clone());
|
||||
}
|
||||
}
|
||||
Value::Array(messages_limited)
|
||||
}
|
||||
|
||||
/// Send a streaming request and process the response
|
||||
async fn do_stream(
|
||||
&self,
|
||||
formatted_messages: &Value,
|
||||
model: &str,
|
||||
key: &str,
|
||||
tools: Option<&Vec<Value>>,
|
||||
tx: &mpsc::Sender<String>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let auth_header = Self::auth_header(key);
|
||||
|
||||
let mut request_body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": formatted_messages,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
if let Some(tools_value) = tools {
|
||||
if !tools_value.is_empty() {
|
||||
request_body["tools"] = serde_json::json!(tools_value);
|
||||
info!("Added {} tools to Bedrock request", tools_value.len());
|
||||
}
|
||||
}
|
||||
|
||||
let url = self.stream_url();
|
||||
info!("Sending streaming request to Bedrock endpoint: {}", url);
|
||||
|
||||
let response = self.client
|
||||
.post(&url)
|
||||
.header("Authorization", &auth_header)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Bedrock generate_stream error: {}", error_text);
|
||||
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_call_buffer = String::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
for line in text.split('\n') {
|
||||
let line = line.trim();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(delta) = first_choice.get("delta") {
|
||||
if let Some(content) = delta.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = delta.get("tool_calls") {
|
||||
if let Some(calls_array) = tool_calls.as_array() {
|
||||
if let Some(first_call) = calls_array.first() {
|
||||
if let Some(function) = first_call.get("function") {
|
||||
if let Some(name) = function.get("name") {
|
||||
if let Some(name_str) = name.as_str() {
|
||||
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str);
|
||||
}
|
||||
}
|
||||
if let Some(args) = function.get("arguments") {
|
||||
if let Some(args_str) = args.as_str() {
|
||||
tool_call_buffer.push_str(args_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Bedrock stream reading error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tool_call_buffer.is_empty() {
|
||||
tool_call_buffer.push_str("\"}");
|
||||
let _ = tx.send(format!("`tool_call`: {}", tool_call_buffer)).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -38,20 +192,8 @@ impl LLMProvider for BedrockClient {
|
|||
&default_messages
|
||||
};
|
||||
|
||||
|
||||
let mut messages_limited = Vec::new();
|
||||
if let Some(msg_array) = raw_messages.as_array() {
|
||||
for msg in msg_array {
|
||||
messages_limited.push(msg.clone());
|
||||
}
|
||||
}
|
||||
let formatted_messages = serde_json::Value::Array(messages_limited);
|
||||
|
||||
let auth_header = if key.starts_with("Bearer ") {
|
||||
key.to_string()
|
||||
} else {
|
||||
format!("Bearer {}", key)
|
||||
};
|
||||
let formatted_messages = Self::build_messages(raw_messages);
|
||||
let auth_header = Self::auth_header(key);
|
||||
|
||||
let request_body = serde_json::json!({
|
||||
"model": model,
|
||||
|
|
@ -110,127 +252,26 @@ impl LLMProvider for BedrockClient {
|
|||
&default_messages
|
||||
};
|
||||
|
||||
let mut messages_limited = Vec::new();
|
||||
if let Some(msg_array) = raw_messages.as_array() {
|
||||
for msg in msg_array {
|
||||
messages_limited.push(msg.clone());
|
||||
}
|
||||
}
|
||||
let formatted_messages = serde_json::Value::Array(messages_limited);
|
||||
let formatted_messages = Self::build_messages(raw_messages);
|
||||
|
||||
let auth_header = if key.starts_with("Bearer ") {
|
||||
key.to_string()
|
||||
} else {
|
||||
format!("Bearer {}", key)
|
||||
};
|
||||
// Try with tools first
|
||||
let result = self.do_stream(&formatted_messages, model, key, tools, &tx).await;
|
||||
|
||||
let mut request_body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": formatted_messages,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
if let Some(tools_value) = tools {
|
||||
if !tools_value.is_empty() {
|
||||
request_body["tools"] = serde_json::json!(tools_value);
|
||||
info!("Added {} tools to Bedrock request", tools_value.len());
|
||||
if let Err(ref e) = result {
|
||||
let err_str = e.to_string();
|
||||
// If error is "Operation not allowed" or validation_error, retry without tools
|
||||
if (err_str.contains("Operation not allowed") || err_str.contains("validation_error"))
|
||||
&& tools.is_some()
|
||||
{
|
||||
warn!(
|
||||
"Bedrock model '{}' does not support tools, retrying without tools",
|
||||
model
|
||||
);
|
||||
return self.do_stream(&formatted_messages, model, key, None, &tx).await;
|
||||
}
|
||||
}
|
||||
|
||||
let stream_url = if self.base_url.ends_with("/invoke") {
|
||||
self.base_url.replace("/invoke", "/invoke-with-response-stream")
|
||||
} else {
|
||||
self.base_url.clone()
|
||||
};
|
||||
|
||||
info!("Sending streaming request to Bedrock endpoint: {}", stream_url);
|
||||
|
||||
let response = self.client
|
||||
.post(&stream_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("Bedrock generate_stream error: {}", error_text);
|
||||
return Err(format!("Bedrock API error ({}): {}", status, error_text).into());
|
||||
}
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut tool_call_buffer = String::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||
for line in text.split('\n') {
|
||||
let line = line.trim();
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
||||
if let Some(choices) = json.get("choices") {
|
||||
if let Some(first_choice) = choices.get(0) {
|
||||
if let Some(delta) = first_choice.get("delta") {
|
||||
// Handle standard content streaming
|
||||
if let Some(content) = delta.get("content") {
|
||||
if let Some(content_str) = content.as_str() {
|
||||
if !content_str.is_empty() && tx.send(content_str.to_string()).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls streaming
|
||||
if let Some(tool_calls) = delta.get("tool_calls") {
|
||||
if let Some(calls_array) = tool_calls.as_array() {
|
||||
if let Some(first_call) = calls_array.first() {
|
||||
if let Some(function) = first_call.get("function") {
|
||||
// Stream function JSON representation just like OpenAI does
|
||||
if let Some(name) = function.get("name") {
|
||||
if let Some(name_str) = name.as_str() {
|
||||
tool_call_buffer = format!("{{\"name\": \"{}\", \"arguments\": \"", name_str);
|
||||
}
|
||||
}
|
||||
if let Some(args) = function.get("arguments") {
|
||||
if let Some(args_str) = args.as_str() {
|
||||
tool_call_buffer.push_str(args_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Bedrock stream reading error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize tool call JSON parsing
|
||||
if !tool_call_buffer.is_empty() {
|
||||
tool_call_buffer.push_str("\"}");
|
||||
if tx.send(format!("`tool_call`: {}", tool_call_buffer)).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
result
|
||||
}
|
||||
|
||||
async fn cancel_job(&self, _session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ impl HallucinationDetector {
|
|||
|
||||
// Ignore Markdown formatting patterns
|
||||
let md_patterns = ["**", "__", "*", "_", "`", "~~", "---", "***"];
|
||||
if md_patterns.iter().any(|p| trimmed == *p) {
|
||||
if md_patterns.contains(&trimmed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ pub async fn ensure_llama_servers_running(
|
|||
log_jemalloc_stats();
|
||||
|
||||
trace!("ensure_llama_servers_running EXIT OK (non-blocking)");
|
||||
return Ok(());
|
||||
Ok(())
|
||||
|
||||
// OLD BLOCKING CODE - REMOVED TO PREVENT HTTP SERVER BLOCKING
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -162,11 +162,18 @@ impl OpenAIClient {
|
|||
}
|
||||
pub fn new(_api_key: String, base_url: Option<String>, endpoint_path: Option<String>) -> Self {
|
||||
let base = base_url.unwrap_or_else(|| "https://api.openai.com".to_string());
|
||||
let trimmed_base = base.trim_end_matches('/').to_string();
|
||||
|
||||
// Detect if the base URL already contains a completions path
|
||||
let has_v1_path = trimmed_base.contains("/v1/chat/completions");
|
||||
let has_chat_path = !has_v1_path && trimmed_base.contains("/chat/completions");
|
||||
|
||||
// For z.ai API, use different endpoint path
|
||||
let endpoint = if let Some(path) = endpoint_path {
|
||||
path
|
||||
} else if base.contains("z.ai") || base.contains("/v4") {
|
||||
} else if has_v1_path || (has_chat_path && !trimmed_base.contains("z.ai")) {
|
||||
// Path already in base_url, use empty endpoint
|
||||
"".to_string()
|
||||
} else if trimmed_base.contains("z.ai") || trimmed_base.contains("/v4") {
|
||||
"/chat/completions".to_string() // z.ai uses /chat/completions, not /v1/chat/completions
|
||||
} else {
|
||||
"/v1/chat/completions".to_string()
|
||||
|
|
@ -450,7 +457,7 @@ pub fn start_llm_services(state: &std::sync::Arc<crate::core::shared::state::App
|
|||
info!("LLM services started (episodic memory scheduler)");
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LLMProviderType {
|
||||
OpenAI,
|
||||
Claude,
|
||||
|
|
@ -524,12 +531,18 @@ pub fn create_llm_provider(
|
|||
}
|
||||
}
|
||||
|
||||
/// Create LLM provider from URL with optional explicit provider type override.
|
||||
/// If explicit_provider is Some, it takes precedence over URL-based detection.
|
||||
pub fn create_llm_provider_from_url(
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
explicit_provider: Option<LLMProviderType>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
let provider_type = LLMProviderType::from(url);
|
||||
let provider_type = explicit_provider.as_ref().map(|p| *p).unwrap_or_else(|| LLMProviderType::from(url));
|
||||
if explicit_provider.is_some() {
|
||||
info!("Using explicit LLM provider type: {:?} for URL: {}", provider_type, url);
|
||||
}
|
||||
create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
|
||||
}
|
||||
|
||||
|
|
@ -555,8 +568,9 @@ impl DynamicLLMProvider {
|
|||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
explicit_provider: Option<LLMProviderType>,
|
||||
) {
|
||||
let new_provider = create_llm_provider_from_url(url, model, endpoint_path);
|
||||
let new_provider = create_llm_provider_from_url(url, model, endpoint_path, explicit_provider);
|
||||
self.update_provider(new_provider).await;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ impl LLMProvider for VertexClient {
|
|||
} else {
|
||||
// --- Native Gemini JSON format ---
|
||||
// It usually arrives as raw JSON objects, sometimes with leading commas or brackets in a stream array
|
||||
let trimmed = line.trim_start_matches(|c| c == ',' || c == '[' || c == ']').trim_end_matches(']');
|
||||
let trimmed = line.trim_start_matches([',', '[', ']']).trim_end_matches(']');
|
||||
if trimmed.is_empty() { continue; }
|
||||
|
||||
if let Ok(json) = serde_json::from_str::<Value>(trimmed) {
|
||||
|
|
|
|||
|
|
@ -473,6 +473,7 @@ pub async fn create_app_state(
|
|||
Some(llm_model.clone())
|
||||
},
|
||||
Some(llm_endpoint_path.clone()),
|
||||
None,
|
||||
);
|
||||
|
||||
#[cfg(feature = "llm")]
|
||||
|
|
@ -487,10 +488,12 @@ pub async fn create_app_state(
|
|||
llm_url,
|
||||
if llm_model.is_empty() { "(default)" } else { &llm_model },
|
||||
llm_endpoint_path.clone());
|
||||
#[cfg(feature = "llm")]
|
||||
dynamic_llm_provider.update_from_config(
|
||||
&llm_url,
|
||||
if llm_model.is_empty() { None } else { Some(llm_model.clone()) },
|
||||
Some(llm_endpoint_path),
|
||||
None,
|
||||
).await;
|
||||
info!("DynamicLLMProvider initialized successfully");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue