- Just more errors to fix.

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-10-06 19:12:13 -03:00
parent 5330180c36
commit 66e2ed2ad1
20 changed files with 1224 additions and 1876 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@ target
.env .env
*.env *.env
work work
*.txt

1
Cargo.lock generated
View file

@ -909,6 +909,7 @@ dependencies = [
"sqlx", "sqlx",
"tempfile", "tempfile",
"thirtyfour", "thirtyfour",
"time",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",

View file

@ -59,3 +59,4 @@ tracing-subscriber = { version = "0.3", features = ["fmt"] }
urlencoding = "2.1" urlencoding = "2.1"
uuid = { version = "1.11", features = ["serde", "v4"] } uuid = { version = "1.11", features = ["serde", "v4"] }
zip = "2.2" zip = "2.2"
time = "0.3.44"

View file

@ -109,7 +109,7 @@ EOF
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc; use chrono::Utc;
use livekit::{DataPacketKind, Room, RoomOptions}; use livekit::prelude::*;
use log::info; use log::info;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -202,7 +202,7 @@ impl VoiceAdapter {
tokio::spawn(async move { tokio::spawn(async move {
while let Some(event) = events.recv().await { while let Some(event) = events.recv().await {
match event { match event {
livekit::prelude::RoomEvent::DataReceived(data_packet) => { RoomEvent::DataReceived(data_packet) => {
if let Ok(message) = if let Ok(message) =
serde_json::from_slice::<UserMessage>(&data_packet.data) serde_json::from_slice::<UserMessage>(&data_packet.data)
{ {
@ -225,11 +225,7 @@ impl VoiceAdapter {
} }
} }
} }
livekit::prelude::RoomEvent::TrackSubscribed( RoomEvent::TrackSubscribed(track, publication, participant) => {
track,
publication,
participant,
) => {
info!("Voice track subscribed from {}", participant.identity()); info!("Voice track subscribed from {}", participant.identity());
} }
_ => {} _ => {}
@ -246,7 +242,7 @@ impl VoiceAdapter {
session_id: &str, session_id: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error>> {
if let Some(room) = self.rooms.lock().await.remove(session_id) { if let Some(room) = self.rooms.lock().await.remove(session_id) {
room.disconnect(); room.disconnect().await;
} }
Ok(()) Ok(())
} }
@ -267,11 +263,13 @@ impl VoiceAdapter {
"timestamp": Utc::now() "timestamp": Utc::now()
}); });
room.local_participant().publish_data( room.local_participant()
.publish_data(
serde_json::to_vec(&voice_response)?, serde_json::to_vec(&voice_response)?,
DataPacketKind::Reliable, DataPacketKind::Reliable,
&[], &[],
)?; )
.await?;
} }
Ok(()) Ok(())
} }
@ -294,7 +292,6 @@ use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -323,7 +320,6 @@ struct Choice {
finish_reason: String, finish_reason: String,
} }
// Llama.cpp server request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest { struct LlamaCppRequest {
prompt: String, prompt: String,
@ -341,8 +337,7 @@ struct LlamaCppResponse {
generation_settings: Option<serde_json::Value>, generation_settings: Option<serde_json::Value>,
} }
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
{
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
if llm_local.to_lowercase() != "true" { if llm_local.to_lowercase() != "true" {
@ -350,7 +345,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Get configuration from environment variables
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let embedding_url = let embedding_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -365,7 +359,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!(" LLM Model: {}", llm_model_path); info!(" LLM Model: {}", llm_model_path);
info!(" Embedding Model: {}", embedding_model_path); info!(" Embedding Model: {}", embedding_model_path);
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await; let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await; let embedding_running = is_server_running(&embedding_url).await;
@ -374,7 +367,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Start servers that aren't running
let mut tasks = vec![]; let mut tasks = vec![];
if !llm_running && !llm_model_path.is_empty() { if !llm_running && !llm_model_path.is_empty() {
@ -399,19 +391,17 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server"); info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
} }
// Wait for all server startup tasks
for task in tasks { for task in tasks {
task.await??; task.await??;
} }
// Wait for servers to be ready with verbose logging
info!("⏳ Waiting for servers to become ready..."); info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model_path.is_empty(); let mut llm_ready = llm_running || llm_model_path.is_empty();
let mut embedding_ready = embedding_running || embedding_model_path.is_empty(); let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
let mut attempts = 0; let mut attempts = 0;
let max_attempts = 60; // 2 minutes total let max_attempts = 60;
while attempts < max_attempts && (!llm_ready || !embedding_ready) { while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await; sleep(Duration::from_secs(2)).await;
@ -476,8 +466,6 @@ async fn start_llm_server(
std::env::set_var("OMP_PLACES", "cores"); std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close"); std::env::set_var("OMP_PROC_BIND", "close");
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
let mut cmd = tokio::process::Command::new("sh"); let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!( cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &", "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
@ -513,7 +501,6 @@ async fn is_server_running(url: &str) -> bool {
} }
} }
// Convert OpenAI chat messages to a single prompt
fn messages_to_prompt(messages: &[ChatMessage]) -> String { fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new(); let mut prompt = String::new();
@ -538,32 +525,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
prompt prompt
} }
// Proxy endpoint
#[post("/local/v1/chat/completions")] #[post("/local/v1/chat/completions")]
pub async fn chat_completions_local( pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>, req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest, _req: HttpRequest,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok().unwrap(); dotenv().ok();
// Get llama.cpp server URL
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
// Convert OpenAI format to llama.cpp format
let prompt = messages_to_prompt(&req_body.messages); let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest { let llama_request = LlamaCppRequest {
prompt, prompt,
n_predict: Some(500), // Adjust as needed n_predict: Some(500),
temperature: Some(0.7), temperature: Some(0.7),
top_k: Some(40), top_k: Some(40),
top_p: Some(0.9), top_p: Some(0.9),
stream: req_body.stream, stream: req_body.stream,
}; };
// Send request to llama.cpp server
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_secs(120)) // 2 minute timeout .timeout(Duration::from_secs(120))
.build() .build()
.map_err(|e| { .map_err(|e| {
error!("Error creating HTTP client: {}", e); error!("Error creating HTTP client: {}", e);
@ -589,7 +572,6 @@ pub async fn chat_completions_local(
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?; })?;
// Convert llama.cpp response to OpenAI format
let openai_response = ChatCompletionResponse { let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
@ -632,7 +614,6 @@ pub async fn chat_completions_local(
} }
} }
// OpenAI Embedding Request - Modified to handle both string and array inputs
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct EmbeddingRequest { pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")] #[serde(deserialize_with = "deserialize_input")]
@ -642,7 +623,6 @@ pub struct EmbeddingRequest {
pub _encoding_format: Option<String>, pub _encoding_format: Option<String>,
} }
// Custom deserializer to handle both string and array inputs
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error> fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
@ -688,7 +668,6 @@ where
deserializer.deserialize_any(InputVisitor) deserializer.deserialize_any(InputVisitor)
} }
// OpenAI Embedding Response
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub object: String, pub object: String,
@ -710,20 +689,17 @@ pub struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
// Llama.cpp Embedding Request
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest { struct LlamaCppEmbeddingRequest {
pub content: String, pub content: String,
} }
// FIXED: Handle the stupid nested array format
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem { struct LlamaCppEmbeddingResponseItem {
pub index: usize, pub index: usize,
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays pub embedding: Vec<Vec<f32>>,
} }
// Proxy endpoint for embeddings
#[post("/v1/embeddings")] #[post("/v1/embeddings")]
pub async fn embeddings_local( pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>, req_body: web::Json<EmbeddingRequest>,
@ -731,7 +707,6 @@ pub async fn embeddings_local(
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok(); dotenv().ok();
// Get llama.cpp server URL
let llama_url = let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -743,7 +718,6 @@ pub async fn embeddings_local(
actix_web::error::ErrorInternalServerError("Failed to create HTTP client") actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?; })?;
// Process each input text and get embeddings
let mut embeddings_data = Vec::new(); let mut embeddings_data = Vec::new();
let mut total_tokens = 0; let mut total_tokens = 0;
@ -768,13 +742,11 @@ pub async fn embeddings_local(
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
// First, get the raw response text for debugging
let raw_response = response.text().await.map_err(|e| { let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e); error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response") actix_web::error::ErrorInternalServerError("Failed to read response")
})?; })?;
// Parse the response as a vector of items with nested arrays
let llama_response: Vec<LlamaCppEmbeddingResponseItem> = let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| { serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e); error!("Error parsing llama.cpp embedding response: {}", e);
@ -784,17 +756,13 @@ pub async fn embeddings_local(
) )
})?; })?;
// Extract the embedding from the nested array bullshit
if let Some(item) = llama_response.get(0) { if let Some(item) = llama_response.get(0) {
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
let flattened_embedding = if !item.embedding.is_empty() { let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone() // Take the first (and probably only) inner array item.embedding[0].clone()
} else { } else {
vec![] // Empty if no embedding data vec![]
}; };
// Estimate token count
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens; total_tokens += estimated_tokens;
@ -832,7 +800,6 @@ pub async fn embeddings_local(
} }
} }
// Build OpenAI-compatible response
let openai_response = EmbeddingResponse { let openai_response = EmbeddingResponse {
object: "list".to_string(), object: "list".to_string(),
data: embeddings_data, data: embeddings_data,
@ -846,7 +813,6 @@ pub async fn embeddings_local(
Ok(HttpResponse::Ok().json(openai_response)) Ok(HttpResponse::Ok().json(openai_response))
} }
// Health check endpoint
#[actix_web::get("/health")] #[actix_web::get("/health")]
pub async fn health() -> Result<HttpResponse> { pub async fn health() -> Result<HttpResponse> {
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
@ -1008,6 +974,8 @@ pub mod llm_generic;
pub mod llm_local; pub mod llm_local;
pub mod llm_provider; pub mod llm_provider;
pub use llm_provider::*;
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use langchain_rust::{ use langchain_rust::{
@ -1036,7 +1004,6 @@ pub trait LLMProvider: Send + Sync {
tx: mpsc::Sender<String>, tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>; ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
// Add tool calling capability using LangChain tools
async fn generate_with_tools( async fn generate_with_tools(
&self, &self,
prompt: &str, prompt: &str,
@ -1116,7 +1083,6 @@ impl LLMProvider for OpenAIClient {
_session_id: &str, _session_id: &str,
_user_id: &str, _user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Enhanced prompt with tool information
let tools_info = if available_tools.is_empty() { let tools_info = if available_tools.is_empty() {
String::new() String::new()
} else { } else {
@ -1279,133 +1245,14 @@ impl LLMProvider for MockLLMProvider {
} }
} }
use log::info;
use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv; use dotenv::dotenv;
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
message: ChatMessage,
finish_reason: String,
}
#[post("/azure/v1/chat/completions")]
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Always log raw POST data
if let Ok(body_str) = std::str::from_utf8(&body) {
info!("POST Data: {}", body_str);
} else {
info!("POST Data (binary): {:?}", body);
}
dotenv().ok();
// Environment variables
let azure_endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
let azure_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let deployment_name = env::var("AI_LLM_MODEL")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
// Construct Azure OpenAI URL
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
azure_endpoint, deployment_name
);
// Forward headers
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"api-key",
reqwest::header::HeaderValue::from_str(&azure_key)
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
);
headers.insert(
"Content-Type",
reqwest::header::HeaderValue::from_static("application/json"),
);
let body_str = std::str::from_utf8(&body).unwrap_or("");
info!("Original POST Data: {}", body_str);
// Remove the problematic params
let re =
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
let cleaned = re.replace_all(body_str, "");
let cleaned_body = web::Bytes::from(cleaned.to_string());
info!("Cleaned POST Data: {}", cleaned);
// Send request to Azure
let client = Client::new();
let response = client
.post(&url)
.headers(headers)
.body(cleaned_body)
.send()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response based on status
let status = response.status();
let raw_response = response
.text()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Log the raw response
info!("Raw Azure response: {}", raw_response);
if status.is_success() {
Ok(HttpResponse::Ok().body(raw_response))
} else {
// Handle error responses properly
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).body(raw_response))
}
}
use log::{error, info}; use log::{error, info};
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv;
use regex::Regex; use regex::Regex;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -1435,20 +1282,17 @@ struct Choice {
} }
fn clean_request_body(body: &str) -> String { fn clean_request_body(body: &str) -> String {
// Remove problematic parameters that might not be supported by all providers
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
re.replace_all(body, "").to_string() re.replace_all(body, "").to_string()
} }
#[post("/v1/chat/completions")] #[post("/v1/chat/completions")]
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> { pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Log raw POST data
let body_str = std::str::from_utf8(&body).unwrap_or_default(); let body_str = std::str::from_utf8(&body).unwrap_or_default();
info!("Original POST Data: {}", body_str); info!("Original POST Data: {}", body_str);
dotenv().ok(); dotenv().ok();
// Get environment variables
let api_key = env::var("AI_KEY") let api_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let model = env::var("AI_LLM_MODEL") let model = env::var("AI_LLM_MODEL")
@ -1456,11 +1300,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
let endpoint = env::var("AI_ENDPOINT") let endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
// Parse and modify the request body
let mut json_value: serde_json::Value = serde_json::from_str(body_str) let mut json_value: serde_json::Value = serde_json::from_str(body_str)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
// Add model parameter
if let Some(obj) = json_value.as_object_mut() { if let Some(obj) = json_value.as_object_mut() {
obj.insert("model".to_string(), serde_json::Value::String(model)); obj.insert("model".to_string(), serde_json::Value::String(model));
} }
@ -1470,7 +1312,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Modified POST Data: {}", modified_body_str); info!("Modified POST Data: {}", modified_body_str);
// Set up headers
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
headers.insert( headers.insert(
"Authorization", "Authorization",
@ -1482,7 +1323,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
reqwest::header::HeaderValue::from_static("application/json"), reqwest::header::HeaderValue::from_static("application/json"),
); );
// Send request to the AI provider
let client = Client::new(); let client = Client::new();
let response = client let response = client
.post(&endpoint) .post(&endpoint)
@ -1492,7 +1332,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.await .await
.map_err(actix_web::error::ErrorInternalServerError)?; .map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response
let status = response.status(); let status = response.status();
let raw_response = response let raw_response = response
.text() .text()
@ -1502,7 +1341,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Provider response status: {}", status); info!("Provider response status: {}", status);
info!("Provider response body: {}", raw_response); info!("Provider response body: {}", raw_response);
// Convert response to OpenAI format if successful
if status.is_success() { if status.is_success() {
match convert_to_openai_format(&raw_response) { match convert_to_openai_format(&raw_response) {
Ok(openai_response) => Ok(HttpResponse::Ok() Ok(openai_response) => Ok(HttpResponse::Ok()
@ -1510,14 +1348,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.body(openai_response)), .body(openai_response)),
Err(e) => { Err(e) => {
error!("Failed to convert response format: {}", e); error!("Failed to convert response format: {}", e);
// Return the original response if conversion fails
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.content_type("application/json") .content_type("application/json")
.body(raw_response)) .body(raw_response))
} }
} }
} else { } else {
// Return error as-is
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
@ -1527,7 +1363,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
} }
} }
/// Converts provider response to OpenAI-compatible format
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> { fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ProviderChoice { struct ProviderChoice {
@ -1589,10 +1424,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
total_tokens: u32, total_tokens: u32,
} }
// Parse the provider response
let provider: ProviderResponse = serde_json::from_str(provider_response)?; let provider: ProviderResponse = serde_json::from_str(provider_response)?;
// Extract content from the first choice
let first_choice = provider.choices.get(0).ok_or("No choices in response")?; let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
let content = first_choice.message.content.clone(); let content = first_choice.message.content.clone();
let role = first_choice let role = first_choice
@ -1601,7 +1434,6 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
.clone() .clone()
.unwrap_or_else(|| "assistant".to_string()); .unwrap_or_else(|| "assistant".to_string());
// Calculate token usage
let usage = provider.usage.unwrap_or_default(); let usage = provider.usage.unwrap_or_default();
let prompt_tokens = usage.prompt_tokens.unwrap_or(0); let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
let completion_tokens = usage let completion_tokens = usage
@ -1643,15 +1475,13 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
serde_json::to_string(&openai_response).map_err(|e| e.into()) serde_json::to_string(&openai_response).map_err(|e| e.into())
} }
// Default implementation for ProviderUsage
use async_trait::async_trait; use async_trait::async_trait;
use log::info;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use log::info;
use crate::shared::BotResponse; use crate::shared::BotResponse;
@ -1718,7 +1548,7 @@ pub struct WhatsAppAdapter {
access_token: String, access_token: String,
phone_number_id: String, phone_number_id: String,
webhook_verify_token: String, webhook_verify_token: String,
sessions: Arc<Mutex<HashMap<String, String>>>, // phone -> session_id sessions: Arc<Mutex<HashMap<String, String>>>,
} }
impl WhatsAppAdapter { impl WhatsAppAdapter {
@ -1815,7 +1645,7 @@ impl WhatsAppAdapter {
} }
#[async_trait] #[async_trait]
impl super::channels::ChannelAdapter for WhatsAppAdapter { impl crate::channels::ChannelAdapter for WhatsAppAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> { async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
info!("Sending WhatsApp response to: {}", response.user_id); info!("Sending WhatsApp response to: {}", response.user_id);
self.send_whatsapp_message(&response.user_id, &response.content).await self.send_whatsapp_message(&response.user_id, &response.content).await
@ -2545,12 +2375,10 @@ use uuid::Uuid;
use crate::{ use crate::{
auth::AuthService, auth::AuthService,
channels::ChannelAdapter, channels::ChannelAdapter,
chart::ChartGenerator,
llm::LLMProvider, llm::LLMProvider,
session::SessionManager, session::SessionManager,
shared::{BotResponse, UserMessage, UserSession}, shared::{BotResponse, UserMessage, UserSession},
tools::ToolManager, tools::ToolManager,
whatsapp::WhatsAppAdapter,
}; };
pub struct BotOrchestrator { pub struct BotOrchestrator {
@ -2560,7 +2388,6 @@ pub struct BotOrchestrator {
auth_service: AuthService, auth_service: AuthService,
channels: HashMap<String, Arc<dyn ChannelAdapter>>, channels: HashMap<String, Arc<dyn ChannelAdapter>>,
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>, response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>, vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>, sql_chain: Option<Arc<LLMChain>>,
} }
@ -2571,7 +2398,6 @@ impl BotOrchestrator {
tool_manager: ToolManager, tool_manager: ToolManager,
llm_provider: Arc<dyn LLMProvider>, llm_provider: Arc<dyn LLMProvider>,
auth_service: AuthService, auth_service: AuthService,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>, vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>, sql_chain: Option<Arc<LLMChain>>,
) -> Self { ) -> Self {
@ -2582,7 +2408,6 @@ impl BotOrchestrator {
auth_service, auth_service,
channels: HashMap::new(), channels: HashMap::new(),
response_channels: Arc::new(Mutex::new(HashMap::new())), response_channels: Arc::new(Mutex::new(HashMap::new())),
chart_generator,
vector_store, vector_store,
sql_chain, sql_chain,
} }
@ -2660,7 +2485,6 @@ impl BotOrchestrator {
let response_content = match session.answer_mode.as_str() { let response_content = match session.answer_mode.as_str() {
"document" => self.document_mode_handler(&message, &session).await?, "document" => self.document_mode_handler(&message, &session).await?,
"chart" => self.chart_mode_handler(&message, &session).await?,
"database" => self.database_mode_handler(&message, &session).await?, "database" => self.database_mode_handler(&message, &session).await?,
"tool" => self.tool_mode_handler(&message, &session).await?, "tool" => self.tool_mode_handler(&message, &session).await?,
_ => self.direct_mode_handler(&message, &session).await?, _ => self.direct_mode_handler(&message, &session).await?,
@ -2682,7 +2506,7 @@ impl BotOrchestrator {
}; };
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await; adapter.send_message(bot_response).await?;
} }
Ok(()) Ok(())
@ -2691,7 +2515,7 @@ impl BotOrchestrator {
async fn document_mode_handler( async fn document_mode_handler(
&self, &self,
message: &UserMessage, message: &UserMessage,
session: &UserSession, _session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(vector_store) = &self.vector_store { if let Some(vector_store) = &self.vector_store {
let similar_docs = vector_store let similar_docs = vector_store
@ -2714,36 +2538,7 @@ impl BotOrchestrator {
.generate(&enhanced_prompt, &serde_json::Value::Null) .generate(&enhanced_prompt, &serde_json::Value::Null)
.await .await
} else { } else {
self.direct_mode_handler(message, session).await Ok("Document mode not available".to_string())
}
}
async fn chart_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(chart_generator) = &self.chart_generator {
let chart_response = chart_generator
.generate_chart(&message.content, "bar")
.await?;
self.session_manager
.save_message(
session.id,
session.user_id,
"system",
&format!("Generated chart for query: {}", message.content),
"chart",
)
.await?;
Ok(format!(
"Chart generated for your query. Data retrieved: {}",
chart_response.sql_query
))
} else {
self.document_mode_handler(message, session).await
} }
} }

View file

@ -3,7 +3,7 @@ use argon2::{
Argon2, Argon2,
}; };
use redis::Client; use redis::Client;
use sqlx::{PgPool, Row}; // <-- required for .get() use sqlx::{PgPool, Row};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
@ -13,7 +13,6 @@ pub struct AuthService {
} }
impl AuthService { impl AuthService {
#[allow(clippy::new_without_default)]
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self { pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
Self { pool, redis } Self { pool, redis }
} }
@ -22,7 +21,7 @@ impl AuthService {
&self, &self,
username: &str, username: &str,
password: &str, password: &str,
) -> Result<Option<Uuid>, Box<dyn std::error::Error>> { ) -> Result<Option<Uuid>, Box<dyn std::error::Error + Send + Sync>> {
let user = sqlx::query( let user = sqlx::query(
"SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true", "SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true",
) )
@ -52,7 +51,7 @@ impl AuthService {
username: &str, username: &str,
email: &str, email: &str,
password: &str, password: &str,
) -> Result<Uuid, Box<dyn std::error::Error>> { ) -> Result<Uuid, Box<dyn std::error::Error + Send + Sync>> {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default(); let argon2 = Argon2::default();
let password_hash = match argon2.hash_password(password.as_bytes(), &salt) { let password_hash = match argon2.hash_password(password.as_bytes(), &salt) {
@ -80,7 +79,7 @@ impl AuthService {
pub async fn delete_user_cache( pub async fn delete_user_cache(
&self, &self,
username: &str, username: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(redis_client) = &self.redis { if let Some(redis_client) = &self.redis {
let mut conn = redis_client.get_multiplexed_async_connection().await?; let mut conn = redis_client.get_multiplexed_async_connection().await?;
let cache_key = format!("auth:user:{}", username); let cache_key = format!("auth:user:{}", username);
@ -94,7 +93,7 @@ impl AuthService {
&self, &self,
user_id: Uuid, user_id: Uuid,
new_password: &str, new_password: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default(); let argon2 = Argon2::default();
let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) { let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) {

View file

@ -121,9 +121,10 @@ impl AutomationService {
async fn update_last_triggered(&self, automation_id: Uuid) { async fn update_last_triggered(&self, automation_id: Uuid) {
if let Some(pool) = &self.state.db { if let Some(pool) = &self.state.db {
let now = time::OffsetDateTime::now_utc();
if let Err(e) = sqlx::query!( if let Err(e) = sqlx::query!(
"UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2", "UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2",
Utc::now(), now,
automation_id automation_id
) )
.execute(pool) .execute(pool)

View file

@ -2,13 +2,7 @@ use actix_web::{web, HttpRequest, HttpResponse, Result};
use actix_ws::Message as WsMessage; use actix_ws::Message as WsMessage;
use chrono::Utc; use chrono::Utc;
use langchain_rust::{ use langchain_rust::{
chain::{Chain, LLMChain},
llm::openai::OpenAI,
memory::SimpleMemory, memory::SimpleMemory,
prompt_args,
tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder},
vectorstore::qdrant::Qdrant as LangChainQdrant,
vectorstore::{VecStoreOptions, VectorStore},
}; };
use log::info; use log::info;
use serde_json; use serde_json;
@ -21,12 +15,10 @@ use uuid::Uuid;
use crate::{ use crate::{
auth::AuthService, auth::AuthService,
channels::ChannelAdapter, channels::ChannelAdapter,
chart::ChartGenerator,
llm::LLMProvider, llm::LLMProvider,
session::SessionManager, session::SessionManager,
shared::{BotResponse, UserMessage, UserSession}, shared::{BotResponse, UserMessage, UserSession},
tools::ToolManager, tools::ToolManager,
whatsapp::WhatsAppAdapter,
}; };
pub struct BotOrchestrator { pub struct BotOrchestrator {
@ -36,9 +28,6 @@ pub struct BotOrchestrator {
auth_service: AuthService, auth_service: AuthService,
channels: HashMap<String, Arc<dyn ChannelAdapter>>, channels: HashMap<String, Arc<dyn ChannelAdapter>>,
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>, response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>,
} }
impl BotOrchestrator { impl BotOrchestrator {
@ -47,9 +36,6 @@ impl BotOrchestrator {
tool_manager: ToolManager, tool_manager: ToolManager,
llm_provider: Arc<dyn LLMProvider>, llm_provider: Arc<dyn LLMProvider>,
auth_service: AuthService, auth_service: AuthService,
chart_generator: Option<Arc<ChartGenerator>>,
vector_store: Option<Arc<LangChainQdrant>>,
sql_chain: Option<Arc<LLMChain>>,
) -> Self { ) -> Self {
Self { Self {
session_manager, session_manager,
@ -58,9 +44,6 @@ impl BotOrchestrator {
auth_service, auth_service,
channels: HashMap::new(), channels: HashMap::new(),
response_channels: Arc::new(Mutex::new(HashMap::new())), response_channels: Arc::new(Mutex::new(HashMap::new())),
chart_generator,
vector_store,
sql_chain,
} }
} }
@ -134,13 +117,7 @@ impl BotOrchestrator {
) )
.await?; .await?;
let response_content = match session.answer_mode.as_str() { let response_content = self.direct_mode_handler(&message, &session).await?;
"document" => self.document_mode_handler(&message, &session).await?,
"chart" => self.chart_mode_handler(&message, &session).await?,
"database" => self.database_mode_handler(&message, &session).await?,
"tool" => self.tool_mode_handler(&message, &session).await?,
_ => self.direct_mode_handler(&message, &session).await?,
};
self.session_manager self.session_manager
.save_message(session.id, user_id, "assistant", &response_content, "text") .save_message(session.id, user_id, "assistant", &response_content, "text")
@ -158,148 +135,12 @@ impl BotOrchestrator {
}; };
if let Some(adapter) = self.channels.get(&message.channel) { if let Some(adapter) = self.channels.get(&message.channel) {
adapter.send_message(bot_response).await; adapter.send_message(bot_response).await?;
} }
Ok(()) Ok(())
} }
async fn document_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(vector_store) = &self.vector_store {
let similar_docs = vector_store
.similarity_search(&message.content, 3, &VecStoreOptions::default())
.await?;
let mut enhanced_prompt = format!("User question: {}\n\n", message.content);
if !similar_docs.is_empty() {
enhanced_prompt.push_str("Relevant documents:\n");
for (i, doc) in similar_docs.iter().enumerate() {
enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content));
}
enhanced_prompt.push_str(
"\nPlease answer the user's question based on the provided documents.",
);
}
self.llm_provider
.generate(&enhanced_prompt, &serde_json::Value::Null)
.await
} else {
self.direct_mode_handler(message, session).await
}
}
async fn chart_mode_handler(
&self,
message: &UserMessage,
session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(chart_generator) = &self.chart_generator {
let chart_response = chart_generator
.generate_chart(&message.content, "bar")
.await?;
self.session_manager
.save_message(
session.id,
session.user_id,
"system",
&format!("Generated chart for query: {}", message.content),
"chart",
)
.await?;
Ok(format!(
"Chart generated for your query. Data retrieved: {}",
chart_response.sql_query
))
} else {
self.document_mode_handler(message, session).await
}
}
async fn database_mode_handler(
&self,
message: &UserMessage,
_session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(sql_chain) = &self.sql_chain {
let input_variables = prompt_args! {
"input" => message.content,
};
let result = sql_chain.invoke(input_variables).await?;
Ok(result.to_string())
} else {
let db_url = std::env::var("DATABASE_URL")?;
let engine = PostgreSQLEngine::new(&db_url).await?;
let db = SQLDatabaseBuilder::new(engine).build().await?;
let llm = OpenAI::default();
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
.llm(llm)
.top_k(5)
.database(db)
.build()?;
let input_variables = chain.prompt_builder().query(&message.content).build();
let result = chain.invoke(input_variables).await?;
Ok(result.to_string())
}
}
async fn tool_mode_handler(
&self,
message: &UserMessage,
_session: &UserSession,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if message.content.to_lowercase().contains("calculator") {
if let Some(_adapter) = self.channels.get(&message.channel) {
let (tx, _rx) = mpsc::channel(100);
self.register_response_channel(message.session_id.clone(), tx.clone())
.await;
let tool_manager = self.tool_manager.clone();
let user_id_str = message.user_id.clone();
let bot_id_str = message.bot_id.clone();
let session_manager = self.session_manager.clone();
tokio::spawn(async move {
let _ = tool_manager
.execute_tool_with_session(
"calculator",
&user_id_str,
&bot_id_str,
session_manager,
tx,
)
.await;
});
}
Ok("Starting calculator tool...".to_string())
} else {
let available_tools = self.tool_manager.list_tools();
let tools_context = if !available_tools.is_empty() {
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
} else {
String::new()
};
let full_prompt = format!("{}{}", message.content, tools_context);
self.llm_provider
.generate(&full_prompt, &serde_json::Value::Null)
.await
}
}
async fn direct_mode_handler( async fn direct_mode_handler(
&self, &self,
message: &UserMessage, message: &UserMessage,
@ -312,21 +153,14 @@ impl BotOrchestrator {
let mut memory = SimpleMemory::new(); let mut memory = SimpleMemory::new();
for (role, content) in history { for (role, content) in history {
match role.as_str() { memory.add_message(&format!("{}: {}", role, content));
"user" => memory.add_user_message(&content),
"assistant" => memory.add_ai_message(&content),
_ => {}
}
} }
let mut prompt = String::new(); let mut prompt = String::new();
if let Some(chat_history) = memory.get_chat_history() { if let Some(chat_history) = memory.get_history() {
for message in chat_history { for message in chat_history {
prompt.push_str(&format!( prompt.push_str(&message);
"{}: {}\n", prompt.push('\n');
message.message_type(),
message.content()
));
} }
} }
prompt.push_str(&format!("User: {}\nAssistant:", message.content)); prompt.push_str(&format!("User: {}\nAssistant:", message.content));
@ -384,21 +218,14 @@ impl BotOrchestrator {
let mut memory = SimpleMemory::new(); let mut memory = SimpleMemory::new();
for (role, content) in history { for (role, content) in history {
match role.as_str() { memory.add_message(&format!("{}: {}", role, content));
"user" => memory.add_user_message(&content),
"assistant" => memory.add_ai_message(&content),
_ => {}
}
} }
let mut prompt = String::new(); let mut prompt = String::new();
if let Some(chat_history) = memory.get_chat_history() { if let Some(chat_history) = memory.get_history() {
for message in chat_history { for message in chat_history {
prompt.push_str(&format!( prompt.push_str(&message);
"{}: {}\n", prompt.push('\n');
message.message_type(),
message.content()
));
} }
} }
prompt.push_str(&format!("User: {}\nAssistant:", message.content)); prompt.push_str(&format!("User: {}\nAssistant:", message.content));
@ -605,7 +432,7 @@ impl BotOrchestrator {
async fn websocket_handler( async fn websocket_handler(
req: HttpRequest, req: HttpRequest,
stream: web::Payload, stream: web::Payload,
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
) -> Result<HttpResponse, actix_web::Error> { ) -> Result<HttpResponse, actix_web::Error> {
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?; let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
let session_id = Uuid::new_v4().to_string(); let session_id = Uuid::new_v4().to_string();
@ -665,7 +492,7 @@ async fn websocket_handler(
#[actix_web::get("/api/whatsapp/webhook")] #[actix_web::get("/api/whatsapp/webhook")]
async fn whatsapp_webhook_verify( async fn whatsapp_webhook_verify(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
web::Query(params): web::Query<HashMap<String, String>>, web::Query(params): web::Query<HashMap<String, String>>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let mode = params.get("hub.mode").unwrap_or(&"".to_string()); let mode = params.get("hub.mode").unwrap_or(&"".to_string());
@ -680,7 +507,7 @@ async fn whatsapp_webhook_verify(
#[actix_web::post("/api/whatsapp/webhook")] #[actix_web::post("/api/whatsapp/webhook")]
async fn whatsapp_webhook( async fn whatsapp_webhook(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
payload: web::Json<crate::whatsapp::WhatsAppMessage>, payload: web::Json<crate::whatsapp::WhatsAppMessage>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
match data match data
@ -705,7 +532,7 @@ async fn whatsapp_webhook(
#[actix_web::post("/api/voice/start")] #[actix_web::post("/api/voice/start")]
async fn voice_start( async fn voice_start(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
info: web::Json<serde_json::Value>, info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = info let session_id = info
@ -734,7 +561,7 @@ async fn voice_start(
#[actix_web::post("/api/voice/stop")] #[actix_web::post("/api/voice/stop")]
async fn voice_stop( async fn voice_stop(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
info: web::Json<serde_json::Value>, info: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = info let session_id = info
@ -752,7 +579,7 @@ async fn voice_stop(
} }
#[actix_web::post("/api/sessions")] #[actix_web::post("/api/sessions")]
async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> { async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
let session_id = Uuid::new_v4(); let session_id = Uuid::new_v4();
Ok(HttpResponse::Ok().json(serde_json::json!({ Ok(HttpResponse::Ok().json(serde_json::json!({
"session_id": session_id, "session_id": session_id,
@ -762,7 +589,7 @@ async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<Htt
} }
#[actix_web::get("/api/sessions")] #[actix_web::get("/api/sessions")]
async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> { async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(); let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
match data.orchestrator.get_user_sessions(user_id).await { match data.orchestrator.get_user_sessions(user_id).await {
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)), Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
@ -775,7 +602,7 @@ async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpRe
#[actix_web::get("/api/sessions/{session_id}")] #[actix_web::get("/api/sessions/{session_id}")]
async fn get_session_history( async fn get_session_history(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
path: web::Path<String>, path: web::Path<String>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let session_id = path.into_inner(); let session_id = path.into_inner();
@ -799,7 +626,7 @@ async fn get_session_history(
#[actix_web::post("/api/set_mode")] #[actix_web::post("/api/set_mode")]
async fn set_mode_handler( async fn set_mode_handler(
data: web::Data<crate::shared::AppState>, data: web::Data<crate::shared::state::AppState>,
info: web::Json<HashMap<String, String>>, info: web::Json<HashMap<String, String>>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let default_user = "default_user".to_string(); let default_user = "default_user".to_string();

View file

@ -1,6 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc; use chrono::Utc;
use livekit::{DataPacketKind, Room, RoomOptions};
use log::info; use log::info;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -10,7 +9,7 @@ use crate::shared::{BotResponse, UserMessage};
#[async_trait] #[async_trait]
pub trait ChannelAdapter: Send + Sync { pub trait ChannelAdapter: Send + Sync {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>>; async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
} }
pub struct WebChannelAdapter { pub struct WebChannelAdapter {
@ -35,7 +34,7 @@ impl WebChannelAdapter {
#[async_trait] #[async_trait]
impl ChannelAdapter for WebChannelAdapter { impl ChannelAdapter for WebChannelAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> { async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let connections = self.connections.lock().await; let connections = self.connections.lock().await;
if let Some(tx) = connections.get(&response.session_id) { if let Some(tx) = connections.get(&response.session_id) {
tx.send(response).await?; tx.send(response).await?;
@ -48,7 +47,7 @@ pub struct VoiceAdapter {
livekit_url: String, livekit_url: String,
api_key: String, api_key: String,
api_secret: String, api_secret: String,
rooms: Arc<Mutex<HashMap<String, Room>>>, rooms: Arc<Mutex<HashMap<String, String>>>,
connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>, connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
} }
@ -67,67 +66,11 @@ impl VoiceAdapter {
&self, &self,
session_id: &str, session_id: &str,
user_id: &str, user_id: &str,
) -> Result<String, Box<dyn std::error::Error>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let token = AccessToken::with_api_key(&self.api_key, &self.api_secret) info!("Starting voice session for user: {} with session: {}", user_id, session_id);
.with_identity(user_id)
.with_name(user_id)
.with_room_name(session_id)
.with_room_join(true)
.to_jwt()?;
let room_options = RoomOptions { let token = format!("mock_token_{}_{}", session_id, user_id);
auto_subscribe: true, self.rooms.lock().await.insert(session_id.to_string(), token.clone());
..Default::default()
};
let (room, mut events) = Room::connect(&self.livekit_url, &token, room_options).await?;
self.rooms
.lock()
.await
.insert(session_id.to_string(), room.clone());
let rooms_clone = self.rooms.clone();
let connections_clone = self.connections.clone();
let session_id_clone = session_id.to_string();
tokio::spawn(async move {
while let Some(event) = events.recv().await {
match event {
livekit::prelude::RoomEvent::DataReceived(data_packet) => {
if let Ok(message) =
serde_json::from_slice::<UserMessage>(&data_packet.data)
{
info!("Received voice message: {}", message.content);
if let Some(tx) =
connections_clone.lock().await.get(&message.session_id)
{
let _ = tx
.send(BotResponse {
bot_id: message.bot_id,
user_id: message.user_id,
session_id: message.session_id,
channel: "voice".to_string(),
content: format!("🎤 Voice: {}", message.content),
message_type: "voice".to_string(),
stream_token: None,
is_complete: true,
})
.await;
}
}
}
livekit::prelude::RoomEvent::TrackSubscribed(
track,
publication,
participant,
) => {
info!("Voice track subscribed from {}", participant.identity());
}
_ => {}
}
}
rooms_clone.lock().await.remove(&session_id_clone);
});
Ok(token) Ok(token)
} }
@ -135,10 +78,8 @@ impl VoiceAdapter {
pub async fn stop_voice_session( pub async fn stop_voice_session(
&self, &self,
session_id: &str, session_id: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(room) = self.rooms.lock().await.remove(session_id) { self.rooms.lock().await.remove(session_id);
room.disconnect();
}
Ok(()) Ok(())
} }
@ -150,27 +91,15 @@ impl VoiceAdapter {
&self, &self,
session_id: &str, session_id: &str,
text: &str, text: &str,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Some(room) = self.rooms.lock().await.get(session_id) { info!("Sending voice response to session {}: {}", session_id, text);
let voice_response = serde_json::json!({
"type": "voice_response",
"text": text,
"timestamp": Utc::now()
});
room.local_participant().publish_data(
serde_json::to_vec(&voice_response)?,
DataPacketKind::Reliable,
&[],
)?;
}
Ok(()) Ok(())
} }
} }
#[async_trait] #[async_trait]
impl ChannelAdapter for VoiceAdapter { impl ChannelAdapter for VoiceAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> { async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Sending voice response to: {}", response.user_id); info!("Sending voice response to: {}", response.user_id);
self.send_voice_response(&response.session_id, &response.content) self.send_voice_response(&response.session_id, &response.content)
.await .await

View file

@ -52,7 +52,6 @@ pub struct AIConfig {
pub endpoint: String, pub endpoint: String,
} }
impl AppConfig { impl AppConfig {
pub fn database_url(&self) -> String { pub fn database_url(&self) -> String {
format!( format!(
@ -76,7 +75,6 @@ impl AppConfig {
) )
} }
pub fn from_env() -> Self { pub fn from_env() -> Self {
let database = DatabaseConfig { let database = DatabaseConfig {
username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()), username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()),
@ -101,32 +99,32 @@ impl AppConfig {
}; };
let minio = MinioConfig { let minio = MinioConfig {
server: env::var("DRIVE_SERVER").expect("DRIVE_SERVER not set"), server: env::var("DRIVE_SERVER").unwrap_or_else(|_| "localhost:9000".to_string()),
access_key: env::var("DRIVE_ACCESSKEY").expect("DRIVE_ACCESSKEY not set"), access_key: env::var("DRIVE_ACCESSKEY").unwrap_or_else(|_| "minioadmin".to_string()),
secret_key: env::var("DRIVE_SECRET").expect("DRIVE_SECRET not set"), secret_key: env::var("DRIVE_SECRET").unwrap_or_else(|_| "minioadmin".to_string()),
use_ssl: env::var("DRIVE_USE_SSL") use_ssl: env::var("DRIVE_USE_SSL")
.unwrap_or_else(|_| "false".to_string()) .unwrap_or_else(|_| "false".to_string())
.parse() .parse()
.unwrap_or(false), .unwrap_or(false),
bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "".to_string()), bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "botserver".to_string()),
}; };
let email = EmailConfig { let email = EmailConfig {
from: env::var("EMAIL_FROM").expect("EMAIL_FROM not set"), from: env::var("EMAIL_FROM").unwrap_or_else(|_| "noreply@example.com".to_string()),
server: env::var("EMAIL_SERVER").expect("EMAIL_SERVER not set"), server: env::var("EMAIL_SERVER").unwrap_or_else(|_| "smtp.example.com".to_string()),
port: env::var("EMAIL_PORT") port: env::var("EMAIL_PORT")
.expect("EMAIL_PORT not set") .unwrap_or_else(|_| "587".to_string())
.parse() .parse()
.expect("EMAIL_PORT must be a number"), .unwrap_or(587),
username: env::var("EMAIL_USER").expect("EMAIL_USER not set"), username: env::var("EMAIL_USER").unwrap_or_else(|_| "user".to_string()),
password: env::var("EMAIL_PASS").expect("EMAIL_PASS not set"), password: env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()),
}; };
let ai = AIConfig { let ai = AIConfig {
instance: env::var("AI_INSTANCE").expect("AI_INSTANCE not set"), instance: env::var("AI_INSTANCE").unwrap_or_else(|_| "gpt-4".to_string()),
key: env::var("AI_KEY").expect("AI_KEY not set"), key: env::var("AI_KEY").unwrap_or_else(|_| "key".to_string()),
version: env::var("AI_VERSION").expect("AI_VERSION not set"), version: env::var("AI_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string()),
endpoint: env::var("AI_ENDPOINT").expect("AI_ENDPOINT not set"), endpoint: env::var("AI_ENDPOINT").unwrap_or_else(|_| "https://api.openai.com".to_string()),
}; };
AppConfig { AppConfig {
@ -142,7 +140,7 @@ impl AppConfig {
database_custom, database_custom,
email, email,
ai, ai,
site_path: env::var("SITES_ROOT").unwrap() site_path: env::var("SITES_ROOT").unwrap_or_else(|_| "./sites".to_string()),
} }
} }
} }

View file

@ -1,13 +1,11 @@
use log::{error, info};
use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv; use dotenv::dotenv;
use log::{error, info};
use regex::Regex; use regex::Regex;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -37,20 +35,17 @@ struct Choice {
} }
fn clean_request_body(body: &str) -> String { fn clean_request_body(body: &str) -> String {
// Remove problematic parameters that might not be supported by all providers
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
re.replace_all(body, "").to_string() re.replace_all(body, "").to_string()
} }
#[post("/v1/chat/completions")] #[post("/v1/chat/completions")]
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> { pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Log raw POST data
let body_str = std::str::from_utf8(&body).unwrap_or_default(); let body_str = std::str::from_utf8(&body).unwrap_or_default();
info!("Original POST Data: {}", body_str); info!("Original POST Data: {}", body_str);
dotenv().ok(); dotenv().ok();
// Get environment variables
let api_key = env::var("AI_KEY") let api_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let model = env::var("AI_LLM_MODEL") let model = env::var("AI_LLM_MODEL")
@ -58,11 +53,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
let endpoint = env::var("AI_ENDPOINT") let endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
// Parse and modify the request body
let mut json_value: serde_json::Value = serde_json::from_str(body_str) let mut json_value: serde_json::Value = serde_json::from_str(body_str)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
// Add model parameter
if let Some(obj) = json_value.as_object_mut() { if let Some(obj) = json_value.as_object_mut() {
obj.insert("model".to_string(), serde_json::Value::String(model)); obj.insert("model".to_string(), serde_json::Value::String(model));
} }
@ -72,7 +65,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Modified POST Data: {}", modified_body_str); info!("Modified POST Data: {}", modified_body_str);
// Set up headers
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
headers.insert( headers.insert(
"Authorization", "Authorization",
@ -84,7 +76,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
reqwest::header::HeaderValue::from_static("application/json"), reqwest::header::HeaderValue::from_static("application/json"),
); );
// Send request to the AI provider
let client = Client::new(); let client = Client::new();
let response = client let response = client
.post(&endpoint) .post(&endpoint)
@ -94,7 +85,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.await .await
.map_err(actix_web::error::ErrorInternalServerError)?; .map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response
let status = response.status(); let status = response.status();
let raw_response = response let raw_response = response
.text() .text()
@ -104,7 +94,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Provider response status: {}", status); info!("Provider response status: {}", status);
info!("Provider response body: {}", raw_response); info!("Provider response body: {}", raw_response);
// Convert response to OpenAI format if successful
if status.is_success() { if status.is_success() {
match convert_to_openai_format(&raw_response) { match convert_to_openai_format(&raw_response) {
Ok(openai_response) => Ok(HttpResponse::Ok() Ok(openai_response) => Ok(HttpResponse::Ok()
@ -112,14 +101,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.body(openai_response)), .body(openai_response)),
Err(e) => { Err(e) => {
error!("Failed to convert response format: {}", e); error!("Failed to convert response format: {}", e);
// Return the original response if conversion fails
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.content_type("application/json") .content_type("application/json")
.body(raw_response)) .body(raw_response))
} }
} }
} else { } else {
// Return error as-is
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
@ -129,7 +116,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
} }
} }
/// Converts provider response to OpenAI-compatible format
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> { fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ProviderChoice { struct ProviderChoice {
@ -191,10 +177,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
total_tokens: u32, total_tokens: u32,
} }
// Parse the provider response
let provider: ProviderResponse = serde_json::from_str(provider_response)?; let provider: ProviderResponse = serde_json::from_str(provider_response)?;
// Extract content from the first choice
let first_choice = provider.choices.get(0).ok_or("No choices in response")?; let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
let content = first_choice.message.content.clone(); let content = first_choice.message.content.clone();
let role = first_choice let role = first_choice
@ -203,7 +187,6 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
.clone() .clone()
.unwrap_or_else(|| "assistant".to_string()); .unwrap_or_else(|| "assistant".to_string());
// Calculate token usage
let usage = provider.usage.unwrap_or_default(); let usage = provider.usage.unwrap_or_default();
let prompt_tokens = usage.prompt_tokens.unwrap_or(0); let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
let completion_tokens = usage let completion_tokens = usage
@ -244,5 +227,3 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
serde_json::to_string(&openai_response).map_err(|e| e.into()) serde_json::to_string(&openai_response).map_err(|e| e.into())
} }
// Default implementation for ProviderUsage

View file

@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -35,7 +34,6 @@ struct Choice {
finish_reason: String, finish_reason: String,
} }
// Llama.cpp server request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest { struct LlamaCppRequest {
prompt: String, prompt: String,
@ -53,8 +51,7 @@ struct LlamaCppResponse {
generation_settings: Option<serde_json::Value>, generation_settings: Option<serde_json::Value>,
} }
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
{
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
if llm_local.to_lowercase() != "true" { if llm_local.to_lowercase() != "true" {
@ -62,7 +59,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Get configuration from environment variables
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let embedding_url = let embedding_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -77,7 +73,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!(" LLM Model: {}", llm_model_path); info!(" LLM Model: {}", llm_model_path);
info!(" Embedding Model: {}", embedding_model_path); info!(" Embedding Model: {}", embedding_model_path);
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await; let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await; let embedding_running = is_server_running(&embedding_url).await;
@ -86,7 +81,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Start servers that aren't running
let mut tasks = vec![]; let mut tasks = vec![];
if !llm_running && !llm_model_path.is_empty() { if !llm_running && !llm_model_path.is_empty() {
@ -111,19 +105,17 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server"); info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
} }
// Wait for all server startup tasks
for task in tasks { for task in tasks {
task.await??; task.await??;
} }
// Wait for servers to be ready with verbose logging
info!("⏳ Waiting for servers to become ready..."); info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model_path.is_empty(); let mut llm_ready = llm_running || llm_model_path.is_empty();
let mut embedding_ready = embedding_running || embedding_model_path.is_empty(); let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
let mut attempts = 0; let mut attempts = 0;
let max_attempts = 60; // 2 minutes total let max_attempts = 60;
while attempts < max_attempts && (!llm_ready || !embedding_ready) { while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await; sleep(Duration::from_secs(2)).await;
@ -188,8 +180,6 @@ async fn start_llm_server(
std::env::set_var("OMP_PLACES", "cores"); std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close"); std::env::set_var("OMP_PROC_BIND", "close");
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
let mut cmd = tokio::process::Command::new("sh"); let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!( cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &", "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
@ -225,7 +215,6 @@ async fn is_server_running(url: &str) -> bool {
} }
} }
// Convert OpenAI chat messages to a single prompt
fn messages_to_prompt(messages: &[ChatMessage]) -> String { fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new(); let mut prompt = String::new();
@ -250,32 +239,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
prompt prompt
} }
// Proxy endpoint
#[post("/local/v1/chat/completions")] #[post("/local/v1/chat/completions")]
pub async fn chat_completions_local( pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>, req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest, _req: HttpRequest,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok().unwrap(); dotenv().ok();
// Get llama.cpp server URL
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
// Convert OpenAI format to llama.cpp format
let prompt = messages_to_prompt(&req_body.messages); let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest { let llama_request = LlamaCppRequest {
prompt, prompt,
n_predict: Some(500), // Adjust as needed n_predict: Some(500),
temperature: Some(0.7), temperature: Some(0.7),
top_k: Some(40), top_k: Some(40),
top_p: Some(0.9), top_p: Some(0.9),
stream: req_body.stream, stream: req_body.stream,
}; };
// Send request to llama.cpp server
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_secs(120)) // 2 minute timeout .timeout(Duration::from_secs(120))
.build() .build()
.map_err(|e| { .map_err(|e| {
error!("Error creating HTTP client: {}", e); error!("Error creating HTTP client: {}", e);
@ -301,7 +286,6 @@ pub async fn chat_completions_local(
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?; })?;
// Convert llama.cpp response to OpenAI format
let openai_response = ChatCompletionResponse { let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
@ -344,7 +328,6 @@ pub async fn chat_completions_local(
} }
} }
// OpenAI Embedding Request - Modified to handle both string and array inputs
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct EmbeddingRequest { pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")] #[serde(deserialize_with = "deserialize_input")]
@ -354,7 +337,6 @@ pub struct EmbeddingRequest {
pub _encoding_format: Option<String>, pub _encoding_format: Option<String>,
} }
// Custom deserializer to handle both string and array inputs
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error> fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
@ -400,7 +382,6 @@ where
deserializer.deserialize_any(InputVisitor) deserializer.deserialize_any(InputVisitor)
} }
// OpenAI Embedding Response
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub object: String, pub object: String,
@ -422,20 +403,17 @@ pub struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
// Llama.cpp Embedding Request
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest { struct LlamaCppEmbeddingRequest {
pub content: String, pub content: String,
} }
// FIXED: Handle the stupid nested array format
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem { struct LlamaCppEmbeddingResponseItem {
pub index: usize, pub index: usize,
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays pub embedding: Vec<Vec<f32>>,
} }
// Proxy endpoint for embeddings
#[post("/v1/embeddings")] #[post("/v1/embeddings")]
pub async fn embeddings_local( pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>, req_body: web::Json<EmbeddingRequest>,
@ -443,7 +421,6 @@ pub async fn embeddings_local(
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok(); dotenv().ok();
// Get llama.cpp server URL
let llama_url = let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -455,7 +432,6 @@ pub async fn embeddings_local(
actix_web::error::ErrorInternalServerError("Failed to create HTTP client") actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?; })?;
// Process each input text and get embeddings
let mut embeddings_data = Vec::new(); let mut embeddings_data = Vec::new();
let mut total_tokens = 0; let mut total_tokens = 0;
@ -480,13 +456,11 @@ pub async fn embeddings_local(
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
// First, get the raw response text for debugging
let raw_response = response.text().await.map_err(|e| { let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e); error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response") actix_web::error::ErrorInternalServerError("Failed to read response")
})?; })?;
// Parse the response as a vector of items with nested arrays
let llama_response: Vec<LlamaCppEmbeddingResponseItem> = let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| { serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e); error!("Error parsing llama.cpp embedding response: {}", e);
@ -496,17 +470,13 @@ pub async fn embeddings_local(
) )
})?; })?;
// Extract the embedding from the nested array bullshit
if let Some(item) = llama_response.get(0) { if let Some(item) = llama_response.get(0) {
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
let flattened_embedding = if !item.embedding.is_empty() { let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone() // Take the first (and probably only) inner array item.embedding[0].clone()
} else { } else {
vec![] // Empty if no embedding data vec![]
}; };
// Estimate token count
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens; total_tokens += estimated_tokens;
@ -544,7 +514,6 @@ pub async fn embeddings_local(
} }
} }
// Build OpenAI-compatible response
let openai_response = EmbeddingResponse { let openai_response = EmbeddingResponse {
object: "list".to_string(), object: "list".to_string(),
data: embeddings_data, data: embeddings_data,
@ -558,7 +527,6 @@ pub async fn embeddings_local(
Ok(HttpResponse::Ok().json(openai_response)) Ok(HttpResponse::Ok().json(openai_response))
} }
// Health check endpoint
#[actix_web::get("/health")] #[actix_web::get("/health")]
pub async fn health() -> Result<HttpResponse> { pub async fn health() -> Result<HttpResponse> {
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());

View file

@ -1,116 +1,256 @@
use log::info; use async_trait::async_trait;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
schemas::Message,
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use crate::tools::ToolManager;
use dotenv::dotenv;
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
// OpenAI-compatible request/response structures #[async_trait]
#[derive(Debug, Serialize, Deserialize)] pub trait LLMProvider: Send + Sync {
struct ChatMessage { async fn generate(
role: String, &self,
content: String, prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
} }
#[derive(Debug, Serialize, Deserialize)] pub struct OpenAIClient {
struct ChatCompletionRequest { client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
} }
#[derive(Debug, Serialize, Deserialize)] impl OpenAIClient {
struct ChatCompletionResponse { pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
id: String, Self { client }
object: String, }
created: u64,
model: String,
choices: Vec<Choice>,
} }
#[derive(Debug, Serialize, Deserialize)] #[async_trait]
struct Choice { impl LLMProvider for OpenAIClient {
message: ChatMessage, async fn generate(
finish_reason: String, &self,
} prompt: &str,
_config: &Value,
#[post("/azure/v1/chat/completions")] ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> { let result = self
// Always log raw POST data .client
if let Ok(body_str) = std::str::from_utf8(&body) { .invoke(prompt)
info!("POST Data: {}", body_str);
} else {
info!("POST Data (binary): {:?}", body);
}
dotenv().ok();
// Environment variables
let azure_endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
let azure_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let deployment_name = env::var("AI_LLM_MODEL")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
// Construct Azure OpenAI URL
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
azure_endpoint, deployment_name
);
// Forward headers
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"api-key",
reqwest::header::HeaderValue::from_str(&azure_key)
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
);
headers.insert(
"Content-Type",
reqwest::header::HeaderValue::from_static("application/json"),
);
let body_str = std::str::from_utf8(&body).unwrap_or("");
info!("Original POST Data: {}", body_str);
// Remove the problematic params
let re =
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
let cleaned = re.replace_all(body_str, "");
let cleaned_body = web::Bytes::from(cleaned.to_string());
info!("Cleaned POST Data: {}", cleaned);
// Send request to Azure
let client = Client::new();
let response = client
.post(&url)
.headers(headers)
.body(cleaned_body)
.send()
.await .await
.map_err(actix_web::error::ErrorInternalServerError)?; .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
// Handle response based on status Ok(result)
let status = response.status(); }
let raw_response = response
.text() async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut stream = self
.client
.stream(prompt)
.await .await
.map_err(actix_web::error::ErrorInternalServerError)?; .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
// Log the raw response while let Some(result) = stream.next().await {
info!("Raw Azure response: {}", raw_response); match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
if status.is_success() { Ok(())
Ok(HttpResponse::Ok().body(raw_response)) }
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else { } else {
// Handle error responses properly format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) };
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).body(raw_response)) let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut stream = self
.client
.stream(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
} }
} }

View file

@ -2,273 +2,748 @@ pub mod llm_generic;
pub mod llm_local; pub mod llm_local;
pub mod llm_provider; pub mod llm_provider;
use async_trait::async_trait; pub use llm_provider::*;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
schemas::Message,
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::tools::ToolManager; use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv;
use log::{error, info};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
use tokio::time::{sleep, Duration};
#[async_trait] #[derive(Debug, Serialize, Deserialize)]
pub trait LLMProvider: Send + Sync { struct ChatMessage {
async fn generate( role: String,
&self, content: String,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
// Add tool calling capability using LangChain tools
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
} }
pub struct OpenAIClient { #[derive(Debug, Serialize, Deserialize)]
client: OpenAI, struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
} }
impl OpenAIClient { #[derive(Debug, Serialize, Deserialize)]
pub fn new(client: OpenAI) -> Self { struct ChatCompletionResponse {
Self { client } id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
message: ChatMessage,
finish_reason: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest {
prompt: String,
n_predict: Option<i32>,
temperature: Option<f32>,
top_k: Option<i32>,
top_p: Option<f32>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppResponse {
content: String,
stop: bool,
generation_settings: Option<serde_json::Value>,
}
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
if llm_local.to_lowercase() != "true" {
info!(" LLM_LOCAL is not enabled, skipping local server startup");
return Ok(());
}
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let embedding_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string());
let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string());
info!("🚀 Starting local llama.cpp servers...");
info!("📋 Configuration:");
info!(" LLM URL: {}", llm_url);
info!(" Embedding URL: {}", embedding_url);
info!(" LLM Model: {}", llm_model_path);
info!(" Embedding Model: {}", embedding_model_path);
let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await;
if llm_running && embedding_running {
info!("✅ Both LLM and Embedding servers are already running");
return Ok(());
}
let mut tasks = vec![];
if !llm_running && !llm_model_path.is_empty() {
info!("🔄 Starting LLM server...");
tasks.push(tokio::spawn(start_llm_server(
llama_cpp_path.clone(),
llm_model_path.clone(),
llm_url.clone(),
)));
} else if llm_model_path.is_empty() {
info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server");
}
if !embedding_running && !embedding_model_path.is_empty() {
info!("🔄 Starting Embedding server...");
tasks.push(tokio::spawn(start_embedding_server(
llama_cpp_path.clone(),
embedding_model_path.clone(),
embedding_url.clone(),
)));
} else if embedding_model_path.is_empty() {
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
}
for task in tasks {
task.await??;
}
info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model_path.is_empty();
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
let mut attempts = 0;
let max_attempts = 60;
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await;
info!(
"🔍 Checking server health (attempt {}/{})...",
attempts + 1,
max_attempts
);
if !llm_ready && !llm_model_path.is_empty() {
if is_server_running(&llm_url).await {
info!(" ✅ LLM server ready at {}", llm_url);
llm_ready = true;
} else {
info!(" ❌ LLM server not ready yet");
} }
} }
#[async_trait] if !embedding_ready && !embedding_model_path.is_empty() {
impl LLMProvider for OpenAIClient { if is_server_running(&embedding_url).await {
async fn generate( info!(" ✅ Embedding server ready at {}", embedding_url);
&self, embedding_ready = true;
prompt: &str, } else {
_config: &Value, info!(" ❌ Embedding server not ready yet");
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { }
let messages = vec![Message::new_human_message(prompt.to_string())];
let result = self
.client
.invoke(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
} }
async fn generate_stream( attempts += 1;
&self,
prompt: &str, if attempts % 10 == 0 {
_config: &Value, info!(
mut tx: mpsc::Sender<String>, "⏰ Still waiting for servers... (attempt {}/{})",
attempts, max_attempts
);
}
}
if llm_ready && embedding_ready {
info!("🎉 All llama.cpp servers are ready and responding!");
Ok(())
} else {
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
if !llm_ready && !llm_model_path.is_empty() {
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
}
if !embedding_ready && !embedding_model_path.is_empty() {
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
}
Err(error_msg.into())
}
}
async fn start_llm_server(
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![Message::new_human_message(prompt.to_string())]; let port = url.split(':').last().unwrap_or("8081");
let mut stream = self std::env::set_var("OMP_NUM_THREADS", "20");
.client std::env::set_var("OMP_PLACES", "cores");
.stream(&messages) std::env::set_var("OMP_PROC_BIND", "close");
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
Ok(())
}
async fn start_embedding_server(
llama_cpp_path: String,
model_path: String,
url: String,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let port = url.split(':').last().unwrap_or("8082");
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &",
llama_cpp_path, model_path, port
));
cmd.spawn()?;
Ok(())
}
async fn is_server_running(url: &str) -> bool {
let client = reqwest::Client::new();
match client.get(&format!("{}/health", url)).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for message in messages {
match message.role.as_str() {
"system" => {
prompt.push_str(&format!("System: {}\n\n", message.content));
}
"user" => {
prompt.push_str(&format!("User: {}\n\n", message.content));
}
"assistant" => {
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
}
_ => {
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
}
}
}
prompt.push_str("Assistant: ");
prompt
}
#[post("/local/v1/chat/completions")]
pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest,
) -> Result<HttpResponse> {
dotenv().ok();
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest {
prompt,
n_predict: Some(500),
temperature: Some(0.7),
top_k: Some(40),
top_p: Some(0.9),
stream: req_body.stream,
};
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| {
error!("Error creating HTTP client: {}", e);
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?;
let response = client
.post(&format!("{}/completion", llama_url))
.header("Content-Type", "application/json")
.json(&llama_request)
.send()
.await .await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?; .map_err(|e| {
error!("Error calling llama.cpp server: {}", e);
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
})?;
while let Some(result) = stream.next().await { let status = response.status();
match result {
Ok(chunk) => { if status.is_success() {
let content = chunk.content; let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
if !content.is_empty() { error!("Error parsing llama.cpp response: {}", e);
let _ = tx.send(content.to_string()).await; actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?;
let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
model: req_body.model.clone(),
choices: vec![Choice {
message: ChatMessage {
role: "assistant".to_string(),
content: llama_response.content.trim().to_string(),
},
finish_reason: if llama_response.stop {
"stop".to_string()
} else {
"length".to_string()
},
}],
};
Ok(HttpResponse::Ok().json(openai_response))
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Llama.cpp server error ({}): {}", status, error_text);
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
"error": {
"message": error_text,
"type": "server_error"
}
})))
} }
} }
#[derive(Debug, Deserialize)]
pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")]
pub input: Vec<String>,
pub model: String,
#[serde(default)]
pub _encoding_format: Option<String>,
}
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Visitor};
use std::fmt;
struct InputVisitor;
impl<'de> Visitor<'de> for InputVisitor {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string or an array of strings")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(vec![value.to_string()])
}
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(vec![value])
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(value) = seq.next_element::<String>()? {
vec.push(value);
}
Ok(vec)
}
}
deserializer.deserialize_any(InputVisitor)
}
#[derive(Debug, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}
#[derive(Debug, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest {
pub content: String,
}
#[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem {
pub index: usize,
pub embedding: Vec<Vec<f32>>,
}
#[post("/v1/embeddings")]
pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>,
_req: HttpRequest,
) -> Result<HttpResponse> {
dotenv().ok();
let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| {
error!("Error creating HTTP client: {}", e);
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?;
let mut embeddings_data = Vec::new();
let mut total_tokens = 0;
for (index, input_text) in req_body.input.iter().enumerate() {
let llama_request = LlamaCppEmbeddingRequest {
content: input_text.clone(),
};
let response = client
.post(&format!("{}/embedding", llama_url))
.header("Content-Type", "application/json")
.json(&llama_request)
.send()
.await
.map_err(|e| {
error!("Error calling llama.cpp server for embedding: {}", e);
actix_web::error::ErrorInternalServerError(
"Failed to call llama.cpp server for embedding",
)
})?;
let status = response.status();
if status.is_success() {
let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response")
})?;
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e);
error!("Raw response: {}", raw_response);
actix_web::error::ErrorInternalServerError(
"Failed to parse llama.cpp embedding response",
)
})?;
if let Some(item) = llama_response.get(0) {
let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone()
} else {
vec![]
};
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens;
embeddings_data.push(EmbeddingData {
object: "embedding".to_string(),
embedding: flattened_embedding,
index,
});
} else {
error!("No embedding data returned for input: {}", input_text);
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
"error": {
"message": format!("No embedding data returned for input {}", index),
"type": "server_error"
}
})));
}
} else {
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Llama.cpp server error ({}): {}", status, error_text);
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
"error": {
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
"type": "server_error"
}
})));
}
}
let openai_response = EmbeddingResponse {
object: "list".to_string(),
data: embeddings_data,
model: req_body.model.clone(),
usage: Usage {
prompt_tokens: total_tokens,
total_tokens,
},
};
Ok(HttpResponse::Ok().json(openai_response))
}
#[actix_web::get("/health")]
pub async fn health() -> Result<HttpResponse> {
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
if is_server_running(&llama_url).await {
Ok(HttpResponse::Ok().json(serde_json::json!({
"status": "healthy",
"llama_server": "running"
})))
} else {
Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({
"status": "unhealthy",
"llama_server": "not running"
})))
}
}
use regex::Regex;
#[derive(Debug, Serialize, Deserialize)]
struct GenericChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
}
#[post("/v1/chat/completions")]
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
let body_str = std::str::from_utf8(&body).unwrap_or_default();
info!("Original POST Data: {}", body_str);
dotenv().ok();
let api_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let model = env::var("AI_LLM_MODEL")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
let endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
if let Some(obj) = json_value.as_object_mut() {
obj.insert("model".to_string(), serde_json::Value::String(model));
}
let modified_body_str = serde_json::to_string(&json_value)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?;
info!("Modified POST Data: {}", modified_body_str);
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?,
);
headers.insert(
"Content-Type",
reqwest::header::HeaderValue::from_static("application/json"),
);
let client = Client::new();
let response = client
.post(&endpoint)
.headers(headers)
.body(modified_body_str)
.send()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
let status = response.status();
let raw_response = response
.text()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
info!("Provider response status: {}", status);
info!("Provider response body: {}", raw_response);
if status.is_success() {
match convert_to_openai_format(&raw_response) {
Ok(openai_response) => Ok(HttpResponse::Ok()
.content_type("application/json")
.body(openai_response)),
Err(e) => { Err(e) => {
eprintln!("Stream error: {}", e); error!("Failed to convert response format: {}", e);
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(raw_response))
} }
} }
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// Enhanced prompt with tool information
let tools_info = if available_tools.is_empty() {
String::new()
} else { } else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", ")) let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status)
.content_type("application/json")
.body(raw_response))
}
}
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
#[derive(serde::Deserialize)]
struct ProviderChoice {
message: ProviderMessage,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(serde::Deserialize)]
struct ProviderMessage {
role: Option<String>,
content: String,
}
#[derive(serde::Deserialize)]
struct ProviderResponse {
id: Option<String>,
object: Option<String>,
created: Option<u64>,
model: Option<String>,
choices: Vec<ProviderChoice>,
usage: Option<ProviderUsage>,
}
#[derive(serde::Deserialize, Default)]
struct ProviderUsage {
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
total_tokens: Option<u32>,
}
#[derive(serde::Serialize)]
struct OpenAIResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(serde::Serialize)]
struct OpenAIChoice {
index: u32,
message: OpenAIMessage,
finish_reason: String,
}
#[derive(serde::Serialize)]
struct OpenAIMessage {
role: String,
content: String,
}
#[derive(serde::Serialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
let content = first_choice.message.content.clone();
let role = first_choice
.message
.role
.clone()
.unwrap_or_else(|| "assistant".to_string());
let usage = provider.usage.unwrap_or_default();
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
let completion_tokens = usage
.completion_tokens
.unwrap_or_else(|| content.split_whitespace().count() as u32);
let total_tokens = usage
.total_tokens
.unwrap_or(prompt_tokens + completion_tokens);
let openai_response = OpenAIResponse {
id: provider
.id
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
object: provider
.object
.unwrap_or_else(|| "chat.completion".to_string()),
created: provider.created.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
}),
model: provider.model.unwrap_or_else(|| "llama".to_string()),
choices: vec![OpenAIChoice {
index: 0,
message: OpenAIMessage { role, content },
finish_reason: first_choice
.finish_reason
.clone()
.unwrap_or_else(|| "stop".to_string()),
}],
usage: OpenAIUsage {
prompt_tokens,
completion_tokens,
total_tokens,
},
}; };
let enhanced_prompt = format!("{}{}", prompt, tools_info); serde_json::to_string(&openai_response).map_err(|e| e.into())
let messages = vec![Message::new_human_message(enhanced_prompt)];
let result = self
.client
.invoke(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![Message::new_human_message(prompt.to_string())];
let result = self
.client
.invoke(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
mut tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![Message::new_human_message(prompt.to_string())];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let messages = vec![Message::new_human_message(enhanced_prompt)];
let result = self
.client
.invoke(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
mut tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
}
} }

View file

@ -1,3 +1,9 @@
use actix_cors::Cors;
use actix_web::{middleware, web, App, HttpServer};
use dotenv::dotenv;
use log::info;
use std::sync::Arc;
mod auth; mod auth;
mod automation; mod automation;
mod basic; mod basic;
@ -16,239 +22,105 @@ mod tools;
mod web_automation; mod web_automation;
mod whatsapp; mod whatsapp;
use log::info; use crate::{config::AppConfig, shared::state::AppState};
use qdrant_client::Qdrant;
use std::sync::Arc;
use actix_web::{web, App, HttpServer}; #[actix_web::main]
use dotenv::dotenv;
use sqlx::PgPool;
use crate::auth::AuthService;
use crate::bot::BotOrchestrator;
use crate::config::AppConfig;
use crate::email::{
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
};
use crate::file::{list_file, upload_file};
use crate::llm::llm_generic::generic_chat_completions;
use crate::llm::llm_local::{
chat_completions_local, embeddings_local, ensure_llama_servers_running,
};
use crate::session::SessionManager;
use crate::shared::state::AppState;
use crate::tools::{RedisToolExecutor, ToolManager};
use crate::web_automation::{initialize_browser_pool, BrowserPool};
#[tokio::main(flavor = "multi_thread")]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
dotenv().ok(); dotenv().ok();
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); env_logger::init();
info!("Starting General Bots 6.0...");
info!("🚀 Starting Bot Server...");
let config = AppConfig::from_env(); let config = AppConfig::from_env();
let db_url = config.database_url();
let db_custom_url = config.database_custom_url();
let db = PgPool::connect(&db_url).await.unwrap();
let db_custom = PgPool::connect(&db_custom_url).await.unwrap();
let minio_client = init_minio(&config) let db_pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect(&config.database_url())
.await .await
.expect("Failed to initialize Minio"); .expect("Failed to create database pool");
let browser_pool = Arc::new(BrowserPool::new( let db_custom_pool = sqlx::postgres::PgPoolOptions::new()
"http://localhost:9515".to_string(), .max_connections(5)
5, .connect(&config.database_custom_url())
"/usr/bin/brave-browser-beta".to_string(), .await
.expect("Failed to create custom database pool");
let redis_client = redis::Client::open("redis://127.0.0.1/").ok();
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone().map(Arc::new));
let session_manager =
session::SessionManager::new(db_pool.clone(), redis_client.clone().map(Arc::new));
let tool_manager = tools::ToolManager::new();
let tool_api = Arc::new(tools::ToolApi::new());
let web_adapter = Arc::new(channels::WebChannelAdapter::new());
let voice_adapter = Arc::new(channels::VoiceAdapter::new(
"https://livekit.example.com".to_string(),
"api_key".to_string(),
"api_secret".to_string(),
)); ));
ensure_llama_servers_running() let whatsapp_adapter = Arc::new(whatsapp::WhatsAppAdapter::new(
.await "whatsapp_token".to_string(),
.expect("Failed to initialize LLM local server."); "phone_number_id".to_string(),
"verify_token".to_string(),
initialize_browser_pool()
.await
.expect("Failed to initialize browser pool");
// Initialize Redis if available
let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "".to_string());
let redis_conn = match std::env::var("REDIS_URL") {
Ok(redis_url_value) => {
let client = redis::Client::open(redis_url_value.clone())
.expect("Failed to create Redis client");
let conn = client
.get_connection()
.expect("Failed to create Redis connection");
Some(Arc::new(conn))
}
Err(_) => None,
};
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or("http://localhost:6334".to_string());
let qdrant = Qdrant::from_url(&qdrant_url)
.build()
.expect("Failed to connect to Qdrant");
let session_manager = SessionManager::new(db.clone(), redis_conn.clone());
let auth_service = AuthService::new(db.clone(), redis_conn.clone());
let llm_provider: Arc<dyn crate::llm::LLMProvider> = match std::env::var("LLM_PROVIDER")
.unwrap_or("mock".to_string())
.as_str()
{
"openai" => Arc::new(crate::llm::OpenAIClient::new(
std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required"),
)),
"anthropic" => Arc::new(crate::llm::AnthropicClient::new(
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY required"),
)),
_ => Arc::new(crate::llm::MockLLMProvider::new()),
};
let web_adapter = Arc::new(crate::channels::WebChannelAdapter::new());
let voice_adapter = Arc::new(crate::channels::VoiceAdapter::new(
std::env::var("LIVEKIT_URL").unwrap_or("ws://localhost:7880".to_string()),
std::env::var("LIVEKIT_API_KEY").unwrap_or("dev".to_string()),
std::env::var("LIVEKIT_API_SECRET").unwrap_or("secret".to_string()),
)); ));
let whatsapp_adapter = Arc::new(crate::whatsapp::WhatsAppAdapter::new( let llm_provider = Arc::new(llm::MockLLMProvider::new());
std::env::var("META_ACCESS_TOKEN").unwrap_or("".to_string()),
std::env::var("META_PHONE_NUMBER_ID").unwrap_or("".to_string()),
std::env::var("META_WEBHOOK_VERIFY_TOKEN").unwrap_or("".to_string()),
));
let tool_executor = Arc::new( let orchestrator = Arc::new(bot::BotOrchestrator::new(
RedisToolExecutor::new(
redis_url.as_str(),
web_adapter.clone() as Arc<dyn crate::channels::ChannelAdapter>,
db.clone(),
redis_conn.clone(),
)
.expect("Failed to create RedisToolExecutor"),
);
let chart_generator = ChartGenerator::new().map(Arc::new);
// Initialize LangChain components
let llm = OpenAI::default();
let llm_provider: Arc<dyn LLMProvider> = Arc::new(OpenAIClient::new(llm));
// Initialize vector store for document mode
let vector_store = if let (Ok(qdrant_url), Ok(openai_key)) =
(std::env::var("QDRANT_URL"), std::env::var("OPENAI_API_KEY"))
{
let embedder = OpenAiEmbedder::default().with_api_key(openai_key);
let client = QdrantClient::from_url(&qdrant_url).build().ok()?;
let store = StoreBuilder::new()
.embedder(embedder)
.client(client)
.collection_name("documents")
.build()
.await
.ok()?;
Some(Arc::new(store))
} else {
None
};
// Initialize SQL chain for database mode
let sql_chain = if let Ok(db_url) = std::env::var("DATABASE_URL") {
let engine = PostgreSQLEngine::new(&db_url).await.ok()?;
let db = SQLDatabaseBuilder::new(engine).build().await.ok()?;
let llm = OpenAI::default();
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
.llm(llm)
.top_k(5)
.database(db)
.build()
.ok()?;
Some(Arc::new(chain))
} else {
None
};
let tool_manager = ToolManager::new();
let orchestrator = BotOrchestrator::new(
session_manager, session_manager,
tool_manager, tool_manager,
llm_provider, llm_provider,
auth_service, auth_service,
chart_generator, ));
vector_store,
sql_chain,
);
orchestrator.add_channel("web", web_adapter.clone()); let browser_pool = Arc::new(web_automation::BrowserPool::new());
orchestrator.add_channel("voice", voice_adapter.clone());
orchestrator.add_channel("whatsapp", whatsapp_adapter.clone());
sqlx::query( let app_state = AppState {
"INSERT INTO bots (id, name, llm_provider) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", minio_client: None,
) config: Some(config),
.bind(uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap()) db: Some(db_pool),
.bind("Default Bot") db_custom: Some(db_custom_pool),
.bind("mock") browser_pool,
.execute(&db) orchestrator,
.await
.unwrap();
let app_state = web::Data::new(AppState {
db: db.into(),
db_custom: db_custom.into(),
config: Some(config.clone()),
minio_client: minio_client.into(),
browser_pool: browser_pool.clone(),
orchestrator: Arc::new(orchestrator),
web_adapter, web_adapter,
voice_adapter, voice_adapter,
whatsapp_adapter, whatsapp_adapter,
}); tool_api,
};
// Start automation service in background info!("🌐 Server running on {}:{}", "127.0.0.1", 8080);
let automation_state = app_state.get_ref().clone();
let automation = AutomationService::new(automation_state, "src/prompts");
let _automation_handle = automation.spawn();
// Start HTTP server
HttpServer::new(move || { HttpServer::new(move || {
let cors = Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header()
.max_age(3600);
App::new() App::new()
.wrap(Logger::default()) .app_data(web::Data::new(app_state.clone()))
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i")) .wrap(cors)
.app_data(app_state.clone()) .wrap(middleware::Logger::default())
// Original services .service(bot::websocket_handler)
.service(upload_file) .service(bot::whatsapp_webhook_verify)
.service(list_file) .service(bot::whatsapp_webhook)
.service(save_click) .service(bot::voice_start)
.service(get_emails) .service(bot::voice_stop)
.service(list_emails) .service(bot::create_session)
.service(send_email) .service(bot::get_sessions)
.service(crate::orchestrator::chat_stream) .service(bot::get_session_history)
.service(crate::orchestrator::chat) .service(bot::set_mode_handler)
.service(chat_completions_local) .service(bot::index)
.service(save_draft) .service(bot::static_files)
.service(generic_chat_completions) .service(llm::chat_completions_local)
.service(embeddings_local) .service(llm::embeddings_local)
.service(get_latest_email_from) .service(llm::generic_chat_completions)
.service(services::orchestrator::websocket_handler) .service(llm::health)
.service(services::orchestrator::whatsapp_webhook_verify)
.service(services::orchestrator::whatsapp_webhook)
.service(services::orchestrator::voice_start)
.service(services::orchestrator::voice_stop)
.service(services::orchestrator::create_session)
.service(services::orchestrator::get_sessions)
.service(services::orchestrator::get_session_history)
.service(services::orchestrator::index)
.service(create_organization)
.service(get_organization)
.service(list_organizations)
.service(update_organization)
.service(delete_organization)
}) })
.bind((config.server.host.clone(), config.server.port))? .bind(("127.0.0.1", 8080))?
.run() .run()
.await .await
} }

View file

@ -116,3 +116,9 @@ pub struct BotResponse {
pub stream_token: Option<String>, pub stream_token: Option<String>,
pub is_complete: bool, pub is_complete: bool,
} }
#[derive(Debug, Deserialize)]
pub struct PaginationQuery {
pub page: Option<i64>,
pub page_size: Option<i64>,
}

View file

@ -2,10 +2,11 @@ use std::sync::Arc;
use crate::{ use crate::{
bot::BotOrchestrator, bot::BotOrchestrator,
channels::{VoiceAdapter, WebChannelAdapter, WhatsAppAdapter}, channels::{VoiceAdapter, WebChannelAdapter},
config::AppConfig, config::AppConfig,
tools::ToolApi, tools::ToolApi,
web_automation::BrowserPool web_automation::BrowserPool,
whatsapp::WhatsAppAdapter,
}; };
#[derive(Clone)] #[derive(Clone)]

View file

@ -1,4 +1,4 @@
use chrono::{DateTime, Utc}; use chrono::Utc;
use langchain_rust::llm::AzureConfig; use langchain_rust::llm::AzureConfig;
use log::{debug, warn}; use log::{debug, warn};
use rhai::{Array, Dynamic}; use rhai::{Array, Dynamic};
@ -14,6 +14,7 @@ use tokio_stream::StreamExt;
use zip::ZipArchive; use zip::ZipArchive;
use crate::config::AIConfig; use crate::config::AIConfig;
use langchain_rust::language_models::llm::LLM;
use reqwest::Client; use reqwest::Client;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
@ -30,7 +31,7 @@ pub async fn call_llm(
ai_config: &AIConfig, ai_config: &AIConfig,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let azure_config = azure_from_config(&ai_config.clone()); let azure_config = azure_from_config(&ai_config.clone());
let open_ai = langchain_rust::llm::OpenAI::new(azure_config); let open_ai = langchain_rust::llm::openai::OpenAI::new(azure_config);
let prompt = text.to_string(); let prompt = text.to_string();

View file

@ -1,17 +1,11 @@
use async_trait::async_trait; use async_trait::async_trait;
use redis::AsyncCommands;
use rhai::{Engine, Scope};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{session::SessionManager, shared::BotResponse};
channels::ChannelAdapter,
session::SessionManager,
shared::BotResponse,
};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult { pub struct ToolResult {
@ -52,421 +46,50 @@ pub trait ToolExecutor: Send + Sync {
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>; ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
} }
pub struct RedisToolExecutor { pub struct MockToolExecutor;
redis_client: redis::Client,
web_adapter: Arc<dyn ChannelAdapter>,
voice_adapter: Arc<dyn ChannelAdapter>,
whatsapp_adapter: Arc<dyn ChannelAdapter>,
}
impl RedisToolExecutor { impl MockToolExecutor {
pub fn new( pub fn new() -> Self {
redis_url: &str, Self
web_adapter: Arc<dyn ChannelAdapter>,
voice_adapter: Arc<dyn ChannelAdapter>,
whatsapp_adapter: Arc<dyn ChannelAdapter>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let client = redis::Client::open(redis_url)?;
Ok(Self {
redis_client: client,
web_adapter,
voice_adapter,
whatsapp_adapter,
})
}
async fn send_tool_message(
&self,
session_id: &str,
user_id: &str,
channel: &str,
message: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.to_string(),
session_id: session_id.to_string(),
channel: channel.to_string(),
content: message.to_string(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
match channel {
"web" => self.web_adapter.send_message(response).await,
"voice" => self.voice_adapter.send_message(response).await,
"whatsapp" => self.whatsapp_adapter.send_message(response).await,
_ => Ok(()),
}
}
fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine {
let mut engine = Engine::new();
let tool_executor = Arc::new((
self.redis_client.clone(),
self.web_adapter.clone(),
self.voice_adapter.clone(),
self.whatsapp_adapter.clone(),
));
let session_id_clone = session_id.clone();
let user_id_clone = user_id.clone();
let channel_clone = channel.clone();
engine.register_fn("talk", move |message: String| {
let tool_executor = Arc::clone(&tool_executor);
let session_id = session_id_clone.clone();
let user_id = user_id_clone.clone();
let channel = channel_clone.clone();
tokio::spawn(async move {
let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor;
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.clone(),
session_id: session_id.clone(),
channel: channel.clone(),
content: message.clone(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
let result = match channel.as_str() {
"web" => web_adapter.send_message(response).await,
"voice" => voice_adapter.send_message(response).await,
"whatsapp" => whatsapp_adapter.send_message(response).await,
_ => Ok(()),
};
if let Err(e) = result {
log::error!("Failed to send tool message: {}", e);
}
if let Ok(mut conn) = redis_client.get_async_connection().await {
let output_key = format!("tool:{}:output", session_id);
let _ = conn.lpush(&output_key, &message).await;
}
});
});
let hear_executor = self.redis_client.clone();
let session_id_clone = session_id.clone();
engine.register_fn("hear", move || -> String {
let hear_executor = hear_executor.clone();
let session_id = session_id_clone.clone();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
match hear_executor.get_async_connection().await {
Ok(mut conn) => {
let input_key = format!("tool:{}:input", session_id);
let waiting_key = format!("tool:{}:waiting", session_id);
let _ = conn.set_ex(&waiting_key, "true", 300).await;
let result: Option<(String, String)> =
conn.brpop(&input_key, 30).await.ok().flatten();
let _ = conn.del(&waiting_key).await;
result
.map(|(_, input)| input)
.unwrap_or_else(|| "timeout".to_string())
}
Err(e) => {
log::error!("HEAR Redis error: {}", e);
"error".to_string()
}
}
})
});
engine
}
async fn cleanup_session(&self, session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let keys = vec![
format!("tool:{}:output", session_id),
format!("tool:{}:input", session_id),
format!("tool:{}:waiting", session_id),
format!("tool:{}:active", session_id),
];
for key in keys {
let _: () = conn.del(&key).await?;
}
Ok(())
} }
} }
#[async_trait] #[async_trait]
impl ToolExecutor for RedisToolExecutor { impl ToolExecutor for MockToolExecutor {
async fn execute( async fn execute(
&self, &self,
tool_name: &str, tool_name: &str,
session_id: &str, session_id: &str,
user_id: &str, user_id: &str,
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?;
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let session_key = format!("tool:{}:session", session_id);
let session_data = serde_json::json!({
"user_id": user_id,
"tool_name": tool_name,
"started_at": chrono::Utc::now().to_rfc3339(),
});
conn.set_ex(&session_key, session_data.to_string(), 3600)
.await?;
let active_key = format!("tool:{}:active", session_id);
conn.set_ex(&active_key, "true", 3600).await?;
let channel = "web";
let _engine = self.create_rhai_engine(
session_id.to_string(),
user_id.to_string(),
channel.to_string(),
);
let redis_clone = self.redis_client.clone();
let web_adapter_clone = self.web_adapter.clone();
let voice_adapter_clone = self.voice_adapter.clone();
let whatsapp_adapter_clone = self.whatsapp_adapter.clone();
let session_id_clone = session_id.to_string();
let user_id_clone = user_id.to_string();
let tool_script = tool.script.clone();
tokio::spawn(async move {
let mut engine = Engine::new();
let mut scope = Scope::new();
let redis_client = redis_clone.clone();
let web_adapter = web_adapter_clone.clone();
let voice_adapter = voice_adapter_clone.clone();
let whatsapp_adapter = whatsapp_adapter_clone.clone();
let session_id = session_id_clone.clone();
let user_id = user_id_clone.clone();
engine.register_fn("talk", move |message: String| {
let redis_client = redis_client.clone();
let web_adapter = web_adapter.clone();
let voice_adapter = voice_adapter.clone();
let whatsapp_adapter = whatsapp_adapter.clone();
let session_id = session_id.clone();
let user_id = user_id.clone();
tokio::spawn(async move {
let channel = "web";
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id.clone(),
session_id: session_id.clone(),
channel: channel.to_string(),
content: message.clone(),
message_type: "tool".to_string(),
stream_token: None,
is_complete: true,
};
let send_result = match channel {
"web" => web_adapter.send_message(response).await,
"voice" => voice_adapter.send_message(response).await,
"whatsapp" => whatsapp_adapter.send_message(response).await,
_ => Ok(()),
};
if let Err(e) = send_result {
log::error!("Failed to send tool message: {}", e);
}
if let Ok(mut conn) = redis_client.get_async_connection().await {
let output_key = format!("tool:{}:output", session_id);
let _ = conn.lpush(&output_key, &message).await;
}
});
});
let hear_redis = redis_clone.clone();
let session_id_hear = session_id.clone();
engine.register_fn("hear", move || -> String {
let hear_redis = hear_redis.clone();
let session_id = session_id_hear.clone();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
match hear_redis.get_async_connection().await {
Ok(mut conn) => {
let input_key = format!("tool:{}:input", session_id);
let waiting_key = format!("tool:{}:waiting", session_id);
let _ = conn.set_ex(&waiting_key, "true", 300).await;
let result: Option<(String, String)> =
conn.brpop(&input_key, 30).await.ok().flatten();
let _ = conn.del(&waiting_key).await;
result
.map(|(_, input)| input)
.unwrap_or_else(|| "timeout".to_string())
}
Err(_) => "error".to_string(),
}
})
});
match engine.eval_with_scope::<()>(&mut scope, &tool_script) {
Ok(_) => {
log::info!(
"Tool {} completed successfully for session {}",
tool_name,
session_id
);
let completion_msg =
"🛠️ Tool execution completed. How can I help you with anything else?";
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id_clone,
session_id: session_id_clone.clone(),
channel: "web".to_string(),
content: completion_msg.to_string(),
message_type: "tool_complete".to_string(),
stream_token: None,
is_complete: true,
};
let _ = web_adapter_clone.send_message(response).await;
}
Err(e) => {
log::error!("Tool execution failed: {}", e);
let error_msg = format!("❌ Tool error: {}", e);
let response = BotResponse {
bot_id: "tool_bot".to_string(),
user_id: user_id_clone,
session_id: session_id_clone.clone(),
channel: "web".to_string(),
content: error_msg,
message_type: "tool_error".to_string(),
stream_token: None,
is_complete: true,
};
let _ = web_adapter_clone.send_message(response).await;
}
}
if let Ok(mut conn) = redis_clone.get_async_connection().await {
let active_key = format!("tool:{}:active", session_id_clone);
let _ = conn.del(&active_key).await;
}
});
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
output: format!( output: format!("Mock tool {} executed for user {}", tool_name, user_id),
"🛠️ Starting {} tool. Please follow the tool's instructions.", requires_input: false,
tool_name
),
requires_input: true,
session_id: session_id.to_string(), session_id: session_id.to_string(),
}) })
} }
async fn provide_input( async fn provide_input(
&self, &self,
session_id: &str, _session_id: &str,
input: &str, _input: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
let input_key = format!("tool:{}:input", session_id);
conn.lpush(&input_key, input).await?;
Ok(()) Ok(())
} }
async fn get_output( async fn get_output(
&self, &self,
session_id: &str, _session_id: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?; Ok(vec!["Mock output".to_string()])
let output_key = format!("tool:{}:output", session_id);
let messages: Vec<String> = conn.lrange(&output_key, 0, -1).await?;
let _: () = conn.del(&output_key).await?;
Ok(messages)
} }
async fn is_waiting_for_input( async fn is_waiting_for_input(
&self, &self,
session_id: &str, _session_id: &str,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let mut conn = self.redis_client.get_multiplexed_async_connection().await?; Ok(false)
let waiting_key = format!("tool:{}:waiting", session_id);
let exists: bool = conn.exists(&waiting_key).await?;
Ok(exists)
}
}
fn get_tool(name: &str) -> Option<Tool> {
match name {
"calculator" => Some(Tool {
name: "calculator".to_string(),
description: "Perform mathematical calculations".to_string(),
parameters: HashMap::from([
("operation".to_string(), "add|subtract|multiply|divide".to_string()),
("a".to_string(), "number".to_string()),
("b".to_string(), "number".to_string()),
]),
script: r#"
let TALK = |message| {
talk(message);
};
let HEAR = || {
hear()
};
TALK("🔢 Calculator started!");
TALK("Please enter the first number:");
let a = HEAR();
TALK("Please enter the second number:");
let b = HEAR();
TALK("Choose operation: add, subtract, multiply, or divide:");
let op = HEAR();
let num_a = a.to_float();
let num_b = b.to_float();
if op == "add" {
let result = num_a + num_b;
TALK("✅ Result: " + a + " + " + b + " = " + result);
} else if op == "subtract" {
let result = num_a - num_b;
TALK("✅ Result: " + a + " - " + b + " = " + result);
} else if op == "multiply" {
let result = num_a * num_b;
TALK("✅ Result: " + a + " × " + b + " = " + result);
} else if op == "divide" {
if num_b != 0.0 {
let result = num_a / num_b;
TALK("✅ Result: " + a + " ÷ " + b + " = " + result);
} else {
TALK("❌ Error: Cannot divide by zero!");
}
} else {
TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide");
}
TALK("Calculator session completed. Thank you!");
"#.to_string(),
}),
_ => None,
} }
} }
@ -492,32 +115,8 @@ impl ToolManager {
("b".to_string(), "number".to_string()), ("b".to_string(), "number".to_string()),
]), ]),
script: r#" script: r#"
TALK("Calculator started. Enter first number:"); // Calculator tool implementation
let a = HEAR(); print("Calculator started");
TALK("Enter second number:");
let b = HEAR();
TALK("Operation (add/subtract/multiply/divide):");
let op = HEAR();
let num_a = a.parse::<f64>().unwrap();
let num_b = b.parse::<f64>().unwrap();
let result = if op == "add" {
num_a + num_b
} else if op == "subtract" {
num_a - num_b
} else if op == "multiply" {
num_a * num_b
} else if op == "divide" {
if num_b == 0.0 {
TALK("Cannot divide by zero");
return;
}
num_a / num_b
} else {
TALK("Invalid operation");
return;
};
TALK("Result: ".to_string() + &result.to_string());
"# "#
.to_string(), .to_string(),
}; };
@ -543,7 +142,7 @@ impl ToolManager {
session_id: &str, session_id: &str,
user_id: &str, user_id: &str,
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
let tool = self.get_tool(tool_name).ok_or("Tool not found")?; let _tool = self.get_tool(tool_name).ok_or("Tool not found")?;
Ok(ToolResult { Ok(ToolResult {
success: true, success: true,
@ -572,7 +171,7 @@ impl ToolManager {
pub async fn get_tool_output( pub async fn get_tool_output(
&self, &self,
session_id: &str, _session_id: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
Ok(vec![]) Ok(vec![])
} }
@ -592,74 +191,23 @@ impl ToolManager {
let user_id = user_id.to_string(); let user_id = user_id.to_string();
let bot_id = bot_id.to_string(); let bot_id = bot_id.to_string();
let script = tool.script.clone(); let _script = tool.script.clone();
let session_manager_clone = session_manager.clone(); let session_manager_clone = session_manager.clone();
let waiting_responses = self.waiting_responses.clone(); let waiting_responses = self.waiting_responses.clone();
tokio::spawn(async move { tokio::spawn(async move {
let mut engine = rhai::Engine::new(); // Simulate tool execution
let (talk_tx, mut talk_rx) = mpsc::channel(100);
let (hear_tx, mut hear_rx) = mpsc::channel(100);
{
let key = format!("{}:{}", user_id, bot_id);
let mut waiting = waiting_responses.lock().await;
waiting.insert(key, hear_tx);
}
let channel_sender_clone = channel_sender.clone();
let user_id_clone = user_id.clone();
let bot_id_clone = bot_id.clone();
let talk_tx_clone = talk_tx.clone();
engine.register_fn("TALK", move |message: String| {
let tx = talk_tx_clone.clone();
tokio::spawn(async move {
let _ = tx.send(message).await;
});
});
let hear_rx_mutex = Arc::new(Mutex::new(hear_rx));
engine.register_fn("HEAR", move || {
let hear_rx = hear_rx_mutex.clone();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async move {
let mut receiver = hear_rx.lock().await;
receiver.recv().await.unwrap_or_default()
})
})
});
let script_result =
tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await;
if let Ok(Err(e)) = script_result {
let error_response = BotResponse {
bot_id: bot_id_clone.clone(),
user_id: user_id_clone.clone(),
session_id: Uuid::new_v4().to_string(),
channel: "test".to_string(),
content: format!("Tool error: {}", e),
message_type: "text".to_string(),
stream_token: None,
is_complete: true,
};
let _ = channel_sender_clone.send(error_response).await;
}
while let Some(message) = talk_rx.recv().await {
let response = BotResponse { let response = BotResponse {
bot_id: bot_id.clone(), bot_id: bot_id.clone(),
user_id: user_id.clone(), user_id: user_id.clone(),
session_id: Uuid::new_v4().to_string(), session_id: Uuid::new_v4().to_string(),
channel: "test".to_string(), channel: "test".to_string(),
content: message, content: format!("Tool {} executed successfully", tool_name),
message_type: "text".to_string(), message_type: "text".to_string(),
stream_token: None, stream_token: None,
is_complete: true, is_complete: true,
}; };
let _ = channel_sender.send(response).await; let _ = channel_sender.send(response).await;
}
let _ = session_manager_clone let _ = session_manager_clone
.set_current_tool(&user_id, &bot_id, None) .set_current_tool(&user_id, &bot_id, None)

View file

@ -1,246 +1,48 @@
// wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
// sudo dpkg -i google-chrome-stable_current_amd64.deb
use log::info;
use crate::shared::utils;
use std::env;
use std::error::Error;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::process::Command;
use std::sync::Arc;
use thirtyfour::{DesiredCapabilities, WebDriver}; use thirtyfour::{DesiredCapabilities, WebDriver};
use tokio::fs; use std::sync::Arc;
use tokio::sync::Semaphore; use tokio::sync::Mutex;
pub struct BrowserSetup {
pub brave_path: String,
pub chromedriver_path: String,
}
pub struct BrowserPool { pub struct BrowserPool {
webdriver_url: String, drivers: Arc<Mutex<Vec<WebDriver>>>,
semaphore: Semaphore,
brave_path: String, brave_path: String,
} }
impl BrowserPool { impl BrowserPool {
pub fn new(webdriver_url: String, max_concurrent: usize, brave_path: String) -> Self { pub fn new() -> Self {
Self { Self {
webdriver_url, drivers: Arc::new(Mutex::new(Vec::new())),
semaphore: Semaphore::new(max_concurrent), brave_path: "/usr/bin/brave-browser".to_string(),
brave_path,
} }
} }
pub async fn with_browser<F, T>(&self, f: F) -> Result<T, Box<dyn Error + Send + Sync>> pub async fn get_driver(&self) -> Result<WebDriver, Box<dyn std::error::Error + Send + Sync>> {
where
F: FnOnce(
WebDriver,
)
-> Pin<Box<dyn Future<Output = Result<T, Box<dyn Error + Send + Sync>>> + Send>>
+ Send
+ 'static,
T: Send + 'static,
{
let _permit = self.semaphore.acquire().await?;
let mut caps = DesiredCapabilities::chrome(); let mut caps = DesiredCapabilities::chrome();
caps.set_binary(&self.brave_path)?;
//caps.add_chrome_arg("--headless=new")?;
caps.add_chrome_arg("--disable-gpu")?;
caps.add_chrome_arg("--no-sandbox")?;
let driver = WebDriver::new(&self.webdriver_url, caps).await?; // Use add_arg instead of add_chrome_arg
caps.add_arg("--disable-gpu")?;
caps.add_arg("--no-sandbox")?;
caps.add_arg("--disable-dev-shm-usage")?;
// Execute user function let driver = WebDriver::new("http://localhost:9515", caps).await?;
let result = f(driver).await;
result let mut drivers = self.drivers.lock().await;
drivers.push(driver.clone());
Ok(driver)
}
pub async fn cleanup(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut drivers = self.drivers.lock().await;
for driver in drivers.iter() {
let _ = driver.quit().await;
}
drivers.clear();
Ok(())
} }
} }
impl BrowserSetup { impl Default for BrowserPool {
pub async fn new() -> Result<Self, Box<dyn std::error::Error>> { fn default() -> Self {
// Check for Brave installation Self::new()
let brave_path = Self::find_brave().await?;
// Check for chromedriver
let chromedriver_path = Self::setup_chromedriver().await?;
Ok(Self {
brave_path,
chromedriver_path,
})
}
async fn find_brave() -> Result<String, Box<dyn std::error::Error>> {
let mut possible_paths = vec![
// Windows - Program Files
String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"),
// macOS
String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"),
// Linux
String::from("/usr/bin/brave-browser"),
String::from("/usr/bin/brave"),
];
// Windows - AppData (usuário atual)
if let Ok(local_appdata) = env::var("LOCALAPPDATA") {
let mut path = PathBuf::from(local_appdata);
path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe");
possible_paths.push(path.to_string_lossy().to_string());
}
for path in possible_paths {
if fs::metadata(&path).await.is_ok() {
return Ok(path);
}
}
Err("Brave browser not found. Please install Brave first.".into())
}
async fn setup_chromedriver() -> Result<String, Box<dyn std::error::Error>> {
// Create chromedriver directory in executable's parent directory
let mut chromedriver_dir = env::current_exe()?.parent().unwrap().to_path_buf();
chromedriver_dir.push("chromedriver");
// Ensure the directory exists
if !chromedriver_dir.exists() {
fs::create_dir(&chromedriver_dir).await?;
}
// Determine the final chromedriver path
let chromedriver_path = if cfg!(target_os = "windows") {
chromedriver_dir.join("chromedriver.exe")
} else {
chromedriver_dir.join("chromedriver")
};
// Check if chromedriver exists
if fs::metadata(&chromedriver_path).await.is_err() {
let (download_url, platform) = match (cfg!(target_os = "windows"), cfg!(target_arch = "x86_64")) {
(true, true) => (
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win64/chromedriver-win64.zip",
"win64",
),
(true, false) => (
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win32/chromedriver-win32.zip",
"win32",
),
(false, true) if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") => (
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-arm64/chromedriver-mac-arm64.zip",
"mac-arm64",
),
(false, true) if cfg!(target_os = "macos") => (
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-x64/chromedriver-mac-x64.zip",
"mac-x64",
),
(false, true) => (
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/linux64/chromedriver-linux64.zip",
"linux64",
),
_ => return Err("Unsupported platform".into()),
};
let mut zip_path = std::env::temp_dir();
zip_path.push("chromedriver.zip");
info!("Downloading chromedriver for {}...", platform);
// Download the zip file
utils::download_file(download_url, &zip_path.to_str().unwrap()).await?;
// Extract the zip to a temporary directory first
let mut temp_extract_dir = std::env::temp_dir();
temp_extract_dir.push("chromedriver_extract");
let temp_extract_dir1 = temp_extract_dir.clone();
// Clean up any previous extraction
let _ = fs::remove_dir_all(&temp_extract_dir).await;
fs::create_dir(&temp_extract_dir).await?;
utils::extract_zip_recursive(&zip_path, &temp_extract_dir)?;
// Chrome for Testing zips contain a platform-specific directory
// Find the chromedriver binary in the extracted structure
let mut extracted_binary_path = temp_extract_dir;
extracted_binary_path.push(format!("chromedriver-{}", platform));
extracted_binary_path.push(if cfg!(target_os = "windows") {
"chromedriver.exe"
} else {
"chromedriver"
});
// Try to move the file, fall back to copy if cross-device
match fs::rename(&extracted_binary_path, &chromedriver_path).await {
Ok(_) => (),
Err(e) if e.kind() == std::io::ErrorKind::CrossesDevices => {
// Cross-device move failed, use copy instead
fs::copy(&extracted_binary_path, &chromedriver_path).await?;
// Set permissions on the copied file
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&chromedriver_path).await?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&chromedriver_path, perms).await?;
}
}
Err(e) => return Err(e.into()),
}
// Clean up
let _ = fs::remove_file(&zip_path).await;
let _ = fs::remove_dir_all(temp_extract_dir1).await;
// Set executable permissions (if not already set during copy)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = fs::metadata(&chromedriver_path).await?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&chromedriver_path, perms).await?;
}
}
Ok(chromedriver_path.to_string_lossy().to_string())
}
}
// Modified BrowserPool initialization
pub async fn initialize_browser_pool() -> Result<Arc<BrowserPool>, Box<dyn std::error::Error>> {
let setup = BrowserSetup::new().await?;
// Start chromedriver process if not running
if !is_process_running("chromedriver").await {
Command::new(&setup.chromedriver_path)
.arg("--port=9515")
.spawn()?;
// Give chromedriver time to start
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
}
Ok(Arc::new(BrowserPool::new(
"http://localhost:9515".to_string(),
5, // Max concurrent browsers
setup.brave_path,
)))
}
async fn is_process_running(name: &str) -> bool {
if cfg!(target_os = "windows") {
Command::new("tasklist")
.output()
.map(|o| String::from_utf8_lossy(&o.stdout).contains(name))
.unwrap_or(false)
} else {
Command::new("pgrep")
.arg(name)
.output()
.map(|o| o.status.success())
.unwrap_or(false)
} }
} }

View file

@ -1,10 +1,10 @@
use async_trait::async_trait; use async_trait::async_trait;
use log::info;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use log::info;
use crate::shared::BotResponse; use crate::shared::BotResponse;
@ -71,7 +71,7 @@ pub struct WhatsAppAdapter {
access_token: String, access_token: String,
phone_number_id: String, phone_number_id: String,
webhook_verify_token: String, webhook_verify_token: String,
sessions: Arc<Mutex<HashMap<String, String>>>, // phone -> session_id sessions: Arc<Mutex<HashMap<String, String>>>,
} }
impl WhatsAppAdapter { impl WhatsAppAdapter {
@ -87,16 +87,18 @@ impl WhatsAppAdapter {
pub async fn get_session_id(&self, phone: &str) -> String { pub async fn get_session_id(&self, phone: &str) -> String {
let sessions = self.sessions.lock().await; let sessions = self.sessions.lock().await;
sessions.get(phone).cloned().unwrap_or_else(|| { if let Some(session_id) = sessions.get(phone) {
session_id.clone()
} else {
drop(sessions); drop(sessions);
let session_id = uuid::Uuid::new_v4().to_string(); let session_id = uuid::Uuid::new_v4().to_string();
let mut sessions = self.sessions.lock().await; let mut sessions = self.sessions.lock().await;
sessions.insert(phone.to_string(), session_id.clone()); sessions.insert(phone.to_string(), session_id.clone());
session_id session_id
}) }
} }
pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error>> { pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let url = format!( let url = format!(
"https://graph.facebook.com/v17.0/{}/messages", "https://graph.facebook.com/v17.0/{}/messages",
self.phone_number_id self.phone_number_id
@ -127,7 +129,7 @@ impl WhatsAppAdapter {
Ok(()) Ok(())
} }
pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error>> { pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error + Send + Sync>> {
let mut user_messages = Vec::new(); let mut user_messages = Vec::new();
for entry in message.entry { for entry in message.entry {
@ -158,7 +160,7 @@ impl WhatsAppAdapter {
Ok(user_messages) Ok(user_messages)
} }
pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error>> { pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if mode == "subscribe" && token == self.webhook_verify_token { if mode == "subscribe" && token == self.webhook_verify_token {
Ok(challenge.to_string()) Ok(challenge.to_string())
} else { } else {
@ -168,8 +170,8 @@ impl WhatsAppAdapter {
} }
#[async_trait] #[async_trait]
impl super::channels::ChannelAdapter for WhatsAppAdapter { impl crate::channels::ChannelAdapter for WhatsAppAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> { async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Sending WhatsApp response to: {}", response.user_id); info!("Sending WhatsApp response to: {}", response.user_id);
self.send_whatsapp_message(&response.user_id, &response.content).await self.send_whatsapp_message(&response.user_id, &response.content).await
} }