- Warning removal and restore of old code.

This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-10-07 07:16:03 -03:00
parent 76dc51e980
commit 8a9cd104d6
11 changed files with 522 additions and 339 deletions

View file

@ -1,15 +1,17 @@
use async_trait::async_trait;
use chrono::Utc;
use log::info;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use crate::shared::{BotResponse, UserMessage};
use crate::shared::BotResponse;
#[async_trait]
pub trait ChannelAdapter: Send + Sync {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
}
pub struct WebChannelAdapter {
@ -34,7 +36,10 @@ impl WebChannelAdapter {
#[async_trait]
impl ChannelAdapter for WebChannelAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let connections = self.connections.lock().await;
if let Some(tx) = connections.get(&response.session_id) {
tx.send(response).await?;
@ -67,11 +72,17 @@ impl VoiceAdapter {
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
info!("Starting voice session for user: {} with session: {}", user_id, session_id);
info!(
"Starting voice session for user: {} with session: {}",
user_id, session_id
);
let token = format!("mock_token_{}_{}", session_id, user_id);
self.rooms.lock().await.insert(session_id.to_string(), token.clone());
self.rooms
.lock()
.await
.insert(session_id.to_string(), token.clone());
Ok(token)
}
@ -99,7 +110,10 @@ impl VoiceAdapter {
#[async_trait]
impl ChannelAdapter for VoiceAdapter {
async fn send_message(&self, response: BotResponse) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async fn send_message(
&self,
response: BotResponse,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Sending voice response to: {}", response.user_id);
self.send_voice_response(&response.session_id, &response.content)
.await

View file

@ -1,257 +0,0 @@
use async_trait::async_trait;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::tools::ToolManager;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct OpenAIClient {
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
}
impl OpenAIClient {
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
Self { client }
}
}
#[async_trait]
impl LLMProvider for OpenAIClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
}
}

View file

@ -1,3 +1,257 @@
pub mod llm_provider;
use async_trait::async_trait;
use futures::StreamExt;
use langchain_rust::{
language_models::llm::LLM,
llm::{claude::Claude, openai::OpenAI},
};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::mpsc;
pub use llm_provider::*;
use crate::tools::ToolManager;
#[async_trait]
pub trait LLMProvider: Send + Sync {
async fn generate(
&self,
prompt: &str,
config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn generate_with_tools(
&self,
prompt: &str,
config: &Value,
available_tools: &[String],
tool_manager: Arc<ToolManager>,
session_id: &str,
user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct OpenAIClient {
client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>,
}
impl OpenAIClient {
pub fn new(client: OpenAI<langchain_rust::llm::openai::OpenAIConfig>) -> Self {
Self { client }
}
}
#[async_trait]
impl LLMProvider for OpenAIClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct AnthropicClient {
client: Claude,
}
impl AnthropicClient {
pub fn new(api_key: String) -> Self {
let client = Claude::default().with_api_key(api_key);
Self { client }
}
}
#[async_trait]
impl LLMProvider for AnthropicClient {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let result = self
.client
.invoke(prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let messages = vec![langchain_rust::schemas::Message::new_human_message(prompt)];
let mut stream = self
.client
.stream(&messages)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
while let Some(result) = stream.next().await {
match result {
Ok(chunk) => {
let content = chunk.content;
if !content.is_empty() {
let _ = tx.send(content.to_string()).await;
}
}
Err(e) => {
eprintln!("Stream error: {}", e);
}
}
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_info = if available_tools.is_empty() {
String::new()
} else {
format!("\n\nAvailable tools: {}. You can suggest using these tools if they would help answer the user's question.", available_tools.join(", "))
};
let enhanced_prompt = format!("{}{}", prompt, tools_info);
let result = self
.client
.invoke(&enhanced_prompt)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)?;
Ok(result)
}
}
pub struct MockLLMProvider;
impl MockLLMProvider {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
prompt: &str,
_config: &Value,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(format!("Mock response to: {}", prompt))
}
async fn generate_stream(
&self,
prompt: &str,
_config: &Value,
tx: mpsc::Sender<String>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let response = format!("Mock stream response to: {}", prompt);
for word in response.split_whitespace() {
let _ = tx.send(format!("{} ", word)).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
async fn generate_with_tools(
&self,
prompt: &str,
_config: &Value,
available_tools: &[String],
_tool_manager: Arc<ToolManager>,
_session_id: &str,
_user_id: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let tools_list = if available_tools.is_empty() {
"no tools available".to_string()
} else {
available_tools.join(", ")
};
Ok(format!(
"Mock response with tools [{}] to: {}",
tools_list, prompt
))
}
}

View file

@ -1,29 +0,0 @@
use actix_web::{web, HttpResponse, Result};
use serde_json::json;
use crate::shared::state::AppState;
use crate::shared::utils::azure_from_config;
pub async fn health() -> Result<HttpResponse> {
Ok(HttpResponse::Ok().json(json!({"status": "healthy"})))
}
pub async fn chat_completions_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local LLM not implemented"})))
}
pub async fn embeddings_local(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Local embeddings not implemented"})))
}
pub async fn generic_chat_completions(
_data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> {
Ok(HttpResponse::NotImplemented().json(json!({"error": "Generic chat not implemented"})))
}

116
src/llm_legacy/llm_azure.rs Normal file
View file

@ -0,0 +1,116 @@
use log::info;
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv;
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Choice {
message: ChatMessage,
finish_reason: String,
}
#[post("/azure/v1/chat/completions")]
async fn chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Always log raw POST data
if let Ok(body_str) = std::str::from_utf8(&body) {
info!("POST Data: {}", body_str);
} else {
info!("POST Data (binary): {:?}", body);
}
dotenv().ok();
// Environment variables
let azure_endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
let azure_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let deployment_name = env::var("AI_LLM_MODEL")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_LLM_MODEL not set."))?;
// Construct Azure OpenAI URL
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version=2025-01-01-preview",
azure_endpoint, deployment_name
);
// Forward headers
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"api-key",
reqwest::header::HeaderValue::from_str(&azure_key)
.map_err(|_| actix_web::error::ErrorInternalServerError("Invalid Azure key"))?,
);
headers.insert(
"Content-Type",
reqwest::header::HeaderValue::from_static("application/json"),
);
let body_str = std::str::from_utf8(&body).unwrap_or("");
info!("Original POST Data: {}", body_str);
// Remove the problematic params
let re =
Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls)"\s*:\s*[^,}]*"#).unwrap();
let cleaned = re.replace_all(body_str, "");
let cleaned_body = web::Bytes::from(cleaned.to_string());
info!("Cleaned POST Data: {}", cleaned);
// Send request to Azure
let client = Client::new();
let response = client
.post(&url)
.headers(headers)
.body(cleaned_body)
.send()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response based on status
let status = response.status();
let raw_response = response
.text()
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Log the raw response
info!("Raw Azure response: {}", raw_response);
if status.is_success() {
Ok(HttpResponse::Ok().body(raw_response))
} else {
// Handle error responses properly
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
Ok(HttpResponse::build(actix_status).body(raw_response))
}
}

View file

@ -1,11 +1,13 @@
use log::{error, info};
use actix_web::{post, web, HttpRequest, HttpResponse, Result};
use dotenv::dotenv;
use log::{error, info};
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
@ -35,17 +37,20 @@ struct Choice {
}
fn clean_request_body(body: &str) -> String {
// Remove problematic parameters that might not be supported by all providers
let re = Regex::new(r#","?\s*"(max_completion_tokens|parallel_tool_calls|top_p|frequency_penalty|presence_penalty)"\s*:\s*[^,}]*"#).unwrap();
re.replace_all(body, "").to_string()
}
#[post("/v1/chat/completions")]
pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Result<HttpResponse> {
// Log raw POST data
let body_str = std::str::from_utf8(&body).unwrap_or_default();
info!("Original POST Data: {}", body_str);
dotenv().ok();
// Get environment variables
let api_key = env::var("AI_KEY")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_KEY not set."))?;
let model = env::var("AI_LLM_MODEL")
@ -53,9 +58,11 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
let endpoint = env::var("AI_ENDPOINT")
.map_err(|_| actix_web::error::ErrorInternalServerError("AI_ENDPOINT not set."))?;
// Parse and modify the request body
let mut json_value: serde_json::Value = serde_json::from_str(body_str)
.map_err(|_| actix_web::error::ErrorInternalServerError("Failed to parse JSON"))?;
// Add model parameter
if let Some(obj) = json_value.as_object_mut() {
obj.insert("model".to_string(), serde_json::Value::String(model));
}
@ -65,6 +72,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Modified POST Data: {}", modified_body_str);
// Set up headers
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
@ -76,6 +84,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
reqwest::header::HeaderValue::from_static("application/json"),
);
// Send request to the AI provider
let client = Client::new();
let response = client
.post(&endpoint)
@ -85,6 +94,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.await
.map_err(actix_web::error::ErrorInternalServerError)?;
// Handle response
let status = response.status();
let raw_response = response
.text()
@ -94,6 +104,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
info!("Provider response status: {}", status);
info!("Provider response body: {}", raw_response);
// Convert response to OpenAI format if successful
if status.is_success() {
match convert_to_openai_format(&raw_response) {
Ok(openai_response) => Ok(HttpResponse::Ok()
@ -101,12 +112,14 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
.body(openai_response)),
Err(e) => {
error!("Failed to convert response format: {}", e);
// Return the original response if conversion fails
Ok(HttpResponse::Ok()
.content_type("application/json")
.body(raw_response))
}
}
} else {
// Return error as-is
let actix_status = actix_web::http::StatusCode::from_u16(status.as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
@ -116,6 +129,7 @@ pub async fn generic_chat_completions(body: web::Bytes, _req: HttpRequest) -> Re
}
}
/// Converts provider response to OpenAI-compatible format
fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn std::error::Error>> {
#[derive(serde::Deserialize)]
struct ProviderChoice {
@ -177,8 +191,10 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
total_tokens: u32,
}
// Parse the provider response
let provider: ProviderResponse = serde_json::from_str(provider_response)?;
// Extract content from the first choice
let first_choice = provider.choices.get(0).ok_or("No choices in response")?;
let content = first_choice.message.content.clone();
let role = first_choice
@ -187,6 +203,7 @@ fn convert_to_openai_format(provider_response: &str) -> Result<String, Box<dyn s
.clone()
.unwrap_or_else(|| "assistant".to_string());
// Calculate token usage
let usage = provider.usage.unwrap_or_default();
let prompt_tokens = usage.prompt_tokens.unwrap_or(0);
let completion_tokens = usage

View file

@ -6,6 +6,7 @@ use serde::{Deserialize, Serialize};
use std::env;
use tokio::time::{sleep, Duration};
// OpenAI-compatible request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
@ -34,6 +35,7 @@ struct Choice {
finish_reason: String,
}
// Llama.cpp server request/response structures
#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppRequest {
prompt: String,
@ -51,7 +53,8 @@ struct LlamaCppResponse {
generation_settings: Option<serde_json::Value>,
}
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
let llm_local = env::var("LLM_LOCAL").unwrap_or_else(|_| "false".to_string());
if llm_local.to_lowercase() != "true" {
@ -59,6 +62,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(());
}
// Get configuration from environment variables
let llm_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
let embedding_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -73,6 +77,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!(" LLM Model: {}", llm_model_path);
info!(" Embedding Model: {}", embedding_model_path);
// Check if servers are already running
let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await;
@ -81,6 +86,7 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
return Ok(());
}
// Start servers that aren't running
let mut tasks = vec![];
if !llm_running && !llm_model_path.is_empty() {
@ -105,17 +111,19 @@ pub async fn ensure_llama_servers_running() -> Result<(), Box<dyn std::error::Er
info!("⚠️ EMBEDDING_MODEL_PATH not set, skipping Embedding server");
}
// Wait for all server startup tasks
for task in tasks {
task.await??;
}
// Wait for servers to be ready with verbose logging
info!("⏳ Waiting for servers to become ready...");
let mut llm_ready = llm_running || llm_model_path.is_empty();
let mut embedding_ready = embedding_running || embedding_model_path.is_empty();
let mut attempts = 0;
let max_attempts = 60;
let max_attempts = 60; // 2 minutes total
while attempts < max_attempts && (!llm_ready || !embedding_ready) {
sleep(Duration::from_secs(2)).await;
@ -180,6 +188,8 @@ async fn start_llm_server(
std::env::set_var("OMP_PLACES", "cores");
std::env::set_var("OMP_PROC_BIND", "close");
// "cd {} && numactl --interleave=all ./llama-server -m {} --host 0.0.0.0 --port {} --threads 20 --threads-batch 40 --temp 0.7 --parallel 1 --repeat-penalty 1.1 --ctx-size 8192 --batch-size 8192 -n 4096 --mlock --no-mmap --flash-attn --no-kv-offload --no-mmap &",
let mut cmd = tokio::process::Command::new("sh");
cmd.arg("-c").arg(format!(
"cd {} && ./llama-server -m {} --host 0.0.0.0 --port {} --n-gpu-layers 99 &",
@ -215,6 +225,7 @@ async fn is_server_running(url: &str) -> bool {
}
}
// Convert OpenAI chat messages to a single prompt
fn messages_to_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
@ -239,28 +250,32 @@ fn messages_to_prompt(messages: &[ChatMessage]) -> String {
prompt
}
// Proxy endpoint
#[post("/local/v1/chat/completions")]
pub async fn chat_completions_local(
req_body: web::Json<ChatCompletionRequest>,
_req: HttpRequest,
) -> Result<HttpResponse> {
dotenv().ok();
dotenv().ok().unwrap();
// Get llama.cpp server URL
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());
// Convert OpenAI format to llama.cpp format
let prompt = messages_to_prompt(&req_body.messages);
let llama_request = LlamaCppRequest {
prompt,
n_predict: Some(500),
n_predict: Some(500), // Adjust as needed
temperature: Some(0.7),
top_k: Some(40),
top_p: Some(0.9),
stream: req_body.stream,
};
// Send request to llama.cpp server
let client = Client::builder()
.timeout(Duration::from_secs(120))
.timeout(Duration::from_secs(120)) // 2 minute timeout
.build()
.map_err(|e| {
error!("Error creating HTTP client: {}", e);
@ -286,6 +301,7 @@ pub async fn chat_completions_local(
actix_web::error::ErrorInternalServerError("Failed to parse llama.cpp response")
})?;
// Convert llama.cpp response to OpenAI format
let openai_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
@ -328,6 +344,7 @@ pub async fn chat_completions_local(
}
}
// OpenAI Embedding Request - Modified to handle both string and array inputs
#[derive(Debug, Deserialize)]
pub struct EmbeddingRequest {
#[serde(deserialize_with = "deserialize_input")]
@ -337,6 +354,7 @@ pub struct EmbeddingRequest {
pub _encoding_format: Option<String>,
}
// Custom deserializer to handle both string and array inputs
fn deserialize_input<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
@ -382,6 +400,7 @@ where
deserializer.deserialize_any(InputVisitor)
}
// OpenAI Embedding Response
#[derive(Debug, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
@ -403,17 +422,20 @@ pub struct Usage {
pub total_tokens: u32,
}
// Llama.cpp Embedding Request
#[derive(Debug, Serialize)]
struct LlamaCppEmbeddingRequest {
pub content: String,
}
// FIXED: Handle the stupid nested array format
#[derive(Debug, Deserialize)]
struct LlamaCppEmbeddingResponseItem {
pub index: usize,
pub embedding: Vec<Vec<f32>>,
pub embedding: Vec<Vec<f32>>, // This is the up part - embedding is an array of arrays
}
// Proxy endpoint for embeddings
#[post("/v1/embeddings")]
pub async fn embeddings_local(
req_body: web::Json<EmbeddingRequest>,
@ -421,6 +443,7 @@ pub async fn embeddings_local(
) -> Result<HttpResponse> {
dotenv().ok();
// Get llama.cpp server URL
let llama_url =
env::var("EMBEDDING_URL").unwrap_or_else(|_| "http://localhost:8082".to_string());
@ -432,6 +455,7 @@ pub async fn embeddings_local(
actix_web::error::ErrorInternalServerError("Failed to create HTTP client")
})?;
// Process each input text and get embeddings
let mut embeddings_data = Vec::new();
let mut total_tokens = 0;
@ -456,11 +480,13 @@ pub async fn embeddings_local(
let status = response.status();
if status.is_success() {
// First, get the raw response text for debugging
let raw_response = response.text().await.map_err(|e| {
error!("Error reading response text: {}", e);
actix_web::error::ErrorInternalServerError("Failed to read response")
})?;
// Parse the response as a vector of items with nested arrays
let llama_response: Vec<LlamaCppEmbeddingResponseItem> =
serde_json::from_str(&raw_response).map_err(|e| {
error!("Error parsing llama.cpp embedding response: {}", e);
@ -470,13 +496,17 @@ pub async fn embeddings_local(
)
})?;
// Extract the embedding from the nested array bullshit
if let Some(item) = llama_response.get(0) {
// The embedding field contains Vec<Vec<f32>>, so we need to flatten it
// If it's [[0.1, 0.2, 0.3]], we want [0.1, 0.2, 0.3]
let flattened_embedding = if !item.embedding.is_empty() {
item.embedding[0].clone()
item.embedding[0].clone() // Take the first (and probably only) inner array
} else {
vec![]
vec![] // Empty if no embedding data
};
// Estimate token count
let estimated_tokens = (input_text.len() as f32 / 4.0).ceil() as u32;
total_tokens += estimated_tokens;
@ -514,6 +544,7 @@ pub async fn embeddings_local(
}
}
// Build OpenAI-compatible response
let openai_response = EmbeddingResponse {
object: "list".to_string(),
data: embeddings_data,
@ -527,6 +558,7 @@ pub async fn embeddings_local(
Ok(HttpResponse::Ok().json(openai_response))
}
// Health check endpoint
#[actix_web::get("/health")]
pub async fn health() -> Result<HttpResponse> {
let llama_url = env::var("LLM_URL").unwrap_or_else(|_| "http://localhost:8081".to_string());

View file

@ -1,3 +1,3 @@
pub mod llm;
pub mod llm_azure;
pub mod llm_generic;
pub mod llm_local;

View file

@ -1,4 +1,5 @@
use actix_cors::Cors;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
use dotenv::dotenv;
use log::info;
@ -23,23 +24,30 @@ mod tools;
mod web_automation;
mod whatsapp;
use crate::automation::AutomationService;
use crate::bot::{
create_session, get_session_history, get_sessions, index, set_mode_handler, static_files,
voice_start, voice_stop, websocket_handler, whatsapp_webhook, whatsapp_webhook_verify,
};
use crate::channels::{VoiceAdapter, WebChannelAdapter};
use crate::config::AppConfig;
use crate::email::send_email;
use crate::email::{
get_emails, get_latest_email_from, list_emails, save_click, save_draft, send_email,
};
use crate::file::{list_file, upload_file};
use crate::llm_legacy::llm_generic::generic_chat_completions;
use crate::llm_legacy::llm_local::{
chat_completions_local, embeddings_local, ensure_llama_servers_running,
};
use crate::shared::AppState;
use crate::whatsapp::WhatsAppAdapter;
#[actix_web::main]
async fn main() -> std::io::Result<()> {
dotenv().ok();
env_logger::init();
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
info!("Starting BotServer...");
info!("Starting General Bots 6.0...");
let config = AppConfig::from_env();
@ -59,17 +67,17 @@ async fn main() -> std::io::Result<()> {
};
// Optional custom database pool
let db_custom_pool: Option<sqlx::Pool<sqlx::Postgres>> =
match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await {
Ok(pool) => {
info!("Connected to custom database");
Some(pool)
}
Err(e) => {
log::warn!("Failed to connect to custom database: {}", e);
None
}
};
let db_custom_pool = match sqlx::postgres::PgPool::connect(&config.database_custom_url()).await
{
Ok(pool) => {
info!("Connected to custom database");
Some(pool)
}
Err(e) => {
log::warn!("Failed to connect to custom database: {}", e);
None
}
};
// Optional Redis client
let redis_client = match redis::Client::open("redis://127.0.0.1/") {
@ -83,9 +91,28 @@ async fn main() -> std::io::Result<()> {
}
};
// Placeholder for MinIO (not yet implemented)
let minio_client = None;
// Initialize MinIO client
let minio_client = file::init_minio(&config)
.await
.expect("Failed to initialize Minio");
// Initialize browser pool
let browser_pool = Arc::new(web_automation::BrowserPool::new(
"chrome".to_string(),
2,
"headless".to_string(),
));
// Initialize LLM servers
ensure_llama_servers_running()
.await
.expect("Failed to initialize LLM local server.");
web_automation::initialize_browser_pool()
.await
.expect("Failed to initialize browser pool");
// Initialize services from new architecture
let auth_service = auth::AuthService::new(db_pool.clone(), redis_client.clone());
let session_manager = session::SessionManager::new(db_pool.clone(), redis_client.clone());
@ -110,20 +137,13 @@ async fn main() -> std::io::Result<()> {
let tool_api = Arc::new(tools::ToolApi::new());
// Browser pool constructed with the required three arguments.
// Adjust the strings as needed for your environment.
let browser_pool = Arc::new(web_automation::BrowserPool::new(
"chrome".to_string(),
2,
"headless".to_string(),
));
// Create unified app state
let app_state = AppState {
minio_client,
minio_client: Some(minio_client),
config: Some(config.clone()),
db: Some(db_pool.clone()),
db_custom: db_custom_pool,
browser_pool,
db_custom: db_custom_pool.clone(),
browser_pool: browser_pool.clone(),
orchestrator: Arc::new(orchestrator),
web_adapter,
voice_adapter,
@ -131,6 +151,11 @@ async fn main() -> std::io::Result<()> {
tool_api,
};
// Start automation service in background
let automation_state = app_state.clone();
let automation = AutomationService::new(automation_state, "src/prompts");
let _automation_handle = automation.spawn();
info!(
"Starting server on {}:{}",
config.server.host, config.server.port
@ -145,7 +170,22 @@ async fn main() -> std::io::Result<()> {
App::new()
.wrap(cors)
.wrap(Logger::default())
.wrap(Logger::new("HTTP REQUEST: %a %{User-Agent}i"))
.app_data(web::Data::new(app_state.clone()))
// Legacy services
.service(upload_file)
.service(list_file)
.service(save_click)
.service(get_emails)
.service(list_emails)
.service(send_email)
.service(chat_completions_local)
.service(save_draft)
.service(generic_chat_completions)
.service(embeddings_local)
.service(get_latest_email_from)
// New bot services
.service(index)
.service(static_files)
.service(websocket_handler)
@ -157,9 +197,6 @@ async fn main() -> std::io::Result<()> {
.service(get_sessions)
.service(get_session_history)
.service(set_mode_handler)
.service(send_email)
.service(upload_file)
.service(list_file)
})
.bind((config.server.host.clone(), config.server.port))?
.run()

View file

@ -1,4 +1,3 @@
use chrono::Utc;
use langchain_rust::llm::AzureConfig;
use log::{debug, warn};
use rhai::{Array, Dynamic};

View file

@ -1,7 +1,7 @@
<!doctype html>
<html>
<head>
<title>General Bots - ChatGPT Clone</title>
<title>General Bots</title>
<style>
* {
margin: 0;