Fix bot_id: Use bot_id from URL path instead of client message

- Extract bot_name from WebSocket query parameters
- Look up bot_id from bot_name using database
- Pass bot_id to WebSocket message handler
- Use session's bot_id for LLM configuration instead of client-provided bot_id
- Fixes issue where client sends 'default' bot_id when accessing /edu
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-01-28 17:18:22 -03:00
parent 51c8a53a90
commit 26963f2caf

View file

@ -22,9 +22,9 @@ use axum::{
};
use diesel::PgConnection;
use futures::{sink::SinkExt, stream::StreamExt};
use log::{error, info, warn};
#[cfg(feature = "llm")]
use log::trace;
use log::{error, info, warn};
use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
@ -90,7 +90,6 @@ impl BotOrchestrator {
let user_id = Uuid::parse_str(&message.user_id)?;
let session_id = Uuid::parse_str(&message.session_id)?;
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
let (session, context_data, history, model, key) = {
let state_clone = self.state.clone();
@ -119,10 +118,10 @@ impl BotOrchestrator {
let config_manager = ConfigManager::new(state_clone.conn.clone());
let model = config_manager
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo"))
.get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
let key = config_manager
.get_config(&bot_id, "llm-key", Some(""))
.get_config(&session.bot_id, "llm-key", Some(""))
.unwrap_or_default();
Ok((session, context_data, history, model, key))
@ -161,7 +160,7 @@ impl BotOrchestrator {
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
let config_manager = ConfigManager::new(self.state.conn.clone());
let max_context_size = config_manager
.get_config(&bot_id, "llm-server-ctx-size", None)
.get_config(&session.bot_id, "llm-server-ctx-size", None)
.unwrap_or_default()
.parse::<usize>()
.unwrap_or(0);
@ -319,7 +318,7 @@ impl BotOrchestrator {
response_tx: mpsc::Sender<BotResponse>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
warn!("LLM feature not enabled, cannot stream response");
let error_response = BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
@ -334,7 +333,7 @@ impl BotOrchestrator {
context_length: 0,
context_max_length: 0,
};
response_tx.send(error_response).await?;
Ok(())
}
@ -368,6 +367,10 @@ pub async fn websocket_handler(
.get("session_id")
.and_then(|s| Uuid::parse_str(s).ok());
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
let bot_name = params
.get("bot_name")
.cloned()
.unwrap_or_else(|| "default".to_string());
if session_id.is_none() || user_id.is_none() {
return (
@ -380,10 +383,27 @@ pub async fn websocket_handler(
let session_id = session_id.unwrap_or_default();
let user_id = user_id.unwrap_or_default();
ws.on_upgrade(move |socket| {
handle_websocket(socket, state, session_id, user_id)
})
.into_response()
// Look up bot_id from bot_name
let bot_id = {
let conn = state.conn.get().ok();
if let Some(mut db_conn) = conn {
use crate::shared::models::schema::bots::dsl::*;
let result: Result<Uuid, _> = bots
.filter(name.eq(&bot_name))
.select(id)
.first(&mut db_conn);
result.unwrap_or_else(|_| {
log::warn!("Bot not found: {}, using nil bot_id", bot_name);
Uuid::nil()
})
} else {
log::warn!("Could not get database connection, using nil bot_id");
Uuid::nil()
}
};
ws.on_upgrade(move |socket| handle_websocket(socket, state, session_id, user_id, bot_id))
.into_response()
}
async fn handle_websocket(
@ -391,6 +411,7 @@ async fn handle_websocket(
state: Arc<AppState>,
session_id: Uuid,
user_id: Uuid,
bot_id: Uuid,
) {
let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
@ -414,6 +435,7 @@ async fn handle_websocket(
"type": "connected",
"session_id": session_id,
"user_id": user_id,
"bot_id": bot_id,
"message": "Connected to bot server"
});
@ -447,8 +469,13 @@ async fn handle_websocket(
.await
.get(&session_id.to_string())
{
// Use bot_id from WebSocket connection instead of from message
let corrected_msg = UserMessage {
bot_id: bot_id.to_string(),
..user_msg
};
if let Err(e) = orchestrator
.stream_response(user_msg, tx_clone.clone())
.stream_response(corrected_msg, tx_clone.clone())
.await
{
error!("Failed to stream response: {}", e);