Add caching mechanism to BotOrchestrator for user input responses
This commit is contained in:
parent
6d95c3acd5
commit
2e1c0a9a68
1 changed files with 40 additions and 7 deletions
|
|
@ -9,16 +9,24 @@ use serde_json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use redis::AsyncCommands;
|
||||||
|
use reqwest::Client;
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
pub struct BotOrchestrator {
|
||||||
pub state: Arc<AppState>,
|
pub state: Arc<AppState>,
|
||||||
|
pub cache: Arc<Mutex<std::collections::HashMap<String, String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BotOrchestrator {
|
impl BotOrchestrator {
|
||||||
pub fn new(state: Arc<AppState>) -> Self {
|
pub fn new(state: Arc<AppState>) -> Self {
|
||||||
Self { state }
|
Self {
|
||||||
|
state,
|
||||||
|
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn handle_user_input(
|
pub async fn handle_user_input(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -293,15 +301,39 @@ impl BotOrchestrator {
|
||||||
session_manager.get_conversation_history(session.id, session.user_id)?
|
session_manager.get_conversation_history(session.id, session.user_id)?
|
||||||
};
|
};
|
||||||
|
|
||||||
for (role, content) in history {
|
// Prompt compactor: keep only last 10 entries to limit size
|
||||||
|
let recent_history = if history.len() > 10 {
|
||||||
|
&history[history.len() - 10..]
|
||||||
|
} else {
|
||||||
|
&history[..]
|
||||||
|
};
|
||||||
|
|
||||||
|
for (role, content) in recent_history {
|
||||||
prompt.push_str(&format!("{}: {}\n", role, content));
|
prompt.push_str(&format!("{}: {}\n", role, content));
|
||||||
}
|
}
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
||||||
self.state
|
// Check in-memory cache for existing response
|
||||||
|
{
|
||||||
|
let cache = self.cache.lock().await;
|
||||||
|
if let Some(cached) = cache.get(&prompt) {
|
||||||
|
return Ok(cached.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate response via LLM provider
|
||||||
|
let response = self.state
|
||||||
.llm_provider
|
.llm_provider
|
||||||
.generate(&prompt, &serde_json::Value::Null)
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
.await
|
.await?;
|
||||||
|
|
||||||
|
// Store the new response in cache
|
||||||
|
{
|
||||||
|
let mut cache = self.cache.lock().await;
|
||||||
|
cache.insert(prompt.clone(), response.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_response(
|
pub async fn stream_response(
|
||||||
|
|
@ -447,12 +479,12 @@ impl BotOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
analysis_buffer.push_str(&chunk);
|
analysis_buffer.push_str(&chunk);
|
||||||
if analysis_buffer.contains("<|channel|>") && !in_analysis {
|
if analysis_buffer.contains("**") && !in_analysis {
|
||||||
in_analysis = true;
|
in_analysis = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if in_analysis {
|
if in_analysis {
|
||||||
if analysis_buffer.ends_with("final<|message|>") {
|
if analysis_buffer.ends_with("final") {
|
||||||
debug!(
|
debug!(
|
||||||
"Analysis section completed, buffer length: {}",
|
"Analysis section completed, buffer length: {}",
|
||||||
analysis_buffer.len()
|
analysis_buffer.len()
|
||||||
|
|
@ -695,6 +727,7 @@ impl Default for BotOrchestrator {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
state: Arc::new(AppState::default()),
|
state: Arc::new(AppState::default()),
|
||||||
|
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue