feat(llm): add support for API key in LLM generation calls
Include retrieval and passing of `llm-key` from configuration to LLM provider methods for secure authentication. Also refine role naming in compact prompts and remove unused logging import.
This commit is contained in:
parent
9bb8b64be7
commit
6c3812753f
6 changed files with 54 additions and 19 deletions
|
|
@ -102,14 +102,16 @@ async fn compact_prompt_for_bots(
|
|||
);
|
||||
|
||||
let mut conversation = String::new();
|
||||
conversation.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
|
||||
|
||||
conversation
|
||||
.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
|
||||
|
||||
for (role, content) in history.iter().skip(start_index) {
|
||||
if role == "compact" {
|
||||
continue;
|
||||
}
|
||||
conversation.push_str(&format!("{}: {}\n",
|
||||
if role == "user" { "User" } else { "Bot" },
|
||||
conversation.push_str(&format!(
|
||||
"{}: {}\n",
|
||||
if role == "user" { "user" } else { "assistant" },
|
||||
content
|
||||
));
|
||||
}
|
||||
|
|
@ -123,11 +125,18 @@ async fn compact_prompt_for_bots(
|
|||
let llm_provider = state.llm_provider.clone();
|
||||
trace!("Starting summarization for session {}", session.id);
|
||||
let mut filtered = String::new();
|
||||
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
||||
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
|
||||
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
||||
let model = config_manager
|
||||
.get_config(&Uuid::nil(), "llm-model", None)
|
||||
.unwrap_or_default();
|
||||
let key = config_manager
|
||||
.get_config(&Uuid::nil(), "llm-key", None)
|
||||
.unwrap_or_default();
|
||||
|
||||
let summarized = match llm_provider.generate(
|
||||
"", &serde_json::Value::Array(messages), &model).await {
|
||||
let summarized = match llm_provider
|
||||
.generate("", &serde_json::Value::Array(messages), &model, &key)
|
||||
.await
|
||||
{
|
||||
Ok(summary) => {
|
||||
trace!(
|
||||
"Successfully summarized session {} ({} chars)",
|
||||
|
|
@ -138,7 +147,8 @@ async fn compact_prompt_for_bots(
|
|||
let handler = llm_models::get_handler(
|
||||
config_manager
|
||||
.get_config(&session.bot_id, "llm-model", None)
|
||||
.unwrap().as_str(),
|
||||
.unwrap()
|
||||
.as_str(),
|
||||
);
|
||||
|
||||
filtered = handler.process_content(&summary);
|
||||
|
|
|
|||
|
|
@ -44,8 +44,10 @@ fn build_llm_prompt(user_text: &str) -> String {
|
|||
pub async fn execute_llm_generation(state: Arc<AppState>, prompt: String) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
||||
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
|
||||
let key = config_manager.get_config(&Uuid::nil(), "llm-key", None).unwrap_or_default();
|
||||
|
||||
let handler = crate::llm_models::get_handler(&model);
|
||||
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model).await?;
|
||||
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model, &key).await?;
|
||||
let processed = handler.process_content(&raw_response);
|
||||
Ok(processed)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use anyhow::Result;
|
|||
use aws_config::BehaviorVersion;
|
||||
use aws_sdk_s3::Client;
|
||||
use diesel::connection::SimpleConnection;
|
||||
use log::{error, info, trace};
|
||||
use log::{error, trace};
|
||||
use rand::distr::Alphanumeric;
|
||||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
|
|
|
|||
|
|
@ -355,6 +355,11 @@ impl BotOrchestrator {
|
|||
.rposition(|(role, _content)| role == "compact")
|
||||
{
|
||||
history = history.split_off(last_compacted_index);
|
||||
for (role, content) in history.iter_mut() {
|
||||
if role == "compact" {
|
||||
*role = "user".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
if history_limit > 0 && history.len() > history_limit as usize {
|
||||
let start = history.len() - history_limit as usize;
|
||||
|
|
@ -405,9 +410,16 @@ impl BotOrchestrator {
|
|||
None,
|
||||
)
|
||||
.unwrap_or_default();
|
||||
let key = config_manager
|
||||
.get_config(
|
||||
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
|
||||
"llm-key",
|
||||
None,
|
||||
)
|
||||
.unwrap_or_default();
|
||||
let model1 = model.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model).await {
|
||||
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model, &key).await {
|
||||
error!("LLM streaming error: {}", e);
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use log::info;
|
||||
use log::{info, trace};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
pub mod local;
|
||||
|
|
@ -11,13 +11,15 @@ pub trait LLMProvider: Send + Sync {
|
|||
prompt: &str,
|
||||
config: &Value,
|
||||
model: &str,
|
||||
key: &str
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
prompt: &str,
|
||||
config: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str
|
||||
model: &str,
|
||||
key: &str
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||
async fn cancel_job(
|
||||
&self,
|
||||
|
|
@ -37,12 +39,14 @@ impl LLMProvider for OpenAIClient {
|
|||
prompt: &str,
|
||||
messages: &Value,
|
||||
model: &str,
|
||||
key: &str
|
||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
let response = self
|
||||
let response =
|
||||
self
|
||||
.client
|
||||
.post(&format!("{}/v1/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.json(&serde_json::json!({
|
||||
"model": model,
|
||||
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
||||
|
|
@ -70,13 +74,14 @@ impl LLMProvider for OpenAIClient {
|
|||
prompt: &str,
|
||||
messages: &Value,
|
||||
tx: mpsc::Sender<String>,
|
||||
model: &str
|
||||
model: &str,
|
||||
key: &str
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||
let response = self
|
||||
.client
|
||||
.post(&format!("{}/v1/chat/completions", self.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.header("Authorization", format!("Bearer {}", key))
|
||||
.json(&serde_json::json!({
|
||||
"model": model.clone(),
|
||||
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
||||
|
|
@ -89,6 +94,12 @@ impl LLMProvider for OpenAIClient {
|
|||
}))
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
if status != reqwest::StatusCode::OK {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
trace!("LLM generate_stream error: {}", error_text);
|
||||
return Err(format!("LLM request failed with status: {}", status).into());
|
||||
}
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ impl SessionManager {
|
|||
for (other_role, content) in messages {
|
||||
let role_str = match other_role {
|
||||
1 => "user".to_string(),
|
||||
2 => "bot".to_string(),
|
||||
2 => "assistant".to_string(),
|
||||
3 => "system".to_string(),
|
||||
9 => "compact".to_string(),
|
||||
_ => "unknown".to_string(),
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue