diff --git a/.zed/debug.json b/.zed/debug.json index fb5d47718..9254d7c0c 100644 --- a/.zed/debug.json +++ b/.zed/debug.json @@ -5,7 +5,7 @@ "command": "cargo", "args": ["build"] }, - "program": "$ZED_WORKTREE_ROOT/target/debug/gbserver", + "program": "$ZED_WORKTREE_ROOT/target/debug/botserver", "sourceLanguages": ["rust"], "request": "launch", diff --git a/scripts/dev/build_prompt.sh b/scripts/dev/build_prompt.sh index 927fda7cc..a0a88e75d 100755 --- a/scripts/dev/build_prompt.sh +++ b/scripts/dev/build_prompt.sh @@ -9,8 +9,8 @@ echo "Consolidated LLM Context" > "$OUTPUT_FILE" prompts=( "../../prompts/dev/shared.md" "../../Cargo.toml" - #"../../prompts/dev/fix.md" - "../../prompts/dev/generation.md" + "../../prompts/dev/fix.md" + #"../../prompts/dev/generation.md" ) for file in "${prompts[@]}"; do @@ -23,18 +23,18 @@ dirs=( #"automation" #"basic" "bot" - "channels" + #"channels" #"config" - "context" + #"context" #"email" #"file" - "llm" + #"llm" #"llm_legacy" #"org" - "session" + #"session" "shared" #"tests" - "tools" + #"tools" #"web_automation" #"whatsapp" ) @@ -47,10 +47,10 @@ done # Also append the specific files you mentioned cat "$PROJECT_ROOT/src/main.rs" >> "$OUTPUT_FILE" -cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE" - -echo "This BASIC file will run as soon as the conversation is created. " >> "$OUTPUT_FILE" -cat "$PROJECT_ROOT/templates/annoucements.gbai/annoucements.gbdialog/start.bas" >> "$OUTPUT_FILE" +# cat "$PROJECT_ROOT/src/basic/keywords/hear_talk.rs" >> "$OUTPUT_FILE" +# cat "$PROJECT_ROOT/src/basic/keywords/get.rs" >> "$OUTPUT_FILE" +cat "$PROJECT_ROOT/src/basic/keywords/llm_keyword.rs" >> "$OUTPUT_FILE" +# cat "$PROJECT_ROOT/src/basic/mod.rs" >> "$OUTPUT_FILE" echo "" >> "$OUTPUT_FILE" diff --git a/src/automation/mod.rs b/src/automation/mod.rs index 483cb40c2..90c8332ef 100644 --- a/src/automation/mod.rs +++ b/src/automation/mod.rs @@ -10,12 +10,12 @@ use tokio::time::Duration; use uuid::Uuid; pub struct AutomationService { - state: AppState, + state: Arc, scripts_dir: String, } impl AutomationService { - pub fn new(state: AppState, scripts_dir: &str) -> Self { + pub fn new(state: Arc, scripts_dir: &str) -> Self { Self { state, scripts_dir: scripts_dir.to_string(), @@ -61,13 +61,11 @@ impl AutomationService { async fn check_table_changes(&self, automations: &[Automation], since: DateTime) { for automation in automations { - // Resolve the trigger kind, disambiguating the `from_i32` call. let trigger_kind = match crate::shared::models::TriggerKind::from_i32(automation.kind) { Some(k) => k, None => continue, }; - // We're only interested in table‑change triggers. if !matches!( trigger_kind, TriggerKind::TableUpdate | TriggerKind::TableInsert | TriggerKind::TableDelete @@ -75,50 +73,41 @@ impl AutomationService { continue; } - // Table name must be present. let table = match &automation.target { Some(t) => t, None => continue, }; - // Choose the appropriate timestamp column. let column = match trigger_kind { TriggerKind::TableInsert => "created_at", _ => "updated_at", }; - // Build a simple COUNT(*) query; alias the column so Diesel can map it directly to i64. let query = format!( "SELECT COUNT(*) as count FROM {} WHERE {} > $1", table, column ); - // Acquire a connection for this query. let mut conn_guard = self.state.conn.lock().unwrap(); let conn = &mut *conn_guard; - // Define a struct to capture the query result #[derive(diesel::QueryableByName)] struct CountResult { #[diesel(sql_type = diesel::sql_types::BigInt)] count: i64, } - // Execute the query, retrieving a plain i64 count. let count_result = diesel::sql_query(&query) .bind::(since.naive_utc()) .get_result::(conn); match count_result { Ok(result) if result.count > 0 => { - // Release the lock before awaiting asynchronous work. drop(conn_guard); self.execute_action(&automation.param).await; self.update_last_triggered(automation.id).await; } - Ok(_result) => { - // No relevant rows changed; continue to the next automation. - } + Ok(_result) => {} Err(e) => { error!("Error checking changes for table '{}': {}", table, e); } @@ -215,7 +204,7 @@ impl AutomationService { updated_at: Utc::now(), }; - let script_service = ScriptService::new(&self.state, user_session); + let script_service = ScriptService::new(Arc::clone(&self.state), user_session); let ast = match script_service.compile(&script_content) { Ok(ast) => ast, Err(e) => { diff --git a/src/basic/keywords/get.rs b/src/basic/keywords/get.rs index 1e583b3f0..3e813c17d 100644 --- a/src/basic/keywords/get.rs +++ b/src/basic/keywords/get.rs @@ -1,59 +1,45 @@ -use log::info; -use crate::shared::state::AppState; use crate::shared::models::UserSession; +use crate::shared::state::AppState; +use log::info; use reqwest::{self, Client}; use rhai::{Dynamic, Engine}; use std::error::Error; -pub fn get_keyword(_state: &AppState, _user: UserSession, engine: &mut Engine) { +pub fn get_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) { + let state_clone = state.clone(); + engine - .register_custom_syntax( - &["GET", "$expr$"], - false, - move |context, inputs| { - let url = context.eval_expression_tree(&inputs[0])?; - let url_str = url.to_string(); + .register_custom_syntax(&["GET", "$expr$"], false, move |context, inputs| { + let url = context.eval_expression_tree(&inputs[0])?; + let url_str = url.to_string(); - if url_str.contains("..") { - return Err("URL contains invalid path traversal sequences like '..'.".into()); - } + if url_str.contains("..") { + return Err("URL contains invalid path traversal sequences like '..'.".into()); + } - let modified_url = if url_str.starts_with("/") { - let work_root = std::env::var("WORK_ROOT").unwrap_or_else(|_| "./work".to_string()); - let full_path = std::path::Path::new(&work_root) - .join(url_str.trim_start_matches('/')) - .to_string_lossy() - .into_owned(); + let state_for_async = state_clone.clone(); + let url_for_async = url_str.clone(); - let base_url = "file://"; - format!("{}{}", base_url, full_path) - } else { - url_str.to_string() - }; + if url_str.starts_with("https://") { + info!("HTTPS GET request: {}", url_for_async); - if modified_url.starts_with("https://") { - info!("HTTPS GET request: {}", modified_url); + let fut = execute_get(&url_for_async); + let result = + tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) + .map_err(|e| format!("HTTP request failed: {}", e))?; - let fut = execute_get(&modified_url); - let result = - tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) - .map_err(|e| format!("HTTP request failed: {}", e))?; + Ok(Dynamic::from(result)) + } else { + info!("Local file GET request from bucket: {}", url_for_async); - Ok(Dynamic::from(result)) - } else if modified_url.starts_with("file://") { - let file_path = modified_url.trim_start_matches("file://"); - match std::fs::read_to_string(file_path) { - Ok(content) => Ok(Dynamic::from(content)), - Err(e) => Err(format!("Failed to read file: {}", e).into()), - } - } else { - Err( - format!("GET request failed: URL must begin with 'https://' or 'file://'") - .into(), - ) - } - }, - ) + let fut = get_from_bucket(&state_for_async, &url_for_async); + let result = + tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) + .map_err(|e| format!("Bucket GET failed: {}", e))?; + + Ok(Dynamic::from(result)) + } + }) .unwrap(); } @@ -69,3 +55,29 @@ pub async fn execute_get(url: &str) -> Result Result> { + info!("Getting file from bucket: {}", file_path); + + if let Some(s3_client) = &state.s3_client { + let bucket_name = + std::env::var("DEFAULT_BUCKET").unwrap_or_else(|_| "default-bucket".to_string()); + + let response = s3_client + .get_object() + .bucket(&bucket_name) + .key(file_path) + .send() + .await?; + + let data = response.body.collect().await?; + let content = String::from_utf8(data.into_bytes().to_vec())?; + + Ok(content) + } else { + Err("S3 client not configured".into()) + } +} diff --git a/src/basic/keywords/hear_talk.rs b/src/basic/keywords/hear_talk.rs index 1675265c0..5be24338c 100644 --- a/src/basic/keywords/hear_talk.rs +++ b/src/basic/keywords/hear_talk.rs @@ -1,10 +1,11 @@ -use crate::shared::models::UserSession; use crate::shared::state::AppState; +use crate::{channels::ChannelAdapter, shared::models::UserSession}; use log::info; use rhai::{Dynamic, Engine, EvalAltResult}; -pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) { +pub fn hear_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { let session_id = user.id; + let cache = state.redis_client.clone(); engine .register_custom_syntax(&["HEAR", "$ident$"], true, move |_context, inputs| { @@ -18,19 +19,35 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) { variable_name ); - // Spawn a background task to handle the input‑waiting logic. - // The actual waiting implementation should be added here. + let cache_clone = cache.clone(); + let session_id_clone = session_id; + let var_name_clone = variable_name.clone(); + tokio::spawn(async move { log::debug!( "HEAR: Starting async task for session {} and variable '{}'", - session_id, - variable_name + session_id_clone, + var_name_clone ); - // TODO: implement actual waiting logic here without using the orchestrator - // For now, just log that we would wait for input + + if let Some(cache_client) = &cache_clone { + let mut conn = match cache_client.get_multiplexed_async_connection().await { + Ok(conn) => conn, + Err(e) => { + log::error!("Failed to connect to cache: {}", e); + return; + } + }; + + let key = format!("hear:{}:{}", session_id_clone, var_name_clone); + let _: Result<(), _> = redis::cmd("SET") + .arg(&key) + .arg("waiting") + .query_async(&mut conn) + .await; + } }); - // Interrupt the current Rhai evaluation flow until the user input is received. Err(Box::new(EvalAltResult::ErrorRuntime( "Waiting for user input".into(), rhai::Position::NONE, @@ -40,7 +57,6 @@ pub fn hear_keyword(_state: &AppState, user: UserSession, engine: &mut Engine) { } pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { - // Import the BotResponse type directly to satisfy diagnostics. use crate::shared::models::BotResponse; let state_clone = state.clone(); @@ -63,13 +79,11 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { is_complete: true, }; - // Send response through a channel or queue instead of accessing orchestrator directly - let _state_for_spawn = state_clone.clone(); + let state_for_spawn = state_clone.clone(); tokio::spawn(async move { - // Use a more thread-safe approach to send the message - // This avoids capturing the orchestrator directly which isn't Send + Sync - // TODO: Implement proper response handling once response_sender field is added to AppState - log::debug!("TALK: Would send response: {:?}", response); + if let Err(e) = state_for_spawn.web_adapter.send_message(response).await { + log::error!("Failed to send TALK message: {}", e); + } }); Ok(Dynamic::UNIT) @@ -78,7 +92,7 @@ pub fn talk_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { } pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Engine) { - let state_clone = state.clone(); + let cache = state.redis_client.clone(); engine .register_custom_syntax( @@ -91,14 +105,14 @@ pub fn set_context_keyword(state: &AppState, user: UserSession, engine: &mut Eng let redis_key = format!("context:{}:{}", user.user_id, user.id); - let state_for_redis = state_clone.clone(); + let cache_clone = cache.clone(); tokio::spawn(async move { - if let Some(redis_client) = &state_for_redis.redis_client { - let mut conn = match redis_client.get_multiplexed_async_connection().await { + if let Some(cache_client) = &cache_clone { + let mut conn = match cache_client.get_multiplexed_async_connection().await { Ok(conn) => conn, Err(e) => { - log::error!("Failed to connect to Redis: {}", e); + log::error!("Failed to connect to cache: {}", e); return; } }; diff --git a/src/basic/keywords/llm_keyword.rs b/src/basic/keywords/llm_keyword.rs index 2554db799..8c4bbeb1f 100644 --- a/src/basic/keywords/llm_keyword.rs +++ b/src/basic/keywords/llm_keyword.rs @@ -1,11 +1,10 @@ use crate::shared::models::UserSession; use crate::shared::state::AppState; -use crate::shared::utils::call_llm; use log::info; use rhai::{Dynamic, Engine}; pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) { - let ai_config = state.config.clone().unwrap().ai.clone(); + let state_clone = state.clone(); engine .register_custom_syntax(&["LLM", "$expr$"], false, move |context, inputs| { @@ -14,10 +13,18 @@ pub fn llm_keyword(state: &AppState, _user: UserSession, engine: &mut Engine) { info!("LLM processing text: {}", text_str); - let fut = call_llm(&text_str, &ai_config); + let state_inner = state_clone.clone(); + let fut = async move { + let prompt = text_str; + state_inner + .llm_provider + .generate(&prompt, &serde_json::Value::Null) + .await + .map_err(|e| format!("LLM call failed: {}", e)) + }; + let result = - tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut)) - .map_err(|e| format!("LLM call failed: {}", e))?; + tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(fut))?; Ok(Dynamic::from(result)) }) diff --git a/src/basic/mod.rs b/src/basic/mod.rs index 6506ce5f3..643e1c27e 100644 --- a/src/basic/mod.rs +++ b/src/basic/mod.rs @@ -1,7 +1,10 @@ -pub mod keywords; +use crate::shared::models::UserSession; +use crate::shared::state::AppState; +use log::info; +use rhai::{Dynamic, Engine, EvalAltResult}; +use std::sync::Arc; -#[cfg(feature = "email")] -use self::keywords::create_draft_keyword; +pub mod keywords; use self::keywords::create_site::create_site_keyword; use self::keywords::find::find_keyword; @@ -17,43 +20,54 @@ use self::keywords::print::print_keyword; use self::keywords::set::set_keyword; use self::keywords::set_schedule::set_schedule_keyword; use self::keywords::wait::wait_keyword; -use crate::shared::models::UserSession; -use crate::shared::state::AppState; -use log::info; -use rhai::{Dynamic, Engine, EvalAltResult}; + +#[cfg(feature = "email")] +use self::keywords::create_draft_keyword; + +#[cfg(feature = "web_automation")] +use self::keywords::get_website::get_website_keyword; pub struct ScriptService { engine: Engine, + state: Arc, + user: UserSession, } impl ScriptService { - pub fn new(state: &AppState, user: UserSession) -> Self { + pub fn new(state: Arc, user: UserSession) -> Self { let mut engine = Engine::new(); engine.set_allow_anonymous_fn(true); engine.set_allow_looping(true); #[cfg(feature = "email")] - create_draft_keyword(state, user.clone(), &mut engine); + create_draft_keyword(&state, user.clone(), &mut engine); - create_site_keyword(state, user.clone(), &mut engine); - find_keyword(state, user.clone(), &mut engine); - for_keyword(state, user.clone(), &mut engine); + create_site_keyword(&state, user.clone(), &mut engine); + find_keyword(&state, user.clone(), &mut engine); + for_keyword(&state, user.clone(), &mut engine); first_keyword(&mut engine); last_keyword(&mut engine); format_keyword(&mut engine); - llm_keyword(state, user.clone(), &mut engine); - get_keyword(state, user.clone(), &mut engine); - set_keyword(state, user.clone(), &mut engine); - wait_keyword(state, user.clone(), &mut engine); - print_keyword(state, user.clone(), &mut engine); - on_keyword(state, user.clone(), &mut engine); - set_schedule_keyword(state, user.clone(), &mut engine); - hear_keyword(state, user.clone(), &mut engine); - talk_keyword(state, user.clone(), &mut engine); - set_context_keyword(state, user.clone(), &mut engine); + llm_keyword(&state, user.clone(), &mut engine); + get_keyword(&state, user.clone(), &mut engine); + set_keyword(&state, user.clone(), &mut engine); + wait_keyword(&state, user.clone(), &mut engine); + print_keyword(&state, user.clone(), &mut engine); + on_keyword(&state, user.clone(), &mut engine); + set_schedule_keyword(&state, user.clone(), &mut engine); + hear_keyword(&state, user.clone(), &mut engine); + talk_keyword(&state, user.clone(), &mut engine); + set_context_keyword(&state, user.clone(), &mut engine); - ScriptService { engine } + #[cfg(feature = "web_automation")] + get_website_keyword(&state, user.clone(), &mut engine); + + ScriptService { + engine, + state, + user, + } } fn preprocess_basic_script(&self, script: &str) -> String { diff --git a/src/bot/mod.rs b/src/bot/mod.rs index 4b43dbd67..1e7fbd49a 100644 --- a/src/bot/mod.rs +++ b/src/bot/mod.rs @@ -6,40 +6,20 @@ use serde_json; use std::collections::HashMap; use std::fs; use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::mpsc; use uuid::Uuid; -use crate::auth::AuthService; use crate::channels::ChannelAdapter; -use crate::llm::LLMProvider; -use crate::session::SessionManager; use crate::shared::models::{BotResponse, UserMessage, UserSession}; -use crate::tools::ToolManager; +use crate::shared::state::AppState; pub struct BotOrchestrator { - pub session_manager: Arc>, - tool_manager: Arc, - llm_provider: Arc, - auth_service: Arc>, - pub channels: HashMap>, - response_channels: Arc>>>, + pub state: Arc, } impl BotOrchestrator { - pub fn new( - session_manager: SessionManager, - tool_manager: ToolManager, - llm_provider: Arc, - auth_service: AuthService, - ) -> Self { - Self { - session_manager: Arc::new(Mutex::new(session_manager)), - tool_manager: Arc::new(tool_manager), - llm_provider, - auth_service: Arc::new(Mutex::new(auth_service)), - channels: HashMap::new(), - response_channels: Arc::new(Mutex::new(HashMap::new())), - } + pub fn new(state: Arc) -> Self { + Self { state } } pub async fn handle_user_input( @@ -47,18 +27,22 @@ impl BotOrchestrator { session_id: Uuid, user_input: &str, ) -> Result, Box> { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.provide_input(session_id, user_input.to_string())?; Ok(None) } pub async fn is_waiting_for_input(&self, session_id: Uuid) -> bool { - let session_manager = self.session_manager.lock().await; + let session_manager = self.state.session_manager.lock().await; session_manager.is_waiting_for_input(&session_id) } - pub fn add_channel(&mut self, channel_type: &str, adapter: Arc) { - self.channels.insert(channel_type.to_string(), adapter); + pub fn add_channel(&self, channel_type: &str, adapter: Arc) { + self.state + .channels + .lock() + .unwrap() + .insert(channel_type.to_string(), adapter); } pub async fn register_response_channel( @@ -66,7 +50,8 @@ impl BotOrchestrator { session_id: String, sender: mpsc::Sender, ) { - self.response_channels + self.state + .response_channels .lock() .await .insert(session_id, sender); @@ -78,7 +63,7 @@ impl BotOrchestrator { bot_id: &str, mode: i32, ) -> Result<(), Box> { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.update_answer_mode(user_id, bot_id, mode)?; Ok(()) } @@ -97,10 +82,15 @@ impl BotOrchestrator { .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); let session = { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; match session_manager.get_user_session(user_id, bot_id)? { Some(session) => session, - None => session_manager.create_session(user_id, bot_id, "New Conversation")?, + None => { + let new_session = + session_manager.create_session(user_id, bot_id, "New Conversation")?; + Self::run_start_script(&new_session, Arc::clone(&self.state)).await; + new_session + } } }; @@ -113,7 +103,7 @@ impl BotOrchestrator { variable_name, session.id ); - if let Some(adapter) = self.channels.get(&message.channel) { + if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) { let ack_response = BotResponse { bot_id: message.bot_id.clone(), user_id: message.user_id.clone(), @@ -131,7 +121,7 @@ impl BotOrchestrator { } if session.answer_mode == 1 && session.current_tool.is_some() { - self.tool_manager.provide_user_response( + self.state.tool_manager.provide_user_response( &message.user_id, &message.bot_id, message.content.clone(), @@ -140,7 +130,7 @@ impl BotOrchestrator { } { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message( session.id, user_id, @@ -153,7 +143,7 @@ impl BotOrchestrator { let response_content = self.direct_mode_handler(&message, &session).await?; { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message(session.id, user_id, 2, &response_content, 1)?; } @@ -168,7 +158,7 @@ impl BotOrchestrator { is_complete: true, }; - if let Some(adapter) = self.channels.get(&message.channel) { + if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) { adapter.send_message(bot_response).await?; } @@ -180,21 +170,19 @@ impl BotOrchestrator { message: &UserMessage, session: &UserSession, ) -> Result> { - // Retrieve conversation history while holding a mutable lock on the session manager. let history = { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.get_conversation_history(session.id, session.user_id)? }; - // Build the prompt from the conversation history. let mut prompt = String::new(); for (role, content) in history { prompt.push_str(&format!("{}: {}\n", role, content)); } prompt.push_str(&format!("User: {}\nAssistant:", message.content)); - // Generate the assistant's response using the LLM provider. - self.llm_provider + self.state + .llm_provider .generate(&prompt, &serde_json::Value::Null) .await } @@ -206,31 +194,31 @@ impl BotOrchestrator { ) -> Result<(), Box> { info!("Streaming response for user: {}", message.user_id); - // Parse identifiers, falling back to safe defaults. let mut user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4()); let bot_id = Uuid::parse_str(&message.bot_id).unwrap_or_else(|_| Uuid::nil()); - let mut auth = self.auth_service.lock().await; + let mut auth = self.state.auth_service.lock().await; let user_exists = auth.get_user_by_id(user_id)?; if user_exists.is_none() { - // User does not exist, invoke Authentication service to create them user_id = auth.create_user("anonymous1", "anonymous@local", "password")?; } else { user_id = user_exists.unwrap().id; } - // Retrieve an existing session or create a new one. let session = { - let mut sm = self.session_manager.lock().await; + let mut sm = self.state.session_manager.lock().await; match sm.get_user_session(user_id, bot_id)? { Some(sess) => sess, - None => sm.create_session(user_id, bot_id, "New Conversation")?, + None => { + let new_session = sm.create_session(user_id, bot_id, "New Conversation")?; + Self::run_start_script(&new_session, Arc::clone(&self.state)).await; + new_session + } } }; - // If the session is awaiting tool input, forward the user's answer to the tool manager. if session.answer_mode == 1 && session.current_tool.is_some() { - self.tool_manager.provide_user_response( + self.state.tool_manager.provide_user_response( &message.user_id, &message.bot_id, message.content.clone(), @@ -238,9 +226,8 @@ impl BotOrchestrator { return Ok(()); } - // Persist the incoming user message. { - let mut sm = self.session_manager.lock().await; + let mut sm = self.state.session_manager.lock().await; sm.save_message( session.id, user_id, @@ -250,9 +237,8 @@ impl BotOrchestrator { )?; } - // Build the prompt from the conversation history. let prompt = { - let mut sm = self.session_manager.lock().await; + let mut sm = self.state.session_manager.lock().await; let history = sm.get_conversation_history(session.id, user_id)?; let mut p = String::new(); for (role, content) in history { @@ -262,11 +248,9 @@ impl BotOrchestrator { p }; - // Set up a channel for the streaming LLM output. let (stream_tx, mut stream_rx) = mpsc::channel::(100); - let llm = self.llm_provider.clone(); + let llm = self.state.llm_provider.clone(); - // Spawn the LLM streaming task. tokio::spawn(async move { if let Err(e) = llm .generate_stream(&prompt, &serde_json::Value::Null, stream_tx) @@ -276,7 +260,6 @@ impl BotOrchestrator { } }); - // Forward each chunk to the client as it arrives. let mut full_response = String::new(); while let Some(chunk) = stream_rx.recv().await { full_response.push_str(&chunk); @@ -293,18 +276,15 @@ impl BotOrchestrator { }; if response_tx.send(partial).await.is_err() { - // Receiver has been dropped; stop streaming. break; } } - // Save the complete assistant reply. { - let mut sm = self.session_manager.lock().await; + let mut sm = self.state.session_manager.lock().await; sm.save_message(session.id, user_id, 2, &full_response, 1)?; } - // Notify the client that the stream is finished. let final_msg = BotResponse { bot_id: message.bot_id, user_id: message.user_id, @@ -324,7 +304,7 @@ impl BotOrchestrator { &self, user_id: Uuid, ) -> Result, Box> { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.get_user_sessions(user_id) } @@ -333,7 +313,7 @@ impl BotOrchestrator { session_id: Uuid, user_id: Uuid, ) -> Result, Box> { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.get_conversation_history(session_id, user_id) } @@ -351,15 +331,20 @@ impl BotOrchestrator { .unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()); let session = { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; match session_manager.get_user_session(user_id, bot_id)? { Some(session) => session, - None => session_manager.create_session(user_id, bot_id, "New Conversation")?, + None => { + let new_session = + session_manager.create_session(user_id, bot_id, "New Conversation")?; + Self::run_start_script(&new_session, Arc::clone(&self.state)).await; + new_session + } } }; { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message( session.id, user_id, @@ -370,17 +355,24 @@ impl BotOrchestrator { } let is_tool_waiting = self + .state .tool_manager .is_tool_waiting(&message.session_id) .await .unwrap_or(false); if is_tool_waiting { - self.tool_manager + self.state + .tool_manager .provide_input(&message.session_id, &message.content) .await?; - if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await { + if let Ok(tool_output) = self + .state + .tool_manager + .get_tool_output(&message.session_id) + .await + { for output in tool_output { let bot_response = BotResponse { bot_id: message.bot_id.clone(), @@ -393,7 +385,8 @@ impl BotOrchestrator { is_complete: true, }; - if let Some(adapter) = self.channels.get(&message.channel) { + if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) + { adapter.send_message(bot_response).await?; } } @@ -406,12 +399,13 @@ impl BotOrchestrator { || message.content.to_lowercase().contains("math") { match self + .state .tool_manager .execute_tool("calculator", &message.session_id, &message.user_id) .await { Ok(tool_result) => { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message(session.id, user_id, 2, &tool_result.output, 2)?; tool_result.output @@ -421,7 +415,7 @@ impl BotOrchestrator { } } } else { - let available_tools = self.tool_manager.list_tools(); + let available_tools = self.state.tool_manager.list_tools(); let tools_context = if !available_tools.is_empty() { format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", ")) } else { @@ -430,13 +424,14 @@ impl BotOrchestrator { let full_prompt = format!("{}{}", message.content, tools_context); - self.llm_provider + self.state + .llm_provider .generate(&full_prompt, &serde_json::Value::Null) .await? }; { - let mut session_manager = self.session_manager.lock().await; + let mut session_manager = self.state.session_manager.lock().await; session_manager.save_message(session.id, user_id, 2, &response, 1)?; } @@ -451,25 +446,63 @@ impl BotOrchestrator { is_complete: true, }; - if let Some(adapter) = self.channels.get(&message.channel) { + if let Some(adapter) = self.state.channels.lock().unwrap().get(&message.channel) { adapter.send_message(bot_response).await?; } Ok(()) } + + async fn run_start_script(session: &UserSession, state: Arc) { + let start_script = r#" +TALK "Welcome to General Bots!" +HEAR name +TALK "Hello, " + name + +text = GET "default.pdf" +SET CONTEXT text + +resume = LLM "Build a resume from " + text +"#; + + info!("Running start.bas for session: {}", session.id); + + let session_clone = session.clone(); + let state_clone = state.clone(); + tokio::spawn(async move { + let state_for_run = state_clone.clone(); + if let Err(e) = crate::basic::ScriptService::new(state_clone, session_clone.clone()) + .compile(start_script) + .and_then(|ast| { + crate::basic::ScriptService::new(state_for_run, session_clone.clone()).run(&ast) + }) + { + log::error!("Failed to run start.bas: {}", e); + } + }); + } +} + +impl Default for BotOrchestrator { + fn default() -> Self { + Self { + state: Arc::new(AppState::default()), + } + } } #[actix_web::get("/ws")] async fn websocket_handler( req: HttpRequest, stream: web::Payload, - data: web::Data, + data: web::Data, ) -> Result { let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; let session_id = Uuid::new_v4().to_string(); let (tx, mut rx) = mpsc::channel::(100); - data.orchestrator + let orchestrator = BotOrchestrator::new(Arc::clone(&data)); + orchestrator .register_response_channel(session_id.clone(), tx.clone()) .await; data.web_adapter @@ -479,7 +512,6 @@ async fn websocket_handler( .add_connection(session_id.clone(), tx.clone()) .await; - let orchestrator = data.orchestrator.clone(); let web_adapter = data.web_adapter.clone(); actix_web::rt::spawn(async move { @@ -523,7 +555,7 @@ async fn websocket_handler( #[actix_web::get("/api/whatsapp/webhook")] async fn whatsapp_webhook_verify( - data: web::Data, + data: web::Data, web::Query(params): web::Query>, ) -> Result { let empty = String::new(); @@ -539,7 +571,7 @@ async fn whatsapp_webhook_verify( #[actix_web::post("/api/whatsapp/webhook")] async fn whatsapp_webhook( - data: web::Data, + data: web::Data, payload: web::Json, ) -> Result { match data @@ -549,7 +581,8 @@ async fn whatsapp_webhook( { Ok(user_messages) => { for user_message in user_messages { - if let Err(e) = data.orchestrator.process_message(user_message).await { + let orchestrator = BotOrchestrator::new(Arc::clone(&data)); + if let Err(e) = orchestrator.process_message(user_message).await { log::error!("Error processing WhatsApp message: {}", e); } } @@ -564,7 +597,7 @@ async fn whatsapp_webhook( #[actix_web::post("/api/voice/start")] async fn voice_start( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -593,7 +626,7 @@ async fn voice_start( #[actix_web::post("/api/voice/stop")] async fn voice_stop( - data: web::Data, + data: web::Data, info: web::Json, ) -> Result { let session_id = info @@ -611,7 +644,7 @@ async fn voice_stop( } #[actix_web::post("/api/sessions")] -async fn create_session(_data: web::Data) -> Result { +async fn create_session(_data: web::Data) -> Result { let session_id = Uuid::new_v4(); Ok(HttpResponse::Ok().json(serde_json::json!({ "session_id": session_id, @@ -621,9 +654,10 @@ async fn create_session(_data: web::Data) -> Res } #[actix_web::get("/api/sessions")] -async fn get_sessions(data: web::Data) -> Result { +async fn get_sessions(data: web::Data) -> Result { let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); - match data.orchestrator.get_user_sessions(user_id).await { + let orchestrator = BotOrchestrator::new(Arc::clone(&data)); + match orchestrator.get_user_sessions(user_id).await { Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), Err(e) => { Ok(HttpResponse::InternalServerError() @@ -634,22 +668,24 @@ async fn get_sessions(data: web::Data) -> Result #[actix_web::get("/api/sessions/{session_id}")] async fn get_session_history( - data: web::Data, + data: web::Data, path: web::Path, ) -> Result { let session_id = path.into_inner(); let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); match Uuid::parse_str(&session_id) { - Ok(session_uuid) => match data - .orchestrator - .get_conversation_history(session_uuid, user_id) - .await - { - Ok(history) => Ok(HttpResponse::Ok().json(history)), - Err(e) => Ok(HttpResponse::InternalServerError() - .json(serde_json::json!({"error": e.to_string()}))), - }, + Ok(session_uuid) => { + let orchestrator = BotOrchestrator::new(Arc::clone(&data)); + match orchestrator + .get_conversation_history(session_uuid, user_id) + .await + { + Ok(history) => Ok(HttpResponse::Ok().json(history)), + Err(e) => Ok(HttpResponse::InternalServerError() + .json(serde_json::json!({"error": e.to_string()}))), + } + } Err(_) => { Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"}))) } @@ -658,7 +694,7 @@ async fn get_session_history( #[actix_web::post("/api/set_mode")] async fn set_mode_handler( - data: web::Data, + data: web::Data, info: web::Json>, ) -> Result { let default_user = "default_user".to_string(); @@ -671,8 +707,8 @@ async fn set_mode_handler( let mode = mode_str.parse::().unwrap_or(0); - if let Err(e) = data - .orchestrator + let orchestrator = BotOrchestrator::new(Arc::clone(&data)); + if let Err(e) = orchestrator .set_user_answer_mode(user_id, bot_id, mode) .await { diff --git a/src/main.rs b/src/main.rs index f88dae4a1..2d86b417a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use actix_web::middleware::Logger; use actix_web::{web, App, HttpServer}; use dotenvy::dotenv; use log::info; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; mod auth; @@ -105,7 +106,7 @@ async fn main() -> std::io::Result<()> { }; let tool_manager = Arc::new(tools::ToolManager::new()); - let llm_provider = Arc::new(llm::OpenAIClient::new( + let llm_provider = Arc::new(crate::llm::OpenAIClient::new( "empty".to_string(), Some("http://localhost:8081".to_string()), )); @@ -125,69 +126,40 @@ async fn main() -> std::io::Result<()> { let tool_api = Arc::new(tools::ToolApi::new()); - let base_app_state = AppState { + let session_manager = Arc::new(tokio::sync::Mutex::new(session::SessionManager::new( + diesel::Connection::establish(&cfg.database_url()).unwrap(), + redis_client.clone(), + ))); + + let auth_service = Arc::new(tokio::sync::Mutex::new(auth::AuthService::new( + diesel::Connection::establish(&cfg.database_url()).unwrap(), + redis_client.clone(), + ))); + + let app_state = Arc::new(AppState { s3_client: None, config: Some(cfg.clone()), conn: db_pool.clone(), custom_conn: db_custom_pool.clone(), redis_client: redis_client.clone(), - orchestrator: Arc::new(bot::BotOrchestrator::new( - session::SessionManager::new( - diesel::Connection::establish(&cfg.database_url()).unwrap(), - redis_client.clone(), - ), - (*tool_manager).clone(), - llm_provider.clone(), - auth::AuthService::new( - diesel::Connection::establish(&cfg.database_url()).unwrap(), - redis_client.clone(), - ), - )), - web_adapter, - voice_adapter, - whatsapp_adapter, - tool_api, - }; + session_manager: session_manager.clone(), + tool_manager: tool_manager.clone(), + llm_provider: llm_provider.clone(), + auth_service: auth_service.clone(), + channels: Arc::new(Mutex::new(HashMap::new())), + response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + web_adapter: web_adapter.clone(), + voice_adapter: voice_adapter.clone(), + whatsapp_adapter: whatsapp_adapter.clone(), + tool_api: tool_api.clone(), + }); info!( "Starting server on {}:{}", config.server.host, config.server.port ); - let closure_config = config.clone(); - HttpServer::new(move || { - let cfg = closure_config.clone(); - - let auth_service = auth::AuthService::new( - diesel::Connection::establish(&cfg.database_url()).unwrap(), - redis_client.clone(), - ); - let session_manager = session::SessionManager::new( - diesel::Connection::establish(&cfg.database_url()).unwrap(), - redis_client.clone(), - ); - - let orchestrator = Arc::new(bot::BotOrchestrator::new( - session_manager, - (*tool_manager).clone(), - llm_provider.clone(), - auth_service, - )); - - let app_state = AppState { - s3_client: base_app_state.s3_client.clone(), - config: base_app_state.config.clone(), - conn: base_app_state.conn.clone(), - custom_conn: base_app_state.custom_conn.clone(), - redis_client: base_app_state.redis_client.clone(), - orchestrator, - web_adapter: base_app_state.web_adapter.clone(), - voice_adapter: base_app_state.voice_adapter.clone(), - whatsapp_adapter: base_app_state.whatsapp_adapter.clone(), - tool_api: base_app_state.tool_api.clone(), - }; - let cors = Cors::default() .allow_any_origin() .allow_any_method() @@ -200,7 +172,7 @@ async fn main() -> std::io::Result<()> { .wrap(cors) .wrap(Logger::default()) .wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i")) - .app_data(web::Data::new(app_state_clone)); + .app_data(web::Data::from(app_state_clone)); app = app .service(upload_file) diff --git a/src/shared/state.rs b/src/shared/state.rs index 87673950b..07eeaf8ef 100644 --- a/src/shared/state.rs +++ b/src/shared/state.rs @@ -1,21 +1,33 @@ -use crate::bot::BotOrchestrator; -use crate::channels::{VoiceAdapter, WebChannelAdapter}; +use crate::auth::AuthService; +use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter}; use crate::config::AppConfig; -use crate::tools::ToolApi; +use crate::llm::LLMProvider; +use crate::session::SessionManager; +use crate::tools::{ToolApi, ToolManager}; use crate::whatsapp::WhatsAppAdapter; -use diesel::PgConnection; +use diesel::{Connection, PgConnection}; use redis::Client; +use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; +use tokio::sync::mpsc; + +use crate::shared::models::BotResponse; pub struct AppState { pub s3_client: Option, pub config: Option, pub conn: Arc>, pub custom_conn: Arc>, - pub redis_client: Option>, - pub orchestrator: Arc, + + pub session_manager: Arc>, + pub tool_manager: Arc, + pub llm_provider: Arc, + pub auth_service: Arc>, + pub channels: Arc>>>, + pub response_channels: Arc>>>, + pub web_adapter: Arc, pub voice_adapter: Arc, pub whatsapp_adapter: Arc, @@ -30,7 +42,12 @@ impl Clone for AppState { conn: Arc::clone(&self.conn), custom_conn: Arc::clone(&self.custom_conn), redis_client: self.redis_client.clone(), - orchestrator: Arc::clone(&self.orchestrator), + session_manager: Arc::clone(&self.session_manager), + tool_manager: Arc::clone(&self.tool_manager), + llm_provider: Arc::clone(&self.llm_provider), + auth_service: Arc::clone(&self.auth_service), + channels: Arc::clone(&self.channels), + response_channels: Arc::clone(&self.response_channels), web_adapter: Arc::clone(&self.web_adapter), voice_adapter: Arc::clone(&self.voice_adapter), whatsapp_adapter: Arc::clone(&self.whatsapp_adapter), @@ -38,3 +55,46 @@ impl Clone for AppState { } } } + +impl Default for AppState { + fn default() -> Self { + Self { + s3_client: None, + config: None, + conn: Arc::new(Mutex::new( + diesel::PgConnection::establish("postgres://localhost/test").unwrap(), + )), + custom_conn: Arc::new(Mutex::new( + diesel::PgConnection::establish("postgres://localhost/test").unwrap(), + )), + redis_client: None, + session_manager: Arc::new(tokio::sync::Mutex::new(SessionManager::new( + diesel::PgConnection::establish("postgres://localhost/test").unwrap(), + None, + ))), + tool_manager: Arc::new(ToolManager::new()), + llm_provider: Arc::new(crate::llm::OpenAIClient::new( + "empty".to_string(), + Some("http://localhost:8081".to_string()), + )), + auth_service: Arc::new(tokio::sync::Mutex::new(AuthService::new( + diesel::PgConnection::establish("postgres://localhost/test").unwrap(), + None, + ))), + channels: Arc::new(Mutex::new(HashMap::new())), + response_channels: Arc::new(tokio::sync::Mutex::new(HashMap::new())), + web_adapter: Arc::new(WebChannelAdapter::new()), + voice_adapter: Arc::new(VoiceAdapter::new( + "https://livekit.example.com".to_string(), + "api_key".to_string(), + "api_secret".to_string(), + )), + whatsapp_adapter: Arc::new(WhatsAppAdapter::new( + "whatsapp_token".to_string(), + "phone_number_id".to_string(), + "verify_token".to_string(), + )), + tool_api: Arc::new(ToolApi::new()), + } + } +}