feat(llm): add message parsing for OpenAI client
Added parse_messages method to handle structured prompt input for OpenAI API. The method converts human/bot/compact prefixes to appropriate OpenAI roles (user/assistant/system) and properly formats multi-line messages. This enables more complex conversation structures in prompts while maintaining compatibility with the OpenAI API format. Removed the direct prompt-to-message conversion in generate and generate_stream methods, replacing it with the new parse_messages utility. Also reorganized the impl blocks for better code organization.
This commit is contained in:
parent
349bdd7984
commit
035d867c2f
1 changed files with 59 additions and 11 deletions
|
|
@ -34,15 +34,6 @@ pub struct OpenAIClient {
|
||||||
api_key: String,
|
api_key: String,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
}
|
}
|
||||||
impl OpenAIClient {
|
|
||||||
pub fn new(api_key: String, base_url: Option<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
client: reqwest::Client::new(),
|
|
||||||
api_key,
|
|
||||||
base_url: base_url.unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl LLMProvider for OpenAIClient {
|
impl LLMProvider for OpenAIClient {
|
||||||
async fn generate(
|
async fn generate(
|
||||||
|
|
@ -50,13 +41,15 @@ impl LLMProvider for OpenAIClient {
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
_config: &Value,
|
_config: &Value,
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let messages = self.parse_messages(prompt);
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("{}/v1/chat/completions/", self.base_url))
|
.post(&format!("{}/v1/chat/completions/", self.base_url))
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&serde_json::json!({
|
.json(&serde_json::json!({
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": messages,
|
||||||
"max_tokens": 1000
|
"max_tokens": 1000
|
||||||
}))
|
}))
|
||||||
.send()
|
.send()
|
||||||
|
|
@ -79,13 +72,15 @@ impl LLMProvider for OpenAIClient {
|
||||||
_config: &Value,
|
_config: &Value,
|
||||||
tx: mpsc::Sender<String>,
|
tx: mpsc::Sender<String>,
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let messages = self.parse_messages(prompt);
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(&format!("{}/v1/chat/completions", self.base_url))
|
.post(&format!("{}/v1/chat/completions", self.base_url))
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.json(&serde_json::json!({
|
.json(&serde_json::json!({
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": messages,
|
||||||
"max_tokens": 1000,
|
"max_tokens": 1000,
|
||||||
"stream": true
|
"stream": true
|
||||||
}))
|
}))
|
||||||
|
|
@ -116,3 +111,56 @@ impl LLMProvider for OpenAIClient {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl OpenAIClient {
|
||||||
|
pub fn new(api_key: String, base_url: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
api_key,
|
||||||
|
base_url: base_url.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_messages(&self, prompt: &str) -> Vec<Value> {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
let mut current_role = None;
|
||||||
|
let mut current_content = String::new();
|
||||||
|
|
||||||
|
for line in prompt.lines() {
|
||||||
|
if let Some(role_end) = line.find(':') {
|
||||||
|
let role_part = &line[..role_end].trim().to_lowercase();
|
||||||
|
let role = match role_part.as_str() {
|
||||||
|
"human" => "user",
|
||||||
|
"bot" => "assistant",
|
||||||
|
"compact" => "system",
|
||||||
|
_ => continue
|
||||||
|
};
|
||||||
|
{
|
||||||
|
if let Some(r) = current_role.take() {
|
||||||
|
messages.push(serde_json::json!({
|
||||||
|
"role": r,
|
||||||
|
"content": current_content.trim()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
current_role = Some(role);
|
||||||
|
current_content = line[role_end + 1..].trim_start().to_string();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(_) = current_role {
|
||||||
|
if !current_content.is_empty() {
|
||||||
|
current_content.push('\n');
|
||||||
|
}
|
||||||
|
current_content.push_str(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(role) = current_role {
|
||||||
|
messages.push(serde_json::json!({
|
||||||
|
"role": role,
|
||||||
|
"content": current_content.trim()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue