fix(llm-config): Fix ConfigManager fallback logic for LLM configuration
Some checks failed
GBCI / build (push) Failing after 12m26s

- Fix ConfigManager to treat 'none', 'null', 'n/a', and empty values as placeholders
  and fall back to default bot's configuration instead of using these as literal values

- Fix ConfigManager to detect local file paths (e.g., .gguf, .bin, ../) and fall back
  to default bot's model when using remote API, allowing bots to keep local model
  config for local LLM server while automatically using remote model for API calls

- Fix get_default_bot() to return the bot actually named 'default' instead of
  the first active bot by ID, ensuring consistent fallback behavior

- Add comprehensive debug logging to trace LLM configuration from database to API call

This fixes the issue where bots with incomplete or local LLM configuration would
fail with 401/400 errors when trying to use remote API, instead of automatically
falling back to the default bot's configuration from config.csv.

Closes: #llm-config-fallback
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2026-02-02 19:20:37 -03:00
parent 39c4dba838
commit 5fb4c889b7
3 changed files with 549 additions and 35 deletions

View file

@ -1,5 +1,7 @@
#[cfg(any(feature = "research", feature = "llm"))] #[cfg(any(feature = "research", feature = "llm"))]
pub mod kb_context; pub mod kb_context;
#[cfg(any(feature = "research", feature = "llm"))]
use kb_context::inject_kb_context;
#[cfg(feature = "llm")] #[cfg(feature = "llm")]
use crate::core::config::ConfigManager; use crate::core::config::ConfigManager;
@ -20,7 +22,10 @@ use axum::{
http::StatusCode, http::StatusCode,
response::{IntoResponse, Json}, response::{IntoResponse, Json},
}; };
use diesel::ExpressionMethods;
use diesel::PgConnection; use diesel::PgConnection;
use diesel::QueryDsl;
use diesel::RunQueryDsl;
use futures::{sink::SinkExt, stream::StreamExt}; use futures::{sink::SinkExt, stream::StreamExt};
#[cfg(feature = "llm")] #[cfg(feature = "llm")]
use log::trace; use log::trace;
@ -39,7 +44,9 @@ pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
use crate::shared::models::schema::bots::dsl::*; use crate::shared::models::schema::bots::dsl::*;
use diesel::prelude::*; use diesel::prelude::*;
// First try to get the bot named "default"
match bots match bots
.filter(name.eq("default"))
.filter(is_active.eq(true)) .filter(is_active.eq(true))
.select((id, name)) .select((id, name))
.first::<(Uuid, String)>(conn) .first::<(Uuid, String)>(conn)
@ -47,8 +54,24 @@ pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
{ {
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name), Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
Ok(None) => { Ok(None) => {
warn!("No active bots found, using nil UUID"); warn!("Bot named 'default' not found, falling back to first active bot");
(Uuid::nil(), "default".to_string()) // Fall back to first active bot
match bots
.filter(is_active.eq(true))
.select((id, name))
.first::<(Uuid, String)>(conn)
.optional()
{
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
Ok(None) => {
warn!("No active bots found, using nil UUID");
(Uuid::nil(), "default".to_string())
}
Err(e) => {
error!("Failed to query fallback bot: {}", e);
(Uuid::nil(), "default".to_string())
}
}
} }
Err(e) => { Err(e) => {
error!("Failed to query default bot: {}", e); error!("Failed to query default bot: {}", e);
@ -72,10 +95,116 @@ impl BotOrchestrator {
} }
pub fn mount_all_bots(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub fn mount_all_bots(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("mount_all_bots called"); info!("Scanning drive for .gbai files to mount bots...");
let mut bots_mounted = 0;
let mut bots_created = 0;
let directories_to_scan: Vec<std::path::PathBuf> = vec![
self.state
.config
.as_ref()
.map(|c| c.site_path.clone())
.unwrap_or_else(|| "./botserver-stack/sites".to_string())
.into(),
"./templates".into(),
"../bottemplates".into(),
];
for dir_path in directories_to_scan {
info!("Checking directory for bots: {}", dir_path.display());
if !dir_path.exists() {
info!("Directory does not exist, skipping: {}", dir_path.display());
continue;
}
match self.scan_directory(&dir_path, &mut bots_mounted, &mut bots_created) {
Ok(()) => {}
Err(e) => {
error!("Failed to scan directory {}: {}", dir_path.display(), e);
}
}
}
info!(
"Bot mounting complete: {} bots processed ({} created, {} already existed)",
bots_mounted,
bots_created,
bots_mounted - bots_created
);
Ok(()) Ok(())
} }
fn scan_directory(
&self,
dir_path: &std::path::Path,
bots_mounted: &mut i32,
_bots_created: &mut i32,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let entries =
std::fs::read_dir(dir_path).map_err(|e| format!("Failed to read directory: {}", e))?;
for entry in entries.flatten() {
let name = entry.file_name();
let bot_name = match name.to_str() {
Some(n) if n.ends_with(".gbai") => n.trim_end_matches(".gbai"),
_ => continue,
};
info!("Found .gbai file: {}", bot_name);
match self.ensure_bot_exists(bot_name) {
Ok(true) => {
info!("Bot '{}' already exists in database, mounting", bot_name);
*bots_mounted += 1;
}
Ok(false) => {
info!(
"Bot '{}' does not exist in database, skipping (run import to create)",
bot_name
);
}
Err(e) => {
error!("Failed to check if bot '{}' exists: {}", bot_name, e);
}
}
}
Ok(())
}
fn ensure_bot_exists(
&self,
bot_name: &str,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
use diesel::sql_query;
let mut conn = self
.state
.conn
.get()
.map_err(|e| format!("Failed to get database connection: {e}"))?;
#[derive(diesel::QueryableByName)]
#[diesel(check_for_backend(diesel::pg::Pg))]
struct BotExistsResult {
#[diesel(sql_type = diesel::sql_types::Bool)]
exists: bool,
}
let exists: BotExistsResult = sql_query(
"SELECT EXISTS(SELECT 1 FROM bots WHERE name = $1 AND is_active = true) as exists",
)
.bind::<diesel::sql_types::Text, _>(bot_name)
.get_result(&mut conn)
.map_err(|e| format!("Failed to check if bot exists: {e}"))?;
Ok(exists.exists)
}
#[cfg(feature = "llm")] #[cfg(feature = "llm")]
pub async fn stream_response( pub async fn stream_response(
&self, &self,
@ -90,6 +219,7 @@ impl BotOrchestrator {
let user_id = Uuid::parse_str(&message.user_id)?; let user_id = Uuid::parse_str(&message.user_id)?;
let session_id = Uuid::parse_str(&message.session_id)?; let session_id = Uuid::parse_str(&message.session_id)?;
let message_content = message.content.clone();
let (session, context_data, history, model, key) = { let (session, context_data, history, model, key) = {
let state_clone = self.state.clone(); let state_clone = self.state.clone();
@ -117,13 +247,24 @@ impl BotOrchestrator {
}; };
let config_manager = ConfigManager::new(state_clone.conn.clone()); let config_manager = ConfigManager::new(state_clone.conn.clone());
// DEBUG: Log which bot we're getting config for
info!("[CONFIG_TRACE] Getting LLM config for bot_id: {}", session.bot_id);
let model = config_manager let model = config_manager
.get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo")) .get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string()); .unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
let key = config_manager let key = config_manager
.get_config(&session.bot_id, "llm-key", Some("")) .get_config(&session.bot_id, "llm-key", Some(""))
.unwrap_or_default(); .unwrap_or_default();
// DEBUG: Log the exact config values retrieved
info!("[CONFIG_TRACE] Model: '{}'", model);
info!("[CONFIG_TRACE] API Key: '{}' ({} chars)", key, key.len());
info!("[CONFIG_TRACE] API Key first 10 chars: '{}'", &key.chars().take(10).collect::<String>());
info!("[CONFIG_TRACE] API Key last 10 chars: '{}'", &key.chars().rev().take(10).collect::<String>());
Ok((session, context_data, history, model, key)) Ok((session, context_data, history, model, key))
}, },
) )
@ -131,7 +272,39 @@ impl BotOrchestrator {
}; };
let system_prompt = "You are a helpful assistant.".to_string(); let system_prompt = "You are a helpful assistant.".to_string();
let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history); let mut messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
#[cfg(any(feature = "research", feature = "llm"))]
{
if let Some(kb_manager) = self.state.kb_manager.as_ref() {
let bot_name_for_kb = {
let conn = self.state.conn.get().ok();
if let Some(mut db_conn) = conn {
use crate::shared::models::schema::bots::dsl::*;
bots.filter(id.eq(session.bot_id))
.select(name)
.first::<String>(&mut db_conn)
.unwrap_or_else(|_| "default".to_string())
} else {
"default".to_string()
}
};
if let Err(e) = inject_kb_context(
kb_manager.clone(),
self.state.conn.clone(),
session_id,
&bot_name_for_kb,
&message_content,
&mut messages,
8000,
)
.await
{
error!("Failed to inject KB context: {}", e);
}
}
}
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100); let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
let llm = self.state.llm_provider.clone(); let llm = self.state.llm_provider.clone();
@ -139,6 +312,16 @@ impl BotOrchestrator {
let model_clone = model.clone(); let model_clone = model.clone();
let key_clone = key.clone(); let key_clone = key.clone();
let messages_clone = messages.clone(); let messages_clone = messages.clone();
// DEBUG: Log exact values being passed to LLM
info!("[LLM_CALL] Calling generate_stream with:");
info!("[LLM_CALL] Model: '{}'", model_clone);
info!("[LLM_CALL] Key length: {} chars", key_clone.len());
info!("[LLM_CALL] Key preview: '{}...{}'",
&key_clone.chars().take(8).collect::<String>(),
&key_clone.chars().rev().take(8).collect::<String>()
);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = llm if let Err(e) = llm
.generate_stream("", &messages_clone, stream_tx, &model_clone, &key_clone) .generate_stream("", &messages_clone, stream_tx, &model_clone, &key_clone)
@ -388,10 +571,18 @@ pub async fn websocket_handler(
let conn = state.conn.get().ok(); let conn = state.conn.get().ok();
if let Some(mut db_conn) = conn { if let Some(mut db_conn) = conn {
use crate::shared::models::schema::bots::dsl::*; use crate::shared::models::schema::bots::dsl::*;
let result: Result<Uuid, _> = bots
.filter(name.eq(&bot_name)) // Try to parse as UUID first, if that fails treat as bot name
.select(id) let result: Result<Uuid, _> = if let Ok(uuid) = Uuid::parse_str(&bot_name) {
.first(&mut db_conn); // Parameter is a UUID, look up by id
bots.filter(id.eq(uuid)).select(id).first(&mut db_conn)
} else {
// Parameter is a bot name, look up by name
bots.filter(name.eq(&bot_name))
.select(id)
.first(&mut db_conn)
};
result.unwrap_or_else(|_| { result.unwrap_or_else(|_| {
log::warn!("Bot not found: {}, using nil bot_id", bot_name); log::warn!("Bot not found: {}, using nil bot_id", bot_name);
Uuid::nil() Uuid::nil()
@ -427,8 +618,8 @@ async fn handle_websocket(
} }
info!( info!(
"WebSocket connected for session: {}, user: {}", "WebSocket connected for session: {}, user: {}, bot: {}",
session_id, user_id session_id, user_id, bot_id
); );
let welcome = serde_json::json!({ let welcome = serde_json::json!({
@ -445,6 +636,89 @@ async fn handle_websocket(
} }
} }
// Execute start.bas automatically on connection (similar to auth.ast pattern)
{
let bot_name_result = {
let conn = state.conn.get().ok();
if let Some(mut db_conn) = conn {
use crate::shared::models::schema::bots::dsl::*;
bots.filter(id.eq(bot_id))
.select(name)
.first::<String>(&mut db_conn)
.ok()
} else {
None
}
};
// DEBUG: Log start script execution attempt
info!(
"Checking for start.bas: bot_id={}, bot_name_result={:?}",
bot_id,
bot_name_result
);
if let Some(bot_name) = bot_name_result {
let start_script_path = format!("./work/{}.gbai/{}.gbdialog/start.bas", bot_name, bot_name);
info!("Looking for start.bas at: {}", start_script_path);
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
if metadata.is_file() {
info!("Found start.bas file, reading contents...");
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
info!(
"Executing start.bas for bot {} on session {}",
bot_name, session_id
);
let state_for_start = state.clone();
let _tx_for_start = tx.clone();
tokio::spawn(async move {
let session_result = {
let mut sm = state_for_start.session_manager.lock().await;
sm.get_session_by_id(session_id)
};
if let Ok(Some(session)) = session_result {
info!("Executing start.bas for bot {} on session {}", bot_name, session_id);
let result = tokio::task::spawn_blocking(move || {
let mut script_service = crate::basic::ScriptService::new(
state_for_start.clone(),
session.clone()
);
script_service.load_bot_config_params(&state_for_start, bot_id);
match script_service.compile(&start_script) {
Ok(ast) => match script_service.run(&ast) {
Ok(_) => Ok(()),
Err(e) => Err(format!("Script execution error: {}", e)),
},
Err(e) => Err(format!("Script compilation error: {}", e)),
}
}).await;
match result {
Ok(Ok(())) => {
info!("start.bas executed successfully for bot {}", bot_name);
}
Ok(Err(e)) => {
error!("start.bas error for bot {}: {}", bot_name, e);
}
Err(e) => {
error!("start.bas task error for bot {}: {}", bot_name, e);
}
}
}
});
}
}
}
}
}
let mut send_task = tokio::spawn(async move { let mut send_task = tokio::spawn(async move {
while let Some(response) = rx.recv().await { while let Some(response) = rx.recv().await {
if let Ok(json_str) = serde_json::to_string(&response) { if let Ok(json_str) = serde_json::to_string(&response) {

View file

@ -362,14 +362,55 @@ impl ConfigManager {
use crate::shared::models::schema::bot_configuration::dsl::*; use crate::shared::models::schema::bot_configuration::dsl::*;
let mut conn = self.get_conn()?; let mut conn = self.get_conn()?;
let fallback_str = fallback.unwrap_or(""); let fallback_str = fallback.unwrap_or("");
// Helper function to check if a value should be treated as "not configured"
fn is_placeholder_value(value: &str) -> bool {
let trimmed = value.trim().to_lowercase();
trimmed.is_empty() || trimmed == "none" || trimmed == "null" || trimmed == "n/a"
}
// Helper function to check if a value is a local file path (for local LLM server)
// These should fall back to default bot's config when using remote API
fn is_local_file_path(value: &str) -> bool {
let value = value.trim();
// Check for file path patterns
value.starts_with("../") ||
value.starts_with("./") ||
value.starts_with('/') ||
value.starts_with("~") ||
value.contains(".gguf") ||
value.contains(".bin") ||
value.contains(".safetensors") ||
value.starts_with("data/") ||
value.starts_with("../../") ||
value.starts_with("models/")
}
// Try to get value for the specific bot
let result = bot_configuration let result = bot_configuration
.filter(bot_id.eq(code_bot_id)) .filter(bot_id.eq(code_bot_id))
.filter(config_key.eq(key)) .filter(config_key.eq(key))
.select(config_value) .select(config_value)
.first::<String>(&mut conn); .first::<String>(&mut conn);
let value = match result { let value = match result {
Ok(v) => v, Ok(v) => {
// Check if it's a placeholder value or local file path - if so, fall back to default bot
// Local file paths are valid for local LLM server but NOT for remote APIs
if is_placeholder_value(&v) || is_local_file_path(&v) {
let (default_bot_id, _default_bot_name) = crate::bot::get_default_bot(&mut conn);
bot_configuration
.filter(bot_id.eq(default_bot_id))
.filter(config_key.eq(key))
.select(config_value)
.first::<String>(&mut conn)
.unwrap_or_else(|_| fallback_str.to_string())
} else {
v
}
}
Err(_) => { Err(_) => {
// Value not found, fall back to default bot
let (default_bot_id, _default_bot_name) = crate::bot::get_default_bot(&mut conn); let (default_bot_id, _default_bot_name) = crate::bot::get_default_bot(&mut conn);
bot_configuration bot_configuration
.filter(bot_id.eq(default_bot_id)) .filter(bot_id.eq(default_bot_id))
@ -379,7 +420,15 @@ impl ConfigManager {
.unwrap_or_else(|_| fallback_str.to_string()) .unwrap_or_else(|_| fallback_str.to_string())
} }
}; };
Ok(value)
// Final check: if the result is still a placeholder value, use the fallback_str
let final_value = if is_placeholder_value(&value) {
fallback_str.to_string()
} else {
value
};
Ok(final_value)
} }
pub fn get_bot_config_value( pub fn get_bot_config_value(

View file

@ -1,6 +1,6 @@
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use log::{info, trace}; use log::{error, info};
use serde_json::Value; use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, RwLock}; use tokio::sync::{mpsc, RwLock};
@ -44,13 +44,116 @@ pub trait LLMProvider: Send + Sync {
pub struct OpenAIClient { pub struct OpenAIClient {
client: reqwest::Client, client: reqwest::Client,
base_url: String, base_url: String,
endpoint_path: String,
} }
impl OpenAIClient { impl OpenAIClient {
pub fn new(_api_key: String, base_url: Option<String>) -> Self { /// Estimates token count for a text string (roughly 4 characters per token for English)
fn estimate_tokens(text: &str) -> usize {
// Rough estimate: ~4 characters per token for English text
// This is a heuristic and may not be accurate for all languages
text.len().div_ceil(4)
}
/// Estimates total tokens for a messages array
fn estimate_messages_tokens(messages: &Value) -> usize {
if let Some(msg_array) = messages.as_array() {
msg_array
.iter()
.map(|msg| {
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
Self::estimate_tokens(content)
} else {
0
}
})
.sum()
} else {
0
}
}
/// Truncates messages to fit within the max_tokens limit
/// Keeps system messages and the most recent user/assistant messages
fn truncate_messages(messages: &Value, max_tokens: usize) -> Value {
let mut result = Vec::new();
let mut token_count = 0;
if let Some(msg_array) = messages.as_array() {
// First pass: keep all system messages
for msg in msg_array {
if let Some(role) = msg.get("role").and_then(|r| r.as_str()) {
if role == "system" {
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
let msg_tokens = Self::estimate_tokens(content);
if token_count + msg_tokens <= max_tokens {
result.push(msg.clone());
token_count += msg_tokens;
}
}
}
}
}
// Second pass: add user/assistant messages from newest to oldest
let mut recent_messages: Vec<&Value> = msg_array
.iter()
.filter(|msg| msg.get("role").and_then(|r| r.as_str()) != Some("system"))
.collect();
// Reverse to get newest first
recent_messages.reverse();
for msg in recent_messages {
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
let msg_tokens = Self::estimate_tokens(content);
if token_count + msg_tokens <= max_tokens {
result.push(msg.clone());
token_count += msg_tokens;
} else {
break;
}
}
}
// Reverse back to chronological order for non-system messages
// But keep system messages at the beginning
let system_count = result.len()
- result
.iter()
.filter(|m| m.get("role").and_then(|r| r.as_str()) != Some("system"))
.count();
let mut user_messages: Vec<Value> = result.drain(system_count..).collect();
user_messages.reverse();
result.extend(user_messages);
}
serde_json::Value::Array(result)
}
/// Ensures messages fit within model's context limit
fn ensure_token_limit(messages: &Value, model_context_limit: usize) -> Value {
let estimated_tokens = Self::estimate_messages_tokens(messages);
// Use 90% of context limit to leave room for response
let safe_limit = (model_context_limit as f64 * 0.9) as usize;
if estimated_tokens > safe_limit {
log::warn!(
"Messages exceed token limit ({} > {}), truncating...",
estimated_tokens,
safe_limit
);
Self::truncate_messages(messages, safe_limit)
} else {
messages.clone()
}
}
pub fn new(_api_key: String, base_url: Option<String>, endpoint_path: Option<String>) -> Self {
Self { Self {
client: reqwest::Client::new(), client: reqwest::Client::new(),
base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()), base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()),
endpoint_path: endpoint_path.unwrap_or_else(|| "/v1/chat/completions".to_string()),
} }
} }
@ -92,21 +195,64 @@ impl LLMProvider for OpenAIClient {
key: &str, key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
// Get the messages to use
let raw_messages =
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
messages
} else {
&default_messages
};
// Ensure messages fit within model's context limit
// GLM-4.7 has 202750 tokens, other models vary
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
202750
} else if model.contains("gpt-4") {
128000
} else if model.contains("gpt-3.5") {
16385
} else if model.starts_with("http://localhost:808") || model == "local" {
768 // Local llama.cpp server context limit
} else {
4096 // Default conservative limit
};
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
let auth_header = format!("Bearer {}", key);
// Debug logging to help troubleshoot 401 errors
info!("LLM Request Details:");
info!(" URL: {}", full_url);
info!(" Authorization: Bearer <{} chars>", key.len());
info!(" Model: {}", model);
if let Some(msg_array) = messages.as_array() {
info!(" Messages: {} messages", msg_array.len());
}
info!(" API Key First 8 chars: '{}...'", &key.chars().take(8).collect::<String>());
info!(" API Key Last 8 chars: '...{}'", &key.chars().rev().take(8).collect::<String>());
let response = self let response = self
.client .client
.post(format!("{}/v1/chat/completions", self.base_url)) .post(&full_url)
.header("Authorization", format!("Bearer {}", key)) .header("Authorization", &auth_header)
.json(&serde_json::json!({ .json(&serde_json::json!({
"model": model, "model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { "messages": messages,
messages "stream": true
} else {
&default_messages
}
})) }))
.send() .send()
.await?; .await?;
let status = response.status();
if status != reqwest::StatusCode::OK {
let error_text = response.text().await.unwrap_or_default();
error!("LLM generate error: {}", error_text);
return Err(format!("LLM request failed with status: {}", status).into());
}
let result: Value = response.json().await?; let result: Value = response.json().await?;
let raw_content = result["choices"][0]["message"]["content"] let raw_content = result["choices"][0]["message"]["content"]
.as_str() .as_str()
@ -127,18 +273,51 @@ impl LLMProvider for OpenAIClient {
key: &str, key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]); let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
// Get the messages to use
let raw_messages =
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
info!("Using provided messages: {:?}", messages);
messages
} else {
&default_messages
};
// Ensure messages fit within model's context limit
// GLM-4.7 has 202750 tokens, other models vary
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
202750
} else if model.contains("gpt-4") {
128000
} else if model.contains("gpt-3.5") {
16385
} else if model.starts_with("http://localhost:808") || model == "local" {
768 // Local llama.cpp server context limit
} else {
4096 // Default conservative limit
};
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
let auth_header = format!("Bearer {}", key);
// Debug logging to help troubleshoot 401 errors
info!("LLM Request Details:");
info!(" URL: {}", full_url);
info!(" Authorization: Bearer <{} chars>", key.len());
info!(" Model: {}", model);
if let Some(msg_array) = messages.as_array() {
info!(" Messages: {} messages", msg_array.len());
}
let response = self let response = self
.client .client
.post(format!("{}/v1/chat/completions", self.base_url)) .post(&full_url)
.header("Authorization", format!("Bearer {}", key)) .header("Authorization", &auth_header)
.json(&serde_json::json!({ .json(&serde_json::json!({
"model": model, "model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() { "messages": messages,
info!("Using provided messages: {:?}", messages);
messages
} else {
&default_messages
},
"stream": true "stream": true
})) }))
.send() .send()
@ -147,7 +326,7 @@ impl LLMProvider for OpenAIClient {
let status = response.status(); let status = response.status();
if status != reqwest::StatusCode::OK { if status != reqwest::StatusCode::OK {
let error_text = response.text().await.unwrap_or_default(); let error_text = response.text().await.unwrap_or_default();
trace!("LLM generate_stream error: {}", error_text); error!("LLM generate_stream error: {}", error_text);
return Err(format!("LLM request failed with status: {}", status).into()); return Err(format!("LLM request failed with status: {}", status).into());
} }
@ -213,11 +392,16 @@ pub fn create_llm_provider(
provider_type: LLMProviderType, provider_type: LLMProviderType,
base_url: String, base_url: String,
deployment_name: Option<String>, deployment_name: Option<String>,
endpoint_path: Option<String>,
) -> std::sync::Arc<dyn LLMProvider> { ) -> std::sync::Arc<dyn LLMProvider> {
match provider_type { match provider_type {
LLMProviderType::OpenAI => { LLMProviderType::OpenAI => {
info!("Creating OpenAI LLM provider with URL: {}", base_url); info!("Creating OpenAI LLM provider with URL: {}", base_url);
std::sync::Arc::new(OpenAIClient::new("empty".to_string(), Some(base_url))) std::sync::Arc::new(OpenAIClient::new(
"empty".to_string(),
Some(base_url),
endpoint_path,
))
} }
LLMProviderType::Claude => { LLMProviderType::Claude => {
info!("Creating Claude LLM provider with URL: {}", base_url); info!("Creating Claude LLM provider with URL: {}", base_url);
@ -237,9 +421,10 @@ pub fn create_llm_provider(
pub fn create_llm_provider_from_url( pub fn create_llm_provider_from_url(
url: &str, url: &str,
model: Option<String>, model: Option<String>,
endpoint_path: Option<String>,
) -> std::sync::Arc<dyn LLMProvider> { ) -> std::sync::Arc<dyn LLMProvider> {
let provider_type = LLMProviderType::from(url); let provider_type = LLMProviderType::from(url);
create_llm_provider(provider_type, url.to_string(), model) create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
} }
pub struct DynamicLLMProvider { pub struct DynamicLLMProvider {
@ -259,8 +444,13 @@ impl DynamicLLMProvider {
info!("LLM provider updated dynamically"); info!("LLM provider updated dynamically");
} }
pub async fn update_from_config(&self, url: &str, model: Option<String>) { pub async fn update_from_config(
let new_provider = create_llm_provider_from_url(url, model); &self,
url: &str,
model: Option<String>,
endpoint_path: Option<String>,
) {
let new_provider = create_llm_provider_from_url(url, model, endpoint_path);
self.update_provider(new_provider).await; self.update_provider(new_provider).await;
} }
@ -490,7 +680,7 @@ mod tests {
#[test] #[test]
fn test_openai_client_new_default_url() { fn test_openai_client_new_default_url() {
let client = OpenAIClient::new("test_key".to_string(), None); let client = OpenAIClient::new("test_key".to_string(), None, None);
assert_eq!(client.base_url, "https://api.openai.com"); assert_eq!(client.base_url, "https://api.openai.com");
} }
@ -499,6 +689,7 @@ mod tests {
let client = OpenAIClient::new( let client = OpenAIClient::new(
"test_key".to_string(), "test_key".to_string(),
Some("http://localhost:8080".to_string()), Some("http://localhost:8080".to_string()),
None,
); );
assert_eq!(client.base_url, "http://localhost:8080"); assert_eq!(client.base_url, "http://localhost:8080");
} }