2626 lines
85 KiB
Text
2626 lines
85 KiB
Text
|
|
Consolidated LLM Context
|
|||
|
|
* Preffer imports than using :: to call methods,
|
|||
|
|
* Output a single `.sh` script using `cat` so it can be restored directly.
|
|||
|
|
* No placeholders, only real, production-ready code.
|
|||
|
|
* No comments, no explanations, no extra text.
|
|||
|
|
* Follow KISS principles.
|
|||
|
|
* Provide a complete, professional, working solution.
|
|||
|
|
* If the script is too long, split into multiple parts, but always return the **entire code**.
|
|||
|
|
* Output must be **only the code**, nothing else.
|
|||
|
|
|
|||
|
|
[package]
|
|||
|
|
name = "gbserver"
|
|||
|
|
version = "0.1.0"
|
|||
|
|
edition = "2021"
|
|||
|
|
authors = ["Rodrigo Rodriguez <me@rodrigorodriguez.com>"]
|
|||
|
|
description = "General Bots Server"
|
|||
|
|
license = "AGPL-3.0"
|
|||
|
|
repository = "https://alm.pragmatismo.com.br/generalbots/gbserver"
|
|||
|
|
|
|||
|
|
[features]
|
|||
|
|
default = ["postgres", "qdrant"]
|
|||
|
|
local_llm = []
|
|||
|
|
postgres = ["sqlx/postgres"]
|
|||
|
|
qdrant = ["langchain-rust/qdrant"]
|
|||
|
|
|
|||
|
|
[dependencies]
|
|||
|
|
actix-cors = "0.7"
|
|||
|
|
actix-multipart = "0.7"
|
|||
|
|
actix-web = "4.9"
|
|||
|
|
actix-ws = "0.3"
|
|||
|
|
anyhow = "1.0"
|
|||
|
|
async-stream = "0.3"
|
|||
|
|
async-trait = "0.1"
|
|||
|
|
aes-gcm = "0.10"
|
|||
|
|
argon2 = "0.5"
|
|||
|
|
base64 = "0.22"
|
|||
|
|
bytes = "1.8"
|
|||
|
|
chrono = { version = "0.4", features = ["serde"] }
|
|||
|
|
dotenv = "0.15"
|
|||
|
|
downloader = "0.2"
|
|||
|
|
env_logger = "0.11"
|
|||
|
|
futures = "0.3"
|
|||
|
|
futures-util = "0.3"
|
|||
|
|
imap = "2.4"
|
|||
|
|
langchain-rust = { version = "4.6", features = ["qdrant", "postgres"] }
|
|||
|
|
lettre = { version = "0.11", features = ["smtp-transport", "builder", "tokio1", "tokio1-native-tls"] }
|
|||
|
|
livekit = "0.7"
|
|||
|
|
log = "0.4"
|
|||
|
|
mailparse = "0.15"
|
|||
|
|
minio = { git = "https://github.com/minio/minio-rs", branch = "master" }
|
|||
|
|
native-tls = "0.2"
|
|||
|
|
num-format = "0.4"
|
|||
|
|
qdrant-client = "1.12"
|
|||
|
|
rhai = "1.22"
|
|||
|
|
redis = { version = "0.27", features = ["tokio-comp"] }
|
|||
|
|
regex = "1.11"
|
|||
|
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
|||
|
|
scraper = "0.20"
|
|||
|
|
serde = { version = "1.0", features = ["derive"] }
|
|||
|
|
serde_json = "1.0"
|
|||
|
|
smartstring = "1.0"
|
|||
|
|
sqlx = { version = "0.8", features = ["time", "uuid", "runtime-tokio-rustls", "postgres", "chrono"] }
|
|||
|
|
tempfile = "3"
|
|||
|
|
thirtyfour = "0.34"
|
|||
|
|
tokio = { version = "1.41", features = ["full"] }
|
|||
|
|
tokio-stream = "0.1"
|
|||
|
|
tracing = "0.1"
|
|||
|
|
tracing-subscriber = { version = "0.3", features = ["fmt"] }
|
|||
|
|
urlencoding = "2.1"
|
|||
|
|
uuid = { version = "1.11", features = ["serde", "v4"] }
|
|||
|
|
zip = "2.2"
|
|||
|
|
|
|||
|
|
You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected.
|
|||
|
|
|
|||
|
|
## Your Task
|
|||
|
|
Fix ALL compiler errors and logical issues while maintaining the original intent. Return the COMPLETE corrected files as a SINGLE .sh script that can be executed from project root.
|
|||
|
|
Use Cargo.toml as reference, do not change it.
|
|||
|
|
Only return input files, all other files already exists.
|
|||
|
|
If something, need to be added to a external file, inform it separated.
|
|||
|
|
|
|||
|
|
## Critical Requirements
|
|||
|
|
1. **Return as SINGLE .sh script** - Output must be a complete shell script using `cat > file << 'EOF'` pattern
|
|||
|
|
2. **Include ALL files** - Every corrected file must be included in the script
|
|||
|
|
3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors
|
|||
|
|
4. **Type safety** - Ensure all types match and trait bounds are satisfied
|
|||
|
|
5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues
|
|||
|
|
|
|||
|
|
## Output Format Requirements
|
|||
|
|
You MUST return exactly this example format:
|
|||
|
|
|
|||
|
|
```sh
|
|||
|
|
#!/bin/bash
|
|||
|
|
|
|||
|
|
# Restore fixed Rust project
|
|||
|
|
|
|||
|
|
cat > src/<filenamehere>.rs << 'EOF'
|
|||
|
|
use std::io;
|
|||
|
|
|
|||
|
|
// test
|
|||
|
|
|
|||
|
|
cat > src/<anotherfile>.rs << 'EOF'
|
|||
|
|
// Fixed library code
|
|||
|
|
pub fn add(a: i32, b: i32) -> i32 {
|
|||
|
|
a + b
|
|||
|
|
}
|
|||
|
|
EOF
|
|||
|
|
|
|||
|
|
----
|
|||
|
|
|
|||
|
|
use std::sync::Arc;
|
|||
|
|
|
|||
|
|
use minio::s3::Client;
|
|||
|
|
|
|||
|
|
use crate::{config::AppConfig, web_automator::BrowserPool};
|
|||
|
|
|
|||
|
|
#[derive(Clone)]
|
|||
|
|
pub struct AppState {
|
|||
|
|
pub minio_client: Option<Client>,
|
|||
|
|
pub config: Option<AppConfig>,
|
|||
|
|
pub db: Option<sqlx::PgPool>,
|
|||
|
|
pub db_custom: Option<sqlx::PgPool>,
|
|||
|
|
pub browser_pool: Arc<BrowserPool>,
|
|||
|
|
pub orchestrator: Arc<BotOrchestrator>,
|
|||
|
|
pub web_adapter: Arc<WebChannelAdapter>,
|
|||
|
|
pub voice_adapter: Arc<VoiceAdapter>,
|
|||
|
|
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
|
|||
|
|
tool_api: Arc<ToolApi>, // Add this
|
|||
|
|
}
|
|||
|
|
pub struct _BotState {
|
|||
|
|
pub language: String,
|
|||
|
|
pub work_folder: String,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
use crate::config::AIConfig;
|
|||
|
|
use langchain_rust::llm::OpenAI;
|
|||
|
|
use langchain_rust::{language_models::llm::LLM, llm::AzureConfig};
|
|||
|
|
use log::error;
|
|||
|
|
use log::{debug, warn};
|
|||
|
|
use rhai::{Array, Dynamic};
|
|||
|
|
use serde_json::{json, Value};
|
|||
|
|
use smartstring::SmartString;
|
|||
|
|
use sqlx::Column; // Required for .name() method
|
|||
|
|
use sqlx::TypeInfo; // Required for .type_info() method
|
|||
|
|
use sqlx::{postgres::PgRow, Row};
|
|||
|
|
use sqlx::{Decode, Type};
|
|||
|
|
use std::error::Error;
|
|||
|
|
use std::fs::File;
|
|||
|
|
use std::io::BufReader;
|
|||
|
|
use std::path::Path;
|
|||
|
|
use tokio::fs::File as TokioFile;
|
|||
|
|
use tokio_stream::StreamExt;
|
|||
|
|
use zip::ZipArchive;
|
|||
|
|
|
|||
|
|
use reqwest::Client;
|
|||
|
|
use tokio::io::AsyncWriteExt;
|
|||
|
|
|
|||
|
|
pub fn azure_from_config(config: &AIConfig) -> AzureConfig {
|
|||
|
|
AzureConfig::new()
|
|||
|
|
.with_api_base(&config.endpoint)
|
|||
|
|
.with_api_key(&config.key)
|
|||
|
|
.with_api_version(&config.version)
|
|||
|
|
.with_deployment_id(&config.instance)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn call_llm(
|
|||
|
|
text: &str,
|
|||
|
|
ai_config: &AIConfig,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
|
let azure_config = azure_from_config(&ai_config.clone());
|
|||
|
|
let open_ai = OpenAI::new(azure_config);
|
|||
|
|
|
|||
|
|
// Directly use the input text as prompt
|
|||
|
|
let prompt = text.to_string();
|
|||
|
|
|
|||
|
|
// Call LLM and return the raw text response
|
|||
|
|
match open_ai.invoke(&prompt).await {
|
|||
|
|
Ok(response_text) => Ok(response_text),
|
|||
|
|
Err(err) => {
|
|||
|
|
error!("Error invoking LLM API: {}", err);
|
|||
|
|
Err(Box::new(std::io::Error::new(
|
|||
|
|
std::io::ErrorKind::Other,
|
|||
|
|
"Failed to invoke LLM API",
|
|||
|
|
)))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn extract_zip_recursive(
|
|||
|
|
zip_path: &Path,
|
|||
|
|
destination_path: &Path,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let file = File::open(zip_path)?;
|
|||
|
|
let buf_reader = BufReader::new(file);
|
|||
|
|
let mut archive = ZipArchive::new(buf_reader)?;
|
|||
|
|
|
|||
|
|
for i in 0..archive.len() {
|
|||
|
|
let mut file = archive.by_index(i)?;
|
|||
|
|
let outpath = destination_path.join(file.mangled_name());
|
|||
|
|
|
|||
|
|
if file.is_dir() {
|
|||
|
|
std::fs::create_dir_all(&outpath)?;
|
|||
|
|
} else {
|
|||
|
|
if let Some(parent) = outpath.parent() {
|
|||
|
|
if !parent.exists() {
|
|||
|
|
std::fs::create_dir_all(&parent)?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
let mut outfile = File::create(&outpath)?;
|
|||
|
|
std::io::copy(&mut file, &mut outfile)?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
pub fn row_to_json(row: PgRow) -> Result<Value, Box<dyn Error>> {
|
|||
|
|
let mut result = serde_json::Map::new();
|
|||
|
|
let columns = row.columns();
|
|||
|
|
debug!("Converting row with {} columns", columns.len());
|
|||
|
|
|
|||
|
|
for (i, column) in columns.iter().enumerate() {
|
|||
|
|
let column_name = column.name();
|
|||
|
|
let type_name = column.type_info().name();
|
|||
|
|
|
|||
|
|
let value = match type_name {
|
|||
|
|
"INT4" | "int4" => handle_nullable_type::<i32>(&row, i, column_name),
|
|||
|
|
"INT8" | "int8" => handle_nullable_type::<i64>(&row, i, column_name),
|
|||
|
|
"FLOAT4" | "float4" => handle_nullable_type::<f32>(&row, i, column_name),
|
|||
|
|
"FLOAT8" | "float8" => handle_nullable_type::<f64>(&row, i, column_name),
|
|||
|
|
"TEXT" | "VARCHAR" | "text" | "varchar" => {
|
|||
|
|
handle_nullable_type::<String>(&row, i, column_name)
|
|||
|
|
}
|
|||
|
|
"BOOL" | "bool" => handle_nullable_type::<bool>(&row, i, column_name),
|
|||
|
|
"JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name),
|
|||
|
|
_ => {
|
|||
|
|
warn!("Unknown type {} for column {}", type_name, column_name);
|
|||
|
|
handle_nullable_type::<String>(&row, i, column_name)
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
result.insert(column_name.to_string(), value);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(Value::Object(result))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value
|
|||
|
|
where
|
|||
|
|
T: Type<sqlx::Postgres> + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug,
|
|||
|
|
{
|
|||
|
|
match row.try_get::<Option<T>, _>(idx) {
|
|||
|
|
Ok(Some(val)) => {
|
|||
|
|
debug!("Successfully read column {} as {:?}", col_name, val);
|
|||
|
|
json!(val)
|
|||
|
|
}
|
|||
|
|
Ok(None) => {
|
|||
|
|
debug!("Column {} is NULL", col_name);
|
|||
|
|
Value::Null
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
warn!("Failed to read column {}: {}", col_name, e);
|
|||
|
|
Value::Null
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value {
|
|||
|
|
// First try to get as Option<Value>
|
|||
|
|
match row.try_get::<Option<Value>, _>(idx) {
|
|||
|
|
Ok(Some(val)) => {
|
|||
|
|
debug!("Successfully read JSON column {} as Value", col_name);
|
|||
|
|
return val;
|
|||
|
|
}
|
|||
|
|
Ok(None) => return Value::Null,
|
|||
|
|
Err(_) => (), // Fall through to other attempts
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Try as Option<String> that might contain JSON
|
|||
|
|
match row.try_get::<Option<String>, _>(idx) {
|
|||
|
|
Ok(Some(s)) => match serde_json::from_str(&s) {
|
|||
|
|
Ok(val) => val,
|
|||
|
|
Err(_) => {
|
|||
|
|
debug!("Column {} contains string that's not JSON", col_name);
|
|||
|
|
json!(s)
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
Ok(None) => Value::Null,
|
|||
|
|
Err(e) => {
|
|||
|
|
warn!("Failed to read JSON column {}: {}", col_name, e);
|
|||
|
|
Value::Null
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
pub fn json_value_to_dynamic(value: &Value) -> Dynamic {
|
|||
|
|
match value {
|
|||
|
|
Value::Null => Dynamic::UNIT,
|
|||
|
|
Value::Bool(b) => Dynamic::from(*b),
|
|||
|
|
Value::Number(n) => {
|
|||
|
|
if let Some(i) = n.as_i64() {
|
|||
|
|
Dynamic::from(i)
|
|||
|
|
} else if let Some(f) = n.as_f64() {
|
|||
|
|
Dynamic::from(f)
|
|||
|
|
} else {
|
|||
|
|
Dynamic::UNIT
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
Value::String(s) => Dynamic::from(s.clone()),
|
|||
|
|
Value::Array(arr) => Dynamic::from(
|
|||
|
|
arr.iter()
|
|||
|
|
.map(json_value_to_dynamic)
|
|||
|
|
.collect::<rhai::Array>(),
|
|||
|
|
),
|
|||
|
|
Value::Object(obj) => Dynamic::from(
|
|||
|
|
obj.iter()
|
|||
|
|
.map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v)))
|
|||
|
|
.collect::<rhai::Map>(),
|
|||
|
|
),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/// Converts any value to an array - single values become single-element arrays
|
|||
|
|
pub fn to_array(value: Dynamic) -> Array {
|
|||
|
|
if value.is_array() {
|
|||
|
|
// Already an array - return as-is
|
|||
|
|
value.cast::<Array>()
|
|||
|
|
} else if value.is_unit() || value.is::<()>() {
|
|||
|
|
// Handle empty/unit case
|
|||
|
|
Array::new()
|
|||
|
|
} else {
|
|||
|
|
// Convert single value to single-element array
|
|||
|
|
Array::from([value])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let client = Client::new();
|
|||
|
|
let response = client.get(url).send().await?;
|
|||
|
|
|
|||
|
|
if response.status().is_success() {
|
|||
|
|
let mut file = TokioFile::create(output_path).await?;
|
|||
|
|
|
|||
|
|
let mut stream = response.bytes_stream();
|
|||
|
|
|
|||
|
|
while let Some(chunk) = stream.next().await {
|
|||
|
|
file.write_all(&chunk?).await?;
|
|||
|
|
}
|
|||
|
|
debug!("File downloaded successfully to {}", output_path);
|
|||
|
|
} else {
|
|||
|
|
return Err("Failed to download file".into());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Helper function to parse the filter string into SQL WHERE clause and parameters
|
|||
|
|
pub fn parse_filter(filter_str: &str) -> Result<(String, Vec<String>), Box<dyn Error>> {
|
|||
|
|
let parts: Vec<&str> = filter_str.split('=').collect();
|
|||
|
|
if parts.len() != 2 {
|
|||
|
|
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let column = parts[0].trim();
|
|||
|
|
let value = parts[1].trim();
|
|||
|
|
|
|||
|
|
// Validate column name to prevent SQL injection
|
|||
|
|
if !column
|
|||
|
|
.chars()
|
|||
|
|
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
|||
|
|
{
|
|||
|
|
return Err("Invalid column name in filter".into());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Return the parameterized query part and the value separately
|
|||
|
|
Ok((format!("{} = $1", column), vec![value.to_string()]))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Parse filter without adding quotes
|
|||
|
|
pub fn parse_filter_with_offset(
|
|||
|
|
filter_str: &str,
|
|||
|
|
offset: usize,
|
|||
|
|
) -> Result<(String, Vec<String>), Box<dyn Error>> {
|
|||
|
|
let mut clauses = Vec::new();
|
|||
|
|
let mut params = Vec::new();
|
|||
|
|
|
|||
|
|
for (i, condition) in filter_str.split('&').enumerate() {
|
|||
|
|
let parts: Vec<&str> = condition.split('=').collect();
|
|||
|
|
if parts.len() != 2 {
|
|||
|
|
return Err("Invalid filter format".into());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let column = parts[0].trim();
|
|||
|
|
let value = parts[1].trim();
|
|||
|
|
|
|||
|
|
if !column
|
|||
|
|
.chars()
|
|||
|
|
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
|||
|
|
{
|
|||
|
|
return Err("Invalid column name".into());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
clauses.push(format!("{} = ${}", column, i + 1 + offset));
|
|||
|
|
params.push(value.to_string()); // Store raw value without quotes
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok((clauses.join(" AND "), params))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
use chrono::{DateTime, Utc};
|
|||
|
|
use serde::{Deserialize, Serialize};
|
|||
|
|
use sqlx::FromRow;
|
|||
|
|
use uuid::Uuid;
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
|||
|
|
pub struct organization {
|
|||
|
|
pub org_id: Uuid,
|
|||
|
|
pub name: String,
|
|||
|
|
pub slug: String,
|
|||
|
|
pub created_at: DateTime<Utc>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
|||
|
|
pub struct Bot {
|
|||
|
|
pub bot_id: Uuid,
|
|||
|
|
pub name: String,
|
|||
|
|
pub status: BotStatus,
|
|||
|
|
pub config: serde_json::Value,
|
|||
|
|
pub created_at: chrono::DateTime<Utc>,
|
|||
|
|
pub updated_at: chrono::DateTime<Utc>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
|
|||
|
|
#[serde(rename_all = "snake_case")]
|
|||
|
|
#[sqlx(type_name = "bot_status", rename_all = "snake_case")]
|
|||
|
|
pub enum BotStatus {
|
|||
|
|
Active,
|
|||
|
|
Inactive,
|
|||
|
|
Maintenance,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
|||
|
|
pub enum TriggerKind {
|
|||
|
|
Scheduled = 0,
|
|||
|
|
TableUpdate = 1,
|
|||
|
|
TableInsert = 2,
|
|||
|
|
TableDelete = 3,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl TriggerKind {
|
|||
|
|
pub fn from_i32(value: i32) -> Option<Self> {
|
|||
|
|
match value {
|
|||
|
|
0 => Some(Self::Scheduled),
|
|||
|
|
1 => Some(Self::TableUpdate),
|
|||
|
|
2 => Some(Self::TableInsert),
|
|||
|
|
3 => Some(Self::TableDelete),
|
|||
|
|
_ => None,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
|||
|
|
pub struct Automation {
|
|||
|
|
pub id: Uuid,
|
|||
|
|
pub kind: i32, // Using number for trigger type
|
|||
|
|
pub target: Option<String>,
|
|||
|
|
pub schedule: Option<String>,
|
|||
|
|
pub param: String,
|
|||
|
|
pub is_active: bool,
|
|||
|
|
pub last_triggered: Option<DateTime<Utc>>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub mod models;
|
|||
|
|
pub mod state;
|
|||
|
|
pub mod utils;
|
|||
|
|
|
|||
|
|
use chrono::{DateTime, Utc};
|
|||
|
|
use serde::{Deserialize, Serialize};
|
|||
|
|
use sqlx::FromRow;
|
|||
|
|
use uuid::Uuid;
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
|||
|
|
pub struct UserSession {
|
|||
|
|
pub id: Uuid,
|
|||
|
|
pub user_id: Uuid,
|
|||
|
|
pub bot_id: Uuid,
|
|||
|
|
pub title: String,
|
|||
|
|
pub context_data: serde_json::Value,
|
|||
|
|
pub answer_mode: String,
|
|||
|
|
pub current_tool: Option<String>,
|
|||
|
|
pub created_at: DateTime<Utc>,
|
|||
|
|
pub updated_at: DateTime<Utc>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct EmbeddingRequest {
|
|||
|
|
pub text: String,
|
|||
|
|
pub model: Option<String>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct EmbeddingResponse {
|
|||
|
|
pub embedding: Vec<f32>,
|
|||
|
|
pub model: String,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct SearchResult {
|
|||
|
|
pub text: String,
|
|||
|
|
pub similarity: f32,
|
|||
|
|
pub metadata: serde_json::Value,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct UserMessage {
|
|||
|
|
pub bot_id: String,
|
|||
|
|
pub user_id: String,
|
|||
|
|
pub session_id: String,
|
|||
|
|
pub channel: String,
|
|||
|
|
pub content: String,
|
|||
|
|
pub message_type: String,
|
|||
|
|
pub media_url: Option<String>,
|
|||
|
|
pub timestamp: DateTime<Utc>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct BotResponse {
|
|||
|
|
pub bot_id: String,
|
|||
|
|
pub user_id: String,
|
|||
|
|
pub session_id: String,
|
|||
|
|
pub channel: String,
|
|||
|
|
pub content: String,
|
|||
|
|
pub message_type: String,
|
|||
|
|
pub stream_token: Option<String>,
|
|||
|
|
pub is_complete: bool,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
use crate::{shared::state::AppState, whatsapp};
|
|||
|
|
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
|||
|
|
use actix_ws::Message as WsMessage;
|
|||
|
|
use chrono::Utc;
|
|||
|
|
use langchain_rust::{
|
|||
|
|
chain::{Chain, LLMChain},
|
|||
|
|
llm::openai::OpenAI,
|
|||
|
|
memory::SimpleMemory,
|
|||
|
|
prompt_args,
|
|||
|
|
tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder},
|
|||
|
|
vectorstore::qdrant::Qdrant as LangChainQdrant,
|
|||
|
|
vectorstore::{VecStoreOptions, VectorStore},
|
|||
|
|
};
|
|||
|
|
use log::info;
|
|||
|
|
use serde_json;
|
|||
|
|
use std::collections::HashMap;
|
|||
|
|
use std::fs;
|
|||
|
|
use std::sync::Arc;
|
|||
|
|
use tokio::sync::{mpsc, Mutex};
|
|||
|
|
use uuid::Uuid;
|
|||
|
|
|
|||
|
|
use crate::channels::{ChannelAdapter, VoiceAdapter, WebChannelAdapter};
|
|||
|
|
use crate::chart::ChartGenerator;
|
|||
|
|
use crate::llm::LLMProvider;
|
|||
|
|
use crate::session::SessionManager;
|
|||
|
|
use crate::tools::ToolManager;
|
|||
|
|
use crate::whatsapp::WhatsAppAdapter;
|
|||
|
|
use crate::{
|
|||
|
|
auth::AuthService,
|
|||
|
|
shared::{BotResponse, UserMessage, UserSession},
|
|||
|
|
};
|
|||
|
|
pub struct BotOrchestrator {
|
|||
|
|
session_manager: SessionManager,
|
|||
|
|
tool_manager: ToolManager,
|
|||
|
|
llm_provider: Arc<dyn LLMProvider>,
|
|||
|
|
auth_service: AuthService,
|
|||
|
|
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
|||
|
|
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
|||
|
|
chart_generator: Option<Arc<ChartGenerator>>,
|
|||
|
|
vector_store: Option<Arc<LangChainQdrant>>,
|
|||
|
|
sql_chain: Option<Arc<LLMChain>>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl BotOrchestrator {
|
|||
|
|
pub fn new(
|
|||
|
|
session_manager: SessionManager,
|
|||
|
|
tool_manager: ToolManager,
|
|||
|
|
llm_provider: Arc<dyn LLMProvider>,
|
|||
|
|
auth_service: AuthService,
|
|||
|
|
chart_generator: Option<Arc<ChartGenerator>>,
|
|||
|
|
vector_store: Option<Arc<LangChainQdrant>>,
|
|||
|
|
sql_chain: Option<Arc<LLMChain>>,
|
|||
|
|
) -> Self {
|
|||
|
|
Self {
|
|||
|
|
session_manager,
|
|||
|
|
tool_manager,
|
|||
|
|
llm_provider,
|
|||
|
|
auth_service,
|
|||
|
|
channels: HashMap::new(),
|
|||
|
|
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
|||
|
|
chart_generator,
|
|||
|
|
vector_store,
|
|||
|
|
sql_chain,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
|||
|
|
self.channels.insert(channel_type.to_string(), adapter);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn register_response_channel(
|
|||
|
|
&self,
|
|||
|
|
session_id: String,
|
|||
|
|
sender: mpsc::Sender<BotResponse>,
|
|||
|
|
) {
|
|||
|
|
self.response_channels
|
|||
|
|
.lock()
|
|||
|
|
.await
|
|||
|
|
.insert(session_id, sender);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn set_user_answer_mode(
|
|||
|
|
&self,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
mode: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
self.session_manager
|
|||
|
|
.update_answer_mode(user_id, bot_id, mode)
|
|||
|
|
.await?;
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn process_message(
|
|||
|
|
&self,
|
|||
|
|
message: UserMessage,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
info!(
|
|||
|
|
"Processing message from channel: {}, user: {}",
|
|||
|
|
message.channel, message.user_id
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|||
|
|
let bot_id = Uuid::parse_str(&message.bot_id)
|
|||
|
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|||
|
|
|
|||
|
|
let session = match self
|
|||
|
|
.session_manager
|
|||
|
|
.get_user_session(user_id, bot_id)
|
|||
|
|
.await?
|
|||
|
|
{
|
|||
|
|
Some(session) => session,
|
|||
|
|
None => {
|
|||
|
|
self.session_manager
|
|||
|
|
.create_session(user_id, bot_id, "New Conversation")
|
|||
|
|
.await?
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
|||
|
|
self.tool_manager
|
|||
|
|
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
|||
|
|
.await?;
|
|||
|
|
return Ok(());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"user",
|
|||
|
|
&message.content,
|
|||
|
|
&message.message_type,
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let response_content = match session.answer_mode.as_str() {
|
|||
|
|
"document" => self.document_mode_handler(&message, &session).await?,
|
|||
|
|
"chart" => self.chart_mode_handler(&message, &session).await?,
|
|||
|
|
"database" => self.database_mode_handler(&message, &session).await?,
|
|||
|
|
"tool" => self.tool_mode_handler(&message, &session).await?,
|
|||
|
|
_ => self.direct_mode_handler(&message, &session).await?,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(session.id, user_id, "assistant", &response_content, "text")
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id,
|
|||
|
|
user_id: message.user_id,
|
|||
|
|
session_id: message.session_id,
|
|||
|
|
channel: message.channel,
|
|||
|
|
content: response_content,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
|||
|
|
adapter.send_message(bot_response).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn document_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
if let Some(vector_store) = &self.vector_store {
|
|||
|
|
let similar_docs = vector_store
|
|||
|
|
.similarity_search(&message.content, 3, &VecStoreOptions::default())
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let mut enhanced_prompt = format!("User question: {}\n\n", message.content);
|
|||
|
|
|
|||
|
|
if !similar_docs.is_empty() {
|
|||
|
|
enhanced_prompt.push_str("Relevant documents:\n");
|
|||
|
|
for (i, doc) in similar_docs.iter().enumerate() {
|
|||
|
|
enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content));
|
|||
|
|
}
|
|||
|
|
enhanced_prompt.push_str(
|
|||
|
|
"\nPlease answer the user's question based on the provided documents.",
|
|||
|
|
);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&enhanced_prompt, &serde_json::Value::Null)
|
|||
|
|
.await
|
|||
|
|
} else {
|
|||
|
|
self.direct_mode_handler(message, session).await
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn chart_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
if let Some(chart_generator) = &self.chart_generator {
|
|||
|
|
let chart_response = chart_generator
|
|||
|
|
.generate_chart(&message.content, "bar")
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
session.user_id,
|
|||
|
|
"system",
|
|||
|
|
&format!("Generated chart for query: {}", message.content),
|
|||
|
|
"chart",
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
Ok(format!(
|
|||
|
|
"Chart generated for your query. Data retrieved: {}",
|
|||
|
|
chart_response.sql_query
|
|||
|
|
))
|
|||
|
|
} else {
|
|||
|
|
self.document_mode_handler(message, session).await
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn database_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
_session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
if let Some(sql_chain) = &self.sql_chain {
|
|||
|
|
let input_variables = prompt_args! {
|
|||
|
|
"input" => message.content,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let result = sql_chain.invoke(input_variables).await?;
|
|||
|
|
Ok(result.to_string())
|
|||
|
|
} else {
|
|||
|
|
let db_url = std::env::var("DATABASE_URL")?;
|
|||
|
|
let engine = PostgreSQLEngine::new(&db_url).await?;
|
|||
|
|
let db = SQLDatabaseBuilder::new(engine).build().await?;
|
|||
|
|
|
|||
|
|
let llm = OpenAI::default();
|
|||
|
|
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
|
|||
|
|
.llm(llm)
|
|||
|
|
.top_k(5)
|
|||
|
|
.database(db)
|
|||
|
|
.build()?;
|
|||
|
|
|
|||
|
|
let input_variables = chain.prompt_builder().query(&message.content).build();
|
|||
|
|
let result = chain.invoke(input_variables).await?;
|
|||
|
|
|
|||
|
|
Ok(result.to_string())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn tool_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
_session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
if message.content.to_lowercase().contains("calculator") {
|
|||
|
|
if let Some(_adapter) = self.channels.get(&message.channel) {
|
|||
|
|
let (tx, _rx) = mpsc::channel(100);
|
|||
|
|
|
|||
|
|
self.register_response_channel(message.session_id.clone(), tx.clone())
|
|||
|
|
.await;
|
|||
|
|
|
|||
|
|
let tool_manager = self.tool_manager.clone();
|
|||
|
|
let user_id_str = message.user_id.clone();
|
|||
|
|
let bot_id_str = message.bot_id.clone();
|
|||
|
|
let session_manager = self.session_manager.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let _ = tool_manager
|
|||
|
|
.execute_tool("calculator", &user_id_str, &bot_id_str, session_manager, tx)
|
|||
|
|
.await;
|
|||
|
|
});
|
|||
|
|
}
|
|||
|
|
Ok("Starting calculator tool...".to_string())
|
|||
|
|
} else {
|
|||
|
|
let available_tools = self.tool_manager.list_tools();
|
|||
|
|
let tools_context = if !available_tools.is_empty() {
|
|||
|
|
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|||
|
|
} else {
|
|||
|
|
String::new()
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let full_prompt = format!("{}{}", message.content, tools_context);
|
|||
|
|
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&full_prompt, &serde_json::Value::Null)
|
|||
|
|
.await
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn direct_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
let history = self
|
|||
|
|
.session_manager
|
|||
|
|
.get_conversation_history(session.id, session.user_id)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let mut memory = SimpleMemory::new();
|
|||
|
|
for (role, content) in history {
|
|||
|
|
match role.as_str() {
|
|||
|
|
"user" => memory.add_user_message(&content),
|
|||
|
|
"assistant" => memory.add_ai_message(&content),
|
|||
|
|
_ => {}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let mut prompt = String::new();
|
|||
|
|
if let Some(chat_history) = memory.get_chat_history() {
|
|||
|
|
for message in chat_history {
|
|||
|
|
prompt.push_str(&format!(
|
|||
|
|
"{}: {}\n",
|
|||
|
|
message.message_type(),
|
|||
|
|
message.content()
|
|||
|
|
));
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
|||
|
|
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&prompt, &serde_json::Value::Null)
|
|||
|
|
.await
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn stream_response(
|
|||
|
|
&self,
|
|||
|
|
message: UserMessage,
|
|||
|
|
mut response_tx: mpsc::Sender<BotResponse>,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
info!("Streaming response for user: {}", message.user_id);
|
|||
|
|
|
|||
|
|
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|||
|
|
let bot_id = Uuid::parse_str(&message.bot_id)
|
|||
|
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|||
|
|
|
|||
|
|
let session = match self
|
|||
|
|
.session_manager
|
|||
|
|
.get_user_session(user_id, bot_id)
|
|||
|
|
.await?
|
|||
|
|
{
|
|||
|
|
Some(session) => session,
|
|||
|
|
None => {
|
|||
|
|
self.session_manager
|
|||
|
|
.create_session(user_id, bot_id, "New Conversation")
|
|||
|
|
.await?
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
|||
|
|
self.tool_manager
|
|||
|
|
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
|||
|
|
.await?;
|
|||
|
|
return Ok(());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"user",
|
|||
|
|
&message.content,
|
|||
|
|
&message.message_type,
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let history = self
|
|||
|
|
.session_manager
|
|||
|
|
.get_conversation_history(session.id, user_id)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let mut memory = SimpleMemory::new();
|
|||
|
|
for (role, content) in history {
|
|||
|
|
match role.as_str() {
|
|||
|
|
"user" => memory.add_user_message(&content),
|
|||
|
|
"assistant" => memory.add_ai_message(&content),
|
|||
|
|
_ => {}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let mut prompt = String::new();
|
|||
|
|
if let Some(chat_history) = memory.get_chat_history() {
|
|||
|
|
for message in chat_history {
|
|||
|
|
prompt.push_str(&format!(
|
|||
|
|
"{}: {}\n",
|
|||
|
|
message.message_type(),
|
|||
|
|
message.content()
|
|||
|
|
));
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
|||
|
|
|
|||
|
|
let (stream_tx, mut stream_rx) = mpsc::channel(100);
|
|||
|
|
let llm_provider = self.llm_provider.clone();
|
|||
|
|
let prompt_clone = prompt.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let _ = llm_provider
|
|||
|
|
.generate_stream(&prompt_clone, &serde_json::Value::Null, stream_tx)
|
|||
|
|
.await;
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
let mut full_response = String::new();
|
|||
|
|
while let Some(chunk) = stream_rx.recv().await {
|
|||
|
|
full_response.push_str(&chunk);
|
|||
|
|
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id.clone(),
|
|||
|
|
user_id: message.user_id.clone(),
|
|||
|
|
session_id: message.session_id.clone(),
|
|||
|
|
channel: message.channel.clone(),
|
|||
|
|
content: chunk,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: false,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if response_tx.send(bot_response).await.is_err() {
|
|||
|
|
break;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(session.id, user_id, "assistant", &full_response, "text")
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let final_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id,
|
|||
|
|
user_id: message.user_id,
|
|||
|
|
session_id: message.session_id,
|
|||
|
|
channel: message.channel,
|
|||
|
|
content: "".to_string(),
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
response_tx.send(final_response).await?;
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_user_sessions(
|
|||
|
|
&self,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
) -> Result<Vec<UserSession>, Box<dyn std::error::Error>> {
|
|||
|
|
self.session_manager.get_user_sessions(user_id).await
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_conversation_history(
|
|||
|
|
&self,
|
|||
|
|
session_id: Uuid,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
|
|||
|
|
self.session_manager
|
|||
|
|
.get_conversation_history(session_id, user_id)
|
|||
|
|
.await
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn process_message_with_tools(
|
|||
|
|
&self,
|
|||
|
|
message: UserMessage,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
info!(
|
|||
|
|
"Processing message with tools from user: {}",
|
|||
|
|
message.user_id
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|||
|
|
let bot_id = Uuid::parse_str(&message.bot_id)
|
|||
|
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|||
|
|
|
|||
|
|
let session = match self
|
|||
|
|
.session_manager
|
|||
|
|
.get_user_session(user_id, bot_id)
|
|||
|
|
.await?
|
|||
|
|
{
|
|||
|
|
Some(session) => session,
|
|||
|
|
None => {
|
|||
|
|
self.session_manager
|
|||
|
|
.create_session(user_id, bot_id, "New Conversation")
|
|||
|
|
.await?
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"user",
|
|||
|
|
&message.content,
|
|||
|
|
&message.message_type,
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let is_tool_waiting = self
|
|||
|
|
.tool_manager
|
|||
|
|
.is_tool_waiting(&message.session_id)
|
|||
|
|
.await
|
|||
|
|
.unwrap_or(false);
|
|||
|
|
|
|||
|
|
if is_tool_waiting {
|
|||
|
|
self.tool_manager
|
|||
|
|
.provide_input(&message.session_id, &message.content)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
|||
|
|
for output in tool_output {
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id.clone(),
|
|||
|
|
user_id: message.user_id.clone(),
|
|||
|
|
session_id: message.session_id.clone(),
|
|||
|
|
channel: message.channel.clone(),
|
|||
|
|
content: output,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
|||
|
|
adapter.send_message(bot_response).await?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return Ok(());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let available_tools = self.tool_manager.list_tools();
|
|||
|
|
let tools_context = if !available_tools.is_empty() {
|
|||
|
|
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|||
|
|
} else {
|
|||
|
|
String::new()
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let full_prompt = format!("{}{}", message.content, tools_context);
|
|||
|
|
|
|||
|
|
let response = if message.content.to_lowercase().contains("calculator")
|
|||
|
|
|| message.content.to_lowercase().contains("calculate")
|
|||
|
|
|| message.content.to_lowercase().contains("math")
|
|||
|
|
{
|
|||
|
|
match self
|
|||
|
|
.tool_manager
|
|||
|
|
.execute_tool("calculator", &message.session_id, &message.user_id)
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
Ok(tool_result) => {
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"assistant",
|
|||
|
|
&tool_result.output,
|
|||
|
|
"tool_start",
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
tool_result.output
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
format!("I encountered an error starting the calculator: {}", e)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&full_prompt, &serde_json::Value::Null)
|
|||
|
|
.await?
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(session.id, user_id, "assistant", &response, "text")
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id,
|
|||
|
|
user_id: message.user_id,
|
|||
|
|
session_id: message.session_id,
|
|||
|
|
channel: message.channel,
|
|||
|
|
content: response,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
|||
|
|
adapter.send_message(bot_response).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn tool_mode_handler(
|
|||
|
|
&self,
|
|||
|
|
message: &UserMessage,
|
|||
|
|
_session: &UserSession,
|
|||
|
|
) -> Result<String, Box<dyn std::error::Error>> {
|
|||
|
|
if message.content.to_lowercase().contains("calculator") {
|
|||
|
|
if let Some(_adapter) = self.channels.get(&message.channel) {
|
|||
|
|
let (tx, _rx) = mpsc::channel(100);
|
|||
|
|
|
|||
|
|
self.register_response_channel(message.session_id.clone(), tx.clone())
|
|||
|
|
.await;
|
|||
|
|
|
|||
|
|
let tool_manager = self.tool_manager.clone();
|
|||
|
|
let user_id_str = message.user_id.clone();
|
|||
|
|
let bot_id_str = message.bot_id.clone();
|
|||
|
|
let session_manager = self.session_manager.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let _ = tool_manager
|
|||
|
|
.execute_tool_with_session(
|
|||
|
|
"calculator",
|
|||
|
|
&user_id_str,
|
|||
|
|
&bot_id_str,
|
|||
|
|
session_manager,
|
|||
|
|
tx,
|
|||
|
|
)
|
|||
|
|
.await;
|
|||
|
|
});
|
|||
|
|
}
|
|||
|
|
Ok("Starting calculator tool...".to_string())
|
|||
|
|
} else {
|
|||
|
|
let available_tools = self.tool_manager.list_tools();
|
|||
|
|
let tools_context = if !available_tools.is_empty() {
|
|||
|
|
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|||
|
|
} else {
|
|||
|
|
String::new()
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let full_prompt = format!("{}{}", message.content, tools_context);
|
|||
|
|
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&full_prompt, &serde_json::Value::Null)
|
|||
|
|
.await
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Fix the process_message_with_tools method
|
|||
|
|
pub async fn process_message_with_tools(
|
|||
|
|
&self,
|
|||
|
|
message: UserMessage,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
info!(
|
|||
|
|
"Processing message with tools from user: {}",
|
|||
|
|
message.user_id
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
|||
|
|
let bot_id = Uuid::parse_str(&message.bot_id)
|
|||
|
|
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
|||
|
|
|
|||
|
|
let session = match self
|
|||
|
|
.session_manager
|
|||
|
|
.get_user_session(user_id, bot_id)
|
|||
|
|
.await?
|
|||
|
|
{
|
|||
|
|
Some(session) => session,
|
|||
|
|
None => {
|
|||
|
|
self.session_manager
|
|||
|
|
.create_session(user_id, bot_id, "New Conversation")
|
|||
|
|
.await?
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"user",
|
|||
|
|
&message.content,
|
|||
|
|
&message.message_type,
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let is_tool_waiting = self
|
|||
|
|
.tool_manager
|
|||
|
|
.is_tool_waiting(&message.session_id)
|
|||
|
|
.await
|
|||
|
|
.unwrap_or(false);
|
|||
|
|
|
|||
|
|
if is_tool_waiting {
|
|||
|
|
self.tool_manager
|
|||
|
|
.provide_input(&message.session_id, &message.content)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
|||
|
|
for output in tool_output {
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id.clone(),
|
|||
|
|
user_id: message.user_id.clone(),
|
|||
|
|
session_id: message.session_id.clone(),
|
|||
|
|
channel: message.channel.clone(),
|
|||
|
|
content: output,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
|||
|
|
adapter.send_message(bot_response).await?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return Ok(());
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let response = if message.content.to_lowercase().contains("calculator")
|
|||
|
|
|| message.content.to_lowercase().contains("calculate")
|
|||
|
|
|| message.content.to_lowercase().contains("math")
|
|||
|
|
{
|
|||
|
|
match self
|
|||
|
|
.tool_manager
|
|||
|
|
.execute_tool("calculator", &message.session_id, &message.user_id)
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
Ok(tool_result) => {
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(
|
|||
|
|
session.id,
|
|||
|
|
user_id,
|
|||
|
|
"assistant",
|
|||
|
|
&tool_result.output,
|
|||
|
|
"tool_start",
|
|||
|
|
)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
tool_result.output
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
format!("I encountered an error starting the calculator: {}", e)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
let available_tools = self.tool_manager.list_tools();
|
|||
|
|
let tools_context = if !available_tools.is_empty() {
|
|||
|
|
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|||
|
|
} else {
|
|||
|
|
String::new()
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let full_prompt = format!("{}{}", message.content, tools_context);
|
|||
|
|
|
|||
|
|
self.llm_provider
|
|||
|
|
.generate(&full_prompt, &serde_json::Value::Null)
|
|||
|
|
.await?
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
self.session_manager
|
|||
|
|
.save_message(session.id, user_id, "assistant", &response, "text")
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let bot_response = BotResponse {
|
|||
|
|
bot_id: message.bot_id,
|
|||
|
|
user_id: message.user_id,
|
|||
|
|
session_id: message.session_id,
|
|||
|
|
channel: message.channel,
|
|||
|
|
content: response,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
|||
|
|
adapter.send_message(bot_response).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/ws")]
|
|||
|
|
async fn websocket_handler(
|
|||
|
|
req: HttpRequest,
|
|||
|
|
stream: web::Payload,
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|||
|
|
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
|||
|
|
let session_id = Uuid::new_v4().to_string();
|
|||
|
|
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
|||
|
|
|
|||
|
|
data.orchestrator
|
|||
|
|
.register_response_channel(session_id.clone(), tx.clone())
|
|||
|
|
.await;
|
|||
|
|
data.web_adapter
|
|||
|
|
.add_connection(session_id.clone(), tx.clone())
|
|||
|
|
.await;
|
|||
|
|
data.voice_adapter
|
|||
|
|
.add_connection(session_id.clone(), tx.clone())
|
|||
|
|
.await;
|
|||
|
|
|
|||
|
|
let orchestrator = data.orchestrator.clone();
|
|||
|
|
let web_adapter = data.web_adapter.clone();
|
|||
|
|
|
|||
|
|
actix_web::rt::spawn(async move {
|
|||
|
|
while let Some(msg) = rx.recv().await {
|
|||
|
|
if let Ok(json) = serde_json::to_string(&msg) {
|
|||
|
|
let _ = session.text(json).await;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
actix_web::rt::spawn(async move {
|
|||
|
|
while let Some(Ok(msg)) = msg_stream.recv().await {
|
|||
|
|
match msg {
|
|||
|
|
WsMessage::Text(text) => {
|
|||
|
|
let user_message = UserMessage {
|
|||
|
|
bot_id: "default_bot".to_string(),
|
|||
|
|
user_id: "default_user".to_string(),
|
|||
|
|
session_id: session_id.clone(),
|
|||
|
|
channel: "web".to_string(),
|
|||
|
|
content: text.to_string(),
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
media_url: None,
|
|||
|
|
timestamp: Utc::now(),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Err(e) = orchestrator.stream_response(user_message, tx.clone()).await {
|
|||
|
|
info!("Error processing message: {}", e);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
WsMessage::Close(_) => {
|
|||
|
|
web_adapter.remove_connection(&session_id).await;
|
|||
|
|
break;
|
|||
|
|
}
|
|||
|
|
_ => {}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
Ok(res)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/api/whatsapp/webhook")]
|
|||
|
|
async fn whatsapp_webhook_verify(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
web::Query(params): web::Query<HashMap<String, String>>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
let mode = params.get("hub.mode").unwrap_or(&"".to_string());
|
|||
|
|
let token = params.get("hub.verify_token").unwrap_or(&"".to_string());
|
|||
|
|
let challenge = params.get("hub.challenge").unwrap_or(&"".to_string());
|
|||
|
|
|
|||
|
|
match data.whatsapp_adapter.verify_webhook(mode, token, challenge) {
|
|||
|
|
Ok(challenge_response) => Ok(HttpResponse::Ok().body(challenge_response)),
|
|||
|
|
Err(_) => Ok(HttpResponse::Forbidden().body("Verification failed")),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::post("/api/whatsapp/webhook")]
|
|||
|
|
async fn whatsapp_webhook(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
payload: web::Json<crate::services::whatsapp::WhatsAppMessage>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
match data
|
|||
|
|
.whatsapp_adapter
|
|||
|
|
.process_incoming_message(payload.into_inner())
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
Ok(user_messages) => {
|
|||
|
|
for user_message in user_messages {
|
|||
|
|
if let Err(e) = data.orchestrator.process_message(user_message).await {
|
|||
|
|
log::error!("Error processing WhatsApp message: {}", e);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
Ok(HttpResponse::Ok().body(""))
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
log::error!("Error processing WhatsApp webhook: {}", e);
|
|||
|
|
Ok(HttpResponse::BadRequest().body("Invalid message"))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::post("/api/voice/start")]
|
|||
|
|
async fn voice_start(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
info: web::Json<serde_json::Value>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
let session_id = info
|
|||
|
|
.get("session_id")
|
|||
|
|
.and_then(|s| s.as_str())
|
|||
|
|
.unwrap_or("");
|
|||
|
|
let user_id = info
|
|||
|
|
.get("user_id")
|
|||
|
|
.and_then(|u| u.as_str())
|
|||
|
|
.unwrap_or("user");
|
|||
|
|
|
|||
|
|
match data
|
|||
|
|
.voice_adapter
|
|||
|
|
.start_voice_session(session_id, user_id)
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
Ok(token) => {
|
|||
|
|
Ok(HttpResponse::Ok().json(serde_json::json!({"token": token, "status": "started"})))
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
Ok(HttpResponse::InternalServerError()
|
|||
|
|
.json(serde_json::json!({"error": e.to_string()})))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::post("/api/voice/stop")]
|
|||
|
|
async fn voice_stop(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
info: web::Json<serde_json::Value>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
let session_id = info
|
|||
|
|
.get("session_id")
|
|||
|
|
.and_then(|s| s.as_str())
|
|||
|
|
.unwrap_or("");
|
|||
|
|
|
|||
|
|
match data.voice_adapter.stop_voice_session(session_id).await {
|
|||
|
|
Ok(()) => Ok(HttpResponse::Ok().json(serde_json::json!({"status": "stopped"}))),
|
|||
|
|
Err(e) => {
|
|||
|
|
Ok(HttpResponse::InternalServerError()
|
|||
|
|
.json(serde_json::json!({"error": e.to_string()})))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::post("/api/sessions")]
|
|||
|
|
async fn create_session(_data: web::Data<AppState>) -> Result<HttpResponse> {
|
|||
|
|
let session_id = Uuid::new_v4();
|
|||
|
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
|||
|
|
"session_id": session_id,
|
|||
|
|
"title": "New Conversation",
|
|||
|
|
"created_at": Utc::now()
|
|||
|
|
})))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/api/sessions")]
|
|||
|
|
async fn get_sessions(data: web::Data<AppState>) -> Result<HttpResponse> {
|
|||
|
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
|||
|
|
match data.orchestrator.get_user_sessions(user_id).await {
|
|||
|
|
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
|||
|
|
Err(e) => {
|
|||
|
|
Ok(HttpResponse::InternalServerError()
|
|||
|
|
.json(serde_json::json!({"error": e.to_string()})))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/api/sessions/{session_id}")]
|
|||
|
|
async fn get_session_history(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
path: web::Path<String>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
let session_id = path.into_inner();
|
|||
|
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
|||
|
|
|
|||
|
|
match Uuid::parse_str(&session_id) {
|
|||
|
|
Ok(session_uuid) => match data
|
|||
|
|
.orchestrator
|
|||
|
|
.get_conversation_history(session_uuid, user_id)
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
|||
|
|
Err(e) => Ok(HttpResponse::InternalServerError()
|
|||
|
|
.json(serde_json::json!({"error": e.to_string()}))),
|
|||
|
|
},
|
|||
|
|
Err(_) => {
|
|||
|
|
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::post("/api/set_mode")]
|
|||
|
|
async fn set_mode_handler(
|
|||
|
|
data: web::Data<AppState>,
|
|||
|
|
info: web::Json<HashMap<String, String>>,
|
|||
|
|
) -> Result<HttpResponse> {
|
|||
|
|
let default_user = "default_user".to_string();
|
|||
|
|
let default_bot = "default_bot".to_string();
|
|||
|
|
let default_mode = "direct".to_string();
|
|||
|
|
|
|||
|
|
let user_id = info.get("user_id").unwrap_or(&default_user);
|
|||
|
|
let bot_id = info.get("bot_id").unwrap_or(&default_bot);
|
|||
|
|
let mode = info.get("mode").unwrap_or(&default_mode);
|
|||
|
|
|
|||
|
|
if let Err(e) = data
|
|||
|
|
.orchestrator
|
|||
|
|
.set_user_answer_mode(user_id, bot_id, mode)
|
|||
|
|
.await
|
|||
|
|
{
|
|||
|
|
return Ok(
|
|||
|
|
HttpResponse::InternalServerError().json(serde_json::json!({"error": e.to_string()}))
|
|||
|
|
);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(HttpResponse::Ok().json(serde_json::json!({"status": "mode_updated"})))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/")]
|
|||
|
|
async fn index() -> Result<HttpResponse> {
|
|||
|
|
let html = fs::read_to_string("templates/index.html")
|
|||
|
|
.unwrap_or_else(|_| include_str!("../../static/index.html").to_string());
|
|||
|
|
Ok(HttpResponse::Ok().content_type("text/html").body(html))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[actix_web::get("/static/{filename:.*}")]
|
|||
|
|
async fn static_files(req: HttpRequest) -> Result<HttpResponse> {
|
|||
|
|
let filename = req.match_info().query("filename");
|
|||
|
|
let path = format!("static/{}", filename);
|
|||
|
|
|
|||
|
|
match fs::read(&path) {
|
|||
|
|
Ok(content) => {
|
|||
|
|
let content_type = match filename {
|
|||
|
|
f if f.ends_with(".js") => "application/javascript",
|
|||
|
|
f if f.ends_with(".css") => "text/css",
|
|||
|
|
f if f.ends_with(".png") => "image/png",
|
|||
|
|
f if f.ends_with(".jpg") | f.ends_with(".jpeg") => "image/jpeg",
|
|||
|
|
_ => "text/plain",
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
Ok(HttpResponse::Ok().content_type(content_type).body(content))
|
|||
|
|
}
|
|||
|
|
Err(_) => Ok(HttpResponse::NotFound().body("File not found")),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
use crate::shared::UserSession;
|
|||
|
|
use redis::{AsyncCommands, Client};
|
|||
|
|
use serde_json;
|
|||
|
|
use sqlx::{PgPool, Row};
|
|||
|
|
use std::sync::Arc;
|
|||
|
|
use uuid::Uuid;
|
|||
|
|
|
|||
|
|
pub struct SessionManager {
|
|||
|
|
pub pool: PgPool,
|
|||
|
|
pub redis: Option<Arc<Client>>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl SessionManager {
|
|||
|
|
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
|||
|
|
Self { pool, redis }
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_user_session(
|
|||
|
|
&self,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
bot_id: Uuid,
|
|||
|
|
) -> Result<Option<UserSession>, Box<dyn std::error::Error>> {
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
|||
|
|
let session_json: Option<String> = conn.get(&cache_key).await?;
|
|||
|
|
if let Some(json) = session_json {
|
|||
|
|
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
|||
|
|
return Ok(Some(session));
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let session = sqlx::query_as::<_, UserSession>(
|
|||
|
|
"SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1",
|
|||
|
|
)
|
|||
|
|
.bind(user_id)
|
|||
|
|
.bind(bot_id)
|
|||
|
|
.fetch_optional(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(ref session) = session {
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
|||
|
|
let session_json = serde_json::to_string(session)?;
|
|||
|
|
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(session)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn create_session(
|
|||
|
|
&self,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
bot_id: Uuid,
|
|||
|
|
title: &str,
|
|||
|
|
) -> Result<UserSession, Box<dyn std::error::Error>> {
|
|||
|
|
let session = sqlx::query_as::<_, UserSession>(
|
|||
|
|
"INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *",
|
|||
|
|
)
|
|||
|
|
.bind(user_id)
|
|||
|
|
.bind(bot_id)
|
|||
|
|
.bind(title)
|
|||
|
|
.fetch_one(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
|||
|
|
let session_json = serde_json::to_string(&session)?;
|
|||
|
|
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(session)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn save_message(
|
|||
|
|
&self,
|
|||
|
|
session_id: Uuid,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
role: &str,
|
|||
|
|
content: &str,
|
|||
|
|
message_type: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let message_count: i64 =
|
|||
|
|
sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1")
|
|||
|
|
.bind(session_id)
|
|||
|
|
.fetch_one(&self.pool)
|
|||
|
|
.await?
|
|||
|
|
.get("count");
|
|||
|
|
|
|||
|
|
sqlx::query(
|
|||
|
|
"INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index)
|
|||
|
|
VALUES ($1, $2, $3, $4, $5, $6)",
|
|||
|
|
)
|
|||
|
|
.bind(session_id)
|
|||
|
|
.bind(user_id)
|
|||
|
|
.bind(role)
|
|||
|
|
.bind(content)
|
|||
|
|
.bind(message_type)
|
|||
|
|
.bind(message_count + 1)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
sqlx::query("UPDATE user_sessions SET updated_at = NOW() WHERE id = $1")
|
|||
|
|
.bind(session_id)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
if let Some(session_info) =
|
|||
|
|
sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1")
|
|||
|
|
.bind(session_id)
|
|||
|
|
.fetch_optional(&self.pool)
|
|||
|
|
.await?
|
|||
|
|
{
|
|||
|
|
let user_id: Uuid = session_info.get("user_id");
|
|||
|
|
let bot_id: Uuid = session_info.get("bot_id");
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
|||
|
|
let _: () = conn.del(cache_key).await?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_conversation_history(
|
|||
|
|
&self,
|
|||
|
|
session_id: Uuid,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error>> {
|
|||
|
|
let messages = sqlx::query(
|
|||
|
|
"SELECT role, content_encrypted FROM message_history
|
|||
|
|
WHERE session_id = $1 AND user_id = $2
|
|||
|
|
ORDER BY message_index ASC",
|
|||
|
|
)
|
|||
|
|
.bind(session_id)
|
|||
|
|
.bind(user_id)
|
|||
|
|
.fetch_all(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let history = messages
|
|||
|
|
.into_iter()
|
|||
|
|
.map(|row| (row.get("role"), row.get("content_encrypted")))
|
|||
|
|
.collect();
|
|||
|
|
|
|||
|
|
Ok(history)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_user_sessions(
|
|||
|
|
&self,
|
|||
|
|
user_id: Uuid,
|
|||
|
|
) -> Result<Vec<UserSession>, Box<dyn std::error::Error>> {
|
|||
|
|
let sessions = sqlx::query_as::<_, UserSession>(
|
|||
|
|
"SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC",
|
|||
|
|
)
|
|||
|
|
.bind(user_id)
|
|||
|
|
.fetch_all(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
Ok(sessions)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn update_answer_mode(
|
|||
|
|
&self,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
mode: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let user_uuid = Uuid::parse_str(user_id)?;
|
|||
|
|
let bot_uuid = Uuid::parse_str(bot_id)?;
|
|||
|
|
|
|||
|
|
sqlx::query(
|
|||
|
|
"UPDATE user_sessions
|
|||
|
|
SET answer_mode = $1, updated_at = NOW()
|
|||
|
|
WHERE user_id = $2 AND bot_id = $3",
|
|||
|
|
)
|
|||
|
|
.bind(mode)
|
|||
|
|
.bind(user_uuid)
|
|||
|
|
.bind(bot_uuid)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|||
|
|
let _: () = conn.del(cache_key).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn update_current_tool(
|
|||
|
|
&self,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
tool_name: Option<&str>,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let user_uuid = Uuid::parse_str(user_id)?;
|
|||
|
|
let bot_uuid = Uuid::parse_str(bot_id)?;
|
|||
|
|
|
|||
|
|
sqlx::query(
|
|||
|
|
"UPDATE user_sessions
|
|||
|
|
SET current_tool = $1, updated_at = NOW()
|
|||
|
|
WHERE user_id = $2 AND bot_id = $3",
|
|||
|
|
)
|
|||
|
|
.bind(tool_name)
|
|||
|
|
.bind(user_uuid)
|
|||
|
|
.bind(bot_uuid)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|||
|
|
let _: () = conn.del(cache_key).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_session_by_id(
|
|||
|
|
&self,
|
|||
|
|
session_id: Uuid,
|
|||
|
|
) -> Result<Option<UserSession>, Box<dyn std::error::Error>> {
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session_by_id:{}", session_id);
|
|||
|
|
let session_json: Option<String> = conn.get(&cache_key).await?;
|
|||
|
|
if let Some(json) = session_json {
|
|||
|
|
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
|||
|
|
return Ok(Some(session));
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1")
|
|||
|
|
.bind(session_id)
|
|||
|
|
.fetch_optional(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(ref session) = session {
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session_by_id:{}", session_id);
|
|||
|
|
let session_json = serde_json::to_string(session)?;
|
|||
|
|
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(session)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn cleanup_old_sessions(
|
|||
|
|
&self,
|
|||
|
|
days_old: i32,
|
|||
|
|
) -> Result<u64, Box<dyn std::error::Error>> {
|
|||
|
|
let result = sqlx::query(
|
|||
|
|
"DELETE FROM user_sessions
|
|||
|
|
WHERE updated_at < NOW() - INTERVAL '1 day' * $1",
|
|||
|
|
)
|
|||
|
|
.bind(days_old)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
Ok(result.rows_affected())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn set_current_tool(
|
|||
|
|
&self,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
tool_name: Option<String>,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let user_uuid = Uuid::parse_str(user_id)?;
|
|||
|
|
let bot_uuid = Uuid::parse_str(bot_id)?;
|
|||
|
|
|
|||
|
|
sqlx::query(
|
|||
|
|
"UPDATE user_sessions
|
|||
|
|
SET current_tool = $1, updated_at = NOW()
|
|||
|
|
WHERE user_id = $2 AND bot_id = $3",
|
|||
|
|
)
|
|||
|
|
.bind(tool_name)
|
|||
|
|
.bind(user_uuid)
|
|||
|
|
.bind(bot_uuid)
|
|||
|
|
.execute(&self.pool)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
if let Some(redis_client) = &self.redis {
|
|||
|
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
|||
|
|
let _: () = conn.del(cache_key).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl Clone for SessionManager {
|
|||
|
|
fn clone(&self) -> Self {
|
|||
|
|
Self {
|
|||
|
|
pool: self.pool.clone(),
|
|||
|
|
redis: self.redis.clone(),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// src/tools/mod.rs
|
|||
|
|
use async_trait::async_trait;
|
|||
|
|
use redis::AsyncCommands;
|
|||
|
|
use rhai::{Engine, Scope};
|
|||
|
|
use serde::{Deserialize, Serialize};
|
|||
|
|
use std::collections::HashMap;
|
|||
|
|
use std::sync::Arc;
|
|||
|
|
use tokio::sync::{mpsc, Mutex};
|
|||
|
|
use uuid::Uuid;
|
|||
|
|
|
|||
|
|
use crate::channels::ChannelAdapter;
|
|||
|
|
use crate::session::SessionManager;
|
|||
|
|
use crate::shared::BotResponse;
|
|||
|
|
|
|||
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|||
|
|
pub struct ToolResult {
|
|||
|
|
pub success: bool,
|
|||
|
|
pub output: String,
|
|||
|
|
pub requires_input: bool,
|
|||
|
|
pub session_id: String,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Clone)]
|
|||
|
|
pub struct Tool {
|
|||
|
|
pub name: String,
|
|||
|
|
pub description: String,
|
|||
|
|
pub parameters: HashMap<String, String>,
|
|||
|
|
pub script: String,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[async_trait]
|
|||
|
|
pub trait ToolExecutor: Send + Sync {
|
|||
|
|
async fn execute(
|
|||
|
|
&self,
|
|||
|
|
tool_name: &str,
|
|||
|
|
session_id: &str,
|
|||
|
|
user_id: &str,
|
|||
|
|
) -> Result<ToolResult, Box<dyn std::error::Error>>;
|
|||
|
|
async fn provide_input(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
input: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>>;
|
|||
|
|
async fn get_output(&self, session_id: &str)
|
|||
|
|
-> Result<Vec<String>, Box<dyn std::error::Error>>;
|
|||
|
|
async fn is_waiting_for_input(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
) -> Result<bool, Box<dyn std::error::Error>>;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub struct RedisToolExecutor {
|
|||
|
|
redis_client: redis::Client,
|
|||
|
|
web_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
voice_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl RedisToolExecutor {
|
|||
|
|
pub fn new(
|
|||
|
|
redis_url: &str,
|
|||
|
|
web_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
voice_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
|||
|
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
|||
|
|
let client = redis::Client::open(redis_url)?;
|
|||
|
|
Ok(Self {
|
|||
|
|
redis_client: client,
|
|||
|
|
web_adapter,
|
|||
|
|
voice_adapter,
|
|||
|
|
whatsapp_adapter,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn send_tool_message(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
user_id: &str,
|
|||
|
|
channel: &str,
|
|||
|
|
message: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: "tool_bot".to_string(),
|
|||
|
|
user_id: user_id.to_string(),
|
|||
|
|
session_id: session_id.to_string(),
|
|||
|
|
channel: channel.to_string(),
|
|||
|
|
content: message.to_string(),
|
|||
|
|
message_type: "tool".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
match channel {
|
|||
|
|
"web" => self.web_adapter.send_message(response).await,
|
|||
|
|
"voice" => self.voice_adapter.send_message(response).await,
|
|||
|
|
"whatsapp" => self.whatsapp_adapter.send_message(response).await,
|
|||
|
|
_ => Ok(()),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine {
|
|||
|
|
let mut engine = Engine::new();
|
|||
|
|
|
|||
|
|
// Clone for TALK function
|
|||
|
|
let tool_executor = Arc::new((
|
|||
|
|
self.redis_client.clone(),
|
|||
|
|
self.web_adapter.clone(),
|
|||
|
|
self.voice_adapter.clone(),
|
|||
|
|
self.whatsapp_adapter.clone(),
|
|||
|
|
));
|
|||
|
|
|
|||
|
|
let session_id_clone = session_id.clone();
|
|||
|
|
let user_id_clone = user_id.clone();
|
|||
|
|
let channel_clone = channel.clone();
|
|||
|
|
|
|||
|
|
engine.register_fn("talk", move |message: String| {
|
|||
|
|
let tool_executor = Arc::clone(&tool_executor);
|
|||
|
|
let session_id = session_id_clone.clone();
|
|||
|
|
let user_id = user_id_clone.clone();
|
|||
|
|
let channel = channel_clone.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor;
|
|||
|
|
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: "tool_bot".to_string(),
|
|||
|
|
user_id: user_id.clone(),
|
|||
|
|
session_id: session_id.clone(),
|
|||
|
|
channel: channel.clone(),
|
|||
|
|
content: message.clone(),
|
|||
|
|
message_type: "tool".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let result = match channel.as_str() {
|
|||
|
|
"web" => web_adapter.send_message(response).await,
|
|||
|
|
"voice" => voice_adapter.send_message(response).await,
|
|||
|
|
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
|||
|
|
_ => Ok(()),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Err(e) = result {
|
|||
|
|
log::error!("Failed to send tool message: {}", e);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
|||
|
|
let output_key = format!("tool:{}:output", session_id);
|
|||
|
|
let _ = conn.lpush(&output_key, &message).await;
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
let hear_executor = self.redis_client.clone();
|
|||
|
|
let session_id_clone = session_id.clone();
|
|||
|
|
|
|||
|
|
engine.register_fn("hear", move || -> String {
|
|||
|
|
let hear_executor = hear_executor.clone();
|
|||
|
|
let session_id = session_id_clone.clone();
|
|||
|
|
|
|||
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|||
|
|
rt.block_on(async move {
|
|||
|
|
match hear_executor.get_async_connection().await {
|
|||
|
|
Ok(mut conn) => {
|
|||
|
|
let input_key = format!("tool:{}:input", session_id);
|
|||
|
|
let waiting_key = format!("tool:{}:waiting", session_id);
|
|||
|
|
|
|||
|
|
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
|||
|
|
let result: Option<(String, String)> =
|
|||
|
|
conn.brpop(&input_key, 30).await.ok().flatten();
|
|||
|
|
let _ = conn.del(&waiting_key).await;
|
|||
|
|
|
|||
|
|
result
|
|||
|
|
.map(|(_, input)| input)
|
|||
|
|
.unwrap_or_else(|| "timeout".to_string())
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
log::error!("HEAR Redis error: {}", e);
|
|||
|
|
"error".to_string()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
engine
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn cleanup_session(&self, session_id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
|
|||
|
|
let keys = vec![
|
|||
|
|
format!("tool:{}:output", session_id),
|
|||
|
|
format!("tool:{}:input", session_id),
|
|||
|
|
format!("tool:{}:waiting", session_id),
|
|||
|
|
format!("tool:{}:active", session_id),
|
|||
|
|
];
|
|||
|
|
|
|||
|
|
for key in keys {
|
|||
|
|
let _: () = conn.del(&key).await?;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[async_trait]
|
|||
|
|
impl ToolExecutor for RedisToolExecutor {
|
|||
|
|
async fn execute(
|
|||
|
|
&self,
|
|||
|
|
tool_name: &str,
|
|||
|
|
session_id: &str,
|
|||
|
|
user_id: &str,
|
|||
|
|
) -> Result<ToolResult, Box<dyn std::error::Error>> {
|
|||
|
|
let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?;
|
|||
|
|
|
|||
|
|
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let session_key = format!("tool:{}:session", session_id);
|
|||
|
|
let session_data = serde_json::json!({
|
|||
|
|
"user_id": user_id,
|
|||
|
|
"tool_name": tool_name,
|
|||
|
|
"started_at": chrono::Utc::now().to_rfc3339(),
|
|||
|
|
});
|
|||
|
|
conn.set_ex(&session_key, session_data.to_string(), 3600)
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let active_key = format!("tool:{}:active", session_id);
|
|||
|
|
conn.set_ex(&active_key, "true", 3600).await?;
|
|||
|
|
|
|||
|
|
let channel = "web";
|
|||
|
|
let _engine = self.create_rhai_engine(
|
|||
|
|
session_id.to_string(),
|
|||
|
|
user_id.to_string(),
|
|||
|
|
channel.to_string(),
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
let redis_clone = self.redis_client.clone();
|
|||
|
|
let web_adapter_clone = self.web_adapter.clone();
|
|||
|
|
let voice_adapter_clone = self.voice_adapter.clone();
|
|||
|
|
let whatsapp_adapter_clone = self.whatsapp_adapter.clone();
|
|||
|
|
let session_id_clone = session_id.to_string();
|
|||
|
|
let user_id_clone = user_id.to_string();
|
|||
|
|
let tool_script = tool.script.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let mut engine = Engine::new();
|
|||
|
|
let mut scope = Scope::new();
|
|||
|
|
|
|||
|
|
let redis_client = redis_clone.clone();
|
|||
|
|
let web_adapter = web_adapter_clone.clone();
|
|||
|
|
let voice_adapter = voice_adapter_clone.clone();
|
|||
|
|
let whatsapp_adapter = whatsapp_adapter_clone.clone();
|
|||
|
|
let session_id = session_id_clone.clone();
|
|||
|
|
let user_id = user_id_clone.clone();
|
|||
|
|
|
|||
|
|
engine.register_fn("talk", move |message: String| {
|
|||
|
|
let redis_client = redis_client.clone();
|
|||
|
|
let web_adapter = web_adapter.clone();
|
|||
|
|
let voice_adapter = voice_adapter.clone();
|
|||
|
|
let whatsapp_adapter = whatsapp_adapter.clone();
|
|||
|
|
let session_id = session_id.clone();
|
|||
|
|
let user_id = user_id.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let channel = "web";
|
|||
|
|
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: "tool_bot".to_string(),
|
|||
|
|
user_id: user_id.clone(),
|
|||
|
|
session_id: session_id.clone(),
|
|||
|
|
channel: channel.to_string(),
|
|||
|
|
content: message.clone(),
|
|||
|
|
message_type: "tool".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let send_result = match channel {
|
|||
|
|
"web" => web_adapter.send_message(response).await,
|
|||
|
|
"voice" => voice_adapter.send_message(response).await,
|
|||
|
|
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
|||
|
|
_ => Ok(()),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
if let Err(e) = send_result {
|
|||
|
|
log::error!("Failed to send tool message: {}", e);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
|||
|
|
let output_key = format!("tool:{}:output", session_id);
|
|||
|
|
let _ = conn.lpush(&output_key, &message).await;
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
let hear_redis = redis_clone.clone();
|
|||
|
|
let session_id_hear = session_id.clone();
|
|||
|
|
engine.register_fn("hear", move || -> String {
|
|||
|
|
let hear_redis = hear_redis.clone();
|
|||
|
|
let session_id = session_id_hear.clone();
|
|||
|
|
|
|||
|
|
let rt = tokio::runtime::Runtime::new().unwrap();
|
|||
|
|
rt.block_on(async move {
|
|||
|
|
match hear_redis.get_async_connection().await {
|
|||
|
|
Ok(mut conn) => {
|
|||
|
|
let input_key = format!("tool:{}:input", session_id);
|
|||
|
|
let waiting_key = format!("tool:{}:waiting", session_id);
|
|||
|
|
|
|||
|
|
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
|||
|
|
let result: Option<(String, String)> =
|
|||
|
|
conn.brpop(&input_key, 30).await.ok().flatten();
|
|||
|
|
let _ = conn.del(&waiting_key).await;
|
|||
|
|
|
|||
|
|
result
|
|||
|
|
.map(|(_, input)| input)
|
|||
|
|
.unwrap_or_else(|| "timeout".to_string())
|
|||
|
|
}
|
|||
|
|
Err(_) => "error".to_string(),
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
match engine.eval_with_scope::<()>(&mut scope, &tool_script) {
|
|||
|
|
Ok(_) => {
|
|||
|
|
log::info!(
|
|||
|
|
"Tool {} completed successfully for session {}",
|
|||
|
|
tool_name,
|
|||
|
|
session_id
|
|||
|
|
);
|
|||
|
|
|
|||
|
|
let completion_msg =
|
|||
|
|
"🛠️ Tool execution completed. How can I help you with anything else?";
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: "tool_bot".to_string(),
|
|||
|
|
user_id: user_id_clone,
|
|||
|
|
session_id: session_id_clone.clone(),
|
|||
|
|
channel: "web".to_string(),
|
|||
|
|
content: completion_msg.to_string(),
|
|||
|
|
message_type: "tool_complete".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let _ = web_adapter_clone.send_message(response).await;
|
|||
|
|
}
|
|||
|
|
Err(e) => {
|
|||
|
|
log::error!("Tool execution failed: {}", e);
|
|||
|
|
|
|||
|
|
let error_msg = format!("❌ Tool error: {}", e);
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: "tool_bot".to_string(),
|
|||
|
|
user_id: user_id_clone,
|
|||
|
|
session_id: session_id_clone.clone(),
|
|||
|
|
channel: "web".to_string(),
|
|||
|
|
content: error_msg,
|
|||
|
|
message_type: "tool_error".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let _ = web_adapter_clone.send_message(response).await;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if let Ok(mut conn) = redis_clone.get_async_connection().await {
|
|||
|
|
let active_key = format!("tool:{}:active", session_id_clone);
|
|||
|
|
let _ = conn.del(&active_key).await;
|
|||
|
|
}
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
Ok(ToolResult {
|
|||
|
|
success: true,
|
|||
|
|
output: format!(
|
|||
|
|
"🛠️ Starting {} tool. Please follow the tool's instructions.",
|
|||
|
|
tool_name
|
|||
|
|
),
|
|||
|
|
requires_input: true,
|
|||
|
|
session_id: session_id.to_string(),
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn provide_input(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
input: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let input_key = format!("tool:{}:input", session_id);
|
|||
|
|
conn.lpush(&input_key, input).await?;
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn get_output(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
|||
|
|
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let output_key = format!("tool:{}:output", session_id);
|
|||
|
|
let messages: Vec<String> = conn.lrange(&output_key, 0, -1).await?;
|
|||
|
|
let _: () = conn.del(&output_key).await?;
|
|||
|
|
Ok(messages)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async fn is_waiting_for_input(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
) -> Result<bool, Box<dyn std::error::Error>> {
|
|||
|
|
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|||
|
|
let waiting_key = format!("tool:{}:waiting", session_id);
|
|||
|
|
let exists: bool = conn.exists(&waiting_key).await?;
|
|||
|
|
Ok(exists)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fn get_tool(name: &str) -> Option<Tool> {
|
|||
|
|
match name {
|
|||
|
|
"calculator" => Some(Tool {
|
|||
|
|
name: "calculator".to_string(),
|
|||
|
|
description: "Perform mathematical calculations".to_string(),
|
|||
|
|
parameters: HashMap::from([
|
|||
|
|
("operation".to_string(), "add|subtract|multiply|divide".to_string()),
|
|||
|
|
("a".to_string(), "number".to_string()),
|
|||
|
|
("b".to_string(), "number".to_string()),
|
|||
|
|
]),
|
|||
|
|
script: r#"
|
|||
|
|
let TALK = |message| {
|
|||
|
|
talk(message);
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
let HEAR = || {
|
|||
|
|
hear()
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
TALK("🔢 Calculator started!");
|
|||
|
|
TALK("Please enter the first number:");
|
|||
|
|
let a = HEAR();
|
|||
|
|
TALK("Please enter the second number:");
|
|||
|
|
let b = HEAR();
|
|||
|
|
TALK("Choose operation: add, subtract, multiply, or divide:");
|
|||
|
|
let op = HEAR();
|
|||
|
|
|
|||
|
|
let num_a = a.to_float();
|
|||
|
|
let num_b = b.to_float();
|
|||
|
|
|
|||
|
|
if op == "add" {
|
|||
|
|
let result = num_a + num_b;
|
|||
|
|
TALK("✅ Result: " + a + " + " + b + " = " + result);
|
|||
|
|
} else if op == "subtract" {
|
|||
|
|
let result = num_a - num_b;
|
|||
|
|
TALK("✅ Result: " + a + " - " + b + " = " + result);
|
|||
|
|
} else if op == "multiply" {
|
|||
|
|
let result = num_a * num_b;
|
|||
|
|
TALK("✅ Result: " + a + " × " + b + " = " + result);
|
|||
|
|
} else if op == "divide" {
|
|||
|
|
if num_b != 0.0 {
|
|||
|
|
let result = num_a / num_b;
|
|||
|
|
TALK("✅ Result: " + a + " ÷ " + b + " = " + result);
|
|||
|
|
} else {
|
|||
|
|
TALK("❌ Error: Cannot divide by zero!");
|
|||
|
|
}
|
|||
|
|
} else {
|
|||
|
|
TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide");
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
TALK("Calculator session completed. Thank you!");
|
|||
|
|
"#.to_string(),
|
|||
|
|
}),
|
|||
|
|
_ => None,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
#[derive(Clone)]
|
|||
|
|
pub struct ToolManager {
|
|||
|
|
tools: HashMap<String, Tool>,
|
|||
|
|
waiting_responses: Arc<Mutex<HashMap<String, mpsc::Sender<String>>>>,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl ToolManager {
|
|||
|
|
pub fn new() -> Self {
|
|||
|
|
let mut tools = HashMap::new();
|
|||
|
|
|
|||
|
|
let calculator_tool = Tool {
|
|||
|
|
name: "calculator".to_string(),
|
|||
|
|
description: "Perform calculations".to_string(),
|
|||
|
|
parameters: HashMap::from([
|
|||
|
|
(
|
|||
|
|
"operation".to_string(),
|
|||
|
|
"add|subtract|multiply|divide".to_string(),
|
|||
|
|
),
|
|||
|
|
("a".to_string(), "number".to_string()),
|
|||
|
|
("b".to_string(), "number".to_string()),
|
|||
|
|
]),
|
|||
|
|
script: r#"
|
|||
|
|
TALK("Calculator started. Enter first number:");
|
|||
|
|
let a = HEAR();
|
|||
|
|
TALK("Enter second number:");
|
|||
|
|
let b = HEAR();
|
|||
|
|
TALK("Operation (add/subtract/multiply/divide):");
|
|||
|
|
let op = HEAR();
|
|||
|
|
|
|||
|
|
let num_a = a.parse::<f64>().unwrap();
|
|||
|
|
let num_b = b.parse::<f64>().unwrap();
|
|||
|
|
let result = if op == "add" {
|
|||
|
|
num_a + num_b
|
|||
|
|
} else if op == "subtract" {
|
|||
|
|
num_a - num_b
|
|||
|
|
} else if op == "multiply" {
|
|||
|
|
num_a * num_b
|
|||
|
|
} else if op == "divide" {
|
|||
|
|
if num_b == 0.0 {
|
|||
|
|
TALK("Cannot divide by zero");
|
|||
|
|
return;
|
|||
|
|
}
|
|||
|
|
num_a / num_b
|
|||
|
|
} else {
|
|||
|
|
TALK("Invalid operation");
|
|||
|
|
return;
|
|||
|
|
};
|
|||
|
|
TALK("Result: ".to_string() + &result.to_string());
|
|||
|
|
"#
|
|||
|
|
.to_string(),
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
tools.insert(calculator_tool.name.clone(), calculator_tool);
|
|||
|
|
Self {
|
|||
|
|
tools,
|
|||
|
|
waiting_responses: Arc::new(Mutex::new(HashMap::new())),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
|
|||
|
|
self.tools.get(name)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub fn list_tools(&self) -> Vec<String> {
|
|||
|
|
self.tools.keys().cloned().collect()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn execute_tool(
|
|||
|
|
&self,
|
|||
|
|
tool_name: &str,
|
|||
|
|
session_id: &str,
|
|||
|
|
user_id: &str,
|
|||
|
|
) -> Result<ToolResult, Box<dyn std::error::Error>> {
|
|||
|
|
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
|||
|
|
|
|||
|
|
Ok(ToolResult {
|
|||
|
|
success: true,
|
|||
|
|
output: format!("Tool {} started for user {}", tool_name, user_id),
|
|||
|
|
requires_input: true,
|
|||
|
|
session_id: session_id.to_string(),
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn is_tool_waiting(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
) -> Result<bool, Box<dyn std::error::Error>> {
|
|||
|
|
let waiting = self.waiting_responses.lock().await;
|
|||
|
|
Ok(waiting.contains_key(session_id))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn provide_input(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
input: &str,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
self.provide_user_response(session_id, "default_bot", input.to_string())
|
|||
|
|
.await
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn get_tool_output(
|
|||
|
|
&self,
|
|||
|
|
session_id: &str,
|
|||
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
|||
|
|
Ok(vec![])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn execute_tool_with_session(
|
|||
|
|
&self,
|
|||
|
|
tool_name: &str,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
session_manager: SessionManager,
|
|||
|
|
channel_sender: mpsc::Sender<BotResponse>,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
|||
|
|
session_manager
|
|||
|
|
.set_current_tool(user_id, bot_id, Some(tool_name.to_string()))
|
|||
|
|
.await?;
|
|||
|
|
|
|||
|
|
let user_id = user_id.to_string();
|
|||
|
|
let bot_id = bot_id.to_string();
|
|||
|
|
let script = tool.script.clone();
|
|||
|
|
let session_manager_clone = session_manager.clone();
|
|||
|
|
let waiting_responses = self.waiting_responses.clone();
|
|||
|
|
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let mut engine = rhai::Engine::new();
|
|||
|
|
let (talk_tx, mut talk_rx) = mpsc::channel(100);
|
|||
|
|
let (hear_tx, mut hear_rx) = mpsc::channel(100);
|
|||
|
|
|
|||
|
|
{
|
|||
|
|
let key = format!("{}:{}", user_id, bot_id);
|
|||
|
|
let mut waiting = waiting_responses.lock().await;
|
|||
|
|
waiting.insert(key, hear_tx);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let channel_sender_clone = channel_sender.clone();
|
|||
|
|
let user_id_clone = user_id.clone();
|
|||
|
|
let bot_id_clone = bot_id.clone();
|
|||
|
|
|
|||
|
|
let talk_tx_clone = talk_tx.clone();
|
|||
|
|
engine.register_fn("TALK", move |message: String| {
|
|||
|
|
let tx = talk_tx_clone.clone();
|
|||
|
|
tokio::spawn(async move {
|
|||
|
|
let _ = tx.send(message).await;
|
|||
|
|
});
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
let hear_rx_mutex = Arc::new(Mutex::new(hear_rx));
|
|||
|
|
engine.register_fn("HEAR", move || {
|
|||
|
|
let hear_rx = hear_rx_mutex.clone();
|
|||
|
|
tokio::task::block_in_place(|| {
|
|||
|
|
tokio::runtime::Handle::current().block_on(async move {
|
|||
|
|
let mut receiver = hear_rx.lock().await;
|
|||
|
|
receiver.recv().await.unwrap_or_default()
|
|||
|
|
})
|
|||
|
|
})
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
let script_result =
|
|||
|
|
tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await;
|
|||
|
|
|
|||
|
|
if let Ok(Err(e)) = script_result {
|
|||
|
|
let error_response = BotResponse {
|
|||
|
|
bot_id: bot_id_clone.clone(),
|
|||
|
|
user_id: user_id_clone.clone(),
|
|||
|
|
session_id: Uuid::new_v4().to_string(),
|
|||
|
|
channel: "test".to_string(),
|
|||
|
|
content: format!("Tool error: {}", e),
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
let _ = channel_sender_clone.send(error_response).await;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
while let Some(message) = talk_rx.recv().await {
|
|||
|
|
let response = BotResponse {
|
|||
|
|
bot_id: bot_id.clone(),
|
|||
|
|
user_id: user_id.clone(),
|
|||
|
|
session_id: Uuid::new_v4().to_string(),
|
|||
|
|
channel: "test".to_string(),
|
|||
|
|
content: message,
|
|||
|
|
message_type: "text".to_string(),
|
|||
|
|
stream_token: None,
|
|||
|
|
is_complete: true,
|
|||
|
|
};
|
|||
|
|
let _ = channel_sender.send(response).await;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
let _ = session_manager_clone
|
|||
|
|
.set_current_tool(&user_id, &bot_id, None)
|
|||
|
|
.await;
|
|||
|
|
});
|
|||
|
|
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
pub async fn provide_user_response(
|
|||
|
|
&self,
|
|||
|
|
user_id: &str,
|
|||
|
|
bot_id: &str,
|
|||
|
|
response: String,
|
|||
|
|
) -> Result<(), Box<dyn std::error::Error>> {
|
|||
|
|
let key = format!("{}:{}", user_id, bot_id);
|
|||
|
|
let mut waiting = self.waiting_responses.lock().await;
|
|||
|
|
if let Some(tx) = waiting.get_mut(&key) {
|
|||
|
|
let _ = tx.send(response).await;
|
|||
|
|
waiting.remove(&key);
|
|||
|
|
}
|
|||
|
|
Ok(())
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
impl Default for ToolManager {
|
|||
|
|
fn default() -> Self {
|
|||
|
|
Self::new()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
.
|
|||
|
|
└── src
|
|||
|
|
├── auth
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── automation
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── basic
|
|||
|
|
│ ├── keywords
|
|||
|
|
│ │ ├── create_draft.rs
|
|||
|
|
│ │ ├── create_site.rs
|
|||
|
|
│ │ ├── find.rs
|
|||
|
|
│ │ ├── first.rs
|
|||
|
|
│ │ ├── format.rs
|
|||
|
|
│ │ ├── for_next.rs
|
|||
|
|
│ │ ├── get.rs
|
|||
|
|
│ │ ├── get_website.rs
|
|||
|
|
│ │ ├── last.rs
|
|||
|
|
│ │ ├── llm_keyword.rs
|
|||
|
|
│ │ ├── mod.rs
|
|||
|
|
│ │ ├── on.rs
|
|||
|
|
│ │ ├── print.rs
|
|||
|
|
│ │ ├── set.rs
|
|||
|
|
│ │ ├── set_schedule.rs
|
|||
|
|
│ │ └── wait.rs
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── bot
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── channels
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── chart
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── config
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── context
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── email
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── file
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── llm
|
|||
|
|
│ ├── llm_generic.rs
|
|||
|
|
│ ├── llm_local.rs
|
|||
|
|
│ ├── llm_provider.rs
|
|||
|
|
│ ├── llm.rs
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── main.rs
|
|||
|
|
├── org
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── session
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── shared
|
|||
|
|
│ ├── models.rs
|
|||
|
|
│ ├── mod.rs
|
|||
|
|
│ ├── state.rs
|
|||
|
|
│ └── utils.rs
|
|||
|
|
├── tests
|
|||
|
|
│ ├── integration_email_list.rs
|
|||
|
|
│ ├── integration_file_list_test.rs
|
|||
|
|
│ └── integration_file_upload_test.rs
|
|||
|
|
├── tools
|
|||
|
|
│ └── mod.rs
|
|||
|
|
├── web_automation
|
|||
|
|
│ └── mod.rs
|
|||
|
|
└── whatsapp
|
|||
|
|
└── mod.rs
|
|||
|
|
|
|||
|
|
21 directories, 44 files
|