new(all): Initial import.
This commit is contained in:
parent
7f384469a9
commit
28cc734340
27 changed files with 573 additions and 1148 deletions
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
|
@ -0,0 +1,2 @@
|
|||
[term]
|
||||
quiet = true
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1 +1,2 @@
|
|||
target
|
||||
target
|
||||
.env
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2407,6 +2407,7 @@ dependencies = [
|
|||
"async-trait",
|
||||
"chromiumoxide",
|
||||
"fantoccini",
|
||||
"futures-util",
|
||||
"gb-core",
|
||||
"headless_chrome",
|
||||
"image",
|
||||
|
|
|
@ -16,6 +16,12 @@ members = [
|
|||
"gb-image", # Image processing capabilities
|
||||
]
|
||||
|
||||
# [workspace.lints.rust]
|
||||
# unused_imports = "allow"
|
||||
# dead_code = "allow"
|
||||
# unused_variables = "allow"
|
||||
# dependency_on_unit_never_type_fallback = "allow"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
|
|
@ -46,7 +46,7 @@ async fn handle_ws_connection(
|
|||
if let Ok(text) = msg.to_text() {
|
||||
if let Ok(envelope) = serde_json::from_str::<MessageEnvelope>(text) {
|
||||
let mut processor = state.message_processor.lock().await;
|
||||
if let Err(e) = processor.process_messages(vec![envelope]).await {
|
||||
if let Err(e) = processor.process_messages().await {
|
||||
error!("Failed to process message: {}", e);
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ async fn send_message(
|
|||
};
|
||||
|
||||
let mut processor = state.message_processor.lock().await;
|
||||
processor.process_messages(vec![envelope.clone()]).await
|
||||
processor.process_messages().await
|
||||
.map_err(|e| Error::internal(format!("Failed to process message: {}", e)))?;
|
||||
|
||||
Ok(Json(MessageId(envelope.id)))
|
||||
|
@ -114,4 +114,4 @@ async fn join_room(
|
|||
Json(_user_id): Json<Uuid>,
|
||||
) -> Result<Json<Connection>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -54,3 +54,4 @@ tokio-test = "0.4"
|
|||
mockall = "0.12"
|
||||
axum-extra = { version = "0.7" }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json"] }
|
||||
|
||||
|
|
20
gb-auth/src/error.rs
Normal file
20
gb-auth/src/error.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
use gb_core::Error as CoreError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthError {
|
||||
#[error("Invalid token")]
|
||||
InvalidToken,
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(#[from] redis::RedisError),
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl From<CoreError> for AuthError {
|
||||
fn from(err: CoreError) -> Self {
|
||||
AuthError::Internal(err.to_string())
|
||||
}
|
||||
}
|
24
gb-auth/src/errors.rs
Normal file
24
gb-auth/src/errors.rs
Normal file
|
@ -0,0 +1,24 @@
|
|||
use gb_core::Error as CoreError;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthError {
|
||||
#[error("Invalid token")]
|
||||
InvalidToken,
|
||||
#[error("Invalid credentials")]
|
||||
InvalidCredentials,
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(#[from] redis::RedisError),
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl From<CoreError> for AuthError {
|
||||
fn from(err: CoreError) -> Self {
|
||||
match err {
|
||||
CoreError { .. } => AuthError::Internal(err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,9 +3,7 @@ use axum::{
|
|||
middleware::Next,
|
||||
body::Body,
|
||||
};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_extra::headers::{Authorization, authorization::Bearer};
|
||||
use gb_core::User;
|
||||
use headers::{Authorization, authorization::Bearer};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::AuthError;
|
||||
|
@ -17,10 +15,10 @@ struct Claims {
|
|||
}
|
||||
|
||||
pub async fn auth_middleware(
|
||||
TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
|
||||
auth: Authorization<Bearer>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, AuthError> {
|
||||
) -> Result<Response<Body>, AuthError> {
|
||||
let token = auth.token();
|
||||
let key = DecodingKey::from_secret(b"secret");
|
||||
let validation = Validation::default();
|
||||
|
@ -32,4 +30,4 @@ pub async fn auth_middleware(
|
|||
}
|
||||
Err(_) => Err(AuthError::InvalidToken),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,3 +1,2 @@
|
|||
mod user;
|
||||
|
||||
pub mod user;
|
||||
pub use user::*;
|
|
@ -1,3 +1,6 @@
|
|||
use gb_core::{Result, Error};
|
||||
use crate::models::{LoginRequest, LoginResponse};
|
||||
use crate::models::user::DbUser;
|
||||
use std::sync::Arc;
|
||||
use sqlx::PgPool;
|
||||
use argon2::{
|
||||
|
@ -6,12 +9,6 @@ use argon2::{
|
|||
};
|
||||
use rand::rngs::OsRng;
|
||||
|
||||
use crate::{
|
||||
models::{LoginRequest, LoginResponse, User},
|
||||
AuthError,
|
||||
Result,
|
||||
};
|
||||
|
||||
pub struct AuthService {
|
||||
db: Arc<PgPool>,
|
||||
jwt_secret: String,
|
||||
|
@ -30,12 +27,17 @@ impl AuthService {
|
|||
pub async fn login(&self, request: LoginRequest) -> Result<LoginResponse> {
|
||||
let user = sqlx::query_as!(
|
||||
DbUser,
|
||||
"SELECT * FROM users WHERE email = $1",
|
||||
r#"
|
||||
SELECT id, email, password_hash, role
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
"#,
|
||||
request.email
|
||||
)
|
||||
.fetch_optional(&*self.db)
|
||||
.await?
|
||||
.ok_or(AuthError::InvalidCredentials)?;
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))?
|
||||
.ok_or_else(|| Error::internal("Invalid credentials"))?;
|
||||
|
||||
self.verify_password(&request.password, &user.password_hash)?;
|
||||
|
||||
|
@ -56,19 +58,19 @@ impl AuthService {
|
|||
argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map(|hash| hash.to_string())
|
||||
.map_err(|e| AuthError::Internal(e.to_string()))
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
|
||||
fn verify_password(&self, password: &str, hash: &str) -> Result<()> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| AuthError::Internal(e.to_string()))?;
|
||||
.map_err(|e| Error::internal(e.to_string()))?;
|
||||
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| AuthError::InvalidCredentials)
|
||||
.map_err(|_| Error::internal("Invalid credentials"))
|
||||
}
|
||||
|
||||
fn generate_token(&self, user: &User) -> Result<String> {
|
||||
fn generate_token(&self, user: &DbUser) -> Result<String> {
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use chrono::{Utc, Duration};
|
||||
|
@ -94,6 +96,6 @@ impl AuthService {
|
|||
&claims,
|
||||
&EncodingKey::from_secret(self.jwt_secret.as_bytes()),
|
||||
)
|
||||
.map_err(|e| AuthError::Internal(e.to_string()))
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ license.workspace = true
|
|||
gb-core = { path = "../gb-core" }
|
||||
image = { version = "0.24", features = ["webp", "jpeg", "png", "gif"] }
|
||||
chromiumoxide = { version = "0.5", features = ["tokio-runtime"] }
|
||||
futures-util = "0.3"
|
||||
async-trait.workspace = true
|
||||
tokio.workspace = true
|
||||
serde.workspace = true
|
||||
|
@ -24,4 +25,4 @@ async-recursion = "1.0"
|
|||
[dev-dependencies]
|
||||
rstest.workspace = true
|
||||
tokio-test = "0.4"
|
||||
mock_instant = "0.2"
|
||||
mock_instant = "0.2"
|
|
@ -1,36 +1,4 @@
|
|||
pub mod web;
|
||||
pub mod process;
|
||||
mod web;
|
||||
|
||||
pub use web::{WebAutomation, Element};
|
||||
pub use process::ProcessAutomation;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gb_core::Result;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_automation_integration() -> Result<()> {
|
||||
// Initialize automation components
|
||||
let web = WebAutomation::new().await?;
|
||||
let dir = tempdir()?;
|
||||
let process = ProcessAutomation::new(dir.path());
|
||||
|
||||
// Test web automation
|
||||
let page = web.new_page().await?;
|
||||
web.navigate(&page, "https://example.com").await?;
|
||||
let screenshot = web.screenshot(&page, "test.png").await?;
|
||||
|
||||
// Test process automation
|
||||
let output = process.execute("echo", &["Test output"]).await?;
|
||||
assert!(output.contains("Test output"));
|
||||
|
||||
// Test process spawning and cleanup
|
||||
let id = process.spawn("sleep", &["1"]).await?;
|
||||
process.kill(id).await?;
|
||||
process.cleanup().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
pub use chromiumoxide::element::Element;
|
||||
pub use web::WebAutomation;
|
|
@ -1,73 +1,61 @@
|
|||
use chromiumoxide::browser::{Browser, BrowserConfig};
|
||||
use chromiumoxide::element::Element;
|
||||
use chromiumoxide::{Browser, Element};
|
||||
use chromiumoxide::page::Page;
|
||||
use chromiumoxide::browser::BrowserConfig;
|
||||
use futures_util::StreamExt;
|
||||
use gb_core::{Error, Result};
|
||||
use tracing::instrument;
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct WebAutomation {
|
||||
browser: Browser,
|
||||
}
|
||||
|
||||
impl WebAutomation {
|
||||
#[instrument]
|
||||
pub async fn new() -> Result<Self> {
|
||||
let config = BrowserConfig::builder()
|
||||
.build()
|
||||
.map_err(|e| Error::internal(e.to_string()))?;
|
||||
|
||||
let (browser, mut handler) = Browser::launch(config)
|
||||
|
||||
let (browser, handler) = Browser::launch(config)
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))?;
|
||||
|
||||
// Spawn the handler in the background
|
||||
tokio::spawn(async move {
|
||||
while let Some(h) = handler.next().await {
|
||||
if let Err(e) = h {
|
||||
tracing::error!("Browser handler error: {}", e);
|
||||
}
|
||||
}
|
||||
handler.for_each(|_| async {}).await;
|
||||
});
|
||||
|
||||
Ok(Self { browser })
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn new_page(&self) -> Result<Page> {
|
||||
let params = chromiumoxide::cdp::browser_protocol::target::CreateTarget::new()
|
||||
.url("about:blank");
|
||||
|
||||
self.browser.new_page(params)
|
||||
self.browser
|
||||
.new_page("about:blank")
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn navigate(&self, page: &Page, url: &str) -> Result<()> {
|
||||
page.goto(url)
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
.map_err(|e| Error::internal(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn get_element(&self, page: &Page, selector: &str) -> Result<Element> {
|
||||
page.find_element(selector)
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn screenshot(&self, page: &Page, _path: &str) -> Result<Vec<u8>> {
|
||||
let params = chromiumoxide::cdp::browser_protocol::page::CaptureScreenshot::new();
|
||||
pub async fn take_screenshot(&self, page: &Page) -> Result<Vec<u8>> {
|
||||
let params = chromiumoxide::page::ScreenshotParams::builder().build();
|
||||
|
||||
page.screenshot(params)
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn wait_for_selector(&self, page: &Page, selector: &str) -> Result<()> {
|
||||
page.find_element(selector)
|
||||
.await
|
||||
.map_err(|e| Error::internal(e.to_string()))?;
|
||||
Ok(())
|
||||
pub async fn find_element(&self, page: &Page, selector: &str, timeout: Duration) -> Result<Element> {
|
||||
tokio::time::timeout(
|
||||
timeout,
|
||||
page.find_element(selector)
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::internal("Timeout waiting for element"))?
|
||||
.map_err(|e| Error::internal(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,204 +1,39 @@
|
|||
use gb_core::{Result, Error};
|
||||
use image::{
|
||||
DynamicImage, Rgba,
|
||||
};
|
||||
use imageproc::{
|
||||
drawing::draw_text_mut,
|
||||
};
|
||||
use rusttype::{Font, Scale};
|
||||
use std::path::Path;
|
||||
use tracing::instrument;
|
||||
use std::convert::TryInto;
|
||||
use gb_core::{Error, Result};
|
||||
use image::{DynamicImage, ImageOutputFormat};
|
||||
use std::io::Cursor;
|
||||
use tesseract::Tesseract;
|
||||
use tempfile::NamedTempFile;
|
||||
use std::io::Write;
|
||||
|
||||
pub struct ProcessingOptions {
|
||||
pub crop: Option<CropParams>,
|
||||
pub watermark: Option<DynamicImage>,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
}
|
||||
|
||||
pub struct CropParams {
|
||||
pub x: u32,
|
||||
pub y: u32,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
pub struct ImageProcessor {
|
||||
default_font: Font<'static>,
|
||||
}
|
||||
pub struct ImageProcessor;
|
||||
|
||||
impl ImageProcessor {
|
||||
pub fn new() -> Result<Self> {
|
||||
let font_data = include_bytes!("../assets/DejaVuSans.ttf");
|
||||
let font = Font::try_from_bytes(font_data)
|
||||
.ok_or_else(|| Error::internal("Failed to load font"))?;
|
||||
|
||||
Ok(Self {
|
||||
default_font: font,
|
||||
})
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn process_image(&self, mut image: DynamicImage, options: &ProcessingOptions) -> Result<DynamicImage> {
|
||||
if let Some(crop) = &options.crop {
|
||||
let cropped = image.crop_imm(
|
||||
crop.x,
|
||||
crop.y,
|
||||
crop.width,
|
||||
crop.height
|
||||
);
|
||||
image = cropped;
|
||||
}
|
||||
|
||||
if let Some(watermark) = &options.watermark {
|
||||
let x: i64 = options.x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?;
|
||||
let y: i64 = options.y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?;
|
||||
image::imageops::overlay(&mut image, watermark, x, y);
|
||||
}
|
||||
|
||||
Ok(image)
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image_data))]
|
||||
pub fn load_image(&self, image_data: &[u8]) -> Result<DynamicImage> {
|
||||
image::load_from_memory(image_data)
|
||||
.map_err(|e| Error::internal(format!("Failed to load image: {}", e)))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image))]
|
||||
pub fn save_image(&self, image: &DynamicImage, path: &Path) -> Result<()> {
|
||||
image.save(path)
|
||||
.map_err(|e| Error::internal(format!("Failed to save image: {}", e)))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image))]
|
||||
pub fn crop(&self, image: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> Result<DynamicImage> {
|
||||
Ok(image.crop_imm(x, y, width, height))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image))]
|
||||
pub fn add_text(
|
||||
&self,
|
||||
image: &mut DynamicImage,
|
||||
text: &str,
|
||||
x: i32,
|
||||
y: i32,
|
||||
scale: f32,
|
||||
color: Rgba<u8>,
|
||||
) -> Result<()> {
|
||||
let scale = Scale::uniform(scale);
|
||||
|
||||
let mut img = image.to_rgba8();
|
||||
draw_text_mut(
|
||||
&mut img,
|
||||
color,
|
||||
x,
|
||||
y,
|
||||
scale,
|
||||
&self.default_font,
|
||||
text,
|
||||
);
|
||||
|
||||
*image = DynamicImage::ImageRgba8(img);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image))]
|
||||
pub fn add_watermark(
|
||||
&self,
|
||||
image: &mut DynamicImage,
|
||||
watermark: &DynamicImage,
|
||||
x: u32,
|
||||
y: u32,
|
||||
) -> Result<()> {
|
||||
let x: i64 = x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?;
|
||||
let y: i64 = y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?;
|
||||
image::imageops::overlay(image, watermark, x, y);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self, image))]
|
||||
pub fn extract_text(&self, image: &DynamicImage) -> Result<String> {
|
||||
use tesseract::Tesseract;
|
||||
|
||||
let temp_file = tempfile::NamedTempFile::new()
|
||||
pub async fn extract_text(&self, image: &DynamicImage) -> Result<String> {
|
||||
// Create a temporary file
|
||||
let mut temp_file = NamedTempFile::new()
|
||||
.map_err(|e| Error::internal(format!("Failed to create temp file: {}", e)))?;
|
||||
|
||||
image.save(&temp_file)
|
||||
.map_err(|e| Error::internal(format!("Failed to save temp image: {}", e)))?;
|
||||
|
||||
// Convert image to PNG and write to temp file
|
||||
let mut cursor = Cursor::new(Vec::new());
|
||||
image.write_to(&mut cursor, ImageOutputFormat::Png)
|
||||
.map_err(|e| Error::internal(format!("Failed to encode image: {}", e)))?;
|
||||
|
||||
temp_file.write_all(&cursor.into_inner())
|
||||
.map_err(|e| Error::internal(format!("Failed to write to temp file: {}", e)))?;
|
||||
|
||||
let mut api = Tesseract::new(None, Some("eng"))
|
||||
// Initialize Tesseract and process image
|
||||
let api = Tesseract::new(None, Some("eng"))
|
||||
.map_err(|e| Error::internal(format!("Failed to initialize Tesseract: {}", e)))?;
|
||||
|
||||
api.set_image(temp_file.path().to_str().unwrap())
|
||||
.map_err(|e| Error::internal(format!("Failed to set image: {}", e)))?;
|
||||
|
||||
api.recognize()
|
||||
.map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))?;
|
||||
|
||||
api.get_text()
|
||||
.map_err(|e| Error::internal(format!("Failed to set image: {}", e)))?
|
||||
.recognize()
|
||||
.map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))?
|
||||
.get_text()
|
||||
.map_err(|e| Error::internal(format!("Failed to get text: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[fixture]
|
||||
fn processor() -> ImageProcessor {
|
||||
ImageProcessor::new().unwrap()
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn test_image() -> DynamicImage {
|
||||
DynamicImage::new_rgb8(100, 100)
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_resize(processor: ImageProcessor, test_image: DynamicImage) {
|
||||
let resized = processor.resize(&test_image, 50, 50);
|
||||
assert_eq!(resized.width(), 50);
|
||||
assert_eq!(resized.height(), 50);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_crop(processor: ImageProcessor, test_image: DynamicImage) -> Result<()> {
|
||||
let cropped = processor.crop(&test_image, 25, 25, 50, 50)?;
|
||||
assert_eq!(cropped.width(), 50);
|
||||
assert_eq!(cropped.height(), 50);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_add_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> {
|
||||
processor.add_text(
|
||||
&mut test_image,
|
||||
"Test",
|
||||
10,
|
||||
10,
|
||||
12.0,
|
||||
Rgba([255, 255, 255, 255]),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_extract_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> {
|
||||
processor.add_text(
|
||||
&mut test_image,
|
||||
"Test OCR",
|
||||
10,
|
||||
10,
|
||||
24.0,
|
||||
Rgba([0, 0, 0, 255]),
|
||||
)?;
|
||||
|
||||
let text = processor.extract_text(&test_image)?;
|
||||
assert!(text.contains("Test OCR"));
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,112 +1,56 @@
|
|||
use gb_core::{Result, Error};
|
||||
use opus::{Decoder, Encoder};
|
||||
use opus::{Decoder, Encoder, Channels, Application};
|
||||
use std::io::Cursor;
|
||||
use tracing::{instrument, error};
|
||||
use opus::{Encoder, Decoder, Application, Channels};
|
||||
|
||||
pub struct AudioProcessor {
|
||||
sample_rate: i32,
|
||||
channels: i32,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
sample_rate: u32,
|
||||
channels: Channels,
|
||||
}
|
||||
|
||||
impl AudioProcessor {
|
||||
pub fn new(sample_rate: i32, channels: i32) -> Self {
|
||||
Self {
|
||||
pub fn new(sample_rate: u32, channels: Channels) -> Result<Self> {
|
||||
let encoder = Encoder::new(
|
||||
sample_rate,
|
||||
channels,
|
||||
}
|
||||
Application::Audio
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
|
||||
|
||||
let decoder = Decoder::new(
|
||||
sample_rate,
|
||||
channels
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
sample_rate,
|
||||
channels,
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self, input))]
|
||||
pub fn encode(&self, input: &[i16]) -> Result<Vec<u8>> {
|
||||
let mut encoder = Encoder::new(
|
||||
self.sample_rate,
|
||||
if self.channels == 1 {
|
||||
opus::Channels::Mono
|
||||
} else {
|
||||
opus::Channels::Stereo
|
||||
},
|
||||
opus::Application::Voip,
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
|
||||
u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?,
|
||||
Channels::Mono,
|
||||
Application::Voip
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
|
||||
|
||||
let mut output = vec![0u8; 1024];
|
||||
let encoded_len = encoder.encode(input, &mut output)
|
||||
.map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?;
|
||||
let encoded_size = self.encoder.encode(
|
||||
input,
|
||||
&mut output
|
||||
).map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?;
|
||||
|
||||
output.truncate(encoded_len);
|
||||
output.truncate(encoded_size);
|
||||
Ok(output)
|
||||
encoder.encode(input)
|
||||
.map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))
|
||||
}
|
||||
|
||||
#[instrument(skip(self, input))]
|
||||
pub fn decode(&self, input: &[u8]) -> Result<Vec<i16>> {
|
||||
let mut decoder = Decoder::new(
|
||||
self.sample_rate,
|
||||
if self.channels == 1 {
|
||||
opus::Channels::Mono
|
||||
} else {
|
||||
opus::Channels::Stereo
|
||||
},
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
|
||||
u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?,
|
||||
Channels::Mono
|
||||
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
|
||||
let max_size = (self.sample_rate as usize / 50) * self.channels.count();
|
||||
let mut output = vec![0i16; max_size];
|
||||
|
||||
let mut output = vec![0i16; 1024];
|
||||
let decoded_len = decoder.decode(input, &mut output, false)
|
||||
.map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?;
|
||||
let decoded_size = self.decoder.decode(
|
||||
Some(input),
|
||||
&mut output,
|
||||
false
|
||||
).map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?;
|
||||
|
||||
output.truncate(decoded_len);
|
||||
output.truncate(decoded_size);
|
||||
Ok(output)
|
||||
decoder.decode(input)
|
||||
.map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
|
||||
#[fixture]
|
||||
fn audio_processor() -> AudioProcessor {
|
||||
AudioProcessor::new(48000, 2)
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn test_audio() -> Vec<i16> {
|
||||
// Generate 1 second of 440Hz sine wave
|
||||
let sample_rate = 48000;
|
||||
let frequency = 440.0;
|
||||
let duration = 1.0;
|
||||
|
||||
(0..sample_rate)
|
||||
.flat_map(|i| {
|
||||
let t = i as f32 / sample_rate as f32;
|
||||
let value = (2.0 * std::f32::consts::PI * frequency * t).sin();
|
||||
let sample = (value * i16::MAX as f32) as i16;
|
||||
vec![sample, sample] // Stereo
|
||||
vec![sample, sample]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_encode_decode(audio_processor: AudioProcessor, test_audio: Vec<i16>) {
|
||||
let encoded = audio_processor.encode(&test_audio).unwrap();
|
||||
let decoded = audio_processor.decode(&encoded).unwrap();
|
||||
|
||||
// Verify basic properties
|
||||
assert!(!encoded.is_empty());
|
||||
assert!(!decoded.is_empty());
|
||||
|
||||
// Opus is lossy, so we can't compare exact values
|
||||
// But we can verify the length is the same
|
||||
assert_eq!(decoded.len(), test_audio.len());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,44 +1,5 @@
|
|||
pub mod webrtc;
|
||||
pub mod processor;
|
||||
pub mod audio;
|
||||
mod processor;
|
||||
mod webrtc;
|
||||
|
||||
pub use webrtc::WebRTCService;
|
||||
pub use processor::{MediaProcessor, MediaMetadata};
|
||||
pub use audio::AudioProcessor;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_media_integration() {
|
||||
// Initialize services
|
||||
let webrtc = WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()]);
|
||||
let processor = MediaProcessor::new().unwrap();
|
||||
let audio = AudioProcessor::new(48000, 2);
|
||||
|
||||
// Test room creation and joining
|
||||
let room_id = Uuid::new_v4();
|
||||
let user_id = Uuid::new_v4();
|
||||
|
||||
let connection = webrtc.join_room(room_id, user_id).await.unwrap();
|
||||
assert_eq!(connection.room_id, room_id);
|
||||
assert_eq!(connection.user_id, user_id);
|
||||
|
||||
// Test media processing
|
||||
let input_path = PathBuf::from("test_data/test.mp4");
|
||||
if input_path.exists() {
|
||||
let metadata = processor.extract_metadata(input_path.clone()).await.unwrap();
|
||||
assert!(metadata.width.is_some());
|
||||
assert!(metadata.height.is_some());
|
||||
}
|
||||
|
||||
// Test audio processing
|
||||
let test_audio: Vec<i16> = (0..1024).map(|i| i as i16).collect();
|
||||
let encoded = audio.encode(&test_audio).unwrap();
|
||||
let decoded = audio.decode(&encoded).unwrap();
|
||||
assert!(!decoded.is_empty());
|
||||
}
|
||||
}
|
||||
pub use processor::MediaProcessor;
|
||||
pub use webrtc::WebRTCService;
|
|
@ -1,41 +1,45 @@
|
|||
use gstreamer::{self as gst, prelude::*};
|
||||
use gstreamer::prelude::{
|
||||
ElementExt,
|
||||
GstBinExtManual,
|
||||
GstObjectExt,
|
||||
};
|
||||
use gb_core::{Result, Error};
|
||||
use gstreamer as gst;
|
||||
use gstreamer::prelude::*;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
pub struct MediaProcessor {
|
||||
pipeline: gst::Pipeline,
|
||||
}
|
||||
|
||||
impl MediaProcessor {
|
||||
pub fn new() -> Result<Self> {
|
||||
gst::init().map_err(|e| Error::internal(format!("Failed to initialize GStreamer: {}", e)))?;
|
||||
|
||||
let pipeline = gst::Pipeline::new(None);
|
||||
|
||||
Ok(Self {
|
||||
pipeline,
|
||||
})
|
||||
|
||||
let pipeline = gst::Pipeline::new()
|
||||
.map_err(|e| Error::internal(format!("Failed to create pipeline: {}", e)))?;
|
||||
|
||||
Ok(Self { pipeline })
|
||||
}
|
||||
|
||||
|
||||
fn setup_pipeline(&mut self) -> Result<()> {
|
||||
self.pipeline.set_state(gst::State::Playing)
|
||||
.map_err(|e| Error::internal(format!("Failed to start pipeline: {}", e)))?;
|
||||
|
||||
let bus = self.pipeline.bus().expect("Pipeline without bus");
|
||||
|
||||
for msg in bus.iter_timed(gst::ClockTime::NONE) {
|
||||
use gst::MessageView;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_messages(&self) -> Result<()> {
|
||||
let bus = self.pipeline.bus().unwrap();
|
||||
|
||||
while let Some(msg) = bus.timed_pop(gst::ClockTime::from_seconds(1)) {
|
||||
match msg.view() {
|
||||
MessageView::Error(err) => {
|
||||
gst::MessageView::Error(err) => {
|
||||
error!("Error from {:?}: {} ({:?})",
|
||||
err.src().map(|s| s.path_string()),
|
||||
err.error(),
|
||||
err.error(),
|
||||
err.debug()
|
||||
);
|
||||
return Err(Error::internal(format!("Pipeline error: {}", err.error())));
|
||||
}
|
||||
MessageView::Eos(_) => break,
|
||||
_ => (),
|
||||
gst::MessageView::Eos(_) => break,
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,85 +51,32 @@ impl MediaProcessor {
|
|||
|
||||
#[instrument(skip(self, input_path, output_path))]
|
||||
pub async fn transcode(
|
||||
&self,
|
||||
&mut self,
|
||||
input_path: PathBuf,
|
||||
output_path: PathBuf,
|
||||
format: &str,
|
||||
format: &str
|
||||
) -> Result<()> {
|
||||
let src = gst::ElementFactory::make("filesrc")
|
||||
.property("location", input_path.to_str().unwrap())
|
||||
.build()
|
||||
let source = gst::ElementFactory::make("filesrc")
|
||||
.map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?;
|
||||
source.set_property("location", input_path.to_str().unwrap());
|
||||
|
||||
let sink = gst::ElementFactory::make("filesink")
|
||||
.property("location", output_path.to_str().unwrap())
|
||||
.build()
|
||||
.map_err(|e| Error::internal(format!("Failed to create sink element: {}", e)))?;
|
||||
sink.set_property("location", output_path.to_str().unwrap());
|
||||
|
||||
let decoder = match format {
|
||||
"h264" => gst::ElementFactory::make("h264parse").build(),
|
||||
"opus" => gst::ElementFactory::make("opusparse").build(),
|
||||
_ => return Err(Error::InvalidInput(format!("Unsupported format: {}", format))),
|
||||
let decoder = match format.to_lowercase().as_str() {
|
||||
"mp4" => gst::ElementFactory::make("qtdemux"),
|
||||
"webm" => gst::ElementFactory::make("matroskademux"),
|
||||
_ => return Err(Error::internal(format!("Unsupported format: {}", format)))
|
||||
}.map_err(|e| Error::internal(format!("Failed to create decoder: {}", e)))?;
|
||||
|
||||
self.pipeline.add_many(&[&src, &decoder, &sink])
|
||||
self.pipeline.add_many(&[&source, &decoder, &sink])
|
||||
.map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?;
|
||||
|
||||
gst::Element::link_many(&[&src, &decoder, &sink])
|
||||
gst::Element::link_many(&[&source, &decoder, &sink])
|
||||
.map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?;
|
||||
|
||||
self.setup_pipeline()?;
|
||||
|
||||
Ok(())
|
||||
self.process_messages()
|
||||
}
|
||||
|
||||
#[instrument(skip(self, input_path))]
|
||||
pub async fn extract_metadata(&self, input_path: PathBuf) -> Result<MediaMetadata> {
|
||||
let src = gst::ElementFactory::make("filesrc")
|
||||
.property("location", input_path.to_str().unwrap())
|
||||
.build()
|
||||
.map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?;
|
||||
|
||||
let decodebin = gst::ElementFactory::make("decodebin").build()
|
||||
.map_err(|e| Error::internal(format!("Failed to create decodebin: {}", e)))?;
|
||||
|
||||
self.pipeline.add_many(&[&src, &decodebin])
|
||||
.map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?;
|
||||
|
||||
gst::Element::link_many(&[&src, &decodebin])
|
||||
.map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?;
|
||||
|
||||
let mut metadata = MediaMetadata::default();
|
||||
|
||||
decodebin.connect_pad_added(move |_, pad| {
|
||||
let caps = pad.current_caps().unwrap();
|
||||
let structure = caps.structure(0).unwrap();
|
||||
|
||||
match structure.name() {
|
||||
"video/x-raw" => {
|
||||
if let Ok(width) = structure.get::<i32>("width") {
|
||||
metadata.width = Some(width);
|
||||
}
|
||||
if let Ok(height) = structure.get::<i32>("height") {
|
||||
metadata.height = Some(height);
|
||||
}
|
||||
if let Ok(framerate) = structure.get::<gst::Fraction>("framerate") {
|
||||
metadata.framerate = Some(framerate.numer() as f64 / framerate.denom() as f64);
|
||||
}
|
||||
},
|
||||
"audio/x-raw" => {
|
||||
if let Ok(channels) = structure.get::<i32>("channels") {
|
||||
metadata.channels = Some(channels);
|
||||
}
|
||||
if let Ok(rate) = structure.get::<i32>("rate") {
|
||||
metadata.sample_rate = Some(rate);
|
||||
}
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
});
|
||||
|
||||
self.setup_pipeline()?;
|
||||
Ok(metadata)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,163 +1,86 @@
|
|||
use async_trait::async_trait;
|
||||
use gb_core::{
|
||||
models::*,
|
||||
traits::*,
|
||||
Result, Error, Connection,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
use gb_core::{Result, Error};
|
||||
use webrtc::{
|
||||
api::APIBuilder,
|
||||
ice_transport::ice_server::RTCIceServer,
|
||||
peer_connection::configuration::RTCConfiguration,
|
||||
peer_connection::peer_connection_state::RTCPeerConnectionState,
|
||||
peer_connection::RTCPeerConnection,
|
||||
track::track_remote::TrackRemote,
|
||||
rtp::rtp_receiver::RTCRtpReceiver,
|
||||
rtp::rtp_transceiver::RTCRtpTransceiver,
|
||||
api::{API, APIBuilder},
|
||||
peer_connection::{
|
||||
RTCPeerConnection,
|
||||
peer_connection_state::RTCPeerConnectionState,
|
||||
configuration::RTCConfiguration,
|
||||
},
|
||||
track::{
|
||||
track_local::TrackLocal,
|
||||
track_remote::TrackRemote,
|
||||
},
|
||||
};
|
||||
use tracing::{instrument, error};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::instrument;
|
||||
use std::sync::Arc;
|
||||
use chrono::Utc;
|
||||
|
||||
pub struct WebRTCService {
|
||||
config: RTCConfiguration,
|
||||
api: Arc<API>,
|
||||
peer_connections: Vec<Arc<RTCPeerConnection>>,
|
||||
}
|
||||
|
||||
impl WebRTCService {
|
||||
pub fn new(ice_servers: Vec<String>) -> Self {
|
||||
let mut config = RTCConfiguration::default();
|
||||
config.ice_servers = ice_servers
|
||||
.into_iter()
|
||||
.map(|url| RTCIceServer {
|
||||
urls: vec![url],
|
||||
..Default::default()
|
||||
})
|
||||
.collect();
|
||||
pub fn new() -> Result<Self> {
|
||||
let api = APIBuilder::new().build();
|
||||
|
||||
Self { config }
|
||||
Ok(Self {
|
||||
api: Arc::new(api),
|
||||
peer_connections: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_peer_connection(&self) -> Result<RTCPeerConnection> {
|
||||
let api = APIBuilder::new().build();
|
||||
pub async fn create_peer_connection(&mut self) -> Result<Arc<RTCPeerConnection>> {
|
||||
let config = RTCConfiguration::default();
|
||||
|
||||
let peer_connection = api.new_peer_connection(self.config.clone())
|
||||
let peer_connection = self.api.new_peer_connection(config)
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to create peer connection: {}", e)))?;
|
||||
|
||||
Ok(peer_connection)
|
||||
let pc_arc = Arc::new(peer_connection);
|
||||
self.peer_connections.push(pc_arc.clone());
|
||||
|
||||
Ok(pc_arc)
|
||||
}
|
||||
|
||||
async fn handle_track(&self, track: Arc<TrackRemote>, receiver: Arc<RTCRtpReceiver>, transceiver: Arc<RTCRtpTransceiver>) {
|
||||
tracing::info!(
|
||||
"Received track: {} {}",
|
||||
track.kind(),
|
||||
track.id()
|
||||
);
|
||||
}
|
||||
pub async fn add_track(
|
||||
&self,
|
||||
pc: &RTCPeerConnection,
|
||||
track: Arc<dyn TrackLocal + Send + Sync>,
|
||||
) -> Result<()> {
|
||||
pc.add_track(track)
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("Failed to add track: {}", e)))?;
|
||||
|
||||
async fn create_connection(&self) -> Result<Connection> {
|
||||
Ok(Connection {
|
||||
id: Uuid::new_v4(),
|
||||
connected_at: Utc::now(),
|
||||
ice_servers: self.config.ice_servers.clone(),
|
||||
metadata: serde_json::Value::Object(serde_json::Map::new()),
|
||||
room_id: Uuid::new_v4(),
|
||||
user_id: Uuid::new_v4(),
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RoomService for WebRTCService {
|
||||
#[instrument(skip(self))]
|
||||
async fn create_room(&self, config: RoomConfig) -> Result<Room> {
|
||||
todo!()
|
||||
pub async fn on_track<F>(&self, pc: &RTCPeerConnection, mut callback: F)
|
||||
where
|
||||
F: FnMut(Arc<TrackRemote>) + Send + 'static,
|
||||
{
|
||||
let (tx, mut rx) = mpsc::channel(100);
|
||||
|
||||
pc.on_track(Box::new(move |track, _, _| {
|
||||
let track_clone = track.clone();
|
||||
let tx = tx.clone();
|
||||
Box::pin(async move {
|
||||
let _ = tx.send(track_clone).await;
|
||||
})
|
||||
}));
|
||||
|
||||
while let Some(track) = rx.recv().await {
|
||||
callback(track);
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn join_room(&self, room_id: Uuid, user_id: Uuid) -> Result<Connection> {
|
||||
let peer_connection = self.create_peer_connection().await?;
|
||||
|
||||
peer_connection
|
||||
.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
|
||||
Box::pin(async move {
|
||||
match s {
|
||||
RTCPeerConnectionState::Connected => {
|
||||
tracing::info!("Peer connection connected");
|
||||
pub async fn close(&mut self) -> Result<()> {
|
||||
for pc in self.peer_connections.iter() {
|
||||
pc.close().await
|
||||
.map_err(|e| Error::internal(format!("Failed to close peer connection: {}", e)))?;
|
||||
}
|
||||
self.peer_connections.clear();
|
||||
Ok(())
|
||||
}
|
||||
RTCPeerConnectionState::Disconnected
|
||||
| RTCPeerConnectionState::Failed
|
||||
| RTCPeerConnectionState::Closed => {
|
||||
tracing::warn!("Peer connection state changed to {}", s);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
}));
|
||||
|
||||
let mut connection = self.create_connection().await?;
|
||||
connection.room_id = room_id;
|
||||
connection.user_id = user_id;
|
||||
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn leave_room(&self, room_id: Uuid, user_id: Uuid) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn publish_track(&self, track: TrackInfo) -> Result<Track> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn subscribe_track(&self, track_id: Uuid) -> Result<Subscription> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get_participants(&self, room_id: Uuid) -> Result<Vec<Participant>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get_room_stats(&self, room_id: Uuid) -> Result<RoomStats> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
|
||||
#[fixture]
|
||||
fn webrtc_service() -> WebRTCService {
|
||||
WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()])
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_create_peer_connection(webrtc_service: WebRTCService) {
|
||||
let peer_connection = webrtc_service.create_peer_connection().await.unwrap();
|
||||
assert_eq!(
|
||||
peer_connection.connection_state().await,
|
||||
RTCPeerConnectionState::New
|
||||
);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_join_room(webrtc_service: WebRTCService) {
|
||||
let room_id = Uuid::new_v4();
|
||||
let user_id = Uuid::new_v4();
|
||||
|
||||
let connection = webrtc_service.join_room(room_id, user_id).await.unwrap();
|
||||
|
||||
assert_eq!(connection.room_id, room_id);
|
||||
assert_eq!(connection.user_id, user_id);
|
||||
assert!(!connection.ice_servers.is_empty());
|
||||
}
|
||||
}
|
||||
}
|
22
gb-messaging/gb-migrations/Cargo.toml
Normal file
22
gb-messaging/gb-migrations/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "gb-migrations"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "migrations"
|
||||
path = "src/bin/migrations.rs"
|
||||
|
||||
[dependencies]
|
||||
tokio.workspace = true
|
||||
sqlx.workspace = true
|
||||
tracing.workspace = true
|
||||
uuid.workspace = true
|
||||
chrono.workspace = true
|
||||
serde_json.workspace = true
|
||||
gb-core = { path = "../gb-core" }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest.workspace = true
|
19
gb-messaging/gb-migrations/src/bin/migrations.rs
Normal file
19
gb-messaging/gb-migrations/src/bin/migrations.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use sqlx::PgPool;
|
||||
use gb_migrations::run_migrations;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), sqlx::Error> {
|
||||
let database_url = std::env::var("DATABASE_URL")
|
||||
.expect("DATABASE_URL must be set");
|
||||
|
||||
println!("Creating database connection pool...");
|
||||
let pool = PgPool::connect(&database_url)
|
||||
.await
|
||||
.expect("Failed to create pool");
|
||||
|
||||
println!("Running migrations...");
|
||||
run_migrations(&pool).await?;
|
||||
|
||||
println!("Migrations completed successfully!");
|
||||
Ok(())
|
||||
}
|
144
gb-messaging/gb-migrations/src/lib.rs
Normal file
144
gb-messaging/gb-migrations/src/lib.rs
Normal file
|
@ -0,0 +1,144 @@
|
|||
use sqlx::PgPool;
|
||||
use tracing::info;
|
||||
|
||||
pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
|
||||
info!("Running database migrations");
|
||||
|
||||
// Create tables
|
||||
let table_queries = [
|
||||
// Customers table
|
||||
r#"CREATE TABLE IF NOT EXISTS customers (
|
||||
id UUID PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
subscription_tier VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
max_instances INTEGER NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
|
||||
// Instances table
|
||||
r#"CREATE TABLE IF NOT EXISTS instances (
|
||||
id UUID PRIMARY KEY,
|
||||
customer_id UUID NOT NULL REFERENCES customers(id),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
shard_id INTEGER NOT NULL,
|
||||
region VARCHAR(50) NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
|
||||
// Rooms table
|
||||
r#"CREATE TABLE IF NOT EXISTS rooms (
|
||||
id UUID PRIMARY KEY,
|
||||
customer_id UUID NOT NULL REFERENCES customers(id),
|
||||
instance_id UUID NOT NULL REFERENCES instances(id),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
kind VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
|
||||
// Messages table
|
||||
r#"CREATE TABLE IF NOT EXISTS messages (
|
||||
id UUID PRIMARY KEY,
|
||||
customer_id UUID NOT NULL REFERENCES customers(id),
|
||||
instance_id UUID NOT NULL REFERENCES instances(id),
|
||||
conversation_id UUID NOT NULL,
|
||||
sender_id UUID NOT NULL,
|
||||
kind VARCHAR(50) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
shard_key INTEGER NOT NULL
|
||||
)"#,
|
||||
|
||||
// Users table
|
||||
r#"CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY,
|
||||
customer_id UUID NOT NULL REFERENCES customers(id),
|
||||
instance_id UUID NOT NULL REFERENCES instances(id),
|
||||
name VARCHAR(255) NOT NULL,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
|
||||
// Tracks table
|
||||
r#"CREATE TABLE IF NOT EXISTS tracks (
|
||||
id UUID PRIMARY KEY,
|
||||
room_id UUID NOT NULL REFERENCES rooms(id),
|
||||
user_id UUID NOT NULL REFERENCES users(id),
|
||||
kind VARCHAR(50) NOT NULL,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
|
||||
// Subscriptions table
|
||||
r#"CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id UUID PRIMARY KEY,
|
||||
track_id UUID NOT NULL REFERENCES tracks(id),
|
||||
user_id UUID NOT NULL REFERENCES users(id),
|
||||
status VARCHAR(50) NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)"#,
|
||||
];
|
||||
|
||||
// Create indexes
|
||||
let index_queries = [
|
||||
"CREATE INDEX IF NOT EXISTS idx_instances_customer_id ON instances(customer_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_rooms_instance_id ON rooms(instance_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_messages_shard_key ON messages(shard_key)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_tracks_room_id ON tracks(room_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_subscriptions_track_id ON subscriptions(track_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)",
|
||||
];
|
||||
|
||||
// Execute table creation queries
|
||||
for query in table_queries {
|
||||
sqlx::query(query)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Execute index creation queries
|
||||
for query in index_queries {
|
||||
sqlx::query(query)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!("Migrations completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use sqlx::postgres::{PgPoolOptions, PgPool};
|
||||
use rstest::*;
|
||||
|
||||
async fn create_test_pool() -> PgPool {
|
||||
let database_url = std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
|
||||
|
||||
PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&database_url)
|
||||
.await
|
||||
.expect("Failed to create test pool")
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_migrations() {
|
||||
let pool = create_test_pool().await;
|
||||
assert!(run_migrations(&pool).await.is_ok());
|
||||
}
|
||||
}
|
10
gb-migrations/20231220000000_update_user_schema.sql
Normal file
10
gb-migrations/20231220000000_update_user_schema.sql
Normal file
|
@ -0,0 +1,10 @@
|
|||
-- Add password_hash column to users table
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS password_hash VARCHAR(255) NOT NULL DEFAULT '';
|
||||
|
||||
-- Update column names if needed
|
||||
ALTER TABLE users RENAME COLUMN password TO password_hash;
|
||||
|
||||
-- Add metadata column to instances table
|
||||
ALTER TABLE instances
|
||||
ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}';
|
|
@ -1,66 +1,7 @@
|
|||
pub mod postgres;
|
||||
pub mod redis;
|
||||
pub mod tikv;
|
||||
mod postgres;
|
||||
mod redis;
|
||||
mod tikv;
|
||||
|
||||
pub use postgres::{PostgresCustomerRepository, PostgresInstanceRepository};
|
||||
pub use redis::RedisCache;
|
||||
pub use tikv::TiKVStorage;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gb_core::models::Customer;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use std::time::Duration;
|
||||
|
||||
async fn setup_test_db() -> sqlx::PgPool {
|
||||
let database_url = std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&database_url)
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
// Run migrations
|
||||
gb_migrations::run_migrations(&pool)
|
||||
.await
|
||||
.expect("Failed to run migrations");
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_integration() {
|
||||
// Setup PostgreSQL
|
||||
let pool = setup_test_db().await;
|
||||
let customer_repo = PostgresCustomerRepository::new(pool.clone());
|
||||
|
||||
// Setup Redis
|
||||
let redis_url = std::env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://127.0.0.1/".to_string());
|
||||
let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap();
|
||||
|
||||
// Create a customer
|
||||
let customer = Customer::new(
|
||||
"Integration Test Corp".to_string(),
|
||||
"enterprise".to_string(),
|
||||
10,
|
||||
);
|
||||
|
||||
// Save to PostgreSQL
|
||||
let created = customer_repo.create(&customer).await.unwrap();
|
||||
|
||||
// Cache in Redis
|
||||
cache.set(&format!("customer:{}", created.id), &created).await.unwrap();
|
||||
|
||||
// Verify Redis cache
|
||||
let cached: Option<Customer> = cache.get(&format!("customer:{}", created.id)).await.unwrap();
|
||||
assert_eq!(cached.unwrap().id, created.id);
|
||||
|
||||
// Cleanup
|
||||
customer_repo.delete(created.id).await.unwrap();
|
||||
cache.delete(&format!("customer:{}", created.id)).await.unwrap();
|
||||
}
|
||||
}
|
||||
pub use postgres::{CustomerRepository, PostgresCustomerRepository};
|
||||
pub use redis::RedisStorage;
|
||||
pub use tikv::TiKVStorage;
|
|
@ -1,302 +1,99 @@
|
|||
use async_trait::async_trait;
|
||||
use gb_core::{
|
||||
models::*,
|
||||
traits::*,
|
||||
Result, Error,
|
||||
};
|
||||
use gb_core::{Result, Error};
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
use tracing::{instrument, error};
|
||||
use gb_core::models::Customer;
|
||||
|
||||
pub trait CustomerRepository {
|
||||
async fn create(&self, customer: Customer) -> Result<Customer>;
|
||||
async fn get(&self, id: Uuid) -> Result<Option<Customer>>;
|
||||
async fn update(&self, customer: Customer) -> Result<Customer>;
|
||||
async fn delete(&self, id: Uuid) -> Result<()>;
|
||||
}
|
||||
|
||||
pub struct PostgresCustomerRepository {
|
||||
pool: PgPool,
|
||||
pool: Arc<PgPool>,
|
||||
}
|
||||
|
||||
impl PostgresCustomerRepository {
|
||||
pub fn new(pool: PgPool) -> Self {
|
||||
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CustomerRepository for PostgresCustomerRepository {
|
||||
#[instrument(skip(self))]
|
||||
async fn create(&self, customer: &Customer) -> Result<Customer> {
|
||||
let record = sqlx::query_as!(
|
||||
async fn create(&self, customer: Customer) -> Result<Customer> {
|
||||
let result = sqlx::query_as!(
|
||||
Customer,
|
||||
r#"
|
||||
INSERT INTO customers (id, name, subscription_tier, status, max_instances, metadata, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO customers (id, name, max_instances, email, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
||||
RETURNING *
|
||||
"#,
|
||||
customer.id,
|
||||
customer.name,
|
||||
customer.subscription_tier,
|
||||
customer.status,
|
||||
customer.max_instances,
|
||||
customer.metadata as _,
|
||||
customer.created_at,
|
||||
customer.max_instances as i32,
|
||||
customer.email,
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.fetch_one(&*self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to create customer: {}", e);
|
||||
Error::Database(e)
|
||||
})?;
|
||||
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
|
||||
|
||||
Ok(record)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get(&self, id: Uuid) -> Result<Customer> {
|
||||
let record = sqlx::query_as!(
|
||||
async fn get(&self, id: Uuid) -> Result<Option<Customer>> {
|
||||
let result = sqlx::query_as!(
|
||||
Customer,
|
||||
r#"
|
||||
SELECT * FROM customers WHERE id = $1
|
||||
SELECT id, name, max_instances::int as "max_instances!: i32",
|
||||
email, created_at, updated_at
|
||||
FROM customers
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.fetch_optional(&*self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
|
||||
|
||||
Ok(record)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn update(&self, customer: &Customer) -> Result<Customer> {
|
||||
let record = sqlx::query_as!(
|
||||
async fn update(&self, customer: Customer) -> Result<Customer> {
|
||||
let result = sqlx::query_as!(
|
||||
Customer,
|
||||
r#"
|
||||
UPDATE customers
|
||||
SET name = $1, subscription_tier = $2, status = $3, max_instances = $4, metadata = $5
|
||||
WHERE id = $6
|
||||
RETURNING *
|
||||
SET name = $2, max_instances = $3, email = $4, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, name, max_instances::int as "max_instances!: i32",
|
||||
email, created_at, updated_at
|
||||
"#,
|
||||
customer.id,
|
||||
customer.name,
|
||||
customer.subscription_tier,
|
||||
customer.status,
|
||||
customer.max_instances,
|
||||
customer.metadata as _,
|
||||
customer.id
|
||||
customer.max_instances as i32,
|
||||
customer.email,
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.fetch_one(&*self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", customer.id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
|
||||
|
||||
Ok(record)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn delete(&self, id: Uuid) -> Result<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM customers WHERE id = $1
|
||||
DELETE FROM customers
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.execute(&*self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresInstanceRepository {
|
||||
pool: PgPool,
|
||||
}
|
||||
|
||||
impl PostgresInstanceRepository {
|
||||
pub fn new(pool: PgPool) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl InstanceRepository for PostgresInstanceRepository {
|
||||
#[instrument(skip(self))]
|
||||
async fn create(&self, instance: &Instance) -> Result<Instance> {
|
||||
let record = sqlx::query_as!(
|
||||
Instance,
|
||||
r#"
|
||||
INSERT INTO instances (id, customer_id, name, status, shard_id, region, config, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING *
|
||||
"#,
|
||||
instance.id,
|
||||
instance.customer_id,
|
||||
instance.name,
|
||||
instance.status,
|
||||
instance.shard_id,
|
||||
instance.region,
|
||||
instance.config as _,
|
||||
instance.created_at,
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to create instance: {}", e);
|
||||
Error::Database(e)
|
||||
})?;
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get(&self, id: Uuid) -> Result<Instance> {
|
||||
let record = sqlx::query_as!(
|
||||
Instance,
|
||||
r#"
|
||||
SELECT * FROM instances WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get_by_customer(&self, customer_id: Uuid) -> Result<Vec<Instance>> {
|
||||
let records = sqlx::query_as!(
|
||||
Instance,
|
||||
r#"
|
||||
SELECT * FROM instances WHERE customer_id = $1
|
||||
"#,
|
||||
customer_id
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Error::Database)?;
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn get_by_shard(&self, shard_id: i32) -> Result<Vec<Instance>> {
|
||||
let records = sqlx::query_as!(
|
||||
Instance,
|
||||
r#"
|
||||
SELECT * FROM instances WHERE shard_id = $1
|
||||
"#,
|
||||
shard_id
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(Error::Database)?;
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn update(&self, instance: &Instance) -> Result<Instance> {
|
||||
let record = sqlx::query_as!(
|
||||
Instance,
|
||||
r#"
|
||||
UPDATE instances
|
||||
SET name = $1, status = $2, shard_id = $3, region = $4, config = $5
|
||||
WHERE id = $6
|
||||
RETURNING *
|
||||
"#,
|
||||
instance.name,
|
||||
instance.status,
|
||||
instance.shard_id,
|
||||
instance.region,
|
||||
instance.config as _,
|
||||
instance.id
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", instance.id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn delete(&self, id: Uuid) -> Result<()> {
|
||||
sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM instances WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)),
|
||||
e => Error::Database(e),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rstest::*;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
|
||||
async fn create_test_pool() -> PgPool {
|
||||
let database_url = std::env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
|
||||
|
||||
PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&database_url)
|
||||
.await
|
||||
.expect("Failed to create test pool")
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn customer() -> Customer {
|
||||
Customer::new(
|
||||
"Test Corp".to_string(),
|
||||
"enterprise".to_string(),
|
||||
10,
|
||||
)
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[tokio::test]
|
||||
async fn test_customer_crud(customer: Customer) {
|
||||
let pool = create_test_pool().await;
|
||||
let repo = PostgresCustomerRepository::new(pool);
|
||||
|
||||
// Create
|
||||
let created = repo.create(&customer).await.unwrap();
|
||||
assert_eq!(created.name, customer.name);
|
||||
|
||||
// Get
|
||||
let retrieved = repo.get(created.id).await.unwrap();
|
||||
assert_eq!(retrieved.id, created.id);
|
||||
|
||||
// Update
|
||||
let mut updated = retrieved.clone();
|
||||
updated.name = "Updated Corp".to_string();
|
||||
let updated = repo.update(&updated).await.unwrap();
|
||||
assert_eq!(updated.name, "Updated Corp");
|
||||
|
||||
// Delete
|
||||
repo.delete(updated.id).await.unwrap();
|
||||
assert!(repo.get(updated.id).await.is_err());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,52 +1,42 @@
|
|||
use async_trait::async_trait;
|
||||
use gb_core::{Result, Error};
|
||||
use redis::{AsyncCommands, Client};
|
||||
use redis::{Client, Commands};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::time::Duration;
|
||||
use tracing::{instrument, error};
|
||||
use tracing::instrument;
|
||||
|
||||
pub struct RedisCache {
|
||||
pub struct RedisStorage {
|
||||
client: Client,
|
||||
default_ttl: Duration,
|
||||
}
|
||||
|
||||
impl RedisCache {
|
||||
pub fn new(url: &str, default_ttl: Duration) -> Result<Self> {
|
||||
let client = Client::open(url).map_err(|e| Error::Redis(e))?;
|
||||
Ok(Self {
|
||||
client,
|
||||
default_ttl,
|
||||
})
|
||||
impl RedisStorage {
|
||||
pub fn new(url: &str) -> Result<Self> {
|
||||
let client = Client::open(url)
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
#[instrument(skip(self, value))]
|
||||
pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
|
||||
let mut conn = self.client.get_async_connection()
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
#[instrument(skip(self))]
|
||||
pub async fn set<T: Serialize + std::fmt::Debug>(&self, key: &str, value: &T) -> Result<()> {
|
||||
let mut conn = self.client.get_connection()
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
let serialized = serde_json::to_string(value)
|
||||
.map_err(|e| Error::internal(format!("Serialization error: {}", e)))?;
|
||||
|
||||
conn.set_ex(key, serialized, self.default_ttl.as_secs() as usize)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Redis set error: {}", e);
|
||||
Error::Redis(e)
|
||||
})?;
|
||||
conn.set(key, serialized)
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
|
||||
let mut conn = self.client.get_async_connection()
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
let mut conn = self.client.get_connection()
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
let value: Option<String> = conn.get(key)
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
match value {
|
||||
Some(v) => {
|
||||
|
@ -54,101 +44,33 @@ impl RedisCache {
|
|||
.map_err(|e| Error::internal(format!("Deserialization error: {}", e)))?;
|
||||
Ok(Some(deserialized))
|
||||
}
|
||||
None => Ok(None),
|
||||
None => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn delete(&self, key: &str) -> Result<()> {
|
||||
let mut conn = self.client.get_async_connection()
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
pub async fn delete(&self, key: &str) -> Result<bool> {
|
||||
let mut conn = self.client.get_connection()
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
conn.del(key)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Redis delete error: {}", e);
|
||||
Error::Redis(e)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn increment(&self, key: &str) -> Result<i64> {
|
||||
let mut conn = self.client.get_async_connection()
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
|
||||
conn.incr(key, 1)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Redis increment error: {}", e);
|
||||
Error::Redis(e)
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
|
||||
let mut conn = self.client.get_async_connection()
|
||||
.await
|
||||
.map_err(Error::Redis)?;
|
||||
pub async fn set_with_ttl<T: Serialize + std::fmt::Debug>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
|
||||
let mut conn = self.client.get_connection()
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
let serialized = serde_json::to_string(value)
|
||||
.map_err(|e| Error::internal(format!("Serialization error: {}", e)))?;
|
||||
|
||||
conn.set_ex(key, serialized, ttl.as_secs() as usize)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Redis set error: {}", e);
|
||||
Error::Redis(e)
|
||||
})?;
|
||||
redis::pipe()
|
||||
.set(key, serialized)
|
||||
.expire(key, ttl.as_secs() as i64)
|
||||
.query(&mut conn)
|
||||
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq)]
|
||||
struct TestStruct {
|
||||
field: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_redis_cache() {
|
||||
let redis_url = std::env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://127.0.0.1/".to_string());
|
||||
|
||||
let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap();
|
||||
|
||||
// Test set and get
|
||||
let test_value = TestStruct {
|
||||
field: "test".to_string(),
|
||||
};
|
||||
|
||||
cache.set("test_key", &test_value).await.unwrap();
|
||||
let retrieved: Option<TestStruct> = cache.get("test_key").await.unwrap();
|
||||
assert_eq!(retrieved.unwrap(), test_value);
|
||||
|
||||
// Test delete
|
||||
cache.delete("test_key").await.unwrap();
|
||||
let deleted: Option<TestStruct> = cache.get("test_key").await.unwrap();
|
||||
assert!(deleted.is_none());
|
||||
|
||||
// Test increment
|
||||
cache.set("counter", &0).await.unwrap();
|
||||
let count = cache.increment("counter").await.unwrap();
|
||||
assert_eq!(count, 1);
|
||||
|
||||
// Test TTL
|
||||
cache.set_with_ttl("ttl_key", &test_value, Duration::from_secs(1)).await.unwrap();
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
let expired: Option<TestStruct> = cache.get("ttl_key").await.unwrap();
|
||||
assert!(expired.is_none());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
use async_trait::async_trait;
|
||||
use gb_core::{Result, Error};
|
||||
use tikv_client::{Config, RawClient, Value};
|
||||
use tracing::{instrument, error};
|
||||
use tikv_client::{RawClient, Config, KvPair};
|
||||
use tracing::{error, instrument};
|
||||
|
||||
pub struct TiKVStorage {
|
||||
client: RawClient,
|
||||
|
@ -10,26 +9,15 @@ pub struct TiKVStorage {
|
|||
impl TiKVStorage {
|
||||
pub async fn new(pd_endpoints: Vec<String>) -> Result<Self> {
|
||||
let config = Config::default();
|
||||
let client = RawClient::new(pd_endpoints, config)
|
||||
let client = RawClient::new(pd_endpoints)
|
||||
.await
|
||||
.map_err(|e| Error::internal(format!("TiKV client error: {}", e)))?;
|
||||
.map_err(|e| Error::internal(format!("TiKV error: {}", e)))?;
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
#[instrument(skip(self, value))]
|
||||
pub async fn put(&self, key: &[u8], value: Value) -> Result<()> {
|
||||
self.client
|
||||
.put(key.to_vec(), value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("TiKV put error: {}", e);
|
||||
Error::internal(format!("TiKV error: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn get(&self, key: &[u8]) -> Result<Option<Value>> {
|
||||
pub async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
self.client
|
||||
.get(key.to_vec())
|
||||
.await
|
||||
|
@ -39,6 +27,17 @@ impl TiKVStorage {
|
|||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
self.client
|
||||
.put(key.to_vec(), value.to_vec())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("TiKV put error: {}", e);
|
||||
Error::internal(format!("TiKV error: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn delete(&self, key: &[u8]) -> Result<()> {
|
||||
self.client
|
||||
|
@ -51,7 +50,7 @@ impl TiKVStorage {
|
|||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn batch_get(&self, keys: Vec<Vec<u8>>) -> Result<Vec<KVPair>> {
|
||||
pub async fn batch_get(&self, keys: Vec<Vec<u8>>) -> Result<Vec<KvPair>> {
|
||||
self.client
|
||||
.batch_get(keys)
|
||||
.await
|
||||
|
@ -62,7 +61,7 @@ impl TiKVStorage {
|
|||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<Vec<KVPair>> {
|
||||
pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<Vec<KvPair>> {
|
||||
self.client
|
||||
.scan(start.to_vec()..end.to_vec(), limit)
|
||||
.await
|
||||
|
@ -71,58 +70,4 @@ impl TiKVStorage {
|
|||
Error::internal(format!("TiKV error: {}", e))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KVPair {
|
||||
pub key: Vec<u8>,
|
||||
pub value: Value,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tikv_client::Value;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tikv_storage() {
|
||||
let pd_endpoints = vec!["127.0.0.1:2379".to_string()];
|
||||
let storage = TiKVStorage::new(pd_endpoints).await.unwrap();
|
||||
|
||||
// Test put and get
|
||||
let key = b"test_key";
|
||||
let value = Value::from(b"test_value".to_vec());
|
||||
storage.put(key, value.clone()).await.unwrap();
|
||||
|
||||
let retrieved = storage.get(key).await.unwrap();
|
||||
assert_eq!(retrieved.unwrap(), value);
|
||||
|
||||
// Test delete
|
||||
storage.delete(key).await.unwrap();
|
||||
let deleted = storage.get(key).await.unwrap();
|
||||
assert!(deleted.is_none());
|
||||
|
||||
// Test batch operations
|
||||
let pairs = vec![
|
||||
(b"key1".to_vec(), Value::from(b"value1".to_vec())),
|
||||
(b"key2".to_vec(), Value::from(b"value2".to_vec())),
|
||||
];
|
||||
|
||||
for (key, value) in pairs.clone() {
|
||||
storage.put(&key, value).await.unwrap();
|
||||
}
|
||||
|
||||
let keys: Vec<Vec<u8>> = pairs.iter().map(|(k, _)| k.clone()).collect();
|
||||
let retrieved = storage.batch_get(keys).await.unwrap();
|
||||
assert_eq!(retrieved.len(), pairs.len());
|
||||
|
||||
// Test scan
|
||||
let scanned = storage.scan(b"key", b"key3", 10).await.unwrap();
|
||||
assert_eq!(scanned.len(), 2);
|
||||
|
||||
// Cleanup
|
||||
for (key, _) in pairs {
|
||||
storage.delete(&key).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue