autoconnect_ws_sm/identified/
on_client_msg.rs1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6 broadcast::Broadcast,
7 protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10 endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11 util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18 pub async fn on_client_msg(
21 &mut self,
22 msg: ClientMessage,
23 ) -> Result<Vec<ServerMessage>, SMError> {
24 match msg {
25 ClientMessage::Hello { .. } => {
26 Err(SMError::invalid_message("Already Hello'd".to_owned()))
27 }
28 ClientMessage::Register { channel_id, key } => {
29 Ok(vec![self.register(channel_id, key).await?])
30 }
31 ClientMessage::Unregister { channel_id, code } => {
32 Ok(vec![self.unregister(channel_id, code).await?])
33 }
34 ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35 .broadcast_subscribe(broadcasts)
36 .await?
37 .map_or_else(Vec::new, |smsg| vec![smsg])),
38 ClientMessage::Ack { updates } => self.ack(&updates).await,
39 ClientMessage::Nack { code, .. } => {
40 self.nack(code);
41 Ok(vec![])
42 }
43 ClientMessage::Ping => Ok(vec![self.ping()?]),
44 }
45 }
46
47 async fn register(
49 &mut self,
50 channel_id_str: String,
51 key: Option<String>,
52 ) -> Result<ServerMessage, SMError> {
53 let uaid_str = self.uaid.to_string();
54 trace!("WebPushClient:register";
55 "uaid" => &uaid_str,
56 "channel_id" => &channel_id_str,
57 "key" => &key,
58 "message_type" => MessageType::Register.as_ref(),
59 );
60 let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61 SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62 })?;
63 if channel_id.as_hyphenated().to_string() != channel_id_str {
64 return Err(SMError::invalid_message(format!(
65 "Invalid UUID format, not lower-case/dashed: {channel_id}",
66 )));
67 }
68
69 let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70 Ok(endpoint) => {
71 let _ = self.app_state.metrics.incr(MetricName::UaCommandRegister);
72 self.stats.registers += 1;
73 (200, endpoint)
74 }
75 Err(SMErrorKind::MakeEndpoint(msg)) => {
76 error!("WebPushClient::register make_endpoint failed: {}", msg);
77 (400, "Failed to generate endpoint".to_owned())
78 }
79 Err(e) => {
80 error!("WebPushClient::register failed: {}", e);
81 (500, "".to_owned())
82 }
83 };
84 Ok(ServerMessage::Register {
85 channel_id,
86 status,
87 push_endpoint,
88 })
89 }
90
91 async fn do_register(
92 &mut self,
93 channel_id: &Uuid,
94 key: Option<String>,
95 ) -> Result<String, SMErrorKind> {
96 if let Some(user) = &self.deferred_add_user {
97 debug!(
98 "💬WebPushClient::register: User not yet registered: {}",
99 &user.uaid
100 );
101 self.app_state.db.add_user(user).await?;
102 self.deferred_add_user = None;
103 }
104
105 let endpoint = make_endpoint(
106 &self.uaid,
107 channel_id,
108 key.as_deref(),
109 &self.app_state.endpoint_url,
110 &self.app_state.fernet,
111 )
112 .map_err(SMErrorKind::MakeEndpoint)?;
113 self.app_state
114 .db
115 .add_channel(&self.uaid, channel_id)
116 .await?;
117 Ok(endpoint)
118 }
119
120 async fn unregister(
122 &mut self,
123 channel_id: Uuid,
124 code: Option<u32>,
125 ) -> Result<ServerMessage, SMError> {
126 let uaid_str = self.uaid.to_string();
127 let chid_str = channel_id.to_string();
128 trace!("WebPushClient:unregister";
129 "uaid" => &uaid_str,
130 "channel_id" => &chid_str,
131 "code" => &code,
132 "message_type" => MessageType::Unregister.as_ref(),
133 );
134 let result = self
138 .app_state
139 .db
140 .remove_channel(&self.uaid, &channel_id)
141 .await;
142 let status = match result {
143 Ok(_) => {
144 self.app_state
145 .metrics
146 .incr_with_tags(MetricName::UaCommandUnregister)
147 .with_tag("code", &code.unwrap_or(200).to_string())
148 .send();
149 self.stats.unregisters += 1;
150 200
151 }
152 Err(e) => {
153 error!("WebPushClient::unregister failed: {}", e);
154 500
155 }
156 };
157 Ok(ServerMessage::Unregister { channel_id, status })
158 }
159
160 async fn broadcast_subscribe(
162 &mut self,
163 broadcasts: HashMap<String, String>,
164 ) -> Result<Option<ServerMessage>, SMError> {
165 trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166 let broadcasts = Broadcast::from_hashmap(broadcasts);
167 let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169 let bc = self.app_state.broadcaster.read().await;
170 if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171 response.extend(Broadcast::vec_into_hashmap(delta));
172 };
173 let missing = bc.missing_broadcasts(&broadcasts);
174 if !missing.is_empty() {
175 response.insert(
176 "errors".to_owned(),
177 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178 );
179 }
180
181 Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182 broadcasts: response,
183 }))
184 }
185
186 async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188 trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189 let _ = self.app_state.metrics.incr(MetricName::UaCommandAck);
190
191 for notif in updates {
192 let pos = self
196 .ack_state
197 .unacked_direct_notifs
198 .iter()
199 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
200 if let Some(pos) = pos {
202 debug!("✅ Ack (Direct)";
203 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
204 "version" => ¬if.version
205 );
206 self.ack_state.unacked_direct_notifs.remove(pos);
207 self.stats.direct_acked += 1;
208 continue;
209 };
210
211 let pos = self
213 .ack_state
214 .unacked_stored_notifs
215 .iter()
216 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
217 if let Some(pos) = pos {
218 debug!(
219 "✅ Ack (Stored)";
220 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
221 "version" => ¬if.version,
222 "message_type" => MessageType::Ack.as_ref()
223 );
224 let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
226 let is_topic = acked_notification.topic.is_some();
227 debug!("✅ Ack notif: {:?}", &acked_notification);
228 if is_topic {
234 let chid = &acked_notification.chidmessageid();
235 debug!("✅ WebPushClient:ack removing Stored, sort_key: {}", &chid);
236 self.app_state.db.remove_message(&self.uaid, chid).await?;
237 #[cfg(feature = "reliable_report")]
241 acked_notification
242 .record_reliability(&self.app_state.reliability, notif.reliability_state())
243 .await;
244 }
245 let _n = self.ack_state.unacked_stored_notifs.remove(pos);
246 #[cfg(feature = "reliable_report")]
247 if !is_topic {
248 self.ack_state.acked_stored_timestamp_notifs.push(_n);
249 }
250 self.stats.stored_acked += 1;
251 continue;
252 };
253 }
254
255 if self.ack_state.unacked_notifs() {
256 Ok(vec![])
259 } else {
260 self.post_process_all_acked().await
261 }
262 }
263
264 fn nack(&mut self, code: Option<i32>) {
267 trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
268 let code = code
270 .and_then(|code| (301..=303).contains(&code).then_some(code))
271 .unwrap_or(0);
272 self.app_state
273 .metrics
274 .incr_with_tags(MetricName::UaCommandNack)
275 .with_tag("code", &code.to_string())
276 .send();
277 self.stats.nacks += 1;
278 }
279
280 fn ping(&mut self) -> Result<ServerMessage, SMError> {
285 trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
286 if sec_since_epoch() - self.last_ping >= 45 {
290 trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
291 self.last_ping = sec_since_epoch();
292 Ok(ServerMessage::Ping)
293 } else {
294 Err(SMErrorKind::ExcessivePing.into())
295 }
296 }
297
298 async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
306 trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
307 let flags = &self.flags;
308 if flags.check_storage {
309 if flags.increment_storage {
310 debug!(
311 "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
312 "message_type" => MessageType::Notification.as_ref()
313 );
314 self.increment_storage().await?;
315 }
316
317 debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
318 let smsgs = self.check_storage_loop().await?;
319 if !smsgs.is_empty() {
320 debug_assert!(self.flags.check_storage);
321 return Ok(smsgs);
325 }
326 debug_assert!(!self.flags.check_storage);
328 debug_assert!(!self.flags.increment_storage);
329 }
330
331 debug_assert!(!self.ack_state.unacked_notifs());
333 let flags = &self.flags;
334 if flags.old_record_version {
335 debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
336 self.app_state
337 .metrics
338 .incr_with_tags(MetricName::UaExpiration)
339 .with_tag("reason", "old_record_version")
340 .send();
341 self.app_state.db.remove_user(&self.uaid).await?;
342 Err(SMErrorKind::UaidReset.into())
343 } else {
344 Ok(vec![])
345 }
346 }
347}