use crate::auth::sign_with_key;
use crate::error::{ApiError, ApiErrorKind};
use crate::headers::util::get_header;
use crate::server::AppState;
use actix_web::dev::Payload;
use actix_web::{web::Data, FromRequest, HttpRequest};
use futures::future::LocalBoxFuture;
use futures::FutureExt;
use openssl::error::ErrorStack;
use uuid::Uuid;
use autopush_common::util::user_agent::UserAgentInfo;
pub struct AuthorizationCheck {
pub user_agent: UserAgentInfo,
}
impl AuthorizationCheck {
pub fn generate_token(auth_key: &str, user: &Uuid) -> Result<String, ErrorStack> {
sign_with_key(auth_key.as_bytes(), user.as_simple().to_string().as_bytes())
}
pub fn validate_token(
token: &str,
uaid: &Uuid,
auth_keys: &[String],
user_agent: UserAgentInfo,
) -> Result<Self, ApiError> {
for key in auth_keys {
let expected_token =
sign_with_key(key.as_bytes(), uaid.as_simple().to_string().as_bytes())
.map_err(ApiErrorKind::RegistrationSecretHash)?;
debug!("expected: {:?}, recv'd {:?}", &expected_token, &token);
if expected_token.len() == token.len()
&& openssl::memcmp::eq(expected_token.as_bytes(), token.as_bytes())
{
return Ok(Self { user_agent });
}
}
Err(ApiErrorKind::InvalidLocalAuth("incorrect auth token".to_owned()).into())
}
}
impl FromRequest for AuthorizationCheck {
type Error = ApiError;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
let req = req.clone();
async move {
let uaid = req
.match_info()
.get("uaid")
.expect("{uaid} must be part of the path")
.parse::<Uuid>()
.map_err(|_| ApiErrorKind::NoUser)?;
let state: Data<AppState> = Data::extract(&req)
.into_inner()
.expect("No server state found");
let auth_header = get_header(&req, "Authorization")
.ok_or_else(|| ApiErrorKind::InvalidLocalAuth("missing auth header".to_owned()))?;
let token = get_token_from_auth_header(auth_header)
.ok_or_else(|| ApiErrorKind::InvalidLocalAuth("missing auth token".to_owned()))?;
let user_agent = UserAgentInfo::from(&req);
Self::validate_token(token, &uaid, &state.settings.auth_keys(), user_agent)
}
.boxed_local()
}
}
fn get_token_from_auth_header(header: &str) -> Option<&str> {
let mut split = header.splitn(2, ' ');
let scheme = split.next()?;
if !["bearer", "webpush"].contains(&scheme.to_lowercase().as_str()) {
return None;
}
split.next()
}
#[cfg(test)]
mod test {
use crate::error::ApiResult;
use autopush_common::util::user_agent::UserAgentInfo;
use super::*;
#[test]
fn test_signature() -> ApiResult<()> {
let uaid: Uuid = "729e5104f5f04abc9196085340317dea".parse().unwrap();
let auth_keys = ["HJVPy4ZwF4Yz_JdvXTL8hRcwIhv742vC60Tg5Ycrvw8=".to_owned()].to_vec();
let token = AuthorizationCheck::generate_token(auth_keys.first().unwrap(), &uaid).unwrap();
AuthorizationCheck::validate_token(&token, &uaid, &auth_keys, UserAgentInfo::default())?;
Ok(())
}
#[test]
fn test_legacy_signature() -> ApiResult<()> {
let uaid: Uuid = "729e5104f5f04abc9196085340317dea".parse().unwrap();
let auth_keys = ["HJVPy4ZwF4Yz_JdvXTL8hRcwIhv742vC60Tg5Ycrvw8=".to_owned()].to_vec();
let legacy_token = "f694963453adf5dedcc379bbdd6900d692b6e09f1c91f44169bfcd2f941bf36c";
let selected = auth_keys.first().unwrap();
let token = AuthorizationCheck::generate_token(selected, &uaid).unwrap();
assert_eq!(&token, legacy_token);
Ok(())
}
#[test]
fn test_token_extractor() -> ApiResult<()> {
let uaid: Uuid = "729e5104f5f04abc9196085340317dea".parse().unwrap();
let auth_keys = ["HJVPy4ZwF4Yz_JdvXTL8hRcwIhv742vC60Tg5Ycrvw8=".to_owned()].to_vec();
let token = AuthorizationCheck::generate_token(auth_keys.first().unwrap(), &uaid).unwrap();
assert!(get_token_from_auth_header(&format!("bearer {}", &token)).is_some());
assert!(get_token_from_auth_header(&format!("webpush {}", &token)).is_some());
assert!(get_token_from_auth_header(&format!("random {}", &token)).is_none());
Ok(())
}
}