bottest/tests/unit/llm/llm_cache_test.rs

342 lines
No EOL
9.8 KiB
Rust

#![allow(unused_imports)]
#![allow(unused_variables)]
#![allow(dead_code)]
use serde_json;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use tokio::sync::mpsc;
struct MockLLMProvider {
response: String,
call_count: std::sync::atomic::AtomicUsize,
}
impl MockLLMProvider {
fn new(response: &str) -> Self {
Self {
response: response.to_string(),
call_count: std::sync::atomic::AtomicUsize::new(0),
}
}
fn get_call_count(&self) -> usize {
self.call_count.load(std::sync::atomic::Ordering::SeqCst)
}
}
#[async_trait]
impl LLMProvider for MockLLMProvider {
async fn generate(
&self,
_prompt: &str,
_messages: &serde_json::Value,
_model: &str,
_key: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(self.response.clone())
}
async fn generate_stream(
&self,
_prompt: &str,
_messages: &serde_json::Value,
tx: mpsc::Sender<String>,
_model: &str,
_key: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let _ = tx.send(self.response.clone()).await;
Ok(())
}
async fn cancel_job(
&self,
_session_id: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}
struct MockEmbeddingService;
#[async_trait]
impl EmbeddingService for MockEmbeddingService {
async fn get_embedding(
&self,
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let hash = text.bytes().fold(0u32, |acc, b| acc.wrapping_add(b as u32));
Ok(vec![hash as f32 / 255.0; 10])
}
async fn compute_similarity(&self, embedding1: &[f32], embedding2: &[f32]) -> f32 {
if embedding1.len() != embedding2.len() {
return 0.0;
}
let diff: f32 = embedding1
.iter()
.zip(embedding2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
1.0 - (diff / embedding1.len() as f32).min(1.0)
}
}
#[tokio::test]
async fn test_exact_cache_hit() {
let mock_provider = Arc::new(MockLLMProvider::new("Test response"));
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
let config = CacheConfig {
ttl: 60,
semantic_matching: false,
similarity_threshold: 0.95,
max_similarity_checks: 10,
key_prefix: "test_cache".to_string(),
};
let cached_provider =
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
let prompt = "What is the weather?";
let messages = json!([{"role": "user", "content": prompt}]);
let model = "test-model";
let key = "test-key";
let result1 = cached_provider
.generate(prompt, &messages, model, key)
.await
.unwrap();
assert_eq!(result1, "Test response");
assert_eq!(mock_provider.get_call_count(), 1);
let result2 = cached_provider
.generate(prompt, &messages, model, key)
.await
.unwrap();
assert_eq!(result2, "Test response");
assert_eq!(mock_provider.get_call_count(), 1);
}
#[tokio::test]
async fn test_semantic_cache_hit() {
let mock_provider = Arc::new(MockLLMProvider::new("Weather is sunny"));
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
let config = CacheConfig {
ttl: 60,
semantic_matching: true,
similarity_threshold: 0.8,
max_similarity_checks: 10,
key_prefix: "test_semantic".to_string(),
};
let embedding_service = Arc::new(MockEmbeddingService);
let cached_provider = CachedLLMProvider::new(
mock_provider.clone(),
cache_client,
config,
Some(embedding_service),
);
let messages = json!([{"role": "user", "content": "test"}]);
let model = "test-model";
let key = "test-key";
let result1 = cached_provider
.generate("What's the weather?", &messages, model, key)
.await
.unwrap();
assert_eq!(result1, "Weather is sunny");
assert_eq!(mock_provider.get_call_count(), 1);
let result2 = cached_provider
.generate("What is the weather?", &messages, model, key)
.await
.unwrap();
assert_eq!(result2, "Weather is sunny");
assert_eq!(mock_provider.get_call_count(), 2);
}
#[tokio::test]
async fn test_cache_miss_different_model() {
let mock_provider = Arc::new(MockLLMProvider::new("Response"));
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
let config = CacheConfig::default();
let cached_provider =
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
let prompt = "Same prompt";
let messages = json!([{"role": "user", "content": prompt}]);
let key = "test-key";
let _ = cached_provider
.generate(prompt, &messages, "model1", key)
.await
.unwrap();
assert_eq!(mock_provider.get_call_count(), 1);
let _ = cached_provider
.generate(prompt, &messages, "model2", key)
.await
.unwrap();
assert_eq!(mock_provider.get_call_count(), 2);
}
#[tokio::test]
async fn test_cache_statistics() {
let mock_provider = Arc::new(MockLLMProvider::new("Response"));
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
let config = CacheConfig {
ttl: 60,
semantic_matching: false,
similarity_threshold: 0.95,
max_similarity_checks: 10,
key_prefix: "test_stats".to_string(),
};
let cached_provider = CachedLLMProvider::new(mock_provider, cache_client, config, None);
let _ = cached_provider.clear_cache(None).await;
let messages = json!([]);
for i in 0..5 {
let _ = cached_provider
.generate(&format!("prompt_{}", i), &messages, "model", "key")
.await;
}
for i in 0..3 {
let _ = cached_provider
.generate(&format!("prompt_{}", i), &messages, "model", "key")
.await;
}
let stats = cached_provider.get_cache_stats().await.unwrap();
assert_eq!(stats.total_entries, 5);
assert_eq!(stats.total_hits, 3);
assert!(stats.total_size_bytes > 0);
assert_eq!(stats.model_distribution.get("model"), Some(&5));
}
#[tokio::test]
async fn test_stream_generation_with_cache() {
let mock_provider = Arc::new(MockLLMProvider::new("Streamed response"));
let cache_client = Arc::new(redis::Client::open("redis://127.0.0.1/").unwrap());
let config = CacheConfig {
ttl: 60,
semantic_matching: false,
similarity_threshold: 0.95,
max_similarity_checks: 10,
key_prefix: "test_stream".to_string(),
};
let cached_provider =
CachedLLMProvider::new(mock_provider.clone(), cache_client, config, None);
let prompt = "Stream this";
let messages = json!([{"role": "user", "content": prompt}]);
let model = "test-model";
let key = "test-key";
let (tx1, mut rx1) = mpsc::channel(100);
cached_provider
.generate_stream(prompt, &messages, tx1, model, key)
.await
.unwrap();
let mut result1 = String::new();
while let Some(chunk) = rx1.recv().await {
result1.push_str(&chunk);
}
assert_eq!(result1, "Streamed response");
assert_eq!(mock_provider.get_call_count(), 1);
let (tx2, mut rx2) = mpsc::channel(100);
cached_provider
.generate_stream(prompt, &messages, tx2, model, key)
.await
.unwrap();
let mut result2 = String::new();
while let Some(chunk) = rx2.recv().await {
result2.push_str(&chunk);
}
assert!(result2.contains("Streamed response"));
assert_eq!(mock_provider.get_call_count(), 1);
}
#[test]
fn test_cosine_similarity_calculation() {
let service = LocalEmbeddingService::new(
"http://localhost:8082".to_string(),
"test-model".to_string(),
);
let vec1 = vec![0.5, 0.5, 0.5];
let vec2 = vec![0.5, 0.5, 0.5];
let similarity = tokio::runtime::Runtime::new()
.unwrap()
.block_on(service.compute_similarity(&vec1, &vec2));
assert_eq!(similarity, 1.0);
let vec3 = vec![1.0, 0.0];
let vec4 = vec![0.0, 1.0];
let similarity = tokio::runtime::Runtime::new()
.unwrap()
.block_on(service.compute_similarity(&vec3, &vec4));
assert_eq!(similarity, 0.0);
let vec5 = vec![1.0, 1.0];
let vec6 = vec![-1.0, -1.0];
let similarity = tokio::runtime::Runtime::new()
.unwrap()
.block_on(service.compute_similarity(&vec5, &vec6));
assert_eq!(similarity, -1.0);
}