Fix start.bas repeated execution and add tool calling system
- Add Redis-based tracking to prevent start.bas from running repeatedly when clicking suggestion buttons. start.bas now executes only once per session with a 24-hour expiration on the tracking key. - Add generic tool executor (ToolExecutor) for parsing and executing tool calls from any LLM provider. Works with Claude, OpenAI, and other providers that use standard tool calling formats. - Update both start.bas execution paths (WebSocket handler and LLM message handler) to check Redis before executing. - Fix suggestion duplication by clearing suggestions from Redis after fetching them. - Add rate limiter for LLM API calls. - Improve error handling and logging throughout. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
6215908536
commit
76abcea5e9
16 changed files with 1838 additions and 182 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"base_url": "http://localhost:8300",
|
"base_url": "http://localhost:8300",
|
||||||
"default_org": {
|
"default_org": {
|
||||||
"id": "357870945618100238",
|
"id": "359341929336406018",
|
||||||
"name": "default",
|
"name": "default",
|
||||||
"domain": "default.localhost"
|
"domain": "default.localhost"
|
||||||
},
|
},
|
||||||
|
|
@ -13,8 +13,8 @@
|
||||||
"first_name": "Admin",
|
"first_name": "Admin",
|
||||||
"last_name": "User"
|
"last_name": "User"
|
||||||
},
|
},
|
||||||
"admin_token": "RflPqOgYM-BtinaBTyCaY8hX-_koTwC65gCg1Kpf7Sfhlc0ZOLZvIr-XsOYXmckPLBAWzjU",
|
"admin_token": "JP3vRy5eNankB5x204_ANNe4GYXt3Igb5xZuJNW17Chf42fsAcUh4Iw_U6JatCI7dJkRMsw",
|
||||||
"project_id": "",
|
"project_id": "",
|
||||||
"client_id": "357870946289254414",
|
"client_id": "359341929806233602",
|
||||||
"client_secret": "q20LOjW5Vdjzp57Cw8EuFt7sILEd8VeSeGPvrhB63880GLgaJZpcWeRgUwdGET2x"
|
"client_secret": "ktwHx7cmAIdp7NC2x3BYPD1NWiuvLAWHYHF6EyjbC0gZAgTgjFEDNA7KukXMAgdC"
|
||||||
}
|
}
|
||||||
|
|
@ -25,6 +25,8 @@ pub struct ParamDeclaration {
|
||||||
pub example: Option<String>,
|
pub example: Option<String>,
|
||||||
pub description: String,
|
pub description: String,
|
||||||
pub required: bool,
|
pub required: bool,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub enum_values: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolDefinition {
|
pub struct ToolDefinition {
|
||||||
|
|
@ -80,6 +82,8 @@ pub struct OpenAIProperty {
|
||||||
pub description: String,
|
pub description: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub example: Option<String>,
|
pub example: Option<String>,
|
||||||
|
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
|
||||||
|
pub enum_values: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BasicCompiler {
|
pub struct BasicCompiler {
|
||||||
|
|
@ -115,8 +119,12 @@ impl BasicCompiler {
|
||||||
.file_stem()
|
.file_stem()
|
||||||
.and_then(|s| s.to_str())
|
.and_then(|s| s.to_str())
|
||||||
.ok_or("Invalid file name")?;
|
.ok_or("Invalid file name")?;
|
||||||
|
|
||||||
|
// Generate ADD SUGGESTION commands for enum parameters
|
||||||
|
let source_with_suggestions = self.generate_enum_suggestions(&source_content, &tool_def)?;
|
||||||
|
|
||||||
let ast_path = format!("{output_dir}/{file_name}.ast");
|
let ast_path = format!("{output_dir}/{file_name}.ast");
|
||||||
let ast_content = self.preprocess_basic(&source_content, source_path, self.bot_id)?;
|
let ast_content = self.preprocess_basic(&source_with_suggestions, source_path, self.bot_id)?;
|
||||||
fs::write(&ast_path, &ast_content).map_err(|e| format!("Failed to write AST file: {e}"))?;
|
fs::write(&ast_path, &ast_content).map_err(|e| format!("Failed to write AST file: {e}"))?;
|
||||||
let (mcp_json, tool_json) = if tool_def.parameters.is_empty() {
|
let (mcp_json, tool_json) = if tool_def.parameters.is_empty() {
|
||||||
(None, None)
|
(None, None)
|
||||||
|
|
@ -206,6 +214,36 @@ impl BasicCompiler {
|
||||||
.map(|end| rest[start + 1..start + 1 + end].to_string())
|
.map(|end| rest[start + 1..start + 1 + end].to_string())
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Parse ENUM array directly from PARAM statement
|
||||||
|
// Syntax: PARAM name AS TYPE ENUM ["value1", "value2", ...]
|
||||||
|
let enum_values = if let Some(enum_pos) = line.find("ENUM") {
|
||||||
|
let rest = &line[enum_pos + 4..].trim();
|
||||||
|
if let Some(start) = rest.find('[') {
|
||||||
|
if let Some(end) = rest[start..].find(']') {
|
||||||
|
let array_content = &rest[start + 1..start + end];
|
||||||
|
// Parse the array elements
|
||||||
|
let values: Vec<String> = array_content
|
||||||
|
.split(',')
|
||||||
|
.map(|s| {
|
||||||
|
s.trim()
|
||||||
|
.trim_matches('"')
|
||||||
|
.trim_matches('\'')
|
||||||
|
.to_string()
|
||||||
|
})
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.collect();
|
||||||
|
Some(values)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let description = if let Some(desc_pos) = line.find("DESCRIPTION") {
|
let description = if let Some(desc_pos) = line.find("DESCRIPTION") {
|
||||||
let rest = &line[desc_pos + 11..].trim();
|
let rest = &line[desc_pos + 11..].trim();
|
||||||
if let Some(start) = rest.find('"') {
|
if let Some(start) = rest.find('"') {
|
||||||
|
|
@ -220,12 +258,14 @@ impl BasicCompiler {
|
||||||
} else {
|
} else {
|
||||||
"".to_string()
|
"".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(ParamDeclaration {
|
Ok(Some(ParamDeclaration {
|
||||||
name,
|
name,
|
||||||
param_type: Self::normalize_type(¶m_type),
|
param_type: Self::normalize_type(¶m_type),
|
||||||
example,
|
example,
|
||||||
description,
|
description,
|
||||||
required: true,
|
required: true,
|
||||||
|
enum_values,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
fn normalize_type(basic_type: &str) -> String {
|
fn normalize_type(basic_type: &str) -> String {
|
||||||
|
|
@ -239,6 +279,62 @@ impl BasicCompiler {
|
||||||
_ => "string".to_string(),
|
_ => "string".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate ADD SUGGESTION commands for parameters with enum values
|
||||||
|
fn generate_enum_suggestions(
|
||||||
|
&self,
|
||||||
|
source: &str,
|
||||||
|
tool_def: &ToolDefinition,
|
||||||
|
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut suggestion_lines = Vec::new();
|
||||||
|
|
||||||
|
// Generate ADD SUGGESTION TEXT commands for each parameter with enum values
|
||||||
|
// These will send the enum value as a text message when clicked
|
||||||
|
for param in &tool_def.parameters {
|
||||||
|
if let Some(ref enum_values) = param.enum_values {
|
||||||
|
// For each enum value, create a suggestion button
|
||||||
|
for enum_value in enum_values {
|
||||||
|
// Use the enum value as both the text to send and the button label
|
||||||
|
let suggestion_cmd = format!(
|
||||||
|
"ADD SUGGESTION TEXT \"{}\" AS \"{}\"",
|
||||||
|
enum_value, enum_value
|
||||||
|
);
|
||||||
|
suggestion_lines.push(suggestion_cmd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert suggestions after the DESCRIPTION line (or at end if no DESCRIPTION)
|
||||||
|
let lines: Vec<&str> = source.lines().collect();
|
||||||
|
let mut inserted = false;
|
||||||
|
|
||||||
|
for line in lines.iter() {
|
||||||
|
result.push_str(line);
|
||||||
|
result.push('\n');
|
||||||
|
|
||||||
|
// Insert suggestions after DESCRIPTION line
|
||||||
|
if !inserted && line.trim().starts_with("DESCRIPTION ") {
|
||||||
|
// Insert suggestions after this line
|
||||||
|
for suggestion in &suggestion_lines {
|
||||||
|
result.push_str(suggestion);
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
inserted = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we didn't find a DESCRIPTION line, insert at the end
|
||||||
|
if !inserted && !suggestion_lines.is_empty() {
|
||||||
|
for suggestion in &suggestion_lines {
|
||||||
|
result.push_str(suggestion);
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
fn generate_mcp_tool(
|
fn generate_mcp_tool(
|
||||||
tool_def: &ToolDefinition,
|
tool_def: &ToolDefinition,
|
||||||
) -> Result<MCPTool, Box<dyn Error + Send + Sync>> {
|
) -> Result<MCPTool, Box<dyn Error + Send + Sync>> {
|
||||||
|
|
@ -279,6 +375,7 @@ impl BasicCompiler {
|
||||||
prop_type: param.param_type.clone(),
|
prop_type: param.param_type.clone(),
|
||||||
description: param.description.clone(),
|
description: param.description.clone(),
|
||||||
example: param.example.clone(),
|
example: param.example.clone(),
|
||||||
|
enum_values: param.enum_values.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
if param.required {
|
if param.required {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::shared::models::UserSession;
|
use crate::shared::models::UserSession;
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
use log::{error, trace};
|
use log::{error, info, trace};
|
||||||
use rhai::{Dynamic, Engine};
|
use rhai::{Dynamic, Engine};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -64,9 +64,14 @@ pub fn add_suggestion_keyword(
|
||||||
let cache = state.cache.clone();
|
let cache = state.cache.clone();
|
||||||
let cache2 = state.cache.clone();
|
let cache2 = state.cache.clone();
|
||||||
let cache3 = state.cache.clone();
|
let cache3 = state.cache.clone();
|
||||||
|
let cache4 = state.cache.clone();
|
||||||
|
let cache5 = state.cache.clone();
|
||||||
let user_session2 = user_session.clone();
|
let user_session2 = user_session.clone();
|
||||||
let user_session3 = user_session.clone();
|
let user_session3 = user_session.clone();
|
||||||
|
let user_session4 = user_session.clone();
|
||||||
|
let user_session5 = user_session.clone();
|
||||||
|
|
||||||
|
// ADD SUGGESTION "context_name" AS "button text"
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(
|
.register_custom_syntax(
|
||||||
["ADD", "SUGGESTION", "$expr$", "AS", "$expr$"],
|
["ADD", "SUGGESTION", "$expr$", "AS", "$expr$"],
|
||||||
|
|
@ -82,6 +87,45 @@ pub fn add_suggestion_keyword(
|
||||||
)
|
)
|
||||||
.expect("valid syntax registration");
|
.expect("valid syntax registration");
|
||||||
|
|
||||||
|
// ADD SUGGESTION TEXT "$expr$" AS "button text"
|
||||||
|
// Creates a suggestion that sends the text as a user message when clicked
|
||||||
|
engine
|
||||||
|
.register_custom_syntax(
|
||||||
|
["ADD", "SUGGESTION", "TEXT", "$expr$", "AS", "$expr$"],
|
||||||
|
true,
|
||||||
|
move |context, inputs| {
|
||||||
|
let text_value = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||||
|
let button_text = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||||
|
|
||||||
|
add_text_suggestion(cache4.as_ref(), &user_session4, &text_value, &button_text)?;
|
||||||
|
|
||||||
|
Ok(Dynamic::UNIT)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.expect("valid syntax registration");
|
||||||
|
|
||||||
|
// ADD_SUGGESTION_TOOL "tool_name" AS "button text" - underscore version to avoid syntax conflicts
|
||||||
|
engine
|
||||||
|
.register_custom_syntax(
|
||||||
|
["ADD_SUGGESTION_TOOL", "$expr$", "AS", "$expr$"],
|
||||||
|
true,
|
||||||
|
move |context, inputs| {
|
||||||
|
let tool_name = context.eval_expression_tree(&inputs[0])?.to_string();
|
||||||
|
let button_text = context.eval_expression_tree(&inputs[1])?.to_string();
|
||||||
|
|
||||||
|
add_tool_suggestion(
|
||||||
|
cache5.as_ref(),
|
||||||
|
&user_session5,
|
||||||
|
&tool_name,
|
||||||
|
None,
|
||||||
|
&button_text,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Dynamic::UNIT)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.expect("valid syntax registration");
|
||||||
|
|
||||||
engine
|
engine
|
||||||
.register_custom_syntax(
|
.register_custom_syntax(
|
||||||
["ADD", "SUGGESTION", "TOOL", "$expr$", "AS", "$expr$"],
|
["ADD", "SUGGESTION", "TOOL", "$expr$", "AS", "$expr$"],
|
||||||
|
|
@ -210,27 +254,22 @@ fn add_context_suggestion(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_tool_suggestion(
|
fn add_text_suggestion(
|
||||||
cache: Option<&Arc<redis::Client>>,
|
cache: Option<&Arc<redis::Client>>,
|
||||||
user_session: &UserSession,
|
user_session: &UserSession,
|
||||||
tool_name: &str,
|
text_value: &str,
|
||||||
params: Option<Vec<String>>,
|
|
||||||
button_text: &str,
|
button_text: &str,
|
||||||
) -> Result<(), Box<rhai::EvalAltResult>> {
|
) -> Result<(), Box<rhai::EvalAltResult>> {
|
||||||
if let Some(cache_client) = cache {
|
if let Some(cache_client) = cache {
|
||||||
let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id);
|
let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id);
|
||||||
|
|
||||||
let suggestion = json!({
|
let suggestion = json!({
|
||||||
"type": "tool",
|
"type": "text_value",
|
||||||
"tool": tool_name,
|
|
||||||
"text": button_text,
|
"text": button_text,
|
||||||
|
"value": text_value,
|
||||||
"action": {
|
"action": {
|
||||||
"type": "invoke_tool",
|
"type": "send_message",
|
||||||
"tool": tool_name,
|
"message": text_value
|
||||||
"params": params,
|
|
||||||
|
|
||||||
|
|
||||||
"prompt_for_params": params.is_none()
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -250,11 +289,65 @@ fn add_tool_suggestion(
|
||||||
match result {
|
match result {
|
||||||
Ok(length) => {
|
Ok(length) => {
|
||||||
trace!(
|
trace!(
|
||||||
"Added tool suggestion '{}' to session {}, total: {}, has_params: {}",
|
"Added text suggestion '{}' to session {}, total: {}",
|
||||||
|
text_value,
|
||||||
|
user_session.id,
|
||||||
|
length
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => error!("Failed to add text suggestion to Redis: {}", e),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
trace!("No cache configured, text suggestion not added");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_tool_suggestion(
|
||||||
|
cache: Option<&Arc<redis::Client>>,
|
||||||
|
user_session: &UserSession,
|
||||||
|
tool_name: &str,
|
||||||
|
params: Option<Vec<String>>,
|
||||||
|
button_text: &str,
|
||||||
|
) -> Result<(), Box<rhai::EvalAltResult>> {
|
||||||
|
if let Some(cache_client) = cache {
|
||||||
|
let redis_key = format!("suggestions:{}:{}", user_session.user_id, user_session.id);
|
||||||
|
|
||||||
|
// Create action object and serialize it to JSON string
|
||||||
|
let action_obj = json!({
|
||||||
|
"type": "invoke_tool",
|
||||||
|
"tool": tool_name,
|
||||||
|
"params": params,
|
||||||
|
"prompt_for_params": params.is_none()
|
||||||
|
});
|
||||||
|
let action_str = action_obj.to_string();
|
||||||
|
|
||||||
|
let suggestion = json!({
|
||||||
|
"text": button_text,
|
||||||
|
"action": action_str
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut conn = match cache_client.get_connection() {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to connect to cache: {}", e);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let result: Result<i64, redis::RedisError> = redis::cmd("RPUSH")
|
||||||
|
.arg(&redis_key)
|
||||||
|
.arg(suggestion.to_string())
|
||||||
|
.query(&mut conn);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(length) => {
|
||||||
|
info!(
|
||||||
|
"Added tool suggestion '{}' to session {}, total: {}",
|
||||||
tool_name,
|
tool_name,
|
||||||
user_session.id,
|
user_session.id,
|
||||||
length,
|
length
|
||||||
params.is_some()
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Err(e) => error!("Failed to add tool suggestion to Redis: {}", e),
|
Err(e) => error!("Failed to add tool suggestion to Redis: {}", e),
|
||||||
|
|
@ -266,6 +359,74 @@ fn add_tool_suggestion(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve suggestions from Valkey/Redis for a given user session
|
||||||
|
/// Returns a vector of Suggestion structs that can be included in BotResponse
|
||||||
|
/// Note: This function clears suggestions from Redis after fetching them to prevent duplicates
|
||||||
|
pub fn get_suggestions(
|
||||||
|
cache: Option<&Arc<redis::Client>>,
|
||||||
|
user_id: &str,
|
||||||
|
session_id: &str,
|
||||||
|
) -> Vec<crate::shared::models::Suggestion> {
|
||||||
|
let mut suggestions = Vec::new();
|
||||||
|
|
||||||
|
if let Some(cache_client) = cache {
|
||||||
|
let redis_key = format!("suggestions:{}:{}", user_id, session_id);
|
||||||
|
|
||||||
|
let mut conn = match cache_client.get_connection() {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to connect to cache: {}", e);
|
||||||
|
return suggestions;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get all suggestions from the Redis list
|
||||||
|
let result: Result<Vec<String>, redis::RedisError> = redis::cmd("LRANGE")
|
||||||
|
.arg(&redis_key)
|
||||||
|
.arg(0)
|
||||||
|
.arg(-1)
|
||||||
|
.query(&mut conn);
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(items) => {
|
||||||
|
for item in items {
|
||||||
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&item) {
|
||||||
|
let suggestion = crate::shared::models::Suggestion {
|
||||||
|
text: json["text"].as_str().unwrap_or("").to_string(),
|
||||||
|
context: json["context"].as_str().map(|s| s.to_string()),
|
||||||
|
action: json.get("action").and_then(|v| serde_json::to_string(v).ok()),
|
||||||
|
icon: json.get("icon").and_then(|v| v.as_str()).map(|s| s.to_string()),
|
||||||
|
};
|
||||||
|
suggestions.push(suggestion);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!(
|
||||||
|
"[SUGGESTIONS] Retrieved {} suggestions for session {}",
|
||||||
|
suggestions.len(),
|
||||||
|
session_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clear suggestions from Redis after fetching to prevent them from being sent again
|
||||||
|
if !suggestions.is_empty() {
|
||||||
|
let _: Result<i64, redis::RedisError> = redis::cmd("DEL")
|
||||||
|
.arg(&redis_key)
|
||||||
|
.query(&mut conn);
|
||||||
|
info!(
|
||||||
|
"[SUGGESTIONS] Cleared {} suggestions from Redis for session {}",
|
||||||
|
suggestions.len(),
|
||||||
|
session_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => error!("Failed to get suggestions from Redis: {}", e),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info!("[SUGGESTIONS] No cache configured, cannot retrieve suggestions");
|
||||||
|
}
|
||||||
|
|
||||||
|
suggestions
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use uuid::Uuid;
|
||||||
|
|
||||||
pub fn set_bot_memory_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
pub fn set_bot_memory_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||||||
let state_clone = Arc::clone(&state);
|
let state_clone = Arc::clone(&state);
|
||||||
let user_clone = user;
|
let user_clone = user.clone();
|
||||||
|
|
||||||
|
|
||||||
engine
|
engine
|
||||||
|
|
@ -110,6 +110,123 @@ pub fn set_bot_memory_keyword(state: Arc<AppState>, user: UserSession, engine: &
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.expect("valid syntax registration");
|
.expect("valid syntax registration");
|
||||||
|
|
||||||
|
// Also register as underscore function for tools (SET_BOT_MEMORY(key, value))
|
||||||
|
let state_clone2 = Arc::clone(&state);
|
||||||
|
let user_clone2 = user.clone();
|
||||||
|
engine.register_fn("SET_BOT_MEMORY", move |key: String, value: String| {
|
||||||
|
let state_for_spawn = Arc::clone(&state_clone2);
|
||||||
|
let user_clone_spawn = user_clone2.clone();
|
||||||
|
let key_clone = key;
|
||||||
|
let value_clone = value;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use crate::shared::models::bot_memories;
|
||||||
|
|
||||||
|
let mut conn = match state_for_spawn.conn.get() {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Failed to acquire database connection for SET_BOT_MEMORY: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let bot_uuid = match Uuid::parse_str(&user_clone_spawn.bot_id.to_string()) {
|
||||||
|
Ok(uuid) => uuid,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Invalid bot ID format: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
|
||||||
|
let existing_memory: Option<Uuid> = bot_memories::table
|
||||||
|
.filter(bot_memories::bot_id.eq(bot_uuid))
|
||||||
|
.filter(bot_memories::key.eq(&key_clone))
|
||||||
|
.select(bot_memories::id)
|
||||||
|
.first(&mut *conn)
|
||||||
|
.optional()
|
||||||
|
.unwrap_or(None);
|
||||||
|
|
||||||
|
if let Some(memory_id) = existing_memory {
|
||||||
|
let update_result = diesel::update(
|
||||||
|
bot_memories::table.filter(bot_memories::id.eq(memory_id)),
|
||||||
|
)
|
||||||
|
.set((
|
||||||
|
bot_memories::value.eq(&value_clone),
|
||||||
|
bot_memories::updated_at.eq(now),
|
||||||
|
))
|
||||||
|
.execute(&mut *conn);
|
||||||
|
|
||||||
|
match update_result {
|
||||||
|
Ok(_) => {
|
||||||
|
trace!(
|
||||||
|
"Updated bot memory for key: {} with value length: {}",
|
||||||
|
key_clone,
|
||||||
|
value_clone.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to update bot memory: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let insert_result = diesel::insert_into(bot_memories::table)
|
||||||
|
.values((
|
||||||
|
bot_memories::id.eq(Uuid::new_v4()),
|
||||||
|
bot_memories::bot_id.eq(bot_uuid),
|
||||||
|
bot_memories::key.eq(&key_clone),
|
||||||
|
bot_memories::value.eq(&value_clone),
|
||||||
|
bot_memories::created_at.eq(now),
|
||||||
|
bot_memories::updated_at.eq(now),
|
||||||
|
))
|
||||||
|
.execute(&mut *conn);
|
||||||
|
|
||||||
|
match insert_result {
|
||||||
|
Ok(_) => {
|
||||||
|
trace!(
|
||||||
|
"Created bot memory for key: {} with value length: {}",
|
||||||
|
key_clone,
|
||||||
|
value_clone.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to insert bot memory: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Also register GET_BOT_MEMORY as underscore function for tools
|
||||||
|
let state_clone3 = Arc::clone(&state);
|
||||||
|
let user_clone3 = user.clone();
|
||||||
|
engine.register_fn("GET_BOT_MEMORY", move |key_param: String| -> String {
|
||||||
|
use crate::shared::models::bot_memories;
|
||||||
|
|
||||||
|
let state = Arc::clone(&state_clone3);
|
||||||
|
let conn_result = state.conn.get();
|
||||||
|
|
||||||
|
if let Ok(mut conn) = conn_result {
|
||||||
|
let bot_uuid = user_clone3.bot_id;
|
||||||
|
|
||||||
|
let memory_value: Option<String> = bot_memories::table
|
||||||
|
.filter(bot_memories::bot_id.eq(bot_uuid))
|
||||||
|
.filter(bot_memories::key.eq(&key_param))
|
||||||
|
.select(bot_memories::value)
|
||||||
|
.first(&mut *conn)
|
||||||
|
.optional()
|
||||||
|
.unwrap_or(None);
|
||||||
|
|
||||||
|
memory_value.unwrap_or_default()
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_bot_memory_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
pub fn get_bot_memory_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::shared::models::UserSession;
|
use crate::shared::models::UserSession;
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
use log::debug;
|
use log::{debug, info};
|
||||||
use rhai::Engine;
|
use rhai::{Dynamic, Engine};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::arrays::register_array_functions;
|
use super::arrays::register_array_functions;
|
||||||
|
|
@ -28,5 +28,21 @@ pub fn register_core_functions(state: Arc<AppState>, user: UserSession, engine:
|
||||||
register_error_functions(state, user, engine);
|
register_error_functions(state, user, engine);
|
||||||
debug!(" * Error handling functions registered");
|
debug!(" * Error handling functions registered");
|
||||||
|
|
||||||
|
// Register send_mail stub function for tools (traces when mail feature is not available)
|
||||||
|
engine.register_fn("send_mail", move |to: Dynamic, subject: Dynamic, body: Dynamic, _attachments: Dynamic| -> String {
|
||||||
|
let to_str = to.to_string();
|
||||||
|
let subject_str = subject.to_string();
|
||||||
|
let body_str = body.to_string();
|
||||||
|
info!(
|
||||||
|
"[TOOL] send_mail called (mail feature not enabled): to='{}', subject='{}', body_len={}",
|
||||||
|
to_str,
|
||||||
|
subject_str,
|
||||||
|
body_str.len()
|
||||||
|
);
|
||||||
|
// Return a fake message ID
|
||||||
|
format!("MSG-{:0X}", chrono::Utc::now().timestamp())
|
||||||
|
});
|
||||||
|
debug!(" * send_mail stub function registered");
|
||||||
|
|
||||||
debug!("All core BASIC functions registered successfully");
|
debug!("All core BASIC functions registered successfully");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,7 @@ pub fn register_save_keyword(state: Arc<AppState>, user: UserSession, engine: &m
|
||||||
|
|
||||||
pub fn register_insert_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
pub fn register_insert_keyword(state: Arc<AppState>, user: UserSession, engine: &mut Engine) {
|
||||||
let state_clone = Arc::clone(&state);
|
let state_clone = Arc::clone(&state);
|
||||||
|
let user_clone = user.clone();
|
||||||
let user_roles = UserRoles::from_user_session(&user);
|
let user_roles = UserRoles::from_user_session(&user);
|
||||||
|
|
||||||
engine
|
engine
|
||||||
|
|
@ -79,10 +80,18 @@ pub fn register_insert_keyword(state: Arc<AppState>, user: UserSession, engine:
|
||||||
|
|
||||||
trace!("INSERT into table: {}", table);
|
trace!("INSERT into table: {}", table);
|
||||||
|
|
||||||
let mut conn = state_clone
|
// Get bot's database connection instead of main connection
|
||||||
.conn
|
let bot_pool = state_clone
|
||||||
.get()
|
.bot_database_manager
|
||||||
.map_err(|e| format!("DB error: {}", e))?;
|
.get_bot_pool(user_clone.bot_id);
|
||||||
|
|
||||||
|
let mut conn = match bot_pool {
|
||||||
|
Ok(pool) => pool.get().map_err(|e| format!("Bot DB error: {}", e))?,
|
||||||
|
Err(_) => state_clone
|
||||||
|
.conn
|
||||||
|
.get()
|
||||||
|
.map_err(|e| format!("DB error: {}", e))?,
|
||||||
|
};
|
||||||
|
|
||||||
// Check write access
|
// Check write access
|
||||||
if let Err(e) =
|
if let Err(e) =
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,89 @@
|
||||||
use chrono::{Datelike, NaiveDateTime, Timelike};
|
use chrono::{Datelike, NaiveDateTime, Timelike};
|
||||||
use num_format::{Locale, ToFormattedString};
|
use num_format::{Locale, ToFormattedString};
|
||||||
use rhai::{Dynamic, Engine};
|
use rhai::{Dynamic, Engine, Map};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
pub fn format_keyword(engine: &mut Engine) {
|
|
||||||
engine
|
/// Format a value (number, date, or Map from NOW()) according to a pattern
|
||||||
.register_custom_syntax(["FORMAT", "$expr$", "$expr$"], false, {
|
fn format_impl(value: Dynamic, pattern: String) -> Result<String, String> {
|
||||||
move |context, inputs| {
|
// Handle Map objects (like NOW() which returns a Map with formatted property)
|
||||||
let value_dyn = context.eval_expression_tree(&inputs[0])?;
|
if value.is::<Map>() {
|
||||||
let pattern_dyn = context.eval_expression_tree(&inputs[1])?;
|
let map = value.cast::<Map>();
|
||||||
let value_str = value_dyn.to_string();
|
// Try to get the 'formatted' property first
|
||||||
let pattern = pattern_dyn.to_string();
|
if let Some(formatted_val) = map.get("formatted") {
|
||||||
if let Ok(num) = f64::from_str(&value_str) {
|
let value_str = formatted_val.to_string();
|
||||||
let formatted = if pattern.starts_with('N') || pattern.starts_with('C') {
|
if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") {
|
||||||
let (prefix, decimals, locale_tag) = parse_pattern(&pattern);
|
let formatted = apply_date_format(&dt, &pattern);
|
||||||
let locale = get_locale(&locale_tag);
|
return Ok(formatted);
|
||||||
let symbol = if prefix == "C" {
|
|
||||||
get_currency_symbol(&locale_tag)
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
};
|
|
||||||
let int_part = num.trunc() as i64;
|
|
||||||
let frac_part = num.fract();
|
|
||||||
if decimals == 0 {
|
|
||||||
format!("{}{}", symbol, int_part.to_formatted_string(&locale))
|
|
||||||
} else {
|
|
||||||
let frac_scaled =
|
|
||||||
((frac_part * 10f64.powi(decimals as i32)).round()) as i64;
|
|
||||||
let decimal_sep = match locale_tag.as_str() {
|
|
||||||
"pt" | "fr" | "es" | "it" | "de" => ",",
|
|
||||||
_ => ".",
|
|
||||||
};
|
|
||||||
format!(
|
|
||||||
"{}{}{}{:0width$}",
|
|
||||||
symbol,
|
|
||||||
int_part.to_formatted_string(&locale),
|
|
||||||
decimal_sep,
|
|
||||||
frac_scaled,
|
|
||||||
width = decimals
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
match pattern.as_str() {
|
|
||||||
"n" | "F" => format!("{num:.2}"),
|
|
||||||
"0%" => format!("{:.0}%", num * 100.0),
|
|
||||||
_ => format!("{num}"),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return Ok(Dynamic::from(formatted));
|
|
||||||
}
|
|
||||||
if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") {
|
|
||||||
let formatted = apply_date_format(&dt, &pattern);
|
|
||||||
return Ok(Dynamic::from(formatted));
|
|
||||||
}
|
|
||||||
let formatted = apply_text_placeholders(&value_str, &pattern);
|
|
||||||
Ok(Dynamic::from(formatted))
|
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.expect("valid syntax registration");
|
// If no formatted property or parsing failed, try to construct from components
|
||||||
|
if let (Some(year), Some(month), Some(day)) = (
|
||||||
|
map.get("year").and_then(|v| v.as_int().ok()),
|
||||||
|
map.get("month").and_then(|v| v.as_int().ok()),
|
||||||
|
map.get("day").and_then(|v| v.as_int().ok()),
|
||||||
|
) {
|
||||||
|
let value_str = format!("{:04}-{:02}-{:02}", year, month, day);
|
||||||
|
if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d") {
|
||||||
|
let formatted = apply_date_format(&dt, &pattern);
|
||||||
|
return Ok(formatted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: return empty string for unsupported map format
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle string/number values
|
||||||
|
let value_str = value.to_string();
|
||||||
|
if let Ok(num) = f64::from_str(&value_str) {
|
||||||
|
let formatted = if pattern.starts_with('N') || pattern.starts_with('C') {
|
||||||
|
let (prefix, decimals, locale_tag) = parse_pattern(&pattern);
|
||||||
|
let locale = get_locale(&locale_tag);
|
||||||
|
let symbol = if prefix == "C" {
|
||||||
|
get_currency_symbol(&locale_tag)
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
};
|
||||||
|
let int_part = num.trunc() as i64;
|
||||||
|
let frac_part = num.fract();
|
||||||
|
if decimals == 0 {
|
||||||
|
format!("{}{}", symbol, int_part.to_formatted_string(&locale))
|
||||||
|
} else {
|
||||||
|
let frac_scaled =
|
||||||
|
((frac_part * 10f64.powi(decimals as i32)).round()) as i64;
|
||||||
|
let decimal_sep = match locale_tag.as_str() {
|
||||||
|
"pt" | "fr" | "es" | "it" | "de" => ",",
|
||||||
|
_ => ".",
|
||||||
|
};
|
||||||
|
format!(
|
||||||
|
"{}{}{}{:0width$}",
|
||||||
|
symbol,
|
||||||
|
int_part.to_formatted_string(&locale),
|
||||||
|
decimal_sep,
|
||||||
|
frac_scaled,
|
||||||
|
width = decimals
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match pattern.as_str() {
|
||||||
|
"n" | "F" => format!("{num:.2}"),
|
||||||
|
"0%" => format!("{:.0}%", num * 100.0),
|
||||||
|
_ => format!("{num}"),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return Ok(formatted);
|
||||||
|
}
|
||||||
|
if let Ok(dt) = NaiveDateTime::parse_from_str(&value_str, "%Y-%m-%d %H:%M:%S") {
|
||||||
|
let formatted = apply_date_format(&dt, &pattern);
|
||||||
|
return Ok(formatted);
|
||||||
|
}
|
||||||
|
let formatted = apply_text_placeholders(&value_str, &pattern);
|
||||||
|
Ok(formatted)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_keyword(engine: &mut Engine) {
|
||||||
|
// Register FORMAT as a regular function with two parameters
|
||||||
|
engine.register_fn("FORMAT", format_impl);
|
||||||
|
engine.register_fn("format", format_impl);
|
||||||
}
|
}
|
||||||
fn parse_pattern(pattern: &str) -> (String, usize, String) {
|
fn parse_pattern(pattern: &str) -> (String, usize, String) {
|
||||||
let mut prefix = String::new();
|
let mut prefix = String::new();
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,7 @@ pub fn get_all_keywords() -> Vec<String> {
|
||||||
"SEND TEMPLATE".to_string(),
|
"SEND TEMPLATE".to_string(),
|
||||||
"SMS".to_string(),
|
"SMS".to_string(),
|
||||||
"ADD SUGGESTION".to_string(),
|
"ADD SUGGESTION".to_string(),
|
||||||
|
"ADD_SUGGESTION_TOOL".to_string(),
|
||||||
"CLEAR SUGGESTIONS".to_string(),
|
"CLEAR SUGGESTIONS".to_string(),
|
||||||
"ADD TOOL".to_string(),
|
"ADD TOOL".to_string(),
|
||||||
"CLEAR TOOLS".to_string(),
|
"CLEAR TOOLS".to_string(),
|
||||||
|
|
|
||||||
|
|
@ -257,6 +257,69 @@ pub fn send_mail_keyword(state: Arc<AppState>, user: UserSession, engine: &mut E
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.expect("valid syntax registration");
|
.expect("valid syntax registration");
|
||||||
|
|
||||||
|
// Register send_mail as a regular function (for function call style: send_mail(to, subject, body, []))
|
||||||
|
let state_fn = Arc::clone(&state);
|
||||||
|
let user_fn = user.clone();
|
||||||
|
engine.register_fn("send_mail", move |to: Dynamic, subject: Dynamic, body: Dynamic, attachments: Dynamic| -> String {
|
||||||
|
// Convert parameters to strings
|
||||||
|
let to_str = to.to_string();
|
||||||
|
let subject_str = subject.to_string();
|
||||||
|
// Convert body to string
|
||||||
|
let body_str = body.to_string();
|
||||||
|
// Convert body to string
|
||||||
|
let body_str = body.to_string();
|
||||||
|
|
||||||
|
// Convert attachments to Vec<String>
|
||||||
|
let mut atts = Vec::new();
|
||||||
|
if attachments.is_array() {
|
||||||
|
if let Ok(arr) = attachments.cast::<rhai::Array>() {
|
||||||
|
for item in arr.iter() {
|
||||||
|
atts.push(item.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !attachments.to_string().is_empty() {
|
||||||
|
atts.push(attachments.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute in blocking thread
|
||||||
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
|
let state_for_task = Arc::clone(&state_fn);
|
||||||
|
let user_for_task = user_fn.clone();
|
||||||
|
|
||||||
|
std::thread::spawn(move || {
|
||||||
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.worker_threads(2)
|
||||||
|
.enable_all()
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let result = if let Ok(rt) = rt {
|
||||||
|
rt.block_on(async move {
|
||||||
|
execute_send_mail(&state_for_task, &user_for_task, &to_str, &subject_str, &body_str, atts, None).await
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err("Failed to build tokio runtime".to_string())
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = tx.send(result);
|
||||||
|
});
|
||||||
|
|
||||||
|
match rx.recv_timeout(std::time::Duration::from_secs(30)) {
|
||||||
|
Ok(Ok(message_id)) => message_id,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
log::error!("send_mail failed: {}", e);
|
||||||
|
String::new()
|
||||||
|
},
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||||||
|
log::error!("send_mail timed out");
|
||||||
|
String::new()
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("send_mail thread failed: {}", e);
|
||||||
|
String::new()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_send_mail(
|
async fn execute_send_mail(
|
||||||
|
|
|
||||||
|
|
@ -248,36 +248,56 @@ fn parse_field_definition(
|
||||||
return Err("Empty field definition".into());
|
return Err("Empty field definition".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let field_name = parts[0].to_string();
|
// Handle FIELD keyword: "FIELD id AS STRING" -> ["FIELD", "id", "AS", "STRING"]
|
||||||
let mut field_type = String::new();
|
// Skip "FIELD" keyword if present
|
||||||
|
let start_idx = if parts[0].eq_ignore_ascii_case("FIELD") { 1 } else { 0 };
|
||||||
|
let as_idx = parts.iter().position(|p| p.eq_ignore_ascii_case("AS"));
|
||||||
|
|
||||||
|
if parts.len() < start_idx + 1 {
|
||||||
|
return Err("Invalid field definition - missing field name".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let field_name = parts[start_idx].to_string();
|
||||||
|
|
||||||
|
// Parse type from "AS type" or use default if no AS keyword
|
||||||
|
let mut field_type = if let Some(as_idx) = as_idx {
|
||||||
|
if as_idx + 1 < parts.len() {
|
||||||
|
parts[as_idx + 1].to_lowercase()
|
||||||
|
} else {
|
||||||
|
"text".to_string()
|
||||||
|
}
|
||||||
|
} else if parts.len() > start_idx + 1 {
|
||||||
|
parts[start_idx + 1].to_lowercase()
|
||||||
|
} else {
|
||||||
|
"text".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle type parameters like STRING(255)
|
||||||
let mut length: Option<i32> = None;
|
let mut length: Option<i32> = None;
|
||||||
let mut precision: Option<i32> = None;
|
let mut precision: Option<i32> = None;
|
||||||
|
if let Some(paren_start) = field_type.find('(') {
|
||||||
|
let base_type = field_type[..paren_start].to_lowercase();
|
||||||
|
let params = &field_type[paren_start + 1..field_type.len() - 1];
|
||||||
|
let param_parts: Vec<&str> = params.split(',').collect();
|
||||||
|
|
||||||
|
if !param_parts.is_empty() {
|
||||||
|
length = param_parts[0].trim().parse().ok();
|
||||||
|
}
|
||||||
|
if param_parts.len() > 1 {
|
||||||
|
precision = param_parts[1].trim().parse().ok();
|
||||||
|
}
|
||||||
|
field_type = base_type;
|
||||||
|
}
|
||||||
|
|
||||||
let mut is_key = false;
|
let mut is_key = false;
|
||||||
let mut reference_table: Option<String> = None;
|
let mut reference_table: Option<String> = None;
|
||||||
|
|
||||||
if parts.len() >= 2 {
|
// Process remaining parts for KEY and REFERENCES keywords
|
||||||
let type_part = parts[1];
|
for i in (start_idx + 1)..parts.len() {
|
||||||
|
|
||||||
if let Some(paren_start) = type_part.find('(') {
|
|
||||||
field_type = type_part[..paren_start].to_lowercase();
|
|
||||||
let params = &type_part[paren_start + 1..type_part.len() - 1];
|
|
||||||
let param_parts: Vec<&str> = params.split(',').collect();
|
|
||||||
|
|
||||||
if !param_parts.is_empty() {
|
|
||||||
length = param_parts[0].trim().parse().ok();
|
|
||||||
}
|
|
||||||
if param_parts.len() > 1 {
|
|
||||||
precision = param_parts[1].trim().parse().ok();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
field_type = type_part.to_lowercase();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in 2..parts.len() {
|
|
||||||
let part = parts[i].to_lowercase();
|
let part = parts[i].to_lowercase();
|
||||||
match part.as_str() {
|
match part.as_str() {
|
||||||
"key" => is_key = true,
|
"key" => is_key = true,
|
||||||
|
"as" => {} // Skip AS keyword - already handled above
|
||||||
_ if parts
|
_ if parts
|
||||||
.get(i - 1)
|
.get(i - 1)
|
||||||
.map(|p| p.eq_ignore_ascii_case("references"))
|
.map(|p| p.eq_ignore_ascii_case("references"))
|
||||||
|
|
@ -619,14 +639,20 @@ pub fn process_table_definitions(
|
||||||
table.name, table.connection_name
|
table.name, table.connection_name
|
||||||
);
|
);
|
||||||
|
|
||||||
store_table_definition(&mut conn, bot_id, table)?;
|
// No need to store table definition - just create the table in the bot's database
|
||||||
|
|
||||||
if table.connection_name == "default" {
|
if table.connection_name == "default" {
|
||||||
let create_sql = generate_create_table_sql(table, "postgres");
|
let create_sql = generate_create_table_sql(table, "postgres");
|
||||||
info!("Creating table {} on default connection", table.name);
|
info!("Creating table {} in bot's database (bot_id={})", table.name, bot_id);
|
||||||
trace!("SQL: {}", create_sql);
|
trace!("SQL: {}", create_sql);
|
||||||
|
|
||||||
sql_query(&create_sql).execute(&mut conn)?;
|
// Create table in the bot's own database using BotDatabaseManager
|
||||||
|
if let Err(e) = state.bot_database_manager.create_table_in_bot_database(bot_id, &create_sql) {
|
||||||
|
error!("Failed to create table {} in bot database: {}", table.name, e);
|
||||||
|
// Continue with other tables even if this one fails
|
||||||
|
} else {
|
||||||
|
info!("Successfully created table {} in bot database", table.name);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
match load_connection_config(&state, bot_id, &table.connection_name) {
|
match load_connection_config(&state, bot_id, &table.connection_name) {
|
||||||
Ok(ext_conn) => {
|
Ok(ext_conn) => {
|
||||||
|
|
|
||||||
297
src/basic/mod.rs
297
src/basic/mod.rs
|
|
@ -165,6 +165,9 @@ impl ScriptService {
|
||||||
#[cfg(feature = "chat")]
|
#[cfg(feature = "chat")]
|
||||||
register_bot_keywords(&state, &user, &mut engine);
|
register_bot_keywords(&state, &user, &mut engine);
|
||||||
|
|
||||||
|
// ===== PROCEDURE KEYWORDS (RETURN, etc.) =====
|
||||||
|
keywords::procedures::register_procedure_keywords(state.clone(), user.clone(), &mut engine);
|
||||||
|
|
||||||
// ===== WORKFLOW ORCHESTRATION KEYWORDS =====
|
// ===== WORKFLOW ORCHESTRATION KEYWORDS =====
|
||||||
keywords::orchestration::register_orchestrate_workflow(state.clone(), user.clone(), &mut engine);
|
keywords::orchestration::register_orchestrate_workflow(state.clone(), user.clone(), &mut engine);
|
||||||
keywords::orchestration::register_step_keyword(state.clone(), user.clone(), &mut engine);
|
keywords::orchestration::register_step_keyword(state.clone(), user.clone(), &mut engine);
|
||||||
|
|
@ -561,10 +564,268 @@ impl ScriptService {
|
||||||
Err(parse_error) => Err(Box::new(parse_error.into())),
|
Err(parse_error) => Err(Box::new(parse_error.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compile a tool script (.bas file with PARAM/DESCRIPTION metadata lines)
|
||||||
|
/// Filters out tool metadata before compiling
|
||||||
|
pub fn compile_tool_script(&self, script: &str) -> Result<rhai::AST, Box<EvalAltResult>> {
|
||||||
|
// Filter out PARAM, DESCRIPTION, comment, and empty lines (tool metadata)
|
||||||
|
let executable_script: String = script
|
||||||
|
.lines()
|
||||||
|
.filter(|line| {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
// Keep lines that are NOT PARAM, DESCRIPTION, comments, or empty
|
||||||
|
!(trimmed.starts_with("PARAM ") ||
|
||||||
|
trimmed.starts_with("PARAM\t") ||
|
||||||
|
trimmed.starts_with("DESCRIPTION ") ||
|
||||||
|
trimmed.starts_with("DESCRIPTION\t") ||
|
||||||
|
trimmed.starts_with('\'') || // BASIC comment lines
|
||||||
|
trimmed.is_empty())
|
||||||
|
})
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
info!("[TOOL] Filtered tool metadata: {} -> {} chars", script.len(), executable_script.len());
|
||||||
|
|
||||||
|
// Apply minimal preprocessing for tools (skip variable normalization to avoid breaking multi-line strings)
|
||||||
|
let script = preprocess_switch(&executable_script);
|
||||||
|
let script = Self::convert_multiword_keywords(&script);
|
||||||
|
// Skip normalize_variables_to_lowercase for tools - it breaks multi-line strings
|
||||||
|
// Note: FORMAT is registered as a regular function, so FORMAT(expr, pattern) works directly
|
||||||
|
|
||||||
|
info!("[TOOL] Preprocessed tool script for Rhai compilation");
|
||||||
|
// Convert IF ... THEN / END IF to if ... { }
|
||||||
|
let script = Self::convert_if_then_syntax(&script);
|
||||||
|
// Convert BASIC keywords to lowercase (but preserve variable casing)
|
||||||
|
let script = Self::convert_keywords_to_lowercase(&script);
|
||||||
|
// Save to file for debugging
|
||||||
|
if let Err(e) = std::fs::write("/tmp/tool_preprocessed.bas", &script) {
|
||||||
|
log::warn!("Failed to write preprocessed script: {}", e);
|
||||||
|
}
|
||||||
|
match self.engine.compile(&script) {
|
||||||
|
Ok(ast) => Ok(ast),
|
||||||
|
Err(parse_error) => Err(Box::new(parse_error.into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
pub fn run(&mut self, ast: &rhai::AST) -> Result<Dynamic, Box<EvalAltResult>> {
|
pub fn run(&mut self, ast: &rhai::AST) -> Result<Dynamic, Box<EvalAltResult>> {
|
||||||
self.engine.eval_ast_with_scope(&mut self.scope, ast)
|
self.engine.eval_ast_with_scope(&mut self.scope, ast)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set a variable in the script scope (for tool parameters)
|
||||||
|
pub fn set_variable(&mut self, name: &str, value: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
use rhai::Dynamic;
|
||||||
|
self.scope.set_or_push(name, Dynamic::from(value.to_string()));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert FORMAT(expr, pattern) to FORMAT expr pattern (custom syntax format)
|
||||||
|
/// Also handles RANDOM and other functions that need space-separated arguments
|
||||||
|
fn convert_format_syntax(script: &str) -> String {
|
||||||
|
use regex::Regex;
|
||||||
|
let mut result = script.to_string();
|
||||||
|
|
||||||
|
// First, process RANDOM to ensure commas are preserved
|
||||||
|
// RANDOM(min, max) stays as RANDOM(min, max) - no conversion needed
|
||||||
|
|
||||||
|
// Convert FORMAT(expr, pattern) → FORMAT expr pattern
|
||||||
|
// Need to handle nested functions carefully
|
||||||
|
// Match FORMAT( ... ) but don't include inner function parentheses
|
||||||
|
// This regex matches FORMAT followed by parentheses containing two comma-separated expressions
|
||||||
|
if let Ok(re) = Regex::new(r"(?i)FORMAT\s*\(([^()]+(?:\([^()]*\)[^()]*)*),([^)]+)\)") {
|
||||||
|
result = re.replace_all(&result, "FORMAT $1$2").to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert BASIC IF ... THEN / END IF syntax to Rhai's if ... { } syntax
|
||||||
|
fn convert_if_then_syntax(script: &str) -> String {
|
||||||
|
let mut result = String::new();
|
||||||
|
let mut if_stack: Vec<bool> = Vec::new(); // Track if we're inside an IF block
|
||||||
|
let mut in_with_block = false; // Track if we're inside a WITH block
|
||||||
|
let mut line_buffer = String::new();
|
||||||
|
|
||||||
|
log::info!("[TOOL] Converting IF/THEN syntax, input has {} lines", script.lines().count());
|
||||||
|
|
||||||
|
for line in script.lines() {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
let upper = trimmed.to_uppercase();
|
||||||
|
|
||||||
|
// Skip empty lines and comments
|
||||||
|
if trimmed.is_empty() || trimmed.starts_with('\'') || trimmed.starts_with("//") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle IF ... THEN
|
||||||
|
if upper.starts_with("IF ") && upper.contains(" THEN") {
|
||||||
|
let then_pos = upper.find(" THEN").unwrap();
|
||||||
|
let condition = &trimmed[3..then_pos].trim();
|
||||||
|
log::info!("[TOOL] Converting IF statement: condition='{}'", condition);
|
||||||
|
result.push_str("if ");
|
||||||
|
result.push_str(condition);
|
||||||
|
result.push_str(" {\n");
|
||||||
|
if_stack.push(true);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle ELSE
|
||||||
|
if upper == "ELSE" {
|
||||||
|
log::info!("[TOOL] Converting ELSE statement");
|
||||||
|
result.push_str("} else {\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle END IF
|
||||||
|
if upper == "END IF" {
|
||||||
|
log::info!("[TOOL] Converting END IF statement");
|
||||||
|
if let Some(_) = if_stack.pop() {
|
||||||
|
result.push_str("}\n");
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle WITH ... END WITH (BASIC object creation)
|
||||||
|
if upper.starts_with("WITH ") {
|
||||||
|
let object_name = &trimmed[5..].trim();
|
||||||
|
log::info!("[TOOL] Converting WITH statement: object='{}'", object_name);
|
||||||
|
// Convert WITH obj → let obj = #{ (start object literal)
|
||||||
|
result.push_str("let ");
|
||||||
|
result.push_str(object_name);
|
||||||
|
result.push_str(" = #{\n");
|
||||||
|
in_with_block = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if upper == "END WITH" {
|
||||||
|
log::info!("[TOOL] Converting END WITH statement");
|
||||||
|
result.push_str("};\n");
|
||||||
|
in_with_block = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inside a WITH block - convert property assignments (key = value → key: value)
|
||||||
|
if in_with_block {
|
||||||
|
// Check if this is a property assignment (identifier = value)
|
||||||
|
if trimmed.contains('=') && !trimmed.contains("==") && !trimmed.contains("!=") && !trimmed.contains("+=") && !trimmed.contains("-=") {
|
||||||
|
// Convert assignment to object property syntax
|
||||||
|
let parts: Vec<&str> = trimmed.splitn(2, '=').collect();
|
||||||
|
if parts.len() == 2 {
|
||||||
|
let property_name = parts[0].trim();
|
||||||
|
let property_value = parts[1].trim();
|
||||||
|
// Remove trailing semicolon if present
|
||||||
|
let property_value = property_value.trim_end_matches(';');
|
||||||
|
result.push_str(&format!(" {}: {},\n", property_name, property_value));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Regular line in WITH block - add indentation
|
||||||
|
result.push_str(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SAVE table, object → INSERT table, object
|
||||||
|
// BASIC SAVE uses 2 parameters but Rhai SAVE needs 3
|
||||||
|
// INSERT uses 2 parameters which matches the BASIC syntax
|
||||||
|
if upper.starts_with("SAVE") && upper.contains(',') {
|
||||||
|
log::info!("[TOOL] Processing SAVE line: '{}'", trimmed);
|
||||||
|
// Extract table and object name
|
||||||
|
let after_save = &trimmed[4..].trim(); // Skip "SAVE"
|
||||||
|
let parts: Vec<&str> = after_save.split(',').collect();
|
||||||
|
log::info!("[TOOL] SAVE parts: {:?}", parts);
|
||||||
|
if parts.len() == 2 {
|
||||||
|
let table = parts[0].trim().trim_matches('"');
|
||||||
|
let object_name = parts[1].trim().trim_end_matches(';');
|
||||||
|
// Convert to INSERT table, object
|
||||||
|
let converted = format!("INSERT \"{}\", {};\n", table, object_name);
|
||||||
|
log::info!("[TOOL] Converted SAVE to INSERT: '{}'", converted);
|
||||||
|
result.push_str(&converted);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle SEND EMAIL → send_mail (function call style)
|
||||||
|
// Syntax: SEND EMAIL to, subject, body → send_mail(to, subject, body, [])
|
||||||
|
if upper.starts_with("SEND EMAIL") {
|
||||||
|
log::info!("[TOOL] Processing SEND EMAIL line: '{}'", trimmed);
|
||||||
|
let after_send = &trimmed[11..].trim(); // Skip "SEND EMAIL " (10 chars + space = 11)
|
||||||
|
let parts: Vec<&str> = after_send.split(',').collect();
|
||||||
|
log::info!("[TOOL] SEND EMAIL parts: {:?}", parts);
|
||||||
|
if parts.len() == 3 {
|
||||||
|
let to = parts[0].trim();
|
||||||
|
let subject = parts[1].trim();
|
||||||
|
let body = parts[2].trim().trim_end_matches(';');
|
||||||
|
// Convert to send_mail(to, subject, body, []) function call
|
||||||
|
let converted = format!("send_mail({}, {}, {}, []);\n", to, subject, body);
|
||||||
|
log::info!("[TOOL] Converted SEND EMAIL to: '{}'", converted);
|
||||||
|
result.push_str(&converted);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular line - add indentation if inside IF block
|
||||||
|
if !if_stack.is_empty() {
|
||||||
|
result.push_str(" ");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if line is a simple statement (not containing THEN or other control flow)
|
||||||
|
if !upper.starts_with("IF ") && !upper.starts_with("ELSE") && !upper.starts_with("END IF") {
|
||||||
|
// Check if this is a variable assignment (identifier = expression)
|
||||||
|
// Pattern: starts with letter/underscore, contains = but not ==, !=, <=, >=, +=, -=
|
||||||
|
let is_var_assignment = trimmed.chars().next().map_or(false, |c| c.is_alphabetic() || c == '_')
|
||||||
|
&& trimmed.contains('=')
|
||||||
|
&& !trimmed.contains("==")
|
||||||
|
&& !trimmed.contains("!=")
|
||||||
|
&& !trimmed.contains("<=")
|
||||||
|
&& !trimmed.contains(">=")
|
||||||
|
&& !trimmed.contains("+=")
|
||||||
|
&& !trimmed.contains("-=")
|
||||||
|
&& !trimmed.contains("*=")
|
||||||
|
&& !trimmed.contains("/=");
|
||||||
|
|
||||||
|
if is_var_assignment {
|
||||||
|
// Add 'let' for variable declarations
|
||||||
|
result.push_str("let ");
|
||||||
|
}
|
||||||
|
result.push_str(trimmed);
|
||||||
|
// Add semicolon if line doesn't have one and doesn't end with { or }
|
||||||
|
if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') {
|
||||||
|
result.push(';');
|
||||||
|
}
|
||||||
|
result.push('\n');
|
||||||
|
} else {
|
||||||
|
result.push_str(trimmed);
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("[TOOL] IF/THEN conversion complete, output has {} lines", result.lines().count());
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert BASIC keywords to lowercase without touching variables
|
||||||
|
/// This is a simplified version of normalize_variables_to_lowercase for tools
|
||||||
|
fn convert_keywords_to_lowercase(script: &str) -> String {
|
||||||
|
let keywords = [
|
||||||
|
"IF", "THEN", "ELSE", "END IF", "FOR", "NEXT", "WHILE", "WEND",
|
||||||
|
"DO", "LOOP", "RETURN", "EXIT", "SELECT", "CASE", "END SELECT",
|
||||||
|
"WITH", "END WITH", "AND", "OR", "NOT", "MOD",
|
||||||
|
"DIM", "AS", "NEW", "FUNCTION", "SUB", "CALL",
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut result = String::new();
|
||||||
|
for line in script.lines() {
|
||||||
|
let mut processed_line = line.to_string();
|
||||||
|
for keyword in &keywords {
|
||||||
|
// Use word boundaries to avoid replacing parts of variable names
|
||||||
|
let pattern = format!(r"\b{}\b", regex::escape(keyword));
|
||||||
|
if let Ok(re) = regex::Regex::new(&pattern) {
|
||||||
|
processed_line = re.replace_all(&processed_line, keyword.to_lowercase()).to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push_str(&processed_line);
|
||||||
|
result.push('\n');
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_variables_to_lowercase(script: &str) -> String {
|
fn normalize_variables_to_lowercase(script: &str) -> String {
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
|
|
@ -809,6 +1070,19 @@ impl ScriptService {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip lines with custom syntax that should not be lowercased
|
||||||
|
// These are registered directly with Rhai in uppercase
|
||||||
|
let trimmed_upper = trimmed.to_uppercase();
|
||||||
|
if trimmed_upper.contains("ADD_SUGGESTION_TOOL") ||
|
||||||
|
trimmed_upper.contains("ADD_SUGGESTION_TEXT") ||
|
||||||
|
trimmed_upper.starts_with("ADD_SUGGESTION_") ||
|
||||||
|
trimmed_upper.starts_with("ADD_MEMBER") {
|
||||||
|
// Keep original line as-is
|
||||||
|
result.push_str(line);
|
||||||
|
result.push('\n');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let mut processed_line = String::new();
|
let mut processed_line = String::new();
|
||||||
let mut chars = line.chars().peekable();
|
let mut chars = line.chars().peekable();
|
||||||
let mut in_string = false;
|
let mut in_string = false;
|
||||||
|
|
@ -889,8 +1163,10 @@ impl ScriptService {
|
||||||
(r#"CLEAR\s+TOOLS"#, 0, 0, vec![]),
|
(r#"CLEAR\s+TOOLS"#, 0, 0, vec![]),
|
||||||
(r#"CLEAR\s+WEBSITES"#, 0, 0, vec![]),
|
(r#"CLEAR\s+WEBSITES"#, 0, 0, vec![]),
|
||||||
|
|
||||||
// ADD family
|
// ADD family - ADD_SUGGESTION_TOOL must come before ADD\s+SUGGESTION
|
||||||
(r#"ADD\s+SUGGESTION"#, 2, 2, vec!["title", "text"]),
|
(r#"ADD_SUGGESTION_TOOL"#, 2, 2, vec!["tool", "text"]),
|
||||||
|
(r#"ADD\s+SUGGESTION\s+TEXT"#, 2, 2, vec!["value", "text"]),
|
||||||
|
(r#"ADD\s+SUGGESTION(?!\s*TEXT|\s*TOOL|_TOOL)"#, 2, 2, vec!["context", "text"]),
|
||||||
(r#"ADD\s+MEMBER"#, 2, 2, vec!["name", "role"]),
|
(r#"ADD\s+MEMBER"#, 2, 2, vec!["name", "role"]),
|
||||||
|
|
||||||
// CREATE family
|
// CREATE family
|
||||||
|
|
@ -916,6 +1192,23 @@ impl ScriptService {
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
let mut converted = false;
|
let mut converted = false;
|
||||||
|
|
||||||
|
// Skip lines that already use underscore-style custom syntax
|
||||||
|
// These are registered directly with Rhai and should not be converted
|
||||||
|
let trimmed_upper = trimmed.to_uppercase();
|
||||||
|
if trimmed_upper.contains("ADD_SUGGESTION_TOOL") ||
|
||||||
|
trimmed_upper.contains("ADD_SUGGESTION_TEXT") ||
|
||||||
|
trimmed_upper.starts_with("ADD_SUGGESTION_") ||
|
||||||
|
trimmed_upper.starts_with("ADD_MEMBER") ||
|
||||||
|
(trimmed_upper.starts_with("USE_") && trimmed.contains('(')) {
|
||||||
|
// Keep original line and add semicolon if needed
|
||||||
|
result.push_str(line);
|
||||||
|
if !trimmed.ends_with(';') && !trimmed.ends_with('{') && !trimmed.ends_with('}') {
|
||||||
|
result.push(';');
|
||||||
|
}
|
||||||
|
result.push('\n');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Try each pattern
|
// Try each pattern
|
||||||
for (pattern, min_params, max_params, _param_names) in &multiword_patterns {
|
for (pattern, min_params, max_params, _param_names) in &multiword_patterns {
|
||||||
// Build regex pattern: KEYWORD params...
|
// Build regex pattern: KEYWORD params...
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ pub mod kb_context;
|
||||||
use kb_context::inject_kb_context;
|
use kb_context::inject_kb_context;
|
||||||
pub mod tool_context;
|
pub mod tool_context;
|
||||||
use tool_context::get_session_tools;
|
use tool_context::get_session_tools;
|
||||||
|
pub mod tool_executor;
|
||||||
|
use tool_executor::ToolExecutor;
|
||||||
#[cfg(feature = "llm")]
|
#[cfg(feature = "llm")]
|
||||||
use crate::core::config::ConfigManager;
|
use crate::core::config::ConfigManager;
|
||||||
|
|
||||||
|
|
@ -18,6 +20,8 @@ use crate::nvidia::get_system_metrics;
|
||||||
use crate::shared::message_types::MessageType;
|
use crate::shared::message_types::MessageType;
|
||||||
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
use crate::shared::models::{BotResponse, UserMessage, UserSession};
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
|
#[cfg(feature = "chat")]
|
||||||
|
use crate::basic::keywords::add_suggestion::get_suggestions;
|
||||||
use axum::extract::ws::{Message, WebSocket};
|
use axum::extract::ws::{Message, WebSocket};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{ws::WebSocketUpgrade, Extension, Query, State},
|
extract::{ws::WebSocketUpgrade, Extension, Query, State},
|
||||||
|
|
@ -422,86 +426,101 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
#[cfg(any(feature = "research", feature = "llm"))]
|
#[cfg(any(feature = "research", feature = "llm"))]
|
||||||
{
|
{
|
||||||
// Execute start.bas on first message to load tools
|
// Execute start.bas on first message - ONLY run once per session to load suggestions
|
||||||
// Check if tools have been loaded for this session
|
let actual_session_id = session.id.to_string();
|
||||||
let has_tools_loaded = {
|
|
||||||
use crate::shared::models::schema::session_tool_associations::dsl::*;
|
// Check if start.bas has already been executed for this session
|
||||||
let conn = self.state.conn.get().ok();
|
let start_bas_key = format!("start_bas_executed:{}", actual_session_id);
|
||||||
if let Some(mut db_conn) = conn {
|
let should_execute_start_bas = if let Some(cache) = &self.state.cache {
|
||||||
let tool_names: Vec<String> = session_tool_associations
|
if let Ok(mut conn) = cache.get_multiplexed_async_connection().await {
|
||||||
.filter(session_id.eq(&session_id_str))
|
let executed: Result<Option<String>, redis::RedisError> = redis::cmd("GET")
|
||||||
.select(tool_name)
|
.arg(&start_bas_key)
|
||||||
.limit(1)
|
.query_async(&mut conn)
|
||||||
.load(&mut db_conn)
|
.await;
|
||||||
.unwrap_or_default();
|
executed.is_ok() && executed.unwrap().is_none()
|
||||||
!tool_names.is_empty()
|
|
||||||
} else {
|
} else {
|
||||||
false
|
true // If cache fails, try to execute
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
true // If no cache, try to execute
|
||||||
};
|
};
|
||||||
|
|
||||||
// If no tools are loaded, execute start.bas
|
if should_execute_start_bas {
|
||||||
if !has_tools_loaded {
|
// Always execute start.bas for this session (blocking - wait for completion)
|
||||||
let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
|
let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
|
||||||
let data_dir = format!("{}/data", home_dir);
|
let data_dir = format!("{}/data", home_dir);
|
||||||
let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name_for_context, bot_name_for_context);
|
let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name_for_context, bot_name_for_context);
|
||||||
|
|
||||||
info!("[START_BAS] Checking for start.bas at: {}", start_script_path);
|
info!("[START_BAS] Executing start.bas for session {} at: {}", actual_session_id, start_script_path);
|
||||||
|
|
||||||
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
|
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
|
||||||
if metadata.is_file() {
|
if metadata.is_file() {
|
||||||
info!("[START_BAS] Found start.bas, executing for session {}", session_id);
|
info!("[START_BAS] Found start.bas, executing for session {}", actual_session_id);
|
||||||
|
|
||||||
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
|
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
|
||||||
let state_clone = self.state.clone();
|
let state_clone = self.state.clone();
|
||||||
let session_id_clone = session_id;
|
let actual_session_id_for_task = session.id;
|
||||||
let bot_id_clone = session.bot_id;
|
let bot_id_clone = session.bot_id;
|
||||||
let bot_name_clone = bot_name_for_context.clone();
|
|
||||||
let bot_name_for_log = bot_name_clone.clone();
|
|
||||||
let start_script_clone = start_script.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
// Execute start.bas synchronously (blocking)
|
||||||
let session_result = state_clone.session_manager.lock().await.get_session_by_id(session_id_clone);
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
|
let session_result = {
|
||||||
|
let mut sm = state_clone.session_manager.blocking_lock();
|
||||||
|
sm.get_session_by_id(actual_session_id_for_task)
|
||||||
|
};
|
||||||
|
|
||||||
if let Ok(Some(session)) = session_result {
|
let sess = match session_result {
|
||||||
let result = tokio::task::spawn_blocking(move || {
|
Ok(Some(s)) => s,
|
||||||
let mut script_service = crate::basic::ScriptService::new(
|
Ok(None) => {
|
||||||
state_clone.clone(),
|
return Err(format!("Session {} not found during start.bas execution", actual_session_id_for_task));
|
||||||
session.clone()
|
}
|
||||||
);
|
Err(e) => return Err(format!("Failed to get session: {}", e)),
|
||||||
script_service.load_bot_config_params(&state_clone, bot_id_clone);
|
};
|
||||||
|
|
||||||
match script_service.compile(&start_script_clone) {
|
let mut script_service = crate::basic::ScriptService::new(
|
||||||
Ok(ast) => match script_service.run(&ast) {
|
state_clone.clone(),
|
||||||
Ok(_) => {
|
sess
|
||||||
info!("[START_BAS] Executed start.bas successfully for bot {}", bot_name_clone);
|
);
|
||||||
Ok(())
|
script_service.load_bot_config_params(&state_clone, bot_id_clone);
|
||||||
},
|
|
||||||
Err(e) => Err(format!("Script execution error: {}", e)),
|
|
||||||
},
|
|
||||||
Err(e) => Err(format!("Script compilation error: {}", e)),
|
|
||||||
}
|
|
||||||
}).await;
|
|
||||||
|
|
||||||
match result {
|
match script_service.compile(&start_script) {
|
||||||
Ok(Ok(())) => {
|
Ok(ast) => match script_service.run(&ast) {
|
||||||
info!("[START_BAS] start.bas completed for bot {}", bot_name_for_log);
|
Ok(_) => Ok(()),
|
||||||
}
|
Err(e) => Err(format!("Script execution error: {}", e)),
|
||||||
Ok(Err(e)) => {
|
},
|
||||||
error!("[START_BAS] start.bas error for bot {}: {}", bot_name_for_log, e);
|
Err(e) => Err(format!("Script compilation error: {}", e)),
|
||||||
}
|
}
|
||||||
Err(e) => {
|
}).await;
|
||||||
error!("[START_BAS] start.bas task error for bot {}: {}", bot_name_for_log, e);
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(())) => {
|
||||||
|
info!("[START_BAS] start.bas completed successfully for session {}", actual_session_id);
|
||||||
|
|
||||||
|
// Mark start.bas as executed for this session to prevent re-running
|
||||||
|
if let Some(cache) = &self.state.cache {
|
||||||
|
if let Ok(mut conn) = cache.get_multiplexed_async_connection().await {
|
||||||
|
let _: Result<(), redis::RedisError> = redis::cmd("SET")
|
||||||
|
.arg(&start_bas_key)
|
||||||
|
.arg("1")
|
||||||
|
.arg("EX")
|
||||||
|
.arg("86400") // Expire after 24 hours
|
||||||
|
.query_async(&mut conn)
|
||||||
|
.await;
|
||||||
|
info!("[START_BAS] Marked start.bas as executed for session {}", actual_session_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
Ok(Err(e)) => {
|
||||||
|
error!("[START_BAS] start.bas error for session {}: {}", actual_session_id, e);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("[START_BAS] start.bas task error for session {}: {}", actual_session_id, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
info!("[START_BAS] start.bas not found for bot {}", bot_name_for_context);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} // End of if should_execute_start_bas
|
||||||
|
|
||||||
if let Some(kb_manager) = self.state.kb_manager.as_ref() {
|
if let Some(kb_manager) = self.state.kb_manager.as_ref() {
|
||||||
if let Err(e) = inject_kb_context(
|
if let Err(e) = inject_kb_context(
|
||||||
|
|
@ -539,15 +558,15 @@ impl BotOrchestrator {
|
||||||
let model_clone = model.clone();
|
let model_clone = model.clone();
|
||||||
let key_clone = key.clone();
|
let key_clone = key.clone();
|
||||||
|
|
||||||
// Retrieve session tools for tool calling
|
// Retrieve session tools for tool calling (use actual session.id after potential creation)
|
||||||
let session_tools = get_session_tools(&self.state.conn, &bot_name_for_context, &session_id);
|
let session_tools = get_session_tools(&self.state.conn, &bot_name_for_context, &session.id);
|
||||||
let tools_for_llm = match session_tools {
|
let tools_for_llm = match session_tools {
|
||||||
Ok(tools) => {
|
Ok(tools) => {
|
||||||
if !tools.is_empty() {
|
if !tools.is_empty() {
|
||||||
info!("[TOOLS] Loaded {} tools for session {}", tools.len(), session_id);
|
info!("[TOOLS] Loaded {} tools for session {}", tools.len(), session.id);
|
||||||
Some(tools)
|
Some(tools)
|
||||||
} else {
|
} else {
|
||||||
info!("[TOOLS] No tools associated with session {}", session_id);
|
info!("[TOOLS] No tools associated with session {}", session.id);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -615,6 +634,89 @@ impl BotOrchestrator {
|
||||||
info!("[STREAM_DEBUG] Received chunk: '{}', len: {}", chunk, chunk.len());
|
info!("[STREAM_DEBUG] Received chunk: '{}', len: {}", chunk, chunk.len());
|
||||||
trace!("Received LLM chunk: {:?}", chunk);
|
trace!("Received LLM chunk: {:?}", chunk);
|
||||||
|
|
||||||
|
// ===== GENERIC TOOL EXECUTION =====
|
||||||
|
// Check if this chunk contains a tool call (works with all LLM providers)
|
||||||
|
if let Some(tool_call) = ToolExecutor::parse_tool_call(&chunk) {
|
||||||
|
info!(
|
||||||
|
"[TOOL_CALL] Detected tool '{}' from LLM, executing...",
|
||||||
|
tool_call.tool_name
|
||||||
|
);
|
||||||
|
|
||||||
|
let execution_result = ToolExecutor::execute_tool_call(
|
||||||
|
&self.state,
|
||||||
|
&bot_name_for_context,
|
||||||
|
&tool_call,
|
||||||
|
&session_id,
|
||||||
|
&user_id,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if execution_result.success {
|
||||||
|
info!(
|
||||||
|
"[TOOL_EXEC] Tool '{}' executed successfully: {}",
|
||||||
|
tool_call.tool_name, execution_result.result
|
||||||
|
);
|
||||||
|
|
||||||
|
// Send tool execution result to user
|
||||||
|
let response = BotResponse {
|
||||||
|
bot_id: message.bot_id.clone(),
|
||||||
|
user_id: message.user_id.clone(),
|
||||||
|
session_id: message.session_id.clone(),
|
||||||
|
channel: message.channel.clone(),
|
||||||
|
content: execution_result.result,
|
||||||
|
message_type: MessageType::BOT_RESPONSE,
|
||||||
|
stream_token: None,
|
||||||
|
is_complete: false,
|
||||||
|
suggestions: Vec::new(),
|
||||||
|
context_name: None,
|
||||||
|
context_length: 0,
|
||||||
|
context_max_length: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if response_tx.send(response).await.is_err() {
|
||||||
|
warn!("Response channel closed during tool execution");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
"[TOOL_EXEC] Tool '{}' execution failed: {:?}",
|
||||||
|
tool_call.tool_name, execution_result.error
|
||||||
|
);
|
||||||
|
|
||||||
|
// Send error to user
|
||||||
|
let error_msg = format!(
|
||||||
|
"Erro ao executar ferramenta '{}': {:?}",
|
||||||
|
tool_call.tool_name,
|
||||||
|
execution_result.error
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = BotResponse {
|
||||||
|
bot_id: message.bot_id.clone(),
|
||||||
|
user_id: message.user_id.clone(),
|
||||||
|
session_id: message.session_id.clone(),
|
||||||
|
channel: message.channel.clone(),
|
||||||
|
content: error_msg,
|
||||||
|
message_type: MessageType::BOT_RESPONSE,
|
||||||
|
stream_token: None,
|
||||||
|
is_complete: false,
|
||||||
|
suggestions: Vec::new(),
|
||||||
|
context_name: None,
|
||||||
|
context_length: 0,
|
||||||
|
context_max_length: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
if response_tx.send(response).await.is_err() {
|
||||||
|
warn!("Response channel closed during tool error");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't add tool_call JSON to full_response or analysis_buffer
|
||||||
|
// Continue to next chunk
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// ===== END TOOL EXECUTION =====
|
||||||
|
|
||||||
analysis_buffer.push_str(&chunk);
|
analysis_buffer.push_str(&chunk);
|
||||||
|
|
||||||
if !in_analysis && handler.has_analysis_markers(&analysis_buffer) {
|
if !in_analysis && handler.has_analysis_markers(&analysis_buffer) {
|
||||||
|
|
@ -732,6 +834,15 @@ impl BotOrchestrator {
|
||||||
)
|
)
|
||||||
.await??;
|
.await??;
|
||||||
|
|
||||||
|
// Extract user_id and session_id before moving them into BotResponse
|
||||||
|
let user_id_str = message.user_id.clone();
|
||||||
|
let session_id_str = message.session_id.clone();
|
||||||
|
|
||||||
|
#[cfg(feature = "chat")]
|
||||||
|
let suggestions = get_suggestions(self.state.cache.as_ref(), &user_id_str, &session_id_str);
|
||||||
|
#[cfg(not(feature = "chat"))]
|
||||||
|
let suggestions: Vec<crate::shared::models::Suggestion> = Vec::new();
|
||||||
|
|
||||||
let final_response = BotResponse {
|
let final_response = BotResponse {
|
||||||
bot_id: message.bot_id,
|
bot_id: message.bot_id,
|
||||||
user_id: message.user_id,
|
user_id: message.user_id,
|
||||||
|
|
@ -741,7 +852,7 @@ impl BotOrchestrator {
|
||||||
message_type: MessageType::BOT_RESPONSE,
|
message_type: MessageType::BOT_RESPONSE,
|
||||||
stream_token: None,
|
stream_token: None,
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
suggestions: Vec::new(),
|
suggestions,
|
||||||
context_name: None,
|
context_name: None,
|
||||||
context_length: 0,
|
context_length: 0,
|
||||||
context_max_length: 0,
|
context_max_length: 0,
|
||||||
|
|
@ -937,16 +1048,33 @@ async fn handle_websocket(
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(bot_name) = bot_name_result {
|
if let Some(bot_name) = bot_name_result {
|
||||||
let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
|
// Check if start.bas has already been executed for this session
|
||||||
let data_dir = format!("{}/data", home_dir);
|
let start_bas_key = format!("start_bas_executed:{}", session_id);
|
||||||
let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name, bot_name);
|
let should_execute_start_bas = if let Some(cache) = &state.cache {
|
||||||
|
if let Ok(mut conn) = cache.get_multiplexed_async_connection().await {
|
||||||
|
let executed: Result<Option<String>, redis::RedisError> = redis::cmd("GET")
|
||||||
|
.arg(&start_bas_key)
|
||||||
|
.query_async(&mut conn)
|
||||||
|
.await;
|
||||||
|
executed.is_ok() && executed.unwrap().is_none()
|
||||||
|
} else {
|
||||||
|
true // If cache fails, try to execute
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true // If no cache, try to execute
|
||||||
|
};
|
||||||
|
|
||||||
info!("Looking for start.bas at: {}", start_script_path);
|
if should_execute_start_bas {
|
||||||
|
let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
|
||||||
|
let data_dir = format!("{}/data", home_dir);
|
||||||
|
let start_script_path = format!("{}/{}.gbai/{}.gbdialog/start.bas", data_dir, bot_name, bot_name);
|
||||||
|
|
||||||
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
|
info!("Looking for start.bas at: {}", start_script_path);
|
||||||
if metadata.is_file() {
|
|
||||||
info!("Found start.bas file, reading contents...");
|
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
|
||||||
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
|
if metadata.is_file() {
|
||||||
|
info!("Found start.bas file, reading contents...");
|
||||||
|
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
|
||||||
info!(
|
info!(
|
||||||
"Executing start.bas for bot {} on session {}",
|
"Executing start.bas for bot {} on session {}",
|
||||||
bot_name, session_id
|
bot_name, session_id
|
||||||
|
|
@ -964,6 +1092,9 @@ async fn handle_websocket(
|
||||||
if let Ok(Some(session)) = session_result {
|
if let Ok(Some(session)) = session_result {
|
||||||
info!("Executing start.bas for bot {} on session {}", bot_name, session_id);
|
info!("Executing start.bas for bot {} on session {}", bot_name, session_id);
|
||||||
|
|
||||||
|
// Clone state_for_start for use in Redis SET after execution
|
||||||
|
let state_for_redis = state_for_start.clone();
|
||||||
|
|
||||||
let result = tokio::task::spawn_blocking(move || {
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
let mut script_service = crate::basic::ScriptService::new(
|
let mut script_service = crate::basic::ScriptService::new(
|
||||||
state_for_start.clone(),
|
state_for_start.clone(),
|
||||||
|
|
@ -983,6 +1114,20 @@ async fn handle_websocket(
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(())) => {
|
Ok(Ok(())) => {
|
||||||
info!("start.bas executed successfully for bot {}", bot_name);
|
info!("start.bas executed successfully for bot {}", bot_name);
|
||||||
|
|
||||||
|
// Mark start.bas as executed for this session to prevent re-running
|
||||||
|
if let Some(cache) = &state_for_redis.cache {
|
||||||
|
if let Ok(mut conn) = cache.get_multiplexed_async_connection().await {
|
||||||
|
let _: Result<(), redis::RedisError> = redis::cmd("SET")
|
||||||
|
.arg(&start_bas_key)
|
||||||
|
.arg("1")
|
||||||
|
.arg("EX")
|
||||||
|
.arg("86400") // Expire after 24 hours
|
||||||
|
.query_async(&mut conn)
|
||||||
|
.await;
|
||||||
|
info!("Marked start.bas as executed for session {}", session_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
error!("start.bas error for bot {}: {}", bot_name, e);
|
error!("start.bas error for bot {}: {}", bot_name, e);
|
||||||
|
|
@ -996,6 +1141,7 @@ async fn handle_websocket(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} // End of if should_execute_start_bas
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1017,15 +1163,52 @@ async fn handle_websocket(
|
||||||
info!("Received WebSocket message: {}", text);
|
info!("Received WebSocket message: {}", text);
|
||||||
if let Ok(user_msg) = serde_json::from_str::<UserMessage>(&text) {
|
if let Ok(user_msg) = serde_json::from_str::<UserMessage>(&text) {
|
||||||
let orchestrator = BotOrchestrator::new(state_clone.clone());
|
let orchestrator = BotOrchestrator::new(state_clone.clone());
|
||||||
|
info!("[WS_DEBUG] Looking up response channel for session: {}", session_id);
|
||||||
if let Some(tx_clone) = state_clone
|
if let Some(tx_clone) = state_clone
|
||||||
.response_channels
|
.response_channels
|
||||||
.lock()
|
.lock()
|
||||||
.await
|
.await
|
||||||
.get(&session_id.to_string())
|
.get(&session_id.to_string())
|
||||||
{
|
{
|
||||||
|
info!("[WS_DEBUG] Response channel found, calling stream_response");
|
||||||
|
|
||||||
|
// Ensure session exists - create if not
|
||||||
|
let session_result = {
|
||||||
|
let mut sm = state_clone.session_manager.lock().await;
|
||||||
|
sm.get_session_by_id(session_id)
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = match session_result {
|
||||||
|
Ok(Some(sess)) => {
|
||||||
|
info!("[WS_DEBUG] Session exists: {}", session_id);
|
||||||
|
sess
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
info!("[WS_DEBUG] Session not found, creating via session manager");
|
||||||
|
// Use session manager to create session (will generate new UUID)
|
||||||
|
let mut sm = state_clone.session_manager.lock().await;
|
||||||
|
match sm.create_session(user_id, bot_id, "WebSocket Chat") {
|
||||||
|
Ok(new_session) => {
|
||||||
|
info!("[WS_DEBUG] Session created: {} (note: different from WebSocket session_id)", new_session.id);
|
||||||
|
new_session
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("[WS_DEBUG] Failed to create session: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("[WS_DEBUG] Error getting session: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Use bot_id from WebSocket connection instead of from message
|
// Use bot_id from WebSocket connection instead of from message
|
||||||
let corrected_msg = UserMessage {
|
let corrected_msg = UserMessage {
|
||||||
bot_id: bot_id.to_string(),
|
bot_id: bot_id.to_string(),
|
||||||
|
user_id: session.user_id.to_string(),
|
||||||
|
session_id: session.id.to_string(),
|
||||||
..user_msg
|
..user_msg
|
||||||
};
|
};
|
||||||
if let Err(e) = orchestrator
|
if let Err(e) = orchestrator
|
||||||
|
|
@ -1034,7 +1217,11 @@ async fn handle_websocket(
|
||||||
{
|
{
|
||||||
error!("Failed to stream response: {}", e);
|
error!("Failed to stream response: {}", e);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
warn!("[WS_DEBUG] Response channel NOT found for session: {}", session_id);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
warn!("[WS_DEBUG] Failed to parse UserMessage from: {}", text);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Message::Close(_) => {
|
Message::Close(_) => {
|
||||||
|
|
|
||||||
384
src/core/bot/tool_executor.rs
Normal file
384
src/core/bot/tool_executor.rs
Normal file
|
|
@ -0,0 +1,384 @@
|
||||||
|
/// Generic tool executor for LLM tool calls
|
||||||
|
/// Works across all LLM providers (GLM, OpenAI, Claude, etc.)
|
||||||
|
use log::{error, info, warn};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::fs::OpenOptions;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::basic::ScriptService;
|
||||||
|
use crate::shared::state::AppState;
|
||||||
|
use crate::shared::models::schema::bots;
|
||||||
|
use diesel::prelude::*;
|
||||||
|
|
||||||
|
/// Represents a parsed tool call from an LLM
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ParsedToolCall {
|
||||||
|
pub id: String,
|
||||||
|
pub tool_name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of tool execution
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolExecutionResult {
|
||||||
|
pub tool_call_id: String,
|
||||||
|
pub success: bool,
|
||||||
|
pub result: String,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic tool executor - works with any LLM provider
|
||||||
|
pub struct ToolExecutor;
|
||||||
|
|
||||||
|
impl ToolExecutor {
|
||||||
|
/// Log tool execution errors to a dedicated log file
|
||||||
|
fn log_tool_error(bot_name: &str, tool_name: &str, error_msg: &str) {
|
||||||
|
let log_path = Path::new("work").join(format!("{}_tool_errors.log", bot_name));
|
||||||
|
|
||||||
|
// Create work directory if it doesn't exist
|
||||||
|
if let Some(parent) = log_path.parent() {
|
||||||
|
let _ = std::fs::create_dir_all(parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
let timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC");
|
||||||
|
let log_entry = format!(
|
||||||
|
"[{}] TOOL: {} | ERROR: {}\n",
|
||||||
|
timestamp, tool_name, error_msg
|
||||||
|
);
|
||||||
|
|
||||||
|
// Append to log file
|
||||||
|
if let Ok(mut file) = OpenOptions::new()
|
||||||
|
.create(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&log_path)
|
||||||
|
{
|
||||||
|
let _ = file.write_all(log_entry.as_bytes());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also log to system logger
|
||||||
|
error!("[TOOL_ERROR] Bot: {}, Tool: {}, Error: {}", bot_name, tool_name, error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert internal errors to user-friendly messages for browser
|
||||||
|
fn format_user_friendly_error(tool_name: &str, error: &str) -> String {
|
||||||
|
// Don't expose internal errors to browser - log them and return generic message
|
||||||
|
if error.contains("Compilation error") {
|
||||||
|
"Desculpe, houve um erro ao processar sua solicitação. Por favor, tente novamente ou entre em contato com a administração."
|
||||||
|
.to_string()
|
||||||
|
} else if error.contains("Execution error") {
|
||||||
|
format!("O processamento da ferramenta '{}' encontrou um problema. Nossa equipe foi notificada.", tool_name)
|
||||||
|
} else if error.contains("not found") {
|
||||||
|
"Ferramenta não disponível no momento. Por favor, tente novamente mais tarde.".to_string()
|
||||||
|
} else {
|
||||||
|
"Ocorreu um erro ao processar sua solicitação. Por favor, tente novamente.".to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Parse a tool call JSON from any LLM provider
|
||||||
|
/// Handles OpenAI, GLM, Claude formats
|
||||||
|
pub fn parse_tool_call(chunk: &str) -> Option<ParsedToolCall> {
|
||||||
|
// Try to parse as JSON
|
||||||
|
let json: Value = serde_json::from_str(chunk).ok()?;
|
||||||
|
|
||||||
|
// Check if this is a tool_call type (from GLM wrapper)
|
||||||
|
if let Some(tool_type) = json.get("type").and_then(|t| t.as_str()) {
|
||||||
|
if tool_type == "tool_call" {
|
||||||
|
if let Some(content) = json.get("content") {
|
||||||
|
return Self::extract_tool_call(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try direct OpenAI format
|
||||||
|
if json.get("function").is_some() {
|
||||||
|
return Self::extract_tool_call(&json);
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract tool call information from various formats
|
||||||
|
fn extract_tool_call(tool_data: &Value) -> Option<ParsedToolCall> {
|
||||||
|
let function = tool_data.get("function")?;
|
||||||
|
let tool_name = function.get("name")?.as_str()?.to_string();
|
||||||
|
let arguments_str = function.get("arguments")?.as_str()?;
|
||||||
|
|
||||||
|
// Parse arguments string to JSON
|
||||||
|
let arguments: Value = serde_json::from_str(arguments_str).ok()?;
|
||||||
|
|
||||||
|
let id = tool_data
|
||||||
|
.get("id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
Some(ParsedToolCall {
|
||||||
|
id,
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a tool call by running the corresponding .bas file
|
||||||
|
pub async fn execute_tool_call(
|
||||||
|
state: &Arc<AppState>,
|
||||||
|
bot_name: &str,
|
||||||
|
tool_call: &ParsedToolCall,
|
||||||
|
session_id: &Uuid,
|
||||||
|
_user_id: &Uuid,
|
||||||
|
) -> ToolExecutionResult {
|
||||||
|
info!(
|
||||||
|
"[TOOL_EXEC] Executing tool '{}' for bot '{}', session '{}'",
|
||||||
|
tool_call.tool_name, bot_name, session_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get bot_id
|
||||||
|
let bot_id = match Self::get_bot_id(state, bot_name) {
|
||||||
|
Some(id) => id,
|
||||||
|
None => {
|
||||||
|
let error_msg = format!("Bot '{}' not found", bot_name);
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load the .bas tool file
|
||||||
|
let bas_path = Self::get_tool_bas_path(bot_name, &tool_call.tool_name);
|
||||||
|
|
||||||
|
if !bas_path.exists() {
|
||||||
|
let error_msg = format!("Tool file not found: {:?}", bas_path);
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the .bas file
|
||||||
|
let bas_script = match tokio::fs::read_to_string(&bas_path).await {
|
||||||
|
Ok(script) => script,
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Failed to read tool file: {}", e);
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get session for ScriptService
|
||||||
|
let session = match state.session_manager.lock().await.get_session_by_id(*session_id) {
|
||||||
|
Ok(Some(sess)) => sess,
|
||||||
|
Ok(None) => {
|
||||||
|
let error_msg = "Session not found".to_string();
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Failed to get session: {}", e);
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call.id.clone(),
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute in blocking thread for ScriptService (which is not async)
|
||||||
|
let bot_name_clone = bot_name.to_string();
|
||||||
|
let tool_name_clone = tool_call.tool_name.clone();
|
||||||
|
let tool_call_id_clone = tool_call.id.clone();
|
||||||
|
let arguments_clone = tool_call.arguments.clone();
|
||||||
|
let state_clone = state.clone();
|
||||||
|
let bot_id_clone = bot_id;
|
||||||
|
|
||||||
|
let execution_result = tokio::task::spawn_blocking(move || {
|
||||||
|
Self::execute_tool_script(
|
||||||
|
&state_clone,
|
||||||
|
&bot_name_clone,
|
||||||
|
bot_id_clone,
|
||||||
|
&session,
|
||||||
|
&bas_script,
|
||||||
|
&tool_name_clone,
|
||||||
|
&arguments_clone,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match execution_result {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Task execution error: {}", e);
|
||||||
|
Self::log_tool_error(bot_name, &tool_call.tool_name, &error_msg);
|
||||||
|
ToolExecutionResult {
|
||||||
|
tool_call_id: tool_call_id_clone,
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(Self::format_user_friendly_error(&tool_call.tool_name, &error_msg)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute the tool script with parameters
|
||||||
|
fn execute_tool_script(
|
||||||
|
state: &Arc<AppState>,
|
||||||
|
bot_name: &str,
|
||||||
|
bot_id: Uuid,
|
||||||
|
session: &crate::shared::models::UserSession,
|
||||||
|
bas_script: &str,
|
||||||
|
tool_name: &str,
|
||||||
|
arguments: &Value,
|
||||||
|
) -> ToolExecutionResult {
|
||||||
|
let tool_call_id = format!("tool_{}", uuid::Uuid::new_v4());
|
||||||
|
|
||||||
|
// Create ScriptService
|
||||||
|
let mut script_service = ScriptService::new(state.clone(), session.clone());
|
||||||
|
script_service.load_bot_config_params(state, bot_id);
|
||||||
|
|
||||||
|
// Set tool parameters as variables in the engine scope
|
||||||
|
if let Some(obj) = arguments.as_object() {
|
||||||
|
for (key, value) in obj {
|
||||||
|
let value_str = match value {
|
||||||
|
Value::String(s) => s.clone(),
|
||||||
|
Value::Number(n) => n.to_string(),
|
||||||
|
Value::Bool(b) => b.to_string(),
|
||||||
|
_ => value.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set variable in script scope
|
||||||
|
if let Err(e) = script_service.set_variable(key, &value_str) {
|
||||||
|
warn!("[TOOL_EXEC] Failed to set variable '{}': {}", key, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile tool script (filters PARAM/DESCRIPTION lines and converts BASIC to Rhai)
|
||||||
|
let ast = match script_service.compile_tool_script(&bas_script) {
|
||||||
|
Ok(ast) => ast,
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Compilation error: {}", e);
|
||||||
|
Self::log_tool_error(bot_name, tool_name, &error_msg);
|
||||||
|
let user_message = Self::format_user_friendly_error(tool_name, &error_msg);
|
||||||
|
return ToolExecutionResult {
|
||||||
|
tool_call_id,
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(user_message),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run the script
|
||||||
|
match script_service.run(&ast) {
|
||||||
|
Ok(result) => {
|
||||||
|
info!("[TOOL_EXEC] Tool '{}' executed successfully", tool_name);
|
||||||
|
|
||||||
|
// Convert result to string
|
||||||
|
let result_str = result.to_string();
|
||||||
|
|
||||||
|
ToolExecutionResult {
|
||||||
|
tool_call_id,
|
||||||
|
success: true,
|
||||||
|
result: result_str,
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg = format!("Execution error: {}", e);
|
||||||
|
Self::log_tool_error(bot_name, tool_name, &error_msg);
|
||||||
|
let user_message = Self::format_user_friendly_error(tool_name, &error_msg);
|
||||||
|
ToolExecutionResult {
|
||||||
|
tool_call_id,
|
||||||
|
success: false,
|
||||||
|
result: String::new(),
|
||||||
|
error: Some(user_message),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the bot_id from bot_name
|
||||||
|
fn get_bot_id(state: &Arc<AppState>, bot_name: &str) -> Option<Uuid> {
|
||||||
|
let mut conn = state.conn.get().ok()?;
|
||||||
|
bots::table
|
||||||
|
.filter(bots::name.eq(bot_name))
|
||||||
|
.select(bots::id)
|
||||||
|
.first(&mut *conn)
|
||||||
|
.ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the path to a tool's .bas file
|
||||||
|
fn get_tool_bas_path(bot_name: &str, tool_name: &str) -> std::path::PathBuf {
|
||||||
|
let home_dir = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
|
||||||
|
|
||||||
|
// Try data directory first
|
||||||
|
let data_path = Path::new(&home_dir)
|
||||||
|
.join("data")
|
||||||
|
.join(format!("{}.gbai", bot_name))
|
||||||
|
.join(format!("{}.gbdialog", bot_name))
|
||||||
|
.join(format!("{}.bas", tool_name));
|
||||||
|
|
||||||
|
if data_path.exists() {
|
||||||
|
return data_path;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try work directory (for development/testing)
|
||||||
|
let work_path = Path::new(&home_dir)
|
||||||
|
.join("gb")
|
||||||
|
.join("work")
|
||||||
|
.join(format!("{}.gbai", bot_name))
|
||||||
|
.join(format!("{}.gbdialog", bot_name))
|
||||||
|
.join(format!("{}.bas", tool_name));
|
||||||
|
|
||||||
|
work_path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_tool_call_glm_format() {
|
||||||
|
let chunk = r#"{"type":"tool_call","content":{"id":"call_123","type":"function","function":{"name":"test_tool","arguments":"{\"param1\":\"value1\"}"}}}"#;
|
||||||
|
|
||||||
|
let result = ToolExecutor::parse_tool_call(chunk);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let tool_call = result.unwrap();
|
||||||
|
assert_eq!(tool_call.tool_name, "test_tool");
|
||||||
|
assert_eq!(tool_call.arguments["param1"], "value1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_tool_call_openai_format() {
|
||||||
|
let chunk = r#"{"id":"call_123","type":"function","function":{"name":"test_tool","arguments":"{\"param1\":\"value1\"}"}}"#;
|
||||||
|
|
||||||
|
let result = ToolExecutor::parse_tool_call(chunk);
|
||||||
|
assert!(result.is_some());
|
||||||
|
|
||||||
|
let tool_call = result.unwrap();
|
||||||
|
assert_eq!(tool_call.tool_name, "test_tool");
|
||||||
|
assert_eq!(tool_call.arguments["param1"], "value1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::basic::compiler::BasicCompiler;
|
use crate::basic::compiler::BasicCompiler;
|
||||||
use crate::shared::state::AppState;
|
use crate::shared::state::AppState;
|
||||||
|
use diesel::prelude::*;
|
||||||
use log::{debug, error, info, warn};
|
use log::{debug, error, info, warn};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
|
@ -11,6 +12,7 @@ use tokio::sync::RwLock;
|
||||||
use tokio::time::Duration;
|
use tokio::time::Duration;
|
||||||
use notify::{RecursiveMode, EventKind, RecommendedWatcher, Watcher};
|
use notify::{RecursiveMode, EventKind, RecommendedWatcher, Watcher};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct LocalFileState {
|
struct LocalFileState {
|
||||||
|
|
@ -184,8 +186,8 @@ impl LocalFileMonitor {
|
||||||
.and_then(|s| s.to_str())
|
.and_then(|s| s.to_str())
|
||||||
.unwrap_or("unknown");
|
.unwrap_or("unknown");
|
||||||
|
|
||||||
// Look for .gbdialog folder inside
|
// Look for <botname>.gbdialog folder inside (e.g., cristo.gbai/cristo.gbdialog)
|
||||||
let gbdialog_path = path.join(".gbdialog");
|
let gbdialog_path = path.join(format!("{}.gbdialog", bot_name));
|
||||||
if gbdialog_path.exists() {
|
if gbdialog_path.exists() {
|
||||||
self.compile_gbdialog(&bot_name, &gbdialog_path).await?;
|
self.compile_gbdialog(&bot_name, &gbdialog_path).await?;
|
||||||
}
|
}
|
||||||
|
|
@ -264,7 +266,19 @@ impl LocalFileMonitor {
|
||||||
let work_dir_clone = work_dir.clone();
|
let work_dir_clone = work_dir.clone();
|
||||||
let tool_name_clone = tool_name.to_string();
|
let tool_name_clone = tool_name.to_string();
|
||||||
let source_content_clone = source_content.clone();
|
let source_content_clone = source_content.clone();
|
||||||
let bot_id = uuid::Uuid::new_v4(); // Generate a bot ID or get from somewhere
|
let bot_name_clone = bot_name.to_string();
|
||||||
|
|
||||||
|
// Get the actual bot_id from the database for this bot_name
|
||||||
|
let bot_id = {
|
||||||
|
use crate::core::shared::models::schema::bots::dsl::*;
|
||||||
|
let mut conn = state_clone.conn.get()
|
||||||
|
.map_err(|e| format!("Failed to get DB connection: {}", e))?;
|
||||||
|
|
||||||
|
bots.filter(name.eq(&bot_name_clone))
|
||||||
|
.select(id)
|
||||||
|
.first::<Uuid>(&mut *conn)
|
||||||
|
.map_err(|e| format!("Failed to get bot_id for '{}': {}", bot_name_clone, e))?
|
||||||
|
};
|
||||||
|
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
std::fs::create_dir_all(&work_dir_clone)?;
|
std::fs::create_dir_all(&work_dir_clone)?;
|
||||||
|
|
@ -274,8 +288,9 @@ impl LocalFileMonitor {
|
||||||
let result = compiler.compile_file(local_source_path.to_str().unwrap(), work_dir_clone.to_str().unwrap())?;
|
let result = compiler.compile_file(local_source_path.to_str().unwrap(), work_dir_clone.to_str().unwrap())?;
|
||||||
if let Some(mcp_tool) = result.mcp_tool {
|
if let Some(mcp_tool) = result.mcp_tool {
|
||||||
info!(
|
info!(
|
||||||
"[LOCAL_MONITOR] MCP tool generated with {} parameters",
|
"[LOCAL_MONITOR] MCP tool generated with {} parameters for bot {}",
|
||||||
mcp_tool.input_schema.properties.len()
|
mcp_tool.input_schema.properties.len(),
|
||||||
|
bot_name_clone
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok::<(), Box<dyn Error + Send + Sync>>(())
|
Ok::<(), Box<dyn Error + Send + Sync>>(())
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,13 @@ pub mod episodic_memory;
|
||||||
pub mod glm;
|
pub mod glm;
|
||||||
pub mod llm_models;
|
pub mod llm_models;
|
||||||
pub mod local;
|
pub mod local;
|
||||||
|
pub mod rate_limiter;
|
||||||
pub mod smart_router;
|
pub mod smart_router;
|
||||||
|
|
||||||
pub use claude::ClaudeClient;
|
pub use claude::ClaudeClient;
|
||||||
pub use glm::GLMClient;
|
pub use glm::GLMClient;
|
||||||
pub use llm_models::get_handler;
|
pub use llm_models::get_handler;
|
||||||
|
pub use rate_limiter::{ApiRateLimiter, RateLimits};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait LLMProvider: Send + Sync {
|
pub trait LLMProvider: Send + Sync {
|
||||||
|
|
@ -48,6 +50,7 @@ pub struct OpenAIClient {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
endpoint_path: String,
|
endpoint_path: String,
|
||||||
|
rate_limiter: Arc<ApiRateLimiter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIClient {
|
impl OpenAIClient {
|
||||||
|
|
@ -164,10 +167,21 @@ impl OpenAIClient {
|
||||||
"/v1/chat/completions".to_string()
|
"/v1/chat/completions".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Detect API provider and set appropriate rate limits
|
||||||
|
let rate_limiter = if base.contains("groq.com") {
|
||||||
|
ApiRateLimiter::new(RateLimits::groq_free_tier())
|
||||||
|
} else if base.contains("openai.com") {
|
||||||
|
ApiRateLimiter::new(RateLimits::openai_free_tier())
|
||||||
|
} else {
|
||||||
|
// Default to unlimited for other providers (local models, etc.)
|
||||||
|
ApiRateLimiter::unlimited()
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
base_url: base,
|
base_url: base,
|
||||||
endpoint_path: endpoint,
|
endpoint_path: endpoint,
|
||||||
|
rate_limiter: Arc::new(rate_limiter),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,6 +329,13 @@ impl LLMProvider for OpenAIClient {
|
||||||
|
|
||||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||||
|
|
||||||
|
// Check rate limits before making the request
|
||||||
|
let estimated_tokens = OpenAIClient::estimate_messages_tokens(&messages);
|
||||||
|
if let Err(e) = self.rate_limiter.acquire(estimated_tokens).await {
|
||||||
|
error!("Rate limit exceeded: {}", e);
|
||||||
|
return Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>);
|
||||||
|
}
|
||||||
|
|
||||||
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
|
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
|
||||||
let auth_header = format!("Bearer {}", key);
|
let auth_header = format!("Bearer {}", key);
|
||||||
|
|
||||||
|
|
|
||||||
239
src/llm/rate_limiter.rs
Normal file
239
src/llm/rate_limiter.rs
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
// Rate limiter for LLM API calls
|
||||||
|
use governor::{
|
||||||
|
clock::DefaultClock,
|
||||||
|
state::{InMemoryState, NotKeyed},
|
||||||
|
Quota, RateLimiter,
|
||||||
|
};
|
||||||
|
use std::num::NonZeroU32;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
|
/// Rate limits for an API provider
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct RateLimits {
|
||||||
|
pub requests_per_minute: u32,
|
||||||
|
pub tokens_per_minute: u32,
|
||||||
|
pub requests_per_day: u32,
|
||||||
|
pub tokens_per_day: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimits {
|
||||||
|
/// Groq free tier rate limits
|
||||||
|
pub const fn groq_free_tier() -> Self {
|
||||||
|
Self {
|
||||||
|
requests_per_minute: 30,
|
||||||
|
tokens_per_minute: 8_000,
|
||||||
|
requests_per_day: 1_000,
|
||||||
|
tokens_per_day: 200_000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OpenAI free tier rate limits
|
||||||
|
pub const fn openai_free_tier() -> Self {
|
||||||
|
Self {
|
||||||
|
requests_per_minute: 3,
|
||||||
|
tokens_per_minute: 40_000,
|
||||||
|
requests_per_day: 200,
|
||||||
|
tokens_per_day: 150_000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// No rate limiting (for local models)
|
||||||
|
pub const fn unlimited() -> Self {
|
||||||
|
Self {
|
||||||
|
requests_per_minute: u32::MAX,
|
||||||
|
tokens_per_minute: u32::MAX,
|
||||||
|
requests_per_day: u32::MAX,
|
||||||
|
tokens_per_day: u32::MAX,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A rate limiter for API requests
|
||||||
|
pub struct ApiRateLimiter {
|
||||||
|
requests_per_minute: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
|
||||||
|
tokens_per_minute: Arc<Semaphore>,
|
||||||
|
// Track daily request count with a simple counter and reset time
|
||||||
|
daily_request_count: Arc<std::sync::atomic::AtomicU32>,
|
||||||
|
daily_request_reset: Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
daily_token_count: Arc<std::sync::atomic::AtomicU32>,
|
||||||
|
daily_token_reset: Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
requests_per_day: u32,
|
||||||
|
tokens_per_day: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ApiRateLimiter {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ApiRateLimiter")
|
||||||
|
.field("requests_per_minute", &self.requests_per_minute)
|
||||||
|
.field("tokens_per_minute", &"Semaphore")
|
||||||
|
.field("daily_request_count", &self.daily_request_count)
|
||||||
|
.field("daily_token_count", &self.daily_token_count)
|
||||||
|
.field("requests_per_day", &self.requests_per_day)
|
||||||
|
.field("tokens_per_day", &self.tokens_per_day)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for ApiRateLimiter {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
requests_per_minute: Arc::clone(&self.requests_per_minute),
|
||||||
|
tokens_per_minute: Arc::clone(&self.tokens_per_minute),
|
||||||
|
daily_request_count: Arc::clone(&self.daily_request_count),
|
||||||
|
daily_request_reset: Arc::clone(&self.daily_request_reset),
|
||||||
|
daily_token_count: Arc::clone(&self.daily_token_count),
|
||||||
|
daily_token_reset: Arc::clone(&self.daily_token_reset),
|
||||||
|
requests_per_day: self.requests_per_day,
|
||||||
|
tokens_per_day: self.tokens_per_day,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiRateLimiter {
|
||||||
|
/// Create a new rate limiter with the specified limits
|
||||||
|
pub fn new(limits: RateLimits) -> Self {
|
||||||
|
// Requests per minute limiter
|
||||||
|
let rpm_quota = Quota::per_minute(NonZeroU32::new(limits.requests_per_minute).unwrap());
|
||||||
|
let requests_per_minute = Arc::new(RateLimiter::direct(rpm_quota));
|
||||||
|
|
||||||
|
// Tokens per minute (using semaphore as we need to track token count)
|
||||||
|
let tokens_per_minute = Arc::new(Semaphore::new(
|
||||||
|
limits.tokens_per_minute.try_into().unwrap_or(usize::MAX)
|
||||||
|
));
|
||||||
|
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
let tomorrow = now + 86400;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
requests_per_minute,
|
||||||
|
tokens_per_minute,
|
||||||
|
daily_request_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
|
||||||
|
daily_request_reset: Arc::new(std::sync::atomic::AtomicU64::new(tomorrow)),
|
||||||
|
daily_token_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
|
||||||
|
daily_token_reset: Arc::new(std::sync::atomic::AtomicU64::new(tomorrow)),
|
||||||
|
requests_per_day: limits.requests_per_day,
|
||||||
|
tokens_per_day: limits.tokens_per_day,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an unlimited rate limiter (for local models)
|
||||||
|
pub fn unlimited() -> Self {
|
||||||
|
Self::new(RateLimits::unlimited())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if daily limits need resetting and reset if needed
|
||||||
|
fn check_and_reset_daily(&self) {
|
||||||
|
let now = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs();
|
||||||
|
let reset_time = self.daily_request_reset.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
|
||||||
|
if now >= reset_time {
|
||||||
|
// Reset counters
|
||||||
|
self.daily_request_count.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
self.daily_token_count.store(0, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Set new reset time to tomorrow
|
||||||
|
let tomorrow = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() + 86400;
|
||||||
|
self.daily_request_reset.store(tomorrow, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
self.daily_token_reset.store(tomorrow, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Acquire permission for a request with estimated token count
|
||||||
|
/// Returns when the request can proceed
|
||||||
|
pub async fn acquire(&self, estimated_tokens: usize) -> Result<(), RateLimitError> {
|
||||||
|
// Check and reset daily limits if needed
|
||||||
|
self.check_and_reset_daily();
|
||||||
|
|
||||||
|
// Check request rate limits
|
||||||
|
self.requests_per_minute.until_ready().await;
|
||||||
|
|
||||||
|
// Check daily request limit
|
||||||
|
let current_requests = self.daily_request_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if current_requests >= self.requests_per_day {
|
||||||
|
self.daily_request_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
return Err(RateLimitError::DailyRateLimitExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check token rate limits
|
||||||
|
let tokens_to_acquire = (estimated_tokens.min(8_000) as u32) as usize;
|
||||||
|
|
||||||
|
// Try to acquire token permits for minute limit
|
||||||
|
let tpm_available = self.tokens_per_minute.available_permits();
|
||||||
|
if tpm_available < tokens_to_acquire {
|
||||||
|
return Err(RateLimitError::TokenRateLimitExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check daily token limit
|
||||||
|
let current_tokens = self.daily_token_count.fetch_add(tokens_to_acquire as u32, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if current_tokens + (tokens_to_acquire as u32) > self.tokens_per_day {
|
||||||
|
self.daily_token_count.fetch_sub(tokens_to_acquire as u32, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
return Err(RateLimitError::DailyTokenLimitExceeded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acquire the permits (this will wait if needed)
|
||||||
|
let semaphore = Arc::clone(&self.tokens_per_minute);
|
||||||
|
let _permits = semaphore.acquire_many_owned(tokens_to_acquire as u32).await;
|
||||||
|
// Permits are held until the request completes
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Release token permits after request completes
|
||||||
|
pub fn release_tokens(&self, _tokens: u32) {
|
||||||
|
// Note: We don't release the daily token count as it's already "used"
|
||||||
|
// But we do need to release the semaphore permits for the minute limit
|
||||||
|
// The permits will be automatically released when dropped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum RateLimitError {
|
||||||
|
RateLimitExceeded,
|
||||||
|
DailyRateLimitExceeded,
|
||||||
|
TokenRateLimitExceeded,
|
||||||
|
DailyTokenLimitExceeded,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for RateLimitError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
RateLimitError::RateLimitExceeded => write!(f, "Rate limit exceeded"),
|
||||||
|
RateLimitError::DailyRateLimitExceeded => write!(f, "Daily request limit exceeded"),
|
||||||
|
RateLimitError::TokenRateLimitExceeded => write!(f, "Token per minute limit exceeded"),
|
||||||
|
RateLimitError::DailyTokenLimitExceeded => write!(f, "Daily token limit exceeded"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for RateLimitError {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_limits_display() {
|
||||||
|
let limits = RateLimits::groq_free_tier();
|
||||||
|
assert_eq!(limits.requests_per_minute, 30);
|
||||||
|
assert_eq!(limits.tokens_per_minute, 8_000);
|
||||||
|
assert_eq!(limits.requests_per_day, 1_000);
|
||||||
|
assert_eq!(limits.tokens_per_day, 200_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unlimited_limits() {
|
||||||
|
let limits = RateLimits::unlimited();
|
||||||
|
assert_eq!(limits.requests_per_minute, u32::MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue