Add caching mechanism to BotOrchestrator for user input responses
This commit is contained in:
parent
6d95c3acd5
commit
2e1c0a9a68
1 changed files with 40 additions and 7 deletions
|
|
@ -9,16 +9,24 @@ use serde_json;
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
use redis::AsyncCommands;
|
||||
use reqwest::Client;
|
||||
|
||||
pub struct BotOrchestrator {
|
||||
pub state: Arc<AppState>,
|
||||
pub cache: Arc<Mutex<std::collections::HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl BotOrchestrator {
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self { state }
|
||||
pub fn new(state: Arc<AppState>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_user_input(
|
||||
&self,
|
||||
|
|
@ -293,15 +301,39 @@ impl BotOrchestrator {
|
|||
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||
};
|
||||
|
||||
for (role, content) in history {
|
||||
// Prompt compactor: keep only last 10 entries to limit size
|
||||
let recent_history = if history.len() > 10 {
|
||||
&history[history.len() - 10..]
|
||||
} else {
|
||||
&history[..]
|
||||
};
|
||||
|
||||
for (role, content) in recent_history {
|
||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
||||
}
|
||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||
|
||||
self.state
|
||||
// Check in-memory cache for existing response
|
||||
{
|
||||
let cache = self.cache.lock().await;
|
||||
if let Some(cached) = cache.get(&prompt) {
|
||||
return Ok(cached.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate response via LLM provider
|
||||
let response = self.state
|
||||
.llm_provider
|
||||
.generate(&prompt, &serde_json::Value::Null)
|
||||
.await
|
||||
.await?;
|
||||
|
||||
// Store the new response in cache
|
||||
{
|
||||
let mut cache = self.cache.lock().await;
|
||||
cache.insert(prompt.clone(), response.clone());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn stream_response(
|
||||
|
|
@ -447,12 +479,12 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
analysis_buffer.push_str(&chunk);
|
||||
if analysis_buffer.contains("<|channel|>") && !in_analysis {
|
||||
if analysis_buffer.contains("**") && !in_analysis {
|
||||
in_analysis = true;
|
||||
}
|
||||
|
||||
if in_analysis {
|
||||
if analysis_buffer.ends_with("final<|message|>") {
|
||||
if analysis_buffer.ends_with("final") {
|
||||
debug!(
|
||||
"Analysis section completed, buffer length: {}",
|
||||
analysis_buffer.len()
|
||||
|
|
@ -695,6 +727,7 @@ impl Default for BotOrchestrator {
|
|||
fn default() -> Self {
|
||||
Self {
|
||||
state: Arc::new(AppState::default()),
|
||||
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue