342 lines
No EOL
9.8 KiB
Rust
342 lines
No EOL
9.8 KiB
Rust
|
|
|
|
|
|
#![allow(unused_imports)]
|
|
#![allow(unused_variables)]
|
|
#![allow(dead_code)]
|
|
|
|
use serde_json;
|
|
|
|
|
|
|
|
use async_trait::async_trait;
|
|
use serde_json::json;
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
|
|
struct MockLLMProvider {
|
|
response: String,
|
|
call_count: std::sync::atomic::AtomicUsize,
|
|
}
|
|
|
|
impl MockLLMProvider {
|
|
fn new(response: &str) -> Self {
|
|
Self {
|
|
response: response.to_string(),
|
|
call_count: std::sync::atomic::AtomicUsize::new(0),
|
|
}
|
|
}
|
|
|
|
fn get_call_count(&self) -> usize {
|
|
self.call_count.load(std::sync::atomic::Ordering::SeqCst)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl LLMProvider for MockLLMProvider {
|
|
async fn generate(
|
|
&self,
|
|
_prompt: &str,
|
|
_messages: &serde_json::Value,
|
|
_model: &str,
|
|
_key: &str,
|
|
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
|
|
self.call_count
|
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
Ok(self.response.clone())
|
|
}
|
|
|
|
async fn generate_stream(
|
|
&self,
|
|
_prompt: &str,
|
|
_messages: &serde_json::Value,
|
|
tx: mpsc::Sender<String>,
|
|
_model: &str,
|
|
_key: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
self.call_count
|
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
|
let _ = tx.send(self.response.clone()).await;
|
|
Ok(())
|
|
}
|
|
|
|
async fn cancel_job(
|
|
&self,
|
|
_session_id: &str,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
struct MockEmbeddingService;
|
|
|
|
#[async_trait]
|
|
impl EmbeddingService for MockEmbeddingService {
|
|
async fn get_embedding(
|
|
&self,
|
|
text: &str,
|
|
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
|
|
|
let hash = text.bytes().fold(0u32, |acc, b| acc.wrapping_add(b as u32));
|
|
Ok(vec![hash as f32 / 255.0; 10])
|
|
}
|
|
|
|
async fn compute_similarity(&self, embedding1: &[f32], embedding2: &[f32]) -> f32 {
|
|
if embedding1.len() != embedding2.len() {
|
|
return 0.0;
|
|
}
|
|
|
|
|
|
let diff: f32 = embedding1
|
|
.iter()
|
|
.zip(embedding2.iter())
|
|
.map(|(a, b)| (a - b).abs())
|
|
.sum();
|
|
|
|
1.0 - (diff / embedding1.len() as f32).min(1.0)
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_exact_cache_hit() {
|
|
|
|
let mock_provider = Arc::new(MockLLMProvider::new("Test response"));
|
|
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
|
|
|
|
let config = CacheConfig {
|
|
ttl: 60,
|
|
semantic_matching: false,
|
|
similarity_threshold: 0.95,
|
|
max_similarity_checks: 10,
|
|
key_prefix: "test_cache".to_string(),
|
|
};
|
|
|
|
let cached_provider =
|
|
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
|
|
|
|
let prompt = "What is the weather?";
|
|
let messages = json!([{"role": "user", "content": prompt}]);
|
|
let model = "test-model";
|
|
let key = "test-key";
|
|
|
|
|
|
let result1 = cached_provider
|
|
.generate(prompt, &messages, model, key)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result1, "Test response");
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
|
|
|
|
let result2 = cached_provider
|
|
.generate(prompt, &messages, model, key)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result2, "Test response");
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_semantic_cache_hit() {
|
|
|
|
let mock_provider = Arc::new(MockLLMProvider::new("Weather is sunny"));
|
|
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
|
|
|
|
let config = CacheConfig {
|
|
ttl: 60,
|
|
semantic_matching: true,
|
|
similarity_threshold: 0.8,
|
|
max_similarity_checks: 10,
|
|
key_prefix: "test_semantic".to_string(),
|
|
};
|
|
|
|
let embedding_service = Arc::new(MockEmbeddingService);
|
|
let cached_provider = CachedLLMProvider::new(
|
|
mock_provider.clone(),
|
|
cache_client,
|
|
config,
|
|
Some(embedding_service),
|
|
);
|
|
|
|
let messages = json!([{"role": "user", "content": "test"}]);
|
|
let model = "test-model";
|
|
let key = "test-key";
|
|
|
|
|
|
let result1 = cached_provider
|
|
.generate("What's the weather?", &messages, model, key)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result1, "Weather is sunny");
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
|
|
|
|
let result2 = cached_provider
|
|
.generate("What is the weather?", &messages, model, key)
|
|
.await
|
|
.unwrap();
|
|
|
|
|
|
|
|
|
|
assert_eq!(result2, "Weather is sunny");
|
|
assert_eq!(mock_provider.get_call_count(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_cache_miss_different_model() {
|
|
|
|
let mock_provider = Arc::new(MockLLMProvider::new("Response"));
|
|
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
|
|
|
|
let config = CacheConfig::default();
|
|
let cached_provider =
|
|
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
|
|
|
|
let prompt = "Same prompt";
|
|
let messages = json!([{"role": "user", "content": prompt}]);
|
|
let key = "test-key";
|
|
|
|
|
|
let _ = cached_provider
|
|
.generate(prompt, &messages, "model1", key)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
|
|
|
|
let _ = cached_provider
|
|
.generate(prompt, &messages, "model2", key)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(mock_provider.get_call_count(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_cache_statistics() {
|
|
|
|
let mock_provider = Arc::new(MockLLMProvider::new("Response"));
|
|
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
|
|
|
|
let config = CacheConfig {
|
|
ttl: 60,
|
|
semantic_matching: false,
|
|
similarity_threshold: 0.95,
|
|
max_similarity_checks: 10,
|
|
key_prefix: "test_stats".to_string(),
|
|
};
|
|
|
|
let cached_provider = CachedLLMProvider::new(mock_provider, cache_client, config, None);
|
|
|
|
|
|
let _ = cached_provider.clear_cache(None).await;
|
|
|
|
|
|
let messages = json!([]);
|
|
for i in 0..5 {
|
|
let _ = cached_provider
|
|
.generate(&format!("prompt_{}", i), &messages, "model", "key")
|
|
.await;
|
|
}
|
|
|
|
|
|
for i in 0..3 {
|
|
let _ = cached_provider
|
|
.generate(&format!("prompt_{}", i), &messages, "model", "key")
|
|
.await;
|
|
}
|
|
|
|
|
|
let stats = cached_provider.get_cache_stats().await.unwrap();
|
|
assert_eq!(stats.total_entries, 5);
|
|
assert_eq!(stats.total_hits, 3);
|
|
assert!(stats.total_size_bytes > 0);
|
|
assert_eq!(stats.model_distribution.get("model"), Some(&5));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_stream_generation_with_cache() {
|
|
|
|
let mock_provider = Arc::new(MockLLMProvider::new("Streamed response"));
|
|
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
|
|
|
|
let config = CacheConfig {
|
|
ttl: 60,
|
|
semantic_matching: false,
|
|
similarity_threshold: 0.95,
|
|
max_similarity_checks: 10,
|
|
key_prefix: "test_stream".to_string(),
|
|
};
|
|
|
|
let cached_provider =
|
|
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
|
|
|
|
let prompt = "Stream this";
|
|
let messages = json!([{"role": "user", "content": prompt}]);
|
|
let model = "test-model";
|
|
let key = "test-key";
|
|
|
|
|
|
let (tx1, mut rx1) = mpsc::channel(100);
|
|
cached_provider
|
|
.generate_stream(prompt, &messages, tx1, model, key)
|
|
.await
|
|
.unwrap();
|
|
|
|
let mut result1 = String::new();
|
|
while let Some(chunk) = rx1.recv().await {
|
|
result1.push_str(&chunk);
|
|
}
|
|
assert_eq!(result1, "Streamed response");
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
|
|
|
|
let (tx2, mut rx2) = mpsc::channel(100);
|
|
cached_provider
|
|
.generate_stream(prompt, &messages, tx2, model, key)
|
|
.await
|
|
.unwrap();
|
|
|
|
let mut result2 = String::new();
|
|
while let Some(chunk) = rx2.recv().await {
|
|
result2.push_str(&chunk);
|
|
}
|
|
assert!(result2.contains("Streamed response"));
|
|
assert_eq!(mock_provider.get_call_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
|
|
|
|
fn test_cosine_similarity_calculation() {
|
|
let service = LocalEmbeddingService::new(
|
|
"http://localhost:8082".to_string(),
|
|
"test-model".to_string(),
|
|
);
|
|
|
|
|
|
let vec1 = vec![0.5, 0.5, 0.5];
|
|
let vec2 = vec![0.5, 0.5, 0.5];
|
|
let similarity = tokio::runtime::Runtime::new()
|
|
.unwrap()
|
|
.block_on(service.compute_similarity(&vec1, &vec2));
|
|
assert_eq!(similarity, 1.0);
|
|
|
|
|
|
let vec3 = vec![1.0, 0.0];
|
|
let vec4 = vec![0.0, 1.0];
|
|
let similarity = tokio::runtime::Runtime::new()
|
|
.unwrap()
|
|
.block_on(service.compute_similarity(&vec3, &vec4));
|
|
assert_eq!(similarity, 0.0);
|
|
|
|
|
|
let vec5 = vec![1.0, 1.0];
|
|
let vec6 = vec![-1.0, -1.0];
|
|
let similarity = tokio::runtime::Runtime::new()
|
|
.unwrap()
|
|
.block_on(service.compute_similarity(&vec5, &vec6));
|
|
assert_eq!(similarity, -1.0);
|
|
} |