feat(llm): add support for API key in LLM generation calls

Include retrieval and passing of `llm-key` from configuration to LLM provider methods for secure authentication. Also refine role naming in compact prompts and remove unused logging import.
This commit is contained in:
Rodrigo Rodriguez (Pragmatismo) 2025-11-12 09:25:10 -03:00
parent 9bb8b64be7
commit 6c3812753f
6 changed files with 54 additions and 19 deletions

View file

@ -102,14 +102,16 @@ async fn compact_prompt_for_bots(
);
let mut conversation = String::new();
conversation.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
conversation
.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
for (role, content) in history.iter().skip(start_index) {
if role == "compact" {
continue;
}
conversation.push_str(&format!("{}: {}\n",
if role == "user" { "User" } else { "Bot" },
conversation.push_str(&format!(
"{}: {}\n",
if role == "user" { "user" } else { "assistant" },
content
));
}
@ -123,11 +125,18 @@ async fn compact_prompt_for_bots(
let llm_provider = state.llm_provider.clone();
trace!("Starting summarization for session {}", session.id);
let mut filtered = String::new();
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
let model = config_manager
.get_config(&Uuid::nil(), "llm-model", None)
.unwrap_or_default();
let key = config_manager
.get_config(&Uuid::nil(), "llm-key", None)
.unwrap_or_default();
let summarized = match llm_provider.generate(
"", &serde_json::Value::Array(messages), &model).await {
let summarized = match llm_provider
.generate("", &serde_json::Value::Array(messages), &model, &key)
.await
{
Ok(summary) => {
trace!(
"Successfully summarized session {} ({} chars)",
@ -138,7 +147,8 @@ async fn compact_prompt_for_bots(
let handler = llm_models::get_handler(
config_manager
.get_config(&session.bot_id, "llm-model", None)
.unwrap().as_str(),
.unwrap()
.as_str(),
);
filtered = handler.process_content(&summary);

View file

@ -44,8 +44,10 @@ fn build_llm_prompt(user_text: &str) -> String {
pub async fn execute_llm_generation(state: Arc<AppState>, prompt: String) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
let key = config_manager.get_config(&Uuid::nil(), "llm-key", None).unwrap_or_default();
let handler = crate::llm_models::get_handler(&model);
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model).await?;
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model, &key).await?;
let processed = handler.process_content(&raw_response);
Ok(processed)
}

View file

@ -5,7 +5,7 @@ use anyhow::Result;
use aws_config::BehaviorVersion;
use aws_sdk_s3::Client;
use diesel::connection::SimpleConnection;
use log::{error, info, trace};
use log::{error, trace};
use rand::distr::Alphanumeric;
use std::io::{self, Write};
use std::path::Path;

View file

@ -355,6 +355,11 @@ impl BotOrchestrator {
.rposition(|(role, _content)| role == "compact")
{
history = history.split_off(last_compacted_index);
for (role, content) in history.iter_mut() {
if role == "compact" {
*role = "user".to_string();
}
}
}
if history_limit > 0 && history.len() > history_limit as usize {
let start = history.len() - history_limit as usize;
@ -405,9 +410,16 @@ impl BotOrchestrator {
None,
)
.unwrap_or_default();
let key = config_manager
.get_config(
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
"llm-key",
None,
)
.unwrap_or_default();
let model1 = model.clone();
tokio::spawn(async move {
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model).await {
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model, &key).await {
error!("LLM streaming error: {}", e);
}
});

View file

@ -1,6 +1,6 @@
use async_trait::async_trait;
use futures::StreamExt;
use log::info;
use log::{info, trace};
use serde_json::Value;
use tokio::sync::mpsc;
pub mod local;
@ -11,13 +11,15 @@ pub trait LLMProvider: Send + Sync {
prompt: &str,
config: &Value,
model: &str,
key: &str
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
async fn generate_stream(
&self,
prompt: &str,
config: &Value,
tx: mpsc::Sender<String>,
model: &str
model: &str,
key: &str
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
async fn cancel_job(
&self,
@ -37,12 +39,14 @@ impl LLMProvider for OpenAIClient {
prompt: &str,
messages: &Value,
model: &str,
key: &str
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let response = self
let response =
self
.client
.post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Authorization", format!("Bearer {}", key))
.json(&serde_json::json!({
"model": model,
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
@ -70,13 +74,14 @@ impl LLMProvider for OpenAIClient {
prompt: &str,
messages: &Value,
tx: mpsc::Sender<String>,
model: &str
model: &str,
key: &str
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
let response = self
.client
.post(&format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Authorization", format!("Bearer {}", key))
.json(&serde_json::json!({
"model": model.clone(),
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
@ -89,6 +94,12 @@ impl LLMProvider for OpenAIClient {
}))
.send()
.await?;
let status = response.status();
if status != reqwest::StatusCode::OK {
let error_text = response.text().await.unwrap_or_default();
trace!("LLM generate_stream error: {}", error_text);
return Err(format!("LLM request failed with status: {}", status).into());
}
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {

View file

@ -274,7 +274,7 @@ impl SessionManager {
for (other_role, content) in messages {
let role_str = match other_role {
1 => "user".to_string(),
2 => "bot".to_string(),
2 => "assistant".to_string(),
3 => "system".to_string(),
9 => "compact".to_string(),
_ => "unknown".to_string(),