botserver/src/core/oauth/mod.rs

293 lines
8.3 KiB
Rust
Raw Normal View History

//! OAuth2 Authentication Module
//!
//! Provides OAuth2 authentication support for multiple providers:
//! - Google
//! - Discord
//! - Reddit
//! - Twitter (X)
//! - Microsoft
//! - Facebook
pub mod providers;
pub mod routes;
use serde::{Deserialize, Serialize};
use std::fmt;
/// Supported OAuth2 providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OAuthProvider {
Google,
Discord,
Reddit,
Twitter,
Microsoft,
Facebook,
}
impl OAuthProvider {
/// Get all available providers
pub fn all() -> Vec<OAuthProvider> {
vec![
OAuthProvider::Google,
OAuthProvider::Discord,
OAuthProvider::Reddit,
OAuthProvider::Twitter,
OAuthProvider::Microsoft,
OAuthProvider::Facebook,
]
}
/// Get provider from string
pub fn from_str(s: &str) -> Option<OAuthProvider> {
match s.to_lowercase().as_str() {
"google" => Some(OAuthProvider::Google),
"discord" => Some(OAuthProvider::Discord),
"reddit" => Some(OAuthProvider::Reddit),
"twitter" | "x" => Some(OAuthProvider::Twitter),
"microsoft" => Some(OAuthProvider::Microsoft),
"facebook" => Some(OAuthProvider::Facebook),
_ => None,
}
}
/// Get the config key prefix for this provider
pub fn config_prefix(&self) -> &'static str {
match self {
OAuthProvider::Google => "oauth-google",
OAuthProvider::Discord => "oauth-discord",
OAuthProvider::Reddit => "oauth-reddit",
OAuthProvider::Twitter => "oauth-twitter",
OAuthProvider::Microsoft => "oauth-microsoft",
OAuthProvider::Facebook => "oauth-facebook",
}
}
/// Get display name for UI
pub fn display_name(&self) -> &'static str {
match self {
OAuthProvider::Google => "Google",
OAuthProvider::Discord => "Discord",
OAuthProvider::Reddit => "Reddit",
OAuthProvider::Twitter => "Twitter",
OAuthProvider::Microsoft => "Microsoft",
OAuthProvider::Facebook => "Facebook",
}
}
/// Get icon/emoji for UI
pub fn icon(&self) -> &'static str {
match self {
OAuthProvider::Google => "🔵",
OAuthProvider::Discord => "🎮",
OAuthProvider::Reddit => "🟠",
OAuthProvider::Twitter => "🐦",
OAuthProvider::Microsoft => "🪟",
OAuthProvider::Facebook => "📘",
}
}
}
impl fmt::Display for OAuthProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.display_name())
}
}
/// OAuth configuration for a provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthConfig {
pub provider: OAuthProvider,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub enabled: bool,
}
impl OAuthConfig {
/// Create a new OAuth config
pub fn new(
provider: OAuthProvider,
client_id: String,
client_secret: String,
redirect_uri: String,
) -> Self {
Self {
provider,
client_id,
client_secret,
redirect_uri,
enabled: true,
}
}
/// Check if the config is valid (has required fields)
pub fn is_valid(&self) -> bool {
self.enabled
&& !self.client_id.is_empty()
&& !self.client_secret.is_empty()
&& !self.redirect_uri.is_empty()
}
}
/// User information returned from OAuth provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthUserInfo {
/// Provider-specific user ID
pub provider_id: String,
/// OAuth provider
pub provider: OAuthProvider,
/// User's email (if available)
pub email: Option<String>,
/// User's display name
pub name: Option<String>,
/// User's avatar URL
pub avatar_url: Option<String>,
/// Raw response from provider (for debugging/additional fields)
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<serde_json::Value>,
}
/// OAuth token response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthTokenResponse {
pub access_token: String,
#[serde(default)]
pub token_type: String,
#[serde(default)]
pub expires_in: Option<i64>,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub scope: Option<String>,
}
/// OAuth error types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthError {
pub error: String,
pub error_description: Option<String>,
}
impl fmt::Display for OAuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(desc) = &self.error_description {
write!(f, "{}: {}", self.error, desc)
} else {
write!(f, "{}", self.error)
}
}
}
impl std::error::Error for OAuthError {}
/// State parameter for OAuth flow (CSRF protection)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthState {
/// Random state token
pub token: String,
/// Provider being used
pub provider: OAuthProvider,
/// Optional redirect URL after login
pub redirect_after: Option<String>,
/// Timestamp when state was created
pub created_at: i64,
}
impl OAuthState {
/// Create a new OAuth state
pub fn new(provider: OAuthProvider, redirect_after: Option<String>) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let token = uuid::Uuid::new_v4().to_string();
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
Self {
token,
provider,
redirect_after,
created_at,
}
}
/// Check if state is expired (default: 10 minutes)
pub fn is_expired(&self) -> bool {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
now - self.created_at > 600 // 10 minutes
}
/// Encode state to URL-safe string
pub fn encode(&self) -> String {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let json = serde_json::to_string(self).unwrap_or_default();
URL_SAFE_NO_PAD.encode(json.as_bytes())
}
/// Decode state from URL-safe string
pub fn decode(encoded: &str) -> Option<Self> {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let bytes = URL_SAFE_NO_PAD.decode(encoded).ok()?;
let json = String::from_utf8(bytes).ok()?;
serde_json::from_str(&json).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_from_str() {
assert_eq!(
OAuthProvider::from_str("google"),
Some(OAuthProvider::Google)
);
assert_eq!(
OAuthProvider::from_str("DISCORD"),
Some(OAuthProvider::Discord)
);
assert_eq!(
OAuthProvider::from_str("Twitter"),
Some(OAuthProvider::Twitter)
);
assert_eq!(OAuthProvider::from_str("x"), Some(OAuthProvider::Twitter));
assert_eq!(OAuthProvider::from_str("invalid"), None);
}
#[test]
fn test_oauth_state_encode_decode() {
let state = OAuthState::new(OAuthProvider::Google, Some("/dashboard".to_string()));
let encoded = state.encode();
let decoded = OAuthState::decode(&encoded).unwrap();
assert_eq!(decoded.provider, OAuthProvider::Google);
assert_eq!(decoded.redirect_after, Some("/dashboard".to_string()));
assert!(!decoded.is_expired());
}
#[test]
fn test_oauth_config_validation() {
let valid_config = OAuthConfig::new(
OAuthProvider::Google,
"client_id".to_string(),
"client_secret".to_string(),
"http://localhost/callback".to_string(),
);
assert!(valid_config.is_valid());
let mut invalid_config = valid_config.clone();
invalid_config.client_id = String::new();
assert!(!invalid_config.is_valid());
}
}