Fix bot_id: Use bot_id from URL path instead of client message
- Extract bot_name from WebSocket query parameters - Look up bot_id from bot_name using database - Pass bot_id to WebSocket message handler - Use session's bot_id for LLM configuration instead of client-provided bot_id - Fixes issue where client sends 'default' bot_id when accessing /edu
This commit is contained in:
parent
51c8a53a90
commit
26963f2caf
1 changed files with 39 additions and 12 deletions
|
|
@ -22,9 +22,9 @@ use axum::{
|
|||
};
|
||||
use diesel::PgConnection;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use log::{error, info, warn};
|
||||
#[cfg(feature = "llm")]
|
||||
use log::trace;
|
||||
use log::{error, info, warn};
|
||||
use serde_json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -90,7 +90,6 @@ impl BotOrchestrator {
|
|||
|
||||
let user_id = Uuid::parse_str(&message.user_id)?;
|
||||
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
|
||||
|
||||
let (session, context_data, history, model, key) = {
|
||||
let state_clone = self.state.clone();
|
||||
|
|
@ -119,10 +118,10 @@ impl BotOrchestrator {
|
|||
|
||||
let config_manager = ConfigManager::new(state_clone.conn.clone());
|
||||
let model = config_manager
|
||||
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
||||
.get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
||||
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||
let key = config_manager
|
||||
.get_config(&bot_id, "llm-key", Some(""))
|
||||
.get_config(&session.bot_id, "llm-key", Some(""))
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok((session, context_data, history, model, key))
|
||||
|
|
@ -161,7 +160,7 @@ impl BotOrchestrator {
|
|||
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
|
||||
let config_manager = ConfigManager::new(self.state.conn.clone());
|
||||
let max_context_size = config_manager
|
||||
.get_config(&bot_id, "llm-server-ctx-size", None)
|
||||
.get_config(&session.bot_id, "llm-server-ctx-size", None)
|
||||
.unwrap_or_default()
|
||||
.parse::<usize>()
|
||||
.unwrap_or(0);
|
||||
|
|
@ -368,6 +367,10 @@ pub async fn websocket_handler(
|
|||
.get("session_id")
|
||||
.and_then(|s| Uuid::parse_str(s).ok());
|
||||
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
|
||||
let bot_name = params
|
||||
.get("bot_name")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
if session_id.is_none() || user_id.is_none() {
|
||||
return (
|
||||
|
|
@ -380,9 +383,26 @@ pub async fn websocket_handler(
|
|||
let session_id = session_id.unwrap_or_default();
|
||||
let user_id = user_id.unwrap_or_default();
|
||||
|
||||
ws.on_upgrade(move |socket| {
|
||||
handle_websocket(socket, state, session_id, user_id)
|
||||
// Look up bot_id from bot_name
|
||||
let bot_id = {
|
||||
let conn = state.conn.get().ok();
|
||||
if let Some(mut db_conn) = conn {
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
let result: Result<Uuid, _> = bots
|
||||
.filter(name.eq(&bot_name))
|
||||
.select(id)
|
||||
.first(&mut db_conn);
|
||||
result.unwrap_or_else(|_| {
|
||||
log::warn!("Bot not found: {}, using nil bot_id", bot_name);
|
||||
Uuid::nil()
|
||||
})
|
||||
} else {
|
||||
log::warn!("Could not get database connection, using nil bot_id");
|
||||
Uuid::nil()
|
||||
}
|
||||
};
|
||||
|
||||
ws.on_upgrade(move |socket| handle_websocket(socket, state, session_id, user_id, bot_id))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
|
|
@ -391,6 +411,7 @@ async fn handle_websocket(
|
|||
state: Arc<AppState>,
|
||||
session_id: Uuid,
|
||||
user_id: Uuid,
|
||||
bot_id: Uuid,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||
|
|
@ -414,6 +435,7 @@ async fn handle_websocket(
|
|||
"type": "connected",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"bot_id": bot_id,
|
||||
"message": "Connected to bot server"
|
||||
});
|
||||
|
||||
|
|
@ -447,8 +469,13 @@ async fn handle_websocket(
|
|||
.await
|
||||
.get(&session_id.to_string())
|
||||
{
|
||||
// Use bot_id from WebSocket connection instead of from message
|
||||
let corrected_msg = UserMessage {
|
||||
bot_id: bot_id.to_string(),
|
||||
..user_msg
|
||||
};
|
||||
if let Err(e) = orchestrator
|
||||
.stream_response(user_msg, tx_clone.clone())
|
||||
.stream_response(corrected_msg, tx_clone.clone())
|
||||
.await
|
||||
{
|
||||
error!("Failed to stream response: {}", e);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue