new(all): Initial import.

This commit is contained in:
Rodrigo Rodriguez 2024-12-24 09:59:55 -03:00
parent e6201719ce
commit cba43cefde
6 changed files with 58 additions and 101 deletions

2
Cargo.lock generated
View file

@ -2349,6 +2349,7 @@ version = "0.1.0"
dependencies = [
"async-trait",
"axum 0.7.9",
"chrono",
"futures-util",
"gb-core",
"gb-messaging",
@ -2357,6 +2358,7 @@ dependencies = [
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-test",
"tower 0.4.13",
"tower-http 0.5.2",

View file

@ -18,7 +18,9 @@ serde_json.workspace = true
uuid.workspace = true
tracing.workspace = true
async-trait.workspace = true
futures-util = "0.3"
futures-util = { version = "0.3", features = ["sink"] }
chrono = { workspace = true, features = ["serde"] }
tokio-stream = "0.1.17"
[dev-dependencies]
rstest.workspace = true

View file

@ -2,7 +2,7 @@ use axum::{
routing::{get, post},
Router,
extract::{
ws::{WebSocket, Message as WsMessage},
ws::WebSocket,
Path, State, WebSocketUpgrade,
},
response::IntoResponse,
@ -10,15 +10,13 @@ use axum::{
};
use gb_core::{Result, Error, models::*};
use gb_messaging::{MessageProcessor, models::MessageEnvelope}; // Update this line
use gb_messaging::{MessageProcessor, models::MessageEnvelope};
use std::sync::Arc;
use chrono;
use tokio::sync::Mutex;
use tracing::{instrument, error};
use uuid::Uuid;
use futures_util::StreamExt;
use futures_util::SinkExt;
pub struct ApiState {
pub message_processor: Mutex<MessageProcessor>,
@ -42,36 +40,36 @@ pub fn create_router(message_processor: MessageProcessor) -> Router {
async fn handle_ws_connection(
ws: WebSocket,
State(_state): State<Arc<ApiState>>,
) -> Result<(), Error> {
let (mut sender, mut receiver) = ws.split();
// ... rest of the implementation
state: Arc<ApiState>,
) -> Result<()> {
let (_sender, mut receiver) = ws.split();
while let Some(Ok(msg)) = receiver.next().await {
if let Ok(text) = msg.to_text() {
if let Ok(envelope) = serde_json::from_str::<MessageEnvelope>(text) {
let mut processor = state.message_processor.lock().await;
if let Err(e) = processor.process_message(&envelope).await {
error!("Failed to process message: {}", e);
}
}
}
}
Ok(())
}
#[axum::debug_handler]
#[instrument(skip(state, ws))]
#[instrument(skip(state))]
async fn websocket_handler(
State(state): State<Arc<ApiState>>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
ws.on_upgrade(|socket| async move {
let (mut sender, mut receiver) = socket.split();
while let Some(Ok(msg)) = receiver.next().await {
if let Ok(text) = msg.to_text() {
if let Ok(envelope) = serde_json::from_str::<MessageEnvelope>(text) {
let mut processor = state.message_processor.lock().await;
if let Err(e) = processor.sender().send(envelope).await {
error!("Failed to process WebSocket message: {}", e);
}
}
}
}
let _ = handle_ws_connection(socket, state).await;
})
}
#[axum::debug_handler]
#[instrument(skip(state, message))]
#[instrument(skip(state))]
async fn send_message(
State(state): State<Arc<ApiState>>,
Json(message): Json<Message>,
@ -83,8 +81,8 @@ async fn send_message(
};
let mut processor = state.message_processor.lock().await;
processor.sender().send(envelope.clone()).await
.map_err(|e| Error::internal(format!("Failed to send message: {}", e)))?;
processor.process_message(&envelope).await
.map_err(|e| Error::internal(format!("Failed to process message: {}", e)))?;
Ok(Json(MessageId(envelope.id)))
}
@ -92,14 +90,14 @@ async fn send_message(
#[axum::debug_handler]
#[instrument(skip(state))]
async fn get_message(
State(state): State<Arc<ApiState>>,
State(_state): State<Arc<ApiState>>,
Path(id): Path<Uuid>,
) -> Result<Json<Message>> {
todo!()
}
#[axum::debug_handler]
#[instrument(skip(state, config))]
#[instrument(skip(state))]
async fn create_room(
State(_state): State<Arc<ApiState>>,
Json(_config): Json<RoomConfig>,
@ -110,7 +108,7 @@ async fn create_room(
#[axum::debug_handler]
#[instrument(skip(state))]
async fn get_room(
State(state): State<Arc<ApiState>>,
State(_state): State<Arc<ApiState>>,
Path(id): Path<Uuid>,
) -> Result<Json<Room>> {
todo!()
@ -119,7 +117,7 @@ async fn get_room(
#[axum::debug_handler]
#[instrument(skip(state))]
async fn join_room(
State(state): State<Arc<ApiState>>,
State(_state): State<Arc<ApiState>>,
Path(id): Path<Uuid>,
Json(user_id): Json<Uuid>,
) -> Result<Json<Connection>> {
@ -136,7 +134,6 @@ mod tests {
#[tokio::test]
async fn test_health_check() {
let app = create_router(MessageProcessor::new(100));
let response = app
.oneshot(
axum::http::Request::builder()
@ -146,39 +143,6 @@ mod tests {
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_send_message() {
let app = create_router(MessageProcessor::new(100));
let message = Message {
id: Uuid::new_v4(),
customer_id: Uuid::new_v4(),
instance_id: Uuid::new_v4(),
conversation_id: Uuid::new_v4(),
sender_id: Uuid::new_v4(),
kind: "test".to_string(),
content: "test message".to_string(),
metadata: serde_json::Value::Object(serde_json::Map::new()),
created_at: chrono::Utc::now(),
shard_key: 0,
};
let response = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/messages")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&message).unwrap()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}

View file

@ -52,3 +52,5 @@ headers = "0.3"
rstest = "0.18"
tokio-test = "0.4"
mockall = "0.12"
axum-extra = { version = "0.7" }
sqlx = { version = "0.7", features = ["runtime-tokio-native-tls", "postgres", "uuid", "chrono", "json"] }

View file

@ -1,46 +1,33 @@
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
RequestPartsExt,
http::Request,
response::Response,
middleware::Next,
};
use axum_extra::headers::{authorization::Bearer, Authorization};
use axum_extra::TypedHeader;
use crate::{models::User, AuthError};
use axum_extra::headers::{Authorization, authorization::Bearer};
use gb_core::User;
use jsonwebtoken::{decode, DecodingKey, Validation};
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
AuthError::TokenExpired => (StatusCode::UNAUTHORIZED, "Token expired"),
AuthError::InvalidCredentials => (StatusCode::UNAUTHORIZED, "Invalid credentials"),
AuthError::AuthenticationFailed => (StatusCode::UNAUTHORIZED, "Authentication failed"),
AuthError::Database(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error"),
AuthError::Cache(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Cache error"),
AuthError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
};
(status, error_message).into_response()
}
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
exp: i64,
}
#[async_trait]
impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
{
type Rejection = AuthError;
pub async fn auth_middleware<B>(
TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
request: Request<B>,
next: Next<B>,
) -> Result<Response, AuthError> {
let token = auth.token();
let key = DecodingKey::from_secret(b"secret");
let validation = Validation::default();
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::MissingToken)?;
let token = bearer.token();
todo!("Implement token validation")
match decode::<Claims>(token, &key, &validation) {
Ok(_claims) => {
let response = next.run(request).await;
Ok(response)
}
Err(_) => Err(AuthError::InvalidToken),
}
}

View file