new(all): Initial import.

This commit is contained in:
Rodrigo Rodriguez 2024-12-24 13:05:54 -03:00
parent 7f384469a9
commit 28cc734340
27 changed files with 573 additions and 1148 deletions

2
.cargo/config.toml Normal file
View file

@ -0,0 +1,2 @@
[term]
quiet = true

3
.gitignore vendored
View file

@ -1 +1,2 @@
target
target
.env

1
Cargo.lock generated
View file

@ -2407,6 +2407,7 @@ dependencies = [
"async-trait",
"chromiumoxide",
"fantoccini",
"futures-util",
"gb-core",
"headless_chrome",
"image",

View file

@ -16,6 +16,12 @@ members = [
"gb-image", # Image processing capabilities
]
# [workspace.lints.rust]
# unused_imports = "allow"
# dead_code = "allow"
# unused_variables = "allow"
# dependency_on_unit_never_type_fallback = "allow"
[workspace.package]
version = "0.1.0"
edition = "2021"

View file

@ -46,7 +46,7 @@ async fn handle_ws_connection(
if let Ok(text) = msg.to_text() {
if let Ok(envelope) = serde_json::from_str::<MessageEnvelope>(text) {
let mut processor = state.message_processor.lock().await;
if let Err(e) = processor.process_messages(vec![envelope]).await {
if let Err(e) = processor.process_messages().await {
error!("Failed to process message: {}", e);
}
}
@ -77,7 +77,7 @@ async fn send_message(
};
let mut processor = state.message_processor.lock().await;
processor.process_messages(vec![envelope.clone()]).await
processor.process_messages().await
.map_err(|e| Error::internal(format!("Failed to process message: {}", e)))?;
Ok(Json(MessageId(envelope.id)))
@ -114,4 +114,4 @@ async fn join_room(
Json(_user_id): Json<Uuid>,
) -> Result<Json<Connection>> {
todo!()
}
}

View file

@ -54,3 +54,4 @@ tokio-test = "0.4"
mockall = "0.12"
axum-extra = { version = "0.7" }
sqlx = { version = "0.7", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json"] }

20
gb-auth/src/error.rs Normal file
View file

@ -0,0 +1,20 @@
use gb_core::Error as CoreError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AuthError {
#[error("Invalid token")]
InvalidToken,
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("Internal error: {0}")]
Internal(String),
}
impl From<CoreError> for AuthError {
fn from(err: CoreError) -> Self {
AuthError::Internal(err.to_string())
}
}

24
gb-auth/src/errors.rs Normal file
View file

@ -0,0 +1,24 @@
use gb_core::Error as CoreError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AuthError {
#[error("Invalid token")]
InvalidToken,
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("Internal error: {0}")]
Internal(String),
}
impl From<CoreError> for AuthError {
fn from(err: CoreError) -> Self {
match err {
CoreError { .. } => AuthError::Internal(err.to_string()),
}
}
}

View file

@ -3,9 +3,7 @@ use axum::{
middleware::Next,
body::Body,
};
use axum_extra::TypedHeader;
use axum_extra::headers::{Authorization, authorization::Bearer};
use gb_core::User;
use headers::{Authorization, authorization::Bearer};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Serialize, Deserialize};
use crate::AuthError;
@ -17,10 +15,10 @@ struct Claims {
}
pub async fn auth_middleware(
TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
auth: Authorization<Bearer>,
request: Request<Body>,
next: Next,
) -> Result<Response, AuthError> {
) -> Result<Response<Body>, AuthError> {
let token = auth.token();
let key = DecodingKey::from_secret(b"secret");
let validation = Validation::default();
@ -32,4 +30,4 @@ pub async fn auth_middleware(
}
Err(_) => Err(AuthError::InvalidToken),
}
}
}

View file

@ -1,3 +1,2 @@
mod user;
pub mod user;
pub use user::*;

View file

@ -1,3 +1,6 @@
use gb_core::{Result, Error};
use crate::models::{LoginRequest, LoginResponse};
use crate::models::user::DbUser;
use std::sync::Arc;
use sqlx::PgPool;
use argon2::{
@ -6,12 +9,6 @@ use argon2::{
};
use rand::rngs::OsRng;
use crate::{
models::{LoginRequest, LoginResponse, User},
AuthError,
Result,
};
pub struct AuthService {
db: Arc<PgPool>,
jwt_secret: String,
@ -30,12 +27,17 @@ impl AuthService {
pub async fn login(&self, request: LoginRequest) -> Result<LoginResponse> {
let user = sqlx::query_as!(
DbUser,
"SELECT * FROM users WHERE email = $1",
r#"
SELECT id, email, password_hash, role
FROM users
WHERE email = $1
"#,
request.email
)
.fetch_optional(&*self.db)
.await?
.ok_or(AuthError::InvalidCredentials)?;
.await
.map_err(|e| Error::internal(e.to_string()))?
.ok_or_else(|| Error::internal("Invalid credentials"))?;
self.verify_password(&request.password, &user.password_hash)?;
@ -56,19 +58,19 @@ impl AuthService {
argon2
.hash_password(password.as_bytes(), &salt)
.map(|hash| hash.to_string())
.map_err(|e| AuthError::Internal(e.to_string()))
.map_err(|e| Error::internal(e.to_string()))
}
fn verify_password(&self, password: &str, hash: &str) -> Result<()> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| AuthError::Internal(e.to_string()))?;
.map_err(|e| Error::internal(e.to_string()))?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| AuthError::InvalidCredentials)
.map_err(|_| Error::internal("Invalid credentials"))
}
fn generate_token(&self, user: &User) -> Result<String> {
fn generate_token(&self, user: &DbUser) -> Result<String> {
use jsonwebtoken::{encode, EncodingKey, Header};
use serde::{Serialize, Deserialize};
use chrono::{Utc, Duration};
@ -94,6 +96,6 @@ impl AuthService {
&claims,
&EncodingKey::from_secret(self.jwt_secret.as_bytes()),
)
.map_err(|e| AuthError::Internal(e.to_string()))
.map_err(|e| Error::internal(e.to_string()))
}
}
}

View file

@ -9,6 +9,7 @@ license.workspace = true
gb-core = { path = "../gb-core" }
image = { version = "0.24", features = ["webp", "jpeg", "png", "gif"] }
chromiumoxide = { version = "0.5", features = ["tokio-runtime"] }
futures-util = "0.3"
async-trait.workspace = true
tokio.workspace = true
serde.workspace = true
@ -24,4 +25,4 @@ async-recursion = "1.0"
[dev-dependencies]
rstest.workspace = true
tokio-test = "0.4"
mock_instant = "0.2"
mock_instant = "0.2"

View file

@ -1,36 +1,4 @@
pub mod web;
pub mod process;
mod web;
pub use web::{WebAutomation, Element};
pub use process::ProcessAutomation;
#[cfg(test)]
mod tests {
use super::*;
use gb_core::Result;
use tempfile::tempdir;
#[tokio::test]
async fn test_automation_integration() -> Result<()> {
// Initialize automation components
let web = WebAutomation::new().await?;
let dir = tempdir()?;
let process = ProcessAutomation::new(dir.path());
// Test web automation
let page = web.new_page().await?;
web.navigate(&page, "https://example.com").await?;
let screenshot = web.screenshot(&page, "test.png").await?;
// Test process automation
let output = process.execute("echo", &["Test output"]).await?;
assert!(output.contains("Test output"));
// Test process spawning and cleanup
let id = process.spawn("sleep", &["1"]).await?;
process.kill(id).await?;
process.cleanup().await?;
Ok(())
}
}
pub use chromiumoxide::element::Element;
pub use web::WebAutomation;

View file

@ -1,73 +1,61 @@
use chromiumoxide::browser::{Browser, BrowserConfig};
use chromiumoxide::element::Element;
use chromiumoxide::{Browser, Element};
use chromiumoxide::page::Page;
use chromiumoxide::browser::BrowserConfig;
use futures_util::StreamExt;
use gb_core::{Error, Result};
use tracing::instrument;
use std::time::Duration;
pub struct WebAutomation {
browser: Browser,
}
impl WebAutomation {
#[instrument]
pub async fn new() -> Result<Self> {
let config = BrowserConfig::builder()
.build()
.map_err(|e| Error::internal(e.to_string()))?;
let (browser, mut handler) = Browser::launch(config)
let (browser, handler) = Browser::launch(config)
.await
.map_err(|e| Error::internal(e.to_string()))?;
// Spawn the handler in the background
tokio::spawn(async move {
while let Some(h) = handler.next().await {
if let Err(e) = h {
tracing::error!("Browser handler error: {}", e);
}
}
handler.for_each(|_| async {}).await;
});
Ok(Self { browser })
}
#[instrument(skip(self))]
pub async fn new_page(&self) -> Result<Page> {
let params = chromiumoxide::cdp::browser_protocol::target::CreateTarget::new()
.url("about:blank");
self.browser.new_page(params)
self.browser
.new_page("about:blank")
.await
.map_err(|e| Error::internal(e.to_string()))
}
#[instrument(skip(self))]
pub async fn navigate(&self, page: &Page, url: &str) -> Result<()> {
page.goto(url)
.await
.map_err(|e| Error::internal(e.to_string()))
.map_err(|e| Error::internal(e.to_string()))?;
Ok(())
}
#[instrument(skip(self))]
pub async fn get_element(&self, page: &Page, selector: &str) -> Result<Element> {
page.find_element(selector)
.await
.map_err(|e| Error::internal(e.to_string()))
}
#[instrument(skip(self))]
pub async fn screenshot(&self, page: &Page, _path: &str) -> Result<Vec<u8>> {
let params = chromiumoxide::cdp::browser_protocol::page::CaptureScreenshot::new();
pub async fn take_screenshot(&self, page: &Page) -> Result<Vec<u8>> {
let params = chromiumoxide::page::ScreenshotParams::builder().build();
page.screenshot(params)
.await
.map_err(|e| Error::internal(e.to_string()))
}
#[instrument(skip(self))]
pub async fn wait_for_selector(&self, page: &Page, selector: &str) -> Result<()> {
page.find_element(selector)
.await
.map_err(|e| Error::internal(e.to_string()))?;
Ok(())
pub async fn find_element(&self, page: &Page, selector: &str, timeout: Duration) -> Result<Element> {
tokio::time::timeout(
timeout,
page.find_element(selector)
)
.await
.map_err(|_| Error::internal("Timeout waiting for element"))?
.map_err(|e| Error::internal(e.to_string()))
}
}
}

View file

@ -1,204 +1,39 @@
use gb_core::{Result, Error};
use image::{
DynamicImage, Rgba,
};
use imageproc::{
drawing::draw_text_mut,
};
use rusttype::{Font, Scale};
use std::path::Path;
use tracing::instrument;
use std::convert::TryInto;
use gb_core::{Error, Result};
use image::{DynamicImage, ImageOutputFormat};
use std::io::Cursor;
use tesseract::Tesseract;
use tempfile::NamedTempFile;
use std::io::Write;
pub struct ProcessingOptions {
pub crop: Option<CropParams>,
pub watermark: Option<DynamicImage>,
pub x: i32,
pub y: i32,
}
pub struct CropParams {
pub x: u32,
pub y: u32,
pub width: u32,
pub height: u32,
}
pub struct ImageProcessor {
default_font: Font<'static>,
}
pub struct ImageProcessor;
impl ImageProcessor {
pub fn new() -> Result<Self> {
let font_data = include_bytes!("../assets/DejaVuSans.ttf");
let font = Font::try_from_bytes(font_data)
.ok_or_else(|| Error::internal("Failed to load font"))?;
Ok(Self {
default_font: font,
})
pub fn new() -> Self {
Self
}
pub fn process_image(&self, mut image: DynamicImage, options: &ProcessingOptions) -> Result<DynamicImage> {
if let Some(crop) = &options.crop {
let cropped = image.crop_imm(
crop.x,
crop.y,
crop.width,
crop.height
);
image = cropped;
}
if let Some(watermark) = &options.watermark {
let x: i64 = options.x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?;
let y: i64 = options.y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?;
image::imageops::overlay(&mut image, watermark, x, y);
}
Ok(image)
}
#[instrument(skip(self, image_data))]
pub fn load_image(&self, image_data: &[u8]) -> Result<DynamicImage> {
image::load_from_memory(image_data)
.map_err(|e| Error::internal(format!("Failed to load image: {}", e)))
}
#[instrument(skip(self, image))]
pub fn save_image(&self, image: &DynamicImage, path: &Path) -> Result<()> {
image.save(path)
.map_err(|e| Error::internal(format!("Failed to save image: {}", e)))
}
#[instrument(skip(self, image))]
pub fn crop(&self, image: &DynamicImage, x: u32, y: u32, width: u32, height: u32) -> Result<DynamicImage> {
Ok(image.crop_imm(x, y, width, height))
}
#[instrument(skip(self, image))]
pub fn add_text(
&self,
image: &mut DynamicImage,
text: &str,
x: i32,
y: i32,
scale: f32,
color: Rgba<u8>,
) -> Result<()> {
let scale = Scale::uniform(scale);
let mut img = image.to_rgba8();
draw_text_mut(
&mut img,
color,
x,
y,
scale,
&self.default_font,
text,
);
*image = DynamicImage::ImageRgba8(img);
Ok(())
}
#[instrument(skip(self, image))]
pub fn add_watermark(
&self,
image: &mut DynamicImage,
watermark: &DynamicImage,
x: u32,
y: u32,
) -> Result<()> {
let x: i64 = x.try_into().map_err(|_| Error::internal("Invalid x coordinate"))?;
let y: i64 = y.try_into().map_err(|_| Error::internal("Invalid y coordinate"))?;
image::imageops::overlay(image, watermark, x, y);
Ok(())
}
#[instrument(skip(self, image))]
pub fn extract_text(&self, image: &DynamicImage) -> Result<String> {
use tesseract::Tesseract;
let temp_file = tempfile::NamedTempFile::new()
pub async fn extract_text(&self, image: &DynamicImage) -> Result<String> {
// Create a temporary file
let mut temp_file = NamedTempFile::new()
.map_err(|e| Error::internal(format!("Failed to create temp file: {}", e)))?;
image.save(&temp_file)
.map_err(|e| Error::internal(format!("Failed to save temp image: {}", e)))?;
// Convert image to PNG and write to temp file
let mut cursor = Cursor::new(Vec::new());
image.write_to(&mut cursor, ImageOutputFormat::Png)
.map_err(|e| Error::internal(format!("Failed to encode image: {}", e)))?;
temp_file.write_all(&cursor.into_inner())
.map_err(|e| Error::internal(format!("Failed to write to temp file: {}", e)))?;
let mut api = Tesseract::new(None, Some("eng"))
// Initialize Tesseract and process image
let api = Tesseract::new(None, Some("eng"))
.map_err(|e| Error::internal(format!("Failed to initialize Tesseract: {}", e)))?;
api.set_image(temp_file.path().to_str().unwrap())
.map_err(|e| Error::internal(format!("Failed to set image: {}", e)))?;
api.recognize()
.map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))?;
api.get_text()
.map_err(|e| Error::internal(format!("Failed to set image: {}", e)))?
.recognize()
.map_err(|e| Error::internal(format!("Failed to recognize text: {}", e)))?
.get_text()
.map_err(|e| Error::internal(format!("Failed to get text: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use std::path::PathBuf;
#[fixture]
fn processor() -> ImageProcessor {
ImageProcessor::new().unwrap()
}
#[fixture]
fn test_image() -> DynamicImage {
DynamicImage::new_rgb8(100, 100)
}
#[rstest]
fn test_resize(processor: ImageProcessor, test_image: DynamicImage) {
let resized = processor.resize(&test_image, 50, 50);
assert_eq!(resized.width(), 50);
assert_eq!(resized.height(), 50);
}
#[rstest]
fn test_crop(processor: ImageProcessor, test_image: DynamicImage) -> Result<()> {
let cropped = processor.crop(&test_image, 25, 25, 50, 50)?;
assert_eq!(cropped.width(), 50);
assert_eq!(cropped.height(), 50);
Ok(())
}
#[rstest]
fn test_add_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> {
processor.add_text(
&mut test_image,
"Test",
10,
10,
12.0,
Rgba([255, 255, 255, 255]),
)?;
Ok(())
}
#[rstest]
fn test_extract_text(processor: ImageProcessor, mut test_image: DynamicImage) -> Result<()> {
processor.add_text(
&mut test_image,
"Test OCR",
10,
10,
24.0,
Rgba([0, 0, 0, 255]),
)?;
let text = processor.extract_text(&test_image)?;
assert!(text.contains("Test OCR"));
Ok(())
}
}

View file

@ -1,112 +1,56 @@
use gb_core::{Result, Error};
use opus::{Decoder, Encoder};
use opus::{Decoder, Encoder, Channels, Application};
use std::io::Cursor;
use tracing::{instrument, error};
use opus::{Encoder, Decoder, Application, Channels};
pub struct AudioProcessor {
sample_rate: i32,
channels: i32,
encoder: Encoder,
decoder: Decoder,
sample_rate: u32,
channels: Channels,
}
impl AudioProcessor {
pub fn new(sample_rate: i32, channels: i32) -> Self {
Self {
pub fn new(sample_rate: u32, channels: Channels) -> Result<Self> {
let encoder = Encoder::new(
sample_rate,
channels,
}
Application::Audio
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
let decoder = Decoder::new(
sample_rate,
channels
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
Ok(Self {
encoder,
decoder,
sample_rate,
channels,
})
}
#[instrument(skip(self, input))]
pub fn encode(&self, input: &[i16]) -> Result<Vec<u8>> {
let mut encoder = Encoder::new(
self.sample_rate,
if self.channels == 1 {
opus::Channels::Mono
} else {
opus::Channels::Stereo
},
opus::Application::Voip,
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?,
Channels::Mono,
Application::Voip
).map_err(|e| Error::internal(format!("Failed to create Opus encoder: {}", e)))?;
let mut output = vec![0u8; 1024];
let encoded_len = encoder.encode(input, &mut output)
.map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?;
let encoded_size = self.encoder.encode(
input,
&mut output
).map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))?;
output.truncate(encoded_len);
output.truncate(encoded_size);
Ok(output)
encoder.encode(input)
.map_err(|e| Error::internal(format!("Failed to encode audio: {}", e)))
}
#[instrument(skip(self, input))]
pub fn decode(&self, input: &[u8]) -> Result<Vec<i16>> {
let mut decoder = Decoder::new(
self.sample_rate,
if self.channels == 1 {
opus::Channels::Mono
} else {
opus::Channels::Stereo
},
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
u32::try_from(self.sample_rate).map_err(|e| Error::internal(format!("Invalid sample rate: {}", e)))?,
Channels::Mono
).map_err(|e| Error::internal(format!("Failed to create Opus decoder: {}", e)))?;
let max_size = (self.sample_rate as usize / 50) * self.channels.count();
let mut output = vec![0i16; max_size];
let mut output = vec![0i16; 1024];
let decoded_len = decoder.decode(input, &mut output, false)
.map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?;
let decoded_size = self.decoder.decode(
Some(input),
&mut output,
false
).map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))?;
output.truncate(decoded_len);
output.truncate(decoded_size);
Ok(output)
decoder.decode(input)
.map_err(|e| Error::internal(format!("Failed to decode audio: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[fixture]
fn audio_processor() -> AudioProcessor {
AudioProcessor::new(48000, 2)
}
#[fixture]
fn test_audio() -> Vec<i16> {
// Generate 1 second of 440Hz sine wave
let sample_rate = 48000;
let frequency = 440.0;
let duration = 1.0;
(0..sample_rate)
.flat_map(|i| {
let t = i as f32 / sample_rate as f32;
let value = (2.0 * std::f32::consts::PI * frequency * t).sin();
let sample = (value * i16::MAX as f32) as i16;
vec![sample, sample] // Stereo
vec![sample, sample]
})
.collect()
}
#[rstest]
fn test_encode_decode(audio_processor: AudioProcessor, test_audio: Vec<i16>) {
let encoded = audio_processor.encode(&test_audio).unwrap();
let decoded = audio_processor.decode(&encoded).unwrap();
// Verify basic properties
assert!(!encoded.is_empty());
assert!(!decoded.is_empty());
// Opus is lossy, so we can't compare exact values
// But we can verify the length is the same
assert_eq!(decoded.len(), test_audio.len());
}
}
}

View file

@ -1,44 +1,5 @@
pub mod webrtc;
pub mod processor;
pub mod audio;
mod processor;
mod webrtc;
pub use webrtc::WebRTCService;
pub use processor::{MediaProcessor, MediaMetadata};
pub use audio::AudioProcessor;
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use uuid::Uuid;
#[tokio::test]
async fn test_media_integration() {
// Initialize services
let webrtc = WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()]);
let processor = MediaProcessor::new().unwrap();
let audio = AudioProcessor::new(48000, 2);
// Test room creation and joining
let room_id = Uuid::new_v4();
let user_id = Uuid::new_v4();
let connection = webrtc.join_room(room_id, user_id).await.unwrap();
assert_eq!(connection.room_id, room_id);
assert_eq!(connection.user_id, user_id);
// Test media processing
let input_path = PathBuf::from("test_data/test.mp4");
if input_path.exists() {
let metadata = processor.extract_metadata(input_path.clone()).await.unwrap();
assert!(metadata.width.is_some());
assert!(metadata.height.is_some());
}
// Test audio processing
let test_audio: Vec<i16> = (0..1024).map(|i| i as i16).collect();
let encoded = audio.encode(&test_audio).unwrap();
let decoded = audio.decode(&encoded).unwrap();
assert!(!decoded.is_empty());
}
}
pub use processor::MediaProcessor;
pub use webrtc::WebRTCService;

View file

@ -1,41 +1,45 @@
use gstreamer::{self as gst, prelude::*};
use gstreamer::prelude::{
ElementExt,
GstBinExtManual,
GstObjectExt,
};
use gb_core::{Result, Error};
use gstreamer as gst;
use gstreamer::prelude::*;
use std::path::PathBuf;
use tracing::{error, instrument};
pub struct MediaProcessor {
pipeline: gst::Pipeline,
}
impl MediaProcessor {
pub fn new() -> Result<Self> {
gst::init().map_err(|e| Error::internal(format!("Failed to initialize GStreamer: {}", e)))?;
let pipeline = gst::Pipeline::new(None);
Ok(Self {
pipeline,
})
let pipeline = gst::Pipeline::new()
.map_err(|e| Error::internal(format!("Failed to create pipeline: {}", e)))?;
Ok(Self { pipeline })
}
fn setup_pipeline(&mut self) -> Result<()> {
self.pipeline.set_state(gst::State::Playing)
.map_err(|e| Error::internal(format!("Failed to start pipeline: {}", e)))?;
let bus = self.pipeline.bus().expect("Pipeline without bus");
for msg in bus.iter_timed(gst::ClockTime::NONE) {
use gst::MessageView;
Ok(())
}
fn process_messages(&self) -> Result<()> {
let bus = self.pipeline.bus().unwrap();
while let Some(msg) = bus.timed_pop(gst::ClockTime::from_seconds(1)) {
match msg.view() {
MessageView::Error(err) => {
gst::MessageView::Error(err) => {
error!("Error from {:?}: {} ({:?})",
err.src().map(|s| s.path_string()),
err.error(),
err.error(),
err.debug()
);
return Err(Error::internal(format!("Pipeline error: {}", err.error())));
}
MessageView::Eos(_) => break,
_ => (),
gst::MessageView::Eos(_) => break,
_ => ()
}
}
@ -47,85 +51,32 @@ impl MediaProcessor {
#[instrument(skip(self, input_path, output_path))]
pub async fn transcode(
&self,
&mut self,
input_path: PathBuf,
output_path: PathBuf,
format: &str,
format: &str
) -> Result<()> {
let src = gst::ElementFactory::make("filesrc")
.property("location", input_path.to_str().unwrap())
.build()
let source = gst::ElementFactory::make("filesrc")
.map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?;
source.set_property("location", input_path.to_str().unwrap());
let sink = gst::ElementFactory::make("filesink")
.property("location", output_path.to_str().unwrap())
.build()
.map_err(|e| Error::internal(format!("Failed to create sink element: {}", e)))?;
sink.set_property("location", output_path.to_str().unwrap());
let decoder = match format {
"h264" => gst::ElementFactory::make("h264parse").build(),
"opus" => gst::ElementFactory::make("opusparse").build(),
_ => return Err(Error::InvalidInput(format!("Unsupported format: {}", format))),
let decoder = match format.to_lowercase().as_str() {
"mp4" => gst::ElementFactory::make("qtdemux"),
"webm" => gst::ElementFactory::make("matroskademux"),
_ => return Err(Error::internal(format!("Unsupported format: {}", format)))
}.map_err(|e| Error::internal(format!("Failed to create decoder: {}", e)))?;
self.pipeline.add_many(&[&src, &decoder, &sink])
self.pipeline.add_many(&[&source, &decoder, &sink])
.map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?;
gst::Element::link_many(&[&src, &decoder, &sink])
gst::Element::link_many(&[&source, &decoder, &sink])
.map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?;
self.setup_pipeline()?;
Ok(())
self.process_messages()
}
#[instrument(skip(self, input_path))]
pub async fn extract_metadata(&self, input_path: PathBuf) -> Result<MediaMetadata> {
let src = gst::ElementFactory::make("filesrc")
.property("location", input_path.to_str().unwrap())
.build()
.map_err(|e| Error::internal(format!("Failed to create source element: {}", e)))?;
let decodebin = gst::ElementFactory::make("decodebin").build()
.map_err(|e| Error::internal(format!("Failed to create decodebin: {}", e)))?;
self.pipeline.add_many(&[&src, &decodebin])
.map_err(|e| Error::internal(format!("Failed to add elements: {}", e)))?;
gst::Element::link_many(&[&src, &decodebin])
.map_err(|e| Error::internal(format!("Failed to link elements: {}", e)))?;
let mut metadata = MediaMetadata::default();
decodebin.connect_pad_added(move |_, pad| {
let caps = pad.current_caps().unwrap();
let structure = caps.structure(0).unwrap();
match structure.name() {
"video/x-raw" => {
if let Ok(width) = structure.get::<i32>("width") {
metadata.width = Some(width);
}
if let Ok(height) = structure.get::<i32>("height") {
metadata.height = Some(height);
}
if let Ok(framerate) = structure.get::<gst::Fraction>("framerate") {
metadata.framerate = Some(framerate.numer() as f64 / framerate.denom() as f64);
}
},
"audio/x-raw" => {
if let Ok(channels) = structure.get::<i32>("channels") {
metadata.channels = Some(channels);
}
if let Ok(rate) = structure.get::<i32>("rate") {
metadata.sample_rate = Some(rate);
}
},
_ => (),
}
});
self.setup_pipeline()?;
Ok(metadata)
}
}
}

View file

@ -1,163 +1,86 @@
use async_trait::async_trait;
use gb_core::{
models::*,
traits::*,
Result, Error, Connection,
};
use uuid::Uuid;
use gb_core::{Result, Error};
use webrtc::{
api::APIBuilder,
ice_transport::ice_server::RTCIceServer,
peer_connection::configuration::RTCConfiguration,
peer_connection::peer_connection_state::RTCPeerConnectionState,
peer_connection::RTCPeerConnection,
track::track_remote::TrackRemote,
rtp::rtp_receiver::RTCRtpReceiver,
rtp::rtp_transceiver::RTCRtpTransceiver,
api::{API, APIBuilder},
peer_connection::{
RTCPeerConnection,
peer_connection_state::RTCPeerConnectionState,
configuration::RTCConfiguration,
},
track::{
track_local::TrackLocal,
track_remote::TrackRemote,
},
};
use tracing::{instrument, error};
use tokio::sync::mpsc;
use tracing::instrument;
use std::sync::Arc;
use chrono::Utc;
pub struct WebRTCService {
config: RTCConfiguration,
api: Arc<API>,
peer_connections: Vec<Arc<RTCPeerConnection>>,
}
impl WebRTCService {
pub fn new(ice_servers: Vec<String>) -> Self {
let mut config = RTCConfiguration::default();
config.ice_servers = ice_servers
.into_iter()
.map(|url| RTCIceServer {
urls: vec![url],
..Default::default()
})
.collect();
pub fn new() -> Result<Self> {
let api = APIBuilder::new().build();
Self { config }
Ok(Self {
api: Arc::new(api),
peer_connections: Vec::new(),
})
}
async fn create_peer_connection(&self) -> Result<RTCPeerConnection> {
let api = APIBuilder::new().build();
pub async fn create_peer_connection(&mut self) -> Result<Arc<RTCPeerConnection>> {
let config = RTCConfiguration::default();
let peer_connection = api.new_peer_connection(self.config.clone())
let peer_connection = self.api.new_peer_connection(config)
.await
.map_err(|e| Error::internal(format!("Failed to create peer connection: {}", e)))?;
Ok(peer_connection)
let pc_arc = Arc::new(peer_connection);
self.peer_connections.push(pc_arc.clone());
Ok(pc_arc)
}
async fn handle_track(&self, track: Arc<TrackRemote>, receiver: Arc<RTCRtpReceiver>, transceiver: Arc<RTCRtpTransceiver>) {
tracing::info!(
"Received track: {} {}",
track.kind(),
track.id()
);
}
pub async fn add_track(
&self,
pc: &RTCPeerConnection,
track: Arc<dyn TrackLocal + Send + Sync>,
) -> Result<()> {
pc.add_track(track)
.await
.map_err(|e| Error::internal(format!("Failed to add track: {}", e)))?;
async fn create_connection(&self) -> Result<Connection> {
Ok(Connection {
id: Uuid::new_v4(),
connected_at: Utc::now(),
ice_servers: self.config.ice_servers.clone(),
metadata: serde_json::Value::Object(serde_json::Map::new()),
room_id: Uuid::new_v4(),
user_id: Uuid::new_v4(),
})
Ok(())
}
}
#[async_trait]
impl RoomService for WebRTCService {
#[instrument(skip(self))]
async fn create_room(&self, config: RoomConfig) -> Result<Room> {
todo!()
pub async fn on_track<F>(&self, pc: &RTCPeerConnection, mut callback: F)
where
F: FnMut(Arc<TrackRemote>) + Send + 'static,
{
let (tx, mut rx) = mpsc::channel(100);
pc.on_track(Box::new(move |track, _, _| {
let track_clone = track.clone();
let tx = tx.clone();
Box::pin(async move {
let _ = tx.send(track_clone).await;
})
}));
while let Some(track) = rx.recv().await {
callback(track);
}
}
#[instrument(skip(self))]
async fn join_room(&self, room_id: Uuid, user_id: Uuid) -> Result<Connection> {
let peer_connection = self.create_peer_connection().await?;
peer_connection
.on_peer_connection_state_change(Box::new(move |s: RTCPeerConnectionState| {
Box::pin(async move {
match s {
RTCPeerConnectionState::Connected => {
tracing::info!("Peer connection connected");
pub async fn close(&mut self) -> Result<()> {
for pc in self.peer_connections.iter() {
pc.close().await
.map_err(|e| Error::internal(format!("Failed to close peer connection: {}", e)))?;
}
self.peer_connections.clear();
Ok(())
}
RTCPeerConnectionState::Disconnected
| RTCPeerConnectionState::Failed
| RTCPeerConnectionState::Closed => {
tracing::warn!("Peer connection state changed to {}", s);
}
_ => {}
}
})
}));
let mut connection = self.create_connection().await?;
connection.room_id = room_id;
connection.user_id = user_id;
Ok(connection)
}
#[instrument(skip(self))]
async fn leave_room(&self, room_id: Uuid, user_id: Uuid) -> Result<()> {
todo!()
}
#[instrument(skip(self))]
async fn publish_track(&self, track: TrackInfo) -> Result<Track> {
todo!()
}
#[instrument(skip(self))]
async fn subscribe_track(&self, track_id: Uuid) -> Result<Subscription> {
todo!()
}
#[instrument(skip(self))]
async fn get_participants(&self, room_id: Uuid) -> Result<Vec<Participant>> {
todo!()
}
#[instrument(skip(self))]
async fn get_room_stats(&self, room_id: Uuid) -> Result<RoomStats> {
todo!()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[fixture]
fn webrtc_service() -> WebRTCService {
WebRTCService::new(vec!["stun:stun.l.google.com:19302".to_string()])
}
#[rstest]
#[tokio::test]
async fn test_create_peer_connection(webrtc_service: WebRTCService) {
let peer_connection = webrtc_service.create_peer_connection().await.unwrap();
assert_eq!(
peer_connection.connection_state().await,
RTCPeerConnectionState::New
);
}
#[rstest]
#[tokio::test]
async fn test_join_room(webrtc_service: WebRTCService) {
let room_id = Uuid::new_v4();
let user_id = Uuid::new_v4();
let connection = webrtc_service.join_room(room_id, user_id).await.unwrap();
assert_eq!(connection.room_id, room_id);
assert_eq!(connection.user_id, user_id);
assert!(!connection.ice_servers.is_empty());
}
}
}

View file

@ -0,0 +1,22 @@
[package]
name = "gb-migrations"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
[[bin]]
name = "migrations"
path = "src/bin/migrations.rs"
[dependencies]
tokio.workspace = true
sqlx.workspace = true
tracing.workspace = true
uuid.workspace = true
chrono.workspace = true
serde_json.workspace = true
gb-core = { path = "../gb-core" }
[dev-dependencies]
rstest.workspace = true

View file

@ -0,0 +1,19 @@
use sqlx::PgPool;
use gb_migrations::run_migrations;
#[tokio::main]
async fn main() -> Result<(), sqlx::Error> {
let database_url = std::env::var("DATABASE_URL")
.expect("DATABASE_URL must be set");
println!("Creating database connection pool...");
let pool = PgPool::connect(&database_url)
.await
.expect("Failed to create pool");
println!("Running migrations...");
run_migrations(&pool).await?;
println!("Migrations completed successfully!");
Ok(())
}

View file

@ -0,0 +1,144 @@
use sqlx::PgPool;
use tracing::info;
pub async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
info!("Running database migrations");
// Create tables
let table_queries = [
// Customers table
r#"CREATE TABLE IF NOT EXISTS customers (
id UUID PRIMARY KEY,
name VARCHAR(255) NOT NULL,
subscription_tier VARCHAR(50) NOT NULL,
status VARCHAR(50) NOT NULL,
max_instances INTEGER NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
// Instances table
r#"CREATE TABLE IF NOT EXISTS instances (
id UUID PRIMARY KEY,
customer_id UUID NOT NULL REFERENCES customers(id),
name VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL,
shard_id INTEGER NOT NULL,
region VARCHAR(50) NOT NULL,
config JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
// Rooms table
r#"CREATE TABLE IF NOT EXISTS rooms (
id UUID PRIMARY KEY,
customer_id UUID NOT NULL REFERENCES customers(id),
instance_id UUID NOT NULL REFERENCES instances(id),
name VARCHAR(255) NOT NULL,
kind VARCHAR(50) NOT NULL,
status VARCHAR(50) NOT NULL,
config JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
// Messages table
r#"CREATE TABLE IF NOT EXISTS messages (
id UUID PRIMARY KEY,
customer_id UUID NOT NULL REFERENCES customers(id),
instance_id UUID NOT NULL REFERENCES instances(id),
conversation_id UUID NOT NULL,
sender_id UUID NOT NULL,
kind VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
shard_key INTEGER NOT NULL
)"#,
// Users table
r#"CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY,
customer_id UUID NOT NULL REFERENCES customers(id),
instance_id UUID NOT NULL REFERENCES instances(id),
name VARCHAR(255) NOT NULL,
email VARCHAR(255) NOT NULL UNIQUE,
status VARCHAR(50) NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
// Tracks table
r#"CREATE TABLE IF NOT EXISTS tracks (
id UUID PRIMARY KEY,
room_id UUID NOT NULL REFERENCES rooms(id),
user_id UUID NOT NULL REFERENCES users(id),
kind VARCHAR(50) NOT NULL,
status VARCHAR(50) NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
// Subscriptions table
r#"CREATE TABLE IF NOT EXISTS subscriptions (
id UUID PRIMARY KEY,
track_id UUID NOT NULL REFERENCES tracks(id),
user_id UUID NOT NULL REFERENCES users(id),
status VARCHAR(50) NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
)"#,
];
// Create indexes
let index_queries = [
"CREATE INDEX IF NOT EXISTS idx_instances_customer_id ON instances(customer_id)",
"CREATE INDEX IF NOT EXISTS idx_rooms_instance_id ON rooms(instance_id)",
"CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id)",
"CREATE INDEX IF NOT EXISTS idx_messages_shard_key ON messages(shard_key)",
"CREATE INDEX IF NOT EXISTS idx_tracks_room_id ON tracks(room_id)",
"CREATE INDEX IF NOT EXISTS idx_subscriptions_track_id ON subscriptions(track_id)",
"CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)",
];
// Execute table creation queries
for query in table_queries {
sqlx::query(query)
.execute(pool)
.await?;
}
// Execute index creation queries
for query in index_queries {
sqlx::query(query)
.execute(pool)
.await?;
}
info!("Migrations completed successfully");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::postgres::{PgPoolOptions, PgPool};
use rstest::*;
async fn create_test_pool() -> PgPool {
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to create test pool")
}
#[rstest]
#[tokio::test]
async fn test_migrations() {
let pool = create_test_pool().await;
assert!(run_migrations(&pool).await.is_ok());
}
}

View file

@ -0,0 +1,10 @@
-- Add password_hash column to users table
ALTER TABLE users
ADD COLUMN IF NOT EXISTS password_hash VARCHAR(255) NOT NULL DEFAULT '';
-- Update column names if needed
ALTER TABLE users RENAME COLUMN password TO password_hash;
-- Add metadata column to instances table
ALTER TABLE instances
ADD COLUMN IF NOT EXISTS metadata JSONB NOT NULL DEFAULT '{}';

View file

@ -1,66 +1,7 @@
pub mod postgres;
pub mod redis;
pub mod tikv;
mod postgres;
mod redis;
mod tikv;
pub use postgres::{PostgresCustomerRepository, PostgresInstanceRepository};
pub use redis::RedisCache;
pub use tikv::TiKVStorage;
#[cfg(test)]
mod tests {
use super::*;
use gb_core::models::Customer;
use sqlx::postgres::PgPoolOptions;
use std::time::Duration;
async fn setup_test_db() -> sqlx::PgPool {
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to connect to database");
// Run migrations
gb_migrations::run_migrations(&pool)
.await
.expect("Failed to run migrations");
pool
}
#[tokio::test]
async fn test_storage_integration() {
// Setup PostgreSQL
let pool = setup_test_db().await;
let customer_repo = PostgresCustomerRepository::new(pool.clone());
// Setup Redis
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://127.0.0.1/".to_string());
let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap();
// Create a customer
let customer = Customer::new(
"Integration Test Corp".to_string(),
"enterprise".to_string(),
10,
);
// Save to PostgreSQL
let created = customer_repo.create(&customer).await.unwrap();
// Cache in Redis
cache.set(&format!("customer:{}", created.id), &created).await.unwrap();
// Verify Redis cache
let cached: Option<Customer> = cache.get(&format!("customer:{}", created.id)).await.unwrap();
assert_eq!(cached.unwrap().id, created.id);
// Cleanup
customer_repo.delete(created.id).await.unwrap();
cache.delete(&format!("customer:{}", created.id)).await.unwrap();
}
}
pub use postgres::{CustomerRepository, PostgresCustomerRepository};
pub use redis::RedisStorage;
pub use tikv::TiKVStorage;

View file

@ -1,302 +1,99 @@
use async_trait::async_trait;
use gb_core::{
models::*,
traits::*,
Result, Error,
};
use gb_core::{Result, Error};
use sqlx::PgPool;
use std::sync::Arc;
use uuid::Uuid;
use tracing::{instrument, error};
use gb_core::models::Customer;
pub trait CustomerRepository {
async fn create(&self, customer: Customer) -> Result<Customer>;
async fn get(&self, id: Uuid) -> Result<Option<Customer>>;
async fn update(&self, customer: Customer) -> Result<Customer>;
async fn delete(&self, id: Uuid) -> Result<()>;
}
pub struct PostgresCustomerRepository {
pool: PgPool,
pool: Arc<PgPool>,
}
impl PostgresCustomerRepository {
pub fn new(pool: PgPool) -> Self {
pub fn new(pool: Arc<PgPool>) -> Self {
Self { pool }
}
}
#[async_trait]
impl CustomerRepository for PostgresCustomerRepository {
#[instrument(skip(self))]
async fn create(&self, customer: &Customer) -> Result<Customer> {
let record = sqlx::query_as!(
async fn create(&self, customer: Customer) -> Result<Customer> {
let result = sqlx::query_as!(
Customer,
r#"
INSERT INTO customers (id, name, subscription_tier, status, max_instances, metadata, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO customers (id, name, max_instances, email, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
RETURNING *
"#,
customer.id,
customer.name,
customer.subscription_tier,
customer.status,
customer.max_instances,
customer.metadata as _,
customer.created_at,
customer.max_instances as i32,
customer.email,
)
.fetch_one(&self.pool)
.fetch_one(&*self.pool)
.await
.map_err(|e| {
error!("Failed to create customer: {}", e);
Error::Database(e)
})?;
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
Ok(record)
Ok(result)
}
#[instrument(skip(self))]
async fn get(&self, id: Uuid) -> Result<Customer> {
let record = sqlx::query_as!(
async fn get(&self, id: Uuid) -> Result<Option<Customer>> {
let result = sqlx::query_as!(
Customer,
r#"
SELECT * FROM customers WHERE id = $1
SELECT id, name, max_instances::int as "max_instances!: i32",
email, created_at, updated_at
FROM customers
WHERE id = $1
"#,
id
)
.fetch_one(&self.pool)
.fetch_optional(&*self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)),
e => Error::Database(e),
})?;
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
Ok(record)
Ok(result)
}
#[instrument(skip(self))]
async fn update(&self, customer: &Customer) -> Result<Customer> {
let record = sqlx::query_as!(
async fn update(&self, customer: Customer) -> Result<Customer> {
let result = sqlx::query_as!(
Customer,
r#"
UPDATE customers
SET name = $1, subscription_tier = $2, status = $3, max_instances = $4, metadata = $5
WHERE id = $6
RETURNING *
SET name = $2, max_instances = $3, email = $4, updated_at = NOW()
WHERE id = $1
RETURNING id, name, max_instances::int as "max_instances!: i32",
email, created_at, updated_at
"#,
customer.id,
customer.name,
customer.subscription_tier,
customer.status,
customer.max_instances,
customer.metadata as _,
customer.id
customer.max_instances as i32,
customer.email,
)
.fetch_one(&self.pool)
.fetch_one(&*self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", customer.id)),
e => Error::Database(e),
})?;
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
Ok(record)
Ok(result)
}
#[instrument(skip(self))]
async fn delete(&self, id: Uuid) -> Result<()> {
sqlx::query!(
r#"
DELETE FROM customers WHERE id = $1
DELETE FROM customers
WHERE id = $1
"#,
id
)
.execute(&self.pool)
.execute(&*self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Customer {} not found", id)),
e => Error::Database(e),
})?;
.map_err(|e| Error::internal(format!("Database error: {}", e)))?;
Ok(())
}
}
pub struct PostgresInstanceRepository {
pool: PgPool,
}
impl PostgresInstanceRepository {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
}
#[async_trait]
impl InstanceRepository for PostgresInstanceRepository {
#[instrument(skip(self))]
async fn create(&self, instance: &Instance) -> Result<Instance> {
let record = sqlx::query_as!(
Instance,
r#"
INSERT INTO instances (id, customer_id, name, status, shard_id, region, config, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *
"#,
instance.id,
instance.customer_id,
instance.name,
instance.status,
instance.shard_id,
instance.region,
instance.config as _,
instance.created_at,
)
.fetch_one(&self.pool)
.await
.map_err(|e| {
error!("Failed to create instance: {}", e);
Error::Database(e)
})?;
Ok(record)
}
#[instrument(skip(self))]
async fn get(&self, id: Uuid) -> Result<Instance> {
let record = sqlx::query_as!(
Instance,
r#"
SELECT * FROM instances WHERE id = $1
"#,
id
)
.fetch_one(&self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)),
e => Error::Database(e),
})?;
Ok(record)
}
#[instrument(skip(self))]
async fn get_by_customer(&self, customer_id: Uuid) -> Result<Vec<Instance>> {
let records = sqlx::query_as!(
Instance,
r#"
SELECT * FROM instances WHERE customer_id = $1
"#,
customer_id
)
.fetch_all(&self.pool)
.await
.map_err(Error::Database)?;
Ok(records)
}
#[instrument(skip(self))]
async fn get_by_shard(&self, shard_id: i32) -> Result<Vec<Instance>> {
let records = sqlx::query_as!(
Instance,
r#"
SELECT * FROM instances WHERE shard_id = $1
"#,
shard_id
)
.fetch_all(&self.pool)
.await
.map_err(Error::Database)?;
Ok(records)
}
#[instrument(skip(self))]
async fn update(&self, instance: &Instance) -> Result<Instance> {
let record = sqlx::query_as!(
Instance,
r#"
UPDATE instances
SET name = $1, status = $2, shard_id = $3, region = $4, config = $5
WHERE id = $6
RETURNING *
"#,
instance.name,
instance.status,
instance.shard_id,
instance.region,
instance.config as _,
instance.id
)
.fetch_one(&self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", instance.id)),
e => Error::Database(e),
})?;
Ok(record)
}
#[instrument(skip(self))]
async fn delete(&self, id: Uuid) -> Result<()> {
sqlx::query!(
r#"
DELETE FROM instances WHERE id = $1
"#,
id
)
.execute(&self.pool)
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => Error::NotFound(format!("Instance {} not found", id)),
e => Error::Database(e),
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use sqlx::postgres::PgPoolOptions;
async fn create_test_pool() -> PgPool {
let database_url = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost/gb_test".to_string());
PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to create test pool")
}
#[fixture]
fn customer() -> Customer {
Customer::new(
"Test Corp".to_string(),
"enterprise".to_string(),
10,
)
}
#[rstest]
#[tokio::test]
async fn test_customer_crud(customer: Customer) {
let pool = create_test_pool().await;
let repo = PostgresCustomerRepository::new(pool);
// Create
let created = repo.create(&customer).await.unwrap();
assert_eq!(created.name, customer.name);
// Get
let retrieved = repo.get(created.id).await.unwrap();
assert_eq!(retrieved.id, created.id);
// Update
let mut updated = retrieved.clone();
updated.name = "Updated Corp".to_string();
let updated = repo.update(&updated).await.unwrap();
assert_eq!(updated.name, "Updated Corp");
// Delete
repo.delete(updated.id).await.unwrap();
assert!(repo.get(updated.id).await.is_err());
}
}
}

View file

@ -1,52 +1,42 @@
use async_trait::async_trait;
use gb_core::{Result, Error};
use redis::{AsyncCommands, Client};
use redis::{Client, Commands};
use serde::{de::DeserializeOwned, Serialize};
use std::time::Duration;
use tracing::{instrument, error};
use tracing::instrument;
pub struct RedisCache {
pub struct RedisStorage {
client: Client,
default_ttl: Duration,
}
impl RedisCache {
pub fn new(url: &str, default_ttl: Duration) -> Result<Self> {
let client = Client::open(url).map_err(|e| Error::Redis(e))?;
Ok(Self {
client,
default_ttl,
})
impl RedisStorage {
pub fn new(url: &str) -> Result<Self> {
let client = Client::open(url)
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
Ok(Self { client })
}
#[instrument(skip(self, value))]
pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<()> {
let mut conn = self.client.get_async_connection()
.await
.map_err(Error::Redis)?;
#[instrument(skip(self))]
pub async fn set<T: Serialize + std::fmt::Debug>(&self, key: &str, value: &T) -> Result<()> {
let mut conn = self.client.get_connection()
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
let serialized = serde_json::to_string(value)
.map_err(|e| Error::internal(format!("Serialization error: {}", e)))?;
conn.set_ex(key, serialized, self.default_ttl.as_secs() as usize)
.await
.map_err(|e| {
error!("Redis set error: {}", e);
Error::Redis(e)
})?;
conn.set(key, serialized)
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
Ok(())
}
#[instrument(skip(self))]
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
let mut conn = self.client.get_async_connection()
.await
.map_err(Error::Redis)?;
let mut conn = self.client.get_connection()
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
let value: Option<String> = conn.get(key)
.await
.map_err(Error::Redis)?;
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
match value {
Some(v) => {
@ -54,101 +44,33 @@ impl RedisCache {
.map_err(|e| Error::internal(format!("Deserialization error: {}", e)))?;
Ok(Some(deserialized))
}
None => Ok(None),
None => Ok(None)
}
}
#[instrument(skip(self))]
pub async fn delete(&self, key: &str) -> Result<()> {
let mut conn = self.client.get_async_connection()
.await
.map_err(Error::Redis)?;
pub async fn delete(&self, key: &str) -> Result<bool> {
let mut conn = self.client.get_connection()
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
conn.del(key)
.await
.map_err(|e| {
error!("Redis delete error: {}", e);
Error::Redis(e)
})?;
Ok(())
.map_err(|e| Error::internal(format!("Redis error: {}", e)))
}
#[instrument(skip(self))]
pub async fn increment(&self, key: &str) -> Result<i64> {
let mut conn = self.client.get_async_connection()
.await
.map_err(Error::Redis)?;
conn.incr(key, 1)
.await
.map_err(|e| {
error!("Redis increment error: {}", e);
Error::Redis(e)
})
}
#[instrument(skip(self))]
pub async fn set_with_ttl<T: Serialize>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
let mut conn = self.client.get_async_connection()
.await
.map_err(Error::Redis)?;
pub async fn set_with_ttl<T: Serialize + std::fmt::Debug>(&self, key: &str, value: &T, ttl: Duration) -> Result<()> {
let mut conn = self.client.get_connection()
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
let serialized = serde_json::to_string(value)
.map_err(|e| Error::internal(format!("Serialization error: {}", e)))?;
conn.set_ex(key, serialized, ttl.as_secs() as usize)
.await
.map_err(|e| {
error!("Redis set error: {}", e);
Error::Redis(e)
})?;
redis::pipe()
.set(key, serialized)
.expire(key, ttl.as_secs() as i64)
.query(&mut conn)
.map_err(|e| Error::internal(format!("Redis error: {}", e)))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestStruct {
field: String,
}
#[tokio::test]
async fn test_redis_cache() {
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://127.0.0.1/".to_string());
let cache = RedisCache::new(&redis_url, Duration::from_secs(60)).unwrap();
// Test set and get
let test_value = TestStruct {
field: "test".to_string(),
};
cache.set("test_key", &test_value).await.unwrap();
let retrieved: Option<TestStruct> = cache.get("test_key").await.unwrap();
assert_eq!(retrieved.unwrap(), test_value);
// Test delete
cache.delete("test_key").await.unwrap();
let deleted: Option<TestStruct> = cache.get("test_key").await.unwrap();
assert!(deleted.is_none());
// Test increment
cache.set("counter", &0).await.unwrap();
let count = cache.increment("counter").await.unwrap();
assert_eq!(count, 1);
// Test TTL
cache.set_with_ttl("ttl_key", &test_value, Duration::from_secs(1)).await.unwrap();
tokio::time::sleep(Duration::from_secs(2)).await;
let expired: Option<TestStruct> = cache.get("ttl_key").await.unwrap();
assert!(expired.is_none());
}
}
}

View file

@ -1,7 +1,6 @@
use async_trait::async_trait;
use gb_core::{Result, Error};
use tikv_client::{Config, RawClient, Value};
use tracing::{instrument, error};
use tikv_client::{RawClient, Config, KvPair};
use tracing::{error, instrument};
pub struct TiKVStorage {
client: RawClient,
@ -10,26 +9,15 @@ pub struct TiKVStorage {
impl TiKVStorage {
pub async fn new(pd_endpoints: Vec<String>) -> Result<Self> {
let config = Config::default();
let client = RawClient::new(pd_endpoints, config)
let client = RawClient::new(pd_endpoints)
.await
.map_err(|e| Error::internal(format!("TiKV client error: {}", e)))?;
.map_err(|e| Error::internal(format!("TiKV error: {}", e)))?;
Ok(Self { client })
}
#[instrument(skip(self, value))]
pub async fn put(&self, key: &[u8], value: Value) -> Result<()> {
self.client
.put(key.to_vec(), value)
.await
.map_err(|e| {
error!("TiKV put error: {}", e);
Error::internal(format!("TiKV error: {}", e))
})
}
#[instrument(skip(self))]
pub async fn get(&self, key: &[u8]) -> Result<Option<Value>> {
pub async fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.client
.get(key.to_vec())
.await
@ -39,6 +27,17 @@ impl TiKVStorage {
})
}
#[instrument(skip(self))]
pub async fn put(&self, key: &[u8], value: &[u8]) -> Result<()> {
self.client
.put(key.to_vec(), value.to_vec())
.await
.map_err(|e| {
error!("TiKV put error: {}", e);
Error::internal(format!("TiKV error: {}", e))
})
}
#[instrument(skip(self))]
pub async fn delete(&self, key: &[u8]) -> Result<()> {
self.client
@ -51,7 +50,7 @@ impl TiKVStorage {
}
#[instrument(skip(self))]
pub async fn batch_get(&self, keys: Vec<Vec<u8>>) -> Result<Vec<KVPair>> {
pub async fn batch_get(&self, keys: Vec<Vec<u8>>) -> Result<Vec<KvPair>> {
self.client
.batch_get(keys)
.await
@ -62,7 +61,7 @@ impl TiKVStorage {
}
#[instrument(skip(self))]
pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<Vec<KVPair>> {
pub async fn scan(&self, start: &[u8], end: &[u8], limit: u32) -> Result<Vec<KvPair>> {
self.client
.scan(start.to_vec()..end.to_vec(), limit)
.await
@ -71,58 +70,4 @@ impl TiKVStorage {
Error::internal(format!("TiKV error: {}", e))
})
}
}
#[derive(Debug, Clone)]
pub struct KVPair {
pub key: Vec<u8>,
pub value: Value,
}
#[cfg(test)]
mod tests {
use super::*;
use tikv_client::Value;
#[tokio::test]
async fn test_tikv_storage() {
let pd_endpoints = vec!["127.0.0.1:2379".to_string()];
let storage = TiKVStorage::new(pd_endpoints).await.unwrap();
// Test put and get
let key = b"test_key";
let value = Value::from(b"test_value".to_vec());
storage.put(key, value.clone()).await.unwrap();
let retrieved = storage.get(key).await.unwrap();
assert_eq!(retrieved.unwrap(), value);
// Test delete
storage.delete(key).await.unwrap();
let deleted = storage.get(key).await.unwrap();
assert!(deleted.is_none());
// Test batch operations
let pairs = vec![
(b"key1".to_vec(), Value::from(b"value1".to_vec())),
(b"key2".to_vec(), Value::from(b"value2".to_vec())),
];
for (key, value) in pairs.clone() {
storage.put(&key, value).await.unwrap();
}
let keys: Vec<Vec<u8>> = pairs.iter().map(|(k, _)| k.clone()).collect();
let retrieved = storage.batch_get(keys).await.unwrap();
assert_eq!(retrieved.len(), pairs.len());
// Test scan
let scanned = storage.scan(b"key", b"key3", 10).await.unwrap();
assert_eq!(scanned.len(), 2);
// Cleanup
for (key, _) in pairs {
storage.delete(&key).await.unwrap();
}
}
}
}