Fix bot_id: Use bot_id from URL path instead of client message
- Extract bot_name from WebSocket query parameters - Look up bot_id from bot_name using database - Pass bot_id to WebSocket message handler - Use session's bot_id for LLM configuration instead of client-provided bot_id - Fixes issue where client sends 'default' bot_id when accessing /edu
This commit is contained in:
parent
51c8a53a90
commit
26963f2caf
1 changed files with 39 additions and 12 deletions
|
|
@ -22,9 +22,9 @@ use axum::{
|
||||||
};
|
};
|
||||||
use diesel::PgConnection;
|
use diesel::PgConnection;
|
||||||
use futures::{sink::SinkExt, stream::StreamExt};
|
use futures::{sink::SinkExt, stream::StreamExt};
|
||||||
use log::{error, info, warn};
|
|
||||||
#[cfg(feature = "llm")]
|
#[cfg(feature = "llm")]
|
||||||
use log::trace;
|
use log::trace;
|
||||||
|
use log::{error, info, warn};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -90,7 +90,6 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let user_id = Uuid::parse_str(&message.user_id)?;
|
let user_id = Uuid::parse_str(&message.user_id)?;
|
||||||
let session_id = Uuid::parse_str(&message.session_id)?;
|
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||||
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
|
|
||||||
|
|
||||||
let (session, context_data, history, model, key) = {
|
let (session, context_data, history, model, key) = {
|
||||||
let state_clone = self.state.clone();
|
let state_clone = self.state.clone();
|
||||||
|
|
@ -119,10 +118,10 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let config_manager = ConfigManager::new(state_clone.conn.clone());
|
let config_manager = ConfigManager::new(state_clone.conn.clone());
|
||||||
let model = config_manager
|
let model = config_manager
|
||||||
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
.get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
||||||
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||||
let key = config_manager
|
let key = config_manager
|
||||||
.get_config(&bot_id, "llm-key", Some(""))
|
.get_config(&session.bot_id, "llm-key", Some(""))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
Ok((session, context_data, history, model, key))
|
Ok((session, context_data, history, model, key))
|
||||||
|
|
@ -161,7 +160,7 @@ impl BotOrchestrator {
|
||||||
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
|
let initial_tokens = crate::shared::utils::estimate_token_count(&context_data);
|
||||||
let config_manager = ConfigManager::new(self.state.conn.clone());
|
let config_manager = ConfigManager::new(self.state.conn.clone());
|
||||||
let max_context_size = config_manager
|
let max_context_size = config_manager
|
||||||
.get_config(&bot_id, "llm-server-ctx-size", None)
|
.get_config(&session.bot_id, "llm-server-ctx-size", None)
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
.parse::<usize>()
|
.parse::<usize>()
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
|
@ -368,6 +367,10 @@ pub async fn websocket_handler(
|
||||||
.get("session_id")
|
.get("session_id")
|
||||||
.and_then(|s| Uuid::parse_str(s).ok());
|
.and_then(|s| Uuid::parse_str(s).ok());
|
||||||
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
|
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
|
||||||
|
let bot_name = params
|
||||||
|
.get("bot_name")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| "default".to_string());
|
||||||
|
|
||||||
if session_id.is_none() || user_id.is_none() {
|
if session_id.is_none() || user_id.is_none() {
|
||||||
return (
|
return (
|
||||||
|
|
@ -380,10 +383,27 @@ pub async fn websocket_handler(
|
||||||
let session_id = session_id.unwrap_or_default();
|
let session_id = session_id.unwrap_or_default();
|
||||||
let user_id = user_id.unwrap_or_default();
|
let user_id = user_id.unwrap_or_default();
|
||||||
|
|
||||||
ws.on_upgrade(move |socket| {
|
// Look up bot_id from bot_name
|
||||||
handle_websocket(socket, state, session_id, user_id)
|
let bot_id = {
|
||||||
})
|
let conn = state.conn.get().ok();
|
||||||
.into_response()
|
if let Some(mut db_conn) = conn {
|
||||||
|
use crate::shared::models::schema::bots::dsl::*;
|
||||||
|
let result: Result<Uuid, _> = bots
|
||||||
|
.filter(name.eq(&bot_name))
|
||||||
|
.select(id)
|
||||||
|
.first(&mut db_conn);
|
||||||
|
result.unwrap_or_else(|_| {
|
||||||
|
log::warn!("Bot not found: {}, using nil bot_id", bot_name);
|
||||||
|
Uuid::nil()
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
log::warn!("Could not get database connection, using nil bot_id");
|
||||||
|
Uuid::nil()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.on_upgrade(move |socket| handle_websocket(socket, state, session_id, user_id, bot_id))
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_websocket(
|
async fn handle_websocket(
|
||||||
|
|
@ -391,6 +411,7 @@ async fn handle_websocket(
|
||||||
state: Arc<AppState>,
|
state: Arc<AppState>,
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
|
bot_id: Uuid,
|
||||||
) {
|
) {
|
||||||
let (mut sender, mut receiver) = socket.split();
|
let (mut sender, mut receiver) = socket.split();
|
||||||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||||||
|
|
@ -414,6 +435,7 @@ async fn handle_websocket(
|
||||||
"type": "connected",
|
"type": "connected",
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"bot_id": bot_id,
|
||||||
"message": "Connected to bot server"
|
"message": "Connected to bot server"
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -447,8 +469,13 @@ async fn handle_websocket(
|
||||||
.await
|
.await
|
||||||
.get(&session_id.to_string())
|
.get(&session_id.to_string())
|
||||||
{
|
{
|
||||||
|
// Use bot_id from WebSocket connection instead of from message
|
||||||
|
let corrected_msg = UserMessage {
|
||||||
|
bot_id: bot_id.to_string(),
|
||||||
|
..user_msg
|
||||||
|
};
|
||||||
if let Err(e) = orchestrator
|
if let Err(e) = orchestrator
|
||||||
.stream_response(user_msg, tx_clone.clone())
|
.stream_response(corrected_msg, tx_clone.clone())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
error!("Failed to stream response: {}", e);
|
error!("Failed to stream response: {}", e);
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue