feat(llm): add support for API key in LLM generation calls
Include retrieval and passing of `llm-key` from configuration to LLM provider methods for secure authentication. Also refine role naming in compact prompts and remove unused logging import.
This commit is contained in:
parent
9bb8b64be7
commit
6c3812753f
6 changed files with 54 additions and 19 deletions
|
|
@ -102,14 +102,16 @@ async fn compact_prompt_for_bots(
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut conversation = String::new();
|
let mut conversation = String::new();
|
||||||
conversation.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
|
conversation
|
||||||
|
.push_str("Please summarize this conversation between user and bot: \n\n [[[***** \n");
|
||||||
|
|
||||||
for (role, content) in history.iter().skip(start_index) {
|
for (role, content) in history.iter().skip(start_index) {
|
||||||
if role == "compact" {
|
if role == "compact" {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
conversation.push_str(&format!("{}: {}\n",
|
conversation.push_str(&format!(
|
||||||
if role == "user" { "User" } else { "Bot" },
|
"{}: {}\n",
|
||||||
|
if role == "user" { "user" } else { "assistant" },
|
||||||
content
|
content
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
@ -124,10 +126,17 @@ async fn compact_prompt_for_bots(
|
||||||
trace!("Starting summarization for session {}", session.id);
|
trace!("Starting summarization for session {}", session.id);
|
||||||
let mut filtered = String::new();
|
let mut filtered = String::new();
|
||||||
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
||||||
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
|
let model = config_manager
|
||||||
|
.get_config(&Uuid::nil(), "llm-model", None)
|
||||||
|
.unwrap_or_default();
|
||||||
|
let key = config_manager
|
||||||
|
.get_config(&Uuid::nil(), "llm-key", None)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
let summarized = match llm_provider.generate(
|
let summarized = match llm_provider
|
||||||
"", &serde_json::Value::Array(messages), &model).await {
|
.generate("", &serde_json::Value::Array(messages), &model, &key)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(summary) => {
|
Ok(summary) => {
|
||||||
trace!(
|
trace!(
|
||||||
"Successfully summarized session {} ({} chars)",
|
"Successfully summarized session {} ({} chars)",
|
||||||
|
|
@ -138,7 +147,8 @@ async fn compact_prompt_for_bots(
|
||||||
let handler = llm_models::get_handler(
|
let handler = llm_models::get_handler(
|
||||||
config_manager
|
config_manager
|
||||||
.get_config(&session.bot_id, "llm-model", None)
|
.get_config(&session.bot_id, "llm-model", None)
|
||||||
.unwrap().as_str(),
|
.unwrap()
|
||||||
|
.as_str(),
|
||||||
);
|
);
|
||||||
|
|
||||||
filtered = handler.process_content(&summary);
|
filtered = handler.process_content(&summary);
|
||||||
|
|
|
||||||
|
|
@ -44,8 +44,10 @@ fn build_llm_prompt(user_text: &str) -> String {
|
||||||
pub async fn execute_llm_generation(state: Arc<AppState>, prompt: String) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
pub async fn execute_llm_generation(state: Arc<AppState>, prompt: String) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
let config_manager = crate::config::ConfigManager::new(state.conn.clone());
|
||||||
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
|
let model = config_manager.get_config(&Uuid::nil(), "llm-model", None).unwrap_or_default();
|
||||||
|
let key = config_manager.get_config(&Uuid::nil(), "llm-key", None).unwrap_or_default();
|
||||||
|
|
||||||
let handler = crate::llm_models::get_handler(&model);
|
let handler = crate::llm_models::get_handler(&model);
|
||||||
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model).await?;
|
let raw_response = state.llm_provider.generate(&prompt, &serde_json::Value::Null, &model, &key).await?;
|
||||||
let processed = handler.process_content(&raw_response);
|
let processed = handler.process_content(&raw_response);
|
||||||
Ok(processed)
|
Ok(processed)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ use anyhow::Result;
|
||||||
use aws_config::BehaviorVersion;
|
use aws_config::BehaviorVersion;
|
||||||
use aws_sdk_s3::Client;
|
use aws_sdk_s3::Client;
|
||||||
use diesel::connection::SimpleConnection;
|
use diesel::connection::SimpleConnection;
|
||||||
use log::{error, info, trace};
|
use log::{error, trace};
|
||||||
use rand::distr::Alphanumeric;
|
use rand::distr::Alphanumeric;
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
|
||||||
|
|
@ -355,6 +355,11 @@ impl BotOrchestrator {
|
||||||
.rposition(|(role, _content)| role == "compact")
|
.rposition(|(role, _content)| role == "compact")
|
||||||
{
|
{
|
||||||
history = history.split_off(last_compacted_index);
|
history = history.split_off(last_compacted_index);
|
||||||
|
for (role, content) in history.iter_mut() {
|
||||||
|
if role == "compact" {
|
||||||
|
*role = "user".to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if history_limit > 0 && history.len() > history_limit as usize {
|
if history_limit > 0 && history.len() > history_limit as usize {
|
||||||
let start = history.len() - history_limit as usize;
|
let start = history.len() - history_limit as usize;
|
||||||
|
|
@ -405,9 +410,16 @@ impl BotOrchestrator {
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
let key = config_manager
|
||||||
|
.get_config(
|
||||||
|
&Uuid::parse_str(&message.bot_id).unwrap_or_default(),
|
||||||
|
"llm-key",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.unwrap_or_default();
|
||||||
let model1 = model.clone();
|
let model1 = model.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model).await {
|
if let Err(e) = llm.generate_stream("", &messages, stream_tx, &model, &key).await {
|
||||||
error!("LLM streaming error: {}", e);
|
error!("LLM streaming error: {}", e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use log::info;
|
use log::{info, trace};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
pub mod local;
|
pub mod local;
|
||||||
|
|
@ -11,13 +11,15 @@ pub trait LLMProvider: Send + Sync {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
config: &Value,
|
config: &Value,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
key: &str
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
config: &Value,
|
config: &Value,
|
||||||
tx: mpsc::Sender<String>,
|
tx: mpsc::Sender<String>,
|
||||||
model: &str
|
model: &str,
|
||||||
|
key: &str
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
||||||
async fn cancel_job(
|
async fn cancel_job(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -37,12 +39,14 @@ impl LLMProvider for OpenAIClient {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
messages: &Value,
|
messages: &Value,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
key: &str
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||||
let response = self
|
let response =
|
||||||
|
self
|
||||||
.client
|
.client
|
||||||
.post(&format!("{}/v1/chat/completions", self.base_url))
|
.post(&format!("{}/v1/chat/completions", self.base_url))
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", key))
|
||||||
.json(&serde_json::json!({
|
.json(&serde_json::json!({
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
||||||
|
|
@ -70,13 +74,14 @@ impl LLMProvider for OpenAIClient {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
messages: &Value,
|
messages: &Value,
|
||||||
tx: mpsc::Sender<String>,
|
tx: mpsc::Sender<String>,
|
||||||
model: &str
|
model: &str,
|
||||||
|
key: &str
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("{}/v1/chat/completions", self.base_url))
|
.post(&format!("{}/v1/chat/completions", self.base_url))
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", key))
|
||||||
.json(&serde_json::json!({
|
.json(&serde_json::json!({
|
||||||
"model": model.clone(),
|
"model": model.clone(),
|
||||||
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
"messages": if messages.is_array() && !messages.as_array().unwrap().is_empty() {
|
||||||
|
|
@ -89,6 +94,12 @@ impl LLMProvider for OpenAIClient {
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
let status = response.status();
|
||||||
|
if status != reqwest::StatusCode::OK {
|
||||||
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
|
trace!("LLM generate_stream error: {}", error_text);
|
||||||
|
return Err(format!("LLM request failed with status: {}", status).into());
|
||||||
|
}
|
||||||
let mut stream = response.bytes_stream();
|
let mut stream = response.bytes_stream();
|
||||||
let mut buffer = String::new();
|
let mut buffer = String::new();
|
||||||
while let Some(chunk) = stream.next().await {
|
while let Some(chunk) = stream.next().await {
|
||||||
|
|
|
||||||
|
|
@ -274,7 +274,7 @@ impl SessionManager {
|
||||||
for (other_role, content) in messages {
|
for (other_role, content) in messages {
|
||||||
let role_str = match other_role {
|
let role_str = match other_role {
|
||||||
1 => "user".to_string(),
|
1 => "user".to_string(),
|
||||||
2 => "bot".to_string(),
|
2 => "assistant".to_string(),
|
||||||
3 => "system".to_string(),
|
3 => "system".to_string(),
|
||||||
9 => "compact".to_string(),
|
9 => "compact".to_string(),
|
||||||
_ => "unknown".to_string(),
|
_ => "unknown".to_string(),
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue