diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 189dd08c..6e6c03b4 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -34,15 +34,6 @@ pub struct OpenAIClient { api_key: String, base_url: String, } -impl OpenAIClient { - pub fn new(api_key: String, base_url: Option) -> Self { - Self { - client: reqwest::Client::new(), - api_key, - base_url: base_url.unwrap() - } - } -} #[async_trait] impl LLMProvider for OpenAIClient { async fn generate( @@ -50,13 +41,15 @@ impl LLMProvider for OpenAIClient { prompt: &str, _config: &Value, ) -> Result> { + let messages = self.parse_messages(prompt); + let response = self .client .post(&format!("{}/v1/chat/completions/", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&serde_json::json!({ "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": prompt}], + "messages": messages, "max_tokens": 1000 })) .send() @@ -79,13 +72,15 @@ impl LLMProvider for OpenAIClient { _config: &Value, tx: mpsc::Sender, ) -> Result<(), Box> { + let messages = self.parse_messages(prompt); + let response = self .client .post(&format!("{}/v1/chat/completions", self.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&serde_json::json!({ "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": prompt}], + "messages": messages, "max_tokens": 1000, "stream": true })) @@ -116,3 +111,56 @@ impl LLMProvider for OpenAIClient { Ok(()) } } + +impl OpenAIClient { + pub fn new(api_key: String, base_url: Option) -> Self { + Self { + client: reqwest::Client::new(), + api_key, + base_url: base_url.unwrap() + } + } + + fn parse_messages(&self, prompt: &str) -> Vec { + let mut messages = Vec::new(); + let mut current_role = None; + let mut current_content = String::new(); + + for line in prompt.lines() { + if let Some(role_end) = line.find(':') { + let role_part = &line[..role_end].trim().to_lowercase(); + let role = match role_part.as_str() { + "human" => "user", + "bot" => "assistant", + "compact" => "system", + _ => continue + }; + { + if let Some(r) = current_role.take() { + messages.push(serde_json::json!({ + "role": r, + "content": current_content.trim() + })); + } + current_role = Some(role); + current_content = line[role_end + 1..].trim_start().to_string(); + continue; + } + } + if let Some(_) = current_role { + if !current_content.is_empty() { + current_content.push('\n'); + } + current_content.push_str(line); + } + } + + if let Some(role) = current_role { + messages.push(serde_json::json!({ + "role": role, + "content": current_content.trim() + })); + } + messages + } +}