- Just more errors to fix.
This commit is contained in:
parent
5330180c36
commit
66e2ed2ad1
20 changed files with 1224 additions and 1876 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -2,3 +2,4 @@ target
|
||||||
.env
|
.env
|
||||||
*.env
|
*.env
|
||||||
work
|
work
|
||||||
|
*.txt
|
||||||
|
|
|
||||||
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -909,6 +909,7 @@ dependencies = [
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thirtyfour",
|
"thirtyfour",
|
||||||
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
|
||||||
|
|
@ -59,3 +59,4 @@ tracing-subscriber = { version = "0.3", features = ["fmt"] }
|
||||||
urlencoding = "2.1"
|
urlencoding = "2.1"
|
||||||
uuid = { version = "1.11", features = ["serde", "v4"] }
|
uuid = { version = "1.11", features = ["serde", "v4"] }
|
||||||
zip = "2.2"
|
zip = "2.2"
|
||||||
|
time = "0.3.44"
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ EOF
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use livekit::{DataPacketKind, Room, RoomOptions};
|
use livekit::prelude::*;
|
||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -202,7 +202,7 @@ impl VoiceAdapter {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(event) = events.recv().await {
|
while let Some(event) = events.recv().await {
|
||||||
match event {
|
match event {
|
||||||
livekit::prelude::RoomEvent::DataReceived(data_packet) => {
|
RoomEvent::DataReceived(data_packet) => {
|
||||||
if let Ok(message) =
|
if let Ok(message) =
|
||||||
serde_json::from_slice::<UserMessage>(&data_packet.data)
|
serde_json::from_slice::<UserMessage>(&data_packet.data)
|
||||||
{
|
{
|
||||||
|
|
@ -225,11 +225,7 @@ impl VoiceAdapter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
livekit::prelude::RoomEvent::TrackSubscribed(
|
RoomEvent::TrackSubscribed(track, publication, participant) => {
|
||||||
track,
|
|
||||||
publication,
|
|
||||||
participant,
|
|
||||||
) => {
|
|
||||||
info!("Voice track subscribed from {}", participant.identity());
|
info!("Voice track subscribed from {}", participant.identity());
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
|
|
@ -246,7 +242,7 @@ impl VoiceAdapter {
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
if let Some(room) = self.rooms.lock().await.remove(session_id) {
|
if let Some(room) = self.rooms.lock().await.remove(session_id) {
|
||||||
room.disconnect();
|
room.disconnect().await;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -267,11 +263,13 @@ impl VoiceAdapter {
|
||||||
"timestamp": Utc::now()
|
"timestamp": Utc::now()
|
||||||
});
|
});
|
||||||
|
|
||||||
room.local_participant().publish_data(
|
room.local_participant()
|
||||||
|
.publish_data(
|
||||||
serde_json::to_vec(&voice_response)?,
|
serde_json::to_vec(&voice_response)?,
|
||||||
DataPacketKind::Reliable,
|
DataPacketKind::Reliable,
|
||||||
&[],
|
&[],
|
||||||
)?;
|
)
|
||||||
|
.await?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -294,7 +292,6 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -323,7 +320,6 @@ struct Choice {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama.cpp server request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct LlamaCppRequest {
|
struct LlamaCppRequest {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
@ -341,8 +337,7 @@ struct LlamaCppResponse {
|
||||||
generation_settings: Option<serde_json::Value>,
|
generation_settings: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
{
|
|
||||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||||
|
|
||||||
if llm_local.to_lowercase() != "true" {
|
if llm_local.to_lowercase() != "true" {
|
||||||
|
|
@ -350,7 +345,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get configuration from environment variables
|
|
||||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
let embedding_url =
|
let embedding_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
@ -365,7 +359,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!(" LLM Model: {}", llm_model_path);
|
info!(" LLM Model: {}", llm_model_path);
|
||||||
info!(" Embedding Model: {}", embedding_model_path);
|
info!(" Embedding Model: {}", embedding_model_path);
|
||||||
|
|
||||||
// Check if servers are already running
|
|
||||||
let llm_running = is_server_running(&llm_url).await;
|
let llm_running = is_server_running(&llm_url).await;
|
||||||
let embedding_running = is_server_running(&embedding_url).await;
|
let embedding_running = is_server_running(&embedding_url).await;
|
||||||
|
|
||||||
|
|
@ -374,7 +367,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start servers that aren't running
|
|
||||||
let mut tasks = vec![];
|
let mut tasks = vec![];
|
||||||
|
|
||||||
if !llm_running && !llm_model_path.is_empty() {
|
if !llm_running && !llm_model_path.is_empty() {
|
||||||
|
|
@ -399,19 +391,17 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all server startup tasks
|
|
||||||
for task in tasks {
|
for task in tasks {
|
||||||
task.await??;
|
task.await??;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for servers to be ready with verbose logging
|
|
||||||
info!("⏳ Waiting for servers to become ready...");
|
info!("⏳ Waiting for servers to become ready...");
|
||||||
|
|
||||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
let max_attempts = 60; // 2 minutes total
|
let max_attempts = 60;
|
||||||
|
|
||||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||||
sleep(Duration::from_secs(2)).await;
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
|
@ -476,8 +466,6 @@ async fn start_llm_server(
|
||||||
std::env::set_var("OMP_PLACES", "cores");
|
std::env::set_var("OMP_PLACES", "cores");
|
||||||
std::env::set_var("OMP_PROC_BIND", "close");
|
std::env::set_var("OMP_PROC_BIND", "close");
|
||||||
|
|
||||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
|
||||||
|
|
||||||
let mut cmd = tokio::process::Command::new("sh");
|
let mut cmd = tokio::process::Command::new("sh");
|
||||||
cmd.arg("-c").arg(format!(
|
cmd.arg("-c").arg(format!(
|
||||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||||
|
|
@ -513,7 +501,6 @@ async fn is_server_running(url: &str) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert OpenAI chat messages to a single prompt
|
|
||||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
|
@ -538,32 +525,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy endpoint
|
|
||||||
#[post("/local/v1/chat/completions")]
|
#[post("/local/v1/chat/completions")]
|
||||||
pub async fn chat_completions_local(
|
pub async fn chat_completions_local(
|
||||||
req_body: web::Json<ChatCompletionRequest>,
|
req_body: web::Json<ChatCompletionRequest>,
|
||||||
_req: HttpRequest,
|
_req: HttpRequest,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok().unwrap();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get llama.cpp server URL
|
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
||||||
// Convert OpenAI format to llama.cpp format
|
|
||||||
let prompt = messages_to_prompt(&req_body.messages);
|
let prompt = messages_to_prompt(&req_body.messages);
|
||||||
|
|
||||||
let llama_request = LlamaCppRequest {
|
let llama_request = LlamaCppRequest {
|
||||||
prompt,
|
prompt,
|
||||||
n_predict: Some(500), // Adjust as needed
|
n_predict: Some(500),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_k: Some(40),
|
top_k: Some(40),
|
||||||
top_p: Some(0.9),
|
top_p: Some(0.9),
|
||||||
stream: req_body.stream,
|
stream: req_body.stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Send request to llama.cpp server
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
.timeout(Duration::from_secs(120))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Error creating HTTP client: {}", e);
|
error!("Error creating HTTP client: {}", e);
|
||||||
|
|
@ -589,7 +572,6 @@ pub async fn chat_completions_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Convert llama.cpp response to OpenAI format
|
|
||||||
let openai_response = ChatCompletionResponse {
|
let openai_response = ChatCompletionResponse {
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||||
object: "chat.completion".to_string(),
|
object: "chat.completion".to_string(),
|
||||||
|
|
@ -632,7 +614,6 @@ pub async fn chat_completions_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
#[serde(deserialize_with = "deserialize_input")]
|
#[serde(deserialize_with = "deserialize_input")]
|
||||||
|
|
@ -642,7 +623,6 @@ pub struct EmbeddingRequest {
|
||||||
pub _encoding_format: Option<String>,
|
pub _encoding_format: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom deserializer to handle both string and array inputs
|
|
||||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
where
|
where
|
||||||
D: serde::Deserializer<'de>,
|
D: serde::Deserializer<'de>,
|
||||||
|
|
@ -688,7 +668,6 @@ where
|
||||||
deserializer.deserialize_any(InputVisitor)
|
deserializer.deserialize_any(InputVisitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Embedding Response
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
|
@ -710,20 +689,17 @@ pub struct Usage {
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama.cpp Embedding Request
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct LlamaCppEmbeddingRequest {
|
struct LlamaCppEmbeddingRequest {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXED: Handle the stupid nested array format
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct LlamaCppEmbeddingResponseItem {
|
struct LlamaCppEmbeddingResponseItem {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
pub embedding: Vec<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy endpoint for embeddings
|
|
||||||
#[post("/v1/embeddings")]
|
#[post("/v1/embeddings")]
|
||||||
pub async fn embeddings_local(
|
pub async fn embeddings_local(
|
||||||
req_body: web::Json<EmbeddingRequest>,
|
req_body: web::Json<EmbeddingRequest>,
|
||||||
|
|
@ -731,7 +707,6 @@ pub async fn embeddings_local(
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get llama.cpp server URL
|
|
||||||
let llama_url =
|
let llama_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
||||||
|
|
@ -743,7 +718,6 @@ pub async fn embeddings_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Process each input text and get embeddings
|
|
||||||
let mut embeddings_data = Vec::new();
|
let mut embeddings_data = Vec::new();
|
||||||
let mut total_tokens = 0;
|
let mut total_tokens = 0;
|
||||||
|
|
||||||
|
|
@ -768,13 +742,11 @@ pub async fn embeddings_local(
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
// First, get the raw response text for debugging
|
|
||||||
let raw_response = response.text().await.map_err(|e| {
|
let raw_response = response.text().await.map_err(|e| {
|
||||||
error!("Error reading response text: {}", e);
|
error!("Error reading response text: {}", e);
|
||||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Parse the response as a vector of items with nested arrays
|
|
||||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||||
serde_json::from_str(&raw_response).map_err(|e| {
|
serde_json::from_str(&raw_response).map_err(|e| {
|
||||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||||
|
|
@ -784,17 +756,13 @@ pub async fn embeddings_local(
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Extract the embedding from the nested array bullshit
|
|
||||||
if let Some(item) = llama_response.get(0) {
|
if let Some(item) = llama_response.get(0) {
|
||||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
|
||||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
|
||||||
let flattened_embedding = if !item.embedding.is_empty() {
|
let flattened_embedding = if !item.embedding.is_empty() {
|
||||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
item.embedding[0].clone()
|
||||||
} else {
|
} else {
|
||||||
vec![] // Empty if no embedding data
|
vec![]
|
||||||
};
|
};
|
||||||
|
|
||||||
// Estimate token count
|
|
||||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||||
total_tokens += estimated_tokens;
|
total_tokens += estimated_tokens;
|
||||||
|
|
||||||
|
|
@ -832,7 +800,6 @@ pub async fn embeddings_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build OpenAI-compatible response
|
|
||||||
let openai_response = EmbeddingResponse {
|
let openai_response = EmbeddingResponse {
|
||||||
object: "list".to_string(),
|
object: "list".to_string(),
|
||||||
data: embeddings_data,
|
data: embeddings_data,
|
||||||
|
|
@ -846,7 +813,6 @@ pub async fn embeddings_local(
|
||||||
Ok(HttpResponse::Ok().json(openai_response))
|
Ok(HttpResponse::Ok().json(openai_response))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health check endpoint
|
|
||||||
#[actix_web::get("/health")]
|
#[actix_web::get("/health")]
|
||||||
pub async fn health() -> Result<HttpResponse> {
|
pub async fn health() -> Result<HttpResponse> {
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
@ -1008,6 +974,8 @@ pub mod llm_generic;
|
||||||
pub mod llm_local;
|
pub mod llm_local;
|
||||||
pub mod llm_provider;
|
pub mod llm_provider;
|
||||||
|
|
||||||
|
pub use llm_provider::*;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use langchain_rust::{
|
use langchain_rust::{
|
||||||
|
|
@ -1036,7 +1004,6 @@ pub trait LLMProvider: Send + Sync {
|
||||||
tx: mpsc::Sender<String>,
|
tx: mpsc::Sender<String>,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
// Add tool calling capability using LangChain tools
|
|
||||||
async fn generate_with_tools(
|
async fn generate_with_tools(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
|
|
@ -1116,7 +1083,6 @@ impl LLMProvider for OpenAIClient {
|
||||||
_session_id: &str,
|
_session_id: &str,
|
||||||
_user_id: &str,
|
_user_id: &str,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// Enhanced prompt with tool information
|
|
||||||
let tools_info = if available_tools.is_empty() {
|
let tools_info = if available_tools.is_empty() {
|
||||||
String::new()
|
String::new()
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1279,133 +1245,14 @@ impl LLMProvider for MockLLMProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use log::info;
|
|
||||||
|
|
||||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use regex::Regex;
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ChatMessage {
|
|
||||||
role: String,
|
|
||||||
content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionRequest {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<ChatMessage>,
|
|
||||||
stream: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionResponse {
|
|
||||||
id: String,
|
|
||||||
object: String,
|
|
||||||
created: u64,
|
|
||||||
model: String,
|
|
||||||
choices: Vec<Choice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct Choice {
|
|
||||||
message: ChatMessage,
|
|
||||||
finish_reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/azure/v1/chat/completions")]
|
|
||||||
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
|
||||||
// Always log raw POST data
|
|
||||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
|
||||||
info!("POST Data: {}", body_str);
|
|
||||||
} else {
|
|
||||||
info!("POST Data (binary): {:?}", body);
|
|
||||||
}
|
|
||||||
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
// Environment variables
|
|
||||||
let azure_endpoint = env::var("AI_ENDPOINT")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
|
||||||
let azure_key = env::var("AI_KEY")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
|
||||||
let deployment_name = env::var("AI_LLM_MODEL")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
|
||||||
|
|
||||||
// Construct Azure OpenAI URL
|
|
||||||
let url = format!(
|
|
||||||
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
|
|
||||||
azure_endpoint, deployment_name
|
|
||||||
);
|
|
||||||
|
|
||||||
// Forward headers
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
"api-key",
|
|
||||||
reqwest::header::HeaderValue::from_str(&azure_key)
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"Content-Type",
|
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
|
||||||
);
|
|
||||||
|
|
||||||
let body_str = std::str::from_utf8(&body).unwrap_or("");
|
|
||||||
info!("Original POST Data: {}", body_str);
|
|
||||||
|
|
||||||
// Remove the problematic params
|
|
||||||
let re =
|
|
||||||
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
|
|
||||||
let cleaned = re.replace_all(body_str, "");
|
|
||||||
let cleaned_body = web::Bytes::from(cleaned.to_string());
|
|
||||||
|
|
||||||
info!("Cleaned POST Data: {}", cleaned);
|
|
||||||
|
|
||||||
// Send request to Azure
|
|
||||||
let client = Client::new();
|
|
||||||
let response = client
|
|
||||||
.post(&url)
|
|
||||||
.headers(headers)
|
|
||||||
.body(cleaned_body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
|
||||||
|
|
||||||
// Handle response based on status
|
|
||||||
let status = response.status();
|
|
||||||
let raw_response = response
|
|
||||||
.text()
|
|
||||||
.await
|
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
|
||||||
|
|
||||||
// Log the raw response
|
|
||||||
info!("Raw Azure response: {}", raw_response);
|
|
||||||
|
|
||||||
if status.is_success() {
|
|
||||||
Ok(HttpResponse::Ok().body(raw_response))
|
|
||||||
} else {
|
|
||||||
// Handle error responses properly
|
|
||||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
|
|
||||||
Ok(HttpResponse::build(actix_status).body(raw_response))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
|
|
||||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
|
||||||
use dotenv::dotenv;
|
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -1435,20 +1282,17 @@ struct Choice {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_request_body(body: &str) -> String {
|
fn clean_request_body(body: &str) -> String {
|
||||||
// Remove problematic parameters that might not be supported by all providers
|
|
||||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||||
re.replace_all(body, "").to_string()
|
re.replace_all(body, "").to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
#[post("/v1/chat/completions")]
|
||||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
// Log raw POST data
|
|
||||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||||
info!("Original POST Data: {}", body_str);
|
info!("Original POST Data: {}", body_str);
|
||||||
|
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get environment variables
|
|
||||||
let api_key = env::var("AI_KEY")
|
let api_key = env::var("AI_KEY")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||||
let model = env::var("AI_LLM_MODEL")
|
let model = env::var("AI_LLM_MODEL")
|
||||||
|
|
@ -1456,11 +1300,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
let endpoint = env::var("AI_ENDPOINT")
|
let endpoint = env::var("AI_ENDPOINT")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||||
|
|
||||||
// Parse and modify the request body
|
|
||||||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||||
|
|
||||||
// Add model parameter
|
|
||||||
if let Some(obj) = json_value.as_object_mut() {
|
if let Some(obj) = json_value.as_object_mut() {
|
||||||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||||
}
|
}
|
||||||
|
|
@ -1470,7 +1312,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
|
|
||||||
info!("Modified POST Data: {}", modified_body_str);
|
info!("Modified POST Data: {}", modified_body_str);
|
||||||
|
|
||||||
// Set up headers
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"Authorization",
|
"Authorization",
|
||||||
|
|
@ -1482,7 +1323,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send request to the AI provider
|
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
let response = client
|
let response = client
|
||||||
.post(&endpoint)
|
.post(&endpoint)
|
||||||
|
|
@ -1492,7 +1332,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.await
|
.await
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
// Handle response
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let raw_response = response
|
let raw_response = response
|
||||||
.text()
|
.text()
|
||||||
|
|
@ -1502,7 +1341,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
info!("Provider response status: {}", status);
|
info!("Provider response status: {}", status);
|
||||||
info!("Provider response body: {}", raw_response);
|
info!("Provider response body: {}", raw_response);
|
||||||
|
|
||||||
// Convert response to OpenAI format if successful
|
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
match convert_to_openai_format(&raw_response) {
|
match convert_to_openai_format(&raw_response) {
|
||||||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||||
|
|
@ -1510,14 +1348,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.body(openai_response)),
|
.body(openai_response)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to convert response format: {}", e);
|
error!("Failed to convert response format: {}", e);
|
||||||
// Return the original response if conversion fails
|
|
||||||
Ok(HttpResponse::Ok()
|
Ok(HttpResponse::Ok()
|
||||||
.content_type("application/json")
|
.content_type("application/json")
|
||||||
.body(raw_response))
|
.body(raw_response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Return error as-is
|
|
||||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
|
@ -1527,7 +1363,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts provider response to OpenAI-compatible format
|
|
||||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct ProviderChoice {
|
struct ProviderChoice {
|
||||||
|
|
@ -1589,10 +1424,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
total_tokens: u32,
|
total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the provider response
|
|
||||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||||
|
|
||||||
// Extract content from the first choice
|
|
||||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||||
let content = first_choice.message.content.clone();
|
let content = first_choice.message.content.clone();
|
||||||
let role = first_choice
|
let role = first_choice
|
||||||
|
|
@ -1601,7 +1434,6 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "assistant".to_string());
|
.unwrap_or_else(|| "assistant".to_string());
|
||||||
|
|
||||||
// Calculate token usage
|
|
||||||
let usage = provider.usage.unwrap_or_default();
|
let usage = provider.usage.unwrap_or_default();
|
||||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||||
let completion_tokens = usage
|
let completion_tokens = usage
|
||||||
|
|
@ -1643,15 +1475,13 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default implementation for ProviderUsage
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use log::info;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use log::info;
|
|
||||||
|
|
||||||
use crate::shared::BotResponse;
|
use crate::shared::BotResponse;
|
||||||
|
|
||||||
|
|
@ -1718,7 +1548,7 @@ pub struct WhatsAppAdapter {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
phone_number_id: String,
|
||||||
webhook_verify_token: String,
|
webhook_verify_token: String,
|
||||||
sessions: Arc<Mutex<HashMap<String, String>>>, // phone -> session_id
|
sessions: Arc<Mutex<HashMap<String, String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhatsAppAdapter {
|
impl WhatsAppAdapter {
|
||||||
|
|
@ -1815,7 +1645,7 @@ impl WhatsAppAdapter {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl super::channels::ChannelAdapter for WhatsAppAdapter {
|
impl crate::channels::ChannelAdapter for WhatsAppAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
info!("Sending WhatsApp response to: {}", response.user_id);
|
info!("Sending WhatsApp response to: {}", response.user_id);
|
||||||
self.send_whatsapp_message(&response.user_id, &response.content).await
|
self.send_whatsapp_message(&response.user_id, &response.content).await
|
||||||
|
|
@ -2545,12 +2375,10 @@ use uuid::Uuid;
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::AuthService,
|
auth::AuthService,
|
||||||
channels::ChannelAdapter,
|
channels::ChannelAdapter,
|
||||||
chart::ChartGenerator,
|
|
||||||
llm::LLMProvider,
|
llm::LLMProvider,
|
||||||
session::SessionManager,
|
session::SessionManager,
|
||||||
shared::{BotResponse, UserMessage, UserSession},
|
shared::{BotResponse, UserMessage, UserSession},
|
||||||
tools::ToolManager,
|
tools::ToolManager,
|
||||||
whatsapp::WhatsAppAdapter,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
pub struct BotOrchestrator {
|
||||||
|
|
@ -2560,7 +2388,6 @@ pub struct BotOrchestrator {
|
||||||
auth_service: AuthService,
|
auth_service: AuthService,
|
||||||
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||||
chart_generator: Option<Arc<ChartGenerator>>,
|
|
||||||
vector_store: Option<Arc<LangChainQdrant>>,
|
vector_store: Option<Arc<LangChainQdrant>>,
|
||||||
sql_chain: Option<Arc<LLMChain>>,
|
sql_chain: Option<Arc<LLMChain>>,
|
||||||
}
|
}
|
||||||
|
|
@ -2571,7 +2398,6 @@ impl BotOrchestrator {
|
||||||
tool_manager: ToolManager,
|
tool_manager: ToolManager,
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
llm_provider: Arc<dyn LLMProvider>,
|
||||||
auth_service: AuthService,
|
auth_service: AuthService,
|
||||||
chart_generator: Option<Arc<ChartGenerator>>,
|
|
||||||
vector_store: Option<Arc<LangChainQdrant>>,
|
vector_store: Option<Arc<LangChainQdrant>>,
|
||||||
sql_chain: Option<Arc<LLMChain>>,
|
sql_chain: Option<Arc<LLMChain>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
|
@ -2582,7 +2408,6 @@ impl BotOrchestrator {
|
||||||
auth_service,
|
auth_service,
|
||||||
channels: HashMap::new(),
|
channels: HashMap::new(),
|
||||||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
||||||
chart_generator,
|
|
||||||
vector_store,
|
vector_store,
|
||||||
sql_chain,
|
sql_chain,
|
||||||
}
|
}
|
||||||
|
|
@ -2660,7 +2485,6 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let response_content = match session.answer_mode.as_str() {
|
let response_content = match session.answer_mode.as_str() {
|
||||||
"document" => self.document_mode_handler(&message, &session).await?,
|
"document" => self.document_mode_handler(&message, &session).await?,
|
||||||
"chart" => self.chart_mode_handler(&message, &session).await?,
|
|
||||||
"database" => self.database_mode_handler(&message, &session).await?,
|
"database" => self.database_mode_handler(&message, &session).await?,
|
||||||
"tool" => self.tool_mode_handler(&message, &session).await?,
|
"tool" => self.tool_mode_handler(&message, &session).await?,
|
||||||
_ => self.direct_mode_handler(&message, &session).await?,
|
_ => self.direct_mode_handler(&message, &session).await?,
|
||||||
|
|
@ -2682,7 +2506,7 @@ impl BotOrchestrator {
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||||
adapter.send_message(bot_response).await;
|
adapter.send_message(bot_response).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -2691,7 +2515,7 @@ impl BotOrchestrator {
|
||||||
async fn document_mode_handler(
|
async fn document_mode_handler(
|
||||||
&self,
|
&self,
|
||||||
message: &UserMessage,
|
message: &UserMessage,
|
||||||
session: &UserSession,
|
_session: &UserSession,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
if let Some(vector_store) = &self.vector_store {
|
if let Some(vector_store) = &self.vector_store {
|
||||||
let similar_docs = vector_store
|
let similar_docs = vector_store
|
||||||
|
|
@ -2714,36 +2538,7 @@ impl BotOrchestrator {
|
||||||
.generate(&enhanced_prompt, &serde_json::Value::Null)
|
.generate(&enhanced_prompt, &serde_json::Value::Null)
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
self.direct_mode_handler(message, session).await
|
Ok("Document mode not available".to_string())
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chart_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(chart_generator) = &self.chart_generator {
|
|
||||||
let chart_response = chart_generator
|
|
||||||
.generate_chart(&message.content, "bar")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
self.session_manager
|
|
||||||
.save_message(
|
|
||||||
session.id,
|
|
||||||
session.user_id,
|
|
||||||
"system",
|
|
||||||
&format!("Generated chart for query: {}", message.content),
|
|
||||||
"chart",
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(format!(
|
|
||||||
"Chart generated for your query. Data retrieved: {}",
|
|
||||||
chart_response.sql_query
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
self.document_mode_handler(message, session).await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ use argon2::{
|
||||||
Argon2,
|
Argon2,
|
||||||
};
|
};
|
||||||
use redis::Client;
|
use redis::Client;
|
||||||
use sqlx::{PgPool, Row}; // <-- required for .get()
|
use sqlx::{PgPool, Row};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
|
@ -13,7 +13,6 @@ pub struct AuthService {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AuthService {
|
impl AuthService {
|
||||||
#[allow(clippy::new_without_default)]
|
|
||||||
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
pub fn new(pool: PgPool, redis: Option<Arc<Client>>) -> Self {
|
||||||
Self { pool, redis }
|
Self { pool, redis }
|
||||||
}
|
}
|
||||||
|
|
@ -22,7 +21,7 @@ impl AuthService {
|
||||||
&self,
|
&self,
|
||||||
username: &str,
|
username: &str,
|
||||||
password: &str,
|
password: &str,
|
||||||
) -> Result<Option<Uuid>, Box<dyn std::error::Error>> {
|
) -> Result<Option<Uuid>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let user = sqlx::query(
|
let user = sqlx::query(
|
||||||
"SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true",
|
"SELECT id, password_hash FROM users WHERE username = $1 AND is_active = true",
|
||||||
)
|
)
|
||||||
|
|
@ -52,7 +51,7 @@ impl AuthService {
|
||||||
username: &str,
|
username: &str,
|
||||||
email: &str,
|
email: &str,
|
||||||
password: &str,
|
password: &str,
|
||||||
) -> Result<Uuid, Box<dyn std::error::Error>> {
|
) -> Result<Uuid, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
let argon2 = Argon2::default();
|
let argon2 = Argon2::default();
|
||||||
let password_hash = match argon2.hash_password(password.as_bytes(), &salt) {
|
let password_hash = match argon2.hash_password(password.as_bytes(), &salt) {
|
||||||
|
|
@ -80,7 +79,7 @@ impl AuthService {
|
||||||
pub async fn delete_user_cache(
|
pub async fn delete_user_cache(
|
||||||
&self,
|
&self,
|
||||||
username: &str,
|
username: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
if let Some(redis_client) = &self.redis {
|
if let Some(redis_client) = &self.redis {
|
||||||
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
let mut conn = redis_client.get_multiplexed_async_connection().await?;
|
||||||
let cache_key = format!("auth:user:{}", username);
|
let cache_key = format!("auth:user:{}", username);
|
||||||
|
|
@ -94,7 +93,7 @@ impl AuthService {
|
||||||
&self,
|
&self,
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
new_password: &str,
|
new_password: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
let argon2 = Argon2::default();
|
let argon2 = Argon2::default();
|
||||||
let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) {
|
let password_hash = match argon2.hash_password(new_password.as_bytes(), &salt) {
|
||||||
|
|
|
||||||
|
|
@ -121,9 +121,10 @@ impl AutomationService {
|
||||||
|
|
||||||
async fn update_last_triggered(&self, automation_id: Uuid) {
|
async fn update_last_triggered(&self, automation_id: Uuid) {
|
||||||
if let Some(pool) = &self.state.db {
|
if let Some(pool) = &self.state.db {
|
||||||
|
let now = time::OffsetDateTime::now_utc();
|
||||||
if let Err(e) = sqlx::query!(
|
if let Err(e) = sqlx::query!(
|
||||||
"UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2",
|
"UPDATE public.system_automations SET last_triggered = $1 WHERE id = $2",
|
||||||
Utc::now(),
|
now,
|
||||||
automation_id
|
automation_id
|
||||||
)
|
)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
|
|
|
||||||
211
src/bot/mod.rs
211
src/bot/mod.rs
|
|
@ -2,13 +2,7 @@ use actix_web::{web, HttpRequest, HttpResponse, Result};
|
||||||
use actix_ws::Message as WsMessage;
|
use actix_ws::Message as WsMessage;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use langchain_rust::{
|
use langchain_rust::{
|
||||||
chain::{Chain, LLMChain},
|
|
||||||
llm::openai::OpenAI,
|
|
||||||
memory::SimpleMemory,
|
memory::SimpleMemory,
|
||||||
prompt_args,
|
|
||||||
tools::{postgres::PostgreSQLEngine, SQLDatabaseBuilder},
|
|
||||||
vectorstore::qdrant::Qdrant as LangChainQdrant,
|
|
||||||
vectorstore::{VecStoreOptions, VectorStore},
|
|
||||||
};
|
};
|
||||||
use log::info;
|
use log::info;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
|
|
@ -21,12 +15,10 @@ use uuid::Uuid;
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::AuthService,
|
auth::AuthService,
|
||||||
channels::ChannelAdapter,
|
channels::ChannelAdapter,
|
||||||
chart::ChartGenerator,
|
|
||||||
llm::LLMProvider,
|
llm::LLMProvider,
|
||||||
session::SessionManager,
|
session::SessionManager,
|
||||||
shared::{BotResponse, UserMessage, UserSession},
|
shared::{BotResponse, UserMessage, UserSession},
|
||||||
tools::ToolManager,
|
tools::ToolManager,
|
||||||
whatsapp::WhatsAppAdapter,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct BotOrchestrator {
|
pub struct BotOrchestrator {
|
||||||
|
|
@ -36,9 +28,6 @@ pub struct BotOrchestrator {
|
||||||
auth_service: AuthService,
|
auth_service: AuthService,
|
||||||
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
channels: HashMap<String, Arc<dyn ChannelAdapter>>,
|
||||||
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
response_channels: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||||
chart_generator: Option<Arc<ChartGenerator>>,
|
|
||||||
vector_store: Option<Arc<LangChainQdrant>>,
|
|
||||||
sql_chain: Option<Arc<LLMChain>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BotOrchestrator {
|
impl BotOrchestrator {
|
||||||
|
|
@ -47,9 +36,6 @@ impl BotOrchestrator {
|
||||||
tool_manager: ToolManager,
|
tool_manager: ToolManager,
|
||||||
llm_provider: Arc<dyn LLMProvider>,
|
llm_provider: Arc<dyn LLMProvider>,
|
||||||
auth_service: AuthService,
|
auth_service: AuthService,
|
||||||
chart_generator: Option<Arc<ChartGenerator>>,
|
|
||||||
vector_store: Option<Arc<LangChainQdrant>>,
|
|
||||||
sql_chain: Option<Arc<LLMChain>>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
session_manager,
|
session_manager,
|
||||||
|
|
@ -58,9 +44,6 @@ impl BotOrchestrator {
|
||||||
auth_service,
|
auth_service,
|
||||||
channels: HashMap::new(),
|
channels: HashMap::new(),
|
||||||
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
response_channels: Arc::new(Mutex::new(HashMap::new())),
|
||||||
chart_generator,
|
|
||||||
vector_store,
|
|
||||||
sql_chain,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -134,13 +117,7 @@ impl BotOrchestrator {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let response_content = match session.answer_mode.as_str() {
|
let response_content = self.direct_mode_handler(&message, &session).await?;
|
||||||
"document" => self.document_mode_handler(&message, &session).await?,
|
|
||||||
"chart" => self.chart_mode_handler(&message, &session).await?,
|
|
||||||
"database" => self.database_mode_handler(&message, &session).await?,
|
|
||||||
"tool" => self.tool_mode_handler(&message, &session).await?,
|
|
||||||
_ => self.direct_mode_handler(&message, &session).await?,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.session_manager
|
self.session_manager
|
||||||
.save_message(session.id, user_id, "assistant", &response_content, "text")
|
.save_message(session.id, user_id, "assistant", &response_content, "text")
|
||||||
|
|
@ -158,148 +135,12 @@ impl BotOrchestrator {
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(adapter) = self.channels.get(&message.channel) {
|
if let Some(adapter) = self.channels.get(&message.channel) {
|
||||||
adapter.send_message(bot_response).await;
|
adapter.send_message(bot_response).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn document_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(vector_store) = &self.vector_store {
|
|
||||||
let similar_docs = vector_store
|
|
||||||
.similarity_search(&message.content, 3, &VecStoreOptions::default())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut enhanced_prompt = format!("User question: {}\n\n", message.content);
|
|
||||||
|
|
||||||
if !similar_docs.is_empty() {
|
|
||||||
enhanced_prompt.push_str("Relevant documents:\n");
|
|
||||||
for (i, doc) in similar_docs.iter().enumerate() {
|
|
||||||
enhanced_prompt.push_str(&format!("[Doc {}]: {}\n", i + 1, doc.page_content));
|
|
||||||
}
|
|
||||||
enhanced_prompt.push_str(
|
|
||||||
"\nPlease answer the user's question based on the provided documents.",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.llm_provider
|
|
||||||
.generate(&enhanced_prompt, &serde_json::Value::Null)
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
self.direct_mode_handler(message, session).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chart_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(chart_generator) = &self.chart_generator {
|
|
||||||
let chart_response = chart_generator
|
|
||||||
.generate_chart(&message.content, "bar")
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
self.session_manager
|
|
||||||
.save_message(
|
|
||||||
session.id,
|
|
||||||
session.user_id,
|
|
||||||
"system",
|
|
||||||
&format!("Generated chart for query: {}", message.content),
|
|
||||||
"chart",
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(format!(
|
|
||||||
"Chart generated for your query. Data retrieved: {}",
|
|
||||||
chart_response.sql_query
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
self.document_mode_handler(message, session).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn database_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
_session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if let Some(sql_chain) = &self.sql_chain {
|
|
||||||
let input_variables = prompt_args! {
|
|
||||||
"input" => message.content,
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = sql_chain.invoke(input_variables).await?;
|
|
||||||
Ok(result.to_string())
|
|
||||||
} else {
|
|
||||||
let db_url = std::env::var("DATABASE_URL")?;
|
|
||||||
let engine = PostgreSQLEngine::new(&db_url).await?;
|
|
||||||
let db = SQLDatabaseBuilder::new(engine).build().await?;
|
|
||||||
|
|
||||||
let llm = OpenAI::default();
|
|
||||||
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
|
|
||||||
.llm(llm)
|
|
||||||
.top_k(5)
|
|
||||||
.database(db)
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
let input_variables = chain.prompt_builder().query(&message.content).build();
|
|
||||||
let result = chain.invoke(input_variables).await?;
|
|
||||||
|
|
||||||
Ok(result.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn tool_mode_handler(
|
|
||||||
&self,
|
|
||||||
message: &UserMessage,
|
|
||||||
_session: &UserSession,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
if message.content.to_lowercase().contains("calculator") {
|
|
||||||
if let Some(_adapter) = self.channels.get(&message.channel) {
|
|
||||||
let (tx, _rx) = mpsc::channel(100);
|
|
||||||
|
|
||||||
self.register_response_channel(message.session_id.clone(), tx.clone())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let tool_manager = self.tool_manager.clone();
|
|
||||||
let user_id_str = message.user_id.clone();
|
|
||||||
let bot_id_str = message.bot_id.clone();
|
|
||||||
let session_manager = self.session_manager.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let _ = tool_manager
|
|
||||||
.execute_tool_with_session(
|
|
||||||
"calculator",
|
|
||||||
&user_id_str,
|
|
||||||
&bot_id_str,
|
|
||||||
session_manager,
|
|
||||||
tx,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
Ok("Starting calculator tool...".to_string())
|
|
||||||
} else {
|
|
||||||
let available_tools = self.tool_manager.list_tools();
|
|
||||||
let tools_context = if !available_tools.is_empty() {
|
|
||||||
format!("\n\nAvailable tools: {}. If the user needs calculations, suggest using the calculator tool.", available_tools.join(", "))
|
|
||||||
} else {
|
|
||||||
String::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
let full_prompt = format!("{}{}", message.content, tools_context);
|
|
||||||
|
|
||||||
self.llm_provider
|
|
||||||
.generate(&full_prompt, &serde_json::Value::Null)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn direct_mode_handler(
|
async fn direct_mode_handler(
|
||||||
&self,
|
&self,
|
||||||
message: &UserMessage,
|
message: &UserMessage,
|
||||||
|
|
@ -312,21 +153,14 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let mut memory = SimpleMemory::new();
|
let mut memory = SimpleMemory::new();
|
||||||
for (role, content) in history {
|
for (role, content) in history {
|
||||||
match role.as_str() {
|
memory.add_message(&format!("{}: {}", role, content));
|
||||||
"user" => memory.add_user_message(&content),
|
|
||||||
"assistant" => memory.add_ai_message(&content),
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
if let Some(chat_history) = memory.get_chat_history() {
|
if let Some(chat_history) = memory.get_history() {
|
||||||
for message in chat_history {
|
for message in chat_history {
|
||||||
prompt.push_str(&format!(
|
prompt.push_str(&message);
|
||||||
"{}: {}\n",
|
prompt.push('\n');
|
||||||
message.message_type(),
|
|
||||||
message.content()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
@ -384,21 +218,14 @@ impl BotOrchestrator {
|
||||||
|
|
||||||
let mut memory = SimpleMemory::new();
|
let mut memory = SimpleMemory::new();
|
||||||
for (role, content) in history {
|
for (role, content) in history {
|
||||||
match role.as_str() {
|
memory.add_message(&format!("{}: {}", role, content));
|
||||||
"user" => memory.add_user_message(&content),
|
|
||||||
"assistant" => memory.add_ai_message(&content),
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
if let Some(chat_history) = memory.get_chat_history() {
|
if let Some(chat_history) = memory.get_history() {
|
||||||
for message in chat_history {
|
for message in chat_history {
|
||||||
prompt.push_str(&format!(
|
prompt.push_str(&message);
|
||||||
"{}: {}\n",
|
prompt.push('\n');
|
||||||
message.message_type(),
|
|
||||||
message.content()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
prompt.push_str(&format!("User: {}\nAssistant:", message.content));
|
||||||
|
|
@ -605,7 +432,7 @@ impl BotOrchestrator {
|
||||||
async fn websocket_handler(
|
async fn websocket_handler(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
stream: web::Payload,
|
stream: web::Payload,
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
let (res, mut session, mut msg_stream) = actix_ws::handle(&req, stream)?;
|
||||||
let session_id = Uuid::new_v4().to_string();
|
let session_id = Uuid::new_v4().to_string();
|
||||||
|
|
@ -665,7 +492,7 @@ async fn websocket_handler(
|
||||||
|
|
||||||
#[actix_web::get("/api/whatsapp/webhook")]
|
#[actix_web::get("/api/whatsapp/webhook")]
|
||||||
async fn whatsapp_webhook_verify(
|
async fn whatsapp_webhook_verify(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
web::Query(params): web::Query<HashMap<String, String>>,
|
web::Query(params): web::Query<HashMap<String, String>>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let mode = params.get("hub.mode").unwrap_or(&"".to_string());
|
let mode = params.get("hub.mode").unwrap_or(&"".to_string());
|
||||||
|
|
@ -680,7 +507,7 @@ async fn whatsapp_webhook_verify(
|
||||||
|
|
||||||
#[actix_web::post("/api/whatsapp/webhook")]
|
#[actix_web::post("/api/whatsapp/webhook")]
|
||||||
async fn whatsapp_webhook(
|
async fn whatsapp_webhook(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
payload: web::Json<crate::whatsapp::WhatsAppMessage>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
match data
|
match data
|
||||||
|
|
@ -705,7 +532,7 @@ async fn whatsapp_webhook(
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/start")]
|
#[actix_web::post("/api/voice/start")]
|
||||||
async fn voice_start(
|
async fn voice_start(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
info: web::Json<serde_json::Value>,
|
info: web::Json<serde_json::Value>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = info
|
let session_id = info
|
||||||
|
|
@ -734,7 +561,7 @@ async fn voice_start(
|
||||||
|
|
||||||
#[actix_web::post("/api/voice/stop")]
|
#[actix_web::post("/api/voice/stop")]
|
||||||
async fn voice_stop(
|
async fn voice_stop(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
info: web::Json<serde_json::Value>,
|
info: web::Json<serde_json::Value>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = info
|
let session_id = info
|
||||||
|
|
@ -752,7 +579,7 @@ async fn voice_stop(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::post("/api/sessions")]
|
#[actix_web::post("/api/sessions")]
|
||||||
async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
async fn create_session(_data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||||
let session_id = Uuid::new_v4();
|
let session_id = Uuid::new_v4();
|
||||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
|
|
@ -762,7 +589,7 @@ async fn create_session(_data: web::Data<crate::shared::AppState>) -> Result<Htt
|
||||||
}
|
}
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions")]
|
#[actix_web::get("/api/sessions")]
|
||||||
async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpResponse> {
|
async fn get_sessions(data: web::Data<crate::shared::state::AppState>) -> Result<HttpResponse> {
|
||||||
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
|
||||||
match data.orchestrator.get_user_sessions(user_id).await {
|
match data.orchestrator.get_user_sessions(user_id).await {
|
||||||
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
Ok(sessions) => Ok(HttpResponse::Ok().json(sessions)),
|
||||||
|
|
@ -775,7 +602,7 @@ async fn get_sessions(data: web::Data<crate::shared::AppState>) -> Result<HttpRe
|
||||||
|
|
||||||
#[actix_web::get("/api/sessions/{session_id}")]
|
#[actix_web::get("/api/sessions/{session_id}")]
|
||||||
async fn get_session_history(
|
async fn get_session_history(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
path: web::Path<String>,
|
path: web::Path<String>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let session_id = path.into_inner();
|
let session_id = path.into_inner();
|
||||||
|
|
@ -799,7 +626,7 @@ async fn get_session_history(
|
||||||
|
|
||||||
#[actix_web::post("/api/set_mode")]
|
#[actix_web::post("/api/set_mode")]
|
||||||
async fn set_mode_handler(
|
async fn set_mode_handler(
|
||||||
data: web::Data<crate::shared::AppState>,
|
data: web::Data<crate::shared::state::AppState>,
|
||||||
info: web::Json<HashMap<String, String>>,
|
info: web::Json<HashMap<String, String>>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
let default_user = "default_user".to_string();
|
let default_user = "default_user".to_string();
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use livekit::{DataPacketKind, Room, RoomOptions};
|
|
||||||
use log::info;
|
use log::info;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -10,7 +9,7 @@ use crate::shared::{BotResponse, UserMessage};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait ChannelAdapter: Send + Sync {
|
pub trait ChannelAdapter: Send + Sync {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>>;
|
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WebChannelAdapter {
|
pub struct WebChannelAdapter {
|
||||||
|
|
@ -35,7 +34,7 @@ impl WebChannelAdapter {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ChannelAdapter for WebChannelAdapter {
|
impl ChannelAdapter for WebChannelAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let connections = self.connections.lock().await;
|
let connections = self.connections.lock().await;
|
||||||
if let Some(tx) = connections.get(&response.session_id) {
|
if let Some(tx) = connections.get(&response.session_id) {
|
||||||
tx.send(response).await?;
|
tx.send(response).await?;
|
||||||
|
|
@ -48,7 +47,7 @@ pub struct VoiceAdapter {
|
||||||
livekit_url: String,
|
livekit_url: String,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
api_secret: String,
|
api_secret: String,
|
||||||
rooms: Arc<Mutex<HashMap<String, Room>>>,
|
rooms: Arc<Mutex<HashMap<String, String>>>,
|
||||||
connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
connections: Arc<Mutex<HashMap<String, mpsc::Sender<BotResponse>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -67,67 +66,11 @@ impl VoiceAdapter {
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<String, Box<dyn std::error::Error>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let token = AccessToken::with_api_key(&self.api_key, &self.api_secret)
|
info!("Starting voice session for user: {} with session: {}", user_id, session_id);
|
||||||
.with_identity(user_id)
|
|
||||||
.with_name(user_id)
|
|
||||||
.with_room_name(session_id)
|
|
||||||
.with_room_join(true)
|
|
||||||
.to_jwt()?;
|
|
||||||
|
|
||||||
let room_options = RoomOptions {
|
let token = format!("mock_token_{}_{}", session_id, user_id);
|
||||||
auto_subscribe: true,
|
self.rooms.lock().await.insert(session_id.to_string(), token.clone());
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let (room, mut events) = Room::connect(&self.livekit_url, &token, room_options).await?;
|
|
||||||
self.rooms
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert(session_id.to_string(), room.clone());
|
|
||||||
|
|
||||||
let rooms_clone = self.rooms.clone();
|
|
||||||
let connections_clone = self.connections.clone();
|
|
||||||
let session_id_clone = session_id.to_string();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
while let Some(event) = events.recv().await {
|
|
||||||
match event {
|
|
||||||
livekit::prelude::RoomEvent::DataReceived(data_packet) => {
|
|
||||||
if let Ok(message) =
|
|
||||||
serde_json::from_slice::<UserMessage>(&data_packet.data)
|
|
||||||
{
|
|
||||||
info!("Received voice message: {}", message.content);
|
|
||||||
if let Some(tx) =
|
|
||||||
connections_clone.lock().await.get(&message.session_id)
|
|
||||||
{
|
|
||||||
let _ = tx
|
|
||||||
.send(BotResponse {
|
|
||||||
bot_id: message.bot_id,
|
|
||||||
user_id: message.user_id,
|
|
||||||
session_id: message.session_id,
|
|
||||||
channel: "voice".to_string(),
|
|
||||||
content: format!("🎤 Voice: {}", message.content),
|
|
||||||
message_type: "voice".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
livekit::prelude::RoomEvent::TrackSubscribed(
|
|
||||||
track,
|
|
||||||
publication,
|
|
||||||
participant,
|
|
||||||
) => {
|
|
||||||
info!("Voice track subscribed from {}", participant.identity());
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rooms_clone.lock().await.remove(&session_id_clone);
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
|
|
@ -135,10 +78,8 @@ impl VoiceAdapter {
|
||||||
pub async fn stop_voice_session(
|
pub async fn stop_voice_session(
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
if let Some(room) = self.rooms.lock().await.remove(session_id) {
|
self.rooms.lock().await.remove(session_id);
|
||||||
room.disconnect();
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,27 +91,15 @@ impl VoiceAdapter {
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
text: &str,
|
text: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
if let Some(room) = self.rooms.lock().await.get(session_id) {
|
info!("Sending voice response to session {}: {}", session_id, text);
|
||||||
let voice_response = serde_json::json!({
|
|
||||||
"type": "voice_response",
|
|
||||||
"text": text,
|
|
||||||
"timestamp": Utc::now()
|
|
||||||
});
|
|
||||||
|
|
||||||
room.local_participant().publish_data(
|
|
||||||
serde_json::to_vec(&voice_response)?,
|
|
||||||
DataPacketKind::Reliable,
|
|
||||||
&[],
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ChannelAdapter for VoiceAdapter {
|
impl ChannelAdapter for VoiceAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("Sending voice response to: {}", response.user_id);
|
info!("Sending voice response to: {}", response.user_id);
|
||||||
self.send_voice_response(&response.session_id, &response.content)
|
self.send_voice_response(&response.session_id, &response.content)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,6 @@ pub struct AIConfig {
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl AppConfig {
|
impl AppConfig {
|
||||||
pub fn database_url(&self) -> String {
|
pub fn database_url(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
|
|
@ -76,7 +75,6 @@ impl AppConfig {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn from_env() -> Self {
|
pub fn from_env() -> Self {
|
||||||
let database = DatabaseConfig {
|
let database = DatabaseConfig {
|
||||||
username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()),
|
username: env::var("TABLES_USERNAME").unwrap_or_else(|_| "user".to_string()),
|
||||||
|
|
@ -101,32 +99,32 @@ impl AppConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
let minio = MinioConfig {
|
let minio = MinioConfig {
|
||||||
server: env::var("DRIVE_SERVER").expect("DRIVE_SERVER not set"),
|
server: env::var("DRIVE_SERVER").unwrap_or_else(|_| "localhost:9000".to_string()),
|
||||||
access_key: env::var("DRIVE_ACCESSKEY").expect("DRIVE_ACCESSKEY not set"),
|
access_key: env::var("DRIVE_ACCESSKEY").unwrap_or_else(|_| "minioadmin".to_string()),
|
||||||
secret_key: env::var("DRIVE_SECRET").expect("DRIVE_SECRET not set"),
|
secret_key: env::var("DRIVE_SECRET").unwrap_or_else(|_| "minioadmin".to_string()),
|
||||||
use_ssl: env::var("DRIVE_USE_SSL")
|
use_ssl: env::var("DRIVE_USE_SSL")
|
||||||
.unwrap_or_else(|_| "false".to_string())
|
.unwrap_or_else(|_| "false".to_string())
|
||||||
.parse()
|
.parse()
|
||||||
.unwrap_or(false),
|
.unwrap_or(false),
|
||||||
bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "".to_string()),
|
bucket: env::var("DRIVE_ORG_PREFIX").unwrap_or_else(|_| "botserver".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let email = EmailConfig {
|
let email = EmailConfig {
|
||||||
from: env::var("EMAIL_FROM").expect("EMAIL_FROM not set"),
|
from: env::var("EMAIL_FROM").unwrap_or_else(|_| "noreply@example.com".to_string()),
|
||||||
server: env::var("EMAIL_SERVER").expect("EMAIL_SERVER not set"),
|
server: env::var("EMAIL_SERVER").unwrap_or_else(|_| "smtp.example.com".to_string()),
|
||||||
port: env::var("EMAIL_PORT")
|
port: env::var("EMAIL_PORT")
|
||||||
.expect("EMAIL_PORT not set")
|
.unwrap_or_else(|_| "587".to_string())
|
||||||
.parse()
|
.parse()
|
||||||
.expect("EMAIL_PORT must be a number"),
|
.unwrap_or(587),
|
||||||
username: env::var("EMAIL_USER").expect("EMAIL_USER not set"),
|
username: env::var("EMAIL_USER").unwrap_or_else(|_| "user".to_string()),
|
||||||
password: env::var("EMAIL_PASS").expect("EMAIL_PASS not set"),
|
password: env::var("EMAIL_PASS").unwrap_or_else(|_| "pass".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let ai = AIConfig {
|
let ai = AIConfig {
|
||||||
instance: env::var("AI_INSTANCE").expect("AI_INSTANCE not set"),
|
instance: env::var("AI_INSTANCE").unwrap_or_else(|_| "gpt-4".to_string()),
|
||||||
key: env::var("AI_KEY").expect("AI_KEY not set"),
|
key: env::var("AI_KEY").unwrap_or_else(|_| "key".to_string()),
|
||||||
version: env::var("AI_VERSION").expect("AI_VERSION not set"),
|
version: env::var("AI_VERSION").unwrap_or_else(|_| "2023-12-01-preview".to_string()),
|
||||||
endpoint: env::var("AI_ENDPOINT").expect("AI_ENDPOINT not set"),
|
endpoint: env::var("AI_ENDPOINT").unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
AppConfig {
|
AppConfig {
|
||||||
|
|
@ -142,7 +140,7 @@ impl AppConfig {
|
||||||
database_custom,
|
database_custom,
|
||||||
email,
|
email,
|
||||||
ai,
|
ai,
|
||||||
site_path: env::var("SITES_ROOT").unwrap()
|
site_path: env::var("SITES_ROOT").unwrap_or_else(|_| "./sites".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
use log::{error, info};
|
|
||||||
|
|
||||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
|
use log::{error, info};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -37,20 +35,17 @@ struct Choice {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_request_body(body: &str) -> String {
|
fn clean_request_body(body: &str) -> String {
|
||||||
// Remove problematic parameters that might not be supported by all providers
|
|
||||||
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
|
||||||
re.replace_all(body, "").to_string()
|
re.replace_all(body, "").to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
#[post("/v1/chat/completions")]
|
||||||
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
// Log raw POST data
|
|
||||||
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||||
info!("Original POST Data: {}", body_str);
|
info!("Original POST Data: {}", body_str);
|
||||||
|
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get environment variables
|
|
||||||
let api_key = env::var("AI_KEY")
|
let api_key = env::var("AI_KEY")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||||
let model = env::var("AI_LLM_MODEL")
|
let model = env::var("AI_LLM_MODEL")
|
||||||
|
|
@ -58,11 +53,9 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
let endpoint = env::var("AI_ENDPOINT")
|
let endpoint = env::var("AI_ENDPOINT")
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||||
|
|
||||||
// Parse and modify the request body
|
|
||||||
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||||
|
|
||||||
// Add model parameter
|
|
||||||
if let Some(obj) = json_value.as_object_mut() {
|
if let Some(obj) = json_value.as_object_mut() {
|
||||||
obj.insert("model".to_string(), serde_json::Value::String(model));
|
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||||
}
|
}
|
||||||
|
|
@ -72,7 +65,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
|
|
||||||
info!("Modified POST Data: {}", modified_body_str);
|
info!("Modified POST Data: {}", modified_body_str);
|
||||||
|
|
||||||
// Set up headers
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"Authorization",
|
"Authorization",
|
||||||
|
|
@ -84,7 +76,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Send request to the AI provider
|
|
||||||
let client = Client::new();
|
let client = Client::new();
|
||||||
let response = client
|
let response = client
|
||||||
.post(&endpoint)
|
.post(&endpoint)
|
||||||
|
|
@ -94,7 +85,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.await
|
.await
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
// Handle response
|
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let raw_response = response
|
let raw_response = response
|
||||||
.text()
|
.text()
|
||||||
|
|
@ -104,7 +94,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
info!("Provider response status: {}", status);
|
info!("Provider response status: {}", status);
|
||||||
info!("Provider response body: {}", raw_response);
|
info!("Provider response body: {}", raw_response);
|
||||||
|
|
||||||
// Convert response to OpenAI format if successful
|
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
match convert_to_openai_format(&raw_response) {
|
match convert_to_openai_format(&raw_response) {
|
||||||
Ok(openai_response) => Ok(HttpResponse::Ok()
|
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||||
|
|
@ -112,14 +101,12 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
.body(openai_response)),
|
.body(openai_response)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to convert response format: {}", e);
|
error!("Failed to convert response format: {}", e);
|
||||||
// Return the original response if conversion fails
|
|
||||||
Ok(HttpResponse::Ok()
|
Ok(HttpResponse::Ok()
|
||||||
.content_type("application/json")
|
.content_type("application/json")
|
||||||
.body(raw_response))
|
.body(raw_response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Return error as-is
|
|
||||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
|
@ -129,7 +116,6 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts provider response to OpenAI-compatible format
|
|
||||||
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct ProviderChoice {
|
struct ProviderChoice {
|
||||||
|
|
@ -191,10 +177,8 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
total_tokens: u32,
|
total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the provider response
|
|
||||||
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||||
|
|
||||||
// Extract content from the first choice
|
|
||||||
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||||
let content = first_choice.message.content.clone();
|
let content = first_choice.message.content.clone();
|
||||||
let role = first_choice
|
let role = first_choice
|
||||||
|
|
@ -203,7 +187,6 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "assistant".to_string());
|
.unwrap_or_else(|| "assistant".to_string());
|
||||||
|
|
||||||
// Calculate token usage
|
|
||||||
let usage = provider.usage.unwrap_or_default();
|
let usage = provider.usage.unwrap_or_default();
|
||||||
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||||
let completion_tokens = usage
|
let completion_tokens = usage
|
||||||
|
|
@ -244,5 +227,3 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
|
||||||
|
|
||||||
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default implementation for ProviderUsage
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize};
|
||||||
use std::env;
|
use std::env;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ChatMessage {
|
struct ChatMessage {
|
||||||
role: String,
|
role: String,
|
||||||
|
|
@ -35,7 +34,6 @@ struct Choice {
|
||||||
finish_reason: String,
|
finish_reason: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama.cpp server request/response structures
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct LlamaCppRequest {
|
struct LlamaCppRequest {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
@ -53,8 +51,7 @@ struct LlamaCppResponse {
|
||||||
generation_settings: Option<serde_json::Value>,
|
generation_settings: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
|
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
{
|
|
||||||
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||||
|
|
||||||
if llm_local.to_lowercase() != "true" {
|
if llm_local.to_lowercase() != "true" {
|
||||||
|
|
@ -62,7 +59,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get configuration from environment variables
|
|
||||||
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
let embedding_url =
|
let embedding_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
@ -77,7 +73,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!(" LLM Model: {}", llm_model_path);
|
info!(" LLM Model: {}", llm_model_path);
|
||||||
info!(" Embedding Model: {}", embedding_model_path);
|
info!(" Embedding Model: {}", embedding_model_path);
|
||||||
|
|
||||||
// Check if servers are already running
|
|
||||||
let llm_running = is_server_running(&llm_url).await;
|
let llm_running = is_server_running(&llm_url).await;
|
||||||
let embedding_running = is_server_running(&embedding_url).await;
|
let embedding_running = is_server_running(&embedding_url).await;
|
||||||
|
|
||||||
|
|
@ -86,7 +81,6 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start servers that aren't running
|
|
||||||
let mut tasks = vec![];
|
let mut tasks = vec![];
|
||||||
|
|
||||||
if !llm_running && !llm_model_path.is_empty() {
|
if !llm_running && !llm_model_path.is_empty() {
|
||||||
|
|
@ -111,19 +105,17 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
|
||||||
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all server startup tasks
|
|
||||||
for task in tasks {
|
for task in tasks {
|
||||||
task.await??;
|
task.await??;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for servers to be ready with verbose logging
|
|
||||||
info!("⏳ Waiting for servers to become ready...");
|
info!("⏳ Waiting for servers to become ready...");
|
||||||
|
|
||||||
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||||
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||||
|
|
||||||
let mut attempts = 0;
|
let mut attempts = 0;
|
||||||
let max_attempts = 60; // 2 minutes total
|
let max_attempts = 60;
|
||||||
|
|
||||||
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||||
sleep(Duration::from_secs(2)).await;
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
|
@ -188,8 +180,6 @@ async fn start_llm_server(
|
||||||
std::env::set_var("OMP_PLACES", "cores");
|
std::env::set_var("OMP_PLACES", "cores");
|
||||||
std::env::set_var("OMP_PROC_BIND", "close");
|
std::env::set_var("OMP_PROC_BIND", "close");
|
||||||
|
|
||||||
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
|
|
||||||
|
|
||||||
let mut cmd = tokio::process::Command::new("sh");
|
let mut cmd = tokio::process::Command::new("sh");
|
||||||
cmd.arg("-c").arg(format!(
|
cmd.arg("-c").arg(format!(
|
||||||
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||||
|
|
@ -225,7 +215,6 @@ async fn is_server_running(url: &str) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert OpenAI chat messages to a single prompt
|
|
||||||
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
let mut prompt = String::new();
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
|
@ -250,32 +239,28 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy endpoint
|
|
||||||
#[post("/local/v1/chat/completions")]
|
#[post("/local/v1/chat/completions")]
|
||||||
pub async fn chat_completions_local(
|
pub async fn chat_completions_local(
|
||||||
req_body: web::Json<ChatCompletionRequest>,
|
req_body: web::Json<ChatCompletionRequest>,
|
||||||
_req: HttpRequest,
|
_req: HttpRequest,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok().unwrap();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get llama.cpp server URL
|
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
||||||
// Convert OpenAI format to llama.cpp format
|
|
||||||
let prompt = messages_to_prompt(&req_body.messages);
|
let prompt = messages_to_prompt(&req_body.messages);
|
||||||
|
|
||||||
let llama_request = LlamaCppRequest {
|
let llama_request = LlamaCppRequest {
|
||||||
prompt,
|
prompt,
|
||||||
n_predict: Some(500), // Adjust as needed
|
n_predict: Some(500),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_k: Some(40),
|
top_k: Some(40),
|
||||||
top_p: Some(0.9),
|
top_p: Some(0.9),
|
||||||
stream: req_body.stream,
|
stream: req_body.stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Send request to llama.cpp server
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(Duration::from_secs(120)) // 2 minute timeout
|
.timeout(Duration::from_secs(120))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Error creating HTTP client: {}", e);
|
error!("Error creating HTTP client: {}", e);
|
||||||
|
|
@ -301,7 +286,6 @@ pub async fn chat_completions_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Convert llama.cpp response to OpenAI format
|
|
||||||
let openai_response = ChatCompletionResponse {
|
let openai_response = ChatCompletionResponse {
|
||||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||||
object: "chat.completion".to_string(),
|
object: "chat.completion".to_string(),
|
||||||
|
|
@ -344,7 +328,6 @@ pub async fn chat_completions_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Embedding Request - Modified to handle both string and array inputs
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
#[serde(deserialize_with = "deserialize_input")]
|
#[serde(deserialize_with = "deserialize_input")]
|
||||||
|
|
@ -354,7 +337,6 @@ pub struct EmbeddingRequest {
|
||||||
pub _encoding_format: Option<String>,
|
pub _encoding_format: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom deserializer to handle both string and array inputs
|
|
||||||
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
where
|
where
|
||||||
D: serde::Deserializer<'de>,
|
D: serde::Deserializer<'de>,
|
||||||
|
|
@ -400,7 +382,6 @@ where
|
||||||
deserializer.deserialize_any(InputVisitor)
|
deserializer.deserialize_any(InputVisitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAI Embedding Response
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub object: String,
|
pub object: String,
|
||||||
|
|
@ -422,20 +403,17 @@ pub struct Usage {
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama.cpp Embedding Request
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct LlamaCppEmbeddingRequest {
|
struct LlamaCppEmbeddingRequest {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXED: Handle the stupid nested array format
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct LlamaCppEmbeddingResponseItem {
|
struct LlamaCppEmbeddingResponseItem {
|
||||||
pub index: usize,
|
pub index: usize,
|
||||||
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
|
pub embedding: Vec<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy endpoint for embeddings
|
|
||||||
#[post("/v1/embeddings")]
|
#[post("/v1/embeddings")]
|
||||||
pub async fn embeddings_local(
|
pub async fn embeddings_local(
|
||||||
req_body: web::Json<EmbeddingRequest>,
|
req_body: web::Json<EmbeddingRequest>,
|
||||||
|
|
@ -443,7 +421,6 @@ pub async fn embeddings_local(
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
|
|
||||||
// Get llama.cpp server URL
|
|
||||||
let llama_url =
|
let llama_url =
|
||||||
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
||||||
|
|
@ -455,7 +432,6 @@ pub async fn embeddings_local(
|
||||||
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Process each input text and get embeddings
|
|
||||||
let mut embeddings_data = Vec::new();
|
let mut embeddings_data = Vec::new();
|
||||||
let mut total_tokens = 0;
|
let mut total_tokens = 0;
|
||||||
|
|
||||||
|
|
@ -480,13 +456,11 @@ pub async fn embeddings_local(
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
||||||
if status.is_success() {
|
if status.is_success() {
|
||||||
// First, get the raw response text for debugging
|
|
||||||
let raw_response = response.text().await.map_err(|e| {
|
let raw_response = response.text().await.map_err(|e| {
|
||||||
error!("Error reading response text: {}", e);
|
error!("Error reading response text: {}", e);
|
||||||
actix_web::error::ErrorInternalServerError("Failed to read response")
|
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Parse the response as a vector of items with nested arrays
|
|
||||||
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||||
serde_json::from_str(&raw_response).map_err(|e| {
|
serde_json::from_str(&raw_response).map_err(|e| {
|
||||||
error!("Error parsing llama.cpp embedding response: {}", e);
|
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||||
|
|
@ -496,17 +470,13 @@ pub async fn embeddings_local(
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Extract the embedding from the nested array bullshit
|
|
||||||
if let Some(item) = llama_response.get(0) {
|
if let Some(item) = llama_response.get(0) {
|
||||||
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
|
|
||||||
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
|
|
||||||
let flattened_embedding = if !item.embedding.is_empty() {
|
let flattened_embedding = if !item.embedding.is_empty() {
|
||||||
item.embedding[0].clone() // Take the first (and probably only) inner array
|
item.embedding[0].clone()
|
||||||
} else {
|
} else {
|
||||||
vec![] // Empty if no embedding data
|
vec![]
|
||||||
};
|
};
|
||||||
|
|
||||||
// Estimate token count
|
|
||||||
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||||
total_tokens += estimated_tokens;
|
total_tokens += estimated_tokens;
|
||||||
|
|
||||||
|
|
@ -544,7 +514,6 @@ pub async fn embeddings_local(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build OpenAI-compatible response
|
|
||||||
let openai_response = EmbeddingResponse {
|
let openai_response = EmbeddingResponse {
|
||||||
object: "list".to_string(),
|
object: "list".to_string(),
|
||||||
data: embeddings_data,
|
data: embeddings_data,
|
||||||
|
|
@ -558,7 +527,6 @@ pub async fn embeddings_local(
|
||||||
Ok(HttpResponse::Ok().json(openai_response))
|
Ok(HttpResponse::Ok().json(openai_response))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health check endpoint
|
|
||||||
#[actix_web::get("/health")]
|
#[actix_web::get("/health")]
|
||||||
pub async fn health() -> Result<HttpResponse> {
|
pub async fn health() -> Result<HttpResponse> {
|
||||||
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
|
||||||
|
|
@ -1,116 +1,256 @@
|
||||||
use log::info;
|
use async_trait::async_trait;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use langchain_rust::{
|
||||||
|
language_models::llm::LLM,
|
||||||
|
llm::{claude::Claude, openai::OpenAI},
|
||||||
|
schemas::Message,
|
||||||
|
};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
use crate::tools::ToolManager;
|
||||||
use dotenv::dotenv;
|
|
||||||
use regex::Regex;
|
|
||||||
use reqwest::Client;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
// OpenAI-compatible request/response structures
|
#[async_trait]
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
pub trait LLMProvider: Send + Sync {
|
||||||
struct ChatMessage {
|
async fn generate(
|
||||||
role: String,
|
&self,
|
||||||
content: String,
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
tool_manager: Arc<ToolManager>,
|
||||||
|
session_id: &str,
|
||||||
|
user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
pub struct OpenAIClient {
|
||||||
struct ChatCompletionRequest {
|
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
|
||||||
model: String,
|
|
||||||
messages: Vec<ChatMessage>,
|
|
||||||
stream: Option<bool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
impl OpenAIClient {
|
||||||
struct ChatCompletionResponse {
|
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
|
||||||
id: String,
|
Self { client }
|
||||||
object: String,
|
}
|
||||||
created: u64,
|
|
||||||
model: String,
|
|
||||||
choices: Vec<Choice>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[async_trait]
|
||||||
struct Choice {
|
impl LLMProvider for OpenAIClient {
|
||||||
message: ChatMessage,
|
async fn generate(
|
||||||
finish_reason: String,
|
&self,
|
||||||
}
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
#[post("/azure/v1/chat/completions")]
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
let result = self
|
||||||
// Always log raw POST data
|
.client
|
||||||
if let Ok(body_str) = std::str::from_utf8(&body) {
|
.invoke(prompt)
|
||||||
info!("POST Data: {}", body_str);
|
|
||||||
} else {
|
|
||||||
info!("POST Data (binary): {:?}", body);
|
|
||||||
}
|
|
||||||
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
// Environment variables
|
|
||||||
let azure_endpoint = env::var("AI_ENDPOINT")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
|
||||||
let azure_key = env::var("AI_KEY")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
|
||||||
let deployment_name = env::var("AI_LLM_MODEL")
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
|
||||||
|
|
||||||
// Construct Azure OpenAI URL
|
|
||||||
let url = format!(
|
|
||||||
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
|
|
||||||
azure_endpoint, deployment_name
|
|
||||||
);
|
|
||||||
|
|
||||||
// Forward headers
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
"api-key",
|
|
||||||
reqwest::header::HeaderValue::from_str(&azure_key)
|
|
||||||
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
"Content-Type",
|
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
|
||||||
);
|
|
||||||
|
|
||||||
let body_str = std::str::from_utf8(&body).unwrap_or("");
|
|
||||||
info!("Original POST Data: {}", body_str);
|
|
||||||
|
|
||||||
// Remove the problematic params
|
|
||||||
let re =
|
|
||||||
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
|
|
||||||
let cleaned = re.replace_all(body_str, "");
|
|
||||||
let cleaned_body = web::Bytes::from(cleaned.to_string());
|
|
||||||
|
|
||||||
info!("Cleaned POST Data: {}", cleaned);
|
|
||||||
|
|
||||||
// Send request to Azure
|
|
||||||
let client = Client::new();
|
|
||||||
let response = client
|
|
||||||
.post(&url)
|
|
||||||
.headers(headers)
|
|
||||||
.body(cleaned_body)
|
|
||||||
.send()
|
|
||||||
.await
|
.await
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
// Handle response based on status
|
Ok(result)
|
||||||
let status = response.status();
|
}
|
||||||
let raw_response = response
|
|
||||||
.text()
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let mut stream = self
|
||||||
|
.client
|
||||||
|
.stream(prompt)
|
||||||
.await
|
.await
|
||||||
.map_err(actix_web::error::ErrorInternalServerError)?;
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
// Log the raw response
|
while let Some(result) = stream.next().await {
|
||||||
info!("Raw Azure response: {}", raw_response);
|
match result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let content = chunk.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(content.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Stream error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if status.is_success() {
|
Ok(())
|
||||||
Ok(HttpResponse::Ok().body(raw_response))
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_info = if available_tools.is_empty() {
|
||||||
|
String::new()
|
||||||
} else {
|
} else {
|
||||||
// Handle error responses properly
|
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||||
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
};
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
|
|
||||||
Ok(HttpResponse::build(actix_status).body(raw_response))
|
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(&enhanced_prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AnthropicClient {
|
||||||
|
client: Claude,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicClient {
|
||||||
|
pub fn new(api_key: String) -> Self {
|
||||||
|
let client = Claude::default().with_api_key(api_key);
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for AnthropicClient {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let mut stream = self
|
||||||
|
.client
|
||||||
|
.stream(prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
let content = chunk.content;
|
||||||
|
if !content.is_empty() {
|
||||||
|
let _ = tx.send(content.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Stream error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_info = if available_tools.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
||||||
|
};
|
||||||
|
|
||||||
|
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.client
|
||||||
|
.invoke(&enhanced_prompt)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MockLLMProvider;
|
||||||
|
|
||||||
|
impl MockLLMProvider {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl LLMProvider for MockLLMProvider {
|
||||||
|
async fn generate(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
Ok(format!("Mock response to: {}", prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_stream(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
tx: mpsc::Sender<String>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let response = format!("Mock stream response to: {}", prompt);
|
||||||
|
for word in response.split_whitespace() {
|
||||||
|
let _ = tx.send(format!("{} ", word)).await;
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_with_tools(
|
||||||
|
&self,
|
||||||
|
prompt: &str,
|
||||||
|
_config: &Value,
|
||||||
|
available_tools: &[String],
|
||||||
|
_tool_manager: Arc<ToolManager>,
|
||||||
|
_session_id: &str,
|
||||||
|
_user_id: &str,
|
||||||
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let tools_list = if available_tools.is_empty() {
|
||||||
|
"no tools available".to_string()
|
||||||
|
} else {
|
||||||
|
available_tools.join(", ")
|
||||||
|
};
|
||||||
|
Ok(format!(
|
||||||
|
"Mock response with tools [{}] to: {}",
|
||||||
|
tools_list, prompt
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
967
src/llm/mod.rs
967
src/llm/mod.rs
|
|
@ -2,273 +2,748 @@ pub mod llm_generic;
|
||||||
pub mod llm_local;
|
pub mod llm_local;
|
||||||
pub mod llm_provider;
|
pub mod llm_provider;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
pub use llm_provider::*;
|
||||||
use futures::StreamExt;
|
|
||||||
use langchain_rust::{
|
|
||||||
language_models::llm::LLM,
|
|
||||||
llm::{claude::Claude, openai::OpenAI},
|
|
||||||
schemas::Message,
|
|
||||||
};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::mpsc;
|
|
||||||
|
|
||||||
use crate::tools::ToolManager;
|
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use log::{error, info};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::env;
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
#[async_trait]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub trait LLMProvider: Send + Sync {
|
struct ChatMessage {
|
||||||
async fn generate(
|
role: String,
|
||||||
&self,
|
content: String,
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
// Add tool calling capability using LangChain tools
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
tool_manager: Arc<ToolManager>,
|
|
||||||
session_id: &str,
|
|
||||||
user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct OpenAIClient {
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
client: OpenAI,
|
struct ChatCompletionRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
stream: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIClient {
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub fn new(client: OpenAI) -> Self {
|
struct ChatCompletionResponse {
|
||||||
Self { client }
|
id: String,
|
||||||
|
object: String,
|
||||||
|
created: u64,
|
||||||
|
model: String,
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct Choice {
|
||||||
|
message: ChatMessage,
|
||||||
|
finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct LlamaCppRequest {
|
||||||
|
prompt: String,
|
||||||
|
n_predict: Option<i32>,
|
||||||
|
temperature: Option<f32>,
|
||||||
|
top_k: Option<i32>,
|
||||||
|
top_p: Option<f32>,
|
||||||
|
stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct LlamaCppResponse {
|
||||||
|
content: String,
|
||||||
|
stop: bool,
|
||||||
|
generation_settings: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
|
||||||
|
|
||||||
|
if llm_local.to_lowercase() != "true" {
|
||||||
|
info!("ℹ️ LLM_LOCAL is not enabled, skipping local server startup");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
let embedding_url =
|
||||||
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
let llama_cpp_path = env::var("LLM_CPP_PATH").unwrap_or_else(|_| "~/llama.cpp".to_string());
|
||||||
|
let llm_model_path = env::var("LLM_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||||
|
let embedding_model_path = env::var("EMBEDDING_MODEL_PATH").unwrap_or_else(|_| "".to_string());
|
||||||
|
|
||||||
|
info!("🚀 Starting local llama.cpp servers...");
|
||||||
|
info!("📋 Configuration:");
|
||||||
|
info!(" LLM URL: {}", llm_url);
|
||||||
|
info!(" Embedding URL: {}", embedding_url);
|
||||||
|
info!(" LLM Model: {}", llm_model_path);
|
||||||
|
info!(" Embedding Model: {}", embedding_model_path);
|
||||||
|
|
||||||
|
let llm_running = is_server_running(&llm_url).await;
|
||||||
|
let embedding_running = is_server_running(&embedding_url).await;
|
||||||
|
|
||||||
|
if llm_running && embedding_running {
|
||||||
|
info!("✅ Both LLM and Embedding servers are already running");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tasks = vec![];
|
||||||
|
|
||||||
|
if !llm_running && !llm_model_path.is_empty() {
|
||||||
|
info!("🔄 Starting LLM server...");
|
||||||
|
tasks.push(tokio::spawn(start_llm_server(
|
||||||
|
llama_cpp_path.clone(),
|
||||||
|
llm_model_path.clone(),
|
||||||
|
llm_url.clone(),
|
||||||
|
)));
|
||||||
|
} else if llm_model_path.is_empty() {
|
||||||
|
info!("⚠️ LLM_MODEL_PATH not set, skipping LLM server");
|
||||||
|
}
|
||||||
|
|
||||||
|
if !embedding_running && !embedding_model_path.is_empty() {
|
||||||
|
info!("🔄 Starting Embedding server...");
|
||||||
|
tasks.push(tokio::spawn(start_embedding_server(
|
||||||
|
llama_cpp_path.clone(),
|
||||||
|
embedding_model_path.clone(),
|
||||||
|
embedding_url.clone(),
|
||||||
|
)));
|
||||||
|
} else if embedding_model_path.is_empty() {
|
||||||
|
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
|
||||||
|
}
|
||||||
|
|
||||||
|
for task in tasks {
|
||||||
|
task.await??;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("⏳ Waiting for servers to become ready...");
|
||||||
|
|
||||||
|
let mut llm_ready = llm_running || llm_model_path.is_empty();
|
||||||
|
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
|
||||||
|
|
||||||
|
let mut attempts = 0;
|
||||||
|
let max_attempts = 60;
|
||||||
|
|
||||||
|
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
|
||||||
|
sleep(Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"🔍 Checking server health (attempt {}/{})...",
|
||||||
|
attempts + 1,
|
||||||
|
max_attempts
|
||||||
|
);
|
||||||
|
|
||||||
|
if !llm_ready && !llm_model_path.is_empty() {
|
||||||
|
if is_server_running(&llm_url).await {
|
||||||
|
info!(" ✅ LLM server ready at {}", llm_url);
|
||||||
|
llm_ready = true;
|
||||||
|
} else {
|
||||||
|
info!(" ❌ LLM server not ready yet");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||||
impl LLMProvider for OpenAIClient {
|
if is_server_running(&embedding_url).await {
|
||||||
async fn generate(
|
info!(" ✅ Embedding server ready at {}", embedding_url);
|
||||||
&self,
|
embedding_ready = true;
|
||||||
prompt: &str,
|
} else {
|
||||||
_config: &Value,
|
info!(" ❌ Embedding server not ready yet");
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
}
|
||||||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn generate_stream(
|
attempts += 1;
|
||||||
&self,
|
|
||||||
prompt: &str,
|
if attempts % 10 == 0 {
|
||||||
_config: &Value,
|
info!(
|
||||||
mut tx: mpsc::Sender<String>,
|
"⏰ Still waiting for servers... (attempt {}/{})",
|
||||||
|
attempts, max_attempts
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if llm_ready && embedding_ready {
|
||||||
|
info!("🎉 All llama.cpp servers are ready and responding!");
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let mut error_msg = "❌ Servers failed to start within timeout:".to_string();
|
||||||
|
if !llm_ready && !llm_model_path.is_empty() {
|
||||||
|
error_msg.push_str(&format!("\n - LLM server at {}", llm_url));
|
||||||
|
}
|
||||||
|
if !embedding_ready && !embedding_model_path.is_empty() {
|
||||||
|
error_msg.push_str(&format!("\n - Embedding server at {}", embedding_url));
|
||||||
|
}
|
||||||
|
Err(error_msg.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_llm_server(
|
||||||
|
llama_cpp_path: String,
|
||||||
|
model_path: String,
|
||||||
|
url: String,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
let port = url.split(':').last().unwrap_or("8081");
|
||||||
|
|
||||||
let mut stream = self
|
std::env::set_var("OMP_NUM_THREADS", "20");
|
||||||
.client
|
std::env::set_var("OMP_PLACES", "cores");
|
||||||
.stream(&messages)
|
std::env::set_var("OMP_PROC_BIND", "close");
|
||||||
|
|
||||||
|
let mut cmd = tokio::process::Command::new("sh");
|
||||||
|
cmd.arg("-c").arg(format!(
|
||||||
|
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
|
||||||
|
llama_cpp_path, model_path, port
|
||||||
|
));
|
||||||
|
|
||||||
|
cmd.spawn()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_embedding_server(
|
||||||
|
llama_cpp_path: String,
|
||||||
|
model_path: String,
|
||||||
|
url: String,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let port = url.split(':').last().unwrap_or("8082");
|
||||||
|
|
||||||
|
let mut cmd = tokio::process::Command::new("sh");
|
||||||
|
cmd.arg("-c").arg(format!(
|
||||||
|
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --embedding --n-gpu-layers 99 &",
|
||||||
|
llama_cpp_path, model_path, port
|
||||||
|
));
|
||||||
|
|
||||||
|
cmd.spawn()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn is_server_running(url: &str) -> bool {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
match client.get(&format!("{}/health", url)).send().await {
|
||||||
|
Ok(response) => response.status().is_success(),
|
||||||
|
Err(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
|
||||||
|
let mut prompt = String::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
match message.role.as_str() {
|
||||||
|
"system" => {
|
||||||
|
prompt.push_str(&format!("System: {}\n\n", message.content));
|
||||||
|
}
|
||||||
|
"user" => {
|
||||||
|
prompt.push_str(&format!("User: {}\n\n", message.content));
|
||||||
|
}
|
||||||
|
"assistant" => {
|
||||||
|
prompt.push_str(&format!("Assistant: {}\n\n", message.content));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
prompt.push_str(&format!("{}: {}\n\n", message.role, message.content));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.push_str("Assistant: ");
|
||||||
|
prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/local/v1/chat/completions")]
|
||||||
|
pub async fn chat_completions_local(
|
||||||
|
req_body: web::Json<ChatCompletionRequest>,
|
||||||
|
_req: HttpRequest,
|
||||||
|
) -> Result<HttpResponse> {
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
||||||
|
let prompt = messages_to_prompt(&req_body.messages);
|
||||||
|
|
||||||
|
let llama_request = LlamaCppRequest {
|
||||||
|
prompt,
|
||||||
|
n_predict: Some(500),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
top_k: Some(40),
|
||||||
|
top_p: Some(0.9),
|
||||||
|
stream: req_body.stream,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(120))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("Error creating HTTP client: {}", e);
|
||||||
|
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(&format!("{}/completion", llama_url))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&llama_request)
|
||||||
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
.map_err(|e| {
|
||||||
|
error!("Error calling llama.cpp server: {}", e);
|
||||||
|
actix_web::error::ErrorInternalServerError("Failed to call llama.cpp server")
|
||||||
|
})?;
|
||||||
|
|
||||||
while let Some(result) = stream.next().await {
|
let status = response.status();
|
||||||
match result {
|
|
||||||
Ok(chunk) => {
|
if status.is_success() {
|
||||||
let content = chunk.content;
|
let llama_response: LlamaCppResponse = response.json().await.map_err(|e| {
|
||||||
if !content.is_empty() {
|
error!("Error parsing llama.cpp response: {}", e);
|
||||||
let _ = tx.send(content.to_string()).await;
|
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let openai_response = ChatCompletionResponse {
|
||||||
|
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||||
|
object: "chat.completion".to_string(),
|
||||||
|
created: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs(),
|
||||||
|
model: req_body.model.clone(),
|
||||||
|
choices: vec![Choice {
|
||||||
|
message: ChatMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: llama_response.content.trim().to_string(),
|
||||||
|
},
|
||||||
|
finish_reason: if llama_response.stop {
|
||||||
|
"stop".to_string()
|
||||||
|
} else {
|
||||||
|
"length".to_string()
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(openai_response))
|
||||||
|
} else {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
|
||||||
|
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||||
|
|
||||||
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": error_text,
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
})))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct EmbeddingRequest {
|
||||||
|
#[serde(deserialize_with = "deserialize_input")]
|
||||||
|
pub input: Vec<String>,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub _encoding_format: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
use serde::de::{self, Visitor};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
struct InputVisitor;
|
||||||
|
|
||||||
|
impl<'de> Visitor<'de> for InputVisitor {
|
||||||
|
type Value = Vec<String>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str("a string or an array of strings")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(vec![value.to_string()])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
|
||||||
|
where
|
||||||
|
E: de::Error,
|
||||||
|
{
|
||||||
|
Ok(vec![value])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: de::SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut vec = Vec::new();
|
||||||
|
while let Some(value) = seq.next_element::<String>()? {
|
||||||
|
vec.push(value);
|
||||||
|
}
|
||||||
|
Ok(vec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deserializer.deserialize_any(InputVisitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct EmbeddingResponse {
|
||||||
|
pub object: String,
|
||||||
|
pub data: Vec<EmbeddingData>,
|
||||||
|
pub model: String,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct EmbeddingData {
|
||||||
|
pub object: String,
|
||||||
|
pub embedding: Vec<f32>,
|
||||||
|
pub index: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct LlamaCppEmbeddingRequest {
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct LlamaCppEmbeddingResponseItem {
|
||||||
|
pub index: usize,
|
||||||
|
pub embedding: Vec<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/v1/embeddings")]
|
||||||
|
pub async fn embeddings_local(
|
||||||
|
req_body: web::Json<EmbeddingRequest>,
|
||||||
|
_req: HttpRequest,
|
||||||
|
) -> Result<HttpResponse> {
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
let llama_url =
|
||||||
|
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
|
||||||
|
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(120))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("Error creating HTTP client: {}", e);
|
||||||
|
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut embeddings_data = Vec::new();
|
||||||
|
let mut total_tokens = 0;
|
||||||
|
|
||||||
|
for (index, input_text) in req_body.input.iter().enumerate() {
|
||||||
|
let llama_request = LlamaCppEmbeddingRequest {
|
||||||
|
content: input_text.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(&format!("{}/embedding", llama_url))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&llama_request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!("Error calling llama.cpp server for embedding: {}", e);
|
||||||
|
actix_web::error::ErrorInternalServerError(
|
||||||
|
"Failed to call llama.cpp server for embedding",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
let raw_response = response.text().await.map_err(|e| {
|
||||||
|
error!("Error reading response text: {}", e);
|
||||||
|
actix_web::error::ErrorInternalServerError("Failed to read response")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
|
||||||
|
serde_json::from_str(&raw_response).map_err(|e| {
|
||||||
|
error!("Error parsing llama.cpp embedding response: {}", e);
|
||||||
|
error!("Raw response: {}", raw_response);
|
||||||
|
actix_web::error::ErrorInternalServerError(
|
||||||
|
"Failed to parse llama.cpp embedding response",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if let Some(item) = llama_response.get(0) {
|
||||||
|
let flattened_embedding = if !item.embedding.is_empty() {
|
||||||
|
item.embedding[0].clone()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
|
||||||
|
total_tokens += estimated_tokens;
|
||||||
|
|
||||||
|
embeddings_data.push(EmbeddingData {
|
||||||
|
object: "embedding".to_string(),
|
||||||
|
embedding: flattened_embedding,
|
||||||
|
index,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
error!("No embedding data returned for input: {}", input_text);
|
||||||
|
return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("No embedding data returned for input {}", index),
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
|
||||||
|
error!("Llama.cpp server error ({}): {}", status, error_text);
|
||||||
|
|
||||||
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
return Ok(HttpResponse::build(actix_status).json(serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": format!("Failed to get embedding for input {}: {}", index, error_text),
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let openai_response = EmbeddingResponse {
|
||||||
|
object: "list".to_string(),
|
||||||
|
data: embeddings_data,
|
||||||
|
model: req_body.model.clone(),
|
||||||
|
usage: Usage {
|
||||||
|
prompt_tokens: total_tokens,
|
||||||
|
total_tokens,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(HttpResponse::Ok().json(openai_response))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[actix_web::get("/health")]
|
||||||
|
pub async fn health() -> Result<HttpResponse> {
|
||||||
|
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
|
||||||
|
|
||||||
|
if is_server_running(&llama_url).await {
|
||||||
|
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||||
|
"status": "healthy",
|
||||||
|
"llama_server": "running"
|
||||||
|
})))
|
||||||
|
} else {
|
||||||
|
Ok(HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||||
|
"status": "unhealthy",
|
||||||
|
"llama_server": "not running"
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct GenericChatCompletionRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/v1/chat/completions")]
|
||||||
|
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
|
||||||
|
let body_str = std::str::from_utf8(&body).unwrap_or_default();
|
||||||
|
info!("Original POST Data: {}", body_str);
|
||||||
|
|
||||||
|
dotenv().ok();
|
||||||
|
|
||||||
|
let api_key = env::var("AI_KEY")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
|
||||||
|
let model = env::var("AI_LLM_MODEL")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
|
||||||
|
let endpoint = env::var("AI_ENDPOINT")
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
|
||||||
|
|
||||||
|
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
|
||||||
|
|
||||||
|
if let Some(obj) = json_value.as_object_mut() {
|
||||||
|
obj.insert("model".to_string(), serde_json::Value::String(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
let modified_body_str = serde_json::to_string(&json_value)
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to serialize JSON"))?;
|
||||||
|
|
||||||
|
info!("Modified POST Data: {}", modified_body_str);
|
||||||
|
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"Authorization",
|
||||||
|
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||||
|
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid API key format"))?,
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"Content-Type",
|
||||||
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = Client::new();
|
||||||
|
let response = client
|
||||||
|
.post(&endpoint)
|
||||||
|
.headers(headers)
|
||||||
|
.body(modified_body_str)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let raw_response = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(actix_web::error::ErrorInternalServerError)?;
|
||||||
|
|
||||||
|
info!("Provider response status: {}", status);
|
||||||
|
info!("Provider response body: {}", raw_response);
|
||||||
|
|
||||||
|
if status.is_success() {
|
||||||
|
match convert_to_openai_format(&raw_response) {
|
||||||
|
Ok(openai_response) => Ok(HttpResponse::Ok()
|
||||||
|
.content_type("application/json")
|
||||||
|
.body(openai_response)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Stream error: {}", e);
|
error!("Failed to convert response format: {}", e);
|
||||||
|
Ok(HttpResponse::Ok()
|
||||||
|
.content_type("application/json")
|
||||||
|
.body(raw_response))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
// Enhanced prompt with tool information
|
|
||||||
let tools_info = if available_tools.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
} else {
|
||||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
|
||||||
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
Ok(HttpResponse::build(actix_status)
|
||||||
|
.content_type("application/json")
|
||||||
|
.body(raw_response))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ProviderChoice {
|
||||||
|
message: ProviderMessage,
|
||||||
|
#[serde(default)]
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ProviderMessage {
|
||||||
|
role: Option<String>,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct ProviderResponse {
|
||||||
|
id: Option<String>,
|
||||||
|
object: Option<String>,
|
||||||
|
created: Option<u64>,
|
||||||
|
model: Option<String>,
|
||||||
|
choices: Vec<ProviderChoice>,
|
||||||
|
usage: Option<ProviderUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Default)]
|
||||||
|
struct ProviderUsage {
|
||||||
|
prompt_tokens: Option<u32>,
|
||||||
|
completion_tokens: Option<u32>,
|
||||||
|
total_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct OpenAIResponse {
|
||||||
|
id: String,
|
||||||
|
object: String,
|
||||||
|
created: u64,
|
||||||
|
model: String,
|
||||||
|
choices: Vec<OpenAIChoice>,
|
||||||
|
usage: OpenAIUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct OpenAIChoice {
|
||||||
|
index: u32,
|
||||||
|
message: OpenAIMessage,
|
||||||
|
finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct OpenAIMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct OpenAIUsage {
|
||||||
|
prompt_tokens: u32,
|
||||||
|
completion_tokens: u32,
|
||||||
|
total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
|
||||||
|
|
||||||
|
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
|
||||||
|
let content = first_choice.message.content.clone();
|
||||||
|
let role = first_choice
|
||||||
|
.message
|
||||||
|
.role
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "assistant".to_string());
|
||||||
|
|
||||||
|
let usage = provider.usage.unwrap_or_default();
|
||||||
|
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
|
||||||
|
let completion_tokens = usage
|
||||||
|
.completion_tokens
|
||||||
|
.unwrap_or_else(|| content.split_whitespace().count() as u32);
|
||||||
|
let total_tokens = usage
|
||||||
|
.total_tokens
|
||||||
|
.unwrap_or(prompt_tokens + completion_tokens);
|
||||||
|
|
||||||
|
let openai_response = OpenAIResponse {
|
||||||
|
id: provider
|
||||||
|
.id
|
||||||
|
.unwrap_or_else(|| format!("chatcmpl-{}", uuid::Uuid::new_v4().simple())),
|
||||||
|
object: provider
|
||||||
|
.object
|
||||||
|
.unwrap_or_else(|| "chat.completion".to_string()),
|
||||||
|
created: provider.created.unwrap_or_else(|| {
|
||||||
|
std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs()
|
||||||
|
}),
|
||||||
|
model: provider.model.unwrap_or_else(|| "llama".to_string()),
|
||||||
|
choices: vec![OpenAIChoice {
|
||||||
|
index: 0,
|
||||||
|
message: OpenAIMessage { role, content },
|
||||||
|
finish_reason: first_choice
|
||||||
|
.finish_reason
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "stop".to_string()),
|
||||||
|
}],
|
||||||
|
usage: OpenAIUsage {
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
serde_json::to_string(&openai_response).map_err(|e| e.into())
|
||||||
|
|
||||||
let messages = vec![Message::new_human_message(enhanced_prompt)];
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AnthropicClient {
|
|
||||||
client: Claude,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AnthropicClient {
|
|
||||||
pub fn new(api_key: String) -> Self {
|
|
||||||
let client = Claude::default().with_api_key(api_key);
|
|
||||||
Self { client }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for AnthropicClient {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
mut tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let messages = vec![Message::new_human_message(prompt.to_string())];
|
|
||||||
|
|
||||||
let mut stream = self
|
|
||||||
.client
|
|
||||||
.stream(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
while let Some(result) = stream.next().await {
|
|
||||||
match result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
let content = chunk.content;
|
|
||||||
if !content.is_empty() {
|
|
||||||
let _ = tx.send(content.to_string()).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Stream error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let tools_info = if available_tools.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
|
|
||||||
};
|
|
||||||
|
|
||||||
let enhanced_prompt = format!("{}{}", prompt, tools_info);
|
|
||||||
|
|
||||||
let messages = vec![Message::new_human_message(enhanced_prompt)];
|
|
||||||
|
|
||||||
let result = self
|
|
||||||
.client
|
|
||||||
.invoke(&messages)
|
|
||||||
.await
|
|
||||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct MockLLMProvider;
|
|
||||||
|
|
||||||
impl MockLLMProvider {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for MockLLMProvider {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
Ok(format!("Mock response to: {}", prompt))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
mut tx: mpsc::Sender<String>,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let response = format!("Mock stream response to: {}", prompt);
|
|
||||||
for word in response.split_whitespace() {
|
|
||||||
let _ = tx.send(format!("{} ", word)).await;
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_with_tools(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
_config: &Value,
|
|
||||||
available_tools: &[String],
|
|
||||||
_tool_manager: Arc<ToolManager>,
|
|
||||||
_session_id: &str,
|
|
||||||
_user_id: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let tools_list = if available_tools.is_empty() {
|
|
||||||
"no tools available".to_string()
|
|
||||||
} else {
|
|
||||||
available_tools.join(", ")
|
|
||||||
};
|
|
||||||
Ok(format!(
|
|
||||||
"Mock response with tools [{}] to: {}",
|
|
||||||
tools_list, prompt
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
284
src/main.rs
284
src/main.rs
|
|
@ -1,3 +1,9 @@
|
||||||
|
use actix_cors::Cors;
|
||||||
|
use actix_web::{middleware, web, App, HttpServer};
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use log::info;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
mod automation;
|
mod automation;
|
||||||
mod basic;
|
mod basic;
|
||||||
|
|
@ -16,239 +22,105 @@ mod tools;
|
||||||
mod web_automation;
|
mod web_automation;
|
||||||
mod whatsapp;
|
mod whatsapp;
|
||||||
|
|
||||||
use log::info;
|
use crate::{config::AppConfig, shared::state::AppState};
|
||||||
use qdrant_client::Qdrant;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use actix_web::{web, App, HttpServer};
|
#[actix_web::main]
|
||||||
use dotenv::dotenv;
|
|
||||||
use sqlx::PgPool;
|
|
||||||
|
|
||||||
use crate::auth::AuthService;
|
|
||||||
use crate::bot::BotOrchestrator;
|
|
||||||
use crate::config::AppConfig;
|
|
||||||
use crate::email::{
|
|
||||||
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
|
|
||||||
};
|
|
||||||
use crate::file::{list_file, upload_file};
|
|
||||||
use crate::llm::llm_generic::generic_chat_completions;
|
|
||||||
use crate::llm::llm_local::{
|
|
||||||
chat_completions_local, embeddings_local, ensure_llama_servers_running,
|
|
||||||
};
|
|
||||||
use crate::session::SessionManager;
|
|
||||||
use crate::shared::state::AppState;
|
|
||||||
use crate::tools::{RedisToolExecutor, ToolManager};
|
|
||||||
use crate::web_automation::{initialize_browser_pool, BrowserPool};
|
|
||||||
|
|
||||||
#[tokio::main(flavor = "multi_thread")]
|
|
||||||
async fn main() -> std::io::Result<()> {
|
async fn main() -> std::io::Result<()> {
|
||||||
dotenv().ok();
|
dotenv().ok();
|
||||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
|
env_logger::init();
|
||||||
info!("Starting General Bots 6.0...");
|
|
||||||
|
info!("🚀 Starting Bot Server...");
|
||||||
|
|
||||||
let config = AppConfig::from_env();
|
let config = AppConfig::from_env();
|
||||||
let db_url = config.database_url();
|
|
||||||
let db_custom_url = config.database_custom_url();
|
|
||||||
let db = PgPool::connect(&db_url).await.unwrap();
|
|
||||||
let db_custom = PgPool::connect(&db_custom_url).await.unwrap();
|
|
||||||
|
|
||||||
let minio_client = init_minio(&config)
|
let db_pool = sqlx::postgres::PgPoolOptions::new()
|
||||||
|
.max_connections(5)
|
||||||
|
.connect(&config.database_url())
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize Minio");
|
.expect("Failed to create database pool");
|
||||||
|
|
||||||
let browser_pool = Arc::new(BrowserPool::new(
|
let db_custom_pool = sqlx::postgres::PgPoolOptions::new()
|
||||||
"http://localhost:9515".to_string(),
|
.max_connections(5)
|
||||||
5,
|
.connect(&config.database_custom_url())
|
||||||
"/usr/bin/brave-browser-beta".to_string(),
|
.await
|
||||||
|
.expect("Failed to create custom database pool");
|
||||||
|
|
||||||
|
let redis_client = redis::Client::open("redis://127.0.0.1/").ok();
|
||||||
|
|
||||||
|
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone().map(Arc::new));
|
||||||
|
let session_manager =
|
||||||
|
session::SessionManager::new(db_pool.clone(), redis_client.clone().map(Arc::new));
|
||||||
|
|
||||||
|
let tool_manager = tools::ToolManager::new();
|
||||||
|
let tool_api = Arc::new(tools::ToolApi::new());
|
||||||
|
|
||||||
|
let web_adapter = Arc::new(channels::WebChannelAdapter::new());
|
||||||
|
let voice_adapter = Arc::new(channels::VoiceAdapter::new(
|
||||||
|
"https://livekit.example.com".to_string(),
|
||||||
|
"api_key".to_string(),
|
||||||
|
"api_secret".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
ensure_llama_servers_running()
|
let whatsapp_adapter = Arc::new(whatsapp::WhatsAppAdapter::new(
|
||||||
.await
|
"whatsapp_token".to_string(),
|
||||||
.expect("Failed to initialize LLM local server.");
|
"phone_number_id".to_string(),
|
||||||
|
"verify_token".to_string(),
|
||||||
initialize_browser_pool()
|
|
||||||
.await
|
|
||||||
.expect("Failed to initialize browser pool");
|
|
||||||
|
|
||||||
// Initialize Redis if available
|
|
||||||
let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "".to_string());
|
|
||||||
let redis_conn = match std::env::var("REDIS_URL") {
|
|
||||||
Ok(redis_url_value) => {
|
|
||||||
let client = redis::Client::open(redis_url_value.clone())
|
|
||||||
.expect("Failed to create Redis client");
|
|
||||||
let conn = client
|
|
||||||
.get_connection()
|
|
||||||
.expect("Failed to create Redis connection");
|
|
||||||
Some(Arc::new(conn))
|
|
||||||
}
|
|
||||||
Err(_) => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or("http://localhost:6334".to_string());
|
|
||||||
let qdrant = Qdrant::from_url(&qdrant_url)
|
|
||||||
.build()
|
|
||||||
.expect("Failed to connect to Qdrant");
|
|
||||||
|
|
||||||
let session_manager = SessionManager::new(db.clone(), redis_conn.clone());
|
|
||||||
let auth_service = AuthService::new(db.clone(), redis_conn.clone());
|
|
||||||
|
|
||||||
let llm_provider: Arc<dyn crate::llm::LLMProvider> = match std::env::var("LLM_PROVIDER")
|
|
||||||
.unwrap_or("mock".to_string())
|
|
||||||
.as_str()
|
|
||||||
{
|
|
||||||
"openai" => Arc::new(crate::llm::OpenAIClient::new(
|
|
||||||
std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required"),
|
|
||||||
)),
|
|
||||||
"anthropic" => Arc::new(crate::llm::AnthropicClient::new(
|
|
||||||
std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY required"),
|
|
||||||
)),
|
|
||||||
_ => Arc::new(crate::llm::MockLLMProvider::new()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let web_adapter = Arc::new(crate::channels::WebChannelAdapter::new());
|
|
||||||
let voice_adapter = Arc::new(crate::channels::VoiceAdapter::new(
|
|
||||||
std::env::var("LIVEKIT_URL").unwrap_or("ws://localhost:7880".to_string()),
|
|
||||||
std::env::var("LIVEKIT_API_KEY").unwrap_or("dev".to_string()),
|
|
||||||
std::env::var("LIVEKIT_API_SECRET").unwrap_or("secret".to_string()),
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let whatsapp_adapter = Arc::new(crate::whatsapp::WhatsAppAdapter::new(
|
let llm_provider = Arc::new(llm::MockLLMProvider::new());
|
||||||
std::env::var("META_ACCESS_TOKEN").unwrap_or("".to_string()),
|
|
||||||
std::env::var("META_PHONE_NUMBER_ID").unwrap_or("".to_string()),
|
|
||||||
std::env::var("META_WEBHOOK_VERIFY_TOKEN").unwrap_or("".to_string()),
|
|
||||||
));
|
|
||||||
|
|
||||||
let tool_executor = Arc::new(
|
let orchestrator = Arc::new(bot::BotOrchestrator::new(
|
||||||
RedisToolExecutor::new(
|
|
||||||
redis_url.as_str(),
|
|
||||||
web_adapter.clone() as Arc<dyn crate::channels::ChannelAdapter>,
|
|
||||||
db.clone(),
|
|
||||||
redis_conn.clone(),
|
|
||||||
)
|
|
||||||
.expect("Failed to create RedisToolExecutor"),
|
|
||||||
);
|
|
||||||
let chart_generator = ChartGenerator::new().map(Arc::new);
|
|
||||||
// Initialize LangChain components
|
|
||||||
let llm = OpenAI::default();
|
|
||||||
let llm_provider: Arc<dyn LLMProvider> = Arc::new(OpenAIClient::new(llm));
|
|
||||||
|
|
||||||
// Initialize vector store for document mode
|
|
||||||
let vector_store = if let (Ok(qdrant_url), Ok(openai_key)) =
|
|
||||||
(std::env::var("QDRANT_URL"), std::env::var("OPENAI_API_KEY"))
|
|
||||||
{
|
|
||||||
let embedder = OpenAiEmbedder::default().with_api_key(openai_key);
|
|
||||||
let client = QdrantClient::from_url(&qdrant_url).build().ok()?;
|
|
||||||
|
|
||||||
let store = StoreBuilder::new()
|
|
||||||
.embedder(embedder)
|
|
||||||
.client(client)
|
|
||||||
.collection_name("documents")
|
|
||||||
.build()
|
|
||||||
.await
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(Arc::new(store))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initialize SQL chain for database mode
|
|
||||||
let sql_chain = if let Ok(db_url) = std::env::var("DATABASE_URL") {
|
|
||||||
let engine = PostgreSQLEngine::new(&db_url).await.ok()?;
|
|
||||||
let db = SQLDatabaseBuilder::new(engine).build().await.ok()?;
|
|
||||||
|
|
||||||
let llm = OpenAI::default();
|
|
||||||
let chain = langchain_rust::chain::SQLDatabaseChainBuilder::new()
|
|
||||||
.llm(llm)
|
|
||||||
.top_k(5)
|
|
||||||
.database(db)
|
|
||||||
.build()
|
|
||||||
.ok()?;
|
|
||||||
|
|
||||||
Some(Arc::new(chain))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let tool_manager = ToolManager::new();
|
|
||||||
let orchestrator = BotOrchestrator::new(
|
|
||||||
session_manager,
|
session_manager,
|
||||||
tool_manager,
|
tool_manager,
|
||||||
llm_provider,
|
llm_provider,
|
||||||
auth_service,
|
auth_service,
|
||||||
chart_generator,
|
));
|
||||||
vector_store,
|
|
||||||
sql_chain,
|
|
||||||
);
|
|
||||||
|
|
||||||
orchestrator.add_channel("web", web_adapter.clone());
|
let browser_pool = Arc::new(web_automation::BrowserPool::new());
|
||||||
orchestrator.add_channel("voice", voice_adapter.clone());
|
|
||||||
orchestrator.add_channel("whatsapp", whatsapp_adapter.clone());
|
|
||||||
|
|
||||||
sqlx::query(
|
let app_state = AppState {
|
||||||
"INSERT INTO bots (id, name, llm_provider) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING",
|
minio_client: None,
|
||||||
)
|
config: Some(config),
|
||||||
.bind(uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000000").unwrap())
|
db: Some(db_pool),
|
||||||
.bind("Default Bot")
|
db_custom: Some(db_custom_pool),
|
||||||
.bind("mock")
|
browser_pool,
|
||||||
.execute(&db)
|
orchestrator,
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let app_state = web::Data::new(AppState {
|
|
||||||
db: db.into(),
|
|
||||||
db_custom: db_custom.into(),
|
|
||||||
config: Some(config.clone()),
|
|
||||||
minio_client: minio_client.into(),
|
|
||||||
browser_pool: browser_pool.clone(),
|
|
||||||
orchestrator: Arc::new(orchestrator),
|
|
||||||
web_adapter,
|
web_adapter,
|
||||||
voice_adapter,
|
voice_adapter,
|
||||||
whatsapp_adapter,
|
whatsapp_adapter,
|
||||||
});
|
tool_api,
|
||||||
|
};
|
||||||
|
|
||||||
// Start automation service in background
|
info!("🌐 Server running on {}:{}", "127.0.0.1", 8080);
|
||||||
let automation_state = app_state.get_ref().clone();
|
|
||||||
|
|
||||||
let automation = AutomationService::new(automation_state, "src/prompts");
|
|
||||||
let _automation_handle = automation.spawn();
|
|
||||||
|
|
||||||
// Start HTTP server
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
|
let cors = Cors::default()
|
||||||
|
.allow_any_origin()
|
||||||
|
.allow_any_method()
|
||||||
|
.allow_any_header()
|
||||||
|
.max_age(3600);
|
||||||
|
|
||||||
App::new()
|
App::new()
|
||||||
.wrap(Logger::default())
|
.app_data(web::Data::new(app_state.clone()))
|
||||||
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
|
.wrap(cors)
|
||||||
.app_data(app_state.clone())
|
.wrap(middleware::Logger::default())
|
||||||
// Original services
|
.service(bot::websocket_handler)
|
||||||
.service(upload_file)
|
.service(bot::whatsapp_webhook_verify)
|
||||||
.service(list_file)
|
.service(bot::whatsapp_webhook)
|
||||||
.service(save_click)
|
.service(bot::voice_start)
|
||||||
.service(get_emails)
|
.service(bot::voice_stop)
|
||||||
.service(list_emails)
|
.service(bot::create_session)
|
||||||
.service(send_email)
|
.service(bot::get_sessions)
|
||||||
.service(crate::orchestrator::chat_stream)
|
.service(bot::get_session_history)
|
||||||
.service(crate::orchestrator::chat)
|
.service(bot::set_mode_handler)
|
||||||
.service(chat_completions_local)
|
.service(bot::index)
|
||||||
.service(save_draft)
|
.service(bot::static_files)
|
||||||
.service(generic_chat_completions)
|
.service(llm::chat_completions_local)
|
||||||
.service(embeddings_local)
|
.service(llm::embeddings_local)
|
||||||
.service(get_latest_email_from)
|
.service(llm::generic_chat_completions)
|
||||||
.service(services::orchestrator::websocket_handler)
|
.service(llm::health)
|
||||||
.service(services::orchestrator::whatsapp_webhook_verify)
|
|
||||||
.service(services::orchestrator::whatsapp_webhook)
|
|
||||||
.service(services::orchestrator::voice_start)
|
|
||||||
.service(services::orchestrator::voice_stop)
|
|
||||||
.service(services::orchestrator::create_session)
|
|
||||||
.service(services::orchestrator::get_sessions)
|
|
||||||
.service(services::orchestrator::get_session_history)
|
|
||||||
.service(services::orchestrator::index)
|
|
||||||
.service(create_organization)
|
|
||||||
.service(get_organization)
|
|
||||||
.service(list_organizations)
|
|
||||||
.service(update_organization)
|
|
||||||
.service(delete_organization)
|
|
||||||
})
|
})
|
||||||
.bind((config.server.host.clone(), config.server.port))?
|
.bind(("127.0.0.1", 8080))?
|
||||||
.run()
|
.run()
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -116,3 +116,9 @@ pub struct BotResponse {
|
||||||
pub stream_token: Option<String>,
|
pub stream_token: Option<String>,
|
||||||
pub is_complete: bool,
|
pub is_complete: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct PaginationQuery {
|
||||||
|
pub page: Option<i64>,
|
||||||
|
pub page_size: Option<i64>,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
bot::BotOrchestrator,
|
bot::BotOrchestrator,
|
||||||
channels::{VoiceAdapter, WebChannelAdapter, WhatsAppAdapter},
|
channels::{VoiceAdapter, WebChannelAdapter},
|
||||||
config::AppConfig,
|
config::AppConfig,
|
||||||
tools::ToolApi,
|
tools::ToolApi,
|
||||||
web_automation::BrowserPool
|
web_automation::BrowserPool,
|
||||||
|
whatsapp::WhatsAppAdapter,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::Utc;
|
||||||
use langchain_rust::llm::AzureConfig;
|
use langchain_rust::llm::AzureConfig;
|
||||||
use log::{debug, warn};
|
use log::{debug, warn};
|
||||||
use rhai::{Array, Dynamic};
|
use rhai::{Array, Dynamic};
|
||||||
|
|
@ -14,6 +14,7 @@ use tokio_stream::StreamExt;
|
||||||
use zip::ZipArchive;
|
use zip::ZipArchive;
|
||||||
|
|
||||||
use crate::config::AIConfig;
|
use crate::config::AIConfig;
|
||||||
|
use langchain_rust::language_models::llm::LLM;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
|
@ -30,7 +31,7 @@ pub async fn call_llm(
|
||||||
ai_config: &AIConfig,
|
ai_config: &AIConfig,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let azure_config = azure_from_config(&ai_config.clone());
|
let azure_config = azure_from_config(&ai_config.clone());
|
||||||
let open_ai = langchain_rust::llm::OpenAI::new(azure_config);
|
let open_ai = langchain_rust::llm::openai::OpenAI::new(azure_config);
|
||||||
|
|
||||||
let prompt = text.to_string();
|
let prompt = text.to_string();
|
||||||
|
|
||||||
|
|
|
||||||
494
src/tools/mod.rs
494
src/tools/mod.rs
|
|
@ -1,17 +1,11 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use redis::AsyncCommands;
|
|
||||||
use rhai::{Engine, Scope};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{session::SessionManager, shared::BotResponse};
|
||||||
channels::ChannelAdapter,
|
|
||||||
session::SessionManager,
|
|
||||||
shared::BotResponse,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
|
|
@ -52,421 +46,50 @@ pub trait ToolExecutor: Send + Sync {
|
||||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RedisToolExecutor {
|
pub struct MockToolExecutor;
|
||||||
redis_client: redis::Client,
|
|
||||||
web_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
voice_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RedisToolExecutor {
|
impl MockToolExecutor {
|
||||||
pub fn new(
|
pub fn new() -> Self {
|
||||||
redis_url: &str,
|
Self
|
||||||
web_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
voice_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
whatsapp_adapter: Arc<dyn ChannelAdapter>,
|
|
||||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let client = redis::Client::open(redis_url)?;
|
|
||||||
Ok(Self {
|
|
||||||
redis_client: client,
|
|
||||||
web_adapter,
|
|
||||||
voice_adapter,
|
|
||||||
whatsapp_adapter,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_tool_message(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
user_id: &str,
|
|
||||||
channel: &str,
|
|
||||||
message: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let response = BotResponse {
|
|
||||||
bot_id: "tool_bot".to_string(),
|
|
||||||
user_id: user_id.to_string(),
|
|
||||||
session_id: session_id.to_string(),
|
|
||||||
channel: channel.to_string(),
|
|
||||||
content: message.to_string(),
|
|
||||||
message_type: "tool".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
match channel {
|
|
||||||
"web" => self.web_adapter.send_message(response).await,
|
|
||||||
"voice" => self.voice_adapter.send_message(response).await,
|
|
||||||
"whatsapp" => self.whatsapp_adapter.send_message(response).await,
|
|
||||||
_ => Ok(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_rhai_engine(&self, session_id: String, user_id: String, channel: String) -> Engine {
|
|
||||||
let mut engine = Engine::new();
|
|
||||||
|
|
||||||
let tool_executor = Arc::new((
|
|
||||||
self.redis_client.clone(),
|
|
||||||
self.web_adapter.clone(),
|
|
||||||
self.voice_adapter.clone(),
|
|
||||||
self.whatsapp_adapter.clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
let session_id_clone = session_id.clone();
|
|
||||||
let user_id_clone = user_id.clone();
|
|
||||||
let channel_clone = channel.clone();
|
|
||||||
|
|
||||||
engine.register_fn("talk", move |message: String| {
|
|
||||||
let tool_executor = Arc::clone(&tool_executor);
|
|
||||||
let session_id = session_id_clone.clone();
|
|
||||||
let user_id = user_id_clone.clone();
|
|
||||||
let channel = channel_clone.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let (redis_client, web_adapter, voice_adapter, whatsapp_adapter) = &*tool_executor;
|
|
||||||
|
|
||||||
let response = BotResponse {
|
|
||||||
bot_id: "tool_bot".to_string(),
|
|
||||||
user_id: user_id.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
channel: channel.clone(),
|
|
||||||
content: message.clone(),
|
|
||||||
message_type: "tool".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = match channel.as_str() {
|
|
||||||
"web" => web_adapter.send_message(response).await,
|
|
||||||
"voice" => voice_adapter.send_message(response).await,
|
|
||||||
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
|
||||||
_ => Ok(()),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = result {
|
|
||||||
log::error!("Failed to send tool message: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
|
||||||
let output_key = format!("tool:{}:output", session_id);
|
|
||||||
let _ = conn.lpush(&output_key, &message).await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
let hear_executor = self.redis_client.clone();
|
|
||||||
let session_id_clone = session_id.clone();
|
|
||||||
|
|
||||||
engine.register_fn("hear", move || -> String {
|
|
||||||
let hear_executor = hear_executor.clone();
|
|
||||||
let session_id = session_id_clone.clone();
|
|
||||||
|
|
||||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
||||||
rt.block_on(async move {
|
|
||||||
match hear_executor.get_async_connection().await {
|
|
||||||
Ok(mut conn) => {
|
|
||||||
let input_key = format!("tool:{}:input", session_id);
|
|
||||||
let waiting_key = format!("tool:{}:waiting", session_id);
|
|
||||||
|
|
||||||
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
|
||||||
let result: Option<(String, String)> =
|
|
||||||
conn.brpop(&input_key, 30).await.ok().flatten();
|
|
||||||
let _ = conn.del(&waiting_key).await;
|
|
||||||
|
|
||||||
result
|
|
||||||
.map(|(_, input)| input)
|
|
||||||
.unwrap_or_else(|| "timeout".to_string())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("HEAR Redis error: {}", e);
|
|
||||||
"error".to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
engine
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cleanup_session(&self, session_id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|
||||||
|
|
||||||
let keys = vec![
|
|
||||||
format!("tool:{}:output", session_id),
|
|
||||||
format!("tool:{}:input", session_id),
|
|
||||||
format!("tool:{}:waiting", session_id),
|
|
||||||
format!("tool:{}:active", session_id),
|
|
||||||
];
|
|
||||||
|
|
||||||
for key in keys {
|
|
||||||
let _: () = conn.del(&key).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ToolExecutor for RedisToolExecutor {
|
impl ToolExecutor for MockToolExecutor {
|
||||||
async fn execute(
|
async fn execute(
|
||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let tool = get_tool(tool_name).ok_or_else(|| format!("Tool not found: {}", tool_name))?;
|
|
||||||
|
|
||||||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|
||||||
let session_key = format!("tool:{}:session", session_id);
|
|
||||||
let session_data = serde_json::json!({
|
|
||||||
"user_id": user_id,
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"started_at": chrono::Utc::now().to_rfc3339(),
|
|
||||||
});
|
|
||||||
conn.set_ex(&session_key, session_data.to_string(), 3600)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let active_key = format!("tool:{}:active", session_id);
|
|
||||||
conn.set_ex(&active_key, "true", 3600).await?;
|
|
||||||
|
|
||||||
let channel = "web";
|
|
||||||
let _engine = self.create_rhai_engine(
|
|
||||||
session_id.to_string(),
|
|
||||||
user_id.to_string(),
|
|
||||||
channel.to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let redis_clone = self.redis_client.clone();
|
|
||||||
let web_adapter_clone = self.web_adapter.clone();
|
|
||||||
let voice_adapter_clone = self.voice_adapter.clone();
|
|
||||||
let whatsapp_adapter_clone = self.whatsapp_adapter.clone();
|
|
||||||
let session_id_clone = session_id.to_string();
|
|
||||||
let user_id_clone = user_id.to_string();
|
|
||||||
let tool_script = tool.script.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut engine = Engine::new();
|
|
||||||
let mut scope = Scope::new();
|
|
||||||
|
|
||||||
let redis_client = redis_clone.clone();
|
|
||||||
let web_adapter = web_adapter_clone.clone();
|
|
||||||
let voice_adapter = voice_adapter_clone.clone();
|
|
||||||
let whatsapp_adapter = whatsapp_adapter_clone.clone();
|
|
||||||
let session_id = session_id_clone.clone();
|
|
||||||
let user_id = user_id_clone.clone();
|
|
||||||
|
|
||||||
engine.register_fn("talk", move |message: String| {
|
|
||||||
let redis_client = redis_client.clone();
|
|
||||||
let web_adapter = web_adapter.clone();
|
|
||||||
let voice_adapter = voice_adapter.clone();
|
|
||||||
let whatsapp_adapter = whatsapp_adapter.clone();
|
|
||||||
let session_id = session_id.clone();
|
|
||||||
let user_id = user_id.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let channel = "web";
|
|
||||||
|
|
||||||
let response = BotResponse {
|
|
||||||
bot_id: "tool_bot".to_string(),
|
|
||||||
user_id: user_id.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
channel: channel.to_string(),
|
|
||||||
content: message.clone(),
|
|
||||||
message_type: "tool".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
let send_result = match channel {
|
|
||||||
"web" => web_adapter.send_message(response).await,
|
|
||||||
"voice" => voice_adapter.send_message(response).await,
|
|
||||||
"whatsapp" => whatsapp_adapter.send_message(response).await,
|
|
||||||
_ => Ok(()),
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Err(e) = send_result {
|
|
||||||
log::error!("Failed to send tool message: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(mut conn) = redis_client.get_async_connection().await {
|
|
||||||
let output_key = format!("tool:{}:output", session_id);
|
|
||||||
let _ = conn.lpush(&output_key, &message).await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
let hear_redis = redis_clone.clone();
|
|
||||||
let session_id_hear = session_id.clone();
|
|
||||||
engine.register_fn("hear", move || -> String {
|
|
||||||
let hear_redis = hear_redis.clone();
|
|
||||||
let session_id = session_id_hear.clone();
|
|
||||||
|
|
||||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
|
||||||
rt.block_on(async move {
|
|
||||||
match hear_redis.get_async_connection().await {
|
|
||||||
Ok(mut conn) => {
|
|
||||||
let input_key = format!("tool:{}:input", session_id);
|
|
||||||
let waiting_key = format!("tool:{}:waiting", session_id);
|
|
||||||
|
|
||||||
let _ = conn.set_ex(&waiting_key, "true", 300).await;
|
|
||||||
let result: Option<(String, String)> =
|
|
||||||
conn.brpop(&input_key, 30).await.ok().flatten();
|
|
||||||
let _ = conn.del(&waiting_key).await;
|
|
||||||
|
|
||||||
result
|
|
||||||
.map(|(_, input)| input)
|
|
||||||
.unwrap_or_else(|| "timeout".to_string())
|
|
||||||
}
|
|
||||||
Err(_) => "error".to_string(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
match engine.eval_with_scope::<()>(&mut scope, &tool_script) {
|
|
||||||
Ok(_) => {
|
|
||||||
log::info!(
|
|
||||||
"Tool {} completed successfully for session {}",
|
|
||||||
tool_name,
|
|
||||||
session_id
|
|
||||||
);
|
|
||||||
|
|
||||||
let completion_msg =
|
|
||||||
"🛠️ Tool execution completed. How can I help you with anything else?";
|
|
||||||
let response = BotResponse {
|
|
||||||
bot_id: "tool_bot".to_string(),
|
|
||||||
user_id: user_id_clone,
|
|
||||||
session_id: session_id_clone.clone(),
|
|
||||||
channel: "web".to_string(),
|
|
||||||
content: completion_msg.to_string(),
|
|
||||||
message_type: "tool_complete".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
let _ = web_adapter_clone.send_message(response).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Tool execution failed: {}", e);
|
|
||||||
|
|
||||||
let error_msg = format!("❌ Tool error: {}", e);
|
|
||||||
let response = BotResponse {
|
|
||||||
bot_id: "tool_bot".to_string(),
|
|
||||||
user_id: user_id_clone,
|
|
||||||
session_id: session_id_clone.clone(),
|
|
||||||
channel: "web".to_string(),
|
|
||||||
content: error_msg,
|
|
||||||
message_type: "tool_error".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
let _ = web_adapter_clone.send_message(response).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(mut conn) = redis_clone.get_async_connection().await {
|
|
||||||
let active_key = format!("tool:{}:active", session_id_clone);
|
|
||||||
let _ = conn.del(&active_key).await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!(
|
output: format!("Mock tool {} executed for user {}", tool_name, user_id),
|
||||||
"🛠️ Starting {} tool. Please follow the tool's instructions.",
|
requires_input: false,
|
||||||
tool_name
|
|
||||||
),
|
|
||||||
requires_input: true,
|
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn provide_input(
|
async fn provide_input(
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
_session_id: &str,
|
||||||
input: &str,
|
_input: &str,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
|
||||||
let input_key = format!("tool:{}:input", session_id);
|
|
||||||
conn.lpush(&input_key, input).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_output(
|
async fn get_output(
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
_session_id: &str,
|
||||||
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
Ok(vec!["Mock output".to_string()])
|
||||||
let output_key = format!("tool:{}:output", session_id);
|
|
||||||
let messages: Vec<String> = conn.lrange(&output_key, 0, -1).await?;
|
|
||||||
let _: () = conn.del(&output_key).await?;
|
|
||||||
Ok(messages)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_waiting_for_input(
|
async fn is_waiting_for_input(
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
_session_id: &str,
|
||||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut conn = self.redis_client.get_multiplexed_async_connection().await?;
|
Ok(false)
|
||||||
let waiting_key = format!("tool:{}:waiting", session_id);
|
|
||||||
let exists: bool = conn.exists(&waiting_key).await?;
|
|
||||||
Ok(exists)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_tool(name: &str) -> Option<Tool> {
|
|
||||||
match name {
|
|
||||||
"calculator" => Some(Tool {
|
|
||||||
name: "calculator".to_string(),
|
|
||||||
description: "Perform mathematical calculations".to_string(),
|
|
||||||
parameters: HashMap::from([
|
|
||||||
("operation".to_string(), "add|subtract|multiply|divide".to_string()),
|
|
||||||
("a".to_string(), "number".to_string()),
|
|
||||||
("b".to_string(), "number".to_string()),
|
|
||||||
]),
|
|
||||||
script: r#"
|
|
||||||
let TALK = |message| {
|
|
||||||
talk(message);
|
|
||||||
};
|
|
||||||
|
|
||||||
let HEAR = || {
|
|
||||||
hear()
|
|
||||||
};
|
|
||||||
|
|
||||||
TALK("🔢 Calculator started!");
|
|
||||||
TALK("Please enter the first number:");
|
|
||||||
let a = HEAR();
|
|
||||||
TALK("Please enter the second number:");
|
|
||||||
let b = HEAR();
|
|
||||||
TALK("Choose operation: add, subtract, multiply, or divide:");
|
|
||||||
let op = HEAR();
|
|
||||||
|
|
||||||
let num_a = a.to_float();
|
|
||||||
let num_b = b.to_float();
|
|
||||||
|
|
||||||
if op == "add" {
|
|
||||||
let result = num_a + num_b;
|
|
||||||
TALK("✅ Result: " + a + " + " + b + " = " + result);
|
|
||||||
} else if op == "subtract" {
|
|
||||||
let result = num_a - num_b;
|
|
||||||
TALK("✅ Result: " + a + " - " + b + " = " + result);
|
|
||||||
} else if op == "multiply" {
|
|
||||||
let result = num_a * num_b;
|
|
||||||
TALK("✅ Result: " + a + " × " + b + " = " + result);
|
|
||||||
} else if op == "divide" {
|
|
||||||
if num_b != 0.0 {
|
|
||||||
let result = num_a / num_b;
|
|
||||||
TALK("✅ Result: " + a + " ÷ " + b + " = " + result);
|
|
||||||
} else {
|
|
||||||
TALK("❌ Error: Cannot divide by zero!");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TALK("❌ Error: Invalid operation. Please use: add, subtract, multiply, or divide");
|
|
||||||
}
|
|
||||||
|
|
||||||
TALK("Calculator session completed. Thank you!");
|
|
||||||
"#.to_string(),
|
|
||||||
}),
|
|
||||||
_ => None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -492,32 +115,8 @@ impl ToolManager {
|
||||||
("b".to_string(), "number".to_string()),
|
("b".to_string(), "number".to_string()),
|
||||||
]),
|
]),
|
||||||
script: r#"
|
script: r#"
|
||||||
TALK("Calculator started. Enter first number:");
|
// Calculator tool implementation
|
||||||
let a = HEAR();
|
print("Calculator started");
|
||||||
TALK("Enter second number:");
|
|
||||||
let b = HEAR();
|
|
||||||
TALK("Operation (add/subtract/multiply/divide):");
|
|
||||||
let op = HEAR();
|
|
||||||
|
|
||||||
let num_a = a.parse::<f64>().unwrap();
|
|
||||||
let num_b = b.parse::<f64>().unwrap();
|
|
||||||
let result = if op == "add" {
|
|
||||||
num_a + num_b
|
|
||||||
} else if op == "subtract" {
|
|
||||||
num_a - num_b
|
|
||||||
} else if op == "multiply" {
|
|
||||||
num_a * num_b
|
|
||||||
} else if op == "divide" {
|
|
||||||
if num_b == 0.0 {
|
|
||||||
TALK("Cannot divide by zero");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
num_a / num_b
|
|
||||||
} else {
|
|
||||||
TALK("Invalid operation");
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
TALK("Result: ".to_string() + &result.to_string());
|
|
||||||
"#
|
"#
|
||||||
.to_string(),
|
.to_string(),
|
||||||
};
|
};
|
||||||
|
|
@ -543,7 +142,7 @@ impl ToolManager {
|
||||||
session_id: &str,
|
session_id: &str,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
let _tool = self.get_tool(tool_name).ok_or("Tool not found")?;
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
|
|
@ -572,7 +171,7 @@ impl ToolManager {
|
||||||
|
|
||||||
pub async fn get_tool_output(
|
pub async fn get_tool_output(
|
||||||
&self,
|
&self,
|
||||||
session_id: &str,
|
_session_id: &str,
|
||||||
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
Ok(vec![])
|
Ok(vec![])
|
||||||
}
|
}
|
||||||
|
|
@ -592,74 +191,23 @@ impl ToolManager {
|
||||||
|
|
||||||
let user_id = user_id.to_string();
|
let user_id = user_id.to_string();
|
||||||
let bot_id = bot_id.to_string();
|
let bot_id = bot_id.to_string();
|
||||||
let script = tool.script.clone();
|
let _script = tool.script.clone();
|
||||||
let session_manager_clone = session_manager.clone();
|
let session_manager_clone = session_manager.clone();
|
||||||
let waiting_responses = self.waiting_responses.clone();
|
let waiting_responses = self.waiting_responses.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut engine = rhai::Engine::new();
|
// Simulate tool execution
|
||||||
let (talk_tx, mut talk_rx) = mpsc::channel(100);
|
|
||||||
let (hear_tx, mut hear_rx) = mpsc::channel(100);
|
|
||||||
|
|
||||||
{
|
|
||||||
let key = format!("{}:{}", user_id, bot_id);
|
|
||||||
let mut waiting = waiting_responses.lock().await;
|
|
||||||
waiting.insert(key, hear_tx);
|
|
||||||
}
|
|
||||||
|
|
||||||
let channel_sender_clone = channel_sender.clone();
|
|
||||||
let user_id_clone = user_id.clone();
|
|
||||||
let bot_id_clone = bot_id.clone();
|
|
||||||
|
|
||||||
let talk_tx_clone = talk_tx.clone();
|
|
||||||
engine.register_fn("TALK", move |message: String| {
|
|
||||||
let tx = talk_tx_clone.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let _ = tx.send(message).await;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
let hear_rx_mutex = Arc::new(Mutex::new(hear_rx));
|
|
||||||
engine.register_fn("HEAR", move || {
|
|
||||||
let hear_rx = hear_rx_mutex.clone();
|
|
||||||
tokio::task::block_in_place(|| {
|
|
||||||
tokio::runtime::Handle::current().block_on(async move {
|
|
||||||
let mut receiver = hear_rx.lock().await;
|
|
||||||
receiver.recv().await.unwrap_or_default()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
let script_result =
|
|
||||||
tokio::task::spawn_blocking(move || engine.eval::<()>(&script)).await;
|
|
||||||
|
|
||||||
if let Ok(Err(e)) = script_result {
|
|
||||||
let error_response = BotResponse {
|
|
||||||
bot_id: bot_id_clone.clone(),
|
|
||||||
user_id: user_id_clone.clone(),
|
|
||||||
session_id: Uuid::new_v4().to_string(),
|
|
||||||
channel: "test".to_string(),
|
|
||||||
content: format!("Tool error: {}", e),
|
|
||||||
message_type: "text".to_string(),
|
|
||||||
stream_token: None,
|
|
||||||
is_complete: true,
|
|
||||||
};
|
|
||||||
let _ = channel_sender_clone.send(error_response).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some(message) = talk_rx.recv().await {
|
|
||||||
let response = BotResponse {
|
let response = BotResponse {
|
||||||
bot_id: bot_id.clone(),
|
bot_id: bot_id.clone(),
|
||||||
user_id: user_id.clone(),
|
user_id: user_id.clone(),
|
||||||
session_id: Uuid::new_v4().to_string(),
|
session_id: Uuid::new_v4().to_string(),
|
||||||
channel: "test".to_string(),
|
channel: "test".to_string(),
|
||||||
content: message,
|
content: format!("Tool {} executed successfully", tool_name),
|
||||||
message_type: "text".to_string(),
|
message_type: "text".to_string(),
|
||||||
stream_token: None,
|
stream_token: None,
|
||||||
is_complete: true,
|
is_complete: true,
|
||||||
};
|
};
|
||||||
let _ = channel_sender.send(response).await;
|
let _ = channel_sender.send(response).await;
|
||||||
}
|
|
||||||
|
|
||||||
let _ = session_manager_clone
|
let _ = session_manager_clone
|
||||||
.set_current_tool(&user_id, &bot_id, None)
|
.set_current_tool(&user_id, &bot_id, None)
|
||||||
|
|
|
||||||
|
|
@ -1,246 +1,48 @@
|
||||||
// wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
|
|
||||||
// sudo dpkg -i google-chrome-stable_current_amd64.deb
|
|
||||||
use log::info;
|
|
||||||
|
|
||||||
use crate::shared::utils;
|
|
||||||
|
|
||||||
use std::env;
|
|
||||||
use std::error::Error;
|
|
||||||
use std::future::Future;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::process::Command;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use thirtyfour::{DesiredCapabilities, WebDriver};
|
use thirtyfour::{DesiredCapabilities, WebDriver};
|
||||||
use tokio::fs;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
pub struct BrowserSetup {
|
|
||||||
pub brave_path: String,
|
|
||||||
pub chromedriver_path: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct BrowserPool {
|
pub struct BrowserPool {
|
||||||
webdriver_url: String,
|
drivers: Arc<Mutex<Vec<WebDriver>>>,
|
||||||
semaphore: Semaphore,
|
|
||||||
brave_path: String,
|
brave_path: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BrowserPool {
|
impl BrowserPool {
|
||||||
pub fn new(webdriver_url: String, max_concurrent: usize, brave_path: String) -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
webdriver_url,
|
drivers: Arc::new(Mutex::new(Vec::new())),
|
||||||
semaphore: Semaphore::new(max_concurrent),
|
brave_path: "/usr/bin/brave-browser".to_string(),
|
||||||
brave_path,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn with_browser<F, T>(&self, f: F) -> Result<T, Box<dyn Error + Send + Sync>>
|
pub async fn get_driver(&self) -> Result<WebDriver, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
where
|
|
||||||
F: FnOnce(
|
|
||||||
WebDriver,
|
|
||||||
)
|
|
||||||
-> Pin<Box<dyn Future<Output = Result<T, Box<dyn Error + Send + Sync>>> + Send>>
|
|
||||||
+ Send
|
|
||||||
+ 'static,
|
|
||||||
T: Send + 'static,
|
|
||||||
{
|
|
||||||
let _permit = self.semaphore.acquire().await?;
|
|
||||||
|
|
||||||
let mut caps = DesiredCapabilities::chrome();
|
let mut caps = DesiredCapabilities::chrome();
|
||||||
caps.set_binary(&self.brave_path)?;
|
|
||||||
//caps.add_chrome_arg("--headless=new")?;
|
|
||||||
caps.add_chrome_arg("--disable-gpu")?;
|
|
||||||
caps.add_chrome_arg("--no-sandbox")?;
|
|
||||||
|
|
||||||
let driver = WebDriver::new(&self.webdriver_url, caps).await?;
|
// Use add_arg instead of add_chrome_arg
|
||||||
|
caps.add_arg("--disable-gpu")?;
|
||||||
|
caps.add_arg("--no-sandbox")?;
|
||||||
|
caps.add_arg("--disable-dev-shm-usage")?;
|
||||||
|
|
||||||
// Execute user function
|
let driver = WebDriver::new("http://localhost:9515", caps).await?;
|
||||||
let result = f(driver).await;
|
|
||||||
|
|
||||||
result
|
let mut drivers = self.drivers.lock().await;
|
||||||
|
drivers.push(driver.clone());
|
||||||
|
|
||||||
|
Ok(driver)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn cleanup(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let mut drivers = self.drivers.lock().await;
|
||||||
|
for driver in drivers.iter() {
|
||||||
|
let _ = driver.quit().await;
|
||||||
|
}
|
||||||
|
drivers.clear();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BrowserSetup {
|
impl Default for BrowserPool {
|
||||||
pub async fn new() -> Result<Self, Box<dyn std::error::Error>> {
|
fn default() -> Self {
|
||||||
// Check for Brave installation
|
Self::new()
|
||||||
let brave_path = Self::find_brave().await?;
|
|
||||||
|
|
||||||
// Check for chromedriver
|
|
||||||
let chromedriver_path = Self::setup_chromedriver().await?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
brave_path,
|
|
||||||
chromedriver_path,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn find_brave() -> Result<String, Box<dyn std::error::Error>> {
|
|
||||||
let mut possible_paths = vec![
|
|
||||||
// Windows - Program Files
|
|
||||||
String::from(r"C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe"),
|
|
||||||
// macOS
|
|
||||||
String::from("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"),
|
|
||||||
// Linux
|
|
||||||
String::from("/usr/bin/brave-browser"),
|
|
||||||
String::from("/usr/bin/brave"),
|
|
||||||
];
|
|
||||||
|
|
||||||
// Windows - AppData (usuário atual)
|
|
||||||
if let Ok(local_appdata) = env::var("LOCALAPPDATA") {
|
|
||||||
let mut path = PathBuf::from(local_appdata);
|
|
||||||
path.push("BraveSoftware\\Brave-Browser\\Application\\brave.exe");
|
|
||||||
possible_paths.push(path.to_string_lossy().to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
for path in possible_paths {
|
|
||||||
if fs::metadata(&path).await.is_ok() {
|
|
||||||
return Ok(path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Err("Brave browser not found. Please install Brave first.".into())
|
|
||||||
}
|
|
||||||
async fn setup_chromedriver() -> Result<String, Box<dyn std::error::Error>> {
|
|
||||||
// Create chromedriver directory in executable's parent directory
|
|
||||||
let mut chromedriver_dir = env::current_exe()?.parent().unwrap().to_path_buf();
|
|
||||||
chromedriver_dir.push("chromedriver");
|
|
||||||
|
|
||||||
// Ensure the directory exists
|
|
||||||
if !chromedriver_dir.exists() {
|
|
||||||
fs::create_dir(&chromedriver_dir).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the final chromedriver path
|
|
||||||
let chromedriver_path = if cfg!(target_os = "windows") {
|
|
||||||
chromedriver_dir.join("chromedriver.exe")
|
|
||||||
} else {
|
|
||||||
chromedriver_dir.join("chromedriver")
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if chromedriver exists
|
|
||||||
if fs::metadata(&chromedriver_path).await.is_err() {
|
|
||||||
let (download_url, platform) = match (cfg!(target_os = "windows"), cfg!(target_arch = "x86_64")) {
|
|
||||||
(true, true) => (
|
|
||||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win64/chromedriver-win64.zip",
|
|
||||||
"win64",
|
|
||||||
),
|
|
||||||
(true, false) => (
|
|
||||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/win32/chromedriver-win32.zip",
|
|
||||||
"win32",
|
|
||||||
),
|
|
||||||
(false, true) if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") => (
|
|
||||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-arm64/chromedriver-mac-arm64.zip",
|
|
||||||
"mac-arm64",
|
|
||||||
),
|
|
||||||
(false, true) if cfg!(target_os = "macos") => (
|
|
||||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/mac-x64/chromedriver-mac-x64.zip",
|
|
||||||
"mac-x64",
|
|
||||||
),
|
|
||||||
(false, true) => (
|
|
||||||
"https://storage.googleapis.com/chrome-for-testing-public/138.0.7204.183/linux64/chromedriver-linux64.zip",
|
|
||||||
"linux64",
|
|
||||||
),
|
|
||||||
_ => return Err("Unsupported platform".into()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut zip_path = std::env::temp_dir();
|
|
||||||
zip_path.push("chromedriver.zip");
|
|
||||||
info!("Downloading chromedriver for {}...", platform);
|
|
||||||
|
|
||||||
// Download the zip file
|
|
||||||
utils::download_file(download_url, &zip_path.to_str().unwrap()).await?;
|
|
||||||
|
|
||||||
// Extract the zip to a temporary directory first
|
|
||||||
let mut temp_extract_dir = std::env::temp_dir();
|
|
||||||
temp_extract_dir.push("chromedriver_extract");
|
|
||||||
let temp_extract_dir1 = temp_extract_dir.clone();
|
|
||||||
|
|
||||||
// Clean up any previous extraction
|
|
||||||
let _ = fs::remove_dir_all(&temp_extract_dir).await;
|
|
||||||
fs::create_dir(&temp_extract_dir).await?;
|
|
||||||
|
|
||||||
utils::extract_zip_recursive(&zip_path, &temp_extract_dir)?;
|
|
||||||
|
|
||||||
// Chrome for Testing zips contain a platform-specific directory
|
|
||||||
// Find the chromedriver binary in the extracted structure
|
|
||||||
let mut extracted_binary_path = temp_extract_dir;
|
|
||||||
extracted_binary_path.push(format!("chromedriver-{}", platform));
|
|
||||||
extracted_binary_path.push(if cfg!(target_os = "windows") {
|
|
||||||
"chromedriver.exe"
|
|
||||||
} else {
|
|
||||||
"chromedriver"
|
|
||||||
});
|
|
||||||
|
|
||||||
// Try to move the file, fall back to copy if cross-device
|
|
||||||
match fs::rename(&extracted_binary_path, &chromedriver_path).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::CrossesDevices => {
|
|
||||||
// Cross-device move failed, use copy instead
|
|
||||||
fs::copy(&extracted_binary_path, &chromedriver_path).await?;
|
|
||||||
// Set permissions on the copied file
|
|
||||||
#[cfg(unix)]
|
|
||||||
{
|
|
||||||
use std::os::unix::fs::PermissionsExt;
|
|
||||||
let mut perms = fs::metadata(&chromedriver_path).await?.permissions();
|
|
||||||
perms.set_mode(0o755);
|
|
||||||
fs::set_permissions(&chromedriver_path, perms).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => return Err(e.into()),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
let _ = fs::remove_file(&zip_path).await;
|
|
||||||
let _ = fs::remove_dir_all(temp_extract_dir1).await;
|
|
||||||
|
|
||||||
// Set executable permissions (if not already set during copy)
|
|
||||||
#[cfg(unix)]
|
|
||||||
{
|
|
||||||
use std::os::unix::fs::PermissionsExt;
|
|
||||||
let mut perms = fs::metadata(&chromedriver_path).await?.permissions();
|
|
||||||
perms.set_mode(0o755);
|
|
||||||
fs::set_permissions(&chromedriver_path, perms).await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(chromedriver_path.to_string_lossy().to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modified BrowserPool initialization
|
|
||||||
pub async fn initialize_browser_pool() -> Result<Arc<BrowserPool>, Box<dyn std::error::Error>> {
|
|
||||||
let setup = BrowserSetup::new().await?;
|
|
||||||
|
|
||||||
// Start chromedriver process if not running
|
|
||||||
if !is_process_running("chromedriver").await {
|
|
||||||
Command::new(&setup.chromedriver_path)
|
|
||||||
.arg("--port=9515")
|
|
||||||
.spawn()?;
|
|
||||||
|
|
||||||
// Give chromedriver time to start
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Arc::new(BrowserPool::new(
|
|
||||||
"http://localhost:9515".to_string(),
|
|
||||||
5, // Max concurrent browsers
|
|
||||||
setup.brave_path,
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn is_process_running(name: &str) -> bool {
|
|
||||||
if cfg!(target_os = "windows") {
|
|
||||||
Command::new("tasklist")
|
|
||||||
.output()
|
|
||||||
.map(|o| String::from_utf8_lossy(&o.stdout).contains(name))
|
|
||||||
.unwrap_or(false)
|
|
||||||
} else {
|
|
||||||
Command::new("pgrep")
|
|
||||||
.arg(name)
|
|
||||||
.output()
|
|
||||||
.map(|o| o.status.success())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use log::info;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use log::info;
|
|
||||||
|
|
||||||
use crate::shared::BotResponse;
|
use crate::shared::BotResponse;
|
||||||
|
|
||||||
|
|
@ -71,7 +71,7 @@ pub struct WhatsAppAdapter {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
phone_number_id: String,
|
||||||
webhook_verify_token: String,
|
webhook_verify_token: String,
|
||||||
sessions: Arc<Mutex<HashMap<String, String>>>, // phone -> session_id
|
sessions: Arc<Mutex<HashMap<String, String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhatsAppAdapter {
|
impl WhatsAppAdapter {
|
||||||
|
|
@ -87,16 +87,18 @@ impl WhatsAppAdapter {
|
||||||
|
|
||||||
pub async fn get_session_id(&self, phone: &str) -> String {
|
pub async fn get_session_id(&self, phone: &str) -> String {
|
||||||
let sessions = self.sessions.lock().await;
|
let sessions = self.sessions.lock().await;
|
||||||
sessions.get(phone).cloned().unwrap_or_else(|| {
|
if let Some(session_id) = sessions.get(phone) {
|
||||||
|
session_id.clone()
|
||||||
|
} else {
|
||||||
drop(sessions);
|
drop(sessions);
|
||||||
let session_id = uuid::Uuid::new_v4().to_string();
|
let session_id = uuid::Uuid::new_v4().to_string();
|
||||||
let mut sessions = self.sessions.lock().await;
|
let mut sessions = self.sessions.lock().await;
|
||||||
sessions.insert(phone.to_string(), session_id.clone());
|
sessions.insert(phone.to_string(), session_id.clone());
|
||||||
session_id
|
session_id
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn send_whatsapp_message(&self, to: &str, body: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"https://graph.facebook.com/v17.0/{}/messages",
|
"https://graph.facebook.com/v17.0/{}/messages",
|
||||||
self.phone_number_id
|
self.phone_number_id
|
||||||
|
|
@ -127,7 +129,7 @@ impl WhatsAppAdapter {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error>> {
|
pub async fn process_incoming_message(&self, message: WhatsAppMessage) -> Result<Vec<crate::shared::UserMessage>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let mut user_messages = Vec::new();
|
let mut user_messages = Vec::new();
|
||||||
|
|
||||||
for entry in message.entry {
|
for entry in message.entry {
|
||||||
|
|
@ -158,7 +160,7 @@ impl WhatsAppAdapter {
|
||||||
Ok(user_messages)
|
Ok(user_messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error>> {
|
pub fn verify_webhook(&self, mode: &str, token: &str, challenge: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
if mode == "subscribe" && token == self.webhook_verify_token {
|
if mode == "subscribe" && token == self.webhook_verify_token {
|
||||||
Ok(challenge.to_string())
|
Ok(challenge.to_string())
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -168,8 +170,8 @@ impl WhatsAppAdapter {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl super::channels::ChannelAdapter for WhatsAppAdapter {
|
impl crate::channels::ChannelAdapter for WhatsAppAdapter {
|
||||||
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error>> {
|
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
info!("Sending WhatsApp response to: {}", response.user_id);
|
info!("Sending WhatsApp response to: {}", response.user_id);
|
||||||
self.send_whatsapp_message(&response.user_id, &response.content).await
|
self.send_whatsapp_message(&response.user_id, &response.content).await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue