use std::collections::HashMap;
use cadence::CountedExt;
use uuid::Uuid;
use autoconnect_common::{
broadcast::Broadcast,
protocol::{BroadcastValue, ClientAck, ClientMessage, ServerMessage},
};
use autopush_common::{endpoint::make_endpoint, util::sec_since_epoch};
use super::WebPushClient;
use crate::error::{SMError, SMErrorKind};
impl WebPushClient {
pub async fn on_client_msg(
&mut self,
msg: ClientMessage,
) -> Result<Vec<ServerMessage>, SMError> {
match msg {
ClientMessage::Hello { .. } => {
Err(SMError::invalid_message("Already Hello'd".to_owned()))
}
ClientMessage::Register { channel_id, key } => {
Ok(vec![self.register(channel_id, key).await?])
}
ClientMessage::Unregister { channel_id, code } => {
Ok(vec![self.unregister(channel_id, code).await?])
}
ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
.broadcast_subscribe(broadcasts)
.await?
.map_or_else(Vec::new, |smsg| vec![smsg])),
ClientMessage::Ack { updates } => self.ack(&updates).await,
ClientMessage::Nack { code, .. } => {
self.nack(code);
Ok(vec![])
}
ClientMessage::Ping => Ok(vec![self.ping()?]),
}
}
async fn register(
&mut self,
channel_id_str: String,
key: Option<String>,
) -> Result<ServerMessage, SMError> {
trace!("WebPushClient:register";
"uaid" => &self.uaid.to_string(),
"channel_id" => &channel_id_str,
"key" => &key,
);
let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
})?;
if channel_id.as_hyphenated().to_string() != channel_id_str {
return Err(SMError::invalid_message(format!(
"Invalid UUID format, not lower-case/dashed: {channel_id}",
)));
}
let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
Ok(endpoint) => {
let _ = self.app_state.metrics.incr("ua.command.register");
self.stats.registers += 1;
(200, endpoint)
}
Err(SMErrorKind::MakeEndpoint(msg)) => {
error!("WebPushClient::register make_endpoint failed: {}", msg);
(400, "Failed to generate endpoint".to_owned())
}
Err(e) => {
error!("WebPushClient::register failed: {}", e);
(500, "".to_owned())
}
};
Ok(ServerMessage::Register {
channel_id,
status,
push_endpoint,
})
}
async fn do_register(
&mut self,
channel_id: &Uuid,
key: Option<String>,
) -> Result<String, SMErrorKind> {
if let Some(user) = &self.deferred_add_user {
debug!(
"💬WebPushClient::register: User not yet registered: {}",
&user.uaid
);
self.app_state.db.add_user(user).await?;
self.deferred_add_user = None;
}
let endpoint = make_endpoint(
&self.uaid,
channel_id,
key.as_deref(),
&self.app_state.endpoint_url,
&self.app_state.fernet,
)
.map_err(SMErrorKind::MakeEndpoint)?;
self.app_state
.db
.add_channel(&self.uaid, channel_id)
.await?;
Ok(endpoint)
}
async fn unregister(
&mut self,
channel_id: Uuid,
code: Option<u32>,
) -> Result<ServerMessage, SMError> {
trace!("WebPushClient:unregister";
"uaid" => &self.uaid.to_string(),
"channel_id" => &channel_id.to_string(),
"code" => &code,
);
let result = self
.app_state
.db
.remove_channel(&self.uaid, &channel_id)
.await;
let status = match result {
Ok(_) => {
self.app_state
.metrics
.incr_with_tags("ua.command.unregister")
.with_tag("code", &code.unwrap_or(200).to_string())
.send();
self.stats.unregisters += 1;
200
}
Err(e) => {
error!("WebPushClient::unregister failed: {}", e);
500
}
};
Ok(ServerMessage::Unregister { channel_id, status })
}
async fn broadcast_subscribe(
&mut self,
broadcasts: HashMap<String, String>,
) -> Result<Option<ServerMessage>, SMError> {
trace!("WebPushClient:broadcast_subscribe");
let broadcasts = Broadcast::from_hashmap(broadcasts);
let mut response: HashMap<String, BroadcastValue> = HashMap::new();
let bc = self.app_state.broadcaster.read().await;
if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
response.extend(Broadcast::vec_into_hashmap(delta));
};
let missing = bc.missing_broadcasts(&broadcasts);
if !missing.is_empty() {
response.insert(
"errors".to_owned(),
BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
);
}
Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
broadcasts: response,
}))
}
async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
trace!("✅ WebPushClient:ack");
let _ = self.app_state.metrics.incr("ua.command.ack");
for notif in updates {
let pos = self
.ack_state
.unacked_direct_notifs
.iter()
.position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
if let Some(pos) = pos {
debug!("✅ Ack (Direct)";
"channel_id" => notif.channel_id.as_hyphenated().to_string(),
"version" => ¬if.version
);
self.ack_state.unacked_direct_notifs.remove(pos);
self.stats.direct_acked += 1;
continue;
};
let pos = self
.ack_state
.unacked_stored_notifs
.iter()
.position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
if let Some(pos) = pos {
debug!(
"✅ Ack (Stored)";
"channel_id" => notif.channel_id.as_hyphenated().to_string(),
"version" => ¬if.version
);
let n = &self.ack_state.unacked_stored_notifs[pos];
debug!("✅ Ack notif: {:?}", &n);
if n.sortkey_timestamp.is_none() {
debug!(
"✅ WebPushClient:ack removing Stored, sort_key: {}",
&n.chidmessageid()
);
self.app_state
.db
.remove_message(&self.uaid, &n.chidmessageid())
.await?;
}
self.ack_state.unacked_stored_notifs.remove(pos);
self.stats.stored_acked += 1;
continue;
};
}
if self.ack_state.unacked_notifs() {
Ok(vec![])
} else {
self.post_process_all_acked().await
}
}
fn nack(&mut self, code: Option<i32>) {
trace!("WebPushClient:nack");
let code = code
.and_then(|code| (301..=303).contains(&code).then_some(code))
.unwrap_or(0);
self.app_state
.metrics
.incr_with_tags("ua.command.nack")
.with_tag("code", &code.to_string())
.send();
self.stats.nacks += 1;
}
fn ping(&mut self) -> Result<ServerMessage, SMError> {
trace!("WebPushClient:ping");
if sec_since_epoch() - self.last_ping >= 45 {
trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
self.last_ping = sec_since_epoch();
Ok(ServerMessage::Ping)
} else {
Err(SMErrorKind::ExcessivePing.into())
}
}
async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
trace!("▶️ WebPushClient:post_process_all_acked");
let flags = &self.flags;
if flags.check_storage {
if flags.increment_storage {
debug!(
"▶️ WebPushClient:post_process_all_acked check_storage && increment_storage"
);
self.increment_storage().await?;
}
debug!("▶️ WebPushClient:post_process_all_acked check_storage");
let smsgs = self.check_storage_loop().await?;
if !smsgs.is_empty() {
debug_assert!(self.flags.check_storage);
return Ok(smsgs);
}
debug_assert!(!self.flags.check_storage);
debug_assert!(!self.flags.increment_storage);
}
debug_assert!(!self.ack_state.unacked_notifs());
let flags = &self.flags;
if flags.old_record_version {
debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid");
self.app_state
.metrics
.incr_with_tags("ua.expiration")
.with_tag("reason", "old_record_version")
.send();
self.app_state.db.remove_user(&self.uaid).await?;
Err(SMErrorKind::UaidReset.into())
} else {
Ok(vec![])
}
}
}