From 2e1c0a9a68903d1726f4b87a6ca6b4c581c8d60e Mon Sep 17 00:00:00 2001 From: "Rodrigo Rodriguez (Pragmatismo)" Date: Fri, 24 Oct 2025 23:36:16 -0300 Subject: [PATCH] Add caching mechanism to BotOrchestrator for user input responses --- src/bot/mod.rs | 47 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 3f4d4ae0c..73bbed0c2 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -9,16 +9,24 @@ use serde_json; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::mpsc; +use tokio::sync::Mutex; use uuid::Uuid; +use redis::AsyncCommands; +use reqwest::Client; pub struct BotOrchestrator { pub state: Arc, + pub cache: Arc>>, } impl BotOrchestrator { - pub fn new(state: Arc) -> Self { - Self { state } +pub fn new(state: Arc) -> Self { + Self { + state, + cache: Arc::new(Mutex::new(std::collections::HashMap::new())), } +} +} pub async fn handle_user_input( &self, @@ -293,15 +301,39 @@ impl BotOrchestrator { session_manager.get_conversation_history(session.id, session.user_id)? }; - for (role, content) in history { + // Prompt compactor: keep only last 10 entries to limit size + let recent_history = if history.len() > 10 { + &history[history.len() - 10..] + } else { + &history[..] + }; + + for (role, content) in recent_history { prompt.push_str(&format!("{}: {}\n", role, content)); } prompt.push_str(&format!("User: {}\nAssistant:", message.content)); - self.state + // Check in-memory cache for existing response + { + let cache = self.cache.lock().await; + if let Some(cached) = cache.get(&prompt) { + return Ok(cached.clone()); + } + } + + // Generate response via LLM provider + let response = self.state .llm_provider .generate(&prompt, &serde_json::Value::Null) - .await + .await?; + + // Store the new response in cache + { + let mut cache = self.cache.lock().await; + cache.insert(prompt.clone(), response.clone()); + } + + Ok(response) } pub async fn stream_response( @@ -447,12 +479,12 @@ impl BotOrchestrator { } analysis_buffer.push_str(&chunk); - if analysis_buffer.contains("<|channel|>") && !in_analysis { + if analysis_buffer.contains("**") && !in_analysis { in_analysis = true; } if in_analysis { - if analysis_buffer.ends_with("final<|message|>") { + if analysis_buffer.ends_with("final") { debug!( "Analysis section completed, buffer length: {}", analysis_buffer.len() @@ -695,6 +727,7 @@ impl Default for BotOrchestrator { fn default() -> Self { Self { state: Arc::new(AppState::default()), + cache: Arc::new(Mutex::new(std::collections::HashMap::new())), } } }