autoconnect_ws_sm/identified/
on_client_msg.rs1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6 broadcast::Broadcast,
7 protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10 endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11 util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18 pub async fn on_client_msg(
21 &mut self,
22 msg: ClientMessage,
23 ) -> Result<Vec<ServerMessage>, SMError> {
24 match msg {
25 ClientMessage::Hello { .. } => {
26 Err(SMError::invalid_message("Already Hello'd".to_owned()))
27 }
28 ClientMessage::Register { channel_id, key } => {
29 Ok(vec![self.register(channel_id, key).await?])
30 }
31 ClientMessage::Unregister { channel_id, code } => {
32 Ok(vec![self.unregister(channel_id, code).await?])
33 }
34 ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35 .broadcast_subscribe(broadcasts)
36 .await?
37 .map_or_else(Vec::new, |smsg| vec![smsg])),
38 ClientMessage::Ack { updates } => self.ack(&updates).await,
39 ClientMessage::Nack { code, .. } => {
40 self.nack(code);
41 Ok(vec![])
42 }
43 ClientMessage::Ping => Ok(vec![self.ping()?]),
44 }
45 }
46
47 async fn register(
49 &mut self,
50 channel_id_str: String,
51 key: Option<String>,
52 ) -> Result<ServerMessage, SMError> {
53 trace!("WebPushClient:register";
54 "uaid" => &self.uaid.to_string(),
55 "channel_id" => &channel_id_str,
56 "key" => &key,
57 "message_type" => MessageType::Register.as_ref(),
58 );
59 let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
60 SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
61 })?;
62 if channel_id.as_hyphenated().to_string() != channel_id_str {
63 return Err(SMError::invalid_message(format!(
64 "Invalid UUID format, not lower-case/dashed: {channel_id}",
65 )));
66 }
67
68 let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
69 Ok(endpoint) => {
70 let _ = self
71 .app_state
72 .metrics
73 .incr(MetricName::UaCommand(MessageType::Register.to_string()));
74 self.stats.registers += 1;
75 (200, endpoint)
76 }
77 Err(SMErrorKind::MakeEndpoint(msg)) => {
78 error!("WebPushClient::register make_endpoint failed: {}", msg);
79 (400, "Failed to generate endpoint".to_owned())
80 }
81 Err(e) => {
82 error!("WebPushClient::register failed: {}", e);
83 (500, "".to_owned())
84 }
85 };
86 Ok(ServerMessage::Register {
87 channel_id,
88 status,
89 push_endpoint,
90 })
91 }
92
93 async fn do_register(
94 &mut self,
95 channel_id: &Uuid,
96 key: Option<String>,
97 ) -> Result<String, SMErrorKind> {
98 if let Some(user) = &self.deferred_add_user {
99 debug!(
100 "💬WebPushClient::register: User not yet registered: {}",
101 &user.uaid
102 );
103 self.app_state.db.add_user(user).await?;
104 self.deferred_add_user = None;
105 }
106
107 let endpoint = make_endpoint(
108 &self.uaid,
109 channel_id,
110 key.as_deref(),
111 &self.app_state.endpoint_url,
112 &self.app_state.fernet,
113 )
114 .map_err(SMErrorKind::MakeEndpoint)?;
115 self.app_state
116 .db
117 .add_channel(&self.uaid, channel_id)
118 .await?;
119 Ok(endpoint)
120 }
121
122 async fn unregister(
124 &mut self,
125 channel_id: Uuid,
126 code: Option<u32>,
127 ) -> Result<ServerMessage, SMError> {
128 trace!("WebPushClient:unregister";
129 "uaid" => &self.uaid.to_string(),
130 "channel_id" => &channel_id.to_string(),
131 "code" => &code,
132 "message_type" => MessageType::Unregister.as_ref(),
133 );
134 let result = self
138 .app_state
139 .db
140 .remove_channel(&self.uaid, &channel_id)
141 .await;
142 let status = match result {
143 Ok(_) => {
144 self.app_state
145 .metrics
146 .incr_with_tags(MetricName::UaCommand(MessageType::Unregister.to_string()))
147 .with_tag("code", &code.unwrap_or(200).to_string())
148 .send();
149 self.stats.unregisters += 1;
150 200
151 }
152 Err(e) => {
153 error!("WebPushClient::unregister failed: {}", e);
154 500
155 }
156 };
157 Ok(ServerMessage::Unregister { channel_id, status })
158 }
159
160 async fn broadcast_subscribe(
162 &mut self,
163 broadcasts: HashMap<String, String>,
164 ) -> Result<Option<ServerMessage>, SMError> {
165 trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166 let broadcasts = Broadcast::from_hashmap(broadcasts);
167 let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169 let bc = self.app_state.broadcaster.read().await;
170 if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171 response.extend(Broadcast::vec_into_hashmap(delta));
172 };
173 let missing = bc.missing_broadcasts(&broadcasts);
174 if !missing.is_empty() {
175 response.insert(
176 "errors".to_owned(),
177 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178 );
179 }
180
181 Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182 broadcasts: response,
183 }))
184 }
185
186 async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188 trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189 let _ = self
190 .app_state
191 .metrics
192 .incr(MetricName::UaCommand(MessageType::Ack.to_string()));
193
194 for notif in updates {
195 let pos = self
199 .ack_state
200 .unacked_direct_notifs
201 .iter()
202 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
203 if let Some(pos) = pos {
205 debug!("✅ Ack (Direct)";
206 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
207 "version" => ¬if.version
208 );
209 self.ack_state.unacked_direct_notifs.remove(pos);
210 self.stats.direct_acked += 1;
211 continue;
212 };
213
214 let pos = self
216 .ack_state
217 .unacked_stored_notifs
218 .iter()
219 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
220 if let Some(pos) = pos {
221 debug!(
222 "✅ Ack (Stored)";
223 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
224 "version" => ¬if.version,
225 "message_type" => MessageType::Ack.as_ref()
226 );
227 let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
229 let is_topic = acked_notification.topic.is_some();
230 debug!("✅ Ack notif: {:?}", &acked_notification);
231 if is_topic {
237 debug!(
238 "✅ WebPushClient:ack removing Stored, sort_key: {}",
239 &acked_notification.chidmessageid()
240 );
241 self.app_state
242 .db
243 .remove_message(&self.uaid, &acked_notification.chidmessageid())
244 .await?;
245 #[cfg(feature = "reliable_report")]
249 acked_notification
250 .record_reliability(&self.app_state.reliability, notif.reliability_state())
251 .await;
252 }
253 let n = self.ack_state.unacked_stored_notifs.remove(pos);
254 #[cfg(feature = "reliable_report")]
255 if !is_topic {
256 self.ack_state.acked_stored_timestamp_notifs.push(n);
257 }
258 self.stats.stored_acked += 1;
259 continue;
260 };
261 }
262
263 if self.ack_state.unacked_notifs() {
264 Ok(vec![])
267 } else {
268 self.post_process_all_acked().await
269 }
270 }
271
272 fn nack(&mut self, code: Option<i32>) {
275 trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
276 let code = code
278 .and_then(|code| (301..=303).contains(&code).then_some(code))
279 .unwrap_or(0);
280 self.app_state
281 .metrics
282 .incr_with_tags(MetricName::UaCommand(MessageType::Nack.to_string()))
283 .with_tag("code", &code.to_string())
284 .send();
285 self.stats.nacks += 1;
286 }
287
288 fn ping(&mut self) -> Result<ServerMessage, SMError> {
293 trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
294 if sec_since_epoch() - self.last_ping >= 45 {
298 trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
299 self.last_ping = sec_since_epoch();
300 Ok(ServerMessage::Ping)
301 } else {
302 Err(SMErrorKind::ExcessivePing.into())
303 }
304 }
305
306 async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
314 trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
315 let flags = &self.flags;
316 if flags.check_storage {
317 if flags.increment_storage {
318 debug!(
319 "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
320 "message_type" => MessageType::Notification.as_ref()
321 );
322 self.increment_storage().await?;
323 }
324
325 debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
326 let smsgs = self.check_storage_loop().await?;
327 if !smsgs.is_empty() {
328 debug_assert!(self.flags.check_storage);
329 return Ok(smsgs);
333 }
334 debug_assert!(!self.flags.check_storage);
336 debug_assert!(!self.flags.increment_storage);
337 }
338
339 debug_assert!(!self.ack_state.unacked_notifs());
341 let flags = &self.flags;
342 if flags.old_record_version {
343 debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
344 self.app_state
345 .metrics
346 .incr_with_tags(MetricName::UaExpiration)
347 .with_tag("reason", "old_record_version")
348 .send();
349 self.app_state.db.remove_user(&self.uaid).await?;
350 Err(SMErrorKind::UaidReset.into())
351 } else {
352 Ok(vec![])
353 }
354 }
355}