diff --git a/config/directory_config.json b/config/directory_config.json index 3cdcd2525..f7655638f 100644 --- a/config/directory_config.json +++ b/config/directory_config.json @@ -1,7 +1,7 @@ { "base_url": "http://localhost:8300", "default_org": { - "id": "357870945618100238", + "id": "359341929336406018", "name": "default", "domain": "default.localhost" }, @@ -13,8 +13,8 @@ "first_name": "Admin", "last_name": "User" }, - "admin_token": "RflPqOgYM-BtinaBTyCaY8hX-_koTwC65gCg1Kpf7Sfhlc0ZOLZvIr-XsOYXmckPLBAWzjU", + "admin_token": "JP3vRy5eNankB5x204_ANNe4GYXt3Igb5xZuJNW17Chf42fsAcUh4Iw_U6JatCI7dJkRMsw", "project_id": "", - "client_id": "357870946289254414", - "client_secret": "q20LOjW5Vdjzp57Cw8EuFt7sILEd8VeSeGPvrhB63880GLgaJZpcWeRgUwdGET2x" + "client_id": "359341929806233602", + "client_secret": "ktwHx7cmAIdp7NC2x3BYPD1NWiuvLAWHYHF6EyjbC0gZAgTgjFEDNA7KukXMAgdC" } \ No newline at end of file diff --git a/src/basic/compiler/mod.rs b/src/basic/compiler/mod.rs index 12f01bcb8..8717a39f0 100644 --- a/src/basic/compiler/mod.rs +++ b/src/basic/compiler/mod.rs @@ -25,6 +25,8 @@ pub struct ParamDeclaration { pub example: Option, pub description: String, pub required: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub enum_values: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolDefinition { @@ -80,6 +82,8 @@ pub struct OpenAIProperty { pub description: String, #[serde(skip_serializing_if = "Option::is_none")] pub example: Option, + #[serde(rename = "enum", skip_serializing_if = "Option::is_none")] + pub enum_values: Option>, } #[derive(Debug)] pub struct BasicCompiler { @@ -115,8 +119,12 @@ impl BasicCompiler { .file_stem() .and_then(|s| s.to_str()) .ok_or("Invalid file name")?; + + // Generate ADD SUGGESTION commands for enum parameters + let source_with_suggestions = self.generate_enum_suggestions(&source_content, &tool_def)?; + let ast_path = format!("{output_dir}/{file_name}.ast"); - let ast_content = self.preprocess_basic(&source_content, source_path, self.bot_id)?; + let ast_content = self.preprocess_basic(&source_with_suggestions, source_path, self.bot_id)?; fs::write(&ast_path, &ast_content).map_err(|e| format!("Failed to write AST file: {e}"))?; let (mcp_json, tool_json) = if tool_def.parameters.is_empty() { (None, None) @@ -206,6 +214,36 @@ impl BasicCompiler { .map(|end| rest[start + 1..start + 1 + end].to_string()) }) }); + + // Parse ENUM array directly from PARAM statement + // Syntax: PARAM name AS TYPE ENUM ["value1", "value2", ...] + let enum_values = if let Some(enum_pos) = line.find("ENUM") { + let rest = &line[enum_pos + 4..].trim(); + if let Some(start) = rest.find('[') { + if let Some(end) = rest[start..].find(']') { + let array_content = &rest[start + 1..start + end]; + // Parse the array elements + let values: Vec = array_content + .split(',') + .map(|s| { + s.trim() + .trim_matches('"') + .trim_matches('\'') + .to_string() + }) + .filter(|s| !s.is_empty()) + .collect(); + Some(values) + } else { + None + } + } else { + None + } + } else { + None + }; + let description = if let Some(desc_pos) = line.find("DESCRIPTION") { let rest = &line[desc_pos + 11..].trim(); if let Some(start) = rest.find('"') { @@ -220,12 +258,14 @@ impl BasicCompiler { } else { "".to_string() }; + Ok(Some(ParamDeclaration { name, param_type: Self::normalize_type(¶m_type), example, description, required: true, + enum_values, })) } fn normalize_type(basic_type: &str) -> String { @@ -239,6 +279,62 @@ impl BasicCompiler { _ => "string".to_string(), } } + + /// Generate ADD SUGGESTION commands for parameters with enum values + fn generate_enum_suggestions( + &self, + source: &str, + tool_def: &ToolDefinition, + ) -> Result> { + let mut result = String::new(); + let mut suggestion_lines = Vec::new(); + + // Generate ADD SUGGESTION TEXT commands for each parameter with enum values + // These will send the enum value as a text message when clicked + for param in &tool_def.parameters { + if let Some(ref enum_values) = param.enum_values { + // For each enum value, create a suggestion button + for enum_value in enum_values { + // Use the enum value as both the text to send and the button label + let suggestion_cmd = format!( + "ADD SUGGESTION TEXT \"{}\" AS \"{}\"", + enum_value, enum_value + ); + suggestion_lines.push(suggestion_cmd); + } + } + } + + // Insert suggestions after the DESCRIPTION line (or at end if no DESCRIPTION) + let lines: Vec<&str> = source.lines().collect(); + let mut inserted = false; + + for line in lines.iter() { + result.push_str(line); + result.push('\n'); + + // Insert suggestions after DESCRIPTION line + if !inserted && line.trim().starts_with("DESCRIPTION ") { + // Insert suggestions after this line + for suggestion in &suggestion_lines { + result.push_str(suggestion); + result.push('\n'); + } + inserted = true; + } + } + + // If we didn't find a DESCRIPTION line, insert at the end + if !inserted && !suggestion_lines.is_empty() { + for suggestion in &suggestion_lines { + result.push_str(suggestion); + result.push('\n'); + } + } + + Ok(result) + } + fn generate_mcp_tool( tool_def: &ToolDefinition, ) -> Result> { @@ -279,6 +375,7 @@ impl BasicCompiler { prop_type: param.param_type.clone(), description: param.description.clone(), example: param.example.clone(), + enum_values: param.enum_values.clone(), }, ); if param.required { diff --git a/src/basic/keywords/add_suggestion.rs b/src/basic/keywords/add_suggestion.rs index b92bb62ee..204a61634 100644 --- a/src/basic/keywords/add_suggestion.rs +++ b/src/basic/keywords/add_suggestion.rs @@ -1,6 +1,6 @@ use crate::shared::models::UserSession; use crate::shared::state::AppState; -use log::{error, trace}; +use log::{error, info, trace}; use rhai::{Dynamic, Engine}; use serde_json::json; use std::sync::Arc; @@ -64,9 +64,14 @@ pub fn add_suggestion_keyword( let cache = state.cache.clone(); let cache2 = state.cache.clone(); let cache3 = state.cache.clone(); + let cache4 = state.cache.clone(); + let cache5 = state.cache.clone(); let user_session2 = user_session.clone(); let user_session3 = user_session.clone(); + let user_session4 = user_session.clone(); + let user_session5 = user_session.clone(); + // ADD SUGGESTION "context_name" AS "button text" engine .register_custom_syntax( ["ADD", "SUGGESTION", "$expr$", "AS", "$expr$"], @@ -82,6 +87,45 @@ pub fn add_suggestion_keyword( ) .expect("valid syntax registration"); + // ADD SUGGESTION TEXT "$expr$" AS "button text" + // Creates a suggestion that sends the text as a user message when clicked + engine + .register_custom_syntax( + ["ADD", "SUGGESTION", "TEXT", "$expr$", "AS", "$expr$"], + true, + move |context, inputs| { + let text_value = context.eval_expression_tree(&inputs[0])?.to_string(); + let button_text = context.eval_expression_tree(&inputs[1])?.to_string(); + + add_text_suggestion(cache4.as_ref(), &user_session4, &text_value, &button_text)?; + + Ok(Dynamic::UNIT) + }, + ) + .expect("valid syntax registration"); + + // ADD_SUGGESTION_TOOL "tool_name" AS "button text" - underscore version to avoid syntax conflicts + engine + .register_custom_syntax( + ["ADD_SUGGESTION_TOOL", "$expr$", "AS", "$expr$"], + true, + move |context, inputs| { + let tool_name = context.eval_expression_tree(&inputs[0])?.to_string(); + let button_text = context.eval_expression_tree(&inputs[1])?.to_string(); + + add_tool_suggestion( + cache5.as_ref(), + &user_session5, + &tool_name, + None, + &button_text, + )?; + + Ok(Dynamic::UNIT) + }, + ) + .expect("valid syntax registration"); + engine .register_custom_syntax( ["ADD", "SUGGESTION", "TOOL", "$expr$", "AS", "$expr$"], @@ -210,27 +254,22 @@ fn add_context_suggestion( Ok(()) } -fn add_tool_suggestion( +fn add_text_suggestion( cache: Option<&Arc>, user_session: &UserSession, - tool_name: &str, - params: Option>, + text_value: &str, button_text: &str, ) -> Result<(), Box> { if let Some(cache_client) = cache { let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id); let suggestion = json!({ - "type": "tool", - "tool": tool_name, + "type": "text_value", "text": button_text, + "value": text_value, "action": { - "type": "invoke_tool", - "tool": tool_name, - "params": params, - - - "prompt_for_params": params.is_none() + "type": "send_message", + "message": text_value } }); @@ -250,11 +289,65 @@ fn add_tool_suggestion( match result { Ok(length) => { trace!( - "Added tool suggestion '{}' to session {}, total: {}, has_params: {}", + "Added text suggestion '{}' to session {}, total: {}", + text_value, + user_session.id, + length + ); + } + Err(e) => error!("Failed to add text suggestion to Redis: {}", e), + } + } else { + trace!("No cache configured, text suggestion not added"); + } + + Ok(()) +} + +fn add_tool_suggestion( + cache: Option<&Arc>, + user_session: &UserSession, + tool_name: &str, + params: Option>, + button_text: &str, +) -> Result<(), Box> { + if let Some(cache_client) = cache { + let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id); + + // Create action object and serialize it to JSON string + let action_obj = json!({ + "type": "invoke_tool", + "tool": tool_name, + "params": params, + "prompt_for_params": params.is_none() + }); + let action_str = action_obj.to_string(); + + let suggestion = json!({ + "text": button_text, + "action": action_str + }); + + let mut conn = match cache_client.get_connection() { + Ok(conn) => conn, + Err(e) => { + error!("Failed to connect to cache: {}", e); + return Ok(()); + } + }; + + let result: Result = redis::cmd("RPUSH") + .arg(&redis_key) + .arg(suggestion.to_string()) + .query(&mut conn); + + match result { + Ok(length) => { + info!( + "Added tool suggestion '{}' to session {}, total: {}", tool_name, user_session.id, - length, - params.is_some() + length ); } Err(e) => error!("Failed to add tool suggestion to Redis: {}", e), @@ -266,6 +359,74 @@ fn add_tool_suggestion( Ok(()) } +/// Retrieve suggestions from Valkey/Redis for a given user session +/// Returns a vector of Suggestion structs that can be included in BotResponse +/// Note: This function clears suggestions from Redis after fetching them to prevent duplicates +pub fn get_suggestions( + cache: Option<&Arc>, + user_id: &str, + session_id: &str, +) -> Vec { + let mut suggestions = Vec::new(); + + if let Some(cache_client) = cache { + let redis_key = format!("suggestions:{}:{}", user_id, session_id); + + let mut conn = match cache_client.get_connection() { + Ok(conn) => conn, + Err(e) => { + error!("Failed to connect to cache: {}", e); + return suggestions; + } + }; + + // Get all suggestions from the Redis list + let result: Result, redis::RedisError> = redis::cmd("LRANGE") + .arg(&redis_key) + .arg(0) + .arg(-1) + .query(&mut conn); + + match result { + Ok(items) => { + for item in items { + if let Ok(json) = serde_json::from_str::(&item) { + let suggestion = crate::shared::models::Suggestion { + text: json["text"].as_str().unwrap_or("").to_string(), + context: json["context"].as_str().map(|s| s.to_string()), + action: json.get("action").and_then(|v| serde_json::to_string(v).ok()), + icon: json.get("icon").and_then(|v| v.as_str()).map(|s| s.to_string()), + }; + suggestions.push(suggestion); + } + } + info!( + "[SUGGESTIONS] Retrieved {} suggestions for session {}", + suggestions.len(), + session_id + ); + + // Clear suggestions from Redis after fetching to prevent them from being sent again + if !suggestions.is_empty() { + let _: Result = redis::cmd("DEL") + .arg(&redis_key) + .query(&mut conn); + info!( + "[SUGGESTIONS] Cleared {} suggestions from Redis for session {}", + suggestions.len(), + session_id + ); + } + } + Err(e) => error!("Failed to get suggestions from Redis: {}", e), + } + } else { + info!("[SUGGESTIONS] No cache configured, cannot retrieve suggestions"); + } + + suggestions +} + #[cfg(test)] mod tests { use serde_json::json; diff --git a/src/basic/keywords/bot_memory.rs b/src/basic/keywords/bot_memory.rs index 233cd4bbf..9cbd5b8df 100644 --- a/src/basic/keywords/bot_memory.rs +++ b/src/basic/keywords/bot_memory.rs @@ -8,7 +8,7 @@ use uuid::Uuid; pub fn set_bot_memory_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone = Arc::clone(&state); - let user_clone = user; + let user_clone = user.clone(); engine @@ -110,6 +110,123 @@ pub fn set_bot_memory_keyword(state: Arc, user: UserSession, engine: & }, ) .expect("valid syntax registration"); + + // Also register as underscore function for tools (SET_BOT_MEMORY(key, value)) + let state_clone2 = Arc::clone(&state); + let user_clone2 = user.clone(); + engine.register_fn("SET_BOT_MEMORY", move |key: String, value: String| { + let state_for_spawn = Arc::clone(&state_clone2); + let user_clone_spawn = user_clone2.clone(); + let key_clone = key; + let value_clone = value; + + tokio::spawn(async move { + use crate::shared::models::bot_memories; + + let mut conn = match state_for_spawn.conn.get() { + Ok(conn) => conn, + Err(e) => { + error!( + "Failed to acquire database connection for SET_BOT_MEMORY: {}", + e + ); + return; + } + }; + + let bot_uuid = match Uuid::parse_str(&user_clone_spawn.bot_id.to_string()) { + Ok(uuid) => uuid, + Err(e) => { + error!("Invalid bot ID format: {}", e); + return; + } + }; + + let now = chrono::Utc::now(); + + let existing_memory: Option = bot_memories::table + .filter(bot_memories::bot_id.eq(bot_uuid)) + .filter(bot_memories::key.eq(&key_clone)) + .select(bot_memories::id) + .first(&mut *conn) + .optional() + .unwrap_or(None); + + if let Some(memory_id) = existing_memory { + let update_result = diesel::update( + bot_memories::table.filter(bot_memories::id.eq(memory_id)), + ) + .set(( + bot_memories::value.eq(&value_clone), + bot_memories::updated_at.eq(now), + )) + .execute(&mut *conn); + + match update_result { + Ok(_) => { + trace!( + "Updated bot memory for key: {} with value length: {}", + key_clone, + value_clone.len() + ); + } + Err(e) => { + error!("Failed to update bot memory: {}", e); + } + } + } else { + let insert_result = diesel::insert_into(bot_memories::table) + .values(( + bot_memories::id.eq(Uuid::new_v4()), + bot_memories::bot_id.eq(bot_uuid), + bot_memories::key.eq(&key_clone), + bot_memories::value.eq(&value_clone), + bot_memories::created_at.eq(now), + bot_memories::updated_at.eq(now), + )) + .execute(&mut *conn); + + match insert_result { + Ok(_) => { + trace!( + "Created bot memory for key: {} with value length: {}", + key_clone, + value_clone.len() + ); + } + Err(e) => { + error!("Failed to insert bot memory: {}", e); + } + } + } + }); + }); + + // Also register GET_BOT_MEMORY as underscore function for tools + let state_clone3 = Arc::clone(&state); + let user_clone3 = user.clone(); + engine.register_fn("GET_BOT_MEMORY", move |key_param: String| -> String { + use crate::shared::models::bot_memories; + + let state = Arc::clone(&state_clone3); + let conn_result = state.conn.get(); + + if let Ok(mut conn) = conn_result { + let bot_uuid = user_clone3.bot_id; + + let memory_value: Option = bot_memories::table + .filter(bot_memories::bot_id.eq(bot_uuid)) + .filter(bot_memories::key.eq(&key_param)) + .select(bot_memories::value) + .first(&mut *conn) + .optional() + .unwrap_or(None); + + memory_value.unwrap_or_default() + } else { + String::new() + } + }); } pub fn get_bot_memory_keyword(state: Arc, user: UserSession, engine: &mut Engine) { diff --git a/src/basic/keywords/core_functions.rs b/src/basic/keywords/core_functions.rs index 4f6d58d19..083d8961b 100644 --- a/src/basic/keywords/core_functions.rs +++ b/src/basic/keywords/core_functions.rs @@ -1,7 +1,7 @@ use crate::shared::models::UserSession; use crate::shared::state::AppState; -use log::debug; -use rhai::Engine; +use log::{debug, info}; +use rhai::{Dynamic, Engine}; use std::sync::Arc; use super::arrays::register_array_functions; @@ -28,5 +28,21 @@ pub fn register_core_functions(state: Arc, user: UserSession, engine: register_error_functions(state, user, engine); debug!(" * Error handling functions registered"); + // Register send_mail stub function for tools (traces when mail feature is not available) + engine.register_fn("send_mail", move |to: Dynamic, subject: Dynamic, body: Dynamic, _attachments: Dynamic| -> String { + let to_str = to.to_string(); + let subject_str = subject.to_string(); + let body_str = body.to_string(); + info!( + "[TOOL] send_mail called (mail feature not enabled): to='{}', subject='{}', body_len={}", + to_str, + subject_str, + body_str.len() + ); + // Return a fake message ID + format!("MSG-{:0X}", chrono::Utc::now().timestamp()) + }); + debug!(" * send_mail stub function registered"); + debug!("All core BASIC functions registered successfully"); } diff --git a/src/basic/keywords/data_operations.rs b/src/basic/keywords/data_operations.rs index 1034749e9..307db927b 100644 --- a/src/basic/keywords/data_operations.rs +++ b/src/basic/keywords/data_operations.rs @@ -67,6 +67,7 @@ pub fn register_save_keyword(state: Arc, user: UserSession, engine: &m pub fn register_insert_keyword(state: Arc, user: UserSession, engine: &mut Engine) { let state_clone = Arc::clone(&state); + let user_clone = user.clone(); let user_roles = UserRoles::from_user_session(&user); engine @@ -79,10 +80,18 @@ pub fn register_insert_keyword(state: Arc, user: UserSession, engine: trace!("INSERT into table: {}", table); - let mut conn = state_clone - .conn - .get() - .map_err(|e| format!("DB error: {}", e))?; + // Get bot's database connection instead of main connection + let bot_pool = state_clone + .bot_database_manager + .get_bot_pool(user_clone.bot_id); + + let mut conn = match bot_pool { + Ok(pool) => pool.get().map_err(|e| format!("Bot DB error: {}", e))?, + Err(_) => state_clone + .conn + .get() + .map_err(|e| format!("DB error: {}", e))?, + }; // Check write access if let Err(e) = diff --git a/src/basic/keywords/format.rs b/src/basic/keywords/format.rs index 8f567408d..88d8fd78a 100644 --- a/src/basic/keywords/format.rs +++ b/src/basic/keywords/format.rs @@ -1,62 +1,89 @@ use chrono::{Datelike, NaiveDateTime, Timelike}; use num_format::{Locale, ToFormattedString}; -use rhai::{Dynamic, Engine}; +use rhai::{Dynamic, Engine, Map}; use std::str::FromStr; -pub fn format_keyword(engine: &mut Engine) { - engine - .register_custom_syntax(["FORMAT", "$expr$", "$expr$"], false, { - move |context, inputs| { - let value_dyn = context.eval_expression_tree(&inputs[0])?; - let pattern_dyn = context.eval_expression_tree(&inputs[1])?; - let value_str = value_dyn.to_string(); - let pattern = pattern_dyn.to_string(); - if let Ok(num) = f64::from_str(&value_str) { - let formatted = if pattern.starts_with('N') || pattern.starts_with('C') { - let (prefix, decimals, locale_tag) = parse_pattern(&pattern); - let locale = get_locale(&locale_tag); - let symbol = if prefix == "C" { - get_currency_symbol(&locale_tag) - } else { - "" - }; - let int_part = num.trunc() as i64; - let frac_part = num.fract(); - if decimals == 0 { - format!("{}{}", symbol, int_part.to_formatted_string(&locale)) - } else { - let frac_scaled = - ((frac_part * 10f64.powi(decimals as i32)).round()) as i64; - let decimal_sep = match locale_tag.as_str() { - "pt" | "fr" | "es" | "it" | "de" => ",", - _ => ".", - }; - format!( - "{}{}{}{:0width$}", - symbol, - int_part.to_formatted_string(&locale), - decimal_sep, - frac_scaled, - width = decimals - ) - } - } else { - match pattern.as_str() { - "n" | "F" => format!("{num:.2}"), - "0%" => format!("{:.0}%", num * 100.0), - _ => format!("{num}"), - } - }; - return Ok(Dynamic::from(formatted)); - } - if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") { - let formatted = apply_date_format(&dt, &pattern); - return Ok(Dynamic::from(formatted)); - } - let formatted = apply_text_placeholders(&value_str, &pattern); - Ok(Dynamic::from(formatted)) + +/// Format a value (number, date, or Map from NOW()) according to a pattern +fn format_impl(value: Dynamic, pattern: String) -> Result { + // Handle Map objects (like NOW() which returns a Map with formatted property) + if value.is::() { + let map = value.cast::(); + // Try to get the 'formatted' property first + if let Some(formatted_val) = map.get("formatted") { + let value_str = formatted_val.to_string(); + if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") { + let formatted = apply_date_format(&dt, &pattern); + return Ok(formatted); } - }) - .expect("valid syntax registration"); + } + // If no formatted property or parsing failed, try to construct from components + if let (Some(year), Some(month), Some(day)) = ( + map.get("year").and_then(|v| v.as_int().ok()), + map.get("month").and_then(|v| v.as_int().ok()), + map.get("day").and_then(|v| v.as_int().ok()), + ) { + let value_str = format!("{:04}-{:02}-{:02}", year, month, day); + if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d") { + let formatted = apply_date_format(&dt, &pattern); + return Ok(formatted); + } + } + // Fallback: return empty string for unsupported map format + return Ok(String::new()); + } + + // Handle string/number values + let value_str = value.to_string(); + if let Ok(num) = f64::from_str(&value_str) { + let formatted = if pattern.starts_with('N') || pattern.starts_with('C') { + let (prefix, decimals, locale_tag) = parse_pattern(&pattern); + let locale = get_locale(&locale_tag); + let symbol = if prefix == "C" { + get_currency_symbol(&locale_tag) + } else { + "" + }; + let int_part = num.trunc() as i64; + let frac_part = num.fract(); + if decimals == 0 { + format!("{}{}", symbol, int_part.to_formatted_string(&locale)) + } else { + let frac_scaled = + ((frac_part * 10f64.powi(decimals as i32)).round()) as i64; + let decimal_sep = match locale_tag.as_str() { + "pt" | "fr" | "es" | "it" | "de" => ",", + _ => ".", + }; + format!( + "{}{}{}{:0width$}", + symbol, + int_part.to_formatted_string(&locale), + decimal_sep, + frac_scaled, + width = decimals + ) + } + } else { + match pattern.as_str() { + "n" | "F" => format!("{num:.2}"), + "0%" => format!("{:.0}%", num * 100.0), + _ => format!("{num}"), + } + }; + return Ok(formatted); + } + if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") { + let formatted = apply_date_format(&dt, &pattern); + return Ok(formatted); + } + let formatted = apply_text_placeholders(&value_str, &pattern); + Ok(formatted) +} + +pub fn format_keyword(engine: &mut Engine) { + // Register FORMAT as a regular function with two parameters + engine.register_fn("FORMAT", format_impl); + engine.register_fn("format", format_impl); } fn parse_pattern(pattern: &str) -> (String, usize, String) { let mut prefix = String::new(); diff --git a/src/basic/keywords/mod.rs b/src/basic/keywords/mod.rs index 1540a9b92..ea0f846b4 100644 --- a/src/basic/keywords/mod.rs +++ b/src/basic/keywords/mod.rs @@ -198,6 +198,7 @@ pub fn get_all_keywords() -> Vec { "SEND TEMPLATE".to_string(), "SMS".to_string(), "ADD SUGGESTION".to_string(), + "ADD_SUGGESTION_TOOL".to_string(), "CLEAR SUGGESTIONS".to_string(), "ADD TOOL".to_string(), "CLEAR TOOLS".to_string(), diff --git a/src/basic/keywords/send_mail.rs b/src/basic/keywords/send_mail.rs index eef4ca05a..02c747eb6 100644 --- a/src/basic/keywords/send_mail.rs +++ b/src/basic/keywords/send_mail.rs @@ -257,6 +257,69 @@ pub fn send_mail_keyword(state: Arc, user: UserSession, engine: &mut E }, ) .expect("valid syntax registration"); + + // Register send_mail as a regular function (for function call style: send_mail(to, subject, body, [])) + let state_fn = Arc::clone(&state); + let user_fn = user.clone(); + engine.register_fn("send_mail", move |to: Dynamic, subject: Dynamic, body: Dynamic, attachments: Dynamic| -> String { + // Convert parameters to strings + let to_str = to.to_string(); + let subject_str = subject.to_string(); + // Convert body to string + let body_str = body.to_string(); + // Convert body to string + let body_str = body.to_string(); + + // Convert attachments to Vec + let mut atts = Vec::new(); + if attachments.is_array() { + if let Ok(arr) = attachments.cast::() { + for item in arr.iter() { + atts.push(item.to_string()); + } + } + } else if !attachments.to_string().is_empty() { + atts.push(attachments.to_string()); + } + + // Execute in blocking thread + let (tx, rx) = std::sync::mpsc::channel(); + let state_for_task = Arc::clone(&state_fn); + let user_for_task = user_fn.clone(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build(); + + let result = if let Ok(rt) = rt { + rt.block_on(async move { + execute_send_mail(&state_for_task, &user_for_task, &to_str, &subject_str, &body_str, atts, None).await + }) + } else { + Err("Failed to build tokio runtime".to_string()) + }; + + let _ = tx.send(result); + }); + + match rx.recv_timeout(std::time::Duration::from_secs(30)) { + Ok(Ok(message_id)) => message_id, + Ok(Err(e)) => { + log::error!("send_mail failed: {}", e); + String::new() + }, + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + log::error!("send_mail timed out"); + String::new() + }, + Err(e) => { + log::error!("send_mail thread failed: {}", e); + String::new() + }, + } + }); } async fn execute_send_mail( diff --git a/src/basic/keywords/table_definition.rs b/src/basic/keywords/table_definition.rs index 642b0b8ec..20c7a84fa 100644 --- a/src/basic/keywords/table_definition.rs +++ b/src/basic/keywords/table_definition.rs @@ -248,36 +248,56 @@ fn parse_field_definition( return Err("Empty field definition".into()); } - let field_name = parts[0].to_string(); - let mut field_type = String::new(); + // Handle FIELD keyword: "FIELD id AS STRING" -> ["FIELD", "id", "AS", "STRING"] + // Skip "FIELD" keyword if present + let start_idx = if parts[0].eq_ignore_ascii_case("FIELD") { 1 } else { 0 }; + let as_idx = parts.iter().position(|p| p.eq_ignore_ascii_case("AS")); + + if parts.len() < start_idx + 1 { + return Err("Invalid field definition - missing field name".into()); + } + + let field_name = parts[start_idx].to_string(); + + // Parse type from "AS type" or use default if no AS keyword + let mut field_type = if let Some(as_idx) = as_idx { + if as_idx + 1 < parts.len() { + parts[as_idx + 1].to_lowercase() + } else { + "text".to_string() + } + } else if parts.len() > start_idx + 1 { + parts[start_idx + 1].to_lowercase() + } else { + "text".to_string() + }; + + // Handle type parameters like STRING(255) let mut length: Option = None; let mut precision: Option = None; + if let Some(paren_start) = field_type.find('(') { + let base_type = field_type[..paren_start].to_lowercase(); + let params = &field_type[paren_start + 1..field_type.len() - 1]; + let param_parts: Vec<&str> = params.split(',').collect(); + + if !param_parts.is_empty() { + length = param_parts[0].trim().parse().ok(); + } + if param_parts.len() > 1 { + precision = param_parts[1].trim().parse().ok(); + } + field_type = base_type; + } + let mut is_key = false; let mut reference_table: Option = None; - if parts.len() >= 2 { - let type_part = parts[1]; - - if let Some(paren_start) = type_part.find('(') { - field_type = type_part[..paren_start].to_lowercase(); - let params = &type_part[paren_start + 1..type_part.len() - 1]; - let param_parts: Vec<&str> = params.split(',').collect(); - - if !param_parts.is_empty() { - length = param_parts[0].trim().parse().ok(); - } - if param_parts.len() > 1 { - precision = param_parts[1].trim().parse().ok(); - } - } else { - field_type = type_part.to_lowercase(); - } - } - - for i in 2..parts.len() { + // Process remaining parts for KEY and REFERENCES keywords + for i in (start_idx + 1)..parts.len() { let part = parts[i].to_lowercase(); match part.as_str() { "key" => is_key = true, + "as" => {} // Skip AS keyword - already handled above _ if parts .get(i - 1) .map(|p| p.eq_ignore_ascii_case("references")) @@ -619,14 +639,20 @@ pub fn process_table_definitions( table.name, table.connection_name ); - store_table_definition(&mut conn, bot_id, table)?; + // No need to store table definition - just create the table in the bot's database if table.connection_name == "default" { let create_sql = generate_create_table_sql(table, "postgres"); - info!("Creating table {} on default connection", table.name); + info!("Creating table {} in bot's database (bot_id={})", table.name, bot_id); trace!("SQL: {}", create_sql); - sql_query(&create_sql).execute(&mut conn)?; + // Create table in the bot's own database using BotDatabaseManager + if let Err(e) = state.bot_database_manager.create_table_in_bot_database(bot_id, &create_sql) { + error!("Failed to create table {} in bot database: {}", table.name, e); + // Continue with other tables even if this one fails + } else { + info!("Successfully created table {} in bot database", table.name); + } } else { match load_connection_config(&state, bot_id, &table.connection_name) { Ok(ext_conn) => { diff --git a/src/basic/mod.rs b/src/basic/mod.rs index 412841269..f307ed90f 100644 --- a/src/basic/mod.rs +++ b/src/basic/mod.rs @@ -165,6 +165,9 @@ impl ScriptService { #[cfg(feature = "chat")] register_bot_keywords(&state, &user, &mut engine); + // ===== PROCEDURE KEYWORDS (RETURN, etc.) ===== + keywords::procedures::register_procedure_keywords(state.clone(), user.clone(), &mut engine); + // ===== WORKFLOW ORCHESTRATION KEYWORDS ===== keywords::orchestration::register_orchestrate_workflow(state.clone(), user.clone(), &mut engine); keywords::orchestration::register_step_keyword(state.clone(), user.clone(), &mut engine); @@ -561,10 +564,268 @@ impl ScriptService { Err(parse_error) => Err(Box::new(parse_error.into())), } } + + /// Compile a tool script (.bas file with PARAM/DESCRIPTION metadata lines) + /// Filters out tool metadata before compiling + pub fn compile_tool_script(&self, script: &str) -> Result> { + // Filter out PARAM, DESCRIPTION, comment, and empty lines (tool metadata) + let executable_script: String = script + .lines() + .filter(|line| { + let trimmed = line.trim(); + // Keep lines that are NOT PARAM, DESCRIPTION, comments, or empty + !(trimmed.starts_with("PARAM ") || + trimmed.starts_with("PARAM\t") || + trimmed.starts_with("DESCRIPTION ") || + trimmed.starts_with("DESCRIPTION\t") || + trimmed.starts_with('\'') || // BASIC comment lines + trimmed.is_empty()) + }) + .collect::>() + .join("\n"); + + info!("[TOOL] Filtered tool metadata: {} -> {} chars", script.len(), executable_script.len()); + + // Apply minimal preprocessing for tools (skip variable normalization to avoid breaking multi-line strings) + let script = preprocess_switch(&executable_script); + let script = Self::convert_multiword_keywords(&script); + // Skip normalize_variables_to_lowercase for tools - it breaks multi-line strings + // Note: FORMAT is registered as a regular function, so FORMAT(expr, pattern) works directly + + info!("[TOOL] Preprocessed tool script for Rhai compilation"); + // Convert IF ... THEN / END IF to if ... { } + let script = Self::convert_if_then_syntax(&script); + // Convert BASIC keywords to lowercase (but preserve variable casing) + let script = Self::convert_keywords_to_lowercase(&script); + // Save to file for debugging + if let Err(e) = std::fs::write("/tmp/tool_preprocessed.bas", &script) { + log::warn!("Failed to write preprocessed script: {}", e); + } + match self.engine.compile(&script) { + Ok(ast) => Ok(ast), + Err(parse_error) => Err(Box::new(parse_error.into())), + } + } pub fn run(&mut self, ast: &rhai::AST) -> Result> { self.engine.eval_ast_with_scope(&mut self.scope, ast) } + /// Set a variable in the script scope (for tool parameters) + pub fn set_variable(&mut self, name: &str, value: &str) -> Result<(), Box> { + use rhai::Dynamic; + self.scope.set_or_push(name, Dynamic::from(value.to_string())); + Ok(()) + } + + /// Convert FORMAT(expr, pattern) to FORMAT expr pattern (custom syntax format) + /// Also handles RANDOM and other functions that need space-separated arguments + fn convert_format_syntax(script: &str) -> String { + use regex::Regex; + let mut result = script.to_string(); + + // First, process RANDOM to ensure commas are preserved + // RANDOM(min, max) stays as RANDOM(min, max) - no conversion needed + + // Convert FORMAT(expr, pattern) → FORMAT expr pattern + // Need to handle nested functions carefully + // Match FORMAT( ... ) but don't include inner function parentheses + // This regex matches FORMAT followed by parentheses containing two comma-separated expressions + if let Ok(re) = Regex::new(r"(?i)FORMAT\s*\(([^()]+(?:\([^()]*\)[^()]*)*),([^)]+)\)") { + result = re.replace_all(&result, "FORMAT $1$2").to_string(); + } + + result + } + + /// Convert BASIC IF ... THEN / END IF syntax to Rhai's if ... { } syntax + fn convert_if_then_syntax(script: &str) -> String { + let mut result = String::new(); + let mut if_stack: Vec = Vec::new(); // Track if we're inside an IF block + let mut in_with_block = false; // Track if we're inside a WITH block + let mut line_buffer = String::new(); + + log::info!("[TOOL] Converting IF/THEN syntax, input has {} lines", script.lines().count()); + + for line in script.lines() { + let trimmed = line.trim(); + let upper = trimmed.to_uppercase(); + + // Skip empty lines and comments + if trimmed.is_empty() || trimmed.starts_with('\'') || trimmed.starts_with("//") { + continue; + } + + // Handle IF ... THEN + if upper.starts_with("IF ") && upper.contains(" THEN") { + let then_pos = upper.find(" THEN").unwrap(); + let condition = &trimmed[3..then_pos].trim(); + log::info!("[TOOL] Converting IF statement: condition='{}'", condition); + result.push_str("if "); + result.push_str(condition); + result.push_str(" {\n"); + if_stack.push(true); + continue; + } + + // Handle ELSE + if upper == "ELSE" { + log::info!("[TOOL] Converting ELSE statement"); + result.push_str("} else {\n"); + continue; + } + + // Handle END IF + if upper == "END IF" { + log::info!("[TOOL] Converting END IF statement"); + if let Some(_) = if_stack.pop() { + result.push_str("}\n"); + } + continue; + } + + // Handle WITH ... END WITH (BASIC object creation) + if upper.starts_with("WITH ") { + let object_name = &trimmed[5..].trim(); + log::info!("[TOOL] Converting WITH statement: object='{}'", object_name); + // Convert WITH obj → let obj = #{ (start object literal) + result.push_str("let "); + result.push_str(object_name); + result.push_str(" = #{\n"); + in_with_block = true; + continue; + } + + if upper == "END WITH" { + log::info!("[TOOL] Converting END WITH statement"); + result.push_str("};\n"); + in_with_block = false; + continue; + } + + // Inside a WITH block - convert property assignments (key = value → key: value) + if in_with_block { + // Check if this is a property assignment (identifier = value) + if trimmed.contains('=') && !trimmed.contains("==") && !trimmed.contains("!=") && !trimmed.contains("+=") && !trimmed.contains("-=") { + // Convert assignment to object property syntax + let parts: Vec<&str> = trimmed.splitn(2, '=').collect(); + if parts.len() == 2 { + let property_name = parts[0].trim(); + let property_value = parts[1].trim(); + // Remove trailing semicolon if present + let property_value = property_value.trim_end_matches(';'); + result.push_str(&format!(" {}: {},\n", property_name, property_value)); + continue; + } + } + // Regular line in WITH block - add indentation + result.push_str(" "); + } + + // Handle SAVE table, object → INSERT table, object + // BASIC SAVE uses 2 parameters but Rhai SAVE needs 3 + // INSERT uses 2 parameters which matches the BASIC syntax + if upper.starts_with("SAVE") && upper.contains(',') { + log::info!("[TOOL] Processing SAVE line: '{}'", trimmed); + // Extract table and object name + let after_save = &trimmed[4..].trim(); // Skip "SAVE" + let parts: Vec<&str> = after_save.split(',').collect(); + log::info!("[TOOL] SAVE parts: {:?}", parts); + if parts.len() == 2 { + let table = parts[0].trim().trim_matches('"'); + let object_name = parts[1].trim().trim_end_matches(';'); + // Convert to INSERT table, object + let converted = format!("INSERT \"{}\", {};\n", table, object_name); + log::info!("[TOOL] Converted SAVE to INSERT: '{}'", converted); + result.push_str(&converted); + continue; + } + } + + // Handle SEND EMAIL → send_mail (function call style) + // Syntax: SEND EMAIL to, subject, body → send_mail(to, subject, body, []) + if upper.starts_with("SEND EMAIL") { + log::info!("[TOOL] Processing SEND EMAIL line: '{}'", trimmed); + let after_send = &trimmed[11..].trim(); // Skip "SEND EMAIL " (10 chars + space = 11) + let parts: Vec<&str> = after_send.split(',').collect(); + log::info!("[TOOL] SEND EMAIL parts: {:?}", parts); + if parts.len() == 3 { + let to = parts[0].trim(); + let subject = parts[1].trim(); + let body = parts[2].trim().trim_end_matches(';'); + // Convert to send_mail(to, subject, body, []) function call + let converted = format!("send_mail({}, {}, {}, []);\n", to, subject, body); + log::info!("[TOOL] Converted SEND EMAIL to: '{}'", converted); + result.push_str(&converted); + continue; + } + } + + // Regular line - add indentation if inside IF block + if !if_stack.is_empty() { + result.push_str(" "); + } + + // Check if line is a simple statement (not containing THEN or other control flow) + if !upper.starts_with("IF ") && !upper.starts_with("ELSE") && !upper.starts_with("END IF") { + // Check if this is a variable assignment (identifier = expression) + // Pattern: starts with letter/underscore, contains = but not ==, !=, <=, >=, +=, -= + let is_var_assignment = trimmed.chars().next().map_or(false, |c| c.is_alphabetic() || c == '_') + && trimmed.contains('=') + && !trimmed.contains("==") + && !trimmed.contains("!=") + && !trimmed.contains("<=") + && !trimmed.contains(">=") + && !trimmed.contains("+=") + && !trimmed.contains("-=") + && !trimmed.contains("*=") + && !trimmed.contains("/="); + + if is_var_assignment { + // Add 'let' for variable declarations + result.push_str("let "); + } + result.push_str(trimmed); + // Add semicolon if line doesn't have one and doesn't end with { or } + if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') { + result.push(';'); + } + result.push('\n'); + } else { + result.push_str(trimmed); + result.push('\n'); + } + } + + log::info!("[TOOL] IF/THEN conversion complete, output has {} lines", result.lines().count()); + result + } + + /// Convert BASIC keywords to lowercase without touching variables + /// This is a simplified version of normalize_variables_to_lowercase for tools + fn convert_keywords_to_lowercase(script: &str) -> String { + let keywords = [ + "IF", "THEN", "ELSE", "END IF", "FOR", "NEXT", "WHILE", "WEND", + "DO", "LOOP", "RETURN", "EXIT", "SELECT", "CASE", "END SELECT", + "WITH", "END WITH", "AND", "OR", "NOT", "MOD", + "DIM", "AS", "NEW", "FUNCTION", "SUB", "CALL", + ]; + + let mut result = String::new(); + for line in script.lines() { + let mut processed_line = line.to_string(); + for keyword in &keywords { + // Use word boundaries to avoid replacing parts of variable names + let pattern = format!(r"\b{}\b", regex::escape(keyword)); + if let Ok(re) = regex::Regex::new(&pattern) { + processed_line = re.replace_all(&processed_line, keyword.to_lowercase()).to_string(); + } + } + result.push_str(&processed_line); + result.push('\n'); + } + result + } + fn normalize_variables_to_lowercase(script: &str) -> String { use regex::Regex; @@ -809,6 +1070,19 @@ impl ScriptService { continue; } + // Skip lines with custom syntax that should not be lowercased + // These are registered directly with Rhai in uppercase + let trimmed_upper = trimmed.to_uppercase(); + if trimmed_upper.contains("ADD_SUGGESTION_TOOL") || + trimmed_upper.contains("ADD_SUGGESTION_TEXT") || + trimmed_upper.starts_with("ADD_SUGGESTION_") || + trimmed_upper.starts_with("ADD_MEMBER") { + // Keep original line as-is + result.push_str(line); + result.push('\n'); + continue; + } + let mut processed_line = String::new(); let mut chars = line.chars().peekable(); let mut in_string = false; @@ -889,8 +1163,10 @@ impl ScriptService { (r#"CLEAR\s+TOOLS"#, 0, 0, vec![]), (r#"CLEAR\s+WEBSITES"#, 0, 0, vec![]), - // ADD family - (r#"ADD\s+SUGGESTION"#, 2, 2, vec!["title", "text"]), + // ADD family - ADD_SUGGESTION_TOOL must come before ADD\s+SUGGESTION + (r#"ADD_SUGGESTION_TOOL"#, 2, 2, vec!["tool", "text"]), + (r#"ADD\s+SUGGESTION\s+TEXT"#, 2, 2, vec!["value", "text"]), + (r#"ADD\s+SUGGESTION(?!\s*TEXT|\s*TOOL|_TOOL)"#, 2, 2, vec!["context", "text"]), (r#"ADD\s+MEMBER"#, 2, 2, vec!["name", "role"]), // CREATE family @@ -916,6 +1192,23 @@ impl ScriptService { let trimmed = line.trim(); let mut converted = false; + // Skip lines that already use underscore-style custom syntax + // These are registered directly with Rhai and should not be converted + let trimmed_upper = trimmed.to_uppercase(); + if trimmed_upper.contains("ADD_SUGGESTION_TOOL") || + trimmed_upper.contains("ADD_SUGGESTION_TEXT") || + trimmed_upper.starts_with("ADD_SUGGESTION_") || + trimmed_upper.starts_with("ADD_MEMBER") || + (trimmed_upper.starts_with("USE_") && trimmed.contains('(')) { + // Keep original line and add semicolon if needed + result.push_str(line); + if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') { + result.push(';'); + } + result.push('\n'); + continue; + } + // Try each pattern for (pattern, min_params, max_params, _param_names) in &multiword_patterns { // Build regex pattern: KEYWORD params... diff --git a/src/core/bot/mod.rs b/src/core/bot/mod.rs index b33656ef2..27c4c70a3 100644 --- a/src/core/bot/mod.rs +++ b/src/core/bot/mod.rs @@ -4,6 +4,8 @@ pub mod kb_context; use kb_context::inject_kb_context; pub mod tool_context; use tool_context::get_session_tools; +pub mod tool_executor; +use tool_executor::ToolExecutor; #[cfg(feature = "llm")] use crate::core::config::ConfigManager; @@ -18,6 +20,8 @@ use crate::nvidia::get_system_metrics; use crate::shared::message_types::MessageType; use crate::shared::models::{BotResponse, UserMessage, UserSession}; use crate::shared::state::AppState; +#[cfg(feature = "chat")] +use crate::basic::keywords::add_suggestion::get_suggestions; use axum::extract::ws::{Message, WebSocket}; use axum::{ extract::{ws::WebSocketUpgrade, Extension, Query, State}, @@ -422,86 +426,101 @@ impl BotOrchestrator { #[cfg(any(feature = "research", feature = "llm"))] { - // Execute start.bas on first message to load tools - // Check if tools have been loaded for this session - let has_tools_loaded = { - use crate::shared::models::schema::session_tool_associations::dsl::*; - let conn = self.state.conn.get().ok(); - if let Some(mut db_conn) = conn { - let tool_names: Vec = session_tool_associations - .filter(session_id.eq(&session_id_str)) - .select(tool_name) - .limit(1) - .load(&mut db_conn) - .unwrap_or_default(); - !tool_names.is_empty() + // Execute start.bas on first message - ONLY run once per session to load suggestions + let actual_session_id = session.id.to_string(); + + // Check if start.bas has already been executed for this session + let start_bas_key = format!("start_bas_executed:{}", actual_session_id); + let should_execute_start_bas = if let Some(cache) = &self.state.cache { + if let Ok(mut conn) = cache.get_multiplexed_async_connection().await { + let executed: Result, redis::RedisError> = redis::cmd("GET") + .arg(&start_bas_key) + .query_async(&mut conn) + .await; + executed.is_ok() && executed.unwrap().is_none() } else { - false + true // If cache fails, try to execute } + } else { + true // If no cache, try to execute }; - // If no tools are loaded, execute start.bas - if !has_tools_loaded { + if should_execute_start_bas { + // Always execute start.bas for this session (blocking - wait for completion) let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); let data_dir = format!("{}/data", home_dir); let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name_for_context, bot_name_for_context); - info!("[START_BAS] Checking for start.bas at: {}", start_script_path); + info!("[START_BAS] Executing start.bas for session {} at: {}", actual_session_id, start_script_path); if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await { if metadata.is_file() { - info!("[START_BAS] Found start.bas, executing for session {}", session_id); + info!("[START_BAS] Found start.bas, executing for session {}", actual_session_id); if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await { let state_clone = self.state.clone(); - let session_id_clone = session_id; + let actual_session_id_for_task = session.id; let bot_id_clone = session.bot_id; - let bot_name_clone = bot_name_for_context.clone(); - let bot_name_for_log = bot_name_clone.clone(); - let start_script_clone = start_script.clone(); - tokio::spawn(async move { - let session_result = state_clone.session_manager.lock().await.get_session_by_id(session_id_clone); + // Execute start.bas synchronously (blocking) + let result = tokio::task::spawn_blocking(move || { + let session_result = { + let mut sm = state_clone.session_manager.blocking_lock(); + sm.get_session_by_id(actual_session_id_for_task) + }; - if let Ok(Some(session)) = session_result { - let result = tokio::task::spawn_blocking(move || { - let mut script_service = crate::basic::ScriptService::new( - state_clone.clone(), - session.clone() - ); - script_service.load_bot_config_params(&state_clone, bot_id_clone); + let sess = match session_result { + Ok(Some(s)) => s, + Ok(None) => { + return Err(format!("Session {} not found during start.bas execution", actual_session_id_for_task)); + } + Err(e) => return Err(format!("Failed to get session: {}", e)), + }; - match script_service.compile(&start_script_clone) { - Ok(ast) => match script_service.run(&ast) { - Ok(_) => { - info!("[START_BAS] Executed start.bas successfully for bot {}", bot_name_clone); - Ok(()) - }, - Err(e) => Err(format!("Script execution error: {}", e)), - }, - Err(e) => Err(format!("Script compilation error: {}", e)), - } - }).await; + let mut script_service = crate::basic::ScriptService::new( + state_clone.clone(), + sess + ); + script_service.load_bot_config_params(&state_clone, bot_id_clone); - match result { - Ok(Ok(())) => { - info!("[START_BAS] start.bas completed for bot {}", bot_name_for_log); - } - Ok(Err(e)) => { - error!("[START_BAS] start.bas error for bot {}: {}", bot_name_for_log, e); - } - Err(e) => { - error!("[START_BAS] start.bas task error for bot {}: {}", bot_name_for_log, e); + match script_service.compile(&start_script) { + Ok(ast) => match script_service.run(&ast) { + Ok(_) => Ok(()), + Err(e) => Err(format!("Script execution error: {}", e)), + }, + Err(e) => Err(format!("Script compilation error: {}", e)), + } + }).await; + + match result { + Ok(Ok(())) => { + info!("[START_BAS] start.bas completed successfully for session {}", actual_session_id); + + // Mark start.bas as executed for this session to prevent re-running + if let Some(cache) = &self.state.cache { + if let Ok(mut conn) = cache.get_multiplexed_async_connection().await { + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&start_bas_key) + .arg("1") + .arg("EX") + .arg("86400") // Expire after 24 hours + .query_async(&mut conn) + .await; + info!("[START_BAS] Marked start.bas as executed for session {}", actual_session_id); } } } - }); + Ok(Err(e)) => { + error!("[START_BAS] start.bas error for session {}: {}", actual_session_id, e); + } + Err(e) => { + error!("[START_BAS] start.bas task error for session {}: {}", actual_session_id, e); + } + } } - } else { - info!("[START_BAS] start.bas not found for bot {}", bot_name_for_context); } } - } + } // End of if should_execute_start_bas if let Some(kb_manager) = self.state.kb_manager.as_ref() { if let Err(e) = inject_kb_context( @@ -539,15 +558,15 @@ impl BotOrchestrator { let model_clone = model.clone(); let key_clone = key.clone(); - // Retrieve session tools for tool calling - let session_tools = get_session_tools(&self.state.conn, &bot_name_for_context, &session_id); + // Retrieve session tools for tool calling (use actual session.id after potential creation) + let session_tools = get_session_tools(&self.state.conn, &bot_name_for_context, &session.id); let tools_for_llm = match session_tools { Ok(tools) => { if !tools.is_empty() { - info!("[TOOLS] Loaded {} tools for session {}", tools.len(), session_id); + info!("[TOOLS] Loaded {} tools for session {}", tools.len(), session.id); Some(tools) } else { - info!("[TOOLS] No tools associated with session {}", session_id); + info!("[TOOLS] No tools associated with session {}", session.id); None } } @@ -615,6 +634,89 @@ impl BotOrchestrator { info!("[STREAM_DEBUG] Received chunk: '{}', len: {}", chunk, chunk.len()); trace!("Received LLM chunk: {:?}", chunk); + // ===== GENERIC TOOL EXECUTION ===== + // Check if this chunk contains a tool call (works with all LLM providers) + if let Some(tool_call) = ToolExecutor::parse_tool_call(&chunk) { + info!( + "[TOOL_CALL] Detected tool '{}' from LLM, executing...", + tool_call.tool_name + ); + + let execution_result = ToolExecutor::execute_tool_call( + &self.state, + &bot_name_for_context, + &tool_call, + &session_id, + &user_id, + ) + .await; + + if execution_result.success { + info!( + "[TOOL_EXEC] Tool '{}' executed successfully: {}", + tool_call.tool_name, execution_result.result + ); + + // Send tool execution result to user + let response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: execution_result.result, + message_type: MessageType::BOT_RESPONSE, + stream_token: None, + is_complete: false, + suggestions: Vec::new(), + context_name: None, + context_length: 0, + context_max_length: 0, + }; + + if response_tx.send(response).await.is_err() { + warn!("Response channel closed during tool execution"); + break; + } + } else { + error!( + "[TOOL_EXEC] Tool '{}' execution failed: {:?}", + tool_call.tool_name, execution_result.error + ); + + // Send error to user + let error_msg = format!( + "Erro ao executar ferramenta '{}': {:?}", + tool_call.tool_name, + execution_result.error + ); + + let response = BotResponse { + bot_id: message.bot_id.clone(), + user_id: message.user_id.clone(), + session_id: message.session_id.clone(), + channel: message.channel.clone(), + content: error_msg, + message_type: MessageType::BOT_RESPONSE, + stream_token: None, + is_complete: false, + suggestions: Vec::new(), + context_name: None, + context_length: 0, + context_max_length: 0, + }; + + if response_tx.send(response).await.is_err() { + warn!("Response channel closed during tool error"); + break; + } + } + + // Don't add tool_call JSON to full_response or analysis_buffer + // Continue to next chunk + continue; + } + // ===== END TOOL EXECUTION ===== + analysis_buffer.push_str(&chunk); if !in_analysis && handler.has_analysis_markers(&analysis_buffer) { @@ -732,6 +834,15 @@ impl BotOrchestrator { ) .await??; + // Extract user_id and session_id before moving them into BotResponse + let user_id_str = message.user_id.clone(); + let session_id_str = message.session_id.clone(); + + #[cfg(feature = "chat")] + let suggestions = get_suggestions(self.state.cache.as_ref(), &user_id_str, &session_id_str); + #[cfg(not(feature = "chat"))] + let suggestions: Vec = Vec::new(); + let final_response = BotResponse { bot_id: message.bot_id, user_id: message.user_id, @@ -741,7 +852,7 @@ impl BotOrchestrator { message_type: MessageType::BOT_RESPONSE, stream_token: None, is_complete: true, - suggestions: Vec::new(), + suggestions, context_name: None, context_length: 0, context_max_length: 0, @@ -937,16 +1048,33 @@ async fn handle_websocket( ); if let Some(bot_name) = bot_name_result { - let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); - let data_dir = format!("{}/data", home_dir); - let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name, bot_name); + // Check if start.bas has already been executed for this session + let start_bas_key = format!("start_bas_executed:{}", session_id); + let should_execute_start_bas = if let Some(cache) = &state.cache { + if let Ok(mut conn) = cache.get_multiplexed_async_connection().await { + let executed: Result, redis::RedisError> = redis::cmd("GET") + .arg(&start_bas_key) + .query_async(&mut conn) + .await; + executed.is_ok() && executed.unwrap().is_none() + } else { + true // If cache fails, try to execute + } + } else { + true // If no cache, try to execute + }; - info!("Looking for start.bas at: {}", start_script_path); + if should_execute_start_bas { + let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + let data_dir = format!("{}/data", home_dir); + let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name, bot_name); - if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await { - if metadata.is_file() { - info!("Found start.bas file, reading contents..."); - if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await { + info!("Looking for start.bas at: {}", start_script_path); + + if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await { + if metadata.is_file() { + info!("Found start.bas file, reading contents..."); + if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await { info!( "Executing start.bas for bot {} on session {}", bot_name, session_id @@ -964,6 +1092,9 @@ async fn handle_websocket( if let Ok(Some(session)) = session_result { info!("Executing start.bas for bot {} on session {}", bot_name, session_id); + // Clone state_for_start for use in Redis SET after execution + let state_for_redis = state_for_start.clone(); + let result = tokio::task::spawn_blocking(move || { let mut script_service = crate::basic::ScriptService::new( state_for_start.clone(), @@ -983,6 +1114,20 @@ async fn handle_websocket( match result { Ok(Ok(())) => { info!("start.bas executed successfully for bot {}", bot_name); + + // Mark start.bas as executed for this session to prevent re-running + if let Some(cache) = &state_for_redis.cache { + if let Ok(mut conn) = cache.get_multiplexed_async_connection().await { + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&start_bas_key) + .arg("1") + .arg("EX") + .arg("86400") // Expire after 24 hours + .query_async(&mut conn) + .await; + info!("Marked start.bas as executed for session {}", session_id); + } + } } Ok(Err(e)) => { error!("start.bas error for bot {}: {}", bot_name, e); @@ -996,6 +1141,7 @@ async fn handle_websocket( } } } + } // End of if should_execute_start_bas } } @@ -1017,15 +1163,52 @@ async fn handle_websocket( info!("Received WebSocket message: {}", text); if let Ok(user_msg) = serde_json::from_str::(&text) { let orchestrator = BotOrchestrator::new(state_clone.clone()); + info!("[WS_DEBUG] Looking up response channel for session: {}", session_id); if let Some(tx_clone) = state_clone .response_channels .lock() .await .get(&session_id.to_string()) { + info!("[WS_DEBUG] Response channel found, calling stream_response"); + + // Ensure session exists - create if not + let session_result = { + let mut sm = state_clone.session_manager.lock().await; + sm.get_session_by_id(session_id) + }; + + let session = match session_result { + Ok(Some(sess)) => { + info!("[WS_DEBUG] Session exists: {}", session_id); + sess + } + Ok(None) => { + info!("[WS_DEBUG] Session not found, creating via session manager"); + // Use session manager to create session (will generate new UUID) + let mut sm = state_clone.session_manager.lock().await; + match sm.create_session(user_id, bot_id, "WebSocket Chat") { + Ok(new_session) => { + info!("[WS_DEBUG] Session created: {} (note: different from WebSocket session_id)", new_session.id); + new_session + } + Err(e) => { + error!("[WS_DEBUG] Failed to create session: {}", e); + continue; + } + } + } + Err(e) => { + error!("[WS_DEBUG] Error getting session: {}", e); + continue; + } + }; + // Use bot_id from WebSocket connection instead of from message let corrected_msg = UserMessage { bot_id: bot_id.to_string(), + user_id: session.user_id.to_string(), + session_id: session.id.to_string(), ..user_msg }; if let Err(e) = orchestrator @@ -1034,7 +1217,11 @@ async fn handle_websocket( { error!("Failed to stream response: {}", e); } + } else { + warn!("[WS_DEBUG] Response channel NOT found for session: {}", session_id); } + } else { + warn!("[WS_DEBUG] Failed to parse UserMessage from: {}", text); } } Message::Close(_) => { diff --git a/src/core/bot/tool_executor.rs b/src/core/bot/tool_executor.rs new file mode 100644 index 000000000..21dc1ec87 --- /dev/null +++ b/src/core/bot/tool_executor.rs @@ -0,0 +1,384 @@ +/// Generic tool executor for LLM tool calls +/// Works across all LLM providers (GLM, OpenAI, Claude, etc.) +use log::{error, info, warn}; +use serde_json::Value; +use std::fs::OpenOptions; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; +use uuid::Uuid; + +use crate::basic::ScriptService; +use crate::shared::state::AppState; +use crate::shared::models::schema::bots; +use diesel::prelude::*; + +/// Represents a parsed tool call from an LLM +#[derive(Debug, Clone)] +pub struct ParsedToolCall { + pub id: String, + pub tool_name: String, + pub arguments: Value, +} + +/// Result of tool execution +#[derive(Debug, Clone)] +pub struct ToolExecutionResult { + pub tool_call_id: String, + pub success: bool, + pub result: String, + pub error: Option, +} + +/// Generic tool executor - works with any LLM provider +pub struct ToolExecutor; + +impl ToolExecutor { + /// Log tool execution errors to a dedicated log file + fn log_tool_error(bot_name: &str, tool_name: &str, error_msg: &str) { + let log_path = Path::new("work").join(format!("{}_tool_errors.log", bot_name)); + + // Create work directory if it doesn't exist + if let Some(parent) = log_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + + let timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC"); + let log_entry = format!( + "[{}] TOOL: {} | ERROR: {}\n", + timestamp, tool_name, error_msg + ); + + // Append to log file + if let Ok(mut file) = OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + { + let _ = file.write_all(log_entry.as_bytes()); + } + + // Also log to system logger + error!("[TOOL_ERROR] Bot: {}, Tool: {}, Error: {}", bot_name, tool_name, error_msg); + } + + /// Convert internal errors to user-friendly messages for browser + fn format_user_friendly_error(tool_name: &str, error: &str) -> String { + // Don't expose internal errors to browser - log them and return generic message + if error.contains("Compilation error") { + "Desculpe, houve um erro ao processar sua solicitação. Por favor, tente novamente ou entre em contato com a administração." + .to_string() + } else if error.contains("Execution error") { + format!("O processamento da ferramenta '{}' encontrou um problema. Nossa equipe foi notificada.", tool_name) + } else if error.contains("not found") { + "Ferramenta não disponível no momento. Por favor, tente novamente mais tarde.".to_string() + } else { + "Ocorreu um erro ao processar sua solicitação. Por favor, tente novamente.".to_string() + } + } + /// Parse a tool call JSON from any LLM provider + /// Handles OpenAI, GLM, Claude formats + pub fn parse_tool_call(chunk: &str) -> Option { + // Try to parse as JSON + let json: Value = serde_json::from_str(chunk).ok()?; + + // Check if this is a tool_call type (from GLM wrapper) + if let Some(tool_type) = json.get("type").and_then(|t| t.as_str()) { + if tool_type == "tool_call" { + if let Some(content) = json.get("content") { + return Self::extract_tool_call(content); + } + } + } + + // Try direct OpenAI format + if json.get("function").is_some() { + return Self::extract_tool_call(&json); + } + + None + } + + /// Extract tool call information from various formats + fn extract_tool_call(tool_data: &Value) -> Option { + let function = tool_data.get("function")?; + let tool_name = function.get("name")?.as_str()?.to_string(); + let arguments_str = function.get("arguments")?.as_str()?; + + // Parse arguments string to JSON + let arguments: Value = serde_json::from_str(arguments_str).ok()?; + + let id = tool_data + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + Some(ParsedToolCall { + id, + tool_name, + arguments, + }) + } + + /// Execute a tool call by running the corresponding .bas file + pub async fn execute_tool_call( + state: &Arc, + bot_name: &str, + tool_call: &ParsedToolCall, + session_id: &Uuid, + _user_id: &Uuid, + ) -> ToolExecutionResult { + info!( + "[TOOL_EXEC] Executing tool '{}' for bot '{}', session '{}'", + tool_call.tool_name, bot_name, session_id + ); + + // Get bot_id + let bot_id = match Self::get_bot_id(state, bot_name) { + Some(id) => id, + None => { + let error_msg = format!("Bot '{}' not found", bot_name); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id: tool_call.id.clone(), + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + }; + } + }; + + // Load the .bas tool file + let bas_path = Self::get_tool_bas_path(bot_name, &tool_call.tool_name); + + if !bas_path.exists() { + let error_msg = format!("Tool file not found: {:?}", bas_path); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id: tool_call.id.clone(), + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + }; + } + + // Read the .bas file + let bas_script = match tokio::fs::read_to_string(&bas_path).await { + Ok(script) => script, + Err(e) => { + let error_msg = format!("Failed to read tool file: {}", e); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id: tool_call.id.clone(), + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + }; + } + }; + + // Get session for ScriptService + let session = match state.session_manager.lock().await.get_session_by_id(*session_id) { + Ok(Some(sess)) => sess, + Ok(None) => { + let error_msg = "Session not found".to_string(); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id: tool_call.id.clone(), + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + }; + } + Err(e) => { + let error_msg = format!("Failed to get session: {}", e); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id: tool_call.id.clone(), + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + }; + } + }; + + // Execute in blocking thread for ScriptService (which is not async) + let bot_name_clone = bot_name.to_string(); + let tool_name_clone = tool_call.tool_name.clone(); + let tool_call_id_clone = tool_call.id.clone(); + let arguments_clone = tool_call.arguments.clone(); + let state_clone = state.clone(); + let bot_id_clone = bot_id; + + let execution_result = tokio::task::spawn_blocking(move || { + Self::execute_tool_script( + &state_clone, + &bot_name_clone, + bot_id_clone, + &session, + &bas_script, + &tool_name_clone, + &arguments_clone, + ) + }) + .await; + + match execution_result { + Ok(result) => result, + Err(e) => { + let error_msg = format!("Task execution error: {}", e); + Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg); + ToolExecutionResult { + tool_call_id: tool_call_id_clone, + success: false, + result: String::new(), + error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)), + } + } + } + } + + /// Execute the tool script with parameters + fn execute_tool_script( + state: &Arc, + bot_name: &str, + bot_id: Uuid, + session: &crate::shared::models::UserSession, + bas_script: &str, + tool_name: &str, + arguments: &Value, + ) -> ToolExecutionResult { + let tool_call_id = format!("tool_{}", uuid::Uuid::new_v4()); + + // Create ScriptService + let mut script_service = ScriptService::new(state.clone(), session.clone()); + script_service.load_bot_config_params(state, bot_id); + + // Set tool parameters as variables in the engine scope + if let Some(obj) = arguments.as_object() { + for (key, value) in obj { + let value_str = match value { + Value::String(s) => s.clone(), + Value::Number(n) => n.to_string(), + Value::Bool(b) => b.to_string(), + _ => value.to_string(), + }; + + // Set variable in script scope + if let Err(e) = script_service.set_variable(key, &value_str) { + warn!("[TOOL_EXEC] Failed to set variable '{}': {}", key, e); + } + } + } + + // Compile tool script (filters PARAM/DESCRIPTION lines and converts BASIC to Rhai) + let ast = match script_service.compile_tool_script(&bas_script) { + Ok(ast) => ast, + Err(e) => { + let error_msg = format!("Compilation error: {}", e); + Self::log_tool_error(bot_name, tool_name, &error_msg); + let user_message = Self::format_user_friendly_error(tool_name, &error_msg); + return ToolExecutionResult { + tool_call_id, + success: false, + result: String::new(), + error: Some(user_message), + }; + } + }; + + // Run the script + match script_service.run(&ast) { + Ok(result) => { + info!("[TOOL_EXEC] Tool '{}' executed successfully", tool_name); + + // Convert result to string + let result_str = result.to_string(); + + ToolExecutionResult { + tool_call_id, + success: true, + result: result_str, + error: None, + } + } + Err(e) => { + let error_msg = format!("Execution error: {}", e); + Self::log_tool_error(bot_name, tool_name, &error_msg); + let user_message = Self::format_user_friendly_error(tool_name, &error_msg); + ToolExecutionResult { + tool_call_id, + success: false, + result: String::new(), + error: Some(user_message), + } + } + } + } + + /// Get the bot_id from bot_name + fn get_bot_id(state: &Arc, bot_name: &str) -> Option { + let mut conn = state.conn.get().ok()?; + bots::table + .filter(bots::name.eq(bot_name)) + .select(bots::id) + .first(&mut *conn) + .ok() + } + + /// Get the path to a tool's .bas file + fn get_tool_bas_path(bot_name: &str, tool_name: &str) -> std::path::PathBuf { + let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + + // Try data directory first + let data_path = Path::new(&home_dir) + .join("data") + .join(format!("{}.gbai", bot_name)) + .join(format!("{}.gbdialog", bot_name)) + .join(format!("{}.bas", tool_name)); + + if data_path.exists() { + return data_path; + } + + // Try work directory (for development/testing) + let work_path = Path::new(&home_dir) + .join("gb") + .join("work") + .join(format!("{}.gbai", bot_name)) + .join(format!("{}.gbdialog", bot_name)) + .join(format!("{}.bas", tool_name)); + + work_path + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_tool_call_glm_format() { + let chunk = r#"{"type":"tool_call","content":{"id":"call_123","type":"function","function":{"name":"test_tool","arguments":"{\"param1\":\"value1\"}"}}}"#; + + let result = ToolExecutor::parse_tool_call(chunk); + assert!(result.is_some()); + + let tool_call = result.unwrap(); + assert_eq!(tool_call.tool_name, "test_tool"); + assert_eq!(tool_call.arguments["param1"], "value1"); + } + + #[test] + fn test_parse_tool_call_openai_format() { + let chunk = r#"{"id":"call_123","type":"function","function":{"name":"test_tool","arguments":"{\"param1\":\"value1\"}"}}"#; + + let result = ToolExecutor::parse_tool_call(chunk); + assert!(result.is_some()); + + let tool_call = result.unwrap(); + assert_eq!(tool_call.tool_name, "test_tool"); + assert_eq!(tool_call.arguments["param1"], "value1"); + } +} diff --git a/src/drive/local_file_monitor.rs b/src/drive/local_file_monitor.rs index 3c2448af1..8dd407604 100644 --- a/src/drive/local_file_monitor.rs +++ b/src/drive/local_file_monitor.rs @@ -1,5 +1,6 @@ use crate::basic::compiler::BasicCompiler; use crate::shared::state::AppState; +use diesel::prelude::*; use log::{debug, error, info, warn}; use std::collections::HashMap; use std::error::Error; @@ -11,6 +12,7 @@ use tokio::sync::RwLock; use tokio::time::Duration; use notify::{RecursiveMode, EventKind, RecommendedWatcher, Watcher}; use serde::{Deserialize, Serialize}; +use uuid::Uuid; #[derive(Debug, Clone, Serialize, Deserialize)] struct LocalFileState { @@ -184,8 +186,8 @@ impl LocalFileMonitor { .and_then(|s| s.to_str()) .unwrap_or("unknown"); - // Look for .gbdialog folder inside - let gbdialog_path = path.join(".gbdialog"); + // Look for .gbdialog folder inside (e.g., cristo.gbai/cristo.gbdialog) + let gbdialog_path = path.join(format!("{}.gbdialog", bot_name)); if gbdialog_path.exists() { self.compile_gbdialog(&bot_name, &gbdialog_path).await?; } @@ -264,7 +266,19 @@ impl LocalFileMonitor { let work_dir_clone = work_dir.clone(); let tool_name_clone = tool_name.to_string(); let source_content_clone = source_content.clone(); - let bot_id = uuid::Uuid::new_v4(); // Generate a bot ID or get from somewhere + let bot_name_clone = bot_name.to_string(); + + // Get the actual bot_id from the database for this bot_name + let bot_id = { + use crate::core::shared::models::schema::bots::dsl::*; + let mut conn = state_clone.conn.get() + .map_err(|e| format!("Failed to get DB connection: {}", e))?; + + bots.filter(name.eq(&bot_name_clone)) + .select(id) + .first::(&mut *conn) + .map_err(|e| format!("Failed to get bot_id for '{}': {}", bot_name_clone, e))? + }; tokio::task::spawn_blocking(move || { std::fs::create_dir_all(&work_dir_clone)?; @@ -274,8 +288,9 @@ impl LocalFileMonitor { let result = compiler.compile_file(local_source_path.to_str().unwrap(), work_dir_clone.to_str().unwrap())?; if let Some(mcp_tool) = result.mcp_tool { info!( - "[LOCAL_MONITOR] MCP tool generated with {} parameters", - mcp_tool.input_schema.properties.len() + "[LOCAL_MONITOR] MCP tool generated with {} parameters for bot {}", + mcp_tool.input_schema.properties.len(), + bot_name_clone ); } Ok::<(), Box>(()) diff --git a/src/llm/mod.rs b/src/llm/mod.rs index ce7c1414a..5a6f5c9ef 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -11,11 +11,13 @@ pub mod episodic_memory; pub mod glm; pub mod llm_models; pub mod local; +pub mod rate_limiter; pub mod smart_router; pub use claude::ClaudeClient; pub use glm::GLMClient; pub use llm_models::get_handler; +pub use rate_limiter::{ApiRateLimiter, RateLimits}; #[async_trait] pub trait LLMProvider: Send + Sync { @@ -48,6 +50,7 @@ pub struct OpenAIClient { client: reqwest::Client, base_url: String, endpoint_path: String, + rate_limiter: Arc, } impl OpenAIClient { @@ -164,10 +167,21 @@ impl OpenAIClient { "/v1/chat/completions".to_string() }; + // Detect API provider and set appropriate rate limits + let rate_limiter = if base.contains("groq.com") { + ApiRateLimiter::new(RateLimits::groq_free_tier()) + } else if base.contains("openai.com") { + ApiRateLimiter::new(RateLimits::openai_free_tier()) + } else { + // Default to unlimited for other providers (local models, etc.) + ApiRateLimiter::unlimited() + }; + Self { client: reqwest::Client::new(), base_url: base, endpoint_path: endpoint, + rate_limiter: Arc::new(rate_limiter), } } @@ -315,6 +329,13 @@ impl LLMProvider for OpenAIClient { let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit); + // Check rate limits before making the request + let estimated_tokens = OpenAIClient::estimate_messages_tokens(&messages); + if let Err(e) = self.rate_limiter.acquire(estimated_tokens).await { + error!("Rate limit exceeded: {}", e); + return Err(Box::new(e) as Box); + } + let full_url = format!("{}{}", self.base_url, self.endpoint_path); let auth_header = format!("Bearer {}", key); diff --git a/src/llm/rate_limiter.rs b/src/llm/rate_limiter.rs new file mode 100644 index 000000000..793cc4d1c --- /dev/null +++ b/src/llm/rate_limiter.rs @@ -0,0 +1,239 @@ +// Rate limiter for LLM API calls +use governor::{ + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, + Quota, RateLimiter, +}; +use std::num::NonZeroU32; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::Semaphore; + +/// Rate limits for an API provider +#[derive(Debug, Clone, Copy)] +pub struct RateLimits { + pub requests_per_minute: u32, + pub tokens_per_minute: u32, + pub requests_per_day: u32, + pub tokens_per_day: u32, +} + +impl RateLimits { + /// Groq free tier rate limits + pub const fn groq_free_tier() -> Self { + Self { + requests_per_minute: 30, + tokens_per_minute: 8_000, + requests_per_day: 1_000, + tokens_per_day: 200_000, + } + } + + /// OpenAI free tier rate limits + pub const fn openai_free_tier() -> Self { + Self { + requests_per_minute: 3, + tokens_per_minute: 40_000, + requests_per_day: 200, + tokens_per_day: 150_000, + } + } + + /// No rate limiting (for local models) + pub const fn unlimited() -> Self { + Self { + requests_per_minute: u32::MAX, + tokens_per_minute: u32::MAX, + requests_per_day: u32::MAX, + tokens_per_day: u32::MAX, + } + } +} + +/// A rate limiter for API requests +pub struct ApiRateLimiter { + requests_per_minute: Arc>, + tokens_per_minute: Arc, + // Track daily request count with a simple counter and reset time + daily_request_count: Arc, + daily_request_reset: Arc, + daily_token_count: Arc, + daily_token_reset: Arc, + requests_per_day: u32, + tokens_per_day: u32, +} + +impl std::fmt::Debug for ApiRateLimiter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApiRateLimiter") + .field("requests_per_minute", &self.requests_per_minute) + .field("tokens_per_minute", &"Semaphore") + .field("daily_request_count", &self.daily_request_count) + .field("daily_token_count", &self.daily_token_count) + .field("requests_per_day", &self.requests_per_day) + .field("tokens_per_day", &self.tokens_per_day) + .finish() + } +} + +impl Clone for ApiRateLimiter { + fn clone(&self) -> Self { + Self { + requests_per_minute: Arc::clone(&self.requests_per_minute), + tokens_per_minute: Arc::clone(&self.tokens_per_minute), + daily_request_count: Arc::clone(&self.daily_request_count), + daily_request_reset: Arc::clone(&self.daily_request_reset), + daily_token_count: Arc::clone(&self.daily_token_count), + daily_token_reset: Arc::clone(&self.daily_token_reset), + requests_per_day: self.requests_per_day, + tokens_per_day: self.tokens_per_day, + } + } +} + +impl ApiRateLimiter { + /// Create a new rate limiter with the specified limits + pub fn new(limits: RateLimits) -> Self { + // Requests per minute limiter + let rpm_quota = Quota::per_minute(NonZeroU32::new(limits.requests_per_minute).unwrap()); + let requests_per_minute = Arc::new(RateLimiter::direct(rpm_quota)); + + // Tokens per minute (using semaphore as we need to track token count) + let tokens_per_minute = Arc::new(Semaphore::new( + limits.tokens_per_minute.try_into().unwrap_or(usize::MAX) + )); + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tomorrow = now + 86400; + + Self { + requests_per_minute, + tokens_per_minute, + daily_request_count: Arc::new(std::sync::atomic::AtomicU32::new(0)), + daily_request_reset: Arc::new(std::sync::atomic::AtomicU64::new(tomorrow)), + daily_token_count: Arc::new(std::sync::atomic::AtomicU32::new(0)), + daily_token_reset: Arc::new(std::sync::atomic::AtomicU64::new(tomorrow)), + requests_per_day: limits.requests_per_day, + tokens_per_day: limits.tokens_per_day, + } + } + + /// Create an unlimited rate limiter (for local models) + pub fn unlimited() -> Self { + Self::new(RateLimits::unlimited()) + } + + /// Check if daily limits need resetting and reset if needed + fn check_and_reset_daily(&self) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let reset_time = self.daily_request_reset.load(std::sync::atomic::Ordering::Relaxed); + + if now >= reset_time { + // Reset counters + self.daily_request_count.store(0, std::sync::atomic::Ordering::Relaxed); + self.daily_token_count.store(0, std::sync::atomic::Ordering::Relaxed); + + // Set new reset time to tomorrow + let tomorrow = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + 86400; + self.daily_request_reset.store(tomorrow, std::sync::atomic::Ordering::Relaxed); + self.daily_token_reset.store(tomorrow, std::sync::atomic::Ordering::Relaxed); + } + } + + /// Acquire permission for a request with estimated token count + /// Returns when the request can proceed + pub async fn acquire(&self, estimated_tokens: usize) -> Result<(), RateLimitError> { + // Check and reset daily limits if needed + self.check_and_reset_daily(); + + // Check request rate limits + self.requests_per_minute.until_ready().await; + + // Check daily request limit + let current_requests = self.daily_request_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if current_requests >= self.requests_per_day { + self.daily_request_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + return Err(RateLimitError::DailyRateLimitExceeded); + } + + // Check token rate limits + let tokens_to_acquire = (estimated_tokens.min(8_000) as u32) as usize; + + // Try to acquire token permits for minute limit + let tpm_available = self.tokens_per_minute.available_permits(); + if tpm_available < tokens_to_acquire { + return Err(RateLimitError::TokenRateLimitExceeded); + } + + // Check daily token limit + let current_tokens = self.daily_token_count.fetch_add(tokens_to_acquire as u32, std::sync::atomic::Ordering::Relaxed); + if current_tokens + (tokens_to_acquire as u32) > self.tokens_per_day { + self.daily_token_count.fetch_sub(tokens_to_acquire as u32, std::sync::atomic::Ordering::Relaxed); + return Err(RateLimitError::DailyTokenLimitExceeded); + } + + // Acquire the permits (this will wait if needed) + let semaphore = Arc::clone(&self.tokens_per_minute); + let _permits = semaphore.acquire_many_owned(tokens_to_acquire as u32).await; + // Permits are held until the request completes + + Ok(()) + } + + /// Release token permits after request completes + pub fn release_tokens(&self, _tokens: u32) { + // Note: We don't release the daily token count as it's already "used" + // But we do need to release the semaphore permits for the minute limit + // The permits will be automatically released when dropped + } +} + +#[derive(Debug, Clone)] +pub enum RateLimitError { + RateLimitExceeded, + DailyRateLimitExceeded, + TokenRateLimitExceeded, + DailyTokenLimitExceeded, +} + +impl std::fmt::Display for RateLimitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RateLimitError::RateLimitExceeded => write!(f, "Rate limit exceeded"), + RateLimitError::DailyRateLimitExceeded => write!(f, "Daily request limit exceeded"), + RateLimitError::TokenRateLimitExceeded => write!(f, "Token per minute limit exceeded"), + RateLimitError::DailyTokenLimitExceeded => write!(f, "Daily token limit exceeded"), + } + } +} + +impl std::error::Error for RateLimitError {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rate_limits_display() { + let limits = RateLimits::groq_free_tier(); + assert_eq!(limits.requests_per_minute, 30); + assert_eq!(limits.tokens_per_minute, 8_000); + assert_eq!(limits.requests_per_day, 1_000); + assert_eq!(limits.tokens_per_day, 200_000); + } + + #[test] + fn test_unlimited_limits() { + let limits = RateLimits::unlimited(); + assert_eq!(limits.requests_per_minute, u32::MAX); + } +}