fix(llm-config): Fix ConfigManager fallback logic for LLM configuration
Some checks failed
GBCI / build (push) Failing after 12m26s
Some checks failed
GBCI / build (push) Failing after 12m26s
- Fix ConfigManager to treat 'none', 'null', 'n/a', and empty values as placeholders and fall back to default bot's configuration instead of using these as literal values - Fix ConfigManager to detect local file paths (e.g., .gguf, .bin, ../) and fall back to default bot's model when using remote API, allowing bots to keep local model config for local LLM server while automatically using remote model for API calls - Fix get_default_bot() to return the bot actually named 'default' instead of the first active bot by ID, ensuring consistent fallback behavior - Add comprehensive debug logging to trace LLM configuration from database to API call This fixes the issue where bots with incomplete or local LLM configuration would fail with 401/400 errors when trying to use remote API, instead of automatically falling back to the default bot's configuration from config.csv. Closes: #llm-config-fallback
This commit is contained in:
parent
39c4dba838
commit
5fb4c889b7
3 changed files with 549 additions and 35 deletions
|
|
@ -1,5 +1,7 @@
|
|||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
pub mod kb_context;
|
||||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
use kb_context::inject_kb_context;
|
||||
#[cfg(feature = "llm")]
|
||||
use crate::core::config::ConfigManager;
|
||||
|
||||
|
|
@ -20,7 +22,10 @@ use axum::{
|
|||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use diesel::ExpressionMethods;
|
||||
use diesel::PgConnection;
|
||||
use diesel::QueryDsl;
|
||||
use diesel::RunQueryDsl;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
#[cfg(feature = "llm")]
|
||||
use log::trace;
|
||||
|
|
@ -39,7 +44,9 @@ pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
|
|||
use crate::shared::models::schema::bots::dsl::*;
|
||||
use diesel::prelude::*;
|
||||
|
||||
// First try to get the bot named "default"
|
||||
match bots
|
||||
.filter(name.eq("default"))
|
||||
.filter(is_active.eq(true))
|
||||
.select((id, name))
|
||||
.first::<(Uuid, String)>(conn)
|
||||
|
|
@ -47,8 +54,24 @@ pub fn get_default_bot(conn: &mut PgConnection) -> (Uuid, String) {
|
|||
{
|
||||
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
|
||||
Ok(None) => {
|
||||
warn!("No active bots found, using nil UUID");
|
||||
(Uuid::nil(), "default".to_string())
|
||||
warn!("Bot named 'default' not found, falling back to first active bot");
|
||||
// Fall back to first active bot
|
||||
match bots
|
||||
.filter(is_active.eq(true))
|
||||
.select((id, name))
|
||||
.first::<(Uuid, String)>(conn)
|
||||
.optional()
|
||||
{
|
||||
Ok(Some((bot_id, bot_name))) => (bot_id, bot_name),
|
||||
Ok(None) => {
|
||||
warn!("No active bots found, using nil UUID");
|
||||
(Uuid::nil(), "default".to_string())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to query fallback bot: {}", e);
|
||||
(Uuid::nil(), "default".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to query default bot: {}", e);
|
||||
|
|
@ -72,10 +95,116 @@ impl BotOrchestrator {
|
|||
}
|
||||
|
||||
pub fn mount_all_bots(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
info!("mount_all_bots called");
|
||||
info!("Scanning drive for .gbai files to mount bots...");
|
||||
|
||||
let mut bots_mounted = 0;
|
||||
let mut bots_created = 0;
|
||||
|
||||
let directories_to_scan: Vec<std::path::PathBuf> = vec![
|
||||
self.state
|
||||
.config
|
||||
.as_ref()
|
||||
.map(|c| c.site_path.clone())
|
||||
.unwrap_or_else(|| "./botserver-stack/sites".to_string())
|
||||
.into(),
|
||||
"./templates".into(),
|
||||
"../bottemplates".into(),
|
||||
];
|
||||
|
||||
for dir_path in directories_to_scan {
|
||||
info!("Checking directory for bots: {}", dir_path.display());
|
||||
|
||||
if !dir_path.exists() {
|
||||
info!("Directory does not exist, skipping: {}", dir_path.display());
|
||||
continue;
|
||||
}
|
||||
|
||||
match self.scan_directory(&dir_path, &mut bots_mounted, &mut bots_created) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("Failed to scan directory {}: {}", dir_path.display(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Bot mounting complete: {} bots processed ({} created, {} already existed)",
|
||||
bots_mounted,
|
||||
bots_created,
|
||||
bots_mounted - bots_created
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scan_directory(
|
||||
&self,
|
||||
dir_path: &std::path::Path,
|
||||
bots_mounted: &mut i32,
|
||||
_bots_created: &mut i32,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let entries =
|
||||
std::fs::read_dir(dir_path).map_err(|e| format!("Failed to read directory: {}", e))?;
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let name = entry.file_name();
|
||||
|
||||
let bot_name = match name.to_str() {
|
||||
Some(n) if n.ends_with(".gbai") => n.trim_end_matches(".gbai"),
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
info!("Found .gbai file: {}", bot_name);
|
||||
|
||||
match self.ensure_bot_exists(bot_name) {
|
||||
Ok(true) => {
|
||||
info!("Bot '{}' already exists in database, mounting", bot_name);
|
||||
*bots_mounted += 1;
|
||||
}
|
||||
Ok(false) => {
|
||||
info!(
|
||||
"Bot '{}' does not exist in database, skipping (run import to create)",
|
||||
bot_name
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to check if bot '{}' exists: {}", bot_name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_bot_exists(
|
||||
&self,
|
||||
bot_name: &str,
|
||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use diesel::sql_query;
|
||||
|
||||
let mut conn = self
|
||||
.state
|
||||
.conn
|
||||
.get()
|
||||
.map_err(|e| format!("Failed to get database connection: {e}"))?;
|
||||
|
||||
#[derive(diesel::QueryableByName)]
|
||||
#[diesel(check_for_backend(diesel::pg::Pg))]
|
||||
struct BotExistsResult {
|
||||
#[diesel(sql_type = diesel::sql_types::Bool)]
|
||||
exists: bool,
|
||||
}
|
||||
|
||||
let exists: BotExistsResult = sql_query(
|
||||
"SELECT EXISTS(SELECT 1 FROM bots WHERE name = $1 AND is_active = true) as exists",
|
||||
)
|
||||
.bind::<diesel::sql_types::Text, _>(bot_name)
|
||||
.get_result(&mut conn)
|
||||
.map_err(|e| format!("Failed to check if bot exists: {e}"))?;
|
||||
|
||||
Ok(exists.exists)
|
||||
}
|
||||
|
||||
#[cfg(feature = "llm")]
|
||||
pub async fn stream_response(
|
||||
&self,
|
||||
|
|
@ -90,6 +219,7 @@ impl BotOrchestrator {
|
|||
|
||||
let user_id = Uuid::parse_str(&message.user_id)?;
|
||||
let session_id = Uuid::parse_str(&message.session_id)?;
|
||||
let message_content = message.content.clone();
|
||||
|
||||
let (session, context_data, history, model, key) = {
|
||||
let state_clone = self.state.clone();
|
||||
|
|
@ -117,13 +247,24 @@ impl BotOrchestrator {
|
|||
};
|
||||
|
||||
let config_manager = ConfigManager::new(state_clone.conn.clone());
|
||||
|
||||
// DEBUG: Log which bot we're getting config for
|
||||
info!("[CONFIG_TRACE] Getting LLM config for bot_id: {}", session.bot_id);
|
||||
|
||||
let model = config_manager
|
||||
.get_config(&session.bot_id, "llm-model", Some("gpt-3.5-turbo"))
|
||||
.unwrap_or_else(|_| "gpt-3.5-turbo".to_string());
|
||||
|
||||
let key = config_manager
|
||||
.get_config(&session.bot_id, "llm-key", Some(""))
|
||||
.unwrap_or_default();
|
||||
|
||||
// DEBUG: Log the exact config values retrieved
|
||||
info!("[CONFIG_TRACE] Model: '{}'", model);
|
||||
info!("[CONFIG_TRACE] API Key: '{}' ({} chars)", key, key.len());
|
||||
info!("[CONFIG_TRACE] API Key first 10 chars: '{}'", &key.chars().take(10).collect::<String>());
|
||||
info!("[CONFIG_TRACE] API Key last 10 chars: '{}'", &key.chars().rev().take(10).collect::<String>());
|
||||
|
||||
Ok((session, context_data, history, model, key))
|
||||
},
|
||||
)
|
||||
|
|
@ -131,7 +272,39 @@ impl BotOrchestrator {
|
|||
};
|
||||
|
||||
let system_prompt = "You are a helpful assistant.".to_string();
|
||||
let messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
|
||||
let mut messages = OpenAIClient::build_messages(&system_prompt, &context_data, &history);
|
||||
|
||||
#[cfg(any(feature = "research", feature = "llm"))]
|
||||
{
|
||||
if let Some(kb_manager) = self.state.kb_manager.as_ref() {
|
||||
let bot_name_for_kb = {
|
||||
let conn = self.state.conn.get().ok();
|
||||
if let Some(mut db_conn) = conn {
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
bots.filter(id.eq(session.bot_id))
|
||||
.select(name)
|
||||
.first::<String>(&mut db_conn)
|
||||
.unwrap_or_else(|_| "default".to_string())
|
||||
} else {
|
||||
"default".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = inject_kb_context(
|
||||
kb_manager.clone(),
|
||||
self.state.conn.clone(),
|
||||
session_id,
|
||||
&bot_name_for_kb,
|
||||
&message_content,
|
||||
&mut messages,
|
||||
8000,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Failed to inject KB context: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (stream_tx, mut stream_rx) = mpsc::channel::<String>(100);
|
||||
let llm = self.state.llm_provider.clone();
|
||||
|
|
@ -139,6 +312,16 @@ impl BotOrchestrator {
|
|||
let model_clone = model.clone();
|
||||
let key_clone = key.clone();
|
||||
let messages_clone = messages.clone();
|
||||
|
||||
// DEBUG: Log exact values being passed to LLM
|
||||
info!("[LLM_CALL] Calling generate_stream with:");
|
||||
info!("[LLM_CALL] Model: '{}'", model_clone);
|
||||
info!("[LLM_CALL] Key length: {} chars", key_clone.len());
|
||||
info!("[LLM_CALL] Key preview: '{}...{}'",
|
||||
&key_clone.chars().take(8).collect::<String>(),
|
||||
&key_clone.chars().rev().take(8).collect::<String>()
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm
|
||||
.generate_stream("", &messages_clone, stream_tx, &model_clone, &key_clone)
|
||||
|
|
@ -388,10 +571,18 @@ pub async fn websocket_handler(
|
|||
let conn = state.conn.get().ok();
|
||||
if let Some(mut db_conn) = conn {
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
let result: Result<Uuid, _> = bots
|
||||
.filter(name.eq(&bot_name))
|
||||
.select(id)
|
||||
.first(&mut db_conn);
|
||||
|
||||
// Try to parse as UUID first, if that fails treat as bot name
|
||||
let result: Result<Uuid, _> = if let Ok(uuid) = Uuid::parse_str(&bot_name) {
|
||||
// Parameter is a UUID, look up by id
|
||||
bots.filter(id.eq(uuid)).select(id).first(&mut db_conn)
|
||||
} else {
|
||||
// Parameter is a bot name, look up by name
|
||||
bots.filter(name.eq(&bot_name))
|
||||
.select(id)
|
||||
.first(&mut db_conn)
|
||||
};
|
||||
|
||||
result.unwrap_or_else(|_| {
|
||||
log::warn!("Bot not found: {}, using nil bot_id", bot_name);
|
||||
Uuid::nil()
|
||||
|
|
@ -427,8 +618,8 @@ async fn handle_websocket(
|
|||
}
|
||||
|
||||
info!(
|
||||
"WebSocket connected for session: {}, user: {}",
|
||||
session_id, user_id
|
||||
"WebSocket connected for session: {}, user: {}, bot: {}",
|
||||
session_id, user_id, bot_id
|
||||
);
|
||||
|
||||
let welcome = serde_json::json!({
|
||||
|
|
@ -445,6 +636,89 @@ async fn handle_websocket(
|
|||
}
|
||||
}
|
||||
|
||||
// Execute start.bas automatically on connection (similar to auth.ast pattern)
|
||||
{
|
||||
let bot_name_result = {
|
||||
let conn = state.conn.get().ok();
|
||||
if let Some(mut db_conn) = conn {
|
||||
use crate::shared::models::schema::bots::dsl::*;
|
||||
bots.filter(id.eq(bot_id))
|
||||
.select(name)
|
||||
.first::<String>(&mut db_conn)
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// DEBUG: Log start script execution attempt
|
||||
info!(
|
||||
"Checking for start.bas: bot_id={}, bot_name_result={:?}",
|
||||
bot_id,
|
||||
bot_name_result
|
||||
);
|
||||
|
||||
if let Some(bot_name) = bot_name_result {
|
||||
let start_script_path = format!("./work/{}.gbai/{}.gbdialog/start.bas", bot_name, bot_name);
|
||||
|
||||
info!("Looking for start.bas at: {}", start_script_path);
|
||||
|
||||
if let Ok(metadata) = tokio::fs::metadata(&start_script_path).await {
|
||||
if metadata.is_file() {
|
||||
info!("Found start.bas file, reading contents...");
|
||||
if let Ok(start_script) = tokio::fs::read_to_string(&start_script_path).await {
|
||||
info!(
|
||||
"Executing start.bas for bot {} on session {}",
|
||||
bot_name, session_id
|
||||
);
|
||||
|
||||
let state_for_start = state.clone();
|
||||
let _tx_for_start = tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let session_result = {
|
||||
let mut sm = state_for_start.session_manager.lock().await;
|
||||
sm.get_session_by_id(session_id)
|
||||
};
|
||||
|
||||
if let Ok(Some(session)) = session_result {
|
||||
info!("Executing start.bas for bot {} on session {}", bot_name, session_id);
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let mut script_service = crate::basic::ScriptService::new(
|
||||
state_for_start.clone(),
|
||||
session.clone()
|
||||
);
|
||||
script_service.load_bot_config_params(&state_for_start, bot_id);
|
||||
|
||||
match script_service.compile(&start_script) {
|
||||
Ok(ast) => match script_service.run(&ast) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(format!("Script execution error: {}", e)),
|
||||
},
|
||||
Err(e) => Err(format!("Script compilation error: {}", e)),
|
||||
}
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {
|
||||
info!("start.bas executed successfully for bot {}", bot_name);
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
error!("start.bas error for bot {}: {}", bot_name, e);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("start.bas task error for bot {}: {}", bot_name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
while let Some(response) = rx.recv().await {
|
||||
if let Ok(json_str) = serde_json::to_string(&response) {
|
||||
|
|
|
|||
|
|
@ -362,14 +362,55 @@ impl ConfigManager {
|
|||
use crate::shared::models::schema::bot_configuration::dsl::*;
|
||||
let mut conn = self.get_conn()?;
|
||||
let fallback_str = fallback.unwrap_or("");
|
||||
|
||||
// Helper function to check if a value should be treated as "not configured"
|
||||
fn is_placeholder_value(value: &str) -> bool {
|
||||
let trimmed = value.trim().to_lowercase();
|
||||
trimmed.is_empty() || trimmed == "none" || trimmed == "null" || trimmed == "n/a"
|
||||
}
|
||||
|
||||
// Helper function to check if a value is a local file path (for local LLM server)
|
||||
// These should fall back to default bot's config when using remote API
|
||||
fn is_local_file_path(value: &str) -> bool {
|
||||
let value = value.trim();
|
||||
// Check for file path patterns
|
||||
value.starts_with("../") ||
|
||||
value.starts_with("./") ||
|
||||
value.starts_with('/') ||
|
||||
value.starts_with("~") ||
|
||||
value.contains(".gguf") ||
|
||||
value.contains(".bin") ||
|
||||
value.contains(".safetensors") ||
|
||||
value.starts_with("data/") ||
|
||||
value.starts_with("../../") ||
|
||||
value.starts_with("models/")
|
||||
}
|
||||
|
||||
// Try to get value for the specific bot
|
||||
let result = bot_configuration
|
||||
.filter(bot_id.eq(code_bot_id))
|
||||
.filter(config_key.eq(key))
|
||||
.select(config_value)
|
||||
.first::<String>(&mut conn);
|
||||
|
||||
let value = match result {
|
||||
Ok(v) => v,
|
||||
Ok(v) => {
|
||||
// Check if it's a placeholder value or local file path - if so, fall back to default bot
|
||||
// Local file paths are valid for local LLM server but NOT for remote APIs
|
||||
if is_placeholder_value(&v) || is_local_file_path(&v) {
|
||||
let (default_bot_id, _default_bot_name) = crate::bot::get_default_bot(&mut conn);
|
||||
bot_configuration
|
||||
.filter(bot_id.eq(default_bot_id))
|
||||
.filter(config_key.eq(key))
|
||||
.select(config_value)
|
||||
.first::<String>(&mut conn)
|
||||
.unwrap_or_else(|_| fallback_str.to_string())
|
||||
} else {
|
||||
v
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Value not found, fall back to default bot
|
||||
let (default_bot_id, _default_bot_name) = crate::bot::get_default_bot(&mut conn);
|
||||
bot_configuration
|
||||
.filter(bot_id.eq(default_bot_id))
|
||||
|
|
@ -379,7 +420,15 @@ impl ConfigManager {
|
|||
.unwrap_or_else(|_| fallback_str.to_string())
|
||||
}
|
||||
};
|
||||
Ok(value)
|
||||
|
||||
// Final check: if the result is still a placeholder value, use the fallback_str
|
||||
let final_value = if is_placeholder_value(&value) {
|
||||
fallback_str.to_string()
|
||||
} else {
|
||||
value
|
||||
};
|
||||
|
||||
Ok(final_value)
|
||||
}
|
||||
|
||||
pub fn get_bot_config_value(
|
||||
|
|
|
|||
237
src/llm/mod.rs
237
src/llm/mod.rs
|
|
@ -1,6 +1,6 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::{info, trace};
|
||||
use log::{error, info};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
|
|
@ -44,13 +44,116 @@ pub trait LLMProvider: Send + Sync {
|
|||
pub struct OpenAIClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
endpoint_path: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
pub fn new(_api_key: String, base_url: Option<String>) -> Self {
|
||||
/// Estimates token count for a text string (roughly 4 characters per token for English)
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
// Rough estimate: ~4 characters per token for English text
|
||||
// This is a heuristic and may not be accurate for all languages
|
||||
text.len().div_ceil(4)
|
||||
}
|
||||
|
||||
/// Estimates total tokens for a messages array
|
||||
fn estimate_messages_tokens(messages: &Value) -> usize {
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
msg_array
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
Self::estimate_tokens(content)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates messages to fit within the max_tokens limit
|
||||
/// Keeps system messages and the most recent user/assistant messages
|
||||
fn truncate_messages(messages: &Value, max_tokens: usize) -> Value {
|
||||
let mut result = Vec::new();
|
||||
let mut token_count = 0;
|
||||
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
// First pass: keep all system messages
|
||||
for msg in msg_array {
|
||||
if let Some(role) = msg.get("role").and_then(|r| r.as_str()) {
|
||||
if role == "system" {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
let msg_tokens = Self::estimate_tokens(content);
|
||||
if token_count + msg_tokens <= max_tokens {
|
||||
result.push(msg.clone());
|
||||
token_count += msg_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: add user/assistant messages from newest to oldest
|
||||
let mut recent_messages: Vec<&Value> = msg_array
|
||||
.iter()
|
||||
.filter(|msg| msg.get("role").and_then(|r| r.as_str()) != Some("system"))
|
||||
.collect();
|
||||
|
||||
// Reverse to get newest first
|
||||
recent_messages.reverse();
|
||||
|
||||
for msg in recent_messages {
|
||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
||||
let msg_tokens = Self::estimate_tokens(content);
|
||||
if token_count + msg_tokens <= max_tokens {
|
||||
result.push(msg.clone());
|
||||
token_count += msg_tokens;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse back to chronological order for non-system messages
|
||||
// But keep system messages at the beginning
|
||||
let system_count = result.len()
|
||||
- result
|
||||
.iter()
|
||||
.filter(|m| m.get("role").and_then(|r| r.as_str()) != Some("system"))
|
||||
.count();
|
||||
let mut user_messages: Vec<Value> = result.drain(system_count..).collect();
|
||||
user_messages.reverse();
|
||||
result.extend(user_messages);
|
||||
}
|
||||
|
||||
serde_json::Value::Array(result)
|
||||
}
|
||||
|
||||
/// Ensures messages fit within model's context limit
|
||||
fn ensure_token_limit(messages: &Value, model_context_limit: usize) -> Value {
|
||||
let estimated_tokens = Self::estimate_messages_tokens(messages);
|
||||
|
||||
// Use 90% of context limit to leave room for response
|
||||
let safe_limit = (model_context_limit as f64 * 0.9) as usize;
|
||||
|
||||
if estimated_tokens > safe_limit {
|
||||
log::warn!(
|
||||
"Messages exceed token limit ({} > {}), truncating...",
|
||||
estimated_tokens,
|
||||
safe_limit
|
||||
);
|
||||
Self::truncate_messages(messages, safe_limit)
|
||||
} else {
|
||||
messages.clone()
|
||||
}
|
||||
}
|
||||
pub fn new(_api_key: String, base_url: Option<String>, endpoint_path: Option<String>) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()),
|
||||
endpoint_path: endpoint_path.unwrap_or_else(|| "/v1/chat/completions".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -92,21 +195,64 @@ impl LLMProvider for OpenAIClient {
|
|||
key: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
// Get the messages to use
|
||||
let raw_messages =
|
||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
// Ensure messages fit within model's context limit
|
||||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else if model.starts_with("http://localhost:808") || model == "local" {
|
||||
768 // Local llama.cpp server context limit
|
||||
} else {
|
||||
4096 // Default conservative limit
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
||||
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
|
||||
let auth_header = format!("Bearer {}", key);
|
||||
|
||||
// Debug logging to help troubleshoot 401 errors
|
||||
info!("LLM Request Details:");
|
||||
info!(" URL: {}", full_url);
|
||||
info!(" Authorization: Bearer <{} chars>", key.len());
|
||||
info!(" Model: {}", model);
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
info!(" Messages: {} messages", msg_array.len());
|
||||
}
|
||||
info!(" API Key First 8 chars: '{}...'", &key.chars().take(8).collect::<String>());
|
||||
info!(" API Key Last 8 chars: '...{}'", &key.chars().rev().take(8).collect::<String>());
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/v1/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.post(&full_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
}
|
||||
"messages": messages,
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
error!("LLM generate error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
let result: Value = response.json().await?;
|
||||
let raw_content = result["choices"][0]["message"]["content"]
|
||||
.as_str()
|
||||
|
|
@ -127,18 +273,51 @@ impl LLMProvider for OpenAIClient {
|
|||
key: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
|
||||
// Get the messages to use
|
||||
let raw_messages =
|
||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
info!("Using provided messages: {:?}", messages);
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
};
|
||||
|
||||
// Ensure messages fit within model's context limit
|
||||
// GLM-4.7 has 202750 tokens, other models vary
|
||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
||||
202750
|
||||
} else if model.contains("gpt-4") {
|
||||
128000
|
||||
} else if model.contains("gpt-3.5") {
|
||||
16385
|
||||
} else if model.starts_with("http://localhost:808") || model == "local" {
|
||||
768 // Local llama.cpp server context limit
|
||||
} else {
|
||||
4096 // Default conservative limit
|
||||
};
|
||||
|
||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
||||
|
||||
let full_url = format!("{}{}", self.base_url, self.endpoint_path);
|
||||
let auth_header = format!("Bearer {}", key);
|
||||
|
||||
// Debug logging to help troubleshoot 401 errors
|
||||
info!("LLM Request Details:");
|
||||
info!(" URL: {}", full_url);
|
||||
info!(" Authorization: Bearer <{} chars>", key.len());
|
||||
info!(" Model: {}", model);
|
||||
if let Some(msg_array) = messages.as_array() {
|
||||
info!(" Messages: {} messages", msg_array.len());
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/v1/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.post(&full_url)
|
||||
.header("Authorization", &auth_header)
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
||||
info!("Using provided messages: {:?}", messages);
|
||||
messages
|
||||
} else {
|
||||
&default_messages
|
||||
},
|
||||
"messages": messages,
|
||||
"stream": true
|
||||
}))
|
||||
.send()
|
||||
|
|
@ -147,7 +326,7 @@ impl LLMProvider for OpenAIClient {
|
|||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
trace!("LLM generate_stream error: {}", error_text);
|
||||
error!("LLM generate_stream error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
|
||||
|
|
@ -213,11 +392,16 @@ pub fn create_llm_provider(
|
|||
provider_type: LLMProviderType,
|
||||
base_url: String,
|
||||
deployment_name: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
match provider_type {
|
||||
LLMProviderType::OpenAI => {
|
||||
info!("Creating OpenAI LLM provider with URL: {}", base_url);
|
||||
std::sync::Arc::new(OpenAIClient::new("empty".to_string(), Some(base_url)))
|
||||
std::sync::Arc::new(OpenAIClient::new(
|
||||
"empty".to_string(),
|
||||
Some(base_url),
|
||||
endpoint_path,
|
||||
))
|
||||
}
|
||||
LLMProviderType::Claude => {
|
||||
info!("Creating Claude LLM provider with URL: {}", base_url);
|
||||
|
|
@ -237,9 +421,10 @@ pub fn create_llm_provider(
|
|||
pub fn create_llm_provider_from_url(
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) -> std::sync::Arc<dyn LLMProvider> {
|
||||
let provider_type = LLMProviderType::from(url);
|
||||
create_llm_provider(provider_type, url.to_string(), model)
|
||||
create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
|
||||
}
|
||||
|
||||
pub struct DynamicLLMProvider {
|
||||
|
|
@ -259,8 +444,13 @@ impl DynamicLLMProvider {
|
|||
info!("LLM provider updated dynamically");
|
||||
}
|
||||
|
||||
pub async fn update_from_config(&self, url: &str, model: Option<String>) {
|
||||
let new_provider = create_llm_provider_from_url(url, model);
|
||||
pub async fn update_from_config(
|
||||
&self,
|
||||
url: &str,
|
||||
model: Option<String>,
|
||||
endpoint_path: Option<String>,
|
||||
) {
|
||||
let new_provider = create_llm_provider_from_url(url, model, endpoint_path);
|
||||
self.update_provider(new_provider).await;
|
||||
}
|
||||
|
||||
|
|
@ -490,7 +680,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_openai_client_new_default_url() {
|
||||
let client = OpenAIClient::new("test_key".to_string(), None);
|
||||
let client = OpenAIClient::new("test_key".to_string(), None, None);
|
||||
assert_eq!(client.base_url, "https://api.openai.com");
|
||||
}
|
||||
|
||||
|
|
@ -499,6 +689,7 @@ mod tests {
|
|||
let client = OpenAIClient::new(
|
||||
"test_key".to_string(),
|
||||
Some("http://localhost:8080".to_string()),
|
||||
None,
|
||||
);
|
||||
assert_eq!(client.base_url, "http://localhost:8080");
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue