Fix bot_id: Use bot_id from URL path instead of client message

- Extract bot_name from WebSocket query parameters
- Look up bot_id from bot_name using database
- Pass bot_id to WebSocket message handler
- Use session's bot_id for LLM configuration instead of client-provided bot_id
- Fixes issue where client sends 'default' bot_id when accessing /edu
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-01-28 17:18:22 -03:00
parent 51c8a53a90
commit 26963f2caf

View file

@ -22,9 +22,9 @@ use axum::{
}; };
use diesel::PgConnection; use diesel::PgConnection;
use futures::{sink::SinkExt, stream::StreamExt}; use futures::{sink::SinkExt, stream::StreamExt};
use log::{error, info, warn};
#[cfg(feature = "llm")] #[cfg(feature = "llm")]
use log::trace; use log::trace;
use log::{error, info, warn};
use serde_json; use serde_json;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -90,7 +90,6 @@ impl BotOrchestrator {
let user_id = Uuid::parse_str(&message.user_id)?; let user_id = Uuid::parse_str(&message.user_id)?;
let session_id = Uuid::parse_str(&message.session_id)?; let session_id = Uuid::parse_str(&message.session_id)?;
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
let (session, context_data, history, model, key) = { let (session, context_data, history, model, key) = {
let state_clone = self.state.clone(); let state_clone = self.state.clone();
@ -119,10 +118,10 @@ impl BotOrchestrator {
let config_manager = ConfigManager::new(state_clone.conn.clone()); let config_manager = ConfigManager::new(state_clone.conn.clone());
let model = config_manager let model = config_manager
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo")) .get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string()); .unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
let key = config_manager let key = config_manager
.get_config(&bot_id, "llm-key", Some("")) .get_config(&session.bot_id, "llm-key", Some(""))
.unwrap_or_default(); .unwrap_or_default();
Ok((session, context_data, history, model, key)) Ok((session, context_data, history, model, key))
@ -161,7 +160,7 @@ impl BotOrchestrator {
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data); let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
let config_manager = ConfigManager::new(self.state.conn.clone()); let config_manager = ConfigManager::new(self.state.conn.clone());
let max_context_size = config_manager let max_context_size = config_manager
.get_config(&bot_id, "llm-server-ctx-size", None) .get_config(&session.bot_id, "llm-server-ctx-size", None)
.unwrap_or_default() .unwrap_or_default()
.parse::<usize>() .parse::<usize>()
.unwrap_or(0); .unwrap_or(0);
@ -368,6 +367,10 @@ pub async fn websocket_handler(
.get("session_id") .get("session_id")
.and_then(|s| Uuid::parse_str(s).ok()); .and_then(|s| Uuid::parse_str(s).ok());
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok()); let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
let bot_name = params
.get("bot_name")
.cloned()
.unwrap_or_else(|| "default".to_string());
if session_id.is_none() || user_id.is_none() { if session_id.is_none() || user_id.is_none() {
return ( return (
@ -380,10 +383,27 @@ pub async fn websocket_handler(
let session_id = session_id.unwrap_or_default(); let session_id = session_id.unwrap_or_default();
let user_id = user_id.unwrap_or_default(); let user_id = user_id.unwrap_or_default();
ws.on_upgrade(move |socket| { // Look up bot_id from bot_name
handle_websocket(socket, state, session_id, user_id) let bot_id = {
}) let conn = state.conn.get().ok();
.into_response() if let Some(mut db_conn) = conn {
use crate::shared::models::schema::bots::dsl::*;
let result: Result<Uuid, _> = bots
.filter(name.eq(&bot_name))
.select(id)
.first(&mut db_conn);
result.unwrap_or_else(|_| {
log::warn!("Bot not found: {}, using nil bot_id", bot_name);
Uuid::nil()
})
} else {
log::warn!("Could not get database connection, using nil bot_id");
Uuid::nil()
}
};
ws.on_upgrade(move |socket| handle_websocket(socket, state, session_id, user_id, bot_id))
.into_response()
} }
async fn handle_websocket( async fn handle_websocket(
@ -391,6 +411,7 @@ async fn handle_websocket(
state: Arc<AppState>, state: Arc<AppState>,
session_id: Uuid, session_id: Uuid,
user_id: Uuid, user_id: Uuid,
bot_id: Uuid,
) { ) {
let (mut sender, mut receiver) = socket.split(); let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = mpsc::channel::<BotResponse>(100); let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
@ -414,6 +435,7 @@ async fn handle_websocket(
"type": "connected", "type": "connected",
"session_id": session_id, "session_id": session_id,
"user_id": user_id, "user_id": user_id,
"bot_id": bot_id,
"message": "Connected to bot server" "message": "Connected to bot server"
}); });
@ -447,8 +469,13 @@ async fn handle_websocket(
.await .await
.get(&session_id.to_string()) .get(&session_id.to_string())
{ {
// Use bot_id from WebSocket connection instead of from message
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),
..user_msg
};
if let Err(e) = orchestrator if let Err(e) = orchestrator
.stream_response(user_msg, tx_clone.clone()) .stream_response(corrected_msg, tx_clone.clone())
.await .await
{ {
error!("Failed to stream response: {}", e); error!("Failed to stream response: {}", e);