chore: Fix warnings and clean TODO refs
This commit is contained in:
parent
ac5b814536
commit
98813fbdc8
5 changed files with 10 additions and 727 deletions
|
|
@ -318,7 +318,7 @@ When a file grows beyond this limit:
|
||||||
| `attendance/llm_assist.rs` | 2053 | → 5 files |
|
| `attendance/llm_assist.rs` | 2053 | → 5 files |
|
||||||
| `drive/mod.rs` | 1522 | → 4 files |
|
| `drive/mod.rs` | 1522 | → 4 files |
|
||||||
|
|
||||||
**See `TODO-refactor1.md` for detailed refactoring plans**
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -465,7 +465,7 @@ We welcome contributions! Please read our contributing guidelines before submitt
|
||||||
|
|
||||||
1. **Replace 955 unwrap()/expect() calls** with proper error handling
|
1. **Replace 955 unwrap()/expect() calls** with proper error handling
|
||||||
2. **Optimize 12,973 clone()/to_string() calls** for performance
|
2. **Optimize 12,973 clone()/to_string() calls** for performance
|
||||||
3. **Refactor 5 large files** following TODO-refactor1.md
|
3. **Refactor 5 large files** following refactoring plan
|
||||||
4. **Add missing error handling** in critical paths
|
4. **Add missing error handling** in critical paths
|
||||||
5. **Implement proper logging** instead of panicking
|
5. **Implement proper logging** instead of panicking
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,10 +26,10 @@ pub struct MigrationResult {
|
||||||
|
|
||||||
/// Column metadata from database
|
/// Column metadata from database
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DbColumn {
|
pub struct DbColumn {
|
||||||
name: String,
|
pub name: String,
|
||||||
data_type: String,
|
pub data_type: String,
|
||||||
is_nullable: bool,
|
pub is_nullable: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compare and sync table schema with definition
|
/// Compare and sync table schema with definition
|
||||||
|
|
|
||||||
|
|
@ -714,6 +714,7 @@ impl ScriptService {
|
||||||
/// Convert FORMAT(expr, pattern) to FORMAT expr pattern (custom syntax format)
|
/// Convert FORMAT(expr, pattern) to FORMAT expr pattern (custom syntax format)
|
||||||
/// Also handles RANDOM and other functions that need space-separated arguments
|
/// Also handles RANDOM and other functions that need space-separated arguments
|
||||||
/// This properly handles nested function calls by counting parentheses
|
/// This properly handles nested function calls by counting parentheses
|
||||||
|
#[allow(dead_code)]
|
||||||
fn convert_format_syntax(script: &str) -> String {
|
fn convert_format_syntax(script: &str) -> String {
|
||||||
let mut result = String::new();
|
let mut result = String::new();
|
||||||
let mut chars = script.chars().peekable();
|
let mut chars = script.chars().peekable();
|
||||||
|
|
|
||||||
|
|
@ -1,720 +0,0 @@
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::StreamExt;
|
|
||||||
use log::{error, info};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::{mpsc, RwLock};
|
|
||||||
|
|
||||||
pub mod cache;
|
|
||||||
pub mod claude;
|
|
||||||
pub mod episodic_memory;
|
|
||||||
pub mod llm_models;
|
|
||||||
pub mod local;
|
|
||||||
pub mod smart_router;
|
|
||||||
|
|
||||||
pub use claude::ClaudeClient;
|
|
||||||
pub use llm_models::get_handler;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait LLMProvider: Send + Sync {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
|
|
||||||
async fn cancel_job(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct OpenAIClient {
|
|
||||||
client: reqwest::Client,
|
|
||||||
base_url: String,
|
|
||||||
endpoint_path: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIClient {
|
|
||||||
/// Estimates token count for a text string (roughly 4 characters per token for English)
|
|
||||||
fn estimate_tokens(text: &str) -> usize {
|
|
||||||
// Rough estimate: ~4 characters per token for English text
|
|
||||||
// This is a heuristic and may not be accurate for all languages
|
|
||||||
text.len().div_ceil(4)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimates total tokens for a messages array
|
|
||||||
fn estimate_messages_tokens(messages: &Value) -> usize {
|
|
||||||
if let Some(msg_array) = messages.as_array() {
|
|
||||||
msg_array
|
|
||||||
.iter()
|
|
||||||
.map(|msg| {
|
|
||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
|
||||||
Self::estimate_tokens(content)
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.sum()
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Truncates messages to fit within the max_tokens limit
|
|
||||||
/// Keeps system messages and the most recent user/assistant messages
|
|
||||||
fn truncate_messages(messages: &Value, max_tokens: usize) -> Value {
|
|
||||||
let mut result = Vec::new();
|
|
||||||
let mut token_count = 0;
|
|
||||||
|
|
||||||
if let Some(msg_array) = messages.as_array() {
|
|
||||||
// First pass: keep all system messages
|
|
||||||
for msg in msg_array {
|
|
||||||
if let Some(role) = msg.get("role").and_then(|r| r.as_str()) {
|
|
||||||
if role == "system" {
|
|
||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
|
||||||
let msg_tokens = Self::estimate_tokens(content);
|
|
||||||
if token_count + msg_tokens <= max_tokens {
|
|
||||||
result.push(msg.clone());
|
|
||||||
token_count += msg_tokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second pass: add user/assistant messages from newest to oldest
|
|
||||||
let mut recent_messages: Vec<&Value> = msg_array
|
|
||||||
.iter()
|
|
||||||
.filter(|msg| msg.get("role").and_then(|r| r.as_str()) != Some("system"))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Reverse to get newest first
|
|
||||||
recent_messages.reverse();
|
|
||||||
|
|
||||||
for msg in recent_messages {
|
|
||||||
if let Some(content) = msg.get("content").and_then(|c| c.as_str()) {
|
|
||||||
let msg_tokens = Self::estimate_tokens(content);
|
|
||||||
if token_count + msg_tokens <= max_tokens {
|
|
||||||
result.push(msg.clone());
|
|
||||||
token_count += msg_tokens;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reverse back to chronological order for non-system messages
|
|
||||||
// But keep system messages at the beginning
|
|
||||||
let system_count = result.len()
|
|
||||||
- result
|
|
||||||
.iter()
|
|
||||||
.filter(|m| m.get("role").and_then(|r| r.as_str()) != Some("system"))
|
|
||||||
.count();
|
|
||||||
let mut user_messages: Vec<Value> = result.drain(system_count..).collect();
|
|
||||||
user_messages.reverse();
|
|
||||||
result.extend(user_messages);
|
|
||||||
}
|
|
||||||
|
|
||||||
serde_json::Value::Array(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ensures messages fit within model's context limit
|
|
||||||
fn ensure_token_limit(messages: &Value, model_context_limit: usize) -> Value {
|
|
||||||
let estimated_tokens = Self::estimate_messages_tokens(messages);
|
|
||||||
|
|
||||||
// Use 90% of context limit to leave room for response
|
|
||||||
let safe_limit = (model_context_limit as f64 * 0.9) as usize;
|
|
||||||
|
|
||||||
if estimated_tokens > safe_limit {
|
|
||||||
log::warn!(
|
|
||||||
"Messages exceed token limit ({} > {}), truncating...",
|
|
||||||
estimated_tokens,
|
|
||||||
safe_limit
|
|
||||||
);
|
|
||||||
Self::truncate_messages(messages, safe_limit)
|
|
||||||
} else {
|
|
||||||
messages.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn new(_api_key: String, base_url: Option<String>, endpoint_path: Option<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
client: reqwest::Client::new(),
|
|
||||||
base_url: base_url.unwrap_or_else(|| "https://api.openai.com".to_string()),
|
|
||||||
endpoint_path: endpoint_path.unwrap_or_else(|| "/v1/chat/completions".to_string()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build_messages(
|
|
||||||
system_prompt: &str,
|
|
||||||
context_data: &str,
|
|
||||||
history: &[(String, String)],
|
|
||||||
) -> Value {
|
|
||||||
let mut messages = Vec::new();
|
|
||||||
if !system_prompt.is_empty() {
|
|
||||||
messages.push(serde_json::json!({
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
if !context_data.is_empty() {
|
|
||||||
messages.push(serde_json::json!({
|
|
||||||
"role": "system",
|
|
||||||
"content": context_data
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
for (role, content) in history {
|
|
||||||
messages.push(serde_json::json!({
|
|
||||||
"role": role,
|
|
||||||
"content": content
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
serde_json::Value::Array(messages)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for OpenAIClient {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
messages: &Value,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
|
||||||
|
|
||||||
// Get the messages to use
|
|
||||||
let raw_messages =
|
|
||||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
|
||||||
messages
|
|
||||||
} else {
|
|
||||||
&default_messages
|
|
||||||
};
|
|
||||||
|
|
||||||
// Ensure messages fit within model's context limit
|
|
||||||
// GLM-4.7 has 202750 tokens, other models vary
|
|
||||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
|
||||||
202750
|
|
||||||
} else if model.contains("gpt-4") {
|
|
||||||
128000
|
|
||||||
} else if model.contains("gpt-3.5") {
|
|
||||||
16385
|
|
||||||
} else {
|
|
||||||
model.starts_with("http://localhost:808") ? 768 : 4096 // Local llama.cpp or default limit
|
|
||||||
};
|
|
||||||
|
|
||||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}{}", self.base_url, self.endpoint_path))
|
|
||||||
.header("Authorization", format!("Bearer {}", key))
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"model": model,
|
|
||||||
"messages": messages
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
if status != reqwest::StatusCode::OK {
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
error!("LLM generate error: {}", error_text);
|
|
||||||
return Err(format!("LLM request failed with status: {}", status).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let result: Value = response.json().await?;
|
|
||||||
let raw_content = result["choices"][0]["message"]["content"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap_or("");
|
|
||||||
|
|
||||||
let handler = get_handler(model);
|
|
||||||
let content = handler.process_content(raw_content);
|
|
||||||
|
|
||||||
Ok(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
messages: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
let default_messages = serde_json::json!([{"role": "user", "content": prompt}]);
|
|
||||||
|
|
||||||
// Get the messages to use
|
|
||||||
let raw_messages =
|
|
||||||
if messages.is_array() && !messages.as_array().unwrap_or(&vec![]).is_empty() {
|
|
||||||
info!("Using provided messages: {:?}", messages);
|
|
||||||
messages
|
|
||||||
} else {
|
|
||||||
&default_messages
|
|
||||||
};
|
|
||||||
|
|
||||||
// Ensure messages fit within model's context limit
|
|
||||||
// GLM-4.7 has 202750 tokens, other models vary
|
|
||||||
let context_limit = if model.contains("glm-4") || model.contains("GLM-4") {
|
|
||||||
202750
|
|
||||||
} else if model.contains("gpt-4") {
|
|
||||||
128000
|
|
||||||
} else if model.contains("gpt-3.5") {
|
|
||||||
16385
|
|
||||||
} else {
|
|
||||||
model.starts_with("http://localhost:808") ? 768 : 4096 // Local llama.cpp or default limit
|
|
||||||
};
|
|
||||||
|
|
||||||
let messages = OpenAIClient::ensure_token_limit(raw_messages, context_limit);
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}{}", self.base_url, self.endpoint_path))
|
|
||||||
.header("Authorization", format!("Bearer {}", key))
|
|
||||||
.json(&serde_json::json!({
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": true
|
|
||||||
}))
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
if status != reqwest::StatusCode::OK {
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
error!("LLM generate_stream error: {}", error_text);
|
|
||||||
return Err(format!("LLM request failed with status: {}", status).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let handler = get_handler(model);
|
|
||||||
let mut stream = response.bytes_stream();
|
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
|
||||||
let chunk = chunk_result?;
|
|
||||||
let chunk_str = String::from_utf8_lossy(&chunk);
|
|
||||||
for line in chunk_str.lines() {
|
|
||||||
if line.starts_with("data: ") && !line.contains("[DONE]") {
|
|
||||||
if let Ok(data) = serde_json::from_str::<Value>(&line[6..]) {
|
|
||||||
if let Some(content) = data["choices"][0]["delta"]["content"].as_str() {
|
|
||||||
let processed = handler.process_content(content);
|
|
||||||
if !processed.is_empty() {
|
|
||||||
let _ = tx.send(processed).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cancel_job(
|
|
||||||
&self,
|
|
||||||
_session_id: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn start_llm_services(state: &std::sync::Arc<crate::shared::state::AppState>) {
|
|
||||||
episodic_memory::start_episodic_memory_scheduler(std::sync::Arc::clone(state));
|
|
||||||
info!("LLM services started (episodic memory scheduler)");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
||||||
pub enum LLMProviderType {
|
|
||||||
OpenAI,
|
|
||||||
Claude,
|
|
||||||
AzureClaude,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&str> for LLMProviderType {
|
|
||||||
fn from(s: &str) -> Self {
|
|
||||||
let lower = s.to_lowercase();
|
|
||||||
if lower.contains("claude") || lower.contains("anthropic") {
|
|
||||||
if lower.contains("azure") {
|
|
||||||
Self::AzureClaude
|
|
||||||
} else {
|
|
||||||
Self::Claude
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Self::OpenAI
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_llm_provider(
|
|
||||||
provider_type: LLMProviderType,
|
|
||||||
base_url: String,
|
|
||||||
deployment_name: Option<String>,
|
|
||||||
endpoint_path: Option<String>,
|
|
||||||
) -> std::sync::Arc<dyn LLMProvider> {
|
|
||||||
match provider_type {
|
|
||||||
LLMProviderType::OpenAI => {
|
|
||||||
info!("Creating OpenAI LLM provider with URL: {}", base_url);
|
|
||||||
std::sync::Arc::new(OpenAIClient::new(
|
|
||||||
"empty".to_string(),
|
|
||||||
Some(base_url),
|
|
||||||
endpoint_path,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
LLMProviderType::Claude => {
|
|
||||||
info!("Creating Claude LLM provider with URL: {}", base_url);
|
|
||||||
std::sync::Arc::new(ClaudeClient::new(base_url, deployment_name))
|
|
||||||
}
|
|
||||||
LLMProviderType::AzureClaude => {
|
|
||||||
let deployment = deployment_name.unwrap_or_else(|| "claude-opus-4-5".to_string());
|
|
||||||
info!(
|
|
||||||
"Creating Azure Claude LLM provider with URL: {}, deployment: {}",
|
|
||||||
base_url, deployment
|
|
||||||
);
|
|
||||||
std::sync::Arc::new(ClaudeClient::azure(base_url, deployment))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_llm_provider_from_url(
|
|
||||||
url: &str,
|
|
||||||
model: Option<String>,
|
|
||||||
endpoint_path: Option<String>,
|
|
||||||
) -> std::sync::Arc<dyn LLMProvider> {
|
|
||||||
let provider_type = LLMProviderType::from(url);
|
|
||||||
create_llm_provider(provider_type, url.to_string(), model, endpoint_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DynamicLLMProvider {
|
|
||||||
inner: RwLock<Arc<dyn LLMProvider>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DynamicLLMProvider {
|
|
||||||
pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
|
|
||||||
Self {
|
|
||||||
inner: RwLock::new(provider),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_provider(&self, new_provider: Arc<dyn LLMProvider>) {
|
|
||||||
let mut guard = self.inner.write().await;
|
|
||||||
*guard = new_provider;
|
|
||||||
info!("LLM provider updated dynamically");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn update_from_config(
|
|
||||||
&self,
|
|
||||||
url: &str,
|
|
||||||
model: Option<String>,
|
|
||||||
endpoint_path: Option<String>,
|
|
||||||
) {
|
|
||||||
let new_provider = create_llm_provider_from_url(url, model, endpoint_path);
|
|
||||||
self.update_provider(new_provider).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_provider(&self) -> Arc<dyn LLMProvider> {
|
|
||||||
self.inner.read().await.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl LLMProvider for DynamicLLMProvider {
|
|
||||||
async fn generate(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
self.get_provider()
|
|
||||||
.await
|
|
||||||
.generate(prompt, config, model, key)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn generate_stream(
|
|
||||||
&self,
|
|
||||||
prompt: &str,
|
|
||||||
config: &Value,
|
|
||||||
tx: mpsc::Sender<String>,
|
|
||||||
model: &str,
|
|
||||||
key: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
self.get_provider()
|
|
||||||
.await
|
|
||||||
.generate_stream(prompt, config, tx, model, key)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cancel_job(
|
|
||||||
&self,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
self.get_provider().await.cancel_job(session_id).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub r#type: String,
|
|
||||||
pub function: ToolFunction,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolFunction {
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatMessage {
|
|
||||||
role: String,
|
|
||||||
content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatCompletionResponse {
|
|
||||||
id: String,
|
|
||||||
object: String,
|
|
||||||
created: i64,
|
|
||||||
model: String,
|
|
||||||
choices: Vec<ChatChoice>,
|
|
||||||
usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ChatChoice {
|
|
||||||
index: i32,
|
|
||||||
message: ChatMessage,
|
|
||||||
finish_reason: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct Usage {
|
|
||||||
#[serde(rename = "prompt_tokens")]
|
|
||||||
prompt: i32,
|
|
||||||
#[serde(rename = "completion_tokens")]
|
|
||||||
completion: i32,
|
|
||||||
#[serde(rename = "total_tokens")]
|
|
||||||
total: i32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ErrorResponse {
|
|
||||||
error: ErrorDetail,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
struct ErrorDetail {
|
|
||||||
message: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
r#type: String,
|
|
||||||
code: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_call_serialization() {
|
|
||||||
let tool_call = ToolCall {
|
|
||||||
id: "call_123".to_string(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: ToolFunction {
|
|
||||||
name: "get_weather".to_string(),
|
|
||||||
arguments: r#"{"location": "NYC"}"#.to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&tool_call).unwrap();
|
|
||||||
assert!(json.contains("get_weather"));
|
|
||||||
assert!(json.contains("call_123"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_completion_response_serialization() {
|
|
||||||
let response = ChatCompletionResponse {
|
|
||||||
id: "test-id".to_string(),
|
|
||||||
object: "chat.completion".to_string(),
|
|
||||||
created: 1_234_567_890,
|
|
||||||
model: "gpt-4".to_string(),
|
|
||||||
choices: vec![ChatChoice {
|
|
||||||
index: 0,
|
|
||||||
message: ChatMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: Some("Hello!".to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
},
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
|
||||||
prompt: 10,
|
|
||||||
completion: 5,
|
|
||||||
total: 15,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&response).unwrap();
|
|
||||||
assert!(json.contains("chat.completion"));
|
|
||||||
assert!(json.contains("Hello!"));
|
|
||||||
assert!(json.contains("gpt-4"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_error_response_serialization() {
|
|
||||||
let error = ErrorResponse {
|
|
||||||
error: ErrorDetail {
|
|
||||||
message: "Test error".to_string(),
|
|
||||||
r#type: "test_error".to_string(),
|
|
||||||
code: "test_code".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&error).unwrap();
|
|
||||||
assert!(json.contains("Test error"));
|
|
||||||
assert!(json.contains("test_code"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_messages_empty() {
|
|
||||||
let messages = OpenAIClient::build_messages("", "", &[]);
|
|
||||||
assert!(messages.is_array());
|
|
||||||
assert!(messages.as_array().unwrap().is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_messages_with_system_prompt() {
|
|
||||||
let messages = OpenAIClient::build_messages("You are a helpful assistant.", "", &[]);
|
|
||||||
let arr = messages.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 1);
|
|
||||||
assert_eq!(arr[0]["role"], "system");
|
|
||||||
assert_eq!(arr[0]["content"], "You are a helpful assistant.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_messages_with_context() {
|
|
||||||
let messages = OpenAIClient::build_messages("System prompt", "Context data", &[]);
|
|
||||||
let arr = messages.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 2);
|
|
||||||
assert_eq!(arr[0]["content"], "System prompt");
|
|
||||||
assert_eq!(arr[1]["content"], "Context data");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_messages_with_history() {
|
|
||||||
let history = vec![
|
|
||||||
("user".to_string(), "Hello".to_string()),
|
|
||||||
("assistant".to_string(), "Hi there!".to_string()),
|
|
||||||
];
|
|
||||||
let messages = OpenAIClient::build_messages("", "", &history);
|
|
||||||
let arr = messages.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 2);
|
|
||||||
assert_eq!(arr[0]["role"], "user");
|
|
||||||
assert_eq!(arr[0]["content"], "Hello");
|
|
||||||
assert_eq!(arr[1]["role"], "assistant");
|
|
||||||
assert_eq!(arr[1]["content"], "Hi there!");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_messages_full() {
|
|
||||||
let history = vec![("user".to_string(), "What is the weather?".to_string())];
|
|
||||||
let messages = OpenAIClient::build_messages(
|
|
||||||
"You are a weather bot.",
|
|
||||||
"Current location: NYC",
|
|
||||||
&history,
|
|
||||||
);
|
|
||||||
let arr = messages.as_array().unwrap();
|
|
||||||
assert_eq!(arr.len(), 3);
|
|
||||||
assert_eq!(arr[0]["role"], "system");
|
|
||||||
assert_eq!(arr[1]["role"], "system");
|
|
||||||
assert_eq!(arr[2]["role"], "user");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_openai_client_new_default_url() {
|
|
||||||
let client = OpenAIClient::new("test_key".to_string(), None, None);
|
|
||||||
assert_eq!(client.base_url, "https://api.openai.com");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_openai_client_new_custom_url() {
|
|
||||||
let client = OpenAIClient::new(
|
|
||||||
"test_key".to_string(),
|
|
||||||
Some("http://localhost:8080".to_string()),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
assert_eq!(client.base_url, "http://localhost:8080");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_message_with_tool_calls() {
|
|
||||||
let message = ChatMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: None,
|
|
||||||
tool_calls: Some(vec![ToolCall {
|
|
||||||
id: "call_1".to_string(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: ToolFunction {
|
|
||||||
name: "search".to_string(),
|
|
||||||
arguments: r#"{"query": "test"}"#.to_string(),
|
|
||||||
},
|
|
||||||
}]),
|
|
||||||
};
|
|
||||||
|
|
||||||
let json = serde_json::to_string(&message).unwrap();
|
|
||||||
assert!(json.contains("tool_calls"));
|
|
||||||
assert!(json.contains("search"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_usage_calculation() {
|
|
||||||
let usage = Usage {
|
|
||||||
prompt: 100,
|
|
||||||
completion: 50,
|
|
||||||
total: 150,
|
|
||||||
};
|
|
||||||
assert_eq!(usage.prompt + usage.completion, usage.total);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_choice_finish_reasons() {
|
|
||||||
let stop_choice = ChatChoice {
|
|
||||||
index: 0,
|
|
||||||
message: ChatMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: Some("Done".to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
},
|
|
||||||
finish_reason: "stop".to_string(),
|
|
||||||
};
|
|
||||||
assert_eq!(stop_choice.finish_reason, "stop");
|
|
||||||
|
|
||||||
let tool_choice = ChatChoice {
|
|
||||||
index: 0,
|
|
||||||
message: ChatMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: None,
|
|
||||||
tool_calls: Some(vec![]),
|
|
||||||
},
|
|
||||||
finish_reason: "tool_calls".to_string(),
|
|
||||||
};
|
|
||||||
assert_eq!(tool_choice.finish_reason, "tool_calls");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -216,7 +216,9 @@ pub async fn init_database(
|
||||||
progress_tx.send(BootstrapProgress::ConnectingDatabase).ok();
|
progress_tx.send(BootstrapProgress::ConnectingDatabase).ok();
|
||||||
|
|
||||||
// Ensure secrets manager is initialized before creating database connection
|
// Ensure secrets manager is initialized before creating database connection
|
||||||
crate::core::shared::utils::init_secrets_manager().await;
|
crate::core::shared::utils::init_secrets_manager()
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize secrets manager");
|
||||||
|
|
||||||
let pool = match create_conn() {
|
let pool = match create_conn() {
|
||||||
Ok(pool) => {
|
Ok(pool) => {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue