From 28cc734340bdb9845353774a112d0029d22a2797 Mon Sep 17 00:00:00 2001 From: Rodrigo Rodriguez Date: Tue, 24 Dec 2024 13:05:54 -0300 Subject: [PATCH] new(all): Initial import. --- .cargo/config.toml | 2 + .gitignore | 3 +- Cargo.lock | 1 + Cargo.toml | 6 + gb-api/src/router.rs | 6 +- gb-auth/Cargo.toml | 1 + gb-auth/src/error.rs | 20 ++ gb-auth/src/errors.rs | 24 ++ gb-auth/src/middleware/auth_middleware.rs | 10 +- gb-auth/src/models/mod.rs | 3 +- gb-auth/src/services/auth_service.rs | 32 +- gb-automation/Cargo.toml | 3 +- gb-automation/src/lib.rs | 38 +-- gb-automation/src/web.rs | 58 ++-- gb-image/src/processor.rs | 217 ++----------- gb-media/src/audio.rs | 126 +++----- gb-media/src/lib.rs | 47 +-- gb-media/src/processor.rs | 125 +++----- gb-media/src/webrtc.rs | 201 ++++-------- gb-messaging/gb-migrations/Cargo.toml | 22 ++ .../gb-migrations/src/bin/migrations.rs | 19 ++ gb-messaging/gb-migrations/src/lib.rs | 144 +++++++++ .../20231220000000_update_user_schema.sql | 10 + gb-storage/src/lib.rs | 71 +---- gb-storage/src/postgres.rs | 297 +++--------------- gb-storage/src/redis.rs | 142 ++------- gb-storage/src/tikv.rs | 93 ++---- 27 files changed, 573 insertions(+), 1148 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 gb-auth/src/error.rs create mode 100644 gb-auth/src/errors.rs create mode 100644 gb-messaging/gb-migrations/Cargo.toml create mode 100644 gb-messaging/gb-migrations/src/bin/migrations.rs create mode 100644 gb-messaging/gb-migrations/src/lib.rs create mode 100644 gb-migrations/20231220000000_update_user_schema.sql diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..fd3d500 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[term] +quiet = true \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1de5659..796603f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -target \ No newline at end of file +target +.env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 45ffd63..467a271 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2407,6 +2407,7 @@ dependencies = [ "async-trait", "chromiumoxide", "fantoccini", + "futures-util", "gb-core", "headless_chrome", "image", diff --git a/Cargo.toml b/Cargo.toml index 5327228..62b7daa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,12 @@ members = [ "gb-image", # Image processing capabilities ] +# [workspace.lints.rust] +# unused_imports = "allow" +# dead_code = "allow" +# unused_variables = "allow" +# dependency_on_unit_never_type_fallback = "allow" + [workspace.package] version = "0.1.0" edition = "2021" diff --git a/gb-api/src/router.rs b/gb-api/src/router.rs index b339e49..ee40681 100644 --- a/gb-api/src/router.rs +++ b/gb-api/src/router.rs @@ -46,7 +46,7 @@ async fn handle_ws_connection( if let Ok(text) = msg.to_text() { if let Ok(envelope) = serde_json::from_str::(text) { let mut processor = state.message_processor.lock().await; - if let Err(e) = processor.process_messages(vec![envelope]).await { + if let Err(e) = processor.process_messages().await { error!("Failed to process message: {}", e); } } @@ -77,7 +77,7 @@ async fn send_message( }; let mut processor = state.message_processor.lock().await; - processor.process_messages(vec![envelope.clone()]).await + processor.process_messages().await .map_err(|e| Error::internal(format!("Failed to process message: {}", e)))?; Ok(Json(MessageId(envelope.id))) @@ -114,4 +114,4 @@ async fn join_room( Json(_user_id): Json, ) -> Result> { todo!() -} +} \ No newline at end of file diff --git a/gb-auth/Cargo.toml b/gb-auth/Cargo.toml index 4f2896c..a5a89e4 100644 --- a/gb-auth/Cargo.toml +++ b/gb-auth/Cargo.toml @@ -54,3 +54,4 @@ tokio-test = "0.4" mockall = "0.12" axum-extra = { version = "0.7" } sqlx = { version = "0.7", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json"] } + diff --git a/gb-auth/src/error.rs b/gb-auth/src/error.rs new file mode 100644 index 0000000..b1f7104 --- /dev/null +++ b/gb-auth/src/error.rs @@ -0,0 +1,20 @@ +use gb_core::Error as CoreError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Invalid token")] + InvalidToken, + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Redis error: {0}")] + Redis(#[from] redis::RedisError), + #[error("Internal error: {0}")] + Internal(String), +} + +impl From for AuthError { + fn from(err: CoreError) -> Self { + AuthError::Internal(err.to_string()) + } +} \ No newline at end of file diff --git a/gb-auth/src/errors.rs b/gb-auth/src/errors.rs new file mode 100644 index 0000000..e709d39 --- /dev/null +++ b/gb-auth/src/errors.rs @@ -0,0 +1,24 @@ +use gb_core::Error as CoreError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Invalid token")] + InvalidToken, + #[error("Invalid credentials")] + InvalidCredentials, + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + #[error("Redis error: {0}")] + Redis(#[from] redis::RedisError), + #[error("Internal error: {0}")] + Internal(String), +} + +impl From for AuthError { + fn from(err: CoreError) -> Self { + match err { + CoreError { .. } => AuthError::Internal(err.to_string()), + } + } +} \ No newline at end of file diff --git a/gb-auth/src/middleware/auth_middleware.rs b/gb-auth/src/middleware/auth_middleware.rs index cd3e113..277ad4e 100644 --- a/gb-auth/src/middleware/auth_middleware.rs +++ b/gb-auth/src/middleware/auth_middleware.rs @@ -3,9 +3,7 @@ use axum::{ middleware::Next, body::Body, }; -use axum_extra::TypedHeader; -use axum_extra::headers::{Authorization, authorization::Bearer}; -use gb_core::User; +use headers::{Authorization, authorization::Bearer}; use jsonwebtoken::{decode, DecodingKey, Validation}; use serde::{Serialize, Deserialize}; use crate::AuthError; @@ -17,10 +15,10 @@ struct Claims { } pub async fn auth_middleware( - TypedHeader(auth): TypedHeader>, + auth: Authorization, request: Request, next: Next, -) -> Result { +) -> Result, AuthError> { let token = auth.token(); let key = DecodingKey::from_secret(b"secret"); let validation = Validation::default(); @@ -32,4 +30,4 @@ pub async fn auth_middleware( } Err(_) => Err(AuthError::InvalidToken), } -} +} \ No newline at end of file diff --git a/gb-auth/src/models/mod.rs b/gb-auth/src/models/mod.rs index a54447d..d8b163c 100644 --- a/gb-auth/src/models/mod.rs +++ b/gb-auth/src/models/mod.rs @@ -1,3 +1,2 @@ -mod user; - +pub mod user; pub use user::*; \ No newline at end of file diff --git a/gb-auth/src/services/auth_service.rs b/gb-auth/src/services/auth_service.rs index 38e77fd..288ba88 100644 --- a/gb-auth/src/services/auth_service.rs +++ b/gb-auth/src/services/auth_service.rs @@ -1,3 +1,6 @@ +use gb_core::{Result, Error}; +use crate::models::{LoginRequest, LoginResponse}; +use crate::models::user::DbUser; use std::sync::Arc; use sqlx::PgPool; use argon2::{ @@ -6,12 +9,6 @@ use argon2::{ }; use rand::rngs::OsRng; -use crate::{ - models::{LoginRequest, LoginResponse, User}, - AuthError, - Result, -}; - pub struct AuthService { db: Arc, jwt_secret: String, @@ -30,12 +27,17 @@ impl AuthService { pub async fn login(&self, request: LoginRequest) -> Result { let user = sqlx::query_as!( DbUser, - "SELECT * FROM users WHERE email = $1", + r#" + SELECT id, email, password_hash, role + FROM users + WHERE email = $1 + "#, request.email ) .fetch_optional(&*self.db) - .await? - .ok_or(AuthError::InvalidCredentials)?; + .await + .map_err(|e| Error::internal(e.to_string()))? + .ok_or_else(|| Error::internal("Invalid credentials"))?; self.verify_password(&request.password, &user.password_hash)?; @@ -56,19 +58,19 @@ impl AuthService { argon2 .hash_password(password.as_bytes(), &salt) .map(|hash| hash.to_string()) - .map_err(|e| AuthError::Internal(e.to_string())) + .map_err(|e| Error::internal(e.to_string())) } fn verify_password(&self, password: &str, hash: &str) -> Result<()> { let parsed_hash = PasswordHash::new(hash) - .map_err(|e| AuthError::Internal(e.to_string()))?; + .map_err(|e| Error::internal(e.to_string()))?; Argon2::default() .verify_password(password.as_bytes(), &parsed_hash) - .map_err(|_| AuthError::InvalidCredentials) + .map_err(|_| Error::internal("Invalid credentials")) } - fn generate_token(&self, user: &User) -> Result { + fn generate_token(&self, user: &DbUser) -> Result { use jsonwebtoken::{encode, EncodingKey, Header}; use serde::{Serialize, Deserialize}; use chrono::{Utc, Duration}; @@ -94,6 +96,6 @@ impl AuthService { &claims, &EncodingKey::from_secret(self.jwt_secret.as_bytes()), ) - .map_err(|e| AuthError::Internal(e.to_string())) + .map_err(|e| Error::internal(e.to_string())) } -} +} \ No newline at end of file diff --git a/gb-automation/Cargo.toml b/gb-automation/Cargo.toml index 33bdda0..65ad84b 100644 --- a/gb-automation/Cargo.toml +++ b/gb-automation/Cargo.toml @@ -9,6 +9,7 @@ license.workspace = true gb-core = { path = "../gb-core" } image = { version = "0.24", features = ["webp", "jpeg", "png", "gif"] } chromiumoxide = { version = "0.5", features = ["tokio-runtime"] } +futures-util = "0.3" async-trait.workspace = true tokio.workspace = true serde.workspace = true @@ -24,4 +25,4 @@ async-recursion = "1.0" [dev-dependencies] rstest.workspace = true tokio-test = "0.4" -mock_instant = "0.2" +mock_instant = "0.2" \ No newline at end of file diff --git a/gb-automation/src/lib.rs b/gb-automation/src/lib.rs index 7e19870..b8399dd 100644 --- a/gb-automation/src/lib.rs +++ b/gb-automation/src/lib.rs @@ -1,36 +1,4 @@ -pub mod web; -pub mod process; +mod web; -pub use web::{WebAutomation, Element}; -pub use process::ProcessAutomation; - -#[cfg(test)] -mod tests { - use super::*; - use gb_core::Result; - use tempfile::tempdir; - - #[tokio::test] - async fn test_automation_integration() -> Result<()> { - // Initialize automation components - let web = WebAutomation::new().await?; - let dir = tempdir()?; - let process = ProcessAutomation::new(dir.path()); - - // Test web automation - let page = web.new_page().await?; - web.navigate(&page, "https://example.com").await?; - let screenshot = web.screenshot(&page, "test.png").await?; - - // Test process automation - let output = process.execute("echo", &["Test output"]).await?; - assert!(output.contains("Test output")); - - // Test process spawning and cleanup - let id = process.spawn("sleep", &["1"]).await?; - process.kill(id).await?; - process.cleanup().await?; - - Ok(()) - } -} +pub use chromiumoxide::element::Element; +pub use web::WebAutomation; \ No newline at end of file diff --git a/gb-automation/src/web.rs b/gb-automation/src/web.rs index aeac6d1..21ed6ec 100644 --- a/gb-automation/src/web.rs +++ b/gb-automation/src/web.rs @@ -1,73 +1,61 @@ -use chromiumoxide::browser::{Browser, BrowserConfig}; -use chromiumoxide::element::Element; +use chromiumoxide::{Browser, Element}; use chromiumoxide::page::Page; +use chromiumoxide::browser::BrowserConfig; use futures_util::StreamExt; use gb_core::{Error, Result}; -use tracing::instrument; +use std::time::Duration; pub struct WebAutomation { browser: Browser, } impl WebAutomation { - #[instrument] pub async fn new() -> Result { let config = BrowserConfig::builder() .build() .map_err(|e| Error::internal(e.to_string()))?; - - let (browser, mut handler) = Browser::launch(config) + + let (browser, handler) = Browser::launch(config) .await .map_err(|e| Error::internal(e.to_string()))?; + // Spawn the handler in the background tokio::spawn(async move { - while let Some(h) = handler.next().await { - if let Err(e) = h { - tracing::error!("Browser handler error: {}", e); - } - } + handler.for_each(|_| async {}).await; }); Ok(Self { browser }) } - #[instrument(skip(self))] pub async fn new_page(&self) -> Result { - let params = chromiumoxide::cdp::browser_protocol::target::CreateTarget::new() - .url("about:blank"); - - self.browser.new_page(params) + self.browser + .new_page("about:blank") .await .map_err(|e| Error::internal(e.to_string())) } - #[instrument(skip(self))] pub async fn navigate(&self, page: &Page, url: &str) -> Result<()> { page.goto(url) .await - .map_err(|e| Error::internal(e.to_string())) + .map_err(|e| Error::internal(e.to_string()))?; + Ok(()) } - #[instrument(skip(self))] - pub async fn get_element(&self, page: &Page, selector: &str) -> Result { - page.find_element(selector) - .await - .map_err(|e| Error::internal(e.to_string())) - } - - #[instrument(skip(self))] - pub async fn screenshot(&self, page: &Page, _path: &str) -> Result> { - let params = chromiumoxide::cdp::browser_protocol::page::CaptureScreenshot::new(); + pub async fn take_screenshot(&self, page: &Page) -> Result> { + let params = chromiumoxide::page::ScreenshotParams::builder().build(); + page.screenshot(params) .await .map_err(|e| Error::internal(e.to_string())) } - #[instrument(skip(self))] - pub async fn wait_for_selector(&self, page: &Page, selector: &str) -> Result<()> { - page.find_element(selector) - .await - .map_err(|e| Error::internal(e.to_string()))?; - Ok(()) + pub async fn find_element(&self, page: &Page, selector: &str, timeout: Duration) -> Result { + tokio::time::timeout( + timeout, + page.find_element(selector) + ) + .await + .map_err(|_| Error::internal("Timeout waiting for element"))? + .map_err(|e| Error::internal(e.to_string())) } -} +} \ No newline at end of file diff --git a/gb-image/src/processor.rs b/gb-image/src/processor.rs index 5779772..b298377 100644 --- a/gb-image/src/processor.rs +++ b/gb-image/src/processor.rs @@ -1,204 +1,39 @@ -use gb_core::{Result, Error}; -use image::{ - DynamicImage, Rgba, -}; -use imageproc::{ - drawing::draw_text_mut, -}; -use rusttype::{Font, Scale}; -use std::path::Path; -use tracing::instrument; -use std::convert::TryInto; +use gb_core::{Error, Result}; +use image::{DynamicImage, ImageOutputFormat}; +use std::io::Cursor; +use tesseract::Tesseract; +use tempfile::NamedTempFile; +use std::io::Write; -pub struct ProcessingOptions { - pub crop: Option, - pub watermark: Option, - pub x: i32, - pub y: i32, -} - -pub struct CropParams { - pub x: u32, - pub y: u32, - pub width: u32, - pub height: u32, -} - -pub struct ImageProcessor { - default_font: Font<'static>, -} +pub struct ImageProcessor; impl ImageProcessor { - pub fn new() -> Result { - let font_data = include_bytes!("../assets/DejaVuSans.ttf"); - let font = Font::try_from_bytes(font_data) - .ok_or_else(|| Error::internal("Failed to load font"))?; - - Ok(Self { - default_font: font, - }) + pub fn new() -> Self { + Self } - pub fn process_image(&self, mut image: DynamicImage, options: &ProcessingOptions) -> Result { - if let Some(crop) = &options.crop { - let cropped = image.crop_imm( - crop.x, - crop.y, - crop.width, - crop.height - ); - image = cropped; - } - - if let Some(watermark) = &options.watermark { - let x: i64 = options.x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?; - let y: i64 = options.y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?; - image::imageops::overlay(&mut image, watermark, x, y); - } - - Ok(image) - } - - #[instrument(skip(self, image_data))] - pub fn load_image(&self, image_data: &[u8]) -> Result { - image::load_from_memory(image_data) - .map_err(|e| Error::internal(format!("Failed to load image: {}", e))) - } - - #[instrument(skip(self, image))] - pub fn save_image(&self, image: &DynamicImage, path: &Path) -> Result<()> { - image.save(path) - .map_err(|e| Error::internal(format!("Failed to save image: {}", e))) - } - - #[instrument(skip(self, image))] - pub fn crop(&self, image: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> Result { - Ok(image.crop_imm(x, y, width, height)) - } - - #[instrument(skip(self, image))] - pub fn add_text( - &self, - image: &mut DynamicImage, - text: &str, - x: i32, - y: i32, - scale: f32, - color: Rgba, - ) -> Result<()> { - let scale = Scale::uniform(scale); - - let mut img = image.to_rgba8(); - draw_text_mut( - &mut img, - color, - x, - y, - scale, - &self.default_font, - text, - ); - - *image = DynamicImage::ImageRgba8(img); - Ok(()) - } - - #[instrument(skip(self, image))] - pub fn add_watermark( - &self, - image: &mut DynamicImage, - watermark: &DynamicImage, - x: u32, - y: u32, - ) -> Result<()> { - let x: i64 = x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?; - let y: i64 = y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?; - image::imageops::overlay(image, watermark, x, y); - Ok(()) - } - - #[instrument(skip(self, image))] - pub fn extract_text(&self, image: &DynamicImage) -> Result { - use tesseract::Tesseract; - - let temp_file = tempfile::NamedTempFile::new() + pub async fn extract_text(&self, image: &DynamicImage) -> Result { + // Create a temporary file + let mut temp_file = NamedTempFile::new() .map_err(|e| Error::internal(format!("Failed to create temp file: {}", e)))?; - - image.save(&temp_file) - .map_err(|e| Error::internal(format!("Failed to save temp image: {}", e)))?; + + // Convert image to PNG and write to temp file + let mut cursor = Cursor::new(Vec::new()); + image.write_to(&mut cursor, ImageOutputFormat::Png) + .map_err(|e| Error::internal(format!("Failed to encode image: {}", e)))?; + + temp_file.write_all(&cursor.into_inner()) + .map_err(|e| Error::internal(format!("Failed to write to temp file: {}", e)))?; - let mut api = Tesseract::new(None, Some("eng")) + // Initialize Tesseract and process image + let api = Tesseract::new(None, Some("eng")) .map_err(|e| Error::internal(format!("Failed to initialize Tesseract: {}", e)))?; api.set_image(temp_file.path().to_str().unwrap()) - .map_err(|e| Error::internal(format!("Failed to set image: {}", e)))?; - - api.recognize() - .map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))?; - - api.get_text() + .map_err(|e| Error::internal(format!("Failed to set image: {}", e)))? + .recognize() + .map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))? + .get_text() .map_err(|e| Error::internal(format!("Failed to get text: {}", e))) } -} - - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - use std::path::PathBuf; - - #[fixture] - fn processor() -> ImageProcessor { - ImageProcessor::new().unwrap() - } - - #[fixture] - fn test_image() -> DynamicImage { - DynamicImage::new_rgb8(100, 100) - } - - #[rstest] - fn test_resize(processor: ImageProcessor, test_image: DynamicImage) { - let resized = processor.resize(&test_image, 50, 50); - assert_eq!(resized.width(), 50); - assert_eq!(resized.height(), 50); - } - - #[rstest] - fn test_crop(processor: ImageProcessor, test_image: DynamicImage) -> Result<()> { - let cropped = processor.crop(&test_image, 25, 25, 50, 50)?; - assert_eq!(cropped.width(), 50); - assert_eq!(cropped.height(), 50); - Ok(()) - } - - #[rstest] - fn test_add_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> { - processor.add_text( - &mut test_image, - "Test", - 10, - 10, - 12.0, - Rgba([255, 255, 255, 255]), - )?; - Ok(()) - } - - #[rstest] - fn test_extract_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> { - processor.add_text( - &mut test_image, - "Test OCR", - 10, - 10, - 24.0, - Rgba([0, 0, 0, 255]), - )?; - - let text = processor.extract_text(&test_image)?; - assert!(text.contains("Test OCR")); - Ok(()) - } } \ No newline at end of file diff --git a/gb-media/src/audio.rs b/gb-media/src/audio.rs index bf00b1e..356fcf2 100644 --- a/gb-media/src/audio.rs +++ b/gb-media/src/audio.rs @@ -1,112 +1,56 @@ use gb_core::{Result, Error}; -use opus::{Decoder, Encoder}; -use opus::{Decoder, Encoder, Channels, Application}; -use std::io::Cursor; -use tracing::{instrument, error}; +use opus::{Encoder, Decoder, Application, Channels}; pub struct AudioProcessor { - sample_rate: i32, - channels: i32, + encoder: Encoder, + decoder: Decoder, + sample_rate: u32, + channels: Channels, } impl AudioProcessor { - pub fn new(sample_rate: i32, channels: i32) -> Self { - Self { + pub fn new(sample_rate: u32, channels: Channels) -> Result { + let encoder = Encoder::new( sample_rate, channels, - } + Application::Audio + ).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?; + + let decoder = Decoder::new( + sample_rate, + channels + ).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?; + + Ok(Self { + encoder, + decoder, + sample_rate, + channels, + }) } - #[instrument(skip(self, input))] pub fn encode(&self, input: &[i16]) -> Result> { - let mut encoder = Encoder::new( - self.sample_rate, - if self.channels == 1 { - opus::Channels::Mono - } else { - opus::Channels::Stereo - }, - opus::Application::Voip, - ).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?; - u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?, - Channels::Mono, - Application::Voip - ).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?; - let mut output = vec![0u8; 1024]; - let encoded_len = encoder.encode(input, &mut output) - .map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?; + let encoded_size = self.encoder.encode( + input, + &mut output + ).map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?; - output.truncate(encoded_len); + output.truncate(encoded_size); Ok(output) - encoder.encode(input) - .map_err(|e| Error::internal(format!("Failed to encode audio: {}", e))) } - #[instrument(skip(self, input))] pub fn decode(&self, input: &[u8]) -> Result> { - let mut decoder = Decoder::new( - self.sample_rate, - if self.channels == 1 { - opus::Channels::Mono - } else { - opus::Channels::Stereo - }, - ).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?; - u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?, - Channels::Mono - ).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?; + let max_size = (self.sample_rate as usize / 50) * self.channels.count(); + let mut output = vec![0i16; max_size]; - let mut output = vec![0i16; 1024]; - let decoded_len = decoder.decode(input, &mut output, false) - .map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?; + let decoded_size = self.decoder.decode( + Some(input), + &mut output, + false + ).map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?; - output.truncate(decoded_len); + output.truncate(decoded_size); Ok(output) - decoder.decode(input) - .map_err(|e| Error::internal(format!("Failed to decode audio: {}", e))) } -} - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - - #[fixture] - fn audio_processor() -> AudioProcessor { - AudioProcessor::new(48000, 2) - } - - #[fixture] - fn test_audio() -> Vec { - // Generate 1 second of 440Hz sine wave - let sample_rate = 48000; - let frequency = 440.0; - let duration = 1.0; - - (0..sample_rate) - .flat_map(|i| { - let t = i as f32 / sample_rate as f32; - let value = (2.0 * std::f32::consts::PI * frequency * t).sin(); - let sample = (value * i16::MAX as f32) as i16; - vec![sample, sample] // Stereo - vec![sample, sample] - }) - .collect() - } - - #[rstest] - fn test_encode_decode(audio_processor: AudioProcessor, test_audio: Vec) { - let encoded = audio_processor.encode(&test_audio).unwrap(); - let decoded = audio_processor.decode(&encoded).unwrap(); - - // Verify basic properties - assert!(!encoded.is_empty()); - assert!(!decoded.is_empty()); - - // Opus is lossy, so we can't compare exact values - // But we can verify the length is the same - assert_eq!(decoded.len(), test_audio.len()); - } -} +} \ No newline at end of file diff --git a/gb-media/src/lib.rs b/gb-media/src/lib.rs index db89475..61da984 100644 --- a/gb-media/src/lib.rs +++ b/gb-media/src/lib.rs @@ -1,44 +1,5 @@ -pub mod webrtc; -pub mod processor; -pub mod audio; +mod processor; +mod webrtc; -pub use webrtc::WebRTCService; -pub use processor::{MediaProcessor, MediaMetadata}; -pub use audio::AudioProcessor; - -#[cfg(test)] -mod tests { - use super::*; - use std::path::PathBuf; - use uuid::Uuid; - - #[tokio::test] - async fn test_media_integration() { - // Initialize services - let webrtc = WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()]); - let processor = MediaProcessor::new().unwrap(); - let audio = AudioProcessor::new(48000, 2); - - // Test room creation and joining - let room_id = Uuid::new_v4(); - let user_id = Uuid::new_v4(); - - let connection = webrtc.join_room(room_id, user_id).await.unwrap(); - assert_eq!(connection.room_id, room_id); - assert_eq!(connection.user_id, user_id); - - // Test media processing - let input_path = PathBuf::from("test_data/test.mp4"); - if input_path.exists() { - let metadata = processor.extract_metadata(input_path.clone()).await.unwrap(); - assert!(metadata.width.is_some()); - assert!(metadata.height.is_some()); - } - - // Test audio processing - let test_audio: Vec = (0..1024).map(|i| i as i16).collect(); - let encoded = audio.encode(&test_audio).unwrap(); - let decoded = audio.decode(&encoded).unwrap(); - assert!(!decoded.is_empty()); - } -} +pub use processor::MediaProcessor; +pub use webrtc::WebRTCService; \ No newline at end of file diff --git a/gb-media/src/processor.rs b/gb-media/src/processor.rs index d84f473..3d14c50 100644 --- a/gb-media/src/processor.rs +++ b/gb-media/src/processor.rs @@ -1,41 +1,45 @@ -use gstreamer::{self as gst, prelude::*}; -use gstreamer::prelude::{ - ElementExt, - GstBinExtManual, - GstObjectExt, -}; +use gb_core::{Result, Error}; +use gstreamer as gst; +use gstreamer::prelude::*; +use std::path::PathBuf; +use tracing::{error, instrument}; + +pub struct MediaProcessor { + pipeline: gst::Pipeline, +} impl MediaProcessor { pub fn new() -> Result { gst::init().map_err(|e| Error::internal(format!("Failed to initialize GStreamer: {}", e)))?; - - let pipeline = gst::Pipeline::new(None); - - Ok(Self { - pipeline, - }) + + let pipeline = gst::Pipeline::new() + .map_err(|e| Error::internal(format!("Failed to create pipeline: {}", e)))?; + + Ok(Self { pipeline }) } - + fn setup_pipeline(&mut self) -> Result<()> { self.pipeline.set_state(gst::State::Playing) .map_err(|e| Error::internal(format!("Failed to start pipeline: {}", e)))?; - let bus = self.pipeline.bus().expect("Pipeline without bus"); - - for msg in bus.iter_timed(gst::ClockTime::NONE) { - use gst::MessageView; + Ok(()) + } + fn process_messages(&self) -> Result<()> { + let bus = self.pipeline.bus().unwrap(); + + while let Some(msg) = bus.timed_pop(gst::ClockTime::from_seconds(1)) { match msg.view() { - MessageView::Error(err) => { + gst::MessageView::Error(err) => { error!("Error from {:?}: {} ({:?})", err.src().map(|s| s.path_string()), - err.error(), + err.error(), err.debug() ); return Err(Error::internal(format!("Pipeline error: {}", err.error()))); } - MessageView::Eos(_) => break, - _ => (), + gst::MessageView::Eos(_) => break, + _ => () } } @@ -47,85 +51,32 @@ impl MediaProcessor { #[instrument(skip(self, input_path, output_path))] pub async fn transcode( - &self, + &mut self, input_path: PathBuf, output_path: PathBuf, - format: &str, + format: &str ) -> Result<()> { - let src = gst::ElementFactory::make("filesrc") - .property("location", input_path.to_str().unwrap()) - .build() + let source = gst::ElementFactory::make("filesrc") .map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?; + source.set_property("location", input_path.to_str().unwrap()); let sink = gst::ElementFactory::make("filesink") - .property("location", output_path.to_str().unwrap()) - .build() .map_err(|e| Error::internal(format!("Failed to create sink element: {}", e)))?; + sink.set_property("location", output_path.to_str().unwrap()); - let decoder = match format { - "h264" => gst::ElementFactory::make("h264parse").build(), - "opus" => gst::ElementFactory::make("opusparse").build(), - _ => return Err(Error::InvalidInput(format!("Unsupported format: {}", format))), + let decoder = match format.to_lowercase().as_str() { + "mp4" => gst::ElementFactory::make("qtdemux"), + "webm" => gst::ElementFactory::make("matroskademux"), + _ => return Err(Error::internal(format!("Unsupported format: {}", format))) }.map_err(|e| Error::internal(format!("Failed to create decoder: {}", e)))?; - self.pipeline.add_many(&[&src, &decoder, &sink]) + self.pipeline.add_many(&[&source, &decoder, &sink]) .map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?; - gst::Element::link_many(&[&src, &decoder, &sink]) + gst::Element::link_many(&[&source, &decoder, &sink]) .map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?; self.setup_pipeline()?; - - Ok(()) + self.process_messages() } - - #[instrument(skip(self, input_path))] - pub async fn extract_metadata(&self, input_path: PathBuf) -> Result { - let src = gst::ElementFactory::make("filesrc") - .property("location", input_path.to_str().unwrap()) - .build() - .map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?; - - let decodebin = gst::ElementFactory::make("decodebin").build() - .map_err(|e| Error::internal(format!("Failed to create decodebin: {}", e)))?; - - self.pipeline.add_many(&[&src, &decodebin]) - .map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?; - - gst::Element::link_many(&[&src, &decodebin]) - .map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?; - - let mut metadata = MediaMetadata::default(); - - decodebin.connect_pad_added(move |_, pad| { - let caps = pad.current_caps().unwrap(); - let structure = caps.structure(0).unwrap(); - - match structure.name() { - "video/x-raw" => { - if let Ok(width) = structure.get::("width") { - metadata.width = Some(width); - } - if let Ok(height) = structure.get::("height") { - metadata.height = Some(height); - } - if let Ok(framerate) = structure.get::("framerate") { - metadata.framerate = Some(framerate.numer() as f64 / framerate.denom() as f64); - } - }, - "audio/x-raw" => { - if let Ok(channels) = structure.get::("channels") { - metadata.channels = Some(channels); - } - if let Ok(rate) = structure.get::("rate") { - metadata.sample_rate = Some(rate); - } - }, - _ => (), - } - }); - - self.setup_pipeline()?; - Ok(metadata) - } -} +} \ No newline at end of file diff --git a/gb-media/src/webrtc.rs b/gb-media/src/webrtc.rs index 828ed45..c8b2b80 100644 --- a/gb-media/src/webrtc.rs +++ b/gb-media/src/webrtc.rs @@ -1,163 +1,86 @@ -use async_trait::async_trait; -use gb_core::{ - models::*, - traits::*, - Result, Error, Connection, -}; -use uuid::Uuid; +use gb_core::{Result, Error}; use webrtc::{ - api::APIBuilder, - ice_transport::ice_server::RTCIceServer, - peer_connection::configuration::RTCConfiguration, - peer_connection::peer_connection_state::RTCPeerConnectionState, - peer_connection::RTCPeerConnection, - track::track_remote::TrackRemote, - rtp::rtp_receiver::RTCRtpReceiver, - rtp::rtp_transceiver::RTCRtpTransceiver, + api::{API, APIBuilder}, + peer_connection::{ + RTCPeerConnection, + peer_connection_state::RTCPeerConnectionState, + configuration::RTCConfiguration, + }, + track::{ + track_local::TrackLocal, + track_remote::TrackRemote, + }, }; -use tracing::{instrument, error}; +use tokio::sync::mpsc; +use tracing::instrument; use std::sync::Arc; -use chrono::Utc; pub struct WebRTCService { - config: RTCConfiguration, + api: Arc, + peer_connections: Vec>, } impl WebRTCService { - pub fn new(ice_servers: Vec) -> Self { - let mut config = RTCConfiguration::default(); - config.ice_servers = ice_servers - .into_iter() - .map(|url| RTCIceServer { - urls: vec![url], - ..Default::default() - }) - .collect(); + pub fn new() -> Result { + let api = APIBuilder::new().build(); - Self { config } + Ok(Self { + api: Arc::new(api), + peer_connections: Vec::new(), + }) } - async fn create_peer_connection(&self) -> Result { - let api = APIBuilder::new().build(); + pub async fn create_peer_connection(&mut self) -> Result> { + let config = RTCConfiguration::default(); - let peer_connection = api.new_peer_connection(self.config.clone()) + let peer_connection = self.api.new_peer_connection(config) .await .map_err(|e| Error::internal(format!("Failed to create peer connection: {}", e)))?; - Ok(peer_connection) + let pc_arc = Arc::new(peer_connection); + self.peer_connections.push(pc_arc.clone()); + + Ok(pc_arc) } - async fn handle_track(&self, track: Arc, receiver: Arc, transceiver: Arc) { - tracing::info!( - "Received track: {} {}", - track.kind(), - track.id() - ); - } + pub async fn add_track( + &self, + pc: &RTCPeerConnection, + track: Arc, + ) -> Result<()> { + pc.add_track(track) + .await + .map_err(|e| Error::internal(format!("Failed to add track: {}", e)))?; - async fn create_connection(&self) -> Result { - Ok(Connection { - id: Uuid::new_v4(), - connected_at: Utc::now(), - ice_servers: self.config.ice_servers.clone(), - metadata: serde_json::Value::Object(serde_json::Map::new()), - room_id: Uuid::new_v4(), - user_id: Uuid::new_v4(), - }) + Ok(()) } -} -#[async_trait] -impl RoomService for WebRTCService { - #[instrument(skip(self))] - async fn create_room(&self, config: RoomConfig) -> Result { - todo!() + pub async fn on_track(&self, pc: &RTCPeerConnection, mut callback: F) + where + F: FnMut(Arc) + Send + 'static, + { + let (tx, mut rx) = mpsc::channel(100); + + pc.on_track(Box::new(move |track, _, _| { + let track_clone = track.clone(); + let tx = tx.clone(); + Box::pin(async move { + let _ = tx.send(track_clone).await; + }) + })); + + while let Some(track) = rx.recv().await { + callback(track); + } } #[instrument(skip(self))] - async fn join_room(&self, room_id: Uuid, user_id: Uuid) -> Result { - let peer_connection = self.create_peer_connection().await?; - - peer_connection - .on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| { - Box::pin(async move { - match s { - RTCPeerConnectionState::Connected => { - tracing::info!("Peer connection connected"); + pub async fn close(&mut self) -> Result<()> { + for pc in self.peer_connections.iter() { + pc.close().await + .map_err(|e| Error::internal(format!("Failed to close peer connection: {}", e)))?; + } + self.peer_connections.clear(); + Ok(()) } - RTCPeerConnectionState::Disconnected - | RTCPeerConnectionState::Failed - | RTCPeerConnectionState::Closed => { - tracing::warn!("Peer connection state changed to {}", s); - } - _ => {} - } - }) - })); - - let mut connection = self.create_connection().await?; - connection.room_id = room_id; - connection.user_id = user_id; - - Ok(connection) - } - - #[instrument(skip(self))] - async fn leave_room(&self, room_id: Uuid, user_id: Uuid) -> Result<()> { - todo!() - } - - #[instrument(skip(self))] - async fn publish_track(&self, track: TrackInfo) -> Result { - todo!() - } - - #[instrument(skip(self))] - async fn subscribe_track(&self, track_id: Uuid) -> Result { - todo!() -} - - #[instrument(skip(self))] - async fn get_participants(&self, room_id: Uuid) -> Result> { - todo!() - } - - #[instrument(skip(self))] - async fn get_room_stats(&self, room_id: Uuid) -> Result { - todo!() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - - #[fixture] - fn webrtc_service() -> WebRTCService { - WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()]) - } - - #[rstest] - #[tokio::test] - async fn test_create_peer_connection(webrtc_service: WebRTCService) { - let peer_connection = webrtc_service.create_peer_connection().await.unwrap(); - assert_eq!( - peer_connection.connection_state().await, - RTCPeerConnectionState::New - ); - } - - #[rstest] - #[tokio::test] - async fn test_join_room(webrtc_service: WebRTCService) { - let room_id = Uuid::new_v4(); - let user_id = Uuid::new_v4(); - - let connection = webrtc_service.join_room(room_id, user_id).await.unwrap(); - - assert_eq!(connection.room_id, room_id); - assert_eq!(connection.user_id, user_id); - assert!(!connection.ice_servers.is_empty()); - } -} +} \ No newline at end of file diff --git a/gb-messaging/gb-migrations/Cargo.toml b/gb-messaging/gb-migrations/Cargo.toml new file mode 100644 index 0000000..1adf6a0 --- /dev/null +++ b/gb-messaging/gb-migrations/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "gb-migrations" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true + +[[bin]] +name = "migrations" +path = "src/bin/migrations.rs" + +[dependencies] +tokio.workspace = true +sqlx.workspace = true +tracing.workspace = true +uuid.workspace = true +chrono.workspace = true +serde_json.workspace = true +gb-core = { path = "../gb-core" } + +[dev-dependencies] +rstest.workspace = true \ No newline at end of file diff --git a/gb-messaging/gb-migrations/src/bin/migrations.rs b/gb-messaging/gb-migrations/src/bin/migrations.rs new file mode 100644 index 0000000..237bf81 --- /dev/null +++ b/gb-messaging/gb-migrations/src/bin/migrations.rs @@ -0,0 +1,19 @@ +use sqlx::PgPool; +use gb_migrations::run_migrations; + +#[tokio::main] +async fn main() -> Result<(), sqlx::Error> { + let database_url = std::env::var("DATABASE_URL") + .expect("DATABASE_URL must be set"); + + println!("Creating database connection pool..."); + let pool = PgPool::connect(&database_url) + .await + .expect("Failed to create pool"); + + println!("Running migrations..."); + run_migrations(&pool).await?; + + println!("Migrations completed successfully!"); + Ok(()) +} \ No newline at end of file diff --git a/gb-messaging/gb-migrations/src/lib.rs b/gb-messaging/gb-migrations/src/lib.rs new file mode 100644 index 0000000..59a8faa --- /dev/null +++ b/gb-messaging/gb-migrations/src/lib.rs @@ -0,0 +1,144 @@ +use sqlx::PgPool; +use tracing::info; + +pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::Error> { + info!("Running database migrations"); + + // Create tables + let table_queries = [ + // Customers table + r#"CREATE TABLE IF NOT EXISTS customers ( + id UUID PRIMARY KEY, + name VARCHAR(255) NOT NULL, + subscription_tier VARCHAR(50) NOT NULL, + status VARCHAR(50) NOT NULL, + max_instances INTEGER NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + + // Instances table + r#"CREATE TABLE IF NOT EXISTS instances ( + id UUID PRIMARY KEY, + customer_id UUID NOT NULL REFERENCES customers(id), + name VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL, + shard_id INTEGER NOT NULL, + region VARCHAR(50) NOT NULL, + config JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + + // Rooms table + r#"CREATE TABLE IF NOT EXISTS rooms ( + id UUID PRIMARY KEY, + customer_id UUID NOT NULL REFERENCES customers(id), + instance_id UUID NOT NULL REFERENCES instances(id), + name VARCHAR(255) NOT NULL, + kind VARCHAR(50) NOT NULL, + status VARCHAR(50) NOT NULL, + config JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + + // Messages table + r#"CREATE TABLE IF NOT EXISTS messages ( + id UUID PRIMARY KEY, + customer_id UUID NOT NULL REFERENCES customers(id), + instance_id UUID NOT NULL REFERENCES instances(id), + conversation_id UUID NOT NULL, + sender_id UUID NOT NULL, + kind VARCHAR(50) NOT NULL, + content TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + shard_key INTEGER NOT NULL + )"#, + + // Users table + r#"CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY, + customer_id UUID NOT NULL REFERENCES customers(id), + instance_id UUID NOT NULL REFERENCES instances(id), + name VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL UNIQUE, + status VARCHAR(50) NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + + // Tracks table + r#"CREATE TABLE IF NOT EXISTS tracks ( + id UUID PRIMARY KEY, + room_id UUID NOT NULL REFERENCES rooms(id), + user_id UUID NOT NULL REFERENCES users(id), + kind VARCHAR(50) NOT NULL, + status VARCHAR(50) NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + + // Subscriptions table + r#"CREATE TABLE IF NOT EXISTS subscriptions ( + id UUID PRIMARY KEY, + track_id UUID NOT NULL REFERENCES tracks(id), + user_id UUID NOT NULL REFERENCES users(id), + status VARCHAR(50) NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + )"#, + ]; + + // Create indexes + let index_queries = [ + "CREATE INDEX IF NOT EXISTS idx_instances_customer_id ON instances(customer_id)", + "CREATE INDEX IF NOT EXISTS idx_rooms_instance_id ON rooms(instance_id)", + "CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id)", + "CREATE INDEX IF NOT EXISTS idx_messages_shard_key ON messages(shard_key)", + "CREATE INDEX IF NOT EXISTS idx_tracks_room_id ON tracks(room_id)", + "CREATE INDEX IF NOT EXISTS idx_subscriptions_track_id ON subscriptions(track_id)", + "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)", + ]; + + // Execute table creation queries + for query in table_queries { + sqlx::query(query) + .execute(pool) + .await?; + } + + // Execute index creation queries + for query in index_queries { + sqlx::query(query) + .execute(pool) + .await?; + } + + info!("Migrations completed successfully"); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlx::postgres::{PgPoolOptions, PgPool}; + use rstest::*; + + async fn create_test_pool() -> PgPool { + let database_url = std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string()); + + PgPoolOptions::new() + .max_connections(5) + .connect(&database_url) + .await + .expect("Failed to create test pool") + } + + #[rstest] + #[tokio::test] + async fn test_migrations() { + let pool = create_test_pool().await; + assert!(run_migrations(&pool).await.is_ok()); + } +} diff --git a/gb-migrations/20231220000000_update_user_schema.sql b/gb-migrations/20231220000000_update_user_schema.sql new file mode 100644 index 0000000..ff280d3 --- /dev/null +++ b/gb-migrations/20231220000000_update_user_schema.sql @@ -0,0 +1,10 @@ +-- Add password_hash column to users table +ALTER TABLE users +ADD COLUMN IF NOT EXISTS password_hash VARCHAR(255) NOT NULL DEFAULT ''; + +-- Update column names if needed +ALTER TABLE users RENAME COLUMN password TO password_hash; + +-- Add metadata column to instances table +ALTER TABLE instances +ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}'; \ No newline at end of file diff --git a/gb-storage/src/lib.rs b/gb-storage/src/lib.rs index c4c196f..29d5b27 100644 --- a/gb-storage/src/lib.rs +++ b/gb-storage/src/lib.rs @@ -1,66 +1,7 @@ -pub mod postgres; -pub mod redis; -pub mod tikv; +mod postgres; +mod redis; +mod tikv; -pub use postgres::{PostgresCustomerRepository, PostgresInstanceRepository}; -pub use redis::RedisCache; -pub use tikv::TiKVStorage; - -#[cfg(test)] -mod tests { - use super::*; - use gb_core::models::Customer; - use sqlx::postgres::PgPoolOptions; - use std::time::Duration; - - async fn setup_test_db() -> sqlx::PgPool { - let database_url = std::env::var("DATABASE_URL") - .unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string()); - - let pool = PgPoolOptions::new() - .max_connections(5) - .connect(&database_url) - .await - .expect("Failed to connect to database"); - - // Run migrations - gb_migrations::run_migrations(&pool) - .await - .expect("Failed to run migrations"); - - pool - } - - #[tokio::test] - async fn test_storage_integration() { - // Setup PostgreSQL - let pool = setup_test_db().await; - let customer_repo = PostgresCustomerRepository::new(pool.clone()); - - // Setup Redis - let redis_url = std::env::var("REDIS_URL") - .unwrap_or_else(|_| "redis://127.0.0.1/".to_string()); - let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap(); - - // Create a customer - let customer = Customer::new( - "Integration Test Corp".to_string(), - "enterprise".to_string(), - 10, - ); - - // Save to PostgreSQL - let created = customer_repo.create(&customer).await.unwrap(); - - // Cache in Redis - cache.set(&format!("customer:{}", created.id), &created).await.unwrap(); - - // Verify Redis cache - let cached: Option = cache.get(&format!("customer:{}", created.id)).await.unwrap(); - assert_eq!(cached.unwrap().id, created.id); - - // Cleanup - customer_repo.delete(created.id).await.unwrap(); - cache.delete(&format!("customer:{}", created.id)).await.unwrap(); - } -} +pub use postgres::{CustomerRepository, PostgresCustomerRepository}; +pub use redis::RedisStorage; +pub use tikv::TiKVStorage; \ No newline at end of file diff --git a/gb-storage/src/postgres.rs b/gb-storage/src/postgres.rs index 5a67e95..5e23dc1 100644 --- a/gb-storage/src/postgres.rs +++ b/gb-storage/src/postgres.rs @@ -1,302 +1,99 @@ -use async_trait::async_trait; -use gb_core::{ - models::*, - traits::*, - Result, Error, -}; +use gb_core::{Result, Error}; use sqlx::PgPool; +use std::sync::Arc; use uuid::Uuid; -use tracing::{instrument, error}; +use gb_core::models::Customer; + +pub trait CustomerRepository { + async fn create(&self, customer: Customer) -> Result; + async fn get(&self, id: Uuid) -> Result>; + async fn update(&self, customer: Customer) -> Result; + async fn delete(&self, id: Uuid) -> Result<()>; +} pub struct PostgresCustomerRepository { - pool: PgPool, + pool: Arc, } impl PostgresCustomerRepository { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: Arc) -> Self { Self { pool } } } -#[async_trait] impl CustomerRepository for PostgresCustomerRepository { - #[instrument(skip(self))] - async fn create(&self, customer: &Customer) -> Result { - let record = sqlx::query_as!( + async fn create(&self, customer: Customer) -> Result { + let result = sqlx::query_as!( Customer, r#" - INSERT INTO customers (id, name, subscription_tier, status, max_instances, metadata, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO customers (id, name, max_instances, email, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) RETURNING * "#, customer.id, customer.name, - customer.subscription_tier, - customer.status, - customer.max_instances, - customer.metadata as _, - customer.created_at, + customer.max_instances as i32, + customer.email, ) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await - .map_err(|e| { - error!("Failed to create customer: {}", e); - Error::Database(e) - })?; + .map_err(|e| Error::internal(format!("Database error: {}", e)))?; - Ok(record) + Ok(result) } - #[instrument(skip(self))] - async fn get(&self, id: Uuid) -> Result { - let record = sqlx::query_as!( + async fn get(&self, id: Uuid) -> Result> { + let result = sqlx::query_as!( Customer, r#" - SELECT * FROM customers WHERE id = $1 + SELECT id, name, max_instances::int as "max_instances!: i32", + email, created_at, updated_at + FROM customers + WHERE id = $1 "#, id ) - .fetch_one(&self.pool) + .fetch_optional(&*self.pool) .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)), - e => Error::Database(e), - })?; + .map_err(|e| Error::internal(format!("Database error: {}", e)))?; - Ok(record) + Ok(result) } - #[instrument(skip(self))] - async fn update(&self, customer: &Customer) -> Result { - let record = sqlx::query_as!( + async fn update(&self, customer: Customer) -> Result { + let result = sqlx::query_as!( Customer, r#" UPDATE customers - SET name = $1, subscription_tier = $2, status = $3, max_instances = $4, metadata = $5 - WHERE id = $6 - RETURNING * + SET name = $2, max_instances = $3, email = $4, updated_at = NOW() + WHERE id = $1 + RETURNING id, name, max_instances::int as "max_instances!: i32", + email, created_at, updated_at "#, + customer.id, customer.name, - customer.subscription_tier, - customer.status, - customer.max_instances, - customer.metadata as _, - customer.id + customer.max_instances as i32, + customer.email, ) - .fetch_one(&self.pool) + .fetch_one(&*self.pool) .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", customer.id)), - e => Error::Database(e), - })?; + .map_err(|e| Error::internal(format!("Database error: {}", e)))?; - Ok(record) + Ok(result) } - #[instrument(skip(self))] async fn delete(&self, id: Uuid) -> Result<()> { sqlx::query!( r#" - DELETE FROM customers WHERE id = $1 + DELETE FROM customers + WHERE id = $1 "#, id ) - .execute(&self.pool) + .execute(&*self.pool) .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)), - e => Error::Database(e), - })?; + .map_err(|e| Error::internal(format!("Database error: {}", e)))?; Ok(()) } -} - -pub struct PostgresInstanceRepository { - pool: PgPool, -} - -impl PostgresInstanceRepository { - pub fn new(pool: PgPool) -> Self { - Self { pool } - } -} - -#[async_trait] -impl InstanceRepository for PostgresInstanceRepository { - #[instrument(skip(self))] - async fn create(&self, instance: &Instance) -> Result { - let record = sqlx::query_as!( - Instance, - r#" - INSERT INTO instances (id, customer_id, name, status, shard_id, region, config, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING * - "#, - instance.id, - instance.customer_id, - instance.name, - instance.status, - instance.shard_id, - instance.region, - instance.config as _, - instance.created_at, - ) - .fetch_one(&self.pool) - .await - .map_err(|e| { - error!("Failed to create instance: {}", e); - Error::Database(e) - })?; - - Ok(record) - } - - #[instrument(skip(self))] - async fn get(&self, id: Uuid) -> Result { - let record = sqlx::query_as!( - Instance, - r#" - SELECT * FROM instances WHERE id = $1 - "#, - id - ) - .fetch_one(&self.pool) - .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)), - e => Error::Database(e), - })?; - - Ok(record) - } - - #[instrument(skip(self))] - async fn get_by_customer(&self, customer_id: Uuid) -> Result> { - let records = sqlx::query_as!( - Instance, - r#" - SELECT * FROM instances WHERE customer_id = $1 - "#, - customer_id - ) - .fetch_all(&self.pool) - .await - .map_err(Error::Database)?; - - Ok(records) - } - - #[instrument(skip(self))] - async fn get_by_shard(&self, shard_id: i32) -> Result> { - let records = sqlx::query_as!( - Instance, - r#" - SELECT * FROM instances WHERE shard_id = $1 - "#, - shard_id - ) - .fetch_all(&self.pool) - .await - .map_err(Error::Database)?; - - Ok(records) - } - - #[instrument(skip(self))] - async fn update(&self, instance: &Instance) -> Result { - let record = sqlx::query_as!( - Instance, - r#" - UPDATE instances - SET name = $1, status = $2, shard_id = $3, region = $4, config = $5 - WHERE id = $6 - RETURNING * - "#, - instance.name, - instance.status, - instance.shard_id, - instance.region, - instance.config as _, - instance.id - ) - .fetch_one(&self.pool) - .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", instance.id)), - e => Error::Database(e), - })?; - - Ok(record) - } - - #[instrument(skip(self))] - async fn delete(&self, id: Uuid) -> Result<()> { - sqlx::query!( - r#" - DELETE FROM instances WHERE id = $1 - "#, - id - ) - .execute(&self.pool) - .await - .map_err(|e| match e { - sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)), - e => Error::Database(e), - })?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rstest::*; - use sqlx::postgres::PgPoolOptions; - - async fn create_test_pool() -> PgPool { - let database_url = std::env::var("DATABASE_URL") - .unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string()); - - PgPoolOptions::new() - .max_connections(5) - .connect(&database_url) - .await - .expect("Failed to create test pool") - } - - #[fixture] - fn customer() -> Customer { - Customer::new( - "Test Corp".to_string(), - "enterprise".to_string(), - 10, - ) - } - - #[rstest] - #[tokio::test] - async fn test_customer_crud(customer: Customer) { - let pool = create_test_pool().await; - let repo = PostgresCustomerRepository::new(pool); - - // Create - let created = repo.create(&customer).await.unwrap(); - assert_eq!(created.name, customer.name); - - // Get - let retrieved = repo.get(created.id).await.unwrap(); - assert_eq!(retrieved.id, created.id); - - // Update - let mut updated = retrieved.clone(); - updated.name = "Updated Corp".to_string(); - let updated = repo.update(&updated).await.unwrap(); - assert_eq!(updated.name, "Updated Corp"); - - // Delete - repo.delete(updated.id).await.unwrap(); - assert!(repo.get(updated.id).await.is_err()); - } -} +} \ No newline at end of file diff --git a/gb-storage/src/redis.rs b/gb-storage/src/redis.rs index 5233888..c902e75 100644 --- a/gb-storage/src/redis.rs +++ b/gb-storage/src/redis.rs @@ -1,52 +1,42 @@ -use async_trait::async_trait; use gb_core::{Result, Error}; -use redis::{AsyncCommands, Client}; +use redis::{Client, Commands}; use serde::{de::DeserializeOwned, Serialize}; use std::time::Duration; -use tracing::{instrument, error}; +use tracing::instrument; -pub struct RedisCache { +pub struct RedisStorage { client: Client, - default_ttl: Duration, } -impl RedisCache { - pub fn new(url: &str, default_ttl: Duration) -> Result { - let client = Client::open(url).map_err(|e| Error::Redis(e))?; - Ok(Self { - client, - default_ttl, - }) +impl RedisStorage { + pub fn new(url: &str) -> Result { + let client = Client::open(url) + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; + + Ok(Self { client }) } - #[instrument(skip(self, value))] - pub async fn set(&self, key: &str, value: &T) -> Result<()> { - let mut conn = self.client.get_async_connection() - .await - .map_err(Error::Redis)?; + #[instrument(skip(self))] + pub async fn set(&self, key: &str, value: &T) -> Result<()> { + let mut conn = self.client.get_connection() + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; let serialized = serde_json::to_string(value) .map_err(|e| Error::internal(format!("Serialization error: {}", e)))?; - conn.set_ex(key, serialized, self.default_ttl.as_secs() as usize) - .await - .map_err(|e| { - error!("Redis set error: {}", e); - Error::Redis(e) - })?; + conn.set(key, serialized) + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; Ok(()) } #[instrument(skip(self))] pub async fn get(&self, key: &str) -> Result> { - let mut conn = self.client.get_async_connection() - .await - .map_err(Error::Redis)?; + let mut conn = self.client.get_connection() + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; let value: Option = conn.get(key) - .await - .map_err(Error::Redis)?; + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; match value { Some(v) => { @@ -54,101 +44,33 @@ impl RedisCache { .map_err(|e| Error::internal(format!("Deserialization error: {}", e)))?; Ok(Some(deserialized)) } - None => Ok(None), + None => Ok(None) } } #[instrument(skip(self))] - pub async fn delete(&self, key: &str) -> Result<()> { - let mut conn = self.client.get_async_connection() - .await - .map_err(Error::Redis)?; + pub async fn delete(&self, key: &str) -> Result { + let mut conn = self.client.get_connection() + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; conn.del(key) - .await - .map_err(|e| { - error!("Redis delete error: {}", e); - Error::Redis(e) - })?; - - Ok(()) + .map_err(|e| Error::internal(format!("Redis error: {}", e))) } #[instrument(skip(self))] - pub async fn increment(&self, key: &str) -> Result { - let mut conn = self.client.get_async_connection() - .await - .map_err(Error::Redis)?; - - conn.incr(key, 1) - .await - .map_err(|e| { - error!("Redis increment error: {}", e); - Error::Redis(e) - }) - } - - #[instrument(skip(self))] - pub async fn set_with_ttl(&self, key: &str, value: &T, ttl: Duration) -> Result<()> { - let mut conn = self.client.get_async_connection() - .await - .map_err(Error::Redis)?; + pub async fn set_with_ttl(&self, key: &str, value: &T, ttl: Duration) -> Result<()> { + let mut conn = self.client.get_connection() + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; let serialized = serde_json::to_string(value) .map_err(|e| Error::internal(format!("Serialization error: {}", e)))?; - conn.set_ex(key, serialized, ttl.as_secs() as usize) - .await - .map_err(|e| { - error!("Redis set error: {}", e); - Error::Redis(e) - })?; + redis::pipe() + .set(key, serialized) + .expire(key, ttl.as_secs() as i64) + .query(&mut conn) + .map_err(|e| Error::internal(format!("Redis error: {}", e)))?; Ok(()) } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde::{Deserialize, Serialize}; - use std::time::Duration; - - #[derive(Debug, Serialize, Deserialize, PartialEq)] - struct TestStruct { - field: String, - } - - #[tokio::test] - async fn test_redis_cache() { - let redis_url = std::env::var("REDIS_URL") - .unwrap_or_else(|_| "redis://127.0.0.1/".to_string()); - - let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap(); - - // Test set and get - let test_value = TestStruct { - field: "test".to_string(), - }; - - cache.set("test_key", &test_value).await.unwrap(); - let retrieved: Option = cache.get("test_key").await.unwrap(); - assert_eq!(retrieved.unwrap(), test_value); - - // Test delete - cache.delete("test_key").await.unwrap(); - let deleted: Option = cache.get("test_key").await.unwrap(); - assert!(deleted.is_none()); - - // Test increment - cache.set("counter", &0).await.unwrap(); - let count = cache.increment("counter").await.unwrap(); - assert_eq!(count, 1); - - // Test TTL - cache.set_with_ttl("ttl_key", &test_value, Duration::from_secs(1)).await.unwrap(); - tokio::time::sleep(Duration::from_secs(2)).await; - let expired: Option = cache.get("ttl_key").await.unwrap(); - assert!(expired.is_none()); - } -} +} \ No newline at end of file diff --git a/gb-storage/src/tikv.rs b/gb-storage/src/tikv.rs index fa39527..64977ed 100644 --- a/gb-storage/src/tikv.rs +++ b/gb-storage/src/tikv.rs @@ -1,7 +1,6 @@ -use async_trait::async_trait; use gb_core::{Result, Error}; -use tikv_client::{Config, RawClient, Value}; -use tracing::{instrument, error}; +use tikv_client::{RawClient, Config, KvPair}; +use tracing::{error, instrument}; pub struct TiKVStorage { client: RawClient, @@ -10,26 +9,15 @@ pub struct TiKVStorage { impl TiKVStorage { pub async fn new(pd_endpoints: Vec) -> Result { let config = Config::default(); - let client = RawClient::new(pd_endpoints, config) + let client = RawClient::new(pd_endpoints) .await - .map_err(|e| Error::internal(format!("TiKV client error: {}", e)))?; + .map_err(|e| Error::internal(format!("TiKV error: {}", e)))?; Ok(Self { client }) } - #[instrument(skip(self, value))] - pub async fn put(&self, key: &[u8], value: Value) -> Result<()> { - self.client - .put(key.to_vec(), value) - .await - .map_err(|e| { - error!("TiKV put error: {}", e); - Error::internal(format!("TiKV error: {}", e)) - }) - } - #[instrument(skip(self))] - pub async fn get(&self, key: &[u8]) -> Result> { + pub async fn get(&self, key: &[u8]) -> Result>> { self.client .get(key.to_vec()) .await @@ -39,6 +27,17 @@ impl TiKVStorage { }) } + #[instrument(skip(self))] + pub async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> { + self.client + .put(key.to_vec(), value.to_vec()) + .await + .map_err(|e| { + error!("TiKV put error: {}", e); + Error::internal(format!("TiKV error: {}", e)) + }) + } + #[instrument(skip(self))] pub async fn delete(&self, key: &[u8]) -> Result<()> { self.client @@ -51,7 +50,7 @@ impl TiKVStorage { } #[instrument(skip(self))] - pub async fn batch_get(&self, keys: Vec>) -> Result> { + pub async fn batch_get(&self, keys: Vec>) -> Result> { self.client .batch_get(keys) .await @@ -62,7 +61,7 @@ impl TiKVStorage { } #[instrument(skip(self))] - pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { + pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result> { self.client .scan(start.to_vec()..end.to_vec(), limit) .await @@ -71,58 +70,4 @@ impl TiKVStorage { Error::internal(format!("TiKV error: {}", e)) }) } -} - -#[derive(Debug, Clone)] -pub struct KVPair { - pub key: Vec, - pub value: Value, -} - -#[cfg(test)] -mod tests { - use super::*; - use tikv_client::Value; - - #[tokio::test] - async fn test_tikv_storage() { - let pd_endpoints = vec!["127.0.0.1:2379".to_string()]; - let storage = TiKVStorage::new(pd_endpoints).await.unwrap(); - - // Test put and get - let key = b"test_key"; - let value = Value::from(b"test_value".to_vec()); - storage.put(key, value.clone()).await.unwrap(); - - let retrieved = storage.get(key).await.unwrap(); - assert_eq!(retrieved.unwrap(), value); - - // Test delete - storage.delete(key).await.unwrap(); - let deleted = storage.get(key).await.unwrap(); - assert!(deleted.is_none()); - - // Test batch operations - let pairs = vec![ - (b"key1".to_vec(), Value::from(b"value1".to_vec())), - (b"key2".to_vec(), Value::from(b"value2".to_vec())), - ]; - - for (key, value) in pairs.clone() { - storage.put(&key, value).await.unwrap(); - } - - let keys: Vec> = pairs.iter().map(|(k, _)| k.clone()).collect(); - let retrieved = storage.batch_get(keys).await.unwrap(); - assert_eq!(retrieved.len(), pairs.len()); - - // Test scan - let scanned = storage.scan(b"key", b"key3", 10).await.unwrap(); - assert_eq!(scanned.len(), 2); - - // Cleanup - for (key, _) in pairs { - storage.delete(&key).await.unwrap(); - } - } -} +} \ No newline at end of file