botserver/src/core/bot/mod.rs
Rodrigo Rodriguez (Pragmatismo) 4b5d18db71 Add metadata and refactor BASIC compiler
- Add package metadata (keywords, categories) to Cargo.toml
- Add #[must_use] attributes to constructor methods
- Simplify conditional logic by inverting if-else blocks
- Replace string formatting with interpolation syntax
2025-11-27 15:19:17 -03:00

746 lines
25 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use crate::core::config::ConfigManager;
use crate::drive::drive_monitor::DriveMonitor;
use crate::llm::OpenAIClient;
use crate::shared::models::{BotResponse, UserMessage, UserSession};
use crate::shared::state::AppState;
use axum::extract::ws::{Message, WebSocket};
use axum::{
extract::{ws::WebSocketUpgrade, Extension, Query, State},
http::StatusCode,
response::{IntoResponse, Json},
};
use diesel::PgConnection;
use futures::{sink::SinkExt, stream::StreamExt};
use log::{error, info, trace, warn};
use serde_json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex as AsyncMutex;
use uuid::Uuid;
pub mod channels;
pub mod multimedia;
/// Retrieves the default bot (first active bot) from the database.
pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
use crate::shared::models::schema::bots::dsl::*;
use diesel::prelude::*;
match bots
.filter(is_active.eq(true))
.select((id, name))
.first::<(Uuid, String)>(conn)
.optional()
{
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
Ok(None) => {
warn!("No active bots found, using nil UUID");
(Uuid::nil(), "default".to_string())
}
Err(e) => {
error!("Failed to query default bot: {}", e);
(Uuid::nil(), "default".to_string())
}
}
}
#[derive(Debug)]
pub struct BotOrchestrator {
pub state: Arc<AppState>,
pub mounted_bots: Arc<AsyncMutex<HashMap<String, Arc<DriveMonitor>>>>,
}
impl BotOrchestrator {
pub fn new(state: Arc<AppState>) -> Self {
Self {
state,
mounted_bots: Arc::new(AsyncMutex::new(HashMap::new())),
}
}
// ... (All existing methods unchanged) ...
pub async fn mount_all_bots(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Starting to mount all bots");
// Get all active bots from database
let bots = {
let mut conn = self.state.conn.get()?;
use crate::shared::models::schema::bots::dsl::*;
use diesel::prelude::*;
bots.filter(is_active.eq(true))
.select((id, name))
.load::<(Uuid, String)>(&mut conn)?
};
info!("Found {} active bots to mount", bots.len());
// Mount each bot
for (bot_id, bot_name) in bots {
info!("Mounting bot: {} ({})", bot_name, bot_id);
// Create DriveMonitor for this bot
let drive_monitor = Arc::new(DriveMonitor::new(
self.state.clone(),
format!("bot-{}", bot_id), // bucket name
bot_id,
));
// Start monitoring
let monitor_clone = drive_monitor.clone();
tokio::spawn(async move {
if let Err(e) = monitor_clone.start_monitoring().await {
error!("Failed to start monitoring for bot {}: {}", bot_id, e);
}
});
// Store in mounted_bots
self.mounted_bots
.lock()
.await
.insert(bot_id.to_string(), drive_monitor);
info!("Bot {} mounted successfully", bot_name);
}
info!("All bots mounted successfully");
Ok(())
}
// Stream response to user via LLM
pub async fn stream_response(
&self,
message: UserMessage,
response_tx: mpsc::Sender<BotResponse>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
trace!(
"Streaming response for user: {}, session: {}",
message.user_id,
message.session_id
);
let user_id = Uuid::parse_str(&message.user_id)?;
let session_id = Uuid::parse_str(&message.session_id)?;
let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_default();
// All database operations in one blocking section
let (session, context_data, history, model, key, _bot_id_from_config, cache_enabled) = {
let state_clone = self.state.clone();
tokio::task::spawn_blocking(
move || -> Result<_, Box<dyn std::error::Error + Send + Sync>> {
// Get session
let session = {
let mut sm = state_clone.session_manager.blocking_lock();
sm.get_session_by_id(session_id)?
}
.ok_or_else(|| "Session not found")?;
// Save user message
{
let mut sm = state_clone.session_manager.blocking_lock();
sm.save_message(session.id, user_id, 1, &message.content, 1)?;
}
// Get context and history
let context_data = {
let sm = state_clone.session_manager.blocking_lock();
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
sm.get_session_context_data(&session.id, &session.user_id)
.await
})?
};
let history = {
let mut sm = state_clone.session_manager.blocking_lock();
sm.get_conversation_history(session.id, user_id)?
};
// Get model config
let config_manager = ConfigManager::new(state_clone.conn.clone());
let model = config_manager
.get_config(&bot_id, "llm-model", Some("gpt-3.5-turbo"))
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
let key = config_manager
.get_config(&bot_id, "llm-key", Some(""))
.unwrap_or_default();
// Check if llm-cache is enabled for this bot
let cache_enabled = config_manager
.get_config(&bot_id, "llm-cache", Some("true"))
.unwrap_or_else(|_| "true".to_string());
Ok((
session,
context_data,
history,
model,
key,
bot_id,
cache_enabled,
))
},
)
.await??
};
// Build messages with bot_id for cache
let system_prompt = std::env::var("SYSTEM_PROMPT")
.unwrap_or_else(|_| "You are a helpful assistant.".to_string());
let mut messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
// Add bot_id and cache config to messages for the cache layer
if let serde_json::Value::Object(ref mut map) = messages {
map.insert("bot_id".to_string(), serde_json::json!(bot_id.to_string()));
map.insert("llm_cache".to_string(), serde_json::json!(cache_enabled));
} else if let serde_json::Value::Array(_) = messages {
// If messages is an array, wrap it in an object
let messages_array = messages.clone();
messages = serde_json::json!({
"messages": messages_array,
"bot_id": bot_id.to_string(),
"llm_cache": cache_enabled
});
}
// Stream from LLM
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
let llm = self.state.llm_provider.clone();
tokio::spawn(async move {
if let Err(e) = llm
.generate_stream("", &messages, stream_tx, &model, &key)
.await
{
error!("LLM streaming error: {}", e);
}
});
let mut full_response = String::new();
let mut chunk_count = 0;
while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1;
info!("Received LLM chunk #{}: {:?}", chunk_count, chunk);
full_response.push_str(&chunk);
let response = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: chunk,
message_type: 2,
stream_token: None,
is_complete: false,
suggestions: Vec::new(),
context_name: None,
context_length: 0,
context_max_length: 0,
};
info!("Sending streaming chunk to WebSocket");
if let Err(e) = response_tx.send(response).await {
error!("Failed to send streaming chunk: {}", e);
break;
}
}
info!(
"LLM streaming complete, received {} chunks, total length: {}",
chunk_count,
full_response.len()
);
// Send final complete response
let final_response = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: full_response.clone(),
message_type: 2,
stream_token: None,
is_complete: true,
suggestions: Vec::new(),
context_name: None,
context_length: 0,
context_max_length: 0,
};
info!("Sending final complete response to WebSocket");
response_tx.send(final_response).await?;
info!("Final response sent successfully");
// Save bot response in blocking context
let state_for_save = self.state.clone();
let full_response_clone = full_response.clone();
tokio::task::spawn_blocking(
move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut sm = state_for_save.session_manager.blocking_lock();
sm.save_message(session.id, user_id, 2, &full_response_clone, 2)?;
Ok(())
},
)
.await??;
Ok(())
}
// ... (Other methods unchanged) ...
pub async fn get_user_sessions(
&self,
user_id: Uuid,
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.state.session_manager.lock().await;
let sessions = session_manager.get_user_sessions(user_id)?;
Ok(sessions)
}
pub async fn get_conversation_history(
&self,
session_id: Uuid,
user_id: Uuid,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
let mut session_manager = self.state.session_manager.lock().await;
let history = session_manager.get_conversation_history(session_id, user_id)?;
Ok(history)
}
pub async fn unmount_bot(
&self,
bot_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut mounted = self.mounted_bots.lock().await;
if let Some(monitor) = mounted.remove(bot_id) {
// Stop monitoring
monitor.stop_monitoring().await?;
info!("Bot {} unmounted successfully", bot_id);
} else {
warn!("Bot {} was not mounted", bot_id);
}
Ok(())
}
pub async fn get_mounted_bots(&self) -> Vec<String> {
let mounted = self.mounted_bots.lock().await;
mounted.keys().cloned().collect()
}
pub async fn is_bot_mounted(&self, bot_id: &str) -> bool {
let mounted = self.mounted_bots.lock().await;
mounted.contains_key(bot_id)
}
// ... (Remaining BotOrchestrator methods unchanged) ...
}
/* Axum handlers placeholders that delegate to BotOrchestrator where appropriate */
/// WebSocket handler that upgrades HTTP connection to WebSocket
pub async fn websocket_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
) -> impl IntoResponse {
let session_id = params
.get("session_id")
.and_then(|s| Uuid::parse_str(s).ok());
let user_id = params.get("user_id").and_then(|s| Uuid::parse_str(s).ok());
if session_id.is_none() || user_id.is_none() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "session_id and user_id are required" })),
)
.into_response();
}
ws.on_upgrade(move |socket| {
handle_websocket(socket, state, session_id.unwrap(), user_id.unwrap())
})
.into_response()
}
/// Handles an individual WebSocket connection
async fn handle_websocket(
socket: WebSocket,
state: Arc<AppState>,
session_id: Uuid,
user_id: Uuid,
) {
let (mut sender, mut receiver) = socket.split();
// Create a channel for this WebSocket connection
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
// Register this connection with the web adapter
state
.web_adapter
.add_connection(session_id.to_string(), tx.clone())
.await;
// Also register in response_channels for BotOrchestrator
{
let mut channels = state.response_channels.lock().await;
channels.insert(session_id.to_string(), tx.clone());
}
info!(
"WebSocket connected for session: {}, user: {}",
session_id, user_id
);
// Execute start.bas if it exists
let state_for_start = state.clone();
let session_for_start = {
let mut sm = state.session_manager.lock().await;
sm.get_session_by_id(session_id).ok().and_then(|opt| opt)
};
if let Some(session_clone) = session_for_start {
tokio::task::spawn_blocking(move || {
use crate::basic::ScriptService;
let bot_name = "default"; // TODO: Get from session
let start_script_path =
format!("./work/{}.gbai/{}.gbdialog/start.bas", bot_name, bot_name);
if let Ok(start_content) = std::fs::read_to_string(&start_script_path) {
info!("Executing start.bas for session {}", session_id);
let script_service = ScriptService::new(state_for_start, session_clone);
match script_service.compile(&start_content) {
Ok(ast) => {
if let Err(e) = script_service.run(&ast) {
error!("Failed to execute start.bas: {}", e);
} else {
info!("start.bas executed successfully for session {}", session_id);
}
}
Err(e) => {
error!("Failed to compile start.bas: {}", e);
}
}
} else {
info!("No start.bas found for bot {}", bot_name);
}
});
}
// Send initial welcome message
let welcome = serde_json::json!({
"type": "connected",
"session_id": session_id,
"user_id": user_id,
"message": "Connected to bot server"
});
if let Ok(welcome_str) = serde_json::to_string(&welcome) {
info!("Sending welcome message to session {}", session_id);
if let Err(e) = sender.send(Message::Text(welcome_str.into())).await {
error!("Failed to send welcome message: {}", e);
}
}
// Spawn task to send messages from the channel to the WebSocket
let mut send_task = tokio::spawn(async move {
while let Some(response) = rx.recv().await {
if let Ok(json_str) = serde_json::to_string(&response) {
if sender.send(Message::Text(json_str.into())).await.is_err() {
break;
}
}
}
});
// Handle incoming messages from the WebSocket
let state_clone = state.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
info!("WebSocket received raw message type: {:?}", msg);
match msg {
Message::Text(text) => {
info!(
"Received WebSocket text message (length {}): {}",
text.len(),
text
);
match serde_json::from_str::<UserMessage>(&text) {
Ok(user_msg) => {
info!(
"Successfully parsed user message from session: {}, content: {}",
session_id, user_msg.content
);
// Process the message through the bot system
if let Err(e) = process_user_message(
state_clone.clone(),
session_id,
user_id,
user_msg,
)
.await
{
error!("Error processing user message: {}", e);
}
}
Err(e) => {
error!(
"Failed to parse user message from session {}: {} - Parse error: {}",
session_id, text, e
);
}
}
}
Message::Close(_) => {
info!(
"WebSocket close message received for session: {}",
session_id
);
break;
}
Message::Ping(_data) => {
// Pings are automatically handled by axum
}
Message::Pong(_) => {
// Pongs are automatically handled by axum
}
_ => {}
}
}
});
// Wait for either task to finish
tokio::select! {
_ = (&mut send_task) => {
recv_task.abort();
}
_ = (&mut recv_task) => {
send_task.abort();
}
}
// Clean up: remove the connection from the adapter
state
.web_adapter
.remove_connection(&session_id.to_string())
.await;
// Also remove from response_channels
{
let mut channels = state.response_channels.lock().await;
channels.remove(&session_id.to_string());
}
info!("WebSocket disconnected for session: {}", session_id);
}
/// Process a user message received via WebSocket
async fn process_user_message(
state: Arc<AppState>,
session_id: Uuid,
user_id: Uuid,
user_msg: UserMessage,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!(
"Processing message from user {} in session {}: {}",
user_id, session_id, user_msg.content
);
// Get the response channel for this session
let tx = {
let channels = state.response_channels.lock().await;
channels.get(&session_id.to_string()).cloned()
};
if let Some(response_tx) = tx {
// Use BotOrchestrator to stream the response
let orchestrator = BotOrchestrator::new(state.clone());
if let Err(e) = orchestrator.stream_response(user_msg, response_tx).await {
error!("Failed to stream response: {}", e);
}
} else {
error!("No response channel found for session {}", session_id);
}
Ok(())
}
/// Create a new bot (placeholder implementation)
pub async fn create_bot_handler(
Extension(state): Extension<Arc<AppState>>,
Json(payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
let bot_name = payload
.get("bot_name")
.cloned()
.unwrap_or_else(|| "default".to_string());
// Use state to create the bot in the database
let mut conn = match state.conn.get() {
Ok(conn) => conn,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Database error: {}", e) })),
)
}
};
use crate::shared::models::schema::bots::dsl::*;
use diesel::prelude::*;
let new_bot = (
name.eq(&bot_name),
description.eq(format!("Bot created via API: {}", bot_name)),
llm_provider.eq("openai"),
llm_config.eq(serde_json::json!({"model": "gpt-4"})),
context_provider.eq("none"),
context_config.eq(serde_json::json!({})),
is_active.eq(true),
);
match diesel::insert_into(bots)
.values(&new_bot)
.execute(&mut conn)
{
Ok(_) => (
StatusCode::OK,
Json(serde_json::json!({
"status": format!("bot '{}' created successfully", bot_name),
"bot_name": bot_name
})),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Failed to create bot: {}", e) })),
),
}
}
/// Mount an existing bot (placeholder implementation)
pub async fn mount_bot_handler(
Extension(state): Extension<Arc<AppState>>,
Json(payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
let bot_guid = payload.get("bot_guid").cloned().unwrap_or_default();
// Parse bot UUID
let bot_uuid = match Uuid::parse_str(&bot_guid) {
Ok(uuid) => uuid,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": format!("Invalid bot UUID: {}", e) })),
);
}
};
// Verify bot exists in database
let bot_name = {
let mut conn = match state.conn.get() {
Ok(conn) => conn,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Database error: {}", e) })),
);
}
};
use crate::shared::models::schema::bots::dsl::*;
use diesel::prelude::*;
match bots
.filter(id.eq(bot_uuid))
.select(name)
.first::<String>(&mut conn)
{
Ok(n) => n,
Err(_) => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "Bot not found" })),
);
}
}
};
// Create DriveMonitor for this bot
let drive_monitor = Arc::new(DriveMonitor::new(
state.clone(),
format!("bot-{}", bot_uuid),
bot_uuid,
));
// Start monitoring
let monitor_clone = drive_monitor.clone();
tokio::spawn(async move {
if let Err(e) = monitor_clone.start_monitoring().await {
error!("Failed to start monitoring for bot {}: {}", bot_uuid, e);
}
});
// Mount the bot
let orchestrator = BotOrchestrator::new(state.clone());
orchestrator
.mounted_bots
.lock()
.await
.insert(bot_guid.clone(), drive_monitor);
info!("Bot {} ({}) mounted successfully", bot_name, bot_guid);
(
StatusCode::OK,
Json(serde_json::json!({
"status": format!("bot '{}' mounted", bot_guid),
"bot_name": bot_name
})),
)
}
/// Handle user input for a bot (placeholder implementation)
pub async fn handle_user_input_handler(
Extension(_state): Extension<Arc<AppState>>,
Json(payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
let session_id = payload.get("session_id").cloned().unwrap_or_default();
let user_input = payload.get("input").cloned().unwrap_or_default();
(
StatusCode::OK,
Json(
serde_json::json!({ "status": format!("input '{}' processed for session {}", user_input, session_id) }),
),
)
}
/// Retrieve user sessions (placeholder implementation)
pub async fn get_user_sessions_handler(
Extension(_state): Extension<Arc<AppState>>,
Json(_payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
(StatusCode::OK, Json(serde_json::json!({ "sessions": [] })))
}
/// Retrieve conversation history (placeholder implementation)
pub async fn get_conversation_history_handler(
Extension(_state): Extension<Arc<AppState>>,
Json(_payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
(StatusCode::OK, Json(serde_json::json!({ "history": [] })))
}
/// Send warning (placeholder implementation)
pub async fn send_warning_handler(
Extension(_state): Extension<Arc<AppState>>,
Json(_payload): Json<HashMap<String, String>>,
) -> impl IntoResponse {
(
StatusCode::OK,
Json(serde_json::json!({ "status": "warning acknowledged" })),
)
}