- Warning removal and restore of old code.

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-10-07 07:16:03 -03:00
parent 76dc51e980
commit 8a9cd104d6
11 changed files with 522 additions and 339 deletions

View file

@ -1,15 +1,17 @@
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc;
use log::info; use log::info;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use crate::shared::{BotResponse, UserMessage}; use crate::shared::BotResponse;
#[async_trait] #[async_trait]
pub trait ChannelAdapter: Send + Sync { pub trait ChannelAdapter: Send + Sync {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>; async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
} }
pub struct WebChannelAdapter { pub struct WebChannelAdapter {
@ -34,7 +36,10 @@ impl WebChannelAdapter {
#[async_trait] #[async_trait]
impl ChannelAdapter for WebChannelAdapter { impl ChannelAdapter for WebChannelAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let connections = self.connections.lock().await; let connections = self.connections.lock().await;
if let Some(tx) = connections.get(&response.session_id) { if let Some(tx) = connections.get(&response.session_id) {
tx.send(response).await?; tx.send(response).await?;
@ -67,11 +72,17 @@ impl VoiceAdapter {
session_id: &str, session_id: &str,
user_id: &str, user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
info!("Starting voice session for user: {} with session: {}", user_id, session_id); info!(
"Starting voice session for user: {} with session: {}",
user_id, session_id
);
let token = format!("mock_token_{}_{}", session_id, user_id); let token = format!("mock_token_{}_{}", session_id, user_id);
self.rooms.lock().await.insert(session_id.to_string(), token.clone()); self.rooms
.lock()
.await
.insert(session_id.to_string(), token.clone());
Ok(token) Ok(token)
} }
@ -99,7 +110,10 @@ impl VoiceAdapter {
#[async_trait] #[async_trait]
impl ChannelAdapter for VoiceAdapter { impl ChannelAdapter for VoiceAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Sending voice response to: {}", response.user_id); info!("Sending voice response to: {}", response.user_id);
self.send_voice_response(&response.session_id, &response.content) self.send_voice_response(&response.session_id, &response.content)
.await .await

View file

@ -1,257 +0,0 @@
use async_trait::async_trait;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::tools::ToolManager;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct OpenAIClient {
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
}
impl OpenAIClient {
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
Self { client }
}
}
#[async_trait]
impl LLMProvider for OpenAIClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
}
}

View file

@ -1,3 +1,257 @@
pub mod llm_provider; use async_trait::async_trait;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
pub use llm_provider::*; use crate::tools::ToolManager;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct OpenAIClient {
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
}
impl OpenAIClient {
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
Self { client }
}
}
#[async_trait]
impl LLMProvider for OpenAIClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
}
}

View file

@ -1,29 +0,0 @@
use actix_web::{web, HttpResponse, Result};
use serde_json::json;
use crate::shared::state::AppState;
use crate::shared::utils::azure_from_config;
pub async fn health() -> Result<HttpResponse> {
Ok(HttpResponse::Ok().json(json!({"status": "healthy"})))
}
pub async fn chat_completions_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local LLM not implemented"})))
}
pub async fn embeddings_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local embeddings not implemented"})))
}
pub async fn generic_chat_completions(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Generic chat not implemented"})))
}

116
src/llm_legacy/llm_azure.rs Normal file
View file

@ -0,0 +1,116 @@
use log::info;
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv;
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
message: ChatMessage,
finish_reason: String,
}
#[post("/azure/v1/chat/completions")]
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Always log raw POST data
if let Ok(body_str) = std::str::from_utf8(&body) {
info!("POST Data: {}", body_str);
} else {
info!("POST Data (binary): {:?}", body);
}
dotenv().ok();
// Environment variables
let azure_endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
let azure_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let deployment_name = env::var("AI_LLM_MODEL")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
// Construct Azure OpenAI URL
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
azure_endpoint, deployment_name
);
// Forward headers
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"api-key",
reqwest::header::HeaderValue::from_str(&azure_key)
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
);
headers.insert(
"Content-Type",
reqwest::header::HeaderValue::from_static("application/json"),
);
let body_str = std::str::from_utf8(&body).unwrap_or("");
info!("Original POST Data: {}", body_str);
// Remove the problematic params
let re =
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
let cleaned = re.replace_all(body_str, "");
let cleaned_body = web::Bytes::from(cleaned.to_string());
info!("Cleaned POST Data: {}", cleaned);
// Send request to Azure
let client = Client::new();
let response = client
.post(&url)
.headers(headers)
.body(cleaned_body)
.send()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response based on status
let status = response.status();
let raw_response = response
.text()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Log the raw response
info!("Raw Azure response: {}", raw_response);
if status.is_success() {
Ok(HttpResponse::Ok().body(raw_response))
} else {
// Handle error responses properly
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).body(raw_response))
}
}

View file

@ -1,11 +1,13 @@
use log::{error, info};
use actix_web::{post, web, HttpRequest, HttpResponse, Result}; use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv; use dotenv::dotenv;
use log::{error, info};
use regex::Regex; use regex::Regex;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::env; use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -35,17 +37,20 @@ struct Choice {
} }
fn clean_request_body(body: &str) -> String { fn clean_request_body(body: &str) -> String {
// Remove problematic parameters that might not be supported by all providers
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap(); let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
re.replace_all(body, "").to_string() re.replace_all(body, "").to_string()
} }
#[post("/v1/chat/completions")] #[post("/v1/chat/completions")]
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> { pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Log raw POST data
let body_str = std::str::from_utf8(&body).unwrap_or_default(); let body_str = std::str::from_utf8(&body).unwrap_or_default();
info!("Original POST Data: {}", body_str); info!("Original POST Data: {}", body_str);
dotenv().ok(); dotenv().ok();
// Get environment variables
let api_key = env::var("AI_KEY") let api_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let model = env::var("AI_LLM_MODEL") let model = env::var("AI_LLM_MODEL")
@ -53,9 +58,11 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
let endpoint = env::var("AI_ENDPOINT") let endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?; .map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
// Parse and modify the request body
let mut json_value: serde_json::Value = serde_json::from_str(body_str) let mut json_value: serde_json::Value = serde_json::from_str(body_str)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?; .map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
// Add model parameter
if let Some(obj) = json_value.as_object_mut() { if let Some(obj) = json_value.as_object_mut() {
obj.insert("model".to_string(), serde_json::Value::String(model)); obj.insert("model".to_string(), serde_json::Value::String(model));
} }
@ -65,6 +72,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Modified POST Data: {}", modified_body_str); info!("Modified POST Data: {}", modified_body_str);
// Set up headers
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
headers.insert( headers.insert(
"Authorization", "Authorization",
@ -76,6 +84,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
reqwest::header::HeaderValue::from_static("application/json"), reqwest::header::HeaderValue::from_static("application/json"),
); );
// Send request to the AI provider
let client = Client::new(); let client = Client::new();
let response = client let response = client
.post(&endpoint) .post(&endpoint)
@ -85,6 +94,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.await .await
.map_err(actix_web::error::ErrorInternalServerError)?; .map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response
let status = response.status(); let status = response.status();
let raw_response = response let raw_response = response
.text() .text()
@ -94,6 +104,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Provider response status: {}", status); info!("Provider response status: {}", status);
info!("Provider response body: {}", raw_response); info!("Provider response body: {}", raw_response);
// Convert response to OpenAI format if successful
if status.is_success() { if status.is_success() {
match convert_to_openai_format(&raw_response) { match convert_to_openai_format(&raw_response) {
Ok(openai_response) => Ok(HttpResponse::Ok() Ok(openai_response) => Ok(HttpResponse::Ok()
@ -101,12 +112,14 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.body(openai_response)), .body(openai_response)),
Err(e) => { Err(e) => {
error!("Failed to convert response format: {}", e); error!("Failed to convert response format: {}", e);
// Return the original response if conversion fails
Ok(HttpResponse::Ok() Ok(HttpResponse::Ok()
.content_type("application/json") .content_type("application/json")
.body(raw_response)) .body(raw_response))
} }
} }
} else { } else {
// Return error as-is
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16()) let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
@ -116,6 +129,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
} }
} }
/// Converts provider response to OpenAI-compatible format
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> { fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ProviderChoice { struct ProviderChoice {
@ -177,8 +191,10 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
total_tokens: u32, total_tokens: u32,
} }
// Parse the provider response
let provider: ProviderResponse = serde_json::from_str(provider_response)?; let provider: ProviderResponse = serde_json::from_str(provider_response)?;
// Extract content from the first choice
let first_choice = provider.choices.get(0).ok_or("No choices in response")?; let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
let content = first_choice.message.content.clone(); let content = first_choice.message.content.clone();
let role = first_choice let role = first_choice
@ -187,6 +203,7 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
.clone() .clone()
.unwrap_or_else(|| "assistant".to_string()); .unwrap_or_else(|| "assistant".to_string());
// Calculate token usage
let usage = provider.usage.unwrap_or_default(); let usage = provider.usage.unwrap_or_default();
let prompt_tokens = usage.prompt_tokens.unwrap_or(0); let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
let completion_tokens = usage let completion_tokens = usage

View file

@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use std::env; use std::env;
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ChatMessage { struct ChatMessage {
role: String, role: String,
@ -34,6 +35,7 @@ struct Choice {
finish_reason: String, finish_reason: String,
} }
// Llama.cpp server request/response structures
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest { struct LlamaCppRequest {
prompt: String, prompt: String,
@ -51,7 +53,8 @@ struct LlamaCppResponse {
generation_settings: Option<serde_json::Value>, generation_settings: Option<serde_json::Value>,
} }
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string()); let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
if llm_local.to_lowercase() != "true" { if llm_local.to_lowercase() != "true" {
@ -59,6 +62,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Get configuration from environment variables
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let embedding_url = let embedding_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -73,6 +77,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!(" LLM Model: {}", llm_model_path); info!(" LLM Model: {}", llm_model_path);
info!(" Embedding Model: {}", embedding_model_path); info!(" Embedding Model: {}", embedding_model_path);
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await; let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await; let embedding_running = is_server_running(&embedding_url).await;
@ -81,6 +86,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(()); return Ok(());
} }
// Start servers that aren't running
let mut tasks = vec![]; let mut tasks = vec![];
if !llm_running && !llm_model_path.is_empty() { if !llm_running && !llm_model_path.is_empty() {
@ -105,17 +111,19 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server"); info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
} }
// Wait for all server startup tasks
for task in tasks { for task in tasks {
task.await??; task.await??;
} }
// Wait for servers to be ready with verbose logging
info!("⏳ Waiting for servers to become ready..."); info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model_path.is_empty(); let mut llm_ready = llm_running || llm_model_path.is_empty();
let mut embedding_ready = embedding_running || embedding_model_path.is_empty(); let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
let mut attempts = 0; let mut attempts = 0;
let max_attempts = 60; let max_attempts = 60; // 2 minutes total
while attempts < max_attempts && (!llm_ready || !embedding_ready) { while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await; sleep(Duration::from_secs(2)).await;
@ -180,6 +188,8 @@ async fn start_llm_server(
std::env::set_var("OMP_PLACES", "cores"); std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close"); std::env::set_var("OMP_PROC_BIND", "close");
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
let mut cmd = tokio::process::Command::new("sh"); let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!( cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &", "cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
@ -215,6 +225,7 @@ async fn is_server_running(url: &str) -> bool {
} }
} }
// Convert OpenAI chat messages to a single prompt
fn messages_to_prompt(messages: &[ChatMessage]) -> String { fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new(); let mut prompt = String::new();
@ -239,28 +250,32 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
prompt prompt
} }
// Proxy endpoint
#[post("/local/v1/chat/completions")] #[post("/local/v1/chat/completions")]
pub async fn chat_completions_local( pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>, req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest, _req: HttpRequest,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok(); dotenv().ok().unwrap();
// Get llama.cpp server URL
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
// Convert OpenAI format to llama.cpp format
let prompt = messages_to_prompt(&req_body.messages); let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest { let llama_request = LlamaCppRequest {
prompt, prompt,
n_predict: Some(500), n_predict: Some(500), // Adjust as needed
temperature: Some(0.7), temperature: Some(0.7),
top_k: Some(40), top_k: Some(40),
top_p: Some(0.9), top_p: Some(0.9),
stream: req_body.stream, stream: req_body.stream,
}; };
// Send request to llama.cpp server
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_secs(120)) .timeout(Duration::from_secs(120)) // 2 minute timeout
.build() .build()
.map_err(|e| { .map_err(|e| {
error!("Error creating HTTP client: {}", e); error!("Error creating HTTP client: {}", e);
@ -286,6 +301,7 @@ pub async fn chat_completions_local(
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response") actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?; })?;
// Convert llama.cpp response to OpenAI format
let openai_response = ChatCompletionResponse { let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
@ -328,6 +344,7 @@ pub async fn chat_completions_local(
} }
} }
// OpenAI Embedding Request - Modified to handle both string and array inputs
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct EmbeddingRequest { pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")] #[serde(deserialize_with = "deserialize_input")]
@ -337,6 +354,7 @@ pub struct EmbeddingRequest {
pub _encoding_format: Option<String>, pub _encoding_format: Option<String>,
} }
// Custom deserializer to handle both string and array inputs
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error> fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
@ -382,6 +400,7 @@ where
deserializer.deserialize_any(InputVisitor) deserializer.deserialize_any(InputVisitor)
} }
// OpenAI Embedding Response
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub object: String, pub object: String,
@ -403,17 +422,20 @@ pub struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
// Llama.cpp Embedding Request
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest { struct LlamaCppEmbeddingRequest {
pub content: String, pub content: String,
} }
// FIXED: Handle the stupid nested array format
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem { struct LlamaCppEmbeddingResponseItem {
pub index: usize, pub index: usize,
pub embedding: Vec<Vec<f32>>, pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
} }
// Proxy endpoint for embeddings
#[post("/v1/embeddings")] #[post("/v1/embeddings")]
pub async fn embeddings_local( pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>, req_body: web::Json<EmbeddingRequest>,
@ -421,6 +443,7 @@ pub async fn embeddings_local(
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
dotenv().ok(); dotenv().ok();
// Get llama.cpp server URL
let llama_url = let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string()); env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -432,6 +455,7 @@ pub async fn embeddings_local(
actix_web::error::ErrorInternalServerError("Failed to create HTTP client") actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?; })?;
// Process each input text and get embeddings
let mut embeddings_data = Vec::new(); let mut embeddings_data = Vec::new();
let mut total_tokens = 0; let mut total_tokens = 0;
@ -456,11 +480,13 @@ pub async fn embeddings_local(
let status = response.status(); let status = response.status();
if status.is_success() { if status.is_success() {
// First, get the raw response text for debugging
let raw_response = response.text().await.map_err(|e| { let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e); error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response") actix_web::error::ErrorInternalServerError("Failed to read response")
})?; })?;
// Parse the response as a vector of items with nested arrays
let llama_response: Vec<LlamaCppEmbeddingResponseItem> = let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| { serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e); error!("Error parsing llama.cpp embedding response: {}", e);
@ -470,13 +496,17 @@ pub async fn embeddings_local(
) )
})?; })?;
// Extract the embedding from the nested array bullshit
if let Some(item) = llama_response.get(0) { if let Some(item) = llama_response.get(0) {
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
let flattened_embedding = if !item.embedding.is_empty() { let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone() item.embedding[0].clone() // Take the first (and probably only) inner array
} else { } else {
vec![] vec![] // Empty if no embedding data
}; };
// Estimate token count
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32; let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens; total_tokens += estimated_tokens;
@ -514,6 +544,7 @@ pub async fn embeddings_local(
} }
} }
// Build OpenAI-compatible response
let openai_response = EmbeddingResponse { let openai_response = EmbeddingResponse {
object: "list".to_string(), object: "list".to_string(),
data: embeddings_data, data: embeddings_data,
@ -527,6 +558,7 @@ pub async fn embeddings_local(
Ok(HttpResponse::Ok().json(openai_response)) Ok(HttpResponse::Ok().json(openai_response))
} }
// Health check endpoint
#[actix_web::get("/health")] #[actix_web::get("/health")]
pub async fn health() -> Result<HttpResponse> { pub async fn health() -> Result<HttpResponse> {
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string()); let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());

View file

@ -1,3 +1,3 @@
pub mod llm; pub mod llm_azure;
pub mod llm_generic; pub mod llm_generic;
pub mod llm_local; pub mod llm_local;

View file

@ -1,4 +1,5 @@
use actix_cors::Cors; use actix_cors::Cors;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer}; use actix_web::{web, App, HttpServer};
use dotenv::dotenv; use dotenv::dotenv;
use log::info; use log::info;
@ -23,23 +24,30 @@ mod tools;
mod web_automation; mod web_automation;
mod whatsapp; mod whatsapp;
use crate::automation::AutomationService;
use crate::bot::{ use crate::bot::{
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files, create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify, voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
}; };
use crate::channels::{VoiceAdapter, WebChannelAdapter}; use crate::channels::{VoiceAdapter, WebChannelAdapter};
use crate::config::AppConfig; use crate::config::AppConfig;
use crate::email::send_email; use crate::email::{
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
};
use crate::file::{list_file, upload_file}; use crate::file::{list_file, upload_file};
use crate::llm_legacy::llm_generic::generic_chat_completions;
use crate::llm_legacy::llm_local::{
chat_completions_local, embeddings_local, ensure_llama_servers_running,
};
use crate::shared::AppState; use crate::shared::AppState;
use crate::whatsapp::WhatsAppAdapter; use crate::whatsapp::WhatsAppAdapter;
#[actix_web::main] #[actix_web::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
dotenv().ok(); dotenv().ok();
env_logger::init(); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
info!("Starting BotServer..."); info!("Starting General Bots 6.0...");
let config = AppConfig::from_env(); let config = AppConfig::from_env();
@ -59,17 +67,17 @@ async fn main() -> std::io::Result<()> {
}; };
// Optional custom database pool // Optional custom database pool
let db_custom_pool: Option<sqlx::Pool<sqlx::Postgres>> = let db_custom_pool = match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await
match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await { {
Ok(pool) => { Ok(pool) => {
info!("Connected to custom database"); info!("Connected to custom database");
Some(pool) Some(pool)
} }
Err(e) => { Err(e) => {
log::warn!("Failed to connect to custom database: {}", e); log::warn!("Failed to connect to custom database: {}", e);
None None
} }
}; };
// Optional Redis client // Optional Redis client
let redis_client = match redis::Client::open("redis://127.0.0.1/") { let redis_client = match redis::Client::open("redis://127.0.0.1/") {
@ -83,9 +91,28 @@ async fn main() -> std::io::Result<()> {
} }
}; };
// Placeholder for MinIO (not yet implemented) // Initialize MinIO client
let minio_client = None; let minio_client = file::init_minio(&config)
.await
.expect("Failed to initialize Minio");
// Initialize browser pool
let browser_pool = Arc::new(web_automation::BrowserPool::new(
"chrome".to_string(),
2,
"headless".to_string(),
));
// Initialize LLM servers
ensure_llama_servers_running()
.await
.expect("Failed to initialize LLM local server.");
web_automation::initialize_browser_pool()
.await
.expect("Failed to initialize browser pool");
// Initialize services from new architecture
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone()); let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone()); let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
@ -110,20 +137,13 @@ async fn main() -> std::io::Result<()> {
let tool_api = Arc::new(tools::ToolApi::new()); let tool_api = Arc::new(tools::ToolApi::new());
// Browser pool constructed with the required three arguments. // Create unified app state
// Adjust the strings as needed for your environment.
let browser_pool = Arc::new(web_automation::BrowserPool::new(
"chrome".to_string(),
2,
"headless".to_string(),
));
let app_state = AppState { let app_state = AppState {
minio_client, minio_client: Some(minio_client),
config: Some(config.clone()), config: Some(config.clone()),
db: Some(db_pool.clone()), db: Some(db_pool.clone()),
db_custom: db_custom_pool, db_custom: db_custom_pool.clone(),
browser_pool, browser_pool: browser_pool.clone(),
orchestrator: Arc::new(orchestrator), orchestrator: Arc::new(orchestrator),
web_adapter, web_adapter,
voice_adapter, voice_adapter,
@ -131,6 +151,11 @@ async fn main() -> std::io::Result<()> {
tool_api, tool_api,
}; };
// Start automation service in background
let automation_state = app_state.clone();
let automation = AutomationService::new(automation_state, "src/prompts");
let _automation_handle = automation.spawn();
info!( info!(
"Starting server on {}:{}", "Starting server on {}:{}",
config.server.host, config.server.port config.server.host, config.server.port
@ -145,7 +170,22 @@ async fn main() -> std::io::Result<()> {
App::new() App::new()
.wrap(cors) .wrap(cors)
.wrap(Logger::default())
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
.app_data(web::Data::new(app_state.clone())) .app_data(web::Data::new(app_state.clone()))
// Legacy services
.service(upload_file)
.service(list_file)
.service(save_click)
.service(get_emails)
.service(list_emails)
.service(send_email)
.service(chat_completions_local)
.service(save_draft)
.service(generic_chat_completions)
.service(embeddings_local)
.service(get_latest_email_from)
// New bot services
.service(index) .service(index)
.service(static_files) .service(static_files)
.service(websocket_handler) .service(websocket_handler)
@ -157,9 +197,6 @@ async fn main() -> std::io::Result<()> {
.service(get_sessions) .service(get_sessions)
.service(get_session_history) .service(get_session_history)
.service(set_mode_handler) .service(set_mode_handler)
.service(send_email)
.service(upload_file)
.service(list_file)
}) })
.bind((config.server.host.clone(), config.server.port))? .bind((config.server.host.clone(), config.server.port))?
.run() .run()

View file

@ -1,4 +1,3 @@
use chrono::Utc;
use langchain_rust::llm::AzureConfig; use langchain_rust::llm::AzureConfig;
use log::{debug, warn}; use log::{debug, warn};
use rhai::{Array, Dynamic}; use rhai::{Array, Dynamic};

View file

@ -1,7 +1,7 @@
<!doctype html> <!doctype html>
<html> <html>
<head> <head>
<title>General Bots - ChatGPT Clone</title> <title>General Bots</title>
<style> <style>
* { * {
margin: 0; margin: 0;