diff --git a/src/automation/compact_prompt.rs b/src/automation/compact_prompt.rs index 88bbbc000..ed7795158 100644 --- a/src/automation/compact_prompt.rs +++ b/src/automation/compact_prompt.rs @@ -102,14 +102,16 @@ async fn compact_prompt_for_bots( ); let mut conversation = String::new(); - conversation.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n"); - + conversation + .push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n"); + for (role, content) in history.iter().skip(start_index) { if role == "compact" { continue; } - conversation.push_str(&format!("{}: {}\n", - if role == "user" { "User" } else { "Bot" }, + conversation.push_str(&format!( + "{}: {}\n", + if role == "user" { "user" } else { "assistant" }, content )); } @@ -123,11 +125,18 @@ async fn compact_prompt_for_bots( let llm_provider = state.llm_provider.clone(); trace!("Starting summarization for session {}", session.id); let mut filtered = String::new(); - let config_manager = crate::config::ConfigManager::new(state.conn.clone()); - let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default(); + let config_manager = crate::config::ConfigManager::new(state.conn.clone()); + let model = config_manager + .get_config(&Uuid::nil(), "llm-model", None) + .unwrap_or_default(); + let key = config_manager + .get_config(&Uuid::nil(), "llm-key", None) + .unwrap_or_default(); - let summarized = match llm_provider.generate( - "", &serde_json::Value::Array(messages), &model).await { + let summarized = match llm_provider + .generate("", &serde_json::Value::Array(messages), &model, &key) + .await + { Ok(summary) => { trace!( "Successfully summarized session {} ({} chars)", @@ -138,7 +147,8 @@ async fn compact_prompt_for_bots( let handler = llm_models::get_handler( config_manager .get_config(&session.bot_id, "llm-model", None) - .unwrap().as_str(), + .unwrap() + .as_str(), ); filtered = handler.process_content(&summary); diff --git a/src/basic/keywords/llm_keyword.rs b/src/basic/keywords/llm_keyword.rs index 8d71a33e7..3df7ab474 100644 --- a/src/basic/keywords/llm_keyword.rs +++ b/src/basic/keywords/llm_keyword.rs @@ -44,8 +44,10 @@ fn build_llm_prompt(user_text: &str) -> String { pub async fn execute_llm_generation(state: Arc, prompt: String) -> Result> { let config_manager = crate::config::ConfigManager::new(state.conn.clone()); let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default(); + let key = config_manager.get_config(&Uuid::nil(), "llm-key", None).unwrap_or_default(); + let handler = crate::llm_models::get_handler(&model); - let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model).await?; + let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model, &key).await?; let processed = handler.process_content(&raw_response); Ok(processed) } diff --git a/src/bootstrap/mod.rs b/src/bootstrap/mod.rs index 5c7184c17..38ab246eb 100644 --- a/src/bootstrap/mod.rs +++ b/src/bootstrap/mod.rs @@ -5,7 +5,7 @@ use anyhow::Result; use aws_config::BehaviorVersion; use aws_sdk_s3::Client; use diesel::connection::SimpleConnection; -use log::{error, info, trace}; +use log::{error, trace}; use rand::distr::Alphanumeric; use std::io::{self, Write}; use std::path::Path; diff --git a/src/bot/mod.rs b/src/bot/mod.rs index a1cc70198..3a67aa5ae 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -355,6 +355,11 @@ impl BotOrchestrator { .rposition(|(role, _content)| role == "compact") { history = history.split_off(last_compacted_index); + for (role, content) in history.iter_mut() { + if role == "compact" { + *role = "user".to_string(); + } + } } if history_limit > 0 && history.len() > history_limit as usize { let start = history.len() - history_limit as usize; @@ -405,9 +410,16 @@ impl BotOrchestrator { None, ) .unwrap_or_default(); + let key = config_manager + .get_config( + &Uuid::parse_str(&message.bot_id).unwrap_or_default(), + "llm-key", + None, + ) + .unwrap_or_default(); let model1 = model.clone(); tokio::spawn(async move { - if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model).await { + if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model, &key).await { error!("LLM streaming error: {}", e); } }); diff --git a/src/llm/mod.rs b/src/llm/mod.rs index f8d0cb27e..d75605658 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use futures::StreamExt; -use log::info; +use log::{info, trace}; use serde_json::Value; use tokio::sync::mpsc; pub mod local; @@ -11,13 +11,15 @@ pub trait LLMProvider: Send + Sync { prompt: &str, config: &Value, model: &str, + key: &str ) -> Result>; async fn generate_stream( &self, prompt: &str, config: &Value, tx: mpsc::Sender, - model: &str + model: &str, + key: &str ) -> Result<(), Box>; async fn cancel_job( &self, @@ -37,12 +39,14 @@ impl LLMProvider for OpenAIClient { prompt: &str, messages: &Value, model: &str, + key: &str ) -> Result> { let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); - let response = self + let response = + self .client .post(&format!("{}/v1/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Authorization", format!("Bearer {}", key)) .json(&serde_json::json!({ "model": model, "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { @@ -70,13 +74,14 @@ impl LLMProvider for OpenAIClient { prompt: &str, messages: &Value, tx: mpsc::Sender, - model: &str + model: &str, + key: &str ) -> Result<(), Box> { let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let response = self .client .post(&format!("{}/v1/chat/completions", self.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Authorization", format!("Bearer {}", key)) .json(&serde_json::json!({ "model": model.clone(), "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { @@ -89,6 +94,12 @@ impl LLMProvider for OpenAIClient { })) .send() .await?; + let status = response.status(); + if status != reqwest::StatusCode::OK { + let error_text = response.text().await.unwrap_or_default(); + trace!("LLM generate_stream error: {}", error_text); + return Err(format!("LLM request failed with status: {}", status).into()); + } let mut stream = response.bytes_stream(); let mut buffer = String::new(); while let Some(chunk) = stream.next().await { diff --git a/src/session/mod.rs b/src/session/mod.rs index 89eaca1f7..a46e31c0c 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -274,7 +274,7 @@ impl SessionManager { for (other_role, content) in messages { let role_str = match other_role { 1 => "user".to_string(), - 2 => "bot".to_string(), + 2 => "assistant".to_string(), 3 => "system".to_string(), 9 => "compact".to_string(), _ => "unknown".to_string(),