feat(bot): add semantic caching and improve message handling

Enhances BotOrchestrator by integrating optional semantic caching via LangCache for faster LLM responses. Also refactors message saving to occur before and after direct mode handling, and simplifies context change logic for better clarity and flow.
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-11-03 12:54:21 -03:00
parent b5e1501454
commit 8aeb6f31ea
2 changed files with 107 additions and 45 deletions

View file

@ -455,28 +455,6 @@ impl BotOrchestrator {
return Ok(());
}
// Handle context change messages (type 4) immediately
// before any other processing
if message.message_type == 4 {
if let Some(context_name) = &message.context_name {
self
.handle_context_change(
&message.user_id,
&message.bot_id,
&message.session_id,
&message.channel,
context_name,
)
.await?;
}
}
let response_content = self.direct_mode_handler(&message, &session).await?;
{
let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message(
@ -486,10 +464,31 @@ impl BotOrchestrator {
&message.content,
message.message_type,
)?;
}
let response_content = self.direct_mode_handler(&message, &session).await?;
{
let mut session_manager = self.state.session_manager.lock().await;
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
}
// Create regular response for non-context-change messages
// Handle context change messages (type 4) first
if message.message_type == 4 {
if let Some(context_name) = &message.context_name {
return self
.handle_context_change(
&message.user_id,
&message.bot_id,
&message.session_id,
&message.channel,
context_name,
)
.await;
}
}
// Create regular response
let channel = message.channel.clone();
let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
let max_context_size = config_manager
@ -566,30 +565,94 @@ impl BotOrchestrator {
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
let use_langcache = std::env::var("LLM_CACHE")
.unwrap_or_else(|_| "false".to_string())
.eq_ignore_ascii_case("true");
let user_message = UserMessage {
bot_id: "default".to_string(),
user_id: session.user_id.to_string(),
session_id: session.id.to_string(),
channel: "web".to_string(),
content: message.content.clone(),
message_type: 1,
media_url: None,
timestamp: Utc::now(),
context_name: None,
};
if use_langcache {
ensure_collection_exists(&self.state, "semantic_cache").await?;
let langcache_client = get_langcache_client()?;
let isolated_question = message.content.trim().to_string();
let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?;
let question_embedding = question_embeddings
.get(0)
.ok_or_else(|| "Failed to generate embedding for question")?
.clone();
let (response_tx, mut response_rx) = mpsc::channel::<BotResponse>(100);
if let Err(e) = self.stream_response(user_message, response_tx).await {
error!("Failed to stream response in direct_mode_handler: {}", e);
let search_results = langcache_client
.search("semantic_cache", question_embedding.clone(), 1)
.await?;
if let Some(result) = search_results.first() {
let payload = &result.payload;
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
return Ok(resp.to_string());
}
}
let response = self
.state
.llm_provider
.generate(&prompt, &serde_json::Value::Null)
.await?;
let point = QdrantPoint {
id: uuid::Uuid::new_v4().to_string(),
vector: question_embedding,
payload: serde_json::json!({
"question": isolated_question,
"prompt": prompt,
"response": response
}),
};
langcache_client
.upsert_points("semantic_cache", vec![point])
.await?;
Ok(response)
} else {
ensure_collection_exists(&self.state, "semantic_cache").await?;
let qdrant_client = get_qdrant_client(&self.state)?;
let embeddings = generate_embeddings(vec![prompt.clone()]).await?;
let embedding = embeddings
.get(0)
.ok_or_else(|| "Failed to generate embedding")?
.clone();
let search_results = qdrant_client
.search("semantic_cache", embedding.clone(), 1)
.await?;
if let Some(result) = search_results.first() {
if let Some(payload) = &result.payload {
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
return Ok(resp.to_string());
}
}
}
let response = self
.state
.llm_provider
.generate(&prompt, &serde_json::Value::Null)
.await?;
let point = QdrantPoint {
id: uuid::Uuid::new_v4().to_string(),
vector: embedding,
payload: serde_json::json!({
"prompt": prompt,
"response": response
}),
};
qdrant_client
.upsert_points("semantic_cache", vec![point])
.await?;
Ok(response)
}
let mut full_response = String::new();
while let Some(response) = response_rx.recv().await {
full_response.push_str(&response.content);
}
Ok(full_response)
}
pub async fn stream_response(

View file

@ -1,4 +1,3 @@
SET_SCHEDULE "*/10 * * * *"
let text = GET "announcements.gbkb/news/news.pdf"