Skip to content
Closed
255 changes: 183 additions & 72 deletions src/webserver/oidc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use std::future::ready;
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
use std::{
future::Future,
pin::Pin,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};

use crate::webserver::http_client::get_http_client_from_appdata;
use crate::{app_config::AppConfig, AppState};
Expand All @@ -20,6 +26,7 @@ use openidconnect::{
TokenResponse,
};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use super::http_client::make_http_client;

Expand All @@ -29,6 +36,48 @@ const SQLPAGE_AUTH_COOKIE_NAME: &str = "sqlpage_auth";
const SQLPAGE_REDIRECT_URI: &str = "/sqlpage/oidc_callback";
const SQLPAGE_STATE_COOKIE_NAME: &str = "sqlpage_oidc_state";

// Cache configuration based on industry best practices
const PROVIDER_METADATA_CACHE_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours
const MIN_REFRESH_INTERVAL: Duration = Duration::from_secs(5 * 60); // 5 minutes (rate limiting)

#[derive(Clone, Debug)]
struct CachedProvider {
client: OidcClient,
metadata: openidconnect::core::CoreProviderMetadata,
cached_at: Instant,
last_refresh_attempt: Instant,
}

impl CachedProvider {
fn new(client: OidcClient, metadata: openidconnect::core::CoreProviderMetadata) -> Self {
let now = Instant::now();
Self {
client,
metadata,
cached_at: now,
last_refresh_attempt: now,
}
}

fn is_stale(&self) -> bool {
self.cached_at.elapsed() > PROVIDER_METADATA_CACHE_DURATION
}

fn can_refresh(&self) -> bool {
self.last_refresh_attempt.elapsed() > MIN_REFRESH_INTERVAL
}

fn update(&mut self, client: OidcClient, metadata: openidconnect::core::CoreProviderMetadata) {
self.client = client;
self.metadata = metadata;
self.cached_at = Instant::now();
}

fn mark_refresh_attempt(&mut self) {
self.last_refresh_attempt = Instant::now();
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(transparent)]
pub struct OidcAdditionalClaims(pub(crate) serde_json::Map<String, serde_json::Value>);
Expand Down Expand Up @@ -117,7 +166,63 @@ fn get_app_host(config: &AppConfig) -> String {

pub struct OidcState {
pub config: Arc<OidcConfig>,
pub client: Arc<OidcClient>,
cached_provider: Arc<RwLock<CachedProvider>>,
}

impl OidcState {
/// Get the current OIDC client, checking if cache is stale but not attempting refresh
pub fn get_client(&self) -> OidcClient {
// For now, we'll use a simple approach - get the current client
// In a production system, you might want to check if cache is stale
// and trigger an async refresh task
futures_util::executor::block_on(async {
self.cached_provider.read().await.client.clone()
})
}

/// Get the current OIDC client, refreshing if stale and possible
pub async fn get_client_with_refresh(&self, app_config: &AppConfig) -> OidcClient {
// Try to refresh if cache is stale and we haven't tried recently
{
let cache = self.cached_provider.read().await;
if cache.is_stale() && cache.can_refresh() {
// Release read lock before attempting refresh
drop(cache);
if let Err(e) = self.refresh_provider(app_config).await {
log::warn!("Failed to refresh OIDC provider: {}", e);
}
}
}

self.cached_provider.read().await.client.clone()
}

/// Refresh provider metadata and client from the OIDC provider
async fn refresh_provider(&self, app_config: &AppConfig) -> anyhow::Result<()> {
let mut cache = self.cached_provider.write().await;

// Double-check we can refresh (another thread might have just done it)
if !cache.can_refresh() {
return Ok(());
}

cache.mark_refresh_attempt();

log::debug!(
"Refreshing OIDC provider metadata for {}",
self.config.issuer_url
);

let http_client = make_http_client(app_config)?;
let new_metadata =
discover_provider_metadata(&http_client, self.config.issuer_url.clone()).await?;
let new_client = make_oidc_client(&self.config, new_metadata.clone())?;

cache.update(new_client, new_metadata);

log::debug!("Successfully refreshed OIDC provider");
Ok(())
}
}

pub async fn initialize_oidc_state(
Expand All @@ -130,14 +235,18 @@ pub async fn initialize_oidc_state(
};

let http_client = make_http_client(app_config)?;

// Initial metadata discovery
let provider_metadata =
discover_provider_metadata(&http_client, oidc_cfg.issuer_url.clone()).await?;
let client = make_oidc_client(&oidc_cfg, provider_metadata)?;
let client = make_oidc_client(&oidc_cfg, provider_metadata.clone())?;

Ok(Some(Arc::new(OidcState {
let oidc_state = Arc::new(OidcState {
config: oidc_cfg,
client: Arc::new(client),
})))
cached_provider: Arc::new(RwLock::new(CachedProvider::new(client, provider_metadata))),
});

Ok(Some(oidc_state))
}

pub struct OidcMiddleware {
Expand Down Expand Up @@ -203,60 +312,57 @@ where
oidc_state,
}
}
}

fn handle_unauthenticated_request(
&self,
request: ServiceRequest,
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
log::debug!("Handling unauthenticated request to {}", request.path());
if request.path() == SQLPAGE_REDIRECT_URI {
log::debug!("The request is the OIDC callback");
return self.handle_oidc_callback(request);
}

if self.oidc_state.config.is_public_path(request.path()) {
log::debug!(
"The request path {} is not in a public path, skipping OIDC authentication",
request.path()
);
return Box::pin(self.service.call(request));
}
async fn handle_unauthenticated_request<S>(
oidc_state: Arc<OidcState>,
request: ServiceRequest,
service: S,
) -> Result<ServiceResponse<BoxBody>, Error>
where
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
{
log::debug!("Handling unauthenticated request to {}", request.path());

log::debug!("Redirecting to OIDC provider");
if request.path() == SQLPAGE_REDIRECT_URI {
log::debug!("The request is the OIDC callback");
return handle_oidc_callback(oidc_state, request).await;
}

let response = build_auth_provider_redirect_response(
&self.oidc_state.client,
&self.oidc_state.config,
&request,
if oidc_state.config.is_public_path(request.path()) {
log::debug!(
"The request path {} is public, skipping OIDC authentication",
request.path()
);
Box::pin(async move { Ok(request.into_response(response)) })
return service.call(request).await;
}

fn handle_oidc_callback(
&self,
request: ServiceRequest,
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
let oidc_client = Arc::clone(&self.oidc_state.client);
let oidc_config = Arc::clone(&self.oidc_state.config);
log::debug!("Redirecting to OIDC provider");
let client = oidc_state.get_client();
let response = build_auth_provider_redirect_response(&client, &oidc_state.config, &request);
Ok(request.into_response(response))
}

Box::pin(async move {
let query_string = request.query_string();
match process_oidc_callback(&oidc_client, query_string, &request).await {
Ok(response) => Ok(request.into_response(response)),
Err(e) => {
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
let resp =
build_auth_provider_redirect_response(&oidc_client, &oidc_config, &request);
Ok(request.into_response(resp))
}
}
})
async fn handle_oidc_callback(
oidc_state: Arc<OidcState>,
request: ServiceRequest,
) -> Result<ServiceResponse<BoxBody>, Error> {
let oidc_client = oidc_state.get_client();
let query_string = request.query_string();
match process_oidc_callback(&oidc_client, query_string, &request).await {
Ok(response) => Ok(request.into_response(response)),
Err(e) => {
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
let resp =
build_auth_provider_redirect_response(&oidc_client, &oidc_state.config, &request);
Ok(request.into_response(resp))
}
}
}

impl<S> Service<ServiceRequest> for OidcService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + Clone,
S::Future: 'static,
{
type Response = ServiceResponse<BoxBody>;
Expand All @@ -268,31 +374,36 @@ where
fn call(&self, request: ServiceRequest) -> Self::Future {
log::trace!("Started OIDC middleware request handling");

let oidc_client = Arc::clone(&self.oidc_state.client);
match get_authenticated_user_info(&oidc_client, &request) {
Ok(Some(claims)) => {
log::trace!("Storing authenticated user info in request extensions: {claims:?}");
request.extensions_mut().insert(claims);
}
Ok(None) => {
log::trace!("No authenticated user found");
return self.handle_unauthenticated_request(request);
}
Err(e) => {
log::debug!(
"{:?}",
e.context(
"An auth cookie is present but could not be verified. \
Redirecting to OIDC provider to re-authenticate."
)
);
return self.handle_unauthenticated_request(request);
}
}
let future = self.service.call(request);
let oidc_state = Arc::clone(&self.oidc_state);
let service = self.service.clone();

Box::pin(async move {
let response = future.await?;
Ok(response)
let oidc_client = oidc_state.get_client();
match get_authenticated_user_info(&oidc_client, &request) {
Ok(Some(claims)) => {
log::trace!(
"Storing authenticated user info in request extensions: {claims:?}"
);
request.extensions_mut().insert(claims);
let future = service.call(request);
let response = future.await?;
Ok(response)
}
Ok(None) => {
log::trace!("No authenticated user found");
handle_unauthenticated_request(oidc_state, request, service).await
}
Err(e) => {
log::debug!(
"{:?}",
e.context(
"An auth cookie is present but could not be verified. \
Redirecting to OIDC provider to re-authenticate."
)
);
handle_unauthenticated_request(oidc_state, request, service).await
}
}
})
}
}
Expand Down
Loading