feat(llm): pass model configuration to LLM generation and streaming

Include model parameter in LLM provider calls across automation, bot, and keyword modules to ensure correct model selection based on configuration. This improves flexibility and consistency in LLM usage.
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-11-12 08:19:21 -03:00
parent d1f6da70d8
commit 9bb8b64be7
7 changed files with 157 additions and 88 deletions

View file

@ -123,7 +123,11 @@ async fn compact_prompt_for_bots(
let llm_provider = state.llm_provider.clone(); let llm_provider = state.llm_provider.clone();
trace!("Starting summarization for session {}", session.id); trace!("Starting summarization for session {}", session.id);
let mut filtered = String::new(); let mut filtered = String::new();
let summarized = match llm_provider.generate("", &serde_json::Value::Array(messages)).await { let config_manager = crate::config::ConfigManager::new(state.conn.clone());
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
let summarized = match llm_provider.generate(
"", &serde_json::Value::Array(messages), &model).await {
Ok(summary) => { Ok(summary) => {
trace!( trace!(
"Successfully summarized session {} ({} chars)", "Successfully summarized session {} ({} chars)",

View file

@ -1,5 +1,4 @@
use crate::basic::ScriptService; use crate::basic::ScriptService;
use crate::config::ConfigManager;
use crate::shared::models::{Automation, TriggerKind}; use crate::shared::models::{Automation, TriggerKind};
use crate::shared::state::AppState; use crate::shared::state::AppState;
use chrono::Utc; use chrono::Utc;

View file

@ -45,7 +45,7 @@ pub async fn execute_llm_generation(state: Arc<AppState>, prompt: String) -> Res
let config_manager = crate::config::ConfigManager::new(state.conn.clone()); let config_manager = crate::config::ConfigManager::new(state.conn.clone());
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default(); let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
let handler = crate::llm_models::get_handler(&model); let handler = crate::llm_models::get_handler(&model);
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null).await?; let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model).await?;
let processed = handler.process_content(&raw_response); let processed = handler.process_content(&raw_response);
Ok(processed) Ok(processed)
} }

View file

@ -1,10 +1,10 @@
mod ui; mod ui;
use crate::bot::ui::BotUI;
use crate::config::ConfigManager; use crate::config::ConfigManager;
use crate::drive_monitor::DriveMonitor; use crate::drive_monitor::DriveMonitor;
use crate::llm::OpenAIClient; use crate::llm::OpenAIClient;
use crate::llm_models; use crate::llm_models;
use crate::nvidia::get_system_metrics; use crate::nvidia::get_system_metrics;
use crate::bot::ui::BotUI;
use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession}; use crate::shared::models::{BotResponse, Suggestion, UserMessage, UserSession};
use crate::shared::state::AppState; use crate::shared::state::AppState;
use actix_web::{web, HttpRequest, HttpResponse, Result}; use actix_web::{web, HttpRequest, HttpResponse, Result};
@ -397,11 +397,17 @@ impl BotOrchestrator {
}; };
response_tx.send(thinking_response).await?; response_tx.send(thinking_response).await?;
} }
let config_manager = ConfigManager::new(self.state.conn.clone());
let model = config_manager
.get_config(
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
"llm-model",
None,
)
.unwrap_or_default();
let model1 = model.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = llm if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model).await {
.generate_stream("", &messages, stream_tx)
.await
{
error!("LLM streaming error: {}", e); error!("LLM streaming error: {}", e);
} }
}); });
@ -413,7 +419,6 @@ impl BotOrchestrator {
let mut last_progress_update = Instant::now(); let mut last_progress_update = Instant::now();
let progress_interval = Duration::from_secs(1); let progress_interval = Duration::from_secs(1);
let initial_tokens = crate::shared::utils::estimate_token_count(&message.content); let initial_tokens = crate::shared::utils::estimate_token_count(&message.content);
let config_manager = ConfigManager::new(self.state.conn.clone());
let max_context_size = config_manager let max_context_size = config_manager
.get_config( .get_config(
&Uuid::parse_str(&message.bot_id).unwrap_or_default(), &Uuid::parse_str(&message.bot_id).unwrap_or_default(),
@ -423,14 +428,7 @@ impl BotOrchestrator {
.unwrap_or_default() .unwrap_or_default()
.parse::<usize>() .parse::<usize>()
.unwrap_or(0); .unwrap_or(0);
let model = config_manager let handler = llm_models::get_handler(&model1);
.get_config(
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
"llm-model",
None,
)
.unwrap_or_default();
let handler = llm_models::get_handler(&model);
while let Some(chunk) = stream_rx.recv().await { while let Some(chunk) = stream_rx.recv().await {
chunk_count += 1; chunk_count += 1;
if !first_word_received && !chunk.trim().is_empty() { if !first_word_received && !chunk.trim().is_empty() {

View file

@ -1,28 +1,32 @@
use crate::shared::state::AppState;
use crate::shared::models::schema::bots::dsl::*;
use crate::config::ConfigManager; use crate::config::ConfigManager;
use diesel::prelude::*; use crate::shared::models::schema::bots::dsl::*;
use std::sync::Arc; use crate::shared::state::AppState;
use log::{info, error};
use tokio;
use reqwest;
use actix_web::{post, web, HttpResponse, Result}; use actix_web::{post, web, HttpResponse, Result};
use diesel::prelude::*;
use log::{error, info};
use reqwest;
use std::sync::Arc;
use tokio;
#[post("/api/chat/completions")] #[post("/api/chat/completions")]
pub async fn chat_completions_local( pub async fn chat_completions_local(
_data: web::Data<AppState>, _data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>, _payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "chat_completions_local not implemented" }))) Ok(HttpResponse::Ok()
.json(serde_json::json!({ "status": "chat_completions_local not implemented" })))
} }
#[post("/api/embeddings")] #[post("/api/embeddings")]
pub async fn embeddings_local( pub async fn embeddings_local(
_data: web::Data<AppState>, _data: web::Data<AppState>,
_payload: web::Json<serde_json::Value>, _payload: web::Json<serde_json::Value>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
Ok(HttpResponse::Ok().json(serde_json::json!({ "status": "embeddings_local not implemented" }))) Ok(
HttpResponse::Ok()
.json(serde_json::json!({ "status": "embeddings_local not implemented" })),
)
} }
pub async fn ensure_llama_servers_running( pub async fn ensure_llama_servers_running(
app_state: Arc<AppState> app_state: Arc<AppState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let config_values = { let config_values = {
let conn_arc = app_state.conn.clone(); let conn_arc = app_state.conn.clone();
@ -32,18 +36,30 @@ let config_values = {
.select(id) .select(id)
.first::<uuid::Uuid>(&mut *conn) .first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil()) .unwrap_or_else(|_| uuid::Uuid::nil())
}).await?; })
.await?;
let config_manager = ConfigManager::new(app_state.conn.clone()); let config_manager = ConfigManager::new(app_state.conn.clone());
( (
default_bot_id, default_bot_id,
config_manager.get_config(&default_bot_id, "llm-url", None).unwrap_or_default(), config_manager
config_manager.get_config(&default_bot_id, "llm-model", None).unwrap_or_default(), .get_config(&default_bot_id, "llm-url", None)
config_manager.get_config(&default_bot_id, "embedding-url", None).unwrap_or_default(), .unwrap_or_default(),
config_manager.get_config(&default_bot_id, "embedding-model", None).unwrap_or_default(), config_manager
config_manager.get_config(&default_bot_id, "llm-server-path", None).unwrap_or_default(), .get_config(&default_bot_id, "llm-model", None)
.unwrap_or_default(),
config_manager
.get_config(&default_bot_id, "embedding-url", None)
.unwrap_or_default(),
config_manager
.get_config(&default_bot_id, "embedding-model", None)
.unwrap_or_default(),
config_manager
.get_config(&default_bot_id, "llm-server-path", None)
.unwrap_or_default(),
) )
}; };
let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_server_path) = config_values; let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_server_path) =
config_values;
info!("Starting LLM servers..."); info!("Starting LLM servers...");
info!("Configuration:"); info!("Configuration:");
info!(" LLM URL: {}", llm_url); info!(" LLM URL: {}", llm_url);
@ -52,6 +68,12 @@ let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_se
info!(" Embedding Model: {}", embedding_model); info!(" Embedding Model: {}", embedding_model);
info!(" LLM Server Path: {}", llm_server_path); info!(" LLM Server Path: {}", llm_server_path);
info!("Restarting any existing llama-server processes..."); info!("Restarting any existing llama-server processes...");
if let Err(e) = tokio::process::Command::new("sh") if let Err(e) = tokio::process::Command::new("sh")
.arg("-c") .arg("-c")
.arg("pkill llama-server -9 || true") .arg("pkill llama-server -9 || true")
@ -62,8 +84,21 @@ let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_se
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
info!("Existing llama-server processes terminated (if any)"); info!("Existing llama-server processes terminated (if any)");
} }
let llm_running = is_server_running(&llm_url).await;
let embedding_running = is_server_running(&embedding_url).await; // Skip local server startup if using HTTPS endpoints
let llm_running = if llm_url.starts_with("https://") {
info!("Using external HTTPS LLM server, skipping local startup");
true
} else {
is_server_running(&llm_url).await
};
let embedding_running = if embedding_url.starts_with("https://") {
info!("Using external HTTPS embedding server, skipping local startup");
true
} else {
is_server_running(&embedding_url).await
};
if llm_running && embedding_running { if llm_running && embedding_running {
info!("Both LLM and Embedding servers are already running"); info!("Both LLM and Embedding servers are already running");
return Ok(()); return Ok(());
@ -100,7 +135,11 @@ let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_se
let max_attempts = 60; let max_attempts = 60;
while attempts < max_attempts && (!llm_ready || !embedding_ready) { while attempts < max_attempts && (!llm_ready || !embedding_ready) {
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
info!("Checking server health (attempt {}/{})...", attempts + 1, max_attempts); info!(
"Checking server health (attempt {}/{})...",
attempts + 1,
max_attempts
);
if !llm_ready && !llm_model.is_empty() { if !llm_ready && !llm_model.is_empty() {
if is_server_running(&llm_url).await { if is_server_running(&llm_url).await {
info!("LLM server ready at {}", llm_url); info!("LLM server ready at {}", llm_url);
@ -119,11 +158,20 @@ let (_default_bot_id, llm_url, llm_model, embedding_url, embedding_model, llm_se
} }
attempts += 1; attempts += 1;
if attempts % 10 == 0 { if attempts % 10 == 0 {
info!("Still waiting for servers... (attempt {}/{})", attempts, max_attempts); info!(
"Still waiting for servers... (attempt {}/{})",
attempts, max_attempts
);
} }
} }
if llm_ready && embedding_ready { if llm_ready && embedding_ready {
info!("All llama.cpp servers are ready and responding!"); info!("All llama.cpp servers are ready and responding!");
// Update LLM provider with new endpoints
let llm_provider1 = Arc::new(crate::llm::OpenAIClient::new(
llm_model.clone(),
Some(llm_url.clone()),
));
Ok(()) Ok(())
} else { } else {
let mut error_msg = "Servers failed to start within timeout:".to_string(); let mut error_msg = "Servers failed to start within timeout:".to_string();
@ -156,18 +204,35 @@ pub async fn start_llm_server(
let conn = app_state.conn.clone(); let conn = app_state.conn.clone();
let config_manager = ConfigManager::new(conn.clone()); let config_manager = ConfigManager::new(conn.clone());
let mut conn = conn.get().unwrap(); let mut conn = conn.get().unwrap();
let default_bot_id = bots.filter(name.eq("default")) let default_bot_id = bots
.filter(name.eq("default"))
.select(id) .select(id)
.first::<uuid::Uuid>(&mut *conn) .first::<uuid::Uuid>(&mut *conn)
.unwrap_or_else(|_| uuid::Uuid::nil()); .unwrap_or_else(|_| uuid::Uuid::nil());
let n_moe = config_manager.get_config(&default_bot_id, "llm-server-n-moe", None).unwrap_or("4".to_string()); let n_moe = config_manager
let parallel = config_manager.get_config(&default_bot_id, "llm-server-parallel", None).unwrap_or("1".to_string()); .get_config(&default_bot_id, "llm-server-n-moe", None)
let cont_batching = config_manager.get_config(&default_bot_id, "llm-server-cont-batching", None).unwrap_or("true".to_string()); .unwrap_or("4".to_string());
let mlock = config_manager.get_config(&default_bot_id, "llm-server-mlock", None).unwrap_or("true".to_string()); let parallel = config_manager
let no_mmap = config_manager.get_config(&default_bot_id, "llm-server-no-mmap", None).unwrap_or("true".to_string()); .get_config(&default_bot_id, "llm-server-parallel", None)
let gpu_layers = config_manager.get_config(&default_bot_id, "llm-server-gpu-layers", None).unwrap_or("20".to_string()); .unwrap_or("1".to_string());
let reasoning_format = config_manager.get_config(&default_bot_id, "llm-server-reasoning-format", None).unwrap_or("".to_string()); let cont_batching = config_manager
let n_predict = config_manager.get_config(&default_bot_id, "llm-server-n-predict", None).unwrap_or("50".to_string()); .get_config(&default_bot_id, "llm-server-cont-batching", None)
.unwrap_or("true".to_string());
let mlock = config_manager
.get_config(&default_bot_id, "llm-server-mlock", None)
.unwrap_or("true".to_string());
let no_mmap = config_manager
.get_config(&default_bot_id, "llm-server-no-mmap", None)
.unwrap_or("true".to_string());
let gpu_layers = config_manager
.get_config(&default_bot_id, "llm-server-gpu-layers", None)
.unwrap_or("20".to_string());
let reasoning_format = config_manager
.get_config(&default_bot_id, "llm-server-reasoning-format", None)
.unwrap_or("".to_string());
let n_predict = config_manager
.get_config(&default_bot_id, "llm-server-n-predict", None)
.unwrap_or("50".to_string());
let mut args = format!( let mut args = format!(
"-m {} --host 0.0.0.0 --port {} --reasoning-format deepseek --top_p 0.95 --temp 0.6 --repeat-penalty 1.2 --n-gpu-layers {}", "-m {} --host 0.0.0.0 --port {} --reasoning-format deepseek --top_p 0.95 --temp 0.6 --repeat-penalty 1.2 --n-gpu-layers {}",
model_path, port, gpu_layers model_path, port, gpu_layers
@ -199,7 +264,10 @@ pub async fn start_llm_server(
"cd {} && .\\llama-server.exe {} --verbose>llm-stdout.log", "cd {} && .\\llama-server.exe {} --verbose>llm-stdout.log",
llama_cpp_path, args llama_cpp_path, args
)); ));
info!("Executing LLM server command: cd {} && .\\llama-server.exe {} --verbose", llama_cpp_path, args); info!(
"Executing LLM server command: cd {} && .\\llama-server.exe {} --verbose",
llama_cpp_path, args
);
cmd.spawn()?; cmd.spawn()?;
} else { } else {
let mut cmd = tokio::process::Command::new("sh"); let mut cmd = tokio::process::Command::new("sh");
@ -207,7 +275,10 @@ pub async fn start_llm_server(
"cd {} && ./llama-server {} --verbose >../../../../logs/llm/stdout.log 2>&1 &", "cd {} && ./llama-server {} --verbose >../../../../logs/llm/stdout.log 2>&1 &",
llama_cpp_path, args llama_cpp_path, args
)); ));
info!("Executing LLM server command: cd {} && ./llama-server {} --verbose", llama_cpp_path, args); info!(
"Executing LLM server command: cd {} && ./llama-server {} --verbose",
llama_cpp_path, args
);
cmd.spawn()?; cmd.spawn()?;
} }
Ok(()) Ok(())

View file

@ -10,21 +10,15 @@ pub trait LLMProvider: Send + Sync {
&self, &self,
prompt: &str, prompt: &str,
config: &Value, config: &Value,
model: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>; ) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream( async fn generate_stream(
&self, &self,
prompt: &str, prompt: &str,
config: &Value, config: &Value,
tx: mpsc::Sender<String>, tx: mpsc::Sender<String>,
model: &str
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>; ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn summarize(
&self,
text: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let prompt = format!("Summarize the following conversation while preserving key details:\n\n{}", text);
self.generate(&prompt, &serde_json::json!({"max_tokens": 500}))
.await
}
async fn cancel_job( async fn cancel_job(
&self, &self,
session_id: &str, session_id: &str,
@ -34,6 +28,7 @@ pub struct OpenAIClient {
client: reqwest::Client, client: reqwest::Client,
api_key: String, api_key: String,
base_url: String, base_url: String,
} }
#[async_trait] #[async_trait]
impl LLMProvider for OpenAIClient { impl LLMProvider for OpenAIClient {
@ -41,6 +36,7 @@ impl LLMProvider for OpenAIClient {
&self, &self,
prompt: &str, prompt: &str,
messages: &Value, messages: &Value,
model: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let response = self let response = self
@ -48,7 +44,7 @@ impl LLMProvider for OpenAIClient {
.post(&format!("{}/v1/chat/completions", self.base_url)) .post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({ .json(&serde_json::json!({
"model": "gpt-3.5-turbo", "model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
messages messages
} else { } else {
@ -74,6 +70,7 @@ impl LLMProvider for OpenAIClient {
prompt: &str, prompt: &str,
messages: &Value, messages: &Value,
tx: mpsc::Sender<String>, tx: mpsc::Sender<String>,
model: &str
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let response = self let response = self
@ -81,7 +78,7 @@ impl LLMProvider for OpenAIClient {
.post(&format!("{}/v1/chat/completions", self.base_url)) .post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key)) .header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({ .json(&serde_json::json!({
"model": "gpt-3.5-turbo", "model": model.clone(),
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() { "messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
info!("Using provided messages: {:?}", messages); info!("Using provided messages: {:?}", messages);
messages messages

View file

@ -1,20 +1,20 @@
name,value name,value
,
server_host,0.0.0.0 server_host,0.0.0.0
server_port,8080 server_port,8080
sites_root,/tmp sites_root,/tmp
,
llm-key,none llm-key,none
llm-url,http://localhost:8081 llm-url,http://localhost:8081
llm-model,../../../../data/llm/DeepSeek-R1-Distill-Qwen-1.5B-Q3_K_M.gguf llm-model,../../../../data/llm/DeepSeek-R1-Distill-Qwen-1.5B-Q3_K_M.gguf
,
prompt-compact,4 prompt-compact,4
,
mcp-server,false mcp-server,false
,
embedding-url,http://localhost:8082 embedding-url,http://localhost:8082
embedding-model,../../../../data/llm/bge-small-en-v1.5-f32.gguf embedding-model,../../../../data/llm/bge-small-en-v1.5-f32.gguf
,
llm-server,false llm-server,false
llm-server-path,botserver-stack/bin/llm/build/bin llm-server-path,botserver-stack/bin/llm/build/bin
llm-server-host,0.0.0.0 llm-server-host,0.0.0.0
@ -27,14 +27,14 @@ llm-server-parallel,6
llm-server-cont-batching,true llm-server-cont-batching,true
llm-server-mlock,false llm-server-mlock,false
llm-server-no-mmap,false llm-server-no-mmap,false
,
,
email-from,from@domain.com email-from,from@domain.com
email-server,mail.domain.com email-server,mail.domain.com
email-port,587 email-port,587
email-user,user@domain.com email-user,user@domain.com
email-pass, email-pass,
,
custom-server,localhost custom-server,localhost
custom-port,5432 custom-port,5432
custom-database,mycustomdb custom-database,mycustomdb

Can't render this file because it has a wrong number of fields in line 30.