//! OAuth2 Authentication Module //! //! Provides OAuth2 authentication support for multiple providers: //! - Google //! - Discord //! - Reddit //! - Twitter (X) //! - Microsoft //! - Facebook pub mod providers; pub mod routes; use serde::{Deserialize, Serialize}; use std::fmt; /// Supported OAuth2 providers #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum OAuthProvider { Google, Discord, Reddit, Twitter, Microsoft, Facebook, } impl OAuthProvider { /// Get all available providers pub fn all() -> Vec { vec![ OAuthProvider::Google, OAuthProvider::Discord, OAuthProvider::Reddit, OAuthProvider::Twitter, OAuthProvider::Microsoft, OAuthProvider::Facebook, ] } /// Get provider from string pub fn from_str(s: &str) -> Option { match s.to_lowercase().as_str() { "google" => Some(OAuthProvider::Google), "discord" => Some(OAuthProvider::Discord), "reddit" => Some(OAuthProvider::Reddit), "twitter" | "x" => Some(OAuthProvider::Twitter), "microsoft" => Some(OAuthProvider::Microsoft), "facebook" => Some(OAuthProvider::Facebook), _ => None, } } /// Get the config key prefix for this provider pub fn config_prefix(&self) -> &'static str { match self { OAuthProvider::Google => "oauth-google", OAuthProvider::Discord => "oauth-discord", OAuthProvider::Reddit => "oauth-reddit", OAuthProvider::Twitter => "oauth-twitter", OAuthProvider::Microsoft => "oauth-microsoft", OAuthProvider::Facebook => "oauth-facebook", } } /// Get display name for UI pub fn display_name(&self) -> &'static str { match self { OAuthProvider::Google => "Google", OAuthProvider::Discord => "Discord", OAuthProvider::Reddit => "Reddit", OAuthProvider::Twitter => "Twitter", OAuthProvider::Microsoft => "Microsoft", OAuthProvider::Facebook => "Facebook", } } /// Get icon/emoji for UI pub fn icon(&self) -> &'static str { match self { OAuthProvider::Google => "", OAuthProvider::Discord => "", OAuthProvider::Reddit => "", OAuthProvider::Twitter => "", OAuthProvider::Microsoft => "", OAuthProvider::Facebook => "", } } } impl fmt::Display for OAuthProvider { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.display_name()) } } /// OAuth configuration for a provider #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthConfig { pub provider: OAuthProvider, pub client_id: String, pub client_secret: String, pub redirect_uri: String, pub enabled: bool, } impl OAuthConfig { /// Create a new OAuth config pub fn new( provider: OAuthProvider, client_id: String, client_secret: String, redirect_uri: String, ) -> Self { Self { provider, client_id, client_secret, redirect_uri, enabled: true, } } /// Check if the config is valid (has required fields) pub fn is_valid(&self) -> bool { self.enabled && !self.client_id.is_empty() && !self.client_secret.is_empty() && !self.redirect_uri.is_empty() } } /// User information returned from OAuth provider #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthUserInfo { /// Provider-specific user ID pub provider_id: String, /// OAuth provider pub provider: OAuthProvider, /// User's email (if available) pub email: Option, /// User's display name pub name: Option, /// User's avatar URL pub avatar_url: Option, /// Raw response from provider (for debugging/additional fields) #[serde(skip_serializing_if = "Option::is_none")] pub raw: Option, } /// OAuth token response #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthTokenResponse { pub access_token: String, #[serde(default)] pub token_type: String, #[serde(default)] pub expires_in: Option, #[serde(default)] pub refresh_token: Option, #[serde(default)] pub scope: Option, } /// OAuth error types #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthError { pub error: String, pub error_description: Option, } impl fmt::Display for OAuthError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(desc) = &self.error_description { write!(f, "{}: {}", self.error, desc) } else { write!(f, "{}", self.error) } } } impl std::error::Error for OAuthError {} /// State parameter for OAuth flow (CSRF protection) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthState { /// Random state token pub token: String, /// Provider being used pub provider: OAuthProvider, /// Optional redirect URL after login pub redirect_after: Option, /// Timestamp when state was created pub created_at: i64, } impl OAuthState { /// Create a new OAuth state pub fn new(provider: OAuthProvider, redirect_after: Option) -> Self { use std::time::{SystemTime, UNIX_EPOCH}; let token = uuid::Uuid::new_v4().to_string(); let created_at = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() as i64; Self { token, provider, redirect_after, created_at, } } /// Check if state is expired (default: 10 minutes) pub fn is_expired(&self) -> bool { use std::time::{SystemTime, UNIX_EPOCH}; let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() as i64; now - self.created_at > 600 // 10 minutes } /// Encode state to URL-safe string pub fn encode(&self) -> String { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let json = serde_json::to_string(self).unwrap_or_default(); URL_SAFE_NO_PAD.encode(json.as_bytes()) } /// Decode state from URL-safe string pub fn decode(encoded: &str) -> Option { use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; let bytes = URL_SAFE_NO_PAD.decode(encoded).ok()?; let json = String::from_utf8(bytes).ok()?; serde_json::from_str(&json).ok() } }