gbserver/scripts/dev/llm_context.txt

2626 lines
85 KiB
Text
Raw Permalink Normal View History

Consolidated LLM Context
* Preffer imports than using :: to call methods,
* Output a single `.sh` script using `cat` so it can be restored directly.
* No placeholders, only real, production-ready code.
* No comments, no explanations, no extra text.
* Follow KISS principles.
* Provide a complete, professional, working solution.
* If the script is too long, split into multiple parts, but always return the **entire code**.
* Output must be **only the code**, nothing else.
[package]
name = "gbserver"
version = "0.1.0"
edition = "2021"
authors = ["Rodrigo Rodriguez <me@rodrigorodriguez.com>"]
description = "General Bots Server"
license = "AGPL-3.0"
repository = "https://alm.pragmatismo.com.br/generalbots/gbserver"
[features]
default = ["postgres", "qdrant"]
local_llm = []
postgres = ["sqlx/postgres"]
qdrant = ["langchain-rust/qdrant"]
[dependencies]
actix-cors = "0.7"
actix-multipart = "0.7"
actix-web = "4.9"
actix-ws = "0.3"
anyhow = "1.0"
async-stream = "0.3"
async-trait = "0.1"
aes-gcm = "0.10"
argon2 = "0.5"
base64 = "0.22"
bytes = "1.8"
chrono = { version = "0.4", features = ["serde"] }
dotenv = "0.15"
downloader = "0.2"
env_logger = "0.11"
futures = "0.3"
futures-util = "0.3"
imap = "2.4"
langchain-rust = { version = "4.6", features = ["qdrant", "postgres"] }
lettre = { version = "0.11", features = ["smtp-transport", "builder", "tokio1", "tokio1-native-tls"] }
livekit = "0.7"
log = "0.4"
mailparse = "0.15"
minio = { git = "https://github.com/minio/minio-rs", branch = "master" }
native-tls = "0.2"
num-format = "0.4"
qdrant-client = "1.12"
rhai = "1.22"
redis = { version = "0.27", features = ["tokio-comp"] }
regex = "1.11"
reqwest = { version = "0.12", features = ["json", "stream"] }
scraper = "0.20"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
smartstring = "1.0"
sqlx = { version = "0.8", features = ["time", "uuid", "runtime-tokio-rustls", "postgres", "chrono"] }
tempfile = "3"
thirtyfour = "0.34"
tokio = { version = "1.41", features = ["full"] }
tokio-stream = "0.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["fmt"] }
urlencoding = "2.1"
uuid = { version = "1.11", features = ["serde", "v4"] }
zip = "2.2"
You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected.
## Your Task
Fix ALL compiler errors and logical issues while maintaining the original intent. Return the COMPLETE corrected files as a SINGLE .sh script that can be executed from project root.
Use Cargo.toml as reference, do not change it.
Only return input files, all other files already exists.
If something, need to be added to a external file, inform it separated.
## Critical Requirements
1. **Return as SINGLE .sh script** - Output must be a complete shell script using `cat > file << 'EOF'` pattern
2. **Include ALL files** - Every corrected file must be included in the script
3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors
4. **Type safety** - Ensure all types match and trait bounds are satisfied
5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues
## Output Format Requirements
You MUST return exactly this example format:
```sh
#!/bin/bash
# Restore fixed Rust project
cat > src/<filenamehere>.rs << 'EOF'
use std::io;
// test
cat > src/<anotherfile>.rs << 'EOF'
// Fixed library code
pub fn add(a: i32, b: i32) -> i32 {
a + b
}
EOF
----
use std::sync::Arc;
use minio::s3::Client;
use crate::{config::AppConfig, web_automator::BrowserPool};
#[derive(Clone)]
pub struct AppState {
pub minio_client: Option<Client>,
pub config: Option<AppConfig>,
pub db: Option<sqlx::PgPool>,
pub db_custom: Option<sqlx::PgPool>,
pub browser_pool: Arc<BrowserPool>,
pub orchestrator: Arc<BotOrchestrator>,
pub web_adapter: Arc<WebChannelAdapter>,
pub voice_adapter: Arc<VoiceAdapter>,
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
tool_api: Arc<ToolApi>, // Add this
}
pub struct _BotState {
pub language: String,
pub work_folder: String,
}
use crate::config::AIConfig;
use langchain_rust::llm::OpenAI;
use langchain_rust::{language_models::llm::LLM, llm::AzureConfig};
use log::error;
use log::{debug, warn};
use rhai::{Array, Dynamic};
use serde_json::{json, Value};
use smartstring::SmartString;
use sqlx::Column; // Required for .name() method
use sqlx::TypeInfo; // Required for .type_info() method
use sqlx::{postgres::PgRow, Row};
use sqlx::{Decode, Type};
use std::error::Error;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use tokio::fs::File as TokioFile;
use tokio_stream::StreamExt;
use zip::ZipArchive;
use reqwest::Client;
use tokio::io::AsyncWriteExt;
pub fn azure_from_config(config: &AIConfig) -> AzureConfig {
AzureConfig::new()
.with_api_base(&config.endpoint)
.with_api_key(&config.key)
.with_api_version(&config.version)
.with_deployment_id(&config.instance)
}
pub async fn call_llm(
text: &str,
ai_config: &AIConfig,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let azure_config = azure_from_config(&ai_config.clone());
let open_ai = OpenAI::new(azure_config);
// Directly use the input text as prompt
let prompt = text.to_string();
// Call LLM and return the raw text response
match open_ai.invoke(&prompt).await {
Ok(response_text) => Ok(response_text),
Err(err) => {
error!("Error invoking LLM API: {}", err);
Err(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to invoke LLM API",
)))
}
}
}
pub fn extract_zip_recursive(
zip_path: &Path,
destination_path: &Path,
) -> Result<(), Box<dyn std::error::Error>> {
let file = File::open(zip_path)?;
let buf_reader = BufReader::new(file);
let mut archive = ZipArchive::new(buf_reader)?;
for i in 0..archive.len() {
let mut file = archive.by_index(i)?;
let outpath = destination_path.join(file.mangled_name());
if file.is_dir() {
std::fs::create_dir_all(&outpath)?;
} else {
if let Some(parent) = outpath.parent() {
if !parent.exists() {
std::fs::create_dir_all(&parent)?;
}
}
let mut outfile = File::create(&outpath)?;
std::io::copy(&mut file, &mut outfile)?;
}
}
Ok(())
}
pub fn row_to_json(row: PgRow) -> Result<Value, Box<dyn Error>> {
let mut result = serde_json::Map::new();
let columns = row.columns();
debug!("Converting row with {} columns", columns.len());
for (i, column) in columns.iter().enumerate() {
let column_name = column.name();
let type_name = column.type_info().name();
let value = match type_name {
"INT4" | "int4" => handle_nullable_type::<i32>(&row, i, column_name),
"INT8" | "int8" => handle_nullable_type::<i64>(&row, i, column_name),
"FLOAT4" | "float4" => handle_nullable_type::<f32>(&row, i, column_name),
"FLOAT8" | "float8" => handle_nullable_type::<f64>(&row, i, column_name),
"TEXT" | "VARCHAR" | "text" | "varchar" => {
handle_nullable_type::<String>(&row, i, column_name)
}
"BOOL" | "bool" => handle_nullable_type::<bool>(&row, i, column_name),
"JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name),
_ => {
warn!("Unknown type {} for column {}", type_name, column_name);
handle_nullable_type::<String>(&row, i, column_name)
}
};
result.insert(column_name.to_string(), value);
}
Ok(Value::Object(result))
}
fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value
where
T: Type<sqlx::Postgres> + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug,
{
match row.try_get::<Option<T>, _>(idx) {
Ok(Some(val)) => {
debug!("Successfully read column {} as {:?}", col_name, val);
json!(val)
}
Ok(None) => {
debug!("Column {} is NULL", col_name);
Value::Null
}
Err(e) => {
warn!("Failed to read column {}: {}", col_name, e);
Value::Null
}
}
}
fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value {
// First try to get as Option<Value>
match row.try_get::<Option<Value>, _>(idx) {
Ok(Some(val)) => {
debug!("Successfully read JSON column {} as Value", col_name);
return val;
}
Ok(None) => return Value::Null,
Err(_) => (), // Fall through to other attempts
}
// Try as Option<String> that might contain JSON
match row.try_get::<Option<String>, _>(idx) {
Ok(Some(s)) => match serde_json::from_str(&s) {
Ok(val) => val,
Err(_) => {
debug!("Column {} contains string that's not JSON", col_name);
json!(s)
}
},
Ok(None) => Value::Null,
Err(e) => {
warn!("Failed to read JSON column {}: {}", col_name, e);
Value::Null
}
}
}
pub fn json_value_to_dynamic(value: &Value) -> Dynamic {
match value {
Value::Null => Dynamic::UNIT,
Value::Bool(b) => Dynamic::from(*b),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
Dynamic::from(i)
} else if let Some(f) = n.as_f64() {
Dynamic::from(f)
} else {
Dynamic::UNIT
}
}
Value::String(s) => Dynamic::from(s.clone()),
Value::Array(arr) => Dynamic::from(
arr.iter()
.map(json_value_to_dynamic)
.collect::<rhai::Array>(),
),
Value::Object(obj) => Dynamic::from(
obj.iter()
.map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v)))
.collect::<rhai::Map>(),
),
}
}
/// Converts any value to an array - single values become single-element arrays
pub fn to_array(value: Dynamic) -> Array {
if value.is_array() {
// Already an array - return as-is
value.cast::<Array>()
} else if value.is_unit() || value.is::<()>() {
// Handle empty/unit case
Array::new()
} else {
// Convert single value to single-element array
Array::from([value])
}
}
pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new();
let response = client.get(url).send().await?;
if response.status().is_success() {
let mut file = TokioFile::create(output_path).await?;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
file.write_all(&chunk?).await?;
}
debug!("File downloaded successfully to {}", output_path);
} else {
return Err("Failed to download file".into());
}
Ok(())
}
// Helper function to parse the filter string into SQL WHERE clause and parameters
pub fn parse_filter(filter_str: &str) -> Result<(String, Vec<String>), Box<dyn Error>> {
let parts: Vec<&str> = filter_str.split('=').collect();
if parts.len() != 2 {
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
}
let column = parts[0].trim();
let value = parts[1].trim();
// Validate column name to prevent SQL injection
if !column
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_')
{
return Err("Invalid column name in filter".into());
}
// Return the parameterized query part and the value separately
Ok((format!("{} = $1", column), vec![value.to_string()]))
}
// Parse filter without adding quotes
pub fn parse_filter_with_offset(
filter_str: &str,
offset: usize,
) -> Result<(String, Vec<String>), Box<dyn Error>> {
let mut clauses = Vec::new();
let mut params = Vec::new();
for (i, condition) in filter_str.split('&').enumerate() {
let parts: Vec<&str> = condition.split('=').collect();
if parts.len() != 2 {
return Err("Invalid filter format".into());
}
let column = parts[0].trim();
let value = parts[1].trim();
if !column
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_')
{
return Err("Invalid column name".into());
}
clauses.push(format!("{} = ${}", column, i + 1 + offset));
params.push(value.to_string()); // Store raw value without quotes
}
Ok((clauses.join(" AND "), params))
}
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct organization {
pub org_id: Uuid,
pub name: String,
pub slug: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Bot {
pub bot_id: Uuid,
pub name: String,
pub status: BotStatus,
pub config: serde_json::Value,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
#[serde(rename_all = "snake_case")]
#[sqlx(type_name = "bot_status", rename_all = "snake_case")]
pub enum BotStatus {
Active,
Inactive,
Maintenance,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TriggerKind {
Scheduled = 0,
TableUpdate = 1,
TableInsert = 2,
TableDelete = 3,
}
impl TriggerKind {
pub fn from_i32(value: i32) -> Option<Self> {
match value {
0 => Some(Self::Scheduled),
1 => Some(Self::TableUpdate),
2 => Some(Self::TableInsert),
3 => Some(Self::TableDelete),
_ => None,
}
}
}
#[derive(Debug, FromRow, Serialize, Deserialize)]
pub struct Automation {
pub id: Uuid,
pub kind: i32, // Using number for trigger type
pub target: Option<String>,
pub schedule: Option<String>,
pub param: String,
pub is_active: bool,
pub last_triggered: Option<DateTime<Utc>>,
}
pub mod models;
pub mod state;
pub mod utils;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct UserSession {
pub id: Uuid,
pub user_id: Uuid,
pub bot_id: Uuid,
pub title: String,
pub context_data: serde_json::Value,
pub answer_mode: String,
pub current_tool: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub text: String,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: Vec<f32>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub text: String,
pub similarity: f32,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserMessage {
pub bot_id: String,
pub user_id: String,
pub session_id: String,
pub channel: String,
pub content: String,
pub message_type: String,
pub media_url: Option<String>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotResponse {
pub bot_id: String,
pub user_id: String,
pub session_id: String,
pub channel: String,
pub content: String,
pub message_type: String,
pub stream_token: Option<String>,
pub is_complete: bool,
}
use crate::{shared::state::AppState, whatsapp};
use actix_web::{web, HttpRequest, HttpResponse, Result};
use actix_ws::Message as WsMessage;
use chrono::Utc;
use langchain_rust::{
chain::{Chain, LLMChain},
llm::openai::OpenAI,
memory::SimpleMemory,
prompt_args,
tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder},
vectorstore::qdrant::Qdrant as LangChainQdrant,
vectorstore::{VecStoreOptions, VectorStore},
};
use log::info;
use serde_json;
use std::collections::HashMap;
use std::fs;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use uuid::Uuid;
use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
use crate::chart::ChartGenerator;
use crate::llm::LLMProvider;
use crate::session::SessionManager;
use crate::tools::ToolManager;
use crate::whatsapp::WhatsAppAdapter;
use crate::{
auth::AuthService,
shared::{BotResponse, UserMessage, UserSession},
};
pub struct BotOrchestrator {
session_manager: SessionManager,
tool_manager: ToolManager,
llm_provider: Arc<dyn LLMProvider>,
auth_service: AuthService,
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>,
}
impl BotOrchestrator {
pub fn new(
session_manager: SessionManager,
tool_manager: ToolManager,
llm_provider: Arc<dyn LLMProvider>,
auth_service: AuthService,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>,
) -> Self {
Self {
session_manager,
tool_manager,
llm_provider,
auth_service,
channels: HashMap::new(),
response_channels: Arc::new(Mutex::new(HashMap::new())),
chart_generator,
vector_store,
sql_chain,
}
}
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
self.channels.insert(channel_type.to_string(), adapter);
}
pub async fn register_response_channel(
&self,
session_id: String,
sender: mpsc::Sender<BotResponse>,
) {
self.response_channels
.lock()
.await
.insert(session_id, sender);
}
pub async fn set_user_answer_mode(
&self,
user_id: &str,
bot_id: &str,
mode: &str,
) -> Result<(), Box<dyn std::error::Error>> {
self.session_manager
.update_answer_mode(user_id, bot_id, mode)
.await?;
Ok(())
}
pub async fn process_message(
&self,
message: UserMessage,
) -> Result<(), Box<dyn std::error::Error>> {
info!(
"Processing message from channel: {}, user: {}",
message.channel, message.user_id
);
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
let bot_id = Uuid::parse_str(&message.bot_id)
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = match self
.session_manager
.get_user_session(user_id, bot_id)
.await?
{
Some(session) => session,
None => {
self.session_manager
.create_session(user_id, bot_id, "New Conversation")
.await?
}
};
if session.answer_mode == "tool" && session.current_tool.is_some() {
self.tool_manager
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
.await?;
return Ok(());
}
self.session_manager
.save_message(
session.id,
user_id,
"user",
&message.content,
&message.message_type,
)
.await?;
let response_content = match session.answer_mode.as_str() {
"document" => self.document_mode_handler(&message, &session).await?,
"chart" => self.chart_mode_handler(&message, &session).await?,
"database" => self.database_mode_handler(&message, &session).await?,
"tool" => self.tool_mode_handler(&message, &session).await?,
_ => self.direct_mode_handler(&message, &session).await?,
};
self.session_manager
.save_message(session.id, user_id, "assistant", &response_content, "text")
.await?;
let bot_response = BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
session_id: message.session_id,
channel: message.channel,
content: response_content,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await?;
}
Ok(())
}
async fn document_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
if let Some(vector_store) = &self.vector_store {
let similar_docs = vector_store
.similarity_search(&message.content, 3, &VecStoreOptions::default())
.await?;
let mut enhanced_prompt = format!("User question: {}\n\n", message.content);
if !similar_docs.is_empty() {
enhanced_prompt.push_str("Relevant documents:\n");
for (i, doc) in similar_docs.iter().enumerate() {
enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content));
}
enhanced_prompt.push_str(
"\nPlease answer the user's question based on the provided documents.",
);
}
self.llm_provider
.generate(&enhanced_prompt, &serde_json::Value::Null)
.await
} else {
self.direct_mode_handler(message, session).await
}
}
async fn chart_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
if let Some(chart_generator) = &self.chart_generator {
let chart_response = chart_generator
.generate_chart(&message.content, "bar")
.await?;
self.session_manager
.save_message(
session.id,
session.user_id,
"system",
&format!("Generated chart for query: {}", message.content),
"chart",
)
.await?;
Ok(format!(
"Chart generated for your query. Data retrieved: {}",
chart_response.sql_query
))
} else {
self.document_mode_handler(message, session).await
}
}
async fn database_mode_handler(
&self,
message: &UserMessage,
_session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
if let Some(sql_chain) = &self.sql_chain {
let input_variables = prompt_args! {
"input" => message.content,
};
let result = sql_chain.invoke(input_variables).await?;
Ok(result.to_string())
} else {
let db_url = std::env::var("DATABASE_URL")?;
let engine = PostgreSQLEngine::new(&db_url).await?;
let db = SQLDatabaseBuilder::new(engine).build().await?;
let llm = OpenAI::default();
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
.llm(llm)
.top_k(5)
.database(db)
.build()?;
let input_variables = chain.prompt_builder().query(&message.content).build();
let result = chain.invoke(input_variables).await?;
Ok(result.to_string())
}
}
async fn tool_mode_handler(
&self,
message: &UserMessage,
_session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
if message.content.to_lowercase().contains("calculator") {
if let Some(_adapter) = self.channels.get(&message.channel) {
let (tx, _rx) = mpsc::channel(100);
self.register_response_channel(message.session_id.clone(), tx.clone())
.await;
let tool_manager = self.tool_manager.clone();
let user_id_str = message.user_id.clone();
let bot_id_str = message.bot_id.clone();
let session_manager = self.session_manager.clone();
tokio::spawn(async move {
let _ = tool_manager
.execute_tool("calculator", &user_id_str, &bot_id_str, session_manager, tx)
.await;
});
}
Ok("Starting calculator tool...".to_string())
} else {
let available_tools = self.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else {
String::new()
};
let full_prompt = format!("{}{}", message.content, tools_context);
self.llm_provider
.generate(&full_prompt, &serde_json::Value::Null)
.await
}
}
async fn direct_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
let history = self
.session_manager
.get_conversation_history(session.id, session.user_id)
.await?;
let mut memory = SimpleMemory::new();
for (role, content) in history {
match role.as_str() {
"user" => memory.add_user_message(&content),
"assistant" => memory.add_ai_message(&content),
_ => {}
}
}
let mut prompt = String::new();
if let Some(chat_history) = memory.get_chat_history() {
for message in chat_history {
prompt.push_str(&format!(
"{}: {}\n",
message.message_type(),
message.content()
));
}
}
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
self.llm_provider
.generate(&prompt, &serde_json::Value::Null)
.await
}
pub async fn stream_response(
&self,
message: UserMessage,
mut response_tx: mpsc::Sender<BotResponse>,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Streaming response for user: {}", message.user_id);
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
let bot_id = Uuid::parse_str(&message.bot_id)
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = match self
.session_manager
.get_user_session(user_id, bot_id)
.await?
{
Some(session) => session,
None => {
self.session_manager
.create_session(user_id, bot_id, "New Conversation")
.await?
}
};
if session.answer_mode == "tool" && session.current_tool.is_some() {
self.tool_manager
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
.await?;
return Ok(());
}
self.session_manager
.save_message(
session.id,
user_id,
"user",
&message.content,
&message.message_type,
)
.await?;
let history = self
.session_manager
.get_conversation_history(session.id, user_id)
.await?;
let mut memory = SimpleMemory::new();
for (role, content) in history {
match role.as_str() {
"user" => memory.add_user_message(&content),
"assistant" => memory.add_ai_message(&content),
_ => {}
}
}
let mut prompt = String::new();
if let Some(chat_history) = memory.get_chat_history() {
for message in chat_history {
prompt.push_str(&format!(
"{}: {}\n",
message.message_type(),
message.content()
));
}
}
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
let (stream_tx, mut stream_rx) = mpsc::channel(100);
let llm_provider = self.llm_provider.clone();
let prompt_clone = prompt.clone();
tokio::spawn(async move {
let _ = llm_provider
.generate_stream(&prompt_clone, &serde_json::Value::Null, stream_tx)
.await;
});
let mut full_response = String::new();
while let Some(chunk) = stream_rx.recv().await {
full_response.push_str(&chunk);
let bot_response = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: chunk,
message_type: "text".to_string(),
stream_token: None,
is_complete: false,
};
if response_tx.send(bot_response).await.is_err() {
break;
}
}
self.session_manager
.save_message(session.id, user_id, "assistant", &full_response, "text")
.await?;
let final_response = BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
session_id: message.session_id,
channel: message.channel,
content: "".to_string(),
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
response_tx.send(final_response).await?;
Ok(())
}
pub async fn get_user_sessions(
&self,
user_id: Uuid,
) -> Result<Vec<UserSession>, Box<dyn std::error::Error>> {
self.session_manager.get_user_sessions(user_id).await
}
pub async fn get_conversation_history(
&self,
session_id: Uuid,
user_id: Uuid,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
self.session_manager
.get_conversation_history(session_id, user_id)
.await
}
pub async fn process_message_with_tools(
&self,
message: UserMessage,
) -> Result<(), Box<dyn std::error::Error>> {
info!(
"Processing message with tools from user: {}",
message.user_id
);
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
let bot_id = Uuid::parse_str(&message.bot_id)
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = match self
.session_manager
.get_user_session(user_id, bot_id)
.await?
{
Some(session) => session,
None => {
self.session_manager
.create_session(user_id, bot_id, "New Conversation")
.await?
}
};
self.session_manager
.save_message(
session.id,
user_id,
"user",
&message.content,
&message.message_type,
)
.await?;
let is_tool_waiting = self
.tool_manager
.is_tool_waiting(&message.session_id)
.await
.unwrap_or(false);
if is_tool_waiting {
self.tool_manager
.provide_input(&message.session_id, &message.content)
.await?;
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
for output in tool_output {
let bot_response = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: output,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await?;
}
}
}
return Ok(());
}
let available_tools = self.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else {
String::new()
};
let full_prompt = format!("{}{}", message.content, tools_context);
let response = if message.content.to_lowercase().contains("calculator")
|| message.content.to_lowercase().contains("calculate")
|| message.content.to_lowercase().contains("math")
{
match self
.tool_manager
.execute_tool("calculator", &message.session_id, &message.user_id)
.await
{
Ok(tool_result) => {
self.session_manager
.save_message(
session.id,
user_id,
"assistant",
&tool_result.output,
"tool_start",
)
.await?;
tool_result.output
}
Err(e) => {
format!("I encountered an error starting the calculator: {}", e)
}
}
} else {
self.llm_provider
.generate(&full_prompt, &serde_json::Value::Null)
.await?
};
self.session_manager
.save_message(session.id, user_id, "assistant", &response, "text")
.await?;
let bot_response = BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
session_id: message.session_id,
channel: message.channel,
content: response,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await?;
}
Ok(())
}
async fn tool_mode_handler(
&self,
message: &UserMessage,
_session: &UserSession,
) -> Result<String, Box<dyn std::error::Error>> {
if message.content.to_lowercase().contains("calculator") {
if let Some(_adapter) = self.channels.get(&message.channel) {
let (tx, _rx) = mpsc::channel(100);
self.register_response_channel(message.session_id.clone(), tx.clone())
.await;
let tool_manager = self.tool_manager.clone();
let user_id_str = message.user_id.clone();
let bot_id_str = message.bot_id.clone();
let session_manager = self.session_manager.clone();
tokio::spawn(async move {
let _ = tool_manager
.execute_tool_with_session(
"calculator",
&user_id_str,
&bot_id_str,
session_manager,
tx,
)
.await;
});
}
Ok("Starting calculator tool...".to_string())
} else {
let available_tools = self.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else {
String::new()
};
let full_prompt = format!("{}{}", message.content, tools_context);
self.llm_provider
.generate(&full_prompt, &serde_json::Value::Null)
.await
}
}
// Fix the process_message_with_tools method
pub async fn process_message_with_tools(
&self,
message: UserMessage,
) -> Result<(), Box<dyn std::error::Error>> {
info!(
"Processing message with tools from user: {}",
message.user_id
);
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
let bot_id = Uuid::parse_str(&message.bot_id)
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
let session = match self
.session_manager
.get_user_session(user_id, bot_id)
.await?
{
Some(session) => session,
None => {
self.session_manager
.create_session(user_id, bot_id, "New Conversation")
.await?
}
};
self.session_manager
.save_message(
session.id,
user_id,
"user",
&message.content,
&message.message_type,
)
.await?;
let is_tool_waiting = self
.tool_manager
.is_tool_waiting(&message.session_id)
.await
.unwrap_or(false);
if is_tool_waiting {
self.tool_manager
.provide_input(&message.session_id, &message.content)
.await?;
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
for output in tool_output {
let bot_response = BotResponse {
bot_id: message.bot_id.clone(),
user_id: message.user_id.clone(),
session_id: message.session_id.clone(),
channel: message.channel.clone(),
content: output,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await?;
}
}
}
return Ok(());
}
let response = if message.content.to_lowercase().contains("calculator")
|| message.content.to_lowercase().contains("calculate")
|| message.content.to_lowercase().contains("math")
{
match self
.tool_manager
.execute_tool("calculator", &message.session_id, &message.user_id)
.await
{
Ok(tool_result) => {
self.session_manager
.save_message(
session.id,
user_id,
"assistant",
&tool_result.output,
"tool_start",
)
.await?;
tool_result.output
}
Err(e) => {
format!("I encountered an error starting the calculator: {}", e)
}
}
} else {
let available_tools = self.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else {
String::new()
};
let full_prompt = format!("{}{}", message.content, tools_context);
self.llm_provider
.generate(&full_prompt, &serde_json::Value::Null)
.await?
};
self.session_manager
.save_message(session.id, user_id, "assistant", &response, "text")
.await?;
let bot_response = BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
session_id: message.session_id,
channel: message.channel,
content: response,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await?;
}
Ok(())
}
}
#[actix_web::get("/ws")]
async fn websocket_handler(
req: HttpRequest,
stream: web::Payload,
data: web::Data<AppState>,
) -> Result<HttpResponse, actix_web::Error> {
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
let session_id = Uuid::new_v4().to_string();
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
data.orchestrator
.register_response_channel(session_id.clone(), tx.clone())
.await;
data.web_adapter
.add_connection(session_id.clone(), tx.clone())
.await;
data.voice_adapter
.add_connection(session_id.clone(), tx.clone())
.await;
let orchestrator = data.orchestrator.clone();
let web_adapter = data.web_adapter.clone();
actix_web::rt::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Ok(json) = serde_json::to_string(&msg) {
let _ = session.text(json).await;
}
}
});
actix_web::rt::spawn(async move {
while let Some(Ok(msg)) = msg_stream.recv().await {
match msg {
WsMessage::Text(text) => {
let user_message = UserMessage {
bot_id: "default_bot".to_string(),
user_id: "default_user".to_string(),
session_id: session_id.clone(),
channel: "web".to_string(),
content: text.to_string(),
message_type: "text".to_string(),
media_url: None,
timestamp: Utc::now(),
};
if let Err(e) = orchestrator.stream_response(user_message, tx.clone()).await {
info!("Error processing message: {}", e);
}
}
WsMessage::Close(_) => {
web_adapter.remove_connection(&session_id).await;
break;
}
_ => {}
}
}
});
Ok(res)
}
#[actix_web::get("/api/whatsapp/webhook")]
async fn whatsapp_webhook_verify(
data: web::Data<AppState>,
web::Query(params): web::Query<HashMap<String, String>>,
) -> Result<HttpResponse> {
let mode = params.get("hub.mode").unwrap_or(&"".to_string());
let token = params.get("hub.verify_token").unwrap_or(&"".to_string());
let challenge = params.get("hub.challenge").unwrap_or(&"".to_string());
match data.whatsapp_adapter.verify_webhook(mode, token, challenge) {
Ok(challenge_response) => Ok(HttpResponse::Ok().body(challenge_response)),
Err(_) => Ok(HttpResponse::Forbidden().body("Verification failed")),
}
}
#[actix_web::post("/api/whatsapp/webhook")]
async fn whatsapp_webhook(
data: web::Data<AppState>,
payload: web::Json<crate::services::whatsapp::WhatsAppMessage>,
) -> Result<HttpResponse> {
match data
.whatsapp_adapter
.process_incoming_message(payload.into_inner())
.await
{
Ok(user_messages) => {
for user_message in user_messages {
if let Err(e) = data.orchestrator.process_message(user_message).await {
log::error!("Error processing WhatsApp message: {}", e);
}
}
Ok(HttpResponse::Ok().body(""))
}
Err(e) => {
log::error!("Error processing WhatsApp webhook: {}", e);
Ok(HttpResponse::BadRequest().body("Invalid message"))
}
}
}
#[actix_web::post("/api/voice/start")]
async fn voice_start(
data: web::Data<AppState>,
info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
let session_id = info
.get("session_id")
.and_then(|s| s.as_str())
.unwrap_or("");
let user_id = info
.get("user_id")
.and_then(|u| u.as_str())
.unwrap_or("user");
match data
.voice_adapter
.start_voice_session(session_id, user_id)
.await
{
Ok(token) => {
Ok(HttpResponse::Ok().json(serde_json::json!({"token": token, "status": "started"})))
}
Err(e) => {
Ok(HttpResponse::InternalServerError()
.json(serde_json::json!({"error": e.to_string()})))
}
}
}
#[actix_web::post("/api/voice/stop")]
async fn voice_stop(
data: web::Data<AppState>,
info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
let session_id = info
.get("session_id")
.and_then(|s| s.as_str())
.unwrap_or("");
match data.voice_adapter.stop_voice_session(session_id).await {
Ok(()) => Ok(HttpResponse::Ok().json(serde_json::json!({"status": "stopped"}))),
Err(e) => {
Ok(HttpResponse::InternalServerError()
.json(serde_json::json!({"error": e.to_string()})))
}
}
}
#[actix_web::post("/api/sessions")]
async fn create_session(_data: web::Data<AppState>) -> Result<HttpResponse> {
let session_id = Uuid::new_v4();
Ok(HttpResponse::Ok().json(serde_json::json!({
"session_id": session_id,
"title": "New Conversation",
"created_at": Utc::now()
})))
}
#[actix_web::get("/api/sessions")]
async fn get_sessions(data: web::Data<AppState>) -> Result<HttpResponse> {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
match data.orchestrator.get_user_sessions(user_id).await {
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
Err(e) => {
Ok(HttpResponse::InternalServerError()
.json(serde_json::json!({"error": e.to_string()})))
}
}
}
#[actix_web::get("/api/sessions/{session_id}")]
async fn get_session_history(
data: web::Data<AppState>,
path: web::Path<String>,
) -> Result<HttpResponse> {
let session_id = path.into_inner();
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
match Uuid::parse_str(&session_id) {
Ok(session_uuid) => match data
.orchestrator
.get_conversation_history(session_uuid, user_id)
.await
{
Ok(history) => Ok(HttpResponse::Ok().json(history)),
Err(e) => Ok(HttpResponse::InternalServerError()
.json(serde_json::json!({"error": e.to_string()}))),
},
Err(_) => {
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
}
}
}
#[actix_web::post("/api/set_mode")]
async fn set_mode_handler(
data: web::Data<AppState>,
info: web::Json<HashMap<String, String>>,
) -> Result<HttpResponse> {
let default_user = "default_user".to_string();
let default_bot = "default_bot".to_string();
let default_mode = "direct".to_string();
let user_id = info.get("user_id").unwrap_or(&default_user);
let bot_id = info.get("bot_id").unwrap_or(&default_bot);
let mode = info.get("mode").unwrap_or(&default_mode);
if let Err(e) = data
.orchestrator
.set_user_answer_mode(user_id, bot_id, mode)
.await
{
return Ok(
HttpResponse::InternalServerError().json(serde_json::json!({"error": e.to_string()}))
);
}
Ok(HttpResponse::Ok().json(serde_json::json!({"status": "mode_updated"})))
}
#[actix_web::get("/")]
async fn index() -> Result<HttpResponse> {
let html = fs::read_to_string("templates/index.html")
.unwrap_or_else(|_| include_str!("../../static/index.html").to_string());
Ok(HttpResponse::Ok().content_type("text/html").body(html))
}
#[actix_web::get("/static/{filename:.*}")]
async fn static_files(req: HttpRequest) -> Result<HttpResponse> {
let filename = req.match_info().query("filename");
let path = format!("static/{}", filename);
match fs::read(&path) {
Ok(content) => {
let content_type = match filename {
f if f.ends_with(".js") => "application/javascript",
f if f.ends_with(".css") => "text/css",
f if f.ends_with(".png") => "image/png",
f if f.ends_with(".jpg") | f.ends_with(".jpeg") => "image/jpeg",
_ => "text/plain",
};
Ok(HttpResponse::Ok().content_type(content_type).body(content))
}
Err(_) => Ok(HttpResponse::NotFound().body("File not found")),
}
}
use crate::shared::UserSession;
use redis::{AsyncCommands, Client};
use serde_json;
use sqlx::{PgPool, Row};
use std::sync::Arc;
use uuid::Uuid;
pub struct SessionManager {
pub pool: PgPool,
pub redis: Option<Arc<Client>>,
}
impl SessionManager {
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
Self { pool, redis }
}
pub async fn get_user_session(
&self,
user_id: Uuid,
bot_id: Uuid,
) -> Result<Option<UserSession>, Box<dyn std::error::Error>> {
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_id, bot_id);
let session_json: Option<String> = conn.get(&cache_key).await?;
if let Some(json) = session_json {
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
return Ok(Some(session));
}
}
}
let session = sqlx::query_as::<_, UserSession>(
"SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1",
)
.bind(user_id)
.bind(bot_id)
.fetch_optional(&self.pool)
.await?;
if let Some(ref session) = session {
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_id, bot_id);
let session_json = serde_json::to_string(session)?;
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
}
}
Ok(session)
}
pub async fn create_session(
&self,
user_id: Uuid,
bot_id: Uuid,
title: &str,
) -> Result<UserSession, Box<dyn std::error::Error>> {
let session = sqlx::query_as::<_, UserSession>(
"INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *",
)
.bind(user_id)
.bind(bot_id)
.bind(title)
.fetch_one(&self.pool)
.await?;
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_id, bot_id);
let session_json = serde_json::to_string(&session)?;
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
}
Ok(session)
}
pub async fn save_message(
&self,
session_id: Uuid,
user_id: Uuid,
role: &str,
content: &str,
message_type: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let message_count: i64 =
sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1")
.bind(session_id)
.fetch_one(&self.pool)
.await?
.get("count");
sqlx::query(
"INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index)
VALUES ($1, $2, $3, $4, $5, $6)",
)
.bind(session_id)
.bind(user_id)
.bind(role)
.bind(content)
.bind(message_type)
.bind(message_count + 1)
.execute(&self.pool)
.await?;
sqlx::query("UPDATE user_sessions SET updated_at = NOW() WHERE id = $1")
.bind(session_id)
.execute(&self.pool)
.await?;
if let Some(redis_client) = &self.redis {
if let Some(session_info) =
sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1")
.bind(session_id)
.fetch_optional(&self.pool)
.await?
{
let user_id: Uuid = session_info.get("user_id");
let bot_id: Uuid = session_info.get("bot_id");
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_id, bot_id);
let _: () = conn.del(cache_key).await?;
}
}
Ok(())
}
pub async fn get_conversation_history(
&self,
session_id: Uuid,
user_id: Uuid,
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
let messages = sqlx::query(
"SELECT role, content_encrypted FROM message_history
WHERE session_id = $1 AND user_id = $2
ORDER BY message_index ASC",
)
.bind(session_id)
.bind(user_id)
.fetch_all(&self.pool)
.await?;
let history = messages
.into_iter()
.map(|row| (row.get("role"), row.get("content_encrypted")))
.collect();
Ok(history)
}
pub async fn get_user_sessions(
&self,
user_id: Uuid,
) -> Result<Vec<UserSession>, Box<dyn std::error::Error>> {
let sessions = sqlx::query_as::<_, UserSession>(
"SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC",
)
.bind(user_id)
.fetch_all(&self.pool)
.await?;
Ok(sessions)
}
pub async fn update_answer_mode(
&self,
user_id: &str,
bot_id: &str,
mode: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let user_uuid = Uuid::parse_str(user_id)?;
let bot_uuid = Uuid::parse_str(bot_id)?;
sqlx::query(
"UPDATE user_sessions
SET answer_mode = $1, updated_at = NOW()
WHERE user_id = $2 AND bot_id = $3",
)
.bind(mode)
.bind(user_uuid)
.bind(bot_uuid)
.execute(&self.pool)
.await?;
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
let _: () = conn.del(cache_key).await?;
}
Ok(())
}
pub async fn update_current_tool(
&self,
user_id: &str,
bot_id: &str,
tool_name: Option<&str>,
) -> Result<(), Box<dyn std::error::Error>> {
let user_uuid = Uuid::parse_str(user_id)?;
let bot_uuid = Uuid::parse_str(bot_id)?;
sqlx::query(
"UPDATE user_sessions
SET current_tool = $1, updated_at = NOW()
WHERE user_id = $2 AND bot_id = $3",
)
.bind(tool_name)
.bind(user_uuid)
.bind(bot_uuid)
.execute(&self.pool)
.await?;
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
let _: () = conn.del(cache_key).await?;
}
Ok(())
}
pub async fn get_session_by_id(
&self,
session_id: Uuid,
) -> Result<Option<UserSession>, Box<dyn std::error::Error>> {
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session_by_id:{}", session_id);
let session_json: Option<String> = conn.get(&cache_key).await?;
if let Some(json) = session_json {
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
return Ok(Some(session));
}
}
}
let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1")
.bind(session_id)
.fetch_optional(&self.pool)
.await?;
if let Some(ref session) = session {
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session_by_id:{}", session_id);
let session_json = serde_json::to_string(session)?;
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
}
}
Ok(session)
}
pub async fn cleanup_old_sessions(
&self,
days_old: i32,
) -> Result<u64, Box<dyn std::error::Error>> {
let result = sqlx::query(
"DELETE FROM user_sessions
WHERE updated_at < NOW() - INTERVAL '1 day' * $1",
)
.bind(days_old)
.execute(&self.pool)
.await?;
Ok(result.rows_affected())
}
pub async fn set_current_tool(
&self,
user_id: &str,
bot_id: &str,
tool_name: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let user_uuid = Uuid::parse_str(user_id)?;
let bot_uuid = Uuid::parse_str(bot_id)?;
sqlx::query(
"UPDATE user_sessions
SET current_tool = $1, updated_at = NOW()
WHERE user_id = $2 AND bot_id = $3",
)
.bind(tool_name)
.bind(user_uuid)
.bind(bot_uuid)
.execute(&self.pool)
.await?;
if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
let _: () = conn.del(cache_key).await?;
}
Ok(())
}
}
impl Clone for SessionManager {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
redis: self.redis.clone(),
}
}
}
// src/tools/mod.rs
use async_trait::async_trait;
use redis::AsyncCommands;
use rhai::{Engine, Scope};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use uuid::Uuid;
use crate::channels::ChannelAdapter;
use crate::session::SessionManager;
use crate::shared::BotResponse;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub requires_input: bool,
pub session_id: String,
}
#[derive(Clone)]
pub struct Tool {
pub name: String,
pub description: String,
pub parameters: HashMap<String, String>,
pub script: String,
}
#[async_trait]
pub trait ToolExecutor: Send + Sync {
async fn execute(
&self,
tool_name: &str,
session_id: &str,
user_id: &str,
) -> Result<ToolResult, Box<dyn std::error::Error>>;
async fn provide_input(
&self,
session_id: &str,
input: &str,
) -> Result<(), Box<dyn std::error::Error>>;
async fn get_output(&self, session_id: &str)
-> Result<Vec<String>, Box<dyn std::error::Error>>;
async fn is_waiting_for_input(
&self,
session_id: &str,
) -> Result<bool, Box<dyn std::error::Error>>;
}
pub struct RedisToolExecutor {
redis_client: redis::Client,
web_adapter: Arc<dyn ChannelAdapter>,
voice_adapter: Arc<dyn ChannelAdapter>,
whatsapp_adapter: Arc<dyn ChannelAdapter>,
}
impl RedisToolExecutor {
pub fn new(
redis_url: &str,
web_adapter: Arc<dyn ChannelAdapter>,
voice_adapter: Arc<dyn ChannelAdapter>,
whatsapp_adapter: Arc<dyn ChannelAdapter>,
) -> Result<Self, Box<dyn std::error::Error>> {
let client = redis::Client::open(redis_url)?;
Ok(Self {
redis_client: client,
web_adapter,
voice_adapter,
whatsapp_adapter,
})
}
async fn send_tool_message(
&self,
session_id: &str,
user_id: &str,
channel: &str,
message: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.to_string(),
session_id: session_id.to_string(),
channel: channel.to_string(),
content: message.to_string(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
match channel {
"web" => self.web_adapter.send_message(response).await,
"voice" => self.voice_adapter.send_message(response).await,
"whatsapp" => self.whatsapp_adapter.send_message(response).await,
_ => Ok(()),
}
}
fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine {
let mut engine = Engine::new();
// Clone for TALK function
let tool_executor = Arc::new((
self.redis_client.clone(),
self.web_adapter.clone(),
self.voice_adapter.clone(),
self.whatsapp_adapter.clone(),
));
let session_id_clone = session_id.clone();
let user_id_clone = user_id.clone();
let channel_clone = channel.clone();
engine.register_fn("talk", move |message: String| {
let tool_executor = Arc::clone(&tool_executor);
let session_id = session_id_clone.clone();
let user_id = user_id_clone.clone();
let channel = channel_clone.clone();
tokio::spawn(async move {
let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor;
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.clone(),
session_id: session_id.clone(),
channel: channel.clone(),
content: message.clone(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
let result = match channel.as_str() {
"web" => web_adapter.send_message(response).await,
"voice" => voice_adapter.send_message(response).await,
"whatsapp" => whatsapp_adapter.send_message(response).await,
_ => Ok(()),
};
if let Err(e) = result {
log::error!("Failed to send tool message: {}", e);
}
if let Ok(mut conn) = redis_client.get_async_connection().await {
let output_key = format!("tool:{}:output", session_id);
let _ = conn.lpush(&output_key, &message).await;
}
});
});
let hear_executor = self.redis_client.clone();
let session_id_clone = session_id.clone();
engine.register_fn("hear", move || -> String {
let hear_executor = hear_executor.clone();
let session_id = session_id_clone.clone();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
match hear_executor.get_async_connection().await {
Ok(mut conn) => {
let input_key = format!("tool:{}:input", session_id);
let waiting_key = format!("tool:{}:waiting", session_id);
let _ = conn.set_ex(&waiting_key, "true", 300).await;
let result: Option<(String, String)> =
conn.brpop(&input_key, 30).await.ok().flatten();
let _ = conn.del(&waiting_key).await;
result
.map(|(_, input)| input)
.unwrap_or_else(|| "timeout".to_string())
}
Err(e) => {
log::error!("HEAR Redis error: {}", e);
"error".to_string()
}
}
})
});
engine
}
async fn cleanup_session(&self, session_id: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let keys = vec![
format!("tool:{}:output", session_id),
format!("tool:{}:input", session_id),
format!("tool:{}:waiting", session_id),
format!("tool:{}:active", session_id),
];
for key in keys {
let _: () = conn.del(&key).await?;
}
Ok(())
}
}
#[async_trait]
impl ToolExecutor for RedisToolExecutor {
async fn execute(
&self,
tool_name: &str,
session_id: &str,
user_id: &str,
) -> Result<ToolResult, Box<dyn std::error::Error>> {
let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?;
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let session_key = format!("tool:{}:session", session_id);
let session_data = serde_json::json!({
"user_id": user_id,
"tool_name": tool_name,
"started_at": chrono::Utc::now().to_rfc3339(),
});
conn.set_ex(&session_key, session_data.to_string(), 3600)
.await?;
let active_key = format!("tool:{}:active", session_id);
conn.set_ex(&active_key, "true", 3600).await?;
let channel = "web";
let _engine = self.create_rhai_engine(
session_id.to_string(),
user_id.to_string(),
channel.to_string(),
);
let redis_clone = self.redis_client.clone();
let web_adapter_clone = self.web_adapter.clone();
let voice_adapter_clone = self.voice_adapter.clone();
let whatsapp_adapter_clone = self.whatsapp_adapter.clone();
let session_id_clone = session_id.to_string();
let user_id_clone = user_id.to_string();
let tool_script = tool.script.clone();
tokio::spawn(async move {
let mut engine = Engine::new();
let mut scope = Scope::new();
let redis_client = redis_clone.clone();
let web_adapter = web_adapter_clone.clone();
let voice_adapter = voice_adapter_clone.clone();
let whatsapp_adapter = whatsapp_adapter_clone.clone();
let session_id = session_id_clone.clone();
let user_id = user_id_clone.clone();
engine.register_fn("talk", move |message: String| {
let redis_client = redis_client.clone();
let web_adapter = web_adapter.clone();
let voice_adapter = voice_adapter.clone();
let whatsapp_adapter = whatsapp_adapter.clone();
let session_id = session_id.clone();
let user_id = user_id.clone();
tokio::spawn(async move {
let channel = "web";
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.clone(),
session_id: session_id.clone(),
channel: channel.to_string(),
content: message.clone(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
let send_result = match channel {
"web" => web_adapter.send_message(response).await,
"voice" => voice_adapter.send_message(response).await,
"whatsapp" => whatsapp_adapter.send_message(response).await,
_ => Ok(()),
};
if let Err(e) = send_result {
log::error!("Failed to send tool message: {}", e);
}
if let Ok(mut conn) = redis_client.get_async_connection().await {
let output_key = format!("tool:{}:output", session_id);
let _ = conn.lpush(&output_key, &message).await;
}
});
});
let hear_redis = redis_clone.clone();
let session_id_hear = session_id.clone();
engine.register_fn("hear", move || -> String {
let hear_redis = hear_redis.clone();
let session_id = session_id_hear.clone();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
match hear_redis.get_async_connection().await {
Ok(mut conn) => {
let input_key = format!("tool:{}:input", session_id);
let waiting_key = format!("tool:{}:waiting", session_id);
let _ = conn.set_ex(&waiting_key, "true", 300).await;
let result: Option<(String, String)> =
conn.brpop(&input_key, 30).await.ok().flatten();
let _ = conn.del(&waiting_key).await;
result
.map(|(_, input)| input)
.unwrap_or_else(|| "timeout".to_string())
}
Err(_) => "error".to_string(),
}
})
});
match engine.eval_with_scope::<()>(&mut scope, &tool_script) {
Ok(_) => {
log::info!(
"Tool {} completed successfully for session {}",
tool_name,
session_id
);
let completion_msg =
"🛠️ Tool execution completed. How can I help you with anything else?";
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id_clone,
session_id: session_id_clone.clone(),
channel: "web".to_string(),
content: completion_msg.to_string(),
message_type: "tool_complete".to_string(),
stream_token: None,
is_complete: true,
};
let _ = web_adapter_clone.send_message(response).await;
}
Err(e) => {
log::error!("Tool execution failed: {}", e);
let error_msg = format!("❌ Tool error: {}", e);
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id_clone,
session_id: session_id_clone.clone(),
channel: "web".to_string(),
content: error_msg,
message_type: "tool_error".to_string(),
stream_token: None,
is_complete: true,
};
let _ = web_adapter_clone.send_message(response).await;
}
}
if let Ok(mut conn) = redis_clone.get_async_connection().await {
let active_key = format!("tool:{}:active", session_id_clone);
let _ = conn.del(&active_key).await;
}
});
Ok(ToolResult {
success: true,
output: format!(
"🛠️ Starting {} tool. Please follow the tool's instructions.",
tool_name
),
requires_input: true,
session_id: session_id.to_string(),
})
}
async fn provide_input(
&self,
session_id: &str,
input: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let input_key = format!("tool:{}:input", session_id);
conn.lpush(&input_key, input).await?;
Ok(())
}
async fn get_output(
&self,
session_id: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let output_key = format!("tool:{}:output", session_id);
let messages: Vec<String> = conn.lrange(&output_key, 0, -1).await?;
let _: () = conn.del(&output_key).await?;
Ok(messages)
}
async fn is_waiting_for_input(
&self,
session_id: &str,
) -> Result<bool, Box<dyn std::error::Error>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let waiting_key = format!("tool:{}:waiting", session_id);
let exists: bool = conn.exists(&waiting_key).await?;
Ok(exists)
}
}
fn get_tool(name: &str) -> Option<Tool> {
match name {
"calculator" => Some(Tool {
name: "calculator".to_string(),
description: "Perform mathematical calculations".to_string(),
parameters: HashMap::from([
("operation".to_string(), "add|subtract|multiply|divide".to_string()),
("a".to_string(), "number".to_string()),
("b".to_string(), "number".to_string()),
]),
script: r#"
let TALK = |message| {
talk(message);
};
let HEAR = || {
hear()
};
TALK("🔢 Calculator started!");
TALK("Please enter the first number:");
let a = HEAR();
TALK("Please enter the second number:");
let b = HEAR();
TALK("Choose operation: add, subtract, multiply, or divide:");
let op = HEAR();
let num_a = a.to_float();
let num_b = b.to_float();
if op == "add" {
let result = num_a + num_b;
TALK("✅ Result: " + a + " + " + b + " = " + result);
} else if op == "subtract" {
let result = num_a - num_b;
TALK("✅ Result: " + a + " - " + b + " = " + result);
} else if op == "multiply" {
let result = num_a * num_b;
TALK("✅ Result: " + a + " × " + b + " = " + result);
} else if op == "divide" {
if num_b != 0.0 {
let result = num_a / num_b;
TALK("✅ Result: " + a + " ÷ " + b + " = " + result);
} else {
TALK("❌ Error: Cannot divide by zero!");
}
} else {
TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide");
}
TALK("Calculator session completed. Thank you!");
"#.to_string(),
}),
_ => None,
}
}
#[derive(Clone)]
pub struct ToolManager {
tools: HashMap<String, Tool>,
waiting_responses: Arc<Mutex<HashMap<String, mpsc::Sender<String>>>>,
}
impl ToolManager {
pub fn new() -> Self {
let mut tools = HashMap::new();
let calculator_tool = Tool {
name: "calculator".to_string(),
description: "Perform calculations".to_string(),
parameters: HashMap::from([
(
"operation".to_string(),
"add|subtract|multiply|divide".to_string(),
),
("a".to_string(), "number".to_string()),
("b".to_string(), "number".to_string()),
]),
script: r#"
TALK("Calculator started. Enter first number:");
let a = HEAR();
TALK("Enter second number:");
let b = HEAR();
TALK("Operation (add/subtract/multiply/divide):");
let op = HEAR();
let num_a = a.parse::<f64>().unwrap();
let num_b = b.parse::<f64>().unwrap();
let result = if op == "add" {
num_a + num_b
} else if op == "subtract" {
num_a - num_b
} else if op == "multiply" {
num_a * num_b
} else if op == "divide" {
if num_b == 0.0 {
TALK("Cannot divide by zero");
return;
}
num_a / num_b
} else {
TALK("Invalid operation");
return;
};
TALK("Result: ".to_string() + &result.to_string());
"#
.to_string(),
};
tools.insert(calculator_tool.name.clone(), calculator_tool);
Self {
tools,
waiting_responses: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
pub async fn execute_tool(
&self,
tool_name: &str,
session_id: &str,
user_id: &str,
) -> Result<ToolResult, Box<dyn std::error::Error>> {
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
Ok(ToolResult {
success: true,
output: format!("Tool {} started for user {}", tool_name, user_id),
requires_input: true,
session_id: session_id.to_string(),
})
}
pub async fn is_tool_waiting(
&self,
session_id: &str,
) -> Result<bool, Box<dyn std::error::Error>> {
let waiting = self.waiting_responses.lock().await;
Ok(waiting.contains_key(session_id))
}
pub async fn provide_input(
&self,
session_id: &str,
input: &str,
) -> Result<(), Box<dyn std::error::Error>> {
self.provide_user_response(session_id, "default_bot", input.to_string())
.await
}
pub async fn get_tool_output(
&self,
session_id: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
Ok(vec![])
}
pub async fn execute_tool_with_session(
&self,
tool_name: &str,
user_id: &str,
bot_id: &str,
session_manager: SessionManager,
channel_sender: mpsc::Sender<BotResponse>,
) -> Result<(), Box<dyn std::error::Error>> {
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
session_manager
.set_current_tool(user_id, bot_id, Some(tool_name.to_string()))
.await?;
let user_id = user_id.to_string();
let bot_id = bot_id.to_string();
let script = tool.script.clone();
let session_manager_clone = session_manager.clone();
let waiting_responses = self.waiting_responses.clone();
tokio::spawn(async move {
let mut engine = rhai::Engine::new();
let (talk_tx, mut talk_rx) = mpsc::channel(100);
let (hear_tx, mut hear_rx) = mpsc::channel(100);
{
let key = format!("{}:{}", user_id, bot_id);
let mut waiting = waiting_responses.lock().await;
waiting.insert(key, hear_tx);
}
let channel_sender_clone = channel_sender.clone();
let user_id_clone = user_id.clone();
let bot_id_clone = bot_id.clone();
let talk_tx_clone = talk_tx.clone();
engine.register_fn("TALK", move |message: String| {
let tx = talk_tx_clone.clone();
tokio::spawn(async move {
let _ = tx.send(message).await;
});
});
let hear_rx_mutex = Arc::new(Mutex::new(hear_rx));
engine.register_fn("HEAR", move || {
let hear_rx = hear_rx_mutex.clone();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async move {
let mut receiver = hear_rx.lock().await;
receiver.recv().await.unwrap_or_default()
})
})
});
let script_result =
tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await;
if let Ok(Err(e)) = script_result {
let error_response = BotResponse {
bot_id: bot_id_clone.clone(),
user_id: user_id_clone.clone(),
session_id: Uuid::new_v4().to_string(),
channel: "test".to_string(),
content: format!("Tool error: {}", e),
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
let _ = channel_sender_clone.send(error_response).await;
}
while let Some(message) = talk_rx.recv().await {
let response = BotResponse {
bot_id: bot_id.clone(),
user_id: user_id.clone(),
session_id: Uuid::new_v4().to_string(),
channel: "test".to_string(),
content: message,
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
let _ = channel_sender.send(response).await;
}
let _ = session_manager_clone
.set_current_tool(&user_id, &bot_id, None)
.await;
});
Ok(())
}
pub async fn provide_user_response(
&self,
user_id: &str,
bot_id: &str,
response: String,
) -> Result<(), Box<dyn std::error::Error>> {
let key = format!("{}:{}", user_id, bot_id);
let mut waiting = self.waiting_responses.lock().await;
if let Some(tx) = waiting.get_mut(&key) {
let _ = tx.send(response).await;
waiting.remove(&key);
}
Ok(())
}
}
impl Default for ToolManager {
fn default() -> Self {
Self::new()
}
}
.
└── src
├── auth
│   └── mod.rs
├── automation
│   └── mod.rs
├── basic
│   ├── keywords
│   │   ├── create_draft.rs
│   │   ├── create_site.rs
│   │   ├── find.rs
│   │   ├── first.rs
│   │   ├── format.rs
│   │   ├── for_next.rs
│   │   ├── get.rs
│   │   ├── get_website.rs
│   │   ├── last.rs
│   │   ├── llm_keyword.rs
│   │   ├── mod.rs
│   │   ├── on.rs
│   │   ├── print.rs
│   │   ├── set.rs
│   │   ├── set_schedule.rs
│   │   └── wait.rs
│   └── mod.rs
├── bot
│   └── mod.rs
├── channels
│   └── mod.rs
├── chart
│   └── mod.rs
├── config
│   └── mod.rs
├── context
│   └── mod.rs
├── email
│   └── mod.rs
├── file
│   └── mod.rs
├── llm
│   ├── llm_generic.rs
│   ├── llm_local.rs
│   ├── llm_provider.rs
│   ├── llm.rs
│   └── mod.rs
├── main.rs
├── org
│   └── mod.rs
├── session
│   └── mod.rs
├── shared
│   ├── models.rs
│   ├── mod.rs
│   ├── state.rs
│   └── utils.rs
├── tests
│   ├── integration_email_list.rs
│   ├── integration_file_list_test.rs
│   └── integration_file_upload_test.rs
├── tools
│   └── mod.rs
├── web_automation
│   └── mod.rs
└── whatsapp
└── mod.rs
21 directories, 44 files