diff --git a/src/basic/keywords/add_suggestion.rs b/src/basic/keywords/add_suggestion.rs index 355748f4..b7688bc9 100644 --- a/src/basic/keywords/add_suggestion.rs +++ b/src/basic/keywords/add_suggestion.rs @@ -9,7 +9,7 @@ pub fn clear_suggestions_keyword(state: Arc, user_session: UserSession let cache = state.cache.clone(); engine - .register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |context, _inputs| { + .register_custom_syntax(&["CLEAR_SUGGESTIONS"], true, move |_context, _inputs| { info!("CLEAR_SUGGESTIONS command executed"); if let Some(cache_client) = &cache { diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 00a9e8dd..b41186e7 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -1367,6 +1367,11 @@ async fn websocket_handler( .unregister_response_channel(&session_id_clone2) .await; + // Cancel any ongoing LLM jobs for this session + if let Err(e) = data.llm_provider.cancel_job(&session_id_clone2).await { + warn!("Failed to cancel LLM job for session {}: {}", session_id_clone2, e); + } + info!("WebSocket fully closed for session {}", session_id_clone2); break; } diff --git a/src/llm/llama.rs b/src/llm/llama.rs new file mode 100644 index 00000000..2060979d --- /dev/null +++ b/src/llm/llama.rs @@ -0,0 +1,123 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value; +use std::sync::Arc; +use tokio::sync::mpsc; +use futures_util::StreamExt; +use crate::tools::ToolManager; +use super::LLMProvider; + +pub struct LlamaClient { + client: Client, + base_url: String, +} + +impl LlamaClient { + pub fn new(base_url: String) -> Self { + Self { + client: Client::new(), + base_url, + } + } +} + +#[async_trait] +impl LLMProvider for LlamaClient { + async fn generate( + &self, + prompt: &str, + config: &Value, + ) -> Result> { + let response = self.client + .post(&format!("{}/completion", self.base_url)) + .json(&serde_json::json!({ + "prompt": prompt, + "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512), + "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7), + })) + .send() + .await?; + + let result: Value = response.json().await?; + Ok(result["content"] + .as_str() + .unwrap_or("") + .to_string()) + } + + async fn generate_stream( + &self, + prompt: &str, + config: &Value, + tx: mpsc::Sender, + ) -> Result<(), Box> { + let response = self.client + .post(&format!("{}/completion", self.base_url)) + .json(&serde_json::json!({ + "prompt": prompt, + "n_predict": config.get("max_tokens").and_then(|v| v.as_i64()).unwrap_or(512), + "temperature": config.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7), + "stream": true + })) + .send() + .await?; + + let mut stream = response.bytes_stream(); + let mut buffer = String::new(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + let chunk_str = String::from_utf8_lossy(&chunk); + + for line in chunk_str.lines() { + if let Ok(data) = serde_json::from_str::(&line) { + if let Some(content) = data["content"].as_str() { + buffer.push_str(content); + let _ = tx.send(content.to_string()).await; + } + } + } + } + + Ok(()) + } + + async fn generate_with_tools( + &self, + prompt: &str, + config: &Value, + available_tools: &[String], + _tool_manager: Arc, + _session_id: &str, + _user_id: &str, + ) -> Result> { + let tools_info = if available_tools.is_empty() { + String::new() + } else { + format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) + }; + + let enhanced_prompt = format!("{}{}", prompt, tools_info); + self.generate(&enhanced_prompt, config).await + } + + async fn cancel_job( + &self, + session_id: &str, + ) -> Result<(), Box> { + // llama.cpp cancellation endpoint + let response = self.client + .post(&format!("{}/cancel", self.base_url)) + .json(&serde_json::json!({ + "session_id": session_id + })) + .send() + .await?; + + if response.status().is_success() { + Ok(()) + } else { + Err(format!("Failed to cancel job for session {}", session_id).into()) + } + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index aca32f86..1ba6db08 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -6,6 +6,8 @@ use tokio::sync::mpsc; use crate::tools::ToolManager; +pub mod llama; + #[async_trait] pub trait LLMProvider: Send + Sync { async fn generate( @@ -30,6 +32,11 @@ pub trait LLMProvider: Send + Sync { session_id: &str, user_id: &str, ) -> Result>; + + async fn cancel_job( + &self, + session_id: &str, + ) -> Result<(), Box>; } pub struct OpenAIClient { @@ -144,6 +151,14 @@ impl LLMProvider for OpenAIClient { self.generate(&enhanced_prompt, &Value::Null).await } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + // OpenAI doesn't support job cancellation + Ok(()) + } } pub struct AnthropicClient { @@ -254,6 +269,14 @@ impl LLMProvider for AnthropicClient { self.generate(&enhanced_prompt, &Value::Null).await } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + // Anthropic doesn't support job cancellation + Ok(()) + } } pub struct MockLLMProvider; @@ -307,4 +330,12 @@ impl LLMProvider for MockLLMProvider { tools_list, prompt )) } + + async fn cancel_job( + &self, + _session_id: &str, + ) -> Result<(), Box> { + // Mock implementation just logs the cancellation + Ok(()) + } }