feat(bot): add semantic caching and improve message handling
Enhances BotOrchestrator by integrating optional semantic caching via LangCache for faster LLM responses. Also refactors message saving to occur before and after direct mode handling, and simplifies context change logic for better clarity and flow.
This commit is contained in:
parent
b5e1501454
commit
8aeb6f31ea
2 changed files with 107 additions and 45 deletions
145
src/bot/mod.rs
145
src/bot/mod.rs
|
|
@ -455,28 +455,6 @@ impl BotOrchestrator {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Handle context change messages (type 4) immediately
|
|
||||||
// before any other processing
|
|
||||||
if message.message_type == 4 {
|
|
||||||
if let Some(context_name) = &message.context_name {
|
|
||||||
self
|
|
||||||
.handle_context_change(
|
|
||||||
&message.user_id,
|
|
||||||
&message.bot_id,
|
|
||||||
&message.session_id,
|
|
||||||
&message.channel,
|
|
||||||
context_name,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let response_content = self.direct_mode_handler(&message, &session).await?;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
let mut session_manager = self.state.session_manager.lock().await;
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(
|
session_manager.save_message(
|
||||||
|
|
@ -486,10 +464,31 @@ impl BotOrchestrator {
|
||||||
&message.content,
|
&message.content,
|
||||||
message.message_type,
|
message.message_type,
|
||||||
)?;
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let response_content = self.direct_mode_handler(&message, &session).await?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut session_manager = self.state.session_manager.lock().await;
|
||||||
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
|
session_manager.save_message(session.id, user_id, 2, &response_content, 1)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create regular response for non-context-change messages
|
// Handle context change messages (type 4) first
|
||||||
|
if message.message_type == 4 {
|
||||||
|
if let Some(context_name) = &message.context_name {
|
||||||
|
return self
|
||||||
|
.handle_context_change(
|
||||||
|
&message.user_id,
|
||||||
|
&message.bot_id,
|
||||||
|
&message.session_id,
|
||||||
|
&message.channel,
|
||||||
|
context_name,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create regular response
|
||||||
let channel = message.channel.clone();
|
let channel = message.channel.clone();
|
||||||
let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
|
let config_manager = ConfigManager::new(Arc::clone(&self.state.conn));
|
||||||
let max_context_size = config_manager
|
let max_context_size = config_manager
|
||||||
|
|
@ -566,30 +565,94 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
||||||
|
let use_langcache = std::env::var("LLM_CACHE")
|
||||||
|
.unwrap_or_else(|_| "false".to_string())
|
||||||
|
.eq_ignore_ascii_case("true");
|
||||||
|
|
||||||
let user_message = UserMessage {
|
if use_langcache {
|
||||||
bot_id: "default".to_string(),
|
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||||
user_id: session.user_id.to_string(),
|
let langcache_client = get_langcache_client()?;
|
||||||
session_id: session.id.to_string(),
|
let isolated_question = message.content.trim().to_string();
|
||||||
channel: "web".to_string(),
|
let question_embeddings = generate_embeddings(vec![isolated_question.clone()]).await?;
|
||||||
content: message.content.clone(),
|
let question_embedding = question_embeddings
|
||||||
message_type: 1,
|
.get(0)
|
||||||
media_url: None,
|
.ok_or_else(|| "Failed to generate embedding for question")?
|
||||||
timestamp: Utc::now(),
|
.clone();
|
||||||
context_name: None,
|
|
||||||
|
let search_results = langcache_client
|
||||||
|
.search("semantic_cache", question_embedding.clone(), 1)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(result) = search_results.first() {
|
||||||
|
let payload = &result.payload;
|
||||||
|
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||||
|
return Ok(resp.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.state
|
||||||
|
.llm_provider
|
||||||
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let point = QdrantPoint {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
vector: question_embedding,
|
||||||
|
payload: serde_json::json!({
|
||||||
|
"question": isolated_question,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": response
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (response_tx, mut response_rx) = mpsc::channel::<BotResponse>(100);
|
langcache_client
|
||||||
if let Err(e) = self.stream_response(user_message, response_tx).await {
|
.upsert_points("semantic_cache", vec![point])
|
||||||
error!("Failed to stream response in direct_mode_handler: {}", e);
|
.await?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
} else {
|
||||||
|
ensure_collection_exists(&self.state, "semantic_cache").await?;
|
||||||
|
let qdrant_client = get_qdrant_client(&self.state)?;
|
||||||
|
let embeddings = generate_embeddings(vec![prompt.clone()]).await?;
|
||||||
|
let embedding = embeddings
|
||||||
|
.get(0)
|
||||||
|
.ok_or_else(|| "Failed to generate embedding")?
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let search_results = qdrant_client
|
||||||
|
.search("semantic_cache", embedding.clone(), 1)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(result) = search_results.first() {
|
||||||
|
if let Some(payload) = &result.payload {
|
||||||
|
if let Some(resp) = payload.get("response").and_then(|v| v.as_str()) {
|
||||||
|
return Ok(resp.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut full_response = String::new();
|
let response = self
|
||||||
while let Some(response) = response_rx.recv().await {
|
.state
|
||||||
full_response.push_str(&response.content);
|
.llm_provider
|
||||||
}
|
.generate(&prompt, &serde_json::Value::Null)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(full_response)
|
let point = QdrantPoint {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
vector: embedding,
|
||||||
|
payload: serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": response
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
qdrant_client
|
||||||
|
.upsert_points("semantic_cache", vec![point])
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stream_response(
|
pub async fn stream_response(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
SET_SCHEDULE "*/10 * * * *"
|
|
||||||
|
|
||||||
|
|
||||||
let text = GET "announcements.gbkb/news/news.pdf"
|
let text = GET "announcements.gbkb/news/news.pdf"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue