autoconnect_ws_sm/identified/
on_client_msg.rs1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6 broadcast::Broadcast,
7 protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10 endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11 util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18 pub async fn on_client_msg(
21 &mut self,
22 msg: ClientMessage,
23 ) -> Result<Vec<ServerMessage>, SMError> {
24 match msg {
25 ClientMessage::Hello { .. } => {
26 Err(SMError::invalid_message("Already Hello'd".to_owned()))
27 }
28 ClientMessage::Register { channel_id, key } => {
29 Ok(vec![self.register(channel_id, key).await?])
30 }
31 ClientMessage::Unregister { channel_id, code } => {
32 Ok(vec![self.unregister(channel_id, code).await?])
33 }
34 ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35 .broadcast_subscribe(broadcasts)
36 .await?
37 .map_or_else(Vec::new, |smsg| vec![smsg])),
38 ClientMessage::Ack { updates } => self.ack(&updates).await,
39 ClientMessage::Nack { code, .. } => {
40 self.nack(code);
41 Ok(vec![])
42 }
43 ClientMessage::Ping => Ok(vec![self.ping()?]),
44 }
45 }
46
47 async fn register(
49 &mut self,
50 channel_id_str: String,
51 key: Option<String>,
52 ) -> Result<ServerMessage, SMError> {
53 let uaid_str = self.uaid.to_string();
54 trace!("WebPushClient:register";
55 "uaid" => &uaid_str,
56 "channel_id" => &channel_id_str,
57 "key" => &key,
58 "message_type" => MessageType::Register.as_ref(),
59 );
60 let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61 SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62 })?;
63 if channel_id.as_hyphenated().to_string() != channel_id_str {
64 return Err(SMError::invalid_message(format!(
65 "Invalid UUID format, not lower-case/dashed: {channel_id}",
66 )));
67 }
68
69 let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70 Ok(endpoint) => {
71 let _ = self.app_state.metrics.incr(MetricName::UaCommandRegister);
72 self.stats.registers += 1;
73 (200, endpoint)
74 }
75 Err(SMErrorKind::MakeEndpoint(msg)) => {
76 error!("WebPushClient::register make_endpoint failed: {}", msg);
77 (400, "Failed to generate endpoint".to_owned())
78 }
79 Err(e) => {
80 error!("WebPushClient::register failed: {}", e);
81 (500, "".to_owned())
82 }
83 };
84 Ok(ServerMessage::Register {
85 channel_id,
86 status,
87 push_endpoint,
88 })
89 }
90
91 async fn do_register(
92 &mut self,
93 channel_id: &Uuid,
94 key: Option<String>,
95 ) -> Result<String, SMErrorKind> {
96 if let Some(user) = &self.deferred_add_user {
97 debug!(
98 "💬WebPushClient::register: User not yet registered: {}",
99 &user.uaid
100 );
101 self.app_state.db.add_user(user).await?;
102 self.deferred_add_user = None;
103 }
104
105 let endpoint = make_endpoint(
106 &self.uaid,
107 channel_id,
108 key.as_deref(),
109 &self.app_state.endpoint_url,
110 &self.app_state.fernet,
111 )
112 .map_err(SMErrorKind::MakeEndpoint)?;
113 self.app_state
114 .db
115 .add_channel(&self.uaid, channel_id)
116 .await?;
117 Ok(endpoint)
118 }
119
120 async fn unregister(
122 &mut self,
123 channel_id: Uuid,
124 code: Option<u32>,
125 ) -> Result<ServerMessage, SMError> {
126 let uaid_str = self.uaid.to_string();
127 let chid_str = channel_id.to_string();
128 trace!("WebPushClient:unregister";
129 "uaid" => &uaid_str,
130 "channel_id" => &chid_str,
131 "code" => &code,
132 "message_type" => MessageType::Unregister.as_ref(),
133 );
134 let result = self
138 .app_state
139 .db
140 .remove_channel(&self.uaid, &channel_id)
141 .await;
142 let status = match result {
143 Ok(_) => {
144 self.app_state
145 .metrics
146 .incr_with_tags(MetricName::UaCommandUnregister)
147 .with_tag("code", &code.unwrap_or(200).to_string())
148 .send();
149 self.stats.unregisters += 1;
150 200
151 }
152 Err(e) => {
153 error!("WebPushClient::unregister failed: {}", e);
154 500
155 }
156 };
157 Ok(ServerMessage::Unregister { channel_id, status })
158 }
159
160 async fn broadcast_subscribe(
162 &mut self,
163 broadcasts: HashMap<String, String>,
164 ) -> Result<Option<ServerMessage>, SMError> {
165 trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166 let broadcasts = Broadcast::from_hashmap(broadcasts);
167 let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169 let bc = self.app_state.broadcaster.read().await;
170 if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171 response.extend(Broadcast::vec_into_hashmap(delta));
172 };
173 let missing = bc.missing_broadcasts(&broadcasts);
174 if !missing.is_empty() {
175 response.insert(
176 "errors".to_owned(),
177 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178 );
179 }
180
181 Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182 broadcasts: response,
183 }))
184 }
185
186 async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188 trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189 let _ = self.app_state.metrics.incr(MetricName::UaCommandAck);
190
191 for notif in updates {
192 let key = notif.version.clone();
193 if self.ack_state.unacked_direct_notifs.remove(&key).is_some() {
195 debug!("✅ Ack (Direct)";
196 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
197 "version" => ¬if.version
198 );
199 self.stats.direct_acked += 1;
200 continue;
201 };
202
203 #[allow(unused_mut)]
205 if let Some(mut acked_notification) = self.ack_state.unacked_stored_notifs.remove(&key)
206 {
207 debug!(
208 "✅ Ack (Stored)";
209 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
210 "version" => ¬if.version,
211 "message_type" => MessageType::Ack.as_ref()
212 );
213 let is_topic = acked_notification
215 .topic
216 .as_ref()
217 .map(|t| !t.is_empty())
218 .unwrap_or(false);
219 debug!("✅ Ack notif: {:?}", &acked_notification);
220 if is_topic {
226 let chid_message_id = &acked_notification.chidmessageid();
227 debug!(
228 "✅🗑 WebPushClient:ack removing Stored, sort_key: {}, version: {}",
229 &chid_message_id, &acked_notification.version
230 );
231 self.app_state
232 .db
233 .remove_message(&self.uaid, chid_message_id)
234 .await?;
235 #[cfg(feature = "reliable_report")]
239 acked_notification
240 .record_reliability(&self.app_state.reliability, notif.reliability_state())
241 .await;
242 }
243 #[cfg(feature = "reliable_report")]
244 if !is_topic {
245 self.ack_state
246 .acked_stored_timestamp_notifs
247 .push(acked_notification);
248 }
249 self.stats.stored_acked += 1;
250 continue;
251 };
252 }
253
254 if self.ack_state.unacked_notifs() {
255 Ok(vec![])
258 } else {
259 self.post_process_all_acked().await
260 }
261 }
262
263 fn nack(&mut self, code: Option<i32>) {
266 trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
267 let code = code
269 .and_then(|code| (301..=303).contains(&code).then_some(code))
270 .unwrap_or(0);
271 self.app_state
272 .metrics
273 .incr_with_tags(MetricName::UaCommandNack)
274 .with_tag("code", &code.to_string())
275 .send();
276 self.stats.nacks += 1;
277 }
278
279 fn ping(&mut self) -> Result<ServerMessage, SMError> {
284 trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
285 if sec_since_epoch() - self.last_ping >= 45 {
289 trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
290 self.last_ping = sec_since_epoch();
291 Ok(ServerMessage::Ping)
292 } else {
293 Err(SMErrorKind::ExcessivePing.into())
294 }
295 }
296
297 async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
305 trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
306 let flags = &self.flags;
307 if flags.check_storage {
308 if flags.increment_storage {
309 debug!(
310 "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
311 "message_type" => MessageType::Notification.as_ref()
312 );
313 self.increment_storage().await?;
314 }
315
316 debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
317 let smsgs = self.check_storage_loop().await?;
318 if !smsgs.is_empty() {
319 debug_assert!(self.flags.check_storage);
320 return Ok(smsgs);
324 }
325 debug_assert!(!self.flags.check_storage);
327 debug_assert!(!self.flags.increment_storage);
328 }
329
330 debug_assert!(!self.ack_state.unacked_notifs());
332 let flags = &self.flags;
333 if flags.old_record_version {
334 debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
335 self.app_state
336 .metrics
337 .incr_with_tags(MetricName::UaExpiration)
338 .with_tag("reason", "old_record_version")
339 .send();
340 self.app_state.db.remove_user(&self.uaid).await?;
341 Err(SMErrorKind::UaidReset.into())
342 } else {
343 Ok(vec![])
344 }
345 }
346}