4249 lines
132 KiB
Text
4249 lines
132 KiB
Text
Consolidated LLM Context
|
||
* Preffer imports than using :: to call methods,
|
||
* Output a single `.sh` script using `cat` so it can be restored directly.
|
||
* No placeholders, only real, production-ready code.
|
||
* No comments, no explanations, no extra text.
|
||
* Follow KISS principles.
|
||
* Provide a complete, professional, working solution.
|
||
* If the script is too long, split into multiple parts, but always return the **entire code**.
|
||
* Output must be **only the code**, nothing else.
|
||
|
||
[package]
|
||
name = "botserver"
|
||
version = "0.1.0"
|
||
edition = "2021"
|
||
authors = ["Rodrigo Rodriguez <me@rodrigorodriguez.com>"]
|
||
description = "General Bots Server"
|
||
license = "AGPL-3.0"
|
||
repository = "https://github.pragmatismo.com.br/generalbots/botserver"
|
||
|
||
[features]
|
||
default = ["postgres", "qdrant"]
|
||
local_llm = []
|
||
postgres = ["sqlx/postgres"]
|
||
qdrant = ["langchain-rust/qdrant"]
|
||
|
||
[dependencies]
|
||
actix-cors = "0.7"
|
||
actix-multipart = "0.7"
|
||
actix-web = "4.9"
|
||
actix-ws = "0.3"
|
||
anyhow = "1.0"
|
||
async-stream = "0.3"
|
||
async-trait = "0.1"
|
||
aes-gcm = "0.10"
|
||
argon2 = "0.5"
|
||
base64 = "0.22"
|
||
bytes = "1.8"
|
||
chrono = { version = "0.4", features = ["serde"] }
|
||
dotenv = "0.15"
|
||
downloader = "0.2"
|
||
env_logger = "0.11"
|
||
futures = "0.3"
|
||
futures-util = "0.3"
|
||
imap = "2.4"
|
||
langchain-rust = { version = "4.6", features = ["qdrant", "postgres"] }
|
||
lettre = { version = "0.11", features = ["smtp-transport", "builder", "tokio1", "tokio1-native-tls"] }
|
||
livekit = "0.7"
|
||
log = "0.4"
|
||
mailparse = "0.15"
|
||
minio = { git = "https://github.com/minio/minio-rs", branch = "master" }
|
||
native-tls = "0.2"
|
||
num-format = "0.4"
|
||
qdrant-client = "1.12"
|
||
rhai = "1.22"
|
||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||
regex = "1.11"
|
||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||
scraper = "0.20"
|
||
serde = { version = "1.0", features = ["derive"] }
|
||
serde_json = "1.0"
|
||
smartstring = "1.0"
|
||
sqlx = { version = "0.8", features = ["time", "uuid", "runtime-tokio-rustls", "postgres", "chrono"] }
|
||
tempfile = "3"
|
||
thirtyfour = "0.34"
|
||
tokio = { version = "1.41", features = ["full"] }
|
||
tokio-stream = "0.1"
|
||
tracing = "0.1"
|
||
tracing-subscriber = { version = "0.3", features = ["fmt"] }
|
||
urlencoding = "2.1"
|
||
uuid = { version = "1.11", features = ["serde", "v4"] }
|
||
zip = "2.2"
|
||
|
||
You are fixing Rust code in a Cargo project. The user is providing problematic code that needs to be corrected.
|
||
|
||
## Your Task
|
||
Fix ALL compiler errors and logical issues while maintaining the original intent. Return the COMPLETE corrected files as a SINGLE .sh script that can be executed from project root.
|
||
Use Cargo.toml as reference, do not change it.
|
||
Only return input files, all other files already exists.
|
||
If something, need to be added to a external file, inform it separated.
|
||
|
||
## Critical Requirements
|
||
1. **Return as SINGLE .sh script** - Output must be a complete shell script using `cat > file << 'EOF'` pattern
|
||
2. **Include ALL files** - Every corrected file must be included in the script
|
||
3. **Respect Cargo.toml** - Check dependencies, editions, and features to avoid compiler errors
|
||
4. **Type safety** - Ensure all types match and trait bounds are satisfied
|
||
5. **Ownership rules** - Fix borrowing, ownership, and lifetime issues
|
||
|
||
## Output Format Requirements
|
||
You MUST return exactly this example format:
|
||
|
||
```sh
|
||
#!/bin/bash
|
||
|
||
# Restore fixed Rust project
|
||
|
||
cat > src/<filenamehere>.rs << 'EOF'
|
||
use std::io;
|
||
|
||
// test
|
||
|
||
cat > src/<anotherfile>.rs << 'EOF'
|
||
// Fixed library code
|
||
pub fn add(a: i32, b: i32) -> i32 {
|
||
a + b
|
||
}
|
||
EOF
|
||
|
||
----
|
||
|
||
use async_trait::async_trait;
|
||
use chrono::Utc;
|
||
use livekit::prelude::*;
|
||
use log::info;
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::{mpsc, Mutex};
|
||
|
||
use crate::shared::{BotResponse, UserMessage};
|
||
|
||
#[async_trait]
|
||
pub trait ChannelAdapter: Send + Sync {
|
||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>>;
|
||
}
|
||
|
||
pub struct WebChannelAdapter {
|
||
connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||
}
|
||
|
||
impl WebChannelAdapter {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
connections: Arc::new(Mutex::new(HashMap::new())),
|
||
}
|
||
}
|
||
|
||
pub async fn add_connection(&self, session_id: String, tx: mpsc::Sender<BotResponse>) {
|
||
self.connections.lock().await.insert(session_id, tx);
|
||
}
|
||
|
||
pub async fn remove_connection(&self, session_id: &str) {
|
||
self.connections.lock().await.remove(session_id);
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl ChannelAdapter for WebChannelAdapter {
|
||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
||
let connections = self.connections.lock().await;
|
||
if let Some(tx) = connections.get(&response.session_id) {
|
||
tx.send(response).await?;
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
pub struct VoiceAdapter {
|
||
livekit_url: String,
|
||
api_key: String,
|
||
api_secret: String,
|
||
rooms: Arc<Mutex<HashMap<String, Room>>>,
|
||
connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||
}
|
||
|
||
impl VoiceAdapter {
|
||
pub fn new(livekit_url: String, api_key: String, api_secret: String) -> Self {
|
||
Self {
|
||
livekit_url,
|
||
api_key,
|
||
api_secret,
|
||
rooms: Arc::new(Mutex::new(HashMap::new())),
|
||
connections: Arc::new(Mutex::new(HashMap::new())),
|
||
}
|
||
}
|
||
|
||
pub async fn start_voice_session(
|
||
&self,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
) -> Result<String, Box<dyn std::error::Error>> {
|
||
let token = AccessToken::with_api_key(&self.api_key, &self.api_secret)
|
||
.with_identity(user_id)
|
||
.with_name(user_id)
|
||
.with_room_name(session_id)
|
||
.with_room_join(true)
|
||
.to_jwt()?;
|
||
|
||
let room_options = RoomOptions {
|
||
auto_subscribe: true,
|
||
..Default::default()
|
||
};
|
||
|
||
let (room, mut events) = Room::connect(&self.livekit_url, &token, room_options).await?;
|
||
self.rooms
|
||
.lock()
|
||
.await
|
||
.insert(session_id.to_string(), room.clone());
|
||
|
||
let rooms_clone = self.rooms.clone();
|
||
let connections_clone = self.connections.clone();
|
||
let session_id_clone = session_id.to_string();
|
||
|
||
tokio::spawn(async move {
|
||
while let Some(event) = events.recv().await {
|
||
match event {
|
||
RoomEvent::DataReceived(data_packet) => {
|
||
if let Ok(message) =
|
||
serde_json::from_slice::<UserMessage>(&data_packet.data)
|
||
{
|
||
info!("Received voice message: {}", message.content);
|
||
if let Some(tx) =
|
||
connections_clone.lock().await.get(&message.session_id)
|
||
{
|
||
let _ = tx
|
||
.send(BotResponse {
|
||
bot_id: message.bot_id,
|
||
user_id: message.user_id,
|
||
session_id: message.session_id,
|
||
channel: "voice".to_string(),
|
||
content: format!("🎤 Voice: {}", message.content),
|
||
message_type: "voice".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
})
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
RoomEvent::TrackSubscribed(track, publication, participant) => {
|
||
info!("Voice track subscribed from {}", participant.identity());
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
rooms_clone.lock().await.remove(&session_id_clone);
|
||
});
|
||
|
||
Ok(token)
|
||
}
|
||
|
||
pub async fn stop_voice_session(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
if let Some(room) = self.rooms.lock().await.remove(session_id) {
|
||
room.disconnect().await;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn add_connection(&self, session_id: String, tx: mpsc::Sender<BotResponse>) {
|
||
self.connections.lock().await.insert(session_id, tx);
|
||
}
|
||
|
||
pub async fn send_voice_response(
|
||
&self,
|
||
session_id: &str,
|
||
text: &str,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
if let Some(room) = self.rooms.lock().await.get(session_id) {
|
||
let voice_response = serde_json::json!({
|
||
"type": "voice_response",
|
||
"text": text,
|
||
"timestamp": Utc::now()
|
||
});
|
||
|
||
room.local_participant()
|
||
.publish_data(
|
||
serde_json::to_vec(&voice_response)?,
|
||
DataPacketKind::Reliable,
|
||
&[],
|
||
)
|
||
.await?;
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl ChannelAdapter for VoiceAdapter {
|
||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
||
info!("Sending voice response to: {}", response.user_id);
|
||
self.send_voice_response(&response.session_id, &response.content)
|
||
.await
|
||
}
|
||
}
|
||
|
||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||
use dotenv::dotenv;
|
||
use log::{error, info};
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::env;
|
||
use tokio::time::{sleep, Duration};
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatMessage {
|
||
role: String,
|
||
content: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatCompletionRequest {
|
||
model: String,
|
||
messages: Vec<ChatMessage>,
|
||
stream: Option<bool>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatCompletionResponse {
|
||
id: String,
|
||
object: String,
|
||
created: u64,
|
||
model: String,
|
||
choices: Vec<Choice>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct Choice {
|
||
message: ChatMessage,
|
||
finish_reason: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct LlamaCppRequest {
|
||
prompt: String,
|
||
n_predict: Option<i32>,
|
||
temperature: Option<f32>,
|
||
top_k: Option<i32>,
|
||
top_p: Option<f32>,
|
||
stream: Option<bool>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct LlamaCppResponse {
|
||
content: String,
|
||
stop: bool,
|
||
generation_settings: Option<serde_json::Value>,
|
||
}
|
||
|
||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||
|
||
if llm_local.to_lowercase() != "true" {
|
||
info!("ℹ️ LLM_LOCAL is not enabled, skipping local server startup");
|
||
return Ok(());
|
||
}
|
||
|
||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||
let embedding_url =
|
||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
||
let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||
let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||
|
||
info!("🚀 Starting local llama.cpp servers...");
|
||
info!("📋 Configuration:");
|
||
info!(" LLM URL: {}", llm_url);
|
||
info!(" Embedding URL: {}", embedding_url);
|
||
info!(" LLM Model: {}", llm_model_path);
|
||
info!(" Embedding Model: {}", embedding_model_path);
|
||
|
||
let llm_running = is_server_running(&llm_url).await;
|
||
let embedding_running = is_server_running(&embedding_url).await;
|
||
|
||
if llm_running && embedding_running {
|
||
info!("✅ Both LLM and Embedding servers are already running");
|
||
return Ok(());
|
||
}
|
||
|
||
let mut tasks = vec![];
|
||
|
||
if !llm_running && !llm_model_path.is_empty() {
|
||
info!("🔄 Starting LLM server...");
|
||
tasks.push(tokio::spawn(start_llm_server(
|
||
llama_cpp_path.clone(),
|
||
llm_model_path.clone(),
|
||
llm_url.clone(),
|
||
)));
|
||
} else if llm_model_path.is_empty() {
|
||
info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server");
|
||
}
|
||
|
||
if !embedding_running && !embedding_model_path.is_empty() {
|
||
info!("🔄 Starting Embedding server...");
|
||
tasks.push(tokio::spawn(start_embedding_server(
|
||
llama_cpp_path.clone(),
|
||
embedding_model_path.clone(),
|
||
embedding_url.clone(),
|
||
)));
|
||
} else if embedding_model_path.is_empty() {
|
||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||
}
|
||
|
||
for task in tasks {
|
||
task.await??;
|
||
}
|
||
|
||
info!("⏳ Waiting for servers to become ready...");
|
||
|
||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||
|
||
let mut attempts = 0;
|
||
let max_attempts = 60;
|
||
|
||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||
sleep(Duration::from_secs(2)).await;
|
||
|
||
info!(
|
||
"🔍 Checking server health (attempt {}/{})...",
|
||
attempts + 1,
|
||
max_attempts
|
||
);
|
||
|
||
if !llm_ready && !llm_model_path.is_empty() {
|
||
if is_server_running(&llm_url).await {
|
||
info!(" ✅ LLM server ready at {}", llm_url);
|
||
llm_ready = true;
|
||
} else {
|
||
info!(" ❌ LLM server not ready yet");
|
||
}
|
||
}
|
||
|
||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||
if is_server_running(&embedding_url).await {
|
||
info!(" ✅ Embedding server ready at {}", embedding_url);
|
||
embedding_ready = true;
|
||
} else {
|
||
info!(" ❌ Embedding server not ready yet");
|
||
}
|
||
}
|
||
|
||
attempts += 1;
|
||
|
||
if attempts % 10 == 0 {
|
||
info!(
|
||
"⏰ Still waiting for servers... (attempt {}/{})",
|
||
attempts, max_attempts
|
||
);
|
||
}
|
||
}
|
||
|
||
if llm_ready && embedding_ready {
|
||
info!("🎉 All llama.cpp servers are ready and responding!");
|
||
Ok(())
|
||
} else {
|
||
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
|
||
if !llm_ready && !llm_model_path.is_empty() {
|
||
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||
}
|
||
if !embedding_ready && !embedding_model_path.is_empty() {
|
||
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||
}
|
||
Err(error_msg.into())
|
||
}
|
||
}
|
||
|
||
async fn start_llm_server(
|
||
llama_cpp_path: String,
|
||
model_path: String,
|
||
url: String,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let port = url.split(':').last().unwrap_or("8081");
|
||
|
||
std::env::set_var("OMP_NUM_THREADS", "20");
|
||
std::env::set_var("OMP_PLACES", "cores");
|
||
std::env::set_var("OMP_PROC_BIND", "close");
|
||
|
||
let mut cmd = tokio::process::Command::new("sh");
|
||
cmd.arg("-c").arg(format!(
|
||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||
llama_cpp_path, model_path, port
|
||
));
|
||
|
||
cmd.spawn()?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn start_embedding_server(
|
||
llama_cpp_path: String,
|
||
model_path: String,
|
||
url: String,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let port = url.split(':').last().unwrap_or("8082");
|
||
|
||
let mut cmd = tokio::process::Command::new("sh");
|
||
cmd.arg("-c").arg(format!(
|
||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &",
|
||
llama_cpp_path, model_path, port
|
||
));
|
||
|
||
cmd.spawn()?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn is_server_running(url: &str) -> bool {
|
||
let client = reqwest::Client::new();
|
||
match client.get(&format!("{}/health", url)).send().await {
|
||
Ok(response) => response.status().is_success(),
|
||
Err(_) => false,
|
||
}
|
||
}
|
||
|
||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||
let mut prompt = String::new();
|
||
|
||
for message in messages {
|
||
match message.role.as_str() {
|
||
"system" => {
|
||
prompt.push_str(&format!("System: {}\n\n", message.content));
|
||
}
|
||
"user" => {
|
||
prompt.push_str(&format!("User: {}\n\n", message.content));
|
||
}
|
||
"assistant" => {
|
||
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
|
||
}
|
||
_ => {
|
||
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
|
||
}
|
||
}
|
||
}
|
||
|
||
prompt.push_str("Assistant: ");
|
||
prompt
|
||
}
|
||
|
||
#[post("/local/v1/chat/completions")]
|
||
pub async fn chat_completions_local(
|
||
req_body: web::Json<ChatCompletionRequest>,
|
||
_req: HttpRequest,
|
||
) -> Result<HttpResponse> {
|
||
dotenv().ok();
|
||
|
||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||
|
||
let prompt = messages_to_prompt(&req_body.messages);
|
||
|
||
let llama_request = LlamaCppRequest {
|
||
prompt,
|
||
n_predict: Some(500),
|
||
temperature: Some(0.7),
|
||
top_k: Some(40),
|
||
top_p: Some(0.9),
|
||
stream: req_body.stream,
|
||
};
|
||
|
||
let client = Client::builder()
|
||
.timeout(Duration::from_secs(120))
|
||
.build()
|
||
.map_err(|e| {
|
||
error!("Error creating HTTP client: {}", e);
|
||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||
})?;
|
||
|
||
let response = client
|
||
.post(&format!("{}/completion", llama_url))
|
||
.header("Content-Type", "application/json")
|
||
.json(&llama_request)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
error!("Error calling llama.cpp server: {}", e);
|
||
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
|
||
})?;
|
||
|
||
let status = response.status();
|
||
|
||
if status.is_success() {
|
||
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
|
||
error!("Error parsing llama.cpp response: {}", e);
|
||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||
})?;
|
||
|
||
let openai_response = ChatCompletionResponse {
|
||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||
object: "chat.completion".to_string(),
|
||
created: std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_secs(),
|
||
model: req_body.model.clone(),
|
||
choices: vec![Choice {
|
||
message: ChatMessage {
|
||
role: "assistant".to_string(),
|
||
content: llama_response.content.trim().to_string(),
|
||
},
|
||
finish_reason: if llama_response.stop {
|
||
"stop".to_string()
|
||
} else {
|
||
"length".to_string()
|
||
},
|
||
}],
|
||
};
|
||
|
||
Ok(HttpResponse::Ok().json(openai_response))
|
||
} else {
|
||
let error_text = response
|
||
.text()
|
||
.await
|
||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||
|
||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||
|
||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||
|
||
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||
"error": {
|
||
"message": error_text,
|
||
"type": "server_error"
|
||
}
|
||
})))
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct EmbeddingRequest {
|
||
#[serde(deserialize_with = "deserialize_input")]
|
||
pub input: Vec<String>,
|
||
pub model: String,
|
||
#[serde(default)]
|
||
pub _encoding_format: Option<String>,
|
||
}
|
||
|
||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||
where
|
||
D: serde::Deserializer<'de>,
|
||
{
|
||
use serde::de::{self, Visitor};
|
||
use std::fmt;
|
||
|
||
struct InputVisitor;
|
||
|
||
impl<'de> Visitor<'de> for InputVisitor {
|
||
type Value = Vec<String>;
|
||
|
||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||
formatter.write_str("a string or an array of strings")
|
||
}
|
||
|
||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||
where
|
||
E: de::Error,
|
||
{
|
||
Ok(vec![value.to_string()])
|
||
}
|
||
|
||
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||
where
|
||
E: de::Error,
|
||
{
|
||
Ok(vec![value])
|
||
}
|
||
|
||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||
where
|
||
A: de::SeqAccess<'de>,
|
||
{
|
||
let mut vec = Vec::new();
|
||
while let Some(value) = seq.next_element::<String>()? {
|
||
vec.push(value);
|
||
}
|
||
Ok(vec)
|
||
}
|
||
}
|
||
|
||
deserializer.deserialize_any(InputVisitor)
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct EmbeddingResponse {
|
||
pub object: String,
|
||
pub data: Vec<EmbeddingData>,
|
||
pub model: String,
|
||
pub usage: Usage,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct EmbeddingData {
|
||
pub object: String,
|
||
pub embedding: Vec<f32>,
|
||
pub index: usize,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct Usage {
|
||
pub prompt_tokens: u32,
|
||
pub total_tokens: u32,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct LlamaCppEmbeddingRequest {
|
||
pub content: String,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct LlamaCppEmbeddingResponseItem {
|
||
pub index: usize,
|
||
pub embedding: Vec<Vec<f32>>,
|
||
}
|
||
|
||
#[post("/v1/embeddings")]
|
||
pub async fn embeddings_local(
|
||
req_body: web::Json<EmbeddingRequest>,
|
||
_req: HttpRequest,
|
||
) -> Result<HttpResponse> {
|
||
dotenv().ok();
|
||
|
||
let llama_url =
|
||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||
|
||
let client = Client::builder()
|
||
.timeout(Duration::from_secs(120))
|
||
.build()
|
||
.map_err(|e| {
|
||
error!("Error creating HTTP client: {}", e);
|
||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||
})?;
|
||
|
||
let mut embeddings_data = Vec::new();
|
||
let mut total_tokens = 0;
|
||
|
||
for (index, input_text) in req_body.input.iter().enumerate() {
|
||
let llama_request = LlamaCppEmbeddingRequest {
|
||
content: input_text.clone(),
|
||
};
|
||
|
||
let response = client
|
||
.post(&format!("{}/embedding", llama_url))
|
||
.header("Content-Type", "application/json")
|
||
.json(&llama_request)
|
||
.send()
|
||
.await
|
||
.map_err(|e| {
|
||
error!("Error calling llama.cpp server for embedding: {}", e);
|
||
actix_web::error::ErrorInternalServerError(
|
||
"Failed to call llama.cpp server for embedding",
|
||
)
|
||
})?;
|
||
|
||
let status = response.status();
|
||
|
||
if status.is_success() {
|
||
let raw_response = response.text().await.map_err(|e| {
|
||
error!("Error reading response text: {}", e);
|
||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||
})?;
|
||
|
||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||
serde_json::from_str(&raw_response).map_err(|e| {
|
||
error!("Error parsing llama.cpp embedding response: {}", e);
|
||
error!("Raw response: {}", raw_response);
|
||
actix_web::error::ErrorInternalServerError(
|
||
"Failed to parse llama.cpp embedding response",
|
||
)
|
||
})?;
|
||
|
||
if let Some(item) = llama_response.get(0) {
|
||
let flattened_embedding = if !item.embedding.is_empty() {
|
||
item.embedding[0].clone()
|
||
} else {
|
||
vec![]
|
||
};
|
||
|
||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||
total_tokens += estimated_tokens;
|
||
|
||
embeddings_data.push(EmbeddingData {
|
||
object: "embedding".to_string(),
|
||
embedding: flattened_embedding,
|
||
index,
|
||
});
|
||
} else {
|
||
error!("No embedding data returned for input: {}", input_text);
|
||
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
|
||
"error": {
|
||
"message": format!("No embedding data returned for input {}", index),
|
||
"type": "server_error"
|
||
}
|
||
})));
|
||
}
|
||
} else {
|
||
let error_text = response
|
||
.text()
|
||
.await
|
||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||
|
||
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||
|
||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||
|
||
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||
"error": {
|
||
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
|
||
"type": "server_error"
|
||
}
|
||
})));
|
||
}
|
||
}
|
||
|
||
let openai_response = EmbeddingResponse {
|
||
object: "list".to_string(),
|
||
data: embeddings_data,
|
||
model: req_body.model.clone(),
|
||
usage: Usage {
|
||
prompt_tokens: total_tokens,
|
||
total_tokens,
|
||
},
|
||
};
|
||
|
||
Ok(HttpResponse::Ok().json(openai_response))
|
||
}
|
||
|
||
#[actix_web::get("/health")]
|
||
pub async fn health() -> Result<HttpResponse> {
|
||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||
|
||
if is_server_running(&llama_url).await {
|
||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||
"status": "healthy",
|
||
"llama_server": "running"
|
||
})))
|
||
} else {
|
||
Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||
"status": "unhealthy",
|
||
"llama_server": "not running"
|
||
})))
|
||
}
|
||
}
|
||
|
||
use log::error;
|
||
|
||
use actix_web::{
|
||
web::{self, Bytes},
|
||
HttpResponse, Responder,
|
||
};
|
||
use anyhow::Result;
|
||
use futures::StreamExt;
|
||
use langchain_rust::{
|
||
chain::{Chain, LLMChainBuilder},
|
||
fmt_message, fmt_template,
|
||
language_models::llm::LLM,
|
||
llm::openai::OpenAI,
|
||
message_formatter,
|
||
prompt::HumanMessagePromptTemplate,
|
||
prompt_args,
|
||
schemas::messages::Message,
|
||
template_fstring,
|
||
};
|
||
|
||
use crate::{state::AppState, utils::azure_from_config};
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct ChatRequest {
|
||
input: String,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
struct ChatResponse {
|
||
text: String,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
action: Option<ChatAction>,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
#[serde(tag = "type", content = "content")]
|
||
enum ChatAction {
|
||
ReplyEmail { content: String },
|
||
// Add other action variants here as needed
|
||
}
|
||
|
||
#[actix_web::post("/chat")]
|
||
pub async fn chat(
|
||
web::Json(request): web::Json<String>,
|
||
state: web::Data<AppState>,
|
||
) -> Result<impl Responder, actix_web::Error> {
|
||
let azure_config = azure_from_config(&state.config.clone().unwrap().ai);
|
||
let open_ai = OpenAI::new(azure_config);
|
||
|
||
// Parse the context JSON
|
||
let context: serde_json::Value = match serde_json::from_str(&request) {
|
||
Ok(ctx) => ctx,
|
||
Err(_) => serde_json::json!({}),
|
||
};
|
||
|
||
// Check view type and prepare appropriate prompt
|
||
let view_type = context
|
||
.get("viewType")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or("");
|
||
let (prompt, might_trigger_action) = match view_type {
|
||
"email" => (
|
||
format!(
|
||
"Respond to this email: {}. Keep it professional and concise. \
|
||
If the email requires a response, provide one in the 'replyEmail' action format.",
|
||
request
|
||
),
|
||
true,
|
||
),
|
||
_ => (request, false),
|
||
};
|
||
|
||
let response_text = match open_ai.invoke(&prompt).await {
|
||
Ok(res) => res,
|
||
Err(err) => {
|
||
error!("Error invoking API: {}", err);
|
||
return Err(actix_web::error::ErrorInternalServerError(
|
||
"Failed to invoke OpenAI API",
|
||
));
|
||
}
|
||
};
|
||
|
||
// Prepare response with potential action
|
||
let mut chat_response = ChatResponse {
|
||
text: response_text.clone(),
|
||
action: None,
|
||
};
|
||
|
||
// If in email view and the response looks like an email reply, add action
|
||
if might_trigger_action && view_type == "email" {
|
||
chat_response.action = Some(ChatAction::ReplyEmail {
|
||
content: response_text,
|
||
});
|
||
}
|
||
|
||
Ok(HttpResponse::Ok().json(chat_response))
|
||
}
|
||
|
||
#[actix_web::post("/stream")]
|
||
pub async fn chat_stream(
|
||
web::Json(request): web::Json<ChatRequest>,
|
||
state: web::Data<AppState>,
|
||
) -> Result<impl Responder, actix_web::Error> {
|
||
let azure_config = azure_from_config(&state.config.clone().unwrap().ai);
|
||
let open_ai = OpenAI::new(azure_config);
|
||
|
||
let prompt = message_formatter![
|
||
fmt_message!(Message::new_system_message(
|
||
"You are world class technical documentation writer."
|
||
)),
|
||
fmt_template!(HumanMessagePromptTemplate::new(template_fstring!(
|
||
"{input}", "input"
|
||
)))
|
||
];
|
||
|
||
let chain = LLMChainBuilder::new()
|
||
.prompt(prompt)
|
||
.llm(open_ai)
|
||
.build()
|
||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||
|
||
let mut stream = chain
|
||
.stream(prompt_args! { "input" => request.input })
|
||
.await
|
||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||
|
||
let actix_stream = async_stream::stream! {
|
||
while let Some(result) = stream.next().await {
|
||
match result {
|
||
Ok(value) => yield Ok::<_, actix_web::Error>(Bytes::from(value.content)),
|
||
Err(e) => yield Err(actix_web::error::ErrorInternalServerError(e)),
|
||
}
|
||
}
|
||
};
|
||
|
||
Ok(HttpResponse::Ok()
|
||
.content_type("text/event-stream")
|
||
.streaming(actix_stream))
|
||
}
|
||
|
||
pub mod llm_generic;
|
||
pub mod llm_local;
|
||
pub mod llm_provider;
|
||
|
||
pub use llm_provider::*;
|
||
|
||
use async_trait::async_trait;
|
||
use futures::StreamExt;
|
||
use langchain_rust::{
|
||
language_models::llm::LLM,
|
||
llm::{claude::Claude, openai::OpenAI},
|
||
schemas::Message,
|
||
};
|
||
use serde_json::Value;
|
||
use std::sync::Arc;
|
||
use tokio::sync::mpsc;
|
||
|
||
use crate::tools::ToolManager;
|
||
|
||
#[async_trait]
|
||
pub trait LLMProvider: Send + Sync {
|
||
async fn generate(
|
||
&self,
|
||
prompt: &str,
|
||
config: &Value,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||
|
||
async fn generate_stream(
|
||
&self,
|
||
prompt: &str,
|
||
config: &Value,
|
||
tx: mpsc::Sender<String>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||
|
||
async fn generate_with_tools(
|
||
&self,
|
||
prompt: &str,
|
||
config: &Value,
|
||
available_tools: &[String],
|
||
tool_manager: Arc<ToolManager>,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||
}
|
||
|
||
pub struct OpenAIClient {
|
||
client: OpenAI,
|
||
}
|
||
|
||
impl OpenAIClient {
|
||
pub fn new(client: OpenAI) -> Self {
|
||
Self { client }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LLMProvider for OpenAIClient {
|
||
async fn generate(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
||
|
||
let result = self
|
||
.client
|
||
.invoke(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
Ok(result)
|
||
}
|
||
|
||
async fn generate_stream(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
mut tx: mpsc::Sender<String>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
||
|
||
let mut stream = self
|
||
.client
|
||
.stream(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
while let Some(result) = stream.next().await {
|
||
match result {
|
||
Ok(chunk) => {
|
||
let content = chunk.content;
|
||
if !content.is_empty() {
|
||
let _ = tx.send(content.to_string()).await;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
eprintln!("Stream error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
async fn generate_with_tools(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
available_tools: &[String],
|
||
_tool_manager: Arc<ToolManager>,
|
||
_session_id: &str,
|
||
_user_id: &str,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let tools_info = if available_tools.is_empty() {
|
||
String::new()
|
||
} else {
|
||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||
};
|
||
|
||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||
|
||
let messages = vec![Message::new_human_message(enhanced_prompt)];
|
||
|
||
let result = self
|
||
.client
|
||
.invoke(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
Ok(result)
|
||
}
|
||
}
|
||
|
||
pub struct AnthropicClient {
|
||
client: Claude,
|
||
}
|
||
|
||
impl AnthropicClient {
|
||
pub fn new(api_key: String) -> Self {
|
||
let client = Claude::default().with_api_key(api_key);
|
||
Self { client }
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LLMProvider for AnthropicClient {
|
||
async fn generate(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
||
|
||
let result = self
|
||
.client
|
||
.invoke(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
Ok(result)
|
||
}
|
||
|
||
async fn generate_stream(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
mut tx: mpsc::Sender<String>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
||
|
||
let mut stream = self
|
||
.client
|
||
.stream(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
while let Some(result) = stream.next().await {
|
||
match result {
|
||
Ok(chunk) => {
|
||
let content = chunk.content;
|
||
if !content.is_empty() {
|
||
let _ = tx.send(content.to_string()).await;
|
||
}
|
||
}
|
||
Err(e) => {
|
||
eprintln!("Stream error: {}", e);
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
async fn generate_with_tools(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
available_tools: &[String],
|
||
_tool_manager: Arc<ToolManager>,
|
||
_session_id: &str,
|
||
_user_id: &str,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let tools_info = if available_tools.is_empty() {
|
||
String::new()
|
||
} else {
|
||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||
};
|
||
|
||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||
|
||
let messages = vec![Message::new_human_message(enhanced_prompt)];
|
||
|
||
let result = self
|
||
.client
|
||
.invoke(&messages)
|
||
.await
|
||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||
|
||
Ok(result)
|
||
}
|
||
}
|
||
|
||
pub struct MockLLMProvider;
|
||
|
||
impl MockLLMProvider {
|
||
pub fn new() -> Self {
|
||
Self
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl LLMProvider for MockLLMProvider {
|
||
async fn generate(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
Ok(format!("Mock response to: {}", prompt))
|
||
}
|
||
|
||
async fn generate_stream(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
mut tx: mpsc::Sender<String>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let response = format!("Mock stream response to: {}", prompt);
|
||
for word in response.split_whitespace() {
|
||
let _ = tx.send(format!("{} ", word)).await;
|
||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
async fn generate_with_tools(
|
||
&self,
|
||
prompt: &str,
|
||
_config: &Value,
|
||
available_tools: &[String],
|
||
_tool_manager: Arc<ToolManager>,
|
||
_session_id: &str,
|
||
_user_id: &str,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let tools_list = if available_tools.is_empty() {
|
||
"no tools available".to_string()
|
||
} else {
|
||
available_tools.join(", ")
|
||
};
|
||
Ok(format!(
|
||
"Mock response with tools [{}] to: {}",
|
||
tools_list, prompt
|
||
))
|
||
}
|
||
}
|
||
|
||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||
use dotenv::dotenv;
|
||
use log::{error, info};
|
||
use regex::Regex;
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::env;
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatMessage {
|
||
role: String,
|
||
content: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatCompletionRequest {
|
||
model: String,
|
||
messages: Vec<ChatMessage>,
|
||
stream: Option<bool>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct ChatCompletionResponse {
|
||
id: String,
|
||
object: String,
|
||
created: u64,
|
||
model: String,
|
||
choices: Vec<Choice>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize, Deserialize)]
|
||
struct Choice {
|
||
message: ChatMessage,
|
||
finish_reason: String,
|
||
}
|
||
|
||
fn clean_request_body(body: &str) -> String {
|
||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||
re.replace_all(body, "").to_string()
|
||
}
|
||
|
||
#[post("/v1/chat/completions")]
|
||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||
info!("Original POST Data: {}", body_str);
|
||
|
||
dotenv().ok();
|
||
|
||
let api_key = env::var("AI_KEY")
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||
let model = env::var("AI_LLM_MODEL")
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||
let endpoint = env::var("AI_ENDPOINT")
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||
|
||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||
|
||
if let Some(obj) = json_value.as_object_mut() {
|
||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||
}
|
||
|
||
let modified_body_str = serde_json::to_string(&json_value)
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?;
|
||
|
||
info!("Modified POST Data: {}", modified_body_str);
|
||
|
||
let mut headers = reqwest::header::HeaderMap::new();
|
||
headers.insert(
|
||
"Authorization",
|
||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?,
|
||
);
|
||
headers.insert(
|
||
"Content-Type",
|
||
reqwest::header::HeaderValue::from_static("application/json"),
|
||
);
|
||
|
||
let client = Client::new();
|
||
let response = client
|
||
.post(&endpoint)
|
||
.headers(headers)
|
||
.body(modified_body_str)
|
||
.send()
|
||
.await
|
||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||
|
||
let status = response.status();
|
||
let raw_response = response
|
||
.text()
|
||
.await
|
||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||
|
||
info!("Provider response status: {}", status);
|
||
info!("Provider response body: {}", raw_response);
|
||
|
||
if status.is_success() {
|
||
match convert_to_openai_format(&raw_response) {
|
||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||
.content_type("application/json")
|
||
.body(openai_response)),
|
||
Err(e) => {
|
||
error!("Failed to convert response format: {}", e);
|
||
Ok(HttpResponse::Ok()
|
||
.content_type("application/json")
|
||
.body(raw_response))
|
||
}
|
||
}
|
||
} else {
|
||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||
|
||
Ok(HttpResponse::build(actix_status)
|
||
.content_type("application/json")
|
||
.body(raw_response))
|
||
}
|
||
}
|
||
|
||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||
#[derive(serde::Deserialize)]
|
||
struct ProviderChoice {
|
||
message: ProviderMessage,
|
||
#[serde(default)]
|
||
finish_reason: Option<String>,
|
||
}
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct ProviderMessage {
|
||
role: Option<String>,
|
||
content: String,
|
||
}
|
||
|
||
#[derive(serde::Deserialize)]
|
||
struct ProviderResponse {
|
||
id: Option<String>,
|
||
object: Option<String>,
|
||
created: Option<u64>,
|
||
model: Option<String>,
|
||
choices: Vec<ProviderChoice>,
|
||
usage: Option<ProviderUsage>,
|
||
}
|
||
|
||
#[derive(serde::Deserialize, Default)]
|
||
struct ProviderUsage {
|
||
prompt_tokens: Option<u32>,
|
||
completion_tokens: Option<u32>,
|
||
total_tokens: Option<u32>,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
struct OpenAIResponse {
|
||
id: String,
|
||
object: String,
|
||
created: u64,
|
||
model: String,
|
||
choices: Vec<OpenAIChoice>,
|
||
usage: OpenAIUsage,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
struct OpenAIChoice {
|
||
index: u32,
|
||
message: OpenAIMessage,
|
||
finish_reason: String,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
struct OpenAIMessage {
|
||
role: String,
|
||
content: String,
|
||
}
|
||
|
||
#[derive(serde::Serialize)]
|
||
struct OpenAIUsage {
|
||
prompt_tokens: u32,
|
||
completion_tokens: u32,
|
||
total_tokens: u32,
|
||
}
|
||
|
||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||
|
||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||
let content = first_choice.message.content.clone();
|
||
let role = first_choice
|
||
.message
|
||
.role
|
||
.clone()
|
||
.unwrap_or_else(|| "assistant".to_string());
|
||
|
||
let usage = provider.usage.unwrap_or_default();
|
||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||
let completion_tokens = usage
|
||
.completion_tokens
|
||
.unwrap_or_else(|| content.split_whitespace().count() as u32);
|
||
let total_tokens = usage
|
||
.total_tokens
|
||
.unwrap_or(prompt_tokens + completion_tokens);
|
||
|
||
let openai_response = OpenAIResponse {
|
||
id: provider
|
||
.id
|
||
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
|
||
object: provider
|
||
.object
|
||
.unwrap_or_else(|| "chat.completion".to_string()),
|
||
created: provider.created.unwrap_or_else(|| {
|
||
std::time::SystemTime::now()
|
||
.duration_since(std::time::UNIX_EPOCH)
|
||
.unwrap()
|
||
.as_secs()
|
||
}),
|
||
model: provider.model.unwrap_or_else(|| "llama".to_string()),
|
||
choices: vec![OpenAIChoice {
|
||
index: 0,
|
||
message: OpenAIMessage { role, content },
|
||
finish_reason: first_choice
|
||
.finish_reason
|
||
.clone()
|
||
.unwrap_or_else(|| "stop".to_string()),
|
||
}],
|
||
usage: OpenAIUsage {
|
||
prompt_tokens,
|
||
completion_tokens,
|
||
total_tokens,
|
||
},
|
||
};
|
||
|
||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||
}
|
||
|
||
use async_trait::async_trait;
|
||
use log::info;
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::Mutex;
|
||
|
||
use crate::shared::BotResponse;
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppMessage {
|
||
pub entry: Vec<WhatsAppEntry>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppEntry {
|
||
pub changes: Vec<WhatsAppChange>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppChange {
|
||
pub value: WhatsAppValue,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppValue {
|
||
pub contacts: Option<Vec<WhatsAppContact>>,
|
||
pub messages: Option<Vec<WhatsAppMessageData>>,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppContact {
|
||
pub profile: WhatsAppProfile,
|
||
pub wa_id: String,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppProfile {
|
||
pub name: String,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppMessageData {
|
||
pub from: String,
|
||
pub id: String,
|
||
pub timestamp: String,
|
||
pub text: Option<WhatsAppText>,
|
||
pub r#type: String,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct WhatsAppText {
|
||
pub body: String,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
pub struct WhatsAppResponse {
|
||
pub messaging_product: String,
|
||
pub to: String,
|
||
pub text: WhatsAppResponseText,
|
||
}
|
||
|
||
#[derive(Serialize)]
|
||
pub struct WhatsAppResponseText {
|
||
pub body: String,
|
||
}
|
||
|
||
pub struct WhatsAppAdapter {
|
||
client: Client,
|
||
access_token: String,
|
||
phone_number_id: String,
|
||
webhook_verify_token: String,
|
||
sessions: Arc<Mutex<HashMap<String, String>>>,
|
||
}
|
||
|
||
impl WhatsAppAdapter {
|
||
pub fn new(access_token: String, phone_number_id: String, webhook_verify_token: String) -> Self {
|
||
Self {
|
||
client: Client::new(),
|
||
access_token,
|
||
phone_number_id,
|
||
webhook_verify_token,
|
||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||
}
|
||
}
|
||
|
||
pub async fn get_session_id(&self, phone: &str) -> String {
|
||
let sessions = self.sessions.lock().await;
|
||
sessions.get(phone).cloned().unwrap_or_else(|| {
|
||
drop(sessions);
|
||
let session_id = uuid::Uuid::new_v4().to_string();
|
||
let mut sessions = self.sessions.lock().await;
|
||
sessions.insert(phone.to_string(), session_id.clone());
|
||
session_id
|
||
})
|
||
}
|
||
|
||
pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||
let url = format!(
|
||
"https://graph.facebook.com/v17.0/{}/messages",
|
||
self.phone_number_id
|
||
);
|
||
|
||
let response_data = WhatsAppResponse {
|
||
messaging_product: "whatsapp".to_string(),
|
||
to: to.to_string(),
|
||
text: WhatsAppResponseText {
|
||
body: body.to_string(),
|
||
},
|
||
};
|
||
|
||
let response = self.client
|
||
.post(&url)
|
||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||
.json(&response_data)
|
||
.send()
|
||
.await?;
|
||
|
||
if response.status().is_success() {
|
||
info!("WhatsApp message sent to {}", to);
|
||
} else {
|
||
let error_text = response.text().await?;
|
||
log::error!("Failed to send WhatsApp message: {}", error_text);
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error>> {
|
||
let mut user_messages = Vec::new();
|
||
|
||
for entry in message.entry {
|
||
for change in entry.changes {
|
||
if let Some(messages) = change.value.messages {
|
||
for msg in messages {
|
||
if let Some(text) = msg.text {
|
||
let session_id = self.get_session_id(&msg.from).await;
|
||
|
||
let user_message = crate::shared::UserMessage {
|
||
bot_id: "default_bot".to_string(),
|
||
user_id: msg.from.clone(),
|
||
session_id: session_id.clone(),
|
||
channel: "whatsapp".to_string(),
|
||
content: text.body,
|
||
message_type: msg.r#type,
|
||
media_url: None,
|
||
timestamp: chrono::Utc::now(),
|
||
};
|
||
|
||
user_messages.push(user_message);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(user_messages)
|
||
}
|
||
|
||
pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||
if mode == "subscribe" && token == self.webhook_verify_token {
|
||
Ok(challenge.to_string())
|
||
} else {
|
||
Err("Invalid verification".into())
|
||
}
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl crate::channels::ChannelAdapter for WhatsAppAdapter {
|
||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
||
info!("Sending WhatsApp response to: {}", response.user_id);
|
||
self.send_whatsapp_message(&response.user_id, &response.content).await
|
||
}
|
||
}
|
||
|
||
use std::env;
|
||
|
||
#[derive(Clone)]
|
||
pub struct AppConfig {
|
||
pub minio: MinioConfig,
|
||
pub server: ServerConfig,
|
||
pub database: DatabaseConfig,
|
||
pub database_custom: DatabaseConfig,
|
||
pub email: EmailConfig,
|
||
pub ai: AIConfig,
|
||
pub site_path: String,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct DatabaseConfig {
|
||
pub username: String,
|
||
pub password: String,
|
||
pub server: String,
|
||
pub port: u32,
|
||
pub database: String,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct MinioConfig {
|
||
pub server: String,
|
||
pub access_key: String,
|
||
pub secret_key: String,
|
||
pub use_ssl: bool,
|
||
pub bucket: String,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct ServerConfig {
|
||
pub host: String,
|
||
pub port: u16,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct EmailConfig {
|
||
pub from: String,
|
||
pub server: String,
|
||
pub port: u16,
|
||
pub username: String,
|
||
pub password: String,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct AIConfig {
|
||
pub instance: String,
|
||
pub key: String,
|
||
pub version: String,
|
||
pub endpoint: String,
|
||
}
|
||
|
||
|
||
impl AppConfig {
|
||
pub fn database_url(&self) -> String {
|
||
format!(
|
||
"postgres://{}:{}@{}:{}/{}",
|
||
self.database.username,
|
||
self.database.password,
|
||
self.database.server,
|
||
self.database.port,
|
||
self.database.database
|
||
)
|
||
}
|
||
|
||
pub fn database_custom_url(&self) -> String {
|
||
format!(
|
||
"postgres://{}:{}@{}:{}/{}",
|
||
self.database_custom.username,
|
||
self.database_custom.password,
|
||
self.database_custom.server,
|
||
self.database_custom.port,
|
||
self.database_custom.database
|
||
)
|
||
}
|
||
|
||
|
||
pub fn from_env() -> Self {
|
||
let database = DatabaseConfig {
|
||
username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()),
|
||
password: env::var("TABLES_PASSWORD").unwrap_or_else(|_| "pass".to_string()),
|
||
server: env::var("TABLES_SERVER").unwrap_or_else(|_| "localhost".to_string()),
|
||
port: env::var("TABLES_PORT")
|
||
.ok()
|
||
.and_then(|p| p.parse().ok())
|
||
.unwrap_or(5432),
|
||
database: env::var("TABLES_DATABASE").unwrap_or_else(|_| "db".to_string()),
|
||
};
|
||
|
||
let database_custom = DatabaseConfig {
|
||
username: env::var("CUSTOM_USERNAME").unwrap_or_else(|_| "user".to_string()),
|
||
password: env::var("CUSTOM_PASSWORD").unwrap_or_else(|_| "pass".to_string()),
|
||
server: env::var("CUSTOM_SERVER").unwrap_or_else(|_| "localhost".to_string()),
|
||
port: env::var("CUSTOM_PORT")
|
||
.ok()
|
||
.and_then(|p| p.parse().ok())
|
||
.unwrap_or(5432),
|
||
database: env::var("CUSTOM_DATABASE").unwrap_or_else(|_| "db".to_string()),
|
||
};
|
||
|
||
let minio = MinioConfig {
|
||
server: env::var("DRIVE_SERVER").expect("DRIVE_SERVER not set"),
|
||
access_key: env::var("DRIVE_ACCESSKEY").expect("DRIVE_ACCESSKEY not set"),
|
||
secret_key: env::var("DRIVE_SECRET").expect("DRIVE_SECRET not set"),
|
||
use_ssl: env::var("DRIVE_USE_SSL")
|
||
.unwrap_or_else(|_| "false".to_string())
|
||
.parse()
|
||
.unwrap_or(false),
|
||
bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "".to_string()),
|
||
};
|
||
|
||
let email = EmailConfig {
|
||
from: env::var("EMAIL_FROM").expect("EMAIL_FROM not set"),
|
||
server: env::var("EMAIL_SERVER").expect("EMAIL_SERVER not set"),
|
||
port: env::var("EMAIL_PORT")
|
||
.expect("EMAIL_PORT not set")
|
||
.parse()
|
||
.expect("EMAIL_PORT must be a number"),
|
||
username: env::var("EMAIL_USER").expect("EMAIL_USER not set"),
|
||
password: env::var("EMAIL_PASS").expect("EMAIL_PASS not set"),
|
||
};
|
||
|
||
let ai = AIConfig {
|
||
instance: env::var("AI_INSTANCE").expect("AI_INSTANCE not set"),
|
||
key: env::var("AI_KEY").expect("AI_KEY not set"),
|
||
version: env::var("AI_VERSION").expect("AI_VERSION not set"),
|
||
endpoint: env::var("AI_ENDPOINT").expect("AI_ENDPOINT not set"),
|
||
};
|
||
|
||
AppConfig {
|
||
minio,
|
||
server: ServerConfig {
|
||
host: env::var("SERVER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
|
||
port: env::var("SERVER_PORT")
|
||
.ok()
|
||
.and_then(|p| p.parse().ok())
|
||
.unwrap_or(8080),
|
||
},
|
||
database,
|
||
database_custom,
|
||
email,
|
||
ai,
|
||
site_path: env::var("SITES_ROOT").unwrap()
|
||
}
|
||
}
|
||
}
|
||
use argon2::{
|
||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||
Argon2,
|
||
};
|
||
use redis::Client;
|
||
use sqlx::{PgPool, Row}; // <-- required for .get()
|
||
use std::sync::Arc;
|
||
use uuid::Uuid;
|
||
|
||
pub struct AuthService {
|
||
pub pool: PgPool,
|
||
pub redis: Option<Arc<Client>>,
|
||
}
|
||
|
||
impl AuthService {
|
||
#[allow(clippy::new_without_default)]
|
||
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
||
Self { pool, redis }
|
||
}
|
||
|
||
pub async fn verify_user(
|
||
&self,
|
||
username: &str,
|
||
password: &str,
|
||
) -> Result<Option<Uuid>, Box<dyn std::error::Error>> {
|
||
let user = sqlx::query(
|
||
"SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true",
|
||
)
|
||
.bind(username)
|
||
.fetch_optional(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(row) = user {
|
||
let user_id: Uuid = row.get("id");
|
||
let password_hash: String = row.get("password_hash");
|
||
|
||
if let Ok(parsed_hash) = PasswordHash::new(&password_hash) {
|
||
if Argon2::default()
|
||
.verify_password(password.as_bytes(), &parsed_hash)
|
||
.is_ok()
|
||
{
|
||
return Ok(Some(user_id));
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(None)
|
||
}
|
||
|
||
pub async fn create_user(
|
||
&self,
|
||
username: &str,
|
||
email: &str,
|
||
password: &str,
|
||
) -> Result<Uuid, Box<dyn std::error::Error>> {
|
||
let salt = SaltString::generate(&mut OsRng);
|
||
let argon2 = Argon2::default();
|
||
let password_hash = match argon2.hash_password(password.as_bytes(), &salt) {
|
||
Ok(ph) => ph.to_string(),
|
||
Err(e) => {
|
||
return Err(Box::new(std::io::Error::new(
|
||
std::io::ErrorKind::Other,
|
||
e.to_string(),
|
||
)))
|
||
}
|
||
};
|
||
|
||
let row = sqlx::query(
|
||
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id",
|
||
)
|
||
.bind(username)
|
||
.bind(email)
|
||
.bind(&password_hash)
|
||
.fetch_one(&self.pool)
|
||
.await?;
|
||
|
||
Ok(row.get::<Uuid, _>("id"))
|
||
}
|
||
|
||
pub async fn delete_user_cache(
|
||
&self,
|
||
username: &str,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("auth:user:{}", username);
|
||
|
||
let _: () = redis::Cmd::del(&cache_key).query_async(&mut conn).await?;
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn update_user_password(
|
||
&self,
|
||
user_id: Uuid,
|
||
new_password: &str,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
let salt = SaltString::generate(&mut OsRng);
|
||
let argon2 = Argon2::default();
|
||
let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) {
|
||
Ok(ph) => ph.to_string(),
|
||
Err(e) => {
|
||
return Err(Box::new(std::io::Error::new(
|
||
std::io::ErrorKind::Other,
|
||
e.to_string(),
|
||
)))
|
||
}
|
||
};
|
||
|
||
sqlx::query("UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2")
|
||
.bind(&password_hash)
|
||
.bind(user_id)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(user_row) = sqlx::query("SELECT username FROM users WHERE id = $1")
|
||
.bind(user_id)
|
||
.fetch_optional(&self.pool)
|
||
.await?
|
||
{
|
||
let username: String = user_row.get("username");
|
||
self.delete_user_cache(&username).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl Clone for AuthService {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
pool: self.pool.clone(),
|
||
redis: self.redis.clone(),
|
||
}
|
||
}
|
||
}
|
||
|
||
use std::sync::Arc;
|
||
|
||
use crate::{
|
||
bot::BotOrchestrator,
|
||
channels::{VoiceAdapter, WebChannelAdapter, WhatsAppAdapter},
|
||
config::AppConfig,
|
||
tools::ToolApi,
|
||
web_automation::BrowserPool
|
||
};
|
||
|
||
#[derive(Clone)]
|
||
pub struct AppState {
|
||
pub minio_client: Option<minio::s3::Client>,
|
||
pub config: Option<AppConfig>,
|
||
pub db: Option<sqlx::PgPool>,
|
||
pub db_custom: Option<sqlx::PgPool>,
|
||
pub browser_pool: Arc<BrowserPool>,
|
||
pub orchestrator: Arc<BotOrchestrator>,
|
||
pub web_adapter: Arc<WebChannelAdapter>,
|
||
pub voice_adapter: Arc<VoiceAdapter>,
|
||
pub whatsapp_adapter: Arc<WhatsAppAdapter>,
|
||
pub tool_api: Arc<ToolApi>,
|
||
}
|
||
|
||
pub struct BotState {
|
||
pub language: String,
|
||
pub work_folder: String,
|
||
}
|
||
|
||
use chrono::{DateTime, Utc};
|
||
use langchain_rust::llm::AzureConfig;
|
||
use log::{debug, warn};
|
||
use rhai::{Array, Dynamic};
|
||
use serde_json::{json, Value};
|
||
use smartstring::SmartString;
|
||
use sqlx::{postgres::PgRow, Column, Decode, Row, Type, TypeInfo};
|
||
use std::error::Error;
|
||
use std::fs::File;
|
||
use std::io::BufReader;
|
||
use std::path::Path;
|
||
use tokio::fs::File as TokioFile;
|
||
use tokio_stream::StreamExt;
|
||
use zip::ZipArchive;
|
||
|
||
use crate::config::AIConfig;
|
||
use reqwest::Client;
|
||
use tokio::io::AsyncWriteExt;
|
||
|
||
pub fn azure_from_config(config: &AIConfig) -> AzureConfig {
|
||
AzureConfig::new()
|
||
.with_api_base(&config.endpoint)
|
||
.with_api_key(&config.key)
|
||
.with_api_version(&config.version)
|
||
.with_deployment_id(&config.instance)
|
||
}
|
||
|
||
pub async fn call_llm(
|
||
text: &str,
|
||
ai_config: &AIConfig,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let azure_config = azure_from_config(&ai_config.clone());
|
||
let open_ai = langchain_rust::llm::OpenAI::new(azure_config);
|
||
|
||
let prompt = text.to_string();
|
||
|
||
match open_ai.invoke(&prompt).await {
|
||
Ok(response_text) => Ok(response_text),
|
||
Err(err) => {
|
||
log::error!("Error invoking LLM API: {}", err);
|
||
Err(Box::new(std::io::Error::new(
|
||
std::io::ErrorKind::Other,
|
||
"Failed to invoke LLM API",
|
||
)))
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn extract_zip_recursive(
|
||
zip_path: &Path,
|
||
destination_path: &Path,
|
||
) -> Result<(), Box<dyn std::error::Error>> {
|
||
let file = File::open(zip_path)?;
|
||
let buf_reader = BufReader::new(file);
|
||
let mut archive = ZipArchive::new(buf_reader)?;
|
||
|
||
for i in 0..archive.len() {
|
||
let mut file = archive.by_index(i)?;
|
||
let outpath = destination_path.join(file.mangled_name());
|
||
|
||
if file.is_dir() {
|
||
std::fs::create_dir_all(&outpath)?;
|
||
} else {
|
||
if let Some(parent) = outpath.parent() {
|
||
if !parent.exists() {
|
||
std::fs::create_dir_all(&parent)?;
|
||
}
|
||
}
|
||
let mut outfile = File::create(&outpath)?;
|
||
std::io::copy(&mut file, &mut outfile)?;
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub fn row_to_json(row: PgRow) -> Result<Value, Box<dyn Error>> {
|
||
let mut result = serde_json::Map::new();
|
||
let columns = row.columns();
|
||
debug!("Converting row with {} columns", columns.len());
|
||
|
||
for (i, column) in columns.iter().enumerate() {
|
||
let column_name = column.name();
|
||
let type_name = column.type_info().name();
|
||
|
||
let value = match type_name {
|
||
"INT4" | "int4" => handle_nullable_type::<i32>(&row, i, column_name),
|
||
"INT8" | "int8" => handle_nullable_type::<i64>(&row, i, column_name),
|
||
"FLOAT4" | "float4" => handle_nullable_type::<f32>(&row, i, column_name),
|
||
"FLOAT8" | "float8" => handle_nullable_type::<f64>(&row, i, column_name),
|
||
"TEXT" | "VARCHAR" | "text" | "varchar" => {
|
||
handle_nullable_type::<String>(&row, i, column_name)
|
||
}
|
||
"BOOL" | "bool" => handle_nullable_type::<bool>(&row, i, column_name),
|
||
"JSON" | "JSONB" | "json" | "jsonb" => handle_json(&row, i, column_name),
|
||
_ => {
|
||
warn!("Unknown type {} for column {}", type_name, column_name);
|
||
handle_nullable_type::<String>(&row, i, column_name)
|
||
}
|
||
};
|
||
|
||
result.insert(column_name.to_string(), value);
|
||
}
|
||
|
||
Ok(Value::Object(result))
|
||
}
|
||
|
||
fn handle_nullable_type<'r, T>(row: &'r PgRow, idx: usize, col_name: &str) -> Value
|
||
where
|
||
T: Type<sqlx::Postgres> + Decode<'r, sqlx::Postgres> + serde::Serialize + std::fmt::Debug,
|
||
{
|
||
match row.try_get::<Option<T>, _>(idx) {
|
||
Ok(Some(val)) => {
|
||
debug!("Successfully read column {} as {:?}", col_name, val);
|
||
json!(val)
|
||
}
|
||
Ok(None) => {
|
||
debug!("Column {} is NULL", col_name);
|
||
Value::Null
|
||
}
|
||
Err(e) => {
|
||
warn!("Failed to read column {}: {}", col_name, e);
|
||
Value::Null
|
||
}
|
||
}
|
||
}
|
||
|
||
fn handle_json(row: &PgRow, idx: usize, col_name: &str) -> Value {
|
||
match row.try_get::<Option<Value>, _>(idx) {
|
||
Ok(Some(val)) => {
|
||
debug!("Successfully read JSON column {} as Value", col_name);
|
||
return val;
|
||
}
|
||
Ok(None) => return Value::Null,
|
||
Err(_) => (),
|
||
}
|
||
|
||
match row.try_get::<Option<String>, _>(idx) {
|
||
Ok(Some(s)) => match serde_json::from_str(&s) {
|
||
Ok(val) => val,
|
||
Err(_) => {
|
||
debug!("Column {} contains string that's not JSON", col_name);
|
||
json!(s)
|
||
}
|
||
},
|
||
Ok(None) => Value::Null,
|
||
Err(e) => {
|
||
warn!("Failed to read JSON column {}: {}", col_name, e);
|
||
Value::Null
|
||
}
|
||
}
|
||
}
|
||
|
||
pub fn json_value_to_dynamic(value: &Value) -> Dynamic {
|
||
match value {
|
||
Value::Null => Dynamic::UNIT,
|
||
Value::Bool(b) => Dynamic::from(*b),
|
||
Value::Number(n) => {
|
||
if let Some(i) = n.as_i64() {
|
||
Dynamic::from(i)
|
||
} else if let Some(f) = n.as_f64() {
|
||
Dynamic::from(f)
|
||
} else {
|
||
Dynamic::UNIT
|
||
}
|
||
}
|
||
Value::String(s) => Dynamic::from(s.clone()),
|
||
Value::Array(arr) => Dynamic::from(
|
||
arr.iter()
|
||
.map(json_value_to_dynamic)
|
||
.collect::<rhai::Array>(),
|
||
),
|
||
Value::Object(obj) => Dynamic::from(
|
||
obj.iter()
|
||
.map(|(k, v)| (SmartString::from(k), json_value_to_dynamic(v)))
|
||
.collect::<rhai::Map>(),
|
||
),
|
||
}
|
||
}
|
||
|
||
pub fn to_array(value: Dynamic) -> Array {
|
||
if value.is_array() {
|
||
value.cast::<Array>()
|
||
} else if value.is_unit() || value.is::<()>() {
|
||
Array::new()
|
||
} else {
|
||
Array::from([value])
|
||
}
|
||
}
|
||
|
||
pub async fn download_file(url: &str, output_path: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||
let client = Client::new();
|
||
let response = client.get(url).send().await?;
|
||
|
||
if response.status().is_success() {
|
||
let mut file = TokioFile::create(output_path).await?;
|
||
|
||
let mut stream = response.bytes_stream();
|
||
|
||
while let Some(chunk) = stream.next().await {
|
||
file.write_all(&chunk?).await?;
|
||
}
|
||
debug!("File downloaded successfully to {}", output_path);
|
||
} else {
|
||
return Err("Failed to download file".into());
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub fn parse_filter(filter_str: &str) -> Result<(String, Vec<String>), Box<dyn Error>> {
|
||
let parts: Vec<&str> = filter_str.split('=').collect();
|
||
if parts.len() != 2 {
|
||
return Err("Invalid filter format. Expected 'KEY=VALUE'".into());
|
||
}
|
||
|
||
let column = parts[0].trim();
|
||
let value = parts[1].trim();
|
||
|
||
if !column
|
||
.chars()
|
||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||
{
|
||
return Err("Invalid column name in filter".into());
|
||
}
|
||
|
||
Ok((format!("{} = $1", column), vec![value.to_string()]))
|
||
}
|
||
|
||
pub fn parse_filter_with_offset(
|
||
filter_str: &str,
|
||
offset: usize,
|
||
) -> Result<(String, Vec<String>), Box<dyn Error>> {
|
||
let mut clauses = Vec::new();
|
||
let mut params = Vec::new();
|
||
|
||
for (i, condition) in filter_str.split('&').enumerate() {
|
||
let parts: Vec<&str> = condition.split('=').collect();
|
||
if parts.len() != 2 {
|
||
return Err("Invalid filter format".into());
|
||
}
|
||
|
||
let column = parts[0].trim();
|
||
let value = parts[1].trim();
|
||
|
||
if !column
|
||
.chars()
|
||
.all(|c| c.is_ascii_alphanumeric() || c == '_')
|
||
{
|
||
return Err("Invalid column name".into());
|
||
}
|
||
|
||
clauses.push(format!("{} = ${}", column, i + 1 + offset));
|
||
params.push(value.to_string());
|
||
}
|
||
|
||
Ok((clauses.join(" AND "), params))
|
||
}
|
||
|
||
use chrono::{DateTime, Utc};
|
||
use serde::{Deserialize, Serialize};
|
||
use sqlx::FromRow;
|
||
use uuid::Uuid;
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||
pub struct Organization {
|
||
pub org_id: Uuid,
|
||
pub name: String,
|
||
pub slug: String,
|
||
pub created_at: DateTime<Utc>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||
pub struct Bot {
|
||
pub bot_id: Uuid,
|
||
pub name: String,
|
||
pub status: BotStatus,
|
||
pub config: serde_json::Value,
|
||
pub created_at: DateTime<Utc>,
|
||
pub updated_at: DateTime<Utc>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)]
|
||
#[serde(rename_all = "snake_case")]
|
||
#[sqlx(type_name = "bot_status", rename_all = "snake_case")]
|
||
pub enum BotStatus {
|
||
Active,
|
||
Inactive,
|
||
Maintenance,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||
pub enum TriggerKind {
|
||
Scheduled = 0,
|
||
TableUpdate = 1,
|
||
TableInsert = 2,
|
||
TableDelete = 3,
|
||
}
|
||
|
||
impl TriggerKind {
|
||
pub fn from_i32(value: i32) -> Option<Self> {
|
||
match value {
|
||
0 => Some(Self::Scheduled),
|
||
1 => Some(Self::TableUpdate),
|
||
2 => Some(Self::TableInsert),
|
||
3 => Some(Self::TableDelete),
|
||
_ => None,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, FromRow, Serialize, Deserialize)]
|
||
pub struct Automation {
|
||
pub id: Uuid,
|
||
pub kind: i32,
|
||
pub target: Option<String>,
|
||
pub schedule: Option<String>,
|
||
pub param: String,
|
||
pub is_active: bool,
|
||
pub last_triggered: Option<DateTime<Utc>>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||
pub struct UserSession {
|
||
pub id: Uuid,
|
||
pub user_id: Uuid,
|
||
pub bot_id: Uuid,
|
||
pub title: String,
|
||
pub context_data: serde_json::Value,
|
||
pub answer_mode: String,
|
||
pub current_tool: Option<String>,
|
||
pub created_at: DateTime<Utc>,
|
||
pub updated_at: DateTime<Utc>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct EmbeddingRequest {
|
||
pub text: String,
|
||
pub model: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct EmbeddingResponse {
|
||
pub embedding: Vec<f32>,
|
||
pub model: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct SearchResult {
|
||
pub text: String,
|
||
pub similarity: f32,
|
||
pub metadata: serde_json::Value,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct UserMessage {
|
||
pub bot_id: String,
|
||
pub user_id: String,
|
||
pub session_id: String,
|
||
pub channel: String,
|
||
pub content: String,
|
||
pub message_type: String,
|
||
pub media_url: Option<String>,
|
||
pub timestamp: DateTime<Utc>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct BotResponse {
|
||
pub bot_id: String,
|
||
pub user_id: String,
|
||
pub session_id: String,
|
||
pub channel: String,
|
||
pub content: String,
|
||
pub message_type: String,
|
||
pub stream_token: Option<String>,
|
||
pub is_complete: bool,
|
||
}
|
||
|
||
pub mod models;
|
||
pub mod state;
|
||
pub mod utils;
|
||
|
||
pub use models::*;
|
||
pub use state::*;
|
||
pub use utils::*;
|
||
|
||
use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||
use actix_ws::Message as WsMessage;
|
||
use chrono::Utc;
|
||
use langchain_rust::{
|
||
chain::{Chain, LLMChain},
|
||
llm::openai::OpenAI,
|
||
memory::SimpleMemory,
|
||
prompt_args,
|
||
tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder},
|
||
vectorstore::qdrant::Qdrant as LangChainQdrant,
|
||
vectorstore::{VecStoreOptions, VectorStore},
|
||
};
|
||
use log::info;
|
||
use serde_json;
|
||
use std::collections::HashMap;
|
||
use std::fs;
|
||
use std::sync::Arc;
|
||
use tokio::sync::{mpsc, Mutex};
|
||
use uuid::Uuid;
|
||
|
||
use crate::{
|
||
auth::AuthService,
|
||
channels::ChannelAdapter,
|
||
llm::LLMProvider,
|
||
session::SessionManager,
|
||
shared::{BotResponse, UserMessage, UserSession},
|
||
tools::ToolManager,
|
||
};
|
||
|
||
pub struct BotOrchestrator {
|
||
session_manager: SessionManager,
|
||
tool_manager: ToolManager,
|
||
llm_provider: Arc<dyn LLMProvider>,
|
||
auth_service: AuthService,
|
||
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||
vector_store: Option<Arc<LangChainQdrant>>,
|
||
sql_chain: Option<Arc<LLMChain>>,
|
||
}
|
||
|
||
impl BotOrchestrator {
|
||
pub fn new(
|
||
session_manager: SessionManager,
|
||
tool_manager: ToolManager,
|
||
llm_provider: Arc<dyn LLMProvider>,
|
||
auth_service: AuthService,
|
||
vector_store: Option<Arc<LangChainQdrant>>,
|
||
sql_chain: Option<Arc<LLMChain>>,
|
||
) -> Self {
|
||
Self {
|
||
session_manager,
|
||
tool_manager,
|
||
llm_provider,
|
||
auth_service,
|
||
channels: HashMap::new(),
|
||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
||
vector_store,
|
||
sql_chain,
|
||
}
|
||
}
|
||
|
||
pub fn add_channel(&mut self, channel_type: &str, adapter: Arc<dyn ChannelAdapter>) {
|
||
self.channels.insert(channel_type.to_string(), adapter);
|
||
}
|
||
|
||
pub async fn register_response_channel(
|
||
&self,
|
||
session_id: String,
|
||
sender: mpsc::Sender<BotResponse>,
|
||
) {
|
||
self.response_channels
|
||
.lock()
|
||
.await
|
||
.insert(session_id, sender);
|
||
}
|
||
|
||
pub async fn set_user_answer_mode(
|
||
&self,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
mode: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
self.session_manager
|
||
.update_answer_mode(user_id, bot_id, mode)
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn process_message(
|
||
&self,
|
||
message: UserMessage,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
info!(
|
||
"Processing message from channel: {}, user: {}",
|
||
message.channel, message.user_id
|
||
);
|
||
|
||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||
|
||
let session = match self
|
||
.session_manager
|
||
.get_user_session(user_id, bot_id)
|
||
.await?
|
||
{
|
||
Some(session) => session,
|
||
None => {
|
||
self.session_manager
|
||
.create_session(user_id, bot_id, "New Conversation")
|
||
.await?
|
||
}
|
||
};
|
||
|
||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
||
self.tool_manager
|
||
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
||
.await?;
|
||
return Ok(());
|
||
}
|
||
|
||
self.session_manager
|
||
.save_message(
|
||
session.id,
|
||
user_id,
|
||
"user",
|
||
&message.content,
|
||
&message.message_type,
|
||
)
|
||
.await?;
|
||
|
||
let response_content = match session.answer_mode.as_str() {
|
||
"document" => self.document_mode_handler(&message, &session).await?,
|
||
"database" => self.database_mode_handler(&message, &session).await?,
|
||
"tool" => self.tool_mode_handler(&message, &session).await?,
|
||
_ => self.direct_mode_handler(&message, &session).await?,
|
||
};
|
||
|
||
self.session_manager
|
||
.save_message(session.id, user_id, "assistant", &response_content, "text")
|
||
.await?;
|
||
|
||
let bot_response = BotResponse {
|
||
bot_id: message.bot_id,
|
||
user_id: message.user_id,
|
||
session_id: message.session_id,
|
||
channel: message.channel,
|
||
content: response_content,
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||
adapter.send_message(bot_response).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
async fn document_mode_handler(
|
||
&self,
|
||
message: &UserMessage,
|
||
_session: &UserSession,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
if let Some(vector_store) = &self.vector_store {
|
||
let similar_docs = vector_store
|
||
.similarity_search(&message.content, 3, &VecStoreOptions::default())
|
||
.await?;
|
||
|
||
let mut enhanced_prompt = format!("User question: {}\n\n", message.content);
|
||
|
||
if !similar_docs.is_empty() {
|
||
enhanced_prompt.push_str("Relevant documents:\n");
|
||
for (i, doc) in similar_docs.iter().enumerate() {
|
||
enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content));
|
||
}
|
||
enhanced_prompt.push_str(
|
||
"\nPlease answer the user's question based on the provided documents.",
|
||
);
|
||
}
|
||
|
||
self.llm_provider
|
||
.generate(&enhanced_prompt, &serde_json::Value::Null)
|
||
.await
|
||
} else {
|
||
Ok("Document mode not available".to_string())
|
||
}
|
||
}
|
||
|
||
async fn database_mode_handler(
|
||
&self,
|
||
message: &UserMessage,
|
||
_session: &UserSession,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
if let Some(sql_chain) = &self.sql_chain {
|
||
let input_variables = prompt_args! {
|
||
"input" => message.content,
|
||
};
|
||
|
||
let result = sql_chain.invoke(input_variables).await?;
|
||
Ok(result.to_string())
|
||
} else {
|
||
let db_url = std::env::var("DATABASE_URL")?;
|
||
let engine = PostgreSQLEngine::new(&db_url).await?;
|
||
let db = SQLDatabaseBuilder::new(engine).build().await?;
|
||
|
||
let llm = OpenAI::default();
|
||
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
|
||
.llm(llm)
|
||
.top_k(5)
|
||
.database(db)
|
||
.build()?;
|
||
|
||
let input_variables = chain.prompt_builder().query(&message.content).build();
|
||
let result = chain.invoke(input_variables).await?;
|
||
|
||
Ok(result.to_string())
|
||
}
|
||
}
|
||
|
||
async fn tool_mode_handler(
|
||
&self,
|
||
message: &UserMessage,
|
||
_session: &UserSession,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
if message.content.to_lowercase().contains("calculator") {
|
||
if let Some(_adapter) = self.channels.get(&message.channel) {
|
||
let (tx, _rx) = mpsc::channel(100);
|
||
|
||
self.register_response_channel(message.session_id.clone(), tx.clone())
|
||
.await;
|
||
|
||
let tool_manager = self.tool_manager.clone();
|
||
let user_id_str = message.user_id.clone();
|
||
let bot_id_str = message.bot_id.clone();
|
||
let session_manager = self.session_manager.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let _ = tool_manager
|
||
.execute_tool_with_session(
|
||
"calculator",
|
||
&user_id_str,
|
||
&bot_id_str,
|
||
session_manager,
|
||
tx,
|
||
)
|
||
.await;
|
||
});
|
||
}
|
||
Ok("Starting calculator tool...".to_string())
|
||
} else {
|
||
let available_tools = self.tool_manager.list_tools();
|
||
let tools_context = if !available_tools.is_empty() {
|
||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
||
} else {
|
||
String::new()
|
||
};
|
||
|
||
let full_prompt = format!("{}{}", message.content, tools_context);
|
||
|
||
self.llm_provider
|
||
.generate(&full_prompt, &serde_json::Value::Null)
|
||
.await
|
||
}
|
||
}
|
||
|
||
async fn direct_mode_handler(
|
||
&self,
|
||
message: &UserMessage,
|
||
session: &UserSession,
|
||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||
let history = self
|
||
.session_manager
|
||
.get_conversation_history(session.id, session.user_id)
|
||
.await?;
|
||
|
||
let mut memory = SimpleMemory::new();
|
||
for (role, content) in history {
|
||
match role.as_str() {
|
||
"user" => memory.add_user_message(&content),
|
||
"assistant" => memory.add_ai_message(&content),
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
let mut prompt = String::new();
|
||
if let Some(chat_history) = memory.get_chat_history() {
|
||
for message in chat_history {
|
||
prompt.push_str(&format!(
|
||
"{}: {}\n",
|
||
message.message_type(),
|
||
message.content()
|
||
));
|
||
}
|
||
}
|
||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||
|
||
self.llm_provider
|
||
.generate(&prompt, &serde_json::Value::Null)
|
||
.await
|
||
}
|
||
|
||
pub async fn stream_response(
|
||
&self,
|
||
message: UserMessage,
|
||
mut response_tx: mpsc::Sender<BotResponse>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
info!("Streaming response for user: {}", message.user_id);
|
||
|
||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||
|
||
let session = match self
|
||
.session_manager
|
||
.get_user_session(user_id, bot_id)
|
||
.await?
|
||
{
|
||
Some(session) => session,
|
||
None => {
|
||
self.session_manager
|
||
.create_session(user_id, bot_id, "New Conversation")
|
||
.await?
|
||
}
|
||
};
|
||
|
||
if session.answer_mode == "tool" && session.current_tool.is_some() {
|
||
self.tool_manager
|
||
.provide_user_response(&message.user_id, &message.bot_id, message.content.clone())
|
||
.await?;
|
||
return Ok(());
|
||
}
|
||
|
||
self.session_manager
|
||
.save_message(
|
||
session.id,
|
||
user_id,
|
||
"user",
|
||
&message.content,
|
||
&message.message_type,
|
||
)
|
||
.await?;
|
||
|
||
let history = self
|
||
.session_manager
|
||
.get_conversation_history(session.id, user_id)
|
||
.await?;
|
||
|
||
let mut memory = SimpleMemory::new();
|
||
for (role, content) in history {
|
||
match role.as_str() {
|
||
"user" => memory.add_user_message(&content),
|
||
"assistant" => memory.add_ai_message(&content),
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
let mut prompt = String::new();
|
||
if let Some(chat_history) = memory.get_chat_history() {
|
||
for message in chat_history {
|
||
prompt.push_str(&format!(
|
||
"{}: {}\n",
|
||
message.message_type(),
|
||
message.content()
|
||
));
|
||
}
|
||
}
|
||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||
|
||
let (stream_tx, mut stream_rx) = mpsc::channel(100);
|
||
let llm_provider = self.llm_provider.clone();
|
||
let prompt_clone = prompt.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let _ = llm_provider
|
||
.generate_stream(&prompt_clone, &serde_json::Value::Null, stream_tx)
|
||
.await;
|
||
});
|
||
|
||
let mut full_response = String::new();
|
||
while let Some(chunk) = stream_rx.recv().await {
|
||
full_response.push_str(&chunk);
|
||
|
||
let bot_response = BotResponse {
|
||
bot_id: message.bot_id.clone(),
|
||
user_id: message.user_id.clone(),
|
||
session_id: message.session_id.clone(),
|
||
channel: message.channel.clone(),
|
||
content: chunk,
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: false,
|
||
};
|
||
|
||
if response_tx.send(bot_response).await.is_err() {
|
||
break;
|
||
}
|
||
}
|
||
|
||
self.session_manager
|
||
.save_message(session.id, user_id, "assistant", &full_response, "text")
|
||
.await?;
|
||
|
||
let final_response = BotResponse {
|
||
bot_id: message.bot_id,
|
||
user_id: message.user_id,
|
||
session_id: message.session_id,
|
||
channel: message.channel,
|
||
content: "".to_string(),
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
response_tx.send(final_response).await?;
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn get_user_sessions(
|
||
&self,
|
||
user_id: Uuid,
|
||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||
self.session_manager.get_user_sessions(user_id).await
|
||
}
|
||
|
||
pub async fn get_conversation_history(
|
||
&self,
|
||
session_id: Uuid,
|
||
user_id: Uuid,
|
||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||
self.session_manager
|
||
.get_conversation_history(session_id, user_id)
|
||
.await
|
||
}
|
||
|
||
pub async fn process_message_with_tools(
|
||
&self,
|
||
message: UserMessage,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
info!(
|
||
"Processing message with tools from user: {}",
|
||
message.user_id
|
||
);
|
||
|
||
let user_id = Uuid::parse_str(&message.user_id).unwrap_or_else(|_| Uuid::new_v4());
|
||
let bot_id = Uuid::parse_str(&message.bot_id)
|
||
.unwrap_or_else(|_| Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap());
|
||
|
||
let session = match self
|
||
.session_manager
|
||
.get_user_session(user_id, bot_id)
|
||
.await?
|
||
{
|
||
Some(session) => session,
|
||
None => {
|
||
self.session_manager
|
||
.create_session(user_id, bot_id, "New Conversation")
|
||
.await?
|
||
}
|
||
};
|
||
|
||
self.session_manager
|
||
.save_message(
|
||
session.id,
|
||
user_id,
|
||
"user",
|
||
&message.content,
|
||
&message.message_type,
|
||
)
|
||
.await?;
|
||
|
||
let is_tool_waiting = self
|
||
.tool_manager
|
||
.is_tool_waiting(&message.session_id)
|
||
.await
|
||
.unwrap_or(false);
|
||
|
||
if is_tool_waiting {
|
||
self.tool_manager
|
||
.provide_input(&message.session_id, &message.content)
|
||
.await?;
|
||
|
||
if let Ok(tool_output) = self.tool_manager.get_tool_output(&message.session_id).await {
|
||
for output in tool_output {
|
||
let bot_response = BotResponse {
|
||
bot_id: message.bot_id.clone(),
|
||
user_id: message.user_id.clone(),
|
||
session_id: message.session_id.clone(),
|
||
channel: message.channel.clone(),
|
||
content: output,
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||
adapter.send_message(bot_response).await?;
|
||
}
|
||
}
|
||
}
|
||
return Ok(());
|
||
}
|
||
|
||
let response = if message.content.to_lowercase().contains("calculator")
|
||
|| message.content.to_lowercase().contains("calculate")
|
||
|| message.content.to_lowercase().contains("math")
|
||
{
|
||
match self
|
||
.tool_manager
|
||
.execute_tool("calculator", &message.session_id, &message.user_id)
|
||
.await
|
||
{
|
||
Ok(tool_result) => {
|
||
self.session_manager
|
||
.save_message(
|
||
session.id,
|
||
user_id,
|
||
"assistant",
|
||
&tool_result.output,
|
||
"tool_start",
|
||
)
|
||
.await?;
|
||
|
||
tool_result.output
|
||
}
|
||
Err(e) => {
|
||
format!("I encountered an error starting the calculator: {}", e)
|
||
}
|
||
}
|
||
} else {
|
||
let available_tools = self.tool_manager.list_tools();
|
||
let tools_context = if !available_tools.is_empty() {
|
||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
||
} else {
|
||
String::new()
|
||
};
|
||
|
||
let full_prompt = format!("{}{}", message.content, tools_context);
|
||
|
||
self.llm_provider
|
||
.generate(&full_prompt, &serde_json::Value::Null)
|
||
.await?
|
||
};
|
||
|
||
self.session_manager
|
||
.save_message(session.id, user_id, "assistant", &response, "text")
|
||
.await?;
|
||
|
||
let bot_response = BotResponse {
|
||
bot_id: message.bot_id,
|
||
user_id: message.user_id,
|
||
session_id: message.session_id,
|
||
channel: message.channel,
|
||
content: response,
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
if let Some(adapter) = self.channels.get(&message.channel) {
|
||
adapter.send_message(bot_response).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[actix_web::get("/ws")]
|
||
async fn websocket_handler(
|
||
req: HttpRequest,
|
||
stream: web::Payload,
|
||
data: web::Data<crate::shared::AppState>,
|
||
) -> Result<HttpResponse, actix_web::Error> {
|
||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||
let session_id = Uuid::new_v4().to_string();
|
||
let (tx, mut rx) = mpsc::channel::<BotResponse>(100);
|
||
|
||
data.orchestrator
|
||
.register_response_channel(session_id.clone(), tx.clone())
|
||
.await;
|
||
data.web_adapter
|
||
.add_connection(session_id.clone(), tx.clone())
|
||
.await;
|
||
data.voice_adapter
|
||
.add_connection(session_id.clone(), tx.clone())
|
||
.await;
|
||
|
||
let orchestrator = data.orchestrator.clone();
|
||
let web_adapter = data.web_adapter.clone();
|
||
|
||
actix_web::rt::spawn(async move {
|
||
while let Some(msg) = rx.recv().await {
|
||
if let Ok(json) = serde_json::to_string(&msg) {
|
||
let _ = session.text(json).await;
|
||
}
|
||
}
|
||
});
|
||
|
||
actix_web::rt::spawn(async move {
|
||
while let Some(Ok(msg)) = msg_stream.recv().await {
|
||
match msg {
|
||
WsMessage::Text(text) => {
|
||
let user_message = UserMessage {
|
||
bot_id: "default_bot".to_string(),
|
||
user_id: "default_user".to_string(),
|
||
session_id: session_id.clone(),
|
||
channel: "web".to_string(),
|
||
content: text.to_string(),
|
||
message_type: "text".to_string(),
|
||
media_url: None,
|
||
timestamp: Utc::now(),
|
||
};
|
||
|
||
if let Err(e) = orchestrator.stream_response(user_message, tx.clone()).await {
|
||
info!("Error processing message: {}", e);
|
||
}
|
||
}
|
||
WsMessage::Close(_) => {
|
||
web_adapter.remove_connection(&session_id).await;
|
||
break;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
});
|
||
|
||
Ok(res)
|
||
}
|
||
|
||
#[actix_web::get("/api/whatsapp/webhook")]
|
||
async fn whatsapp_webhook_verify(
|
||
data: web::Data<crate::shared::AppState>,
|
||
web::Query(params): web::Query<HashMap<String, String>>,
|
||
) -> Result<HttpResponse> {
|
||
let mode = params.get("hub.mode").unwrap_or(&"".to_string());
|
||
let token = params.get("hub.verify_token").unwrap_or(&"".to_string());
|
||
let challenge = params.get("hub.challenge").unwrap_or(&"".to_string());
|
||
|
||
match data.whatsapp_adapter.verify_webhook(mode, token, challenge) {
|
||
Ok(challenge_response) => Ok(HttpResponse::Ok().body(challenge_response)),
|
||
Err(_) => Ok(HttpResponse::Forbidden().body("Verification failed")),
|
||
}
|
||
}
|
||
|
||
#[actix_web::post("/api/whatsapp/webhook")]
|
||
async fn whatsapp_webhook(
|
||
data: web::Data<crate::shared::AppState>,
|
||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||
) -> Result<HttpResponse> {
|
||
match data
|
||
.whatsapp_adapter
|
||
.process_incoming_message(payload.into_inner())
|
||
.await
|
||
{
|
||
Ok(user_messages) => {
|
||
for user_message in user_messages {
|
||
if let Err(e) = data.orchestrator.process_message(user_message).await {
|
||
log::error!("Error processing WhatsApp message: {}", e);
|
||
}
|
||
}
|
||
Ok(HttpResponse::Ok().body(""))
|
||
}
|
||
Err(e) => {
|
||
log::error!("Error processing WhatsApp webhook: {}", e);
|
||
Ok(HttpResponse::BadRequest().body("Invalid message"))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[actix_web::post("/api/voice/start")]
|
||
async fn voice_start(
|
||
data: web::Data<crate::shared::AppState>,
|
||
info: web::Json<serde_json::Value>,
|
||
) -> Result<HttpResponse> {
|
||
let session_id = info
|
||
.get("session_id")
|
||
.and_then(|s| s.as_str())
|
||
.unwrap_or("");
|
||
let user_id = info
|
||
.get("user_id")
|
||
.and_then(|u| u.as_str())
|
||
.unwrap_or("user");
|
||
|
||
match data
|
||
.voice_adapter
|
||
.start_voice_session(session_id, user_id)
|
||
.await
|
||
{
|
||
Ok(token) => {
|
||
Ok(HttpResponse::Ok().json(serde_json::json!({"token": token, "status": "started"})))
|
||
}
|
||
Err(e) => {
|
||
Ok(HttpResponse::InternalServerError()
|
||
.json(serde_json::json!({"error": e.to_string()})))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[actix_web::post("/api/voice/stop")]
|
||
async fn voice_stop(
|
||
data: web::Data<crate::shared::AppState>,
|
||
info: web::Json<serde_json::Value>,
|
||
) -> Result<HttpResponse> {
|
||
let session_id = info
|
||
.get("session_id")
|
||
.and_then(|s| s.as_str())
|
||
.unwrap_or("");
|
||
|
||
match data.voice_adapter.stop_voice_session(session_id).await {
|
||
Ok(()) => Ok(HttpResponse::Ok().json(serde_json::json!({"status": "stopped"}))),
|
||
Err(e) => {
|
||
Ok(HttpResponse::InternalServerError()
|
||
.json(serde_json::json!({"error": e.to_string()})))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[actix_web::post("/api/sessions")]
|
||
async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||
let session_id = Uuid::new_v4();
|
||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||
"session_id": session_id,
|
||
"title": "New Conversation",
|
||
"created_at": Utc::now()
|
||
})))
|
||
}
|
||
|
||
#[actix_web::get("/api/sessions")]
|
||
async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||
match data.orchestrator.get_user_sessions(user_id).await {
|
||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||
Err(e) => {
|
||
Ok(HttpResponse::InternalServerError()
|
||
.json(serde_json::json!({"error": e.to_string()})))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[actix_web::get("/api/sessions/{session_id}")]
|
||
async fn get_session_history(
|
||
data: web::Data<crate::shared::AppState>,
|
||
path: web::Path<String>,
|
||
) -> Result<HttpResponse> {
|
||
let session_id = path.into_inner();
|
||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||
|
||
match Uuid::parse_str(&session_id) {
|
||
Ok(session_uuid) => match data
|
||
.orchestrator
|
||
.get_conversation_history(session_uuid, user_id)
|
||
.await
|
||
{
|
||
Ok(history) => Ok(HttpResponse::Ok().json(history)),
|
||
Err(e) => Ok(HttpResponse::InternalServerError()
|
||
.json(serde_json::json!({"error": e.to_string()}))),
|
||
},
|
||
Err(_) => {
|
||
Ok(HttpResponse::BadRequest().json(serde_json::json!({"error": "Invalid session ID"})))
|
||
}
|
||
}
|
||
}
|
||
|
||
#[actix_web::post("/api/set_mode")]
|
||
async fn set_mode_handler(
|
||
data: web::Data<crate::shared::AppState>,
|
||
info: web::Json<HashMap<String, String>>,
|
||
) -> Result<HttpResponse> {
|
||
let default_user = "default_user".to_string();
|
||
let default_bot = "default_bot".to_string();
|
||
let default_mode = "direct".to_string();
|
||
|
||
let user_id = info.get("user_id").unwrap_or(&default_user);
|
||
let bot_id = info.get("bot_id").unwrap_or(&default_bot);
|
||
let mode = info.get("mode").unwrap_or(&default_mode);
|
||
|
||
if let Err(e) = data
|
||
.orchestrator
|
||
.set_user_answer_mode(user_id, bot_id, mode)
|
||
.await
|
||
{
|
||
return Ok(
|
||
HttpResponse::InternalServerError().json(serde_json::json!({"error": e.to_string()}))
|
||
);
|
||
}
|
||
|
||
Ok(HttpResponse::Ok().json(serde_json::json!({"status": "mode_updated"})))
|
||
}
|
||
|
||
#[actix_web::get("/")]
|
||
async fn index() -> Result<HttpResponse> {
|
||
let html = fs::read_to_string("templates/index.html")
|
||
.unwrap_or_else(|_| include_str!("../../static/index.html").to_string());
|
||
Ok(HttpResponse::Ok().content_type("text/html").body(html))
|
||
}
|
||
|
||
#[actix_web::get("/static/{filename:.*}")]
|
||
async fn static_files(req: HttpRequest) -> Result<HttpResponse> {
|
||
let filename = req.match_info().query("filename");
|
||
let path = format!("static/{}", filename);
|
||
|
||
match fs::read(&path) {
|
||
Ok(content) => {
|
||
let content_type = match filename {
|
||
f if f.ends_with(".js") => "application/javascript",
|
||
f if f.ends_with(".css") => "text/css",
|
||
f if f.ends_with(".png") => "image/png",
|
||
f if f.ends_with(".jpg") | f.ends_with(".jpeg") => "image/jpeg",
|
||
_ => "text/plain",
|
||
};
|
||
|
||
Ok(HttpResponse::Ok().content_type(content_type).body(content))
|
||
}
|
||
Err(_) => Ok(HttpResponse::NotFound().body("File not found")),
|
||
}
|
||
}
|
||
|
||
use redis::{AsyncCommands, Client};
|
||
use serde_json;
|
||
use sqlx::{PgPool, Row};
|
||
use std::sync::Arc;
|
||
use uuid::Uuid;
|
||
|
||
use crate::shared::UserSession;
|
||
|
||
pub struct SessionManager {
|
||
pub pool: PgPool,
|
||
pub redis: Option<Arc<Client>>,
|
||
}
|
||
|
||
impl SessionManager {
|
||
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
||
Self { pool, redis }
|
||
}
|
||
|
||
pub async fn get_user_session(
|
||
&self,
|
||
user_id: Uuid,
|
||
bot_id: Uuid,
|
||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||
let session_json: Option<String> = conn.get(&cache_key).await?;
|
||
if let Some(json) = session_json {
|
||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
||
return Ok(Some(session));
|
||
}
|
||
}
|
||
}
|
||
|
||
let session = sqlx::query_as::<_, UserSession>(
|
||
"SELECT * FROM user_sessions WHERE user_id = $1 AND bot_id = $2 ORDER BY updated_at DESC LIMIT 1",
|
||
)
|
||
.bind(user_id)
|
||
.bind(bot_id)
|
||
.fetch_optional(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(ref session) = session {
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||
let session_json = serde_json::to_string(session)?;
|
||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||
}
|
||
}
|
||
|
||
Ok(session)
|
||
}
|
||
|
||
pub async fn create_session(
|
||
&self,
|
||
user_id: Uuid,
|
||
bot_id: Uuid,
|
||
title: &str,
|
||
) -> Result<UserSession, Box<dyn std::error::Error + Send + Sync>> {
|
||
let session = sqlx::query_as::<_, UserSession>(
|
||
"INSERT INTO user_sessions (user_id, bot_id, title) VALUES ($1, $2, $3) RETURNING *",
|
||
)
|
||
.bind(user_id)
|
||
.bind(bot_id)
|
||
.bind(title)
|
||
.fetch_one(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||
let session_json = serde_json::to_string(&session)?;
|
||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||
}
|
||
|
||
Ok(session)
|
||
}
|
||
|
||
pub async fn save_message(
|
||
&self,
|
||
session_id: Uuid,
|
||
user_id: Uuid,
|
||
role: &str,
|
||
content: &str,
|
||
message_type: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let message_count: i64 =
|
||
sqlx::query("SELECT COUNT(*) as count FROM message_history WHERE session_id = $1")
|
||
.bind(session_id)
|
||
.fetch_one(&self.pool)
|
||
.await?
|
||
.get("count");
|
||
|
||
sqlx::query(
|
||
"INSERT INTO message_history (session_id, user_id, role, content_encrypted, message_type, message_index)
|
||
VALUES ($1, $2, $3, $4, $5, $6)",
|
||
)
|
||
.bind(session_id)
|
||
.bind(user_id)
|
||
.bind(role)
|
||
.bind(content)
|
||
.bind(message_type)
|
||
.bind(message_count + 1)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
sqlx::query("UPDATE user_sessions SET updated_at = NOW() WHERE id = $1")
|
||
.bind(session_id)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(redis_client) = &self.redis {
|
||
if let Some(session_info) =
|
||
sqlx::query("SELECT user_id, bot_id FROM user_sessions WHERE id = $1")
|
||
.bind(session_id)
|
||
.fetch_optional(&self.pool)
|
||
.await?
|
||
{
|
||
let user_id: Uuid = session_info.get("user_id");
|
||
let bot_id: Uuid = session_info.get("bot_id");
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_id, bot_id);
|
||
let _: () = conn.del(cache_key).await?;
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn get_conversation_history(
|
||
&self,
|
||
session_id: Uuid,
|
||
user_id: Uuid,
|
||
) -> Result<Vec<(String, String)>, Box<dyn std::error::Error + Send + Sync>> {
|
||
let messages = sqlx::query(
|
||
"SELECT role, content_encrypted FROM message_history
|
||
WHERE session_id = $1 AND user_id = $2
|
||
ORDER BY message_index ASC",
|
||
)
|
||
.bind(session_id)
|
||
.bind(user_id)
|
||
.fetch_all(&self.pool)
|
||
.await?;
|
||
|
||
let history = messages
|
||
.into_iter()
|
||
.map(|row| (row.get("role"), row.get("content_encrypted")))
|
||
.collect();
|
||
|
||
Ok(history)
|
||
}
|
||
|
||
pub async fn get_user_sessions(
|
||
&self,
|
||
user_id: Uuid,
|
||
) -> Result<Vec<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||
let sessions = sqlx::query_as::<_, UserSession>(
|
||
"SELECT * FROM user_sessions WHERE user_id = $1 ORDER BY updated_at DESC",
|
||
)
|
||
.bind(user_id)
|
||
.fetch_all(&self.pool)
|
||
.await?;
|
||
Ok(sessions)
|
||
}
|
||
|
||
pub async fn update_answer_mode(
|
||
&self,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
mode: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let user_uuid = Uuid::parse_str(user_id)?;
|
||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||
|
||
sqlx::query(
|
||
"UPDATE user_sessions
|
||
SET answer_mode = $1, updated_at = NOW()
|
||
WHERE user_id = $2 AND bot_id = $3",
|
||
)
|
||
.bind(mode)
|
||
.bind(user_uuid)
|
||
.bind(bot_uuid)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||
let _: () = conn.del(cache_key).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn update_current_tool(
|
||
&self,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
tool_name: Option<&str>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let user_uuid = Uuid::parse_str(user_id)?;
|
||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||
|
||
sqlx::query(
|
||
"UPDATE user_sessions
|
||
SET current_tool = $1, updated_at = NOW()
|
||
WHERE user_id = $2 AND bot_id = $3",
|
||
)
|
||
.bind(tool_name)
|
||
.bind(user_uuid)
|
||
.bind(bot_uuid)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||
let _: () = conn.del(cache_key).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn get_session_by_id(
|
||
&self,
|
||
session_id: Uuid,
|
||
) -> Result<Option<UserSession>, Box<dyn std::error::Error + Send + Sync>> {
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session_by_id:{}", session_id);
|
||
let session_json: Option<String> = conn.get(&cache_key).await?;
|
||
if let Some(json) = session_json {
|
||
if let Ok(session) = serde_json::from_str::<UserSession>(&json) {
|
||
return Ok(Some(session));
|
||
}
|
||
}
|
||
}
|
||
|
||
let session = sqlx::query_as::<_, UserSession>("SELECT * FROM user_sessions WHERE id = $1")
|
||
.bind(session_id)
|
||
.fetch_optional(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(ref session) = session {
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session_by_id:{}", session_id);
|
||
let session_json = serde_json::to_string(session)?;
|
||
let _: () = conn.set_ex(cache_key, session_json, 1800).await?;
|
||
}
|
||
}
|
||
|
||
Ok(session)
|
||
}
|
||
|
||
pub async fn cleanup_old_sessions(
|
||
&self,
|
||
days_old: i32,
|
||
) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
|
||
let result = sqlx::query(
|
||
"DELETE FROM user_sessions
|
||
WHERE updated_at < NOW() - INTERVAL '1 day' * $1",
|
||
)
|
||
.bind(days_old)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
Ok(result.rows_affected())
|
||
}
|
||
|
||
pub async fn set_current_tool(
|
||
&self,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
tool_name: Option<String>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let user_uuid = Uuid::parse_str(user_id)?;
|
||
let bot_uuid = Uuid::parse_str(bot_id)?;
|
||
|
||
sqlx::query(
|
||
"UPDATE user_sessions
|
||
SET current_tool = $1, updated_at = NOW()
|
||
WHERE user_id = $2 AND bot_id = $3",
|
||
)
|
||
.bind(tool_name)
|
||
.bind(user_uuid)
|
||
.bind(bot_uuid)
|
||
.execute(&self.pool)
|
||
.await?;
|
||
|
||
if let Some(redis_client) = &self.redis {
|
||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||
let cache_key = format!("session:{}:{}", user_uuid, bot_uuid);
|
||
let _: () = conn.del(cache_key).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl Clone for SessionManager {
|
||
fn clone(&self) -> Self {
|
||
Self {
|
||
pool: self.pool.clone(),
|
||
redis: self.redis.clone(),
|
||
}
|
||
}
|
||
}
|
||
|
||
use async_trait::async_trait;
|
||
use redis::AsyncCommands;
|
||
use rhai::{Engine, Scope};
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
use std::sync::Arc;
|
||
use tokio::sync::{mpsc, Mutex};
|
||
use uuid::Uuid;
|
||
|
||
use crate::{
|
||
channels::ChannelAdapter,
|
||
session::SessionManager,
|
||
shared::BotResponse,
|
||
};
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct ToolResult {
|
||
pub success: bool,
|
||
pub output: String,
|
||
pub requires_input: bool,
|
||
pub session_id: String,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct Tool {
|
||
pub name: String,
|
||
pub description: String,
|
||
pub parameters: HashMap<String, String>,
|
||
pub script: String,
|
||
}
|
||
|
||
#[async_trait]
|
||
pub trait ToolExecutor: Send + Sync {
|
||
async fn execute(
|
||
&self,
|
||
tool_name: &str,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>>;
|
||
async fn provide_input(
|
||
&self,
|
||
session_id: &str,
|
||
input: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||
async fn get_output(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>>;
|
||
async fn is_waiting_for_input(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
|
||
}
|
||
|
||
pub struct RedisToolExecutor {
|
||
redis_client: redis::Client,
|
||
web_adapter: Arc<dyn ChannelAdapter>,
|
||
voice_adapter: Arc<dyn ChannelAdapter>,
|
||
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
||
}
|
||
|
||
impl RedisToolExecutor {
|
||
pub fn new(
|
||
redis_url: &str,
|
||
web_adapter: Arc<dyn ChannelAdapter>,
|
||
voice_adapter: Arc<dyn ChannelAdapter>,
|
||
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||
let client = redis::Client::open(redis_url)?;
|
||
Ok(Self {
|
||
redis_client: client,
|
||
web_adapter,
|
||
voice_adapter,
|
||
whatsapp_adapter,
|
||
})
|
||
}
|
||
|
||
async fn send_tool_message(
|
||
&self,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
channel: &str,
|
||
message: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let response = BotResponse {
|
||
bot_id: "tool_bot".to_string(),
|
||
user_id: user_id.to_string(),
|
||
session_id: session_id.to_string(),
|
||
channel: channel.to_string(),
|
||
content: message.to_string(),
|
||
message_type: "tool".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
match channel {
|
||
"web" => self.web_adapter.send_message(response).await,
|
||
"voice" => self.voice_adapter.send_message(response).await,
|
||
"whatsapp" => self.whatsapp_adapter.send_message(response).await,
|
||
_ => Ok(()),
|
||
}
|
||
}
|
||
|
||
fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine {
|
||
let mut engine = Engine::new();
|
||
|
||
let tool_executor = Arc::new((
|
||
self.redis_client.clone(),
|
||
self.web_adapter.clone(),
|
||
self.voice_adapter.clone(),
|
||
self.whatsapp_adapter.clone(),
|
||
));
|
||
|
||
let session_id_clone = session_id.clone();
|
||
let user_id_clone = user_id.clone();
|
||
let channel_clone = channel.clone();
|
||
|
||
engine.register_fn("talk", move |message: String| {
|
||
let tool_executor = Arc::clone(&tool_executor);
|
||
let session_id = session_id_clone.clone();
|
||
let user_id = user_id_clone.clone();
|
||
let channel = channel_clone.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor;
|
||
|
||
let response = BotResponse {
|
||
bot_id: "tool_bot".to_string(),
|
||
user_id: user_id.clone(),
|
||
session_id: session_id.clone(),
|
||
channel: channel.clone(),
|
||
content: message.clone(),
|
||
message_type: "tool".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
let result = match channel.as_str() {
|
||
"web" => web_adapter.send_message(response).await,
|
||
"voice" => voice_adapter.send_message(response).await,
|
||
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
||
_ => Ok(()),
|
||
};
|
||
|
||
if let Err(e) = result {
|
||
log::error!("Failed to send tool message: {}", e);
|
||
}
|
||
|
||
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
||
let output_key = format!("tool:{}:output", session_id);
|
||
let _ = conn.lpush(&output_key, &message).await;
|
||
}
|
||
});
|
||
});
|
||
|
||
let hear_executor = self.redis_client.clone();
|
||
let session_id_clone = session_id.clone();
|
||
|
||
engine.register_fn("hear", move || -> String {
|
||
let hear_executor = hear_executor.clone();
|
||
let session_id = session_id_clone.clone();
|
||
|
||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||
rt.block_on(async move {
|
||
match hear_executor.get_async_connection().await {
|
||
Ok(mut conn) => {
|
||
let input_key = format!("tool:{}:input", session_id);
|
||
let waiting_key = format!("tool:{}:waiting", session_id);
|
||
|
||
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
||
let result: Option<(String, String)> =
|
||
conn.brpop(&input_key, 30).await.ok().flatten();
|
||
let _ = conn.del(&waiting_key).await;
|
||
|
||
result
|
||
.map(|(_, input)| input)
|
||
.unwrap_or_else(|| "timeout".to_string())
|
||
}
|
||
Err(e) => {
|
||
log::error!("HEAR Redis error: {}", e);
|
||
"error".to_string()
|
||
}
|
||
}
|
||
})
|
||
});
|
||
|
||
engine
|
||
}
|
||
|
||
async fn cleanup_session(&self, session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
||
|
||
let keys = vec![
|
||
format!("tool:{}:output", session_id),
|
||
format!("tool:{}:input", session_id),
|
||
format!("tool:{}:waiting", session_id),
|
||
format!("tool:{}:active", session_id),
|
||
];
|
||
|
||
for key in keys {
|
||
let _: () = conn.del(&key).await?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait]
|
||
impl ToolExecutor for RedisToolExecutor {
|
||
async fn execute(
|
||
&self,
|
||
tool_name: &str,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
||
let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?;
|
||
|
||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
||
let session_key = format!("tool:{}:session", session_id);
|
||
let session_data = serde_json::json!({
|
||
"user_id": user_id,
|
||
"tool_name": tool_name,
|
||
"started_at": chrono::Utc::now().to_rfc3339(),
|
||
});
|
||
conn.set_ex(&session_key, session_data.to_string(), 3600)
|
||
.await?;
|
||
|
||
let active_key = format!("tool:{}:active", session_id);
|
||
conn.set_ex(&active_key, "true", 3600).await?;
|
||
|
||
let channel = "web";
|
||
let _engine = self.create_rhai_engine(
|
||
session_id.to_string(),
|
||
user_id.to_string(),
|
||
channel.to_string(),
|
||
);
|
||
|
||
let redis_clone = self.redis_client.clone();
|
||
let web_adapter_clone = self.web_adapter.clone();
|
||
let voice_adapter_clone = self.voice_adapter.clone();
|
||
let whatsapp_adapter_clone = self.whatsapp_adapter.clone();
|
||
let session_id_clone = session_id.to_string();
|
||
let user_id_clone = user_id.to_string();
|
||
let tool_script = tool.script.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let mut engine = Engine::new();
|
||
let mut scope = Scope::new();
|
||
|
||
let redis_client = redis_clone.clone();
|
||
let web_adapter = web_adapter_clone.clone();
|
||
let voice_adapter = voice_adapter_clone.clone();
|
||
let whatsapp_adapter = whatsapp_adapter_clone.clone();
|
||
let session_id = session_id_clone.clone();
|
||
let user_id = user_id_clone.clone();
|
||
|
||
engine.register_fn("talk", move |message: String| {
|
||
let redis_client = redis_client.clone();
|
||
let web_adapter = web_adapter.clone();
|
||
let voice_adapter = voice_adapter.clone();
|
||
let whatsapp_adapter = whatsapp_adapter.clone();
|
||
let session_id = session_id.clone();
|
||
let user_id = user_id.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let channel = "web";
|
||
|
||
let response = BotResponse {
|
||
bot_id: "tool_bot".to_string(),
|
||
user_id: user_id.clone(),
|
||
session_id: session_id.clone(),
|
||
channel: channel.to_string(),
|
||
content: message.clone(),
|
||
message_type: "tool".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
let send_result = match channel {
|
||
"web" => web_adapter.send_message(response).await,
|
||
"voice" => voice_adapter.send_message(response).await,
|
||
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
||
_ => Ok(()),
|
||
};
|
||
|
||
if let Err(e) = send_result {
|
||
log::error!("Failed to send tool message: {}", e);
|
||
}
|
||
|
||
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
||
let output_key = format!("tool:{}:output", session_id);
|
||
let _ = conn.lpush(&output_key, &message).await;
|
||
}
|
||
});
|
||
});
|
||
|
||
let hear_redis = redis_clone.clone();
|
||
let session_id_hear = session_id.clone();
|
||
engine.register_fn("hear", move || -> String {
|
||
let hear_redis = hear_redis.clone();
|
||
let session_id = session_id_hear.clone();
|
||
|
||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||
rt.block_on(async move {
|
||
match hear_redis.get_async_connection().await {
|
||
Ok(mut conn) => {
|
||
let input_key = format!("tool:{}:input", session_id);
|
||
let waiting_key = format!("tool:{}:waiting", session_id);
|
||
|
||
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
||
let result: Option<(String, String)> =
|
||
conn.brpop(&input_key, 30).await.ok().flatten();
|
||
let _ = conn.del(&waiting_key).await;
|
||
|
||
result
|
||
.map(|(_, input)| input)
|
||
.unwrap_or_else(|| "timeout".to_string())
|
||
}
|
||
Err(_) => "error".to_string(),
|
||
}
|
||
})
|
||
});
|
||
|
||
match engine.eval_with_scope::<()>(&mut scope, &tool_script) {
|
||
Ok(_) => {
|
||
log::info!(
|
||
"Tool {} completed successfully for session {}",
|
||
tool_name,
|
||
session_id
|
||
);
|
||
|
||
let completion_msg =
|
||
"🛠️ Tool execution completed. How can I help you with anything else?";
|
||
let response = BotResponse {
|
||
bot_id: "tool_bot".to_string(),
|
||
user_id: user_id_clone,
|
||
session_id: session_id_clone.clone(),
|
||
channel: "web".to_string(),
|
||
content: completion_msg.to_string(),
|
||
message_type: "tool_complete".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
let _ = web_adapter_clone.send_message(response).await;
|
||
}
|
||
Err(e) => {
|
||
log::error!("Tool execution failed: {}", e);
|
||
|
||
let error_msg = format!("❌ Tool error: {}", e);
|
||
let response = BotResponse {
|
||
bot_id: "tool_bot".to_string(),
|
||
user_id: user_id_clone,
|
||
session_id: session_id_clone.clone(),
|
||
channel: "web".to_string(),
|
||
content: error_msg,
|
||
message_type: "tool_error".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
|
||
let _ = web_adapter_clone.send_message(response).await;
|
||
}
|
||
}
|
||
|
||
if let Ok(mut conn) = redis_clone.get_async_connection().await {
|
||
let active_key = format!("tool:{}:active", session_id_clone);
|
||
let _ = conn.del(&active_key).await;
|
||
}
|
||
});
|
||
|
||
Ok(ToolResult {
|
||
success: true,
|
||
output: format!(
|
||
"🛠️ Starting {} tool. Please follow the tool's instructions.",
|
||
tool_name
|
||
),
|
||
requires_input: true,
|
||
session_id: session_id.to_string(),
|
||
})
|
||
}
|
||
|
||
async fn provide_input(
|
||
&self,
|
||
session_id: &str,
|
||
input: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
||
let input_key = format!("tool:{}:input", session_id);
|
||
conn.lpush(&input_key, input).await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn get_output(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
||
let output_key = format!("tool:{}:output", session_id);
|
||
let messages: Vec<String> = conn.lrange(&output_key, 0, -1).await?;
|
||
let _: () = conn.del(&output_key).await?;
|
||
Ok(messages)
|
||
}
|
||
|
||
async fn is_waiting_for_input(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
||
let waiting_key = format!("tool:{}:waiting", session_id);
|
||
let exists: bool = conn.exists(&waiting_key).await?;
|
||
Ok(exists)
|
||
}
|
||
}
|
||
|
||
fn get_tool(name: &str) -> Option<Tool> {
|
||
match name {
|
||
"calculator" => Some(Tool {
|
||
name: "calculator".to_string(),
|
||
description: "Perform mathematical calculations".to_string(),
|
||
parameters: HashMap::from([
|
||
("operation".to_string(), "add|subtract|multiply|divide".to_string()),
|
||
("a".to_string(), "number".to_string()),
|
||
("b".to_string(), "number".to_string()),
|
||
]),
|
||
script: r#"
|
||
let TALK = |message| {
|
||
talk(message);
|
||
};
|
||
|
||
let HEAR = || {
|
||
hear()
|
||
};
|
||
|
||
TALK("🔢 Calculator started!");
|
||
TALK("Please enter the first number:");
|
||
let a = HEAR();
|
||
TALK("Please enter the second number:");
|
||
let b = HEAR();
|
||
TALK("Choose operation: add, subtract, multiply, or divide:");
|
||
let op = HEAR();
|
||
|
||
let num_a = a.to_float();
|
||
let num_b = b.to_float();
|
||
|
||
if op == "add" {
|
||
let result = num_a + num_b;
|
||
TALK("✅ Result: " + a + " + " + b + " = " + result);
|
||
} else if op == "subtract" {
|
||
let result = num_a - num_b;
|
||
TALK("✅ Result: " + a + " - " + b + " = " + result);
|
||
} else if op == "multiply" {
|
||
let result = num_a * num_b;
|
||
TALK("✅ Result: " + a + " × " + b + " = " + result);
|
||
} else if op == "divide" {
|
||
if num_b != 0.0 {
|
||
let result = num_a / num_b;
|
||
TALK("✅ Result: " + a + " ÷ " + b + " = " + result);
|
||
} else {
|
||
TALK("❌ Error: Cannot divide by zero!");
|
||
}
|
||
} else {
|
||
TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide");
|
||
}
|
||
|
||
TALK("Calculator session completed. Thank you!");
|
||
"#.to_string(),
|
||
}),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct ToolManager {
|
||
tools: HashMap<String, Tool>,
|
||
waiting_responses: Arc<Mutex<HashMap<String, mpsc::Sender<String>>>>,
|
||
}
|
||
|
||
impl ToolManager {
|
||
pub fn new() -> Self {
|
||
let mut tools = HashMap::new();
|
||
|
||
let calculator_tool = Tool {
|
||
name: "calculator".to_string(),
|
||
description: "Perform calculations".to_string(),
|
||
parameters: HashMap::from([
|
||
(
|
||
"operation".to_string(),
|
||
"add|subtract|multiply|divide".to_string(),
|
||
),
|
||
("a".to_string(), "number".to_string()),
|
||
("b".to_string(), "number".to_string()),
|
||
]),
|
||
script: r#"
|
||
TALK("Calculator started. Enter first number:");
|
||
let a = HEAR();
|
||
TALK("Enter second number:");
|
||
let b = HEAR();
|
||
TALK("Operation (add/subtract/multiply/divide):");
|
||
let op = HEAR();
|
||
|
||
let num_a = a.parse::<f64>().unwrap();
|
||
let num_b = b.parse::<f64>().unwrap();
|
||
let result = if op == "add" {
|
||
num_a + num_b
|
||
} else if op == "subtract" {
|
||
num_a - num_b
|
||
} else if op == "multiply" {
|
||
num_a * num_b
|
||
} else if op == "divide" {
|
||
if num_b == 0.0 {
|
||
TALK("Cannot divide by zero");
|
||
return;
|
||
}
|
||
num_a / num_b
|
||
} else {
|
||
TALK("Invalid operation");
|
||
return;
|
||
};
|
||
TALK("Result: ".to_string() + &result.to_string());
|
||
"#
|
||
.to_string(),
|
||
};
|
||
|
||
tools.insert(calculator_tool.name.clone(), calculator_tool);
|
||
Self {
|
||
tools,
|
||
waiting_responses: Arc::new(Mutex::new(HashMap::new())),
|
||
}
|
||
}
|
||
|
||
pub fn get_tool(&self, name: &str) -> Option<&Tool> {
|
||
self.tools.get(name)
|
||
}
|
||
|
||
pub fn list_tools(&self) -> Vec<String> {
|
||
self.tools.keys().cloned().collect()
|
||
}
|
||
|
||
pub async fn execute_tool(
|
||
&self,
|
||
tool_name: &str,
|
||
session_id: &str,
|
||
user_id: &str,
|
||
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
||
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
||
|
||
Ok(ToolResult {
|
||
success: true,
|
||
output: format!("Tool {} started for user {}", tool_name, user_id),
|
||
requires_input: true,
|
||
session_id: session_id.to_string(),
|
||
})
|
||
}
|
||
|
||
pub async fn is_tool_waiting(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||
let waiting = self.waiting_responses.lock().await;
|
||
Ok(waiting.contains_key(session_id))
|
||
}
|
||
|
||
pub async fn provide_input(
|
||
&self,
|
||
session_id: &str,
|
||
input: &str,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
self.provide_user_response(session_id, "default_bot", input.to_string())
|
||
.await
|
||
}
|
||
|
||
pub async fn get_tool_output(
|
||
&self,
|
||
session_id: &str,
|
||
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||
Ok(vec![])
|
||
}
|
||
|
||
pub async fn execute_tool_with_session(
|
||
&self,
|
||
tool_name: &str,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
session_manager: SessionManager,
|
||
channel_sender: mpsc::Sender<BotResponse>,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
||
session_manager
|
||
.set_current_tool(user_id, bot_id, Some(tool_name.to_string()))
|
||
.await?;
|
||
|
||
let user_id = user_id.to_string();
|
||
let bot_id = bot_id.to_string();
|
||
let script = tool.script.clone();
|
||
let session_manager_clone = session_manager.clone();
|
||
let waiting_responses = self.waiting_responses.clone();
|
||
|
||
tokio::spawn(async move {
|
||
let mut engine = rhai::Engine::new();
|
||
let (talk_tx, mut talk_rx) = mpsc::channel(100);
|
||
let (hear_tx, mut hear_rx) = mpsc::channel(100);
|
||
|
||
{
|
||
let key = format!("{}:{}", user_id, bot_id);
|
||
let mut waiting = waiting_responses.lock().await;
|
||
waiting.insert(key, hear_tx);
|
||
}
|
||
|
||
let channel_sender_clone = channel_sender.clone();
|
||
let user_id_clone = user_id.clone();
|
||
let bot_id_clone = bot_id.clone();
|
||
|
||
let talk_tx_clone = talk_tx.clone();
|
||
engine.register_fn("TALK", move |message: String| {
|
||
let tx = talk_tx_clone.clone();
|
||
tokio::spawn(async move {
|
||
let _ = tx.send(message).await;
|
||
});
|
||
});
|
||
|
||
let hear_rx_mutex = Arc::new(Mutex::new(hear_rx));
|
||
engine.register_fn("HEAR", move || {
|
||
let hear_rx = hear_rx_mutex.clone();
|
||
tokio::task::block_in_place(|| {
|
||
tokio::runtime::Handle::current().block_on(async move {
|
||
let mut receiver = hear_rx.lock().await;
|
||
receiver.recv().await.unwrap_or_default()
|
||
})
|
||
})
|
||
});
|
||
|
||
let script_result =
|
||
tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await;
|
||
|
||
if let Ok(Err(e)) = script_result {
|
||
let error_response = BotResponse {
|
||
bot_id: bot_id_clone.clone(),
|
||
user_id: user_id_clone.clone(),
|
||
session_id: Uuid::new_v4().to_string(),
|
||
channel: "test".to_string(),
|
||
content: format!("Tool error: {}", e),
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
let _ = channel_sender_clone.send(error_response).await;
|
||
}
|
||
|
||
while let Some(message) = talk_rx.recv().await {
|
||
let response = BotResponse {
|
||
bot_id: bot_id.clone(),
|
||
user_id: user_id.clone(),
|
||
session_id: Uuid::new_v4().to_string(),
|
||
channel: "test".to_string(),
|
||
content: message,
|
||
message_type: "text".to_string(),
|
||
stream_token: None,
|
||
is_complete: true,
|
||
};
|
||
let _ = channel_sender.send(response).await;
|
||
}
|
||
|
||
let _ = session_manager_clone
|
||
.set_current_tool(&user_id, &bot_id, None)
|
||
.await;
|
||
});
|
||
|
||
Ok(())
|
||
}
|
||
|
||
pub async fn provide_user_response(
|
||
&self,
|
||
user_id: &str,
|
||
bot_id: &str,
|
||
response: String,
|
||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||
let key = format!("{}:{}", user_id, bot_id);
|
||
let mut waiting = self.waiting_responses.lock().await;
|
||
if let Some(tx) = waiting.get_mut(&key) {
|
||
let _ = tx.send(response).await;
|
||
waiting.remove(&key);
|
||
}
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl Default for ToolManager {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
pub struct ToolApi;
|
||
|
||
impl ToolApi {
|
||
pub fn new() -> Self {
|
||
Self
|
||
}
|
||
}
|
||
|
||
.
|
||
└── src
|
||
├── auth
|
||
│ └── mod.rs
|
||
├── automation
|
||
│ └── mod.rs
|
||
├── basic
|
||
│ ├── keywords
|
||
│ │ ├── create_draft.rs
|
||
│ │ ├── create_site.rs
|
||
│ │ ├── find.rs
|
||
│ │ ├── first.rs
|
||
│ │ ├── format.rs
|
||
│ │ ├── for_next.rs
|
||
│ │ ├── get.rs
|
||
│ │ ├── get_website.rs
|
||
│ │ ├── last.rs
|
||
│ │ ├── llm_keyword.rs
|
||
│ │ ├── mod.rs
|
||
│ │ ├── on.rs
|
||
│ │ ├── print.rs
|
||
│ │ ├── set.rs
|
||
│ │ ├── set_schedule.rs
|
||
│ │ └── wait.rs
|
||
│ └── mod.rs
|
||
├── bot
|
||
│ └── mod.rs
|
||
├── channels
|
||
│ └── mod.rs
|
||
├── chart
|
||
│ └── mod.rs
|
||
├── config
|
||
│ └── mod.rs
|
||
├── context
|
||
│ └── mod.rs
|
||
├── email
|
||
│ └── mod.rs
|
||
├── file
|
||
│ └── mod.rs
|
||
├── llm
|
||
│ ├── llm_generic.rs
|
||
│ ├── llm_local.rs
|
||
│ ├── llm_provider.rs
|
||
│ ├── llm.rs
|
||
│ └── mod.rs
|
||
├── main.rs
|
||
├── org
|
||
│ └── mod.rs
|
||
├── session
|
||
│ └── mod.rs
|
||
├── shared
|
||
│ ├── models.rs
|
||
│ ├── mod.rs
|
||
│ ├── state.rs
|
||
│ └── utils.rs
|
||
├── tests
|
||
│ ├── integration_email_list.rs
|
||
│ ├── integration_file_list_test.rs
|
||
│ └── integration_file_upload_test.rs
|
||
├── tools
|
||
│ └── mod.rs
|
||
├── web_automation
|
||
│ └── mod.rs
|
||
└── whatsapp
|
||
└── mod.rs
|
||
|
||
21 directories, 44 files
|