autoconnect_ws_sm/identified/
on_client_msg.rs

1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6    broadcast::Broadcast,
7    protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10    endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11    util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18    /// Handle a WebPush `ClientMessage` sent from the user agent over the
19    /// WebSocket for this user
20    pub async fn on_client_msg(
21        &mut self,
22        msg: ClientMessage,
23    ) -> Result<Vec<ServerMessage>, SMError> {
24        match msg {
25            ClientMessage::Hello { .. } => {
26                Err(SMError::invalid_message("Already Hello'd".to_owned()))
27            }
28            ClientMessage::Register { channel_id, key } => {
29                Ok(vec![self.register(channel_id, key).await?])
30            }
31            ClientMessage::Unregister { channel_id, code } => {
32                Ok(vec![self.unregister(channel_id, code).await?])
33            }
34            ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35                .broadcast_subscribe(broadcasts)
36                .await?
37                .map_or_else(Vec::new, |smsg| vec![smsg])),
38            ClientMessage::Ack { updates } => self.ack(&updates).await,
39            ClientMessage::Nack { code, .. } => {
40                self.nack(code);
41                Ok(vec![])
42            }
43            ClientMessage::Ping => Ok(vec![self.ping()?]),
44        }
45    }
46
47    /// Register a new Push subscription
48    async fn register(
49        &mut self,
50        channel_id_str: String,
51        key: Option<String>,
52    ) -> Result<ServerMessage, SMError> {
53        let uaid_str = self.uaid.to_string();
54        trace!("WebPushClient:register";
55               "uaid" => &uaid_str,
56               "channel_id" => &channel_id_str,
57               "key" => &key,
58               "message_type" => MessageType::Register.as_ref(),
59        );
60        let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61            SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62        })?;
63        if channel_id.as_hyphenated().to_string() != channel_id_str {
64            return Err(SMError::invalid_message(format!(
65                "Invalid UUID format, not lower-case/dashed: {channel_id}",
66            )));
67        }
68
69        let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70            Ok(endpoint) => {
71                let _ = self
72                    .app_state
73                    .metrics
74                    .incr(MetricName::UaCommand(MessageType::Register.to_string()));
75                self.stats.registers += 1;
76                (200, endpoint)
77            }
78            Err(SMErrorKind::MakeEndpoint(msg)) => {
79                error!("WebPushClient::register make_endpoint failed: {}", msg);
80                (400, "Failed to generate endpoint".to_owned())
81            }
82            Err(e) => {
83                error!("WebPushClient::register failed: {}", e);
84                (500, "".to_owned())
85            }
86        };
87        Ok(ServerMessage::Register {
88            channel_id,
89            status,
90            push_endpoint,
91        })
92    }
93
94    async fn do_register(
95        &mut self,
96        channel_id: &Uuid,
97        key: Option<String>,
98    ) -> Result<String, SMErrorKind> {
99        if let Some(user) = &self.deferred_add_user {
100            debug!(
101                "💬WebPushClient::register: User not yet registered: {}",
102                &user.uaid
103            );
104            self.app_state.db.add_user(user).await?;
105            self.deferred_add_user = None;
106        }
107
108        let endpoint = make_endpoint(
109            &self.uaid,
110            channel_id,
111            key.as_deref(),
112            &self.app_state.endpoint_url,
113            &self.app_state.fernet,
114        )
115        .map_err(SMErrorKind::MakeEndpoint)?;
116        self.app_state
117            .db
118            .add_channel(&self.uaid, channel_id)
119            .await?;
120        Ok(endpoint)
121    }
122
123    /// Unregister an existing Push subscription
124    async fn unregister(
125        &mut self,
126        channel_id: Uuid,
127        code: Option<u32>,
128    ) -> Result<ServerMessage, SMError> {
129        let uaid_str = self.uaid.to_string();
130        let chid_str = channel_id.to_string();
131        trace!("WebPushClient:unregister";
132               "uaid" => &uaid_str,
133               "channel_id" => &chid_str,
134               "code" => &code,
135               "message_type" => MessageType::Unregister.as_ref(),
136        );
137        // TODO: (copied from previous state machine) unregister should check
138        // the format of channel_id like register does
139
140        let result = self
141            .app_state
142            .db
143            .remove_channel(&self.uaid, &channel_id)
144            .await;
145        let status = match result {
146            Ok(_) => {
147                self.app_state
148                    .metrics
149                    .incr_with_tags(MetricName::UaCommand(MessageType::Unregister.to_string()))
150                    .with_tag("code", &code.unwrap_or(200).to_string())
151                    .send();
152                self.stats.unregisters += 1;
153                200
154            }
155            Err(e) => {
156                error!("WebPushClient::unregister failed: {}", e);
157                500
158            }
159        };
160        Ok(ServerMessage::Unregister { channel_id, status })
161    }
162
163    /// Subscribe to a new set of Broadcasts
164    async fn broadcast_subscribe(
165        &mut self,
166        broadcasts: HashMap<String, String>,
167    ) -> Result<Option<ServerMessage>, SMError> {
168        trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
169        let broadcasts = Broadcast::from_hashmap(broadcasts);
170        let mut response: HashMap<String, BroadcastValue> = HashMap::new();
171
172        let bc = self.app_state.broadcaster.read().await;
173        if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
174            response.extend(Broadcast::vec_into_hashmap(delta));
175        };
176        let missing = bc.missing_broadcasts(&broadcasts);
177        if !missing.is_empty() {
178            response.insert(
179                "errors".to_owned(),
180                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
181            );
182        }
183
184        Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
185            broadcasts: response,
186        }))
187    }
188
189    /// Acknowledge receipt of one or more Push Notifications
190    async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
191        trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
192        let _ = self
193            .app_state
194            .metrics
195            .incr(MetricName::UaCommand(MessageType::Ack.to_string()));
196
197        for notif in updates {
198            // Check the list of unacked "direct" (unstored) notifications. We only want to
199            // ack messages we've not yet seen and we have the right version, otherwise we could
200            // have gotten an older, inaccurate ACK.
201            let pos = self
202                .ack_state
203                .unacked_direct_notifs
204                .iter()
205                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
206            // We found one, so delete it from our list of unacked messages
207            if let Some(pos) = pos {
208                debug!("✅ Ack (Direct)";
209                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
210                       "version" => &notif.version
211                );
212                self.ack_state.unacked_direct_notifs.remove(pos);
213                self.stats.direct_acked += 1;
214                continue;
215            };
216
217            // Now, check the list of stored notifications
218            let pos = self
219                .ack_state
220                .unacked_stored_notifs
221                .iter()
222                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
223            if let Some(pos) = pos {
224                debug!(
225                    "✅ Ack (Stored)";
226                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
227                       "version" => &notif.version,
228                       "message_type" => MessageType::Ack.as_ref()
229                );
230                // Get the stored notification record.
231                let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
232                let is_topic = acked_notification.topic.is_some();
233                debug!("✅ Ack notif: {:?}", &acked_notification);
234                // Only force delete Topic messages, since they don't have a timestamp.
235                // Other messages persist in the database, to be, eventually, cleaned up by their
236                // TTL. We will need to update the `CurrentTimestamp` field for the channel
237                // record. Use that field to set the baseline timestamp for when to pull messages
238                // in the future.
239                if is_topic {
240                    let chid = &acked_notification.chidmessageid();
241                    debug!("✅ WebPushClient:ack removing Stored, sort_key: {}", &chid);
242                    self.app_state.db.remove_message(&self.uaid, chid).await?;
243                    // NOTE: timestamp messages may still be in state of flux: they're not fully
244                    // ack'd (removed/unable to be resurrected) until increment_storage is called,
245                    // so their reliability is recorded there
246                    #[cfg(feature = "reliable_report")]
247                    acked_notification
248                        .record_reliability(&self.app_state.reliability, notif.reliability_state())
249                        .await;
250                }
251                let n = self.ack_state.unacked_stored_notifs.remove(pos);
252                #[cfg(feature = "reliable_report")]
253                if !is_topic {
254                    self.ack_state.acked_stored_timestamp_notifs.push(n);
255                }
256                self.stats.stored_acked += 1;
257                continue;
258            };
259        }
260
261        if self.ack_state.unacked_notifs() {
262            // Wait for the Client to Ack all notifications before further
263            // processing
264            Ok(vec![])
265        } else {
266            self.post_process_all_acked().await
267        }
268    }
269
270    /// Negative Acknowledgement (a Client error occurred) of one or more Push
271    /// Notifications
272    fn nack(&mut self, code: Option<i32>) {
273        trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
274        // only metric codes expected from the client (or 0)
275        let code = code
276            .and_then(|code| (301..=303).contains(&code).then_some(code))
277            .unwrap_or(0);
278        self.app_state
279            .metrics
280            .incr_with_tags(MetricName::UaCommand(MessageType::Nack.to_string()))
281            .with_tag("code", &code.to_string())
282            .send();
283        self.stats.nacks += 1;
284    }
285
286    /// Handle a WebPush Ping
287    ///
288    /// Note this is the WebPush Protocol level's Ping: this differs from the
289    /// lower level WebSocket Ping frame (handled by the `webpush_ws` handler).
290    fn ping(&mut self) -> Result<ServerMessage, SMError> {
291        trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
292        // TODO: why is this 45 vs the comment describing a minute? and 45
293        // should be a setting
294        // Clients shouldn't ping > than once per minute or we disconnect them
295        if sec_since_epoch() - self.last_ping >= 45 {
296            trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
297            self.last_ping = sec_since_epoch();
298            Ok(ServerMessage::Ping)
299        } else {
300            Err(SMErrorKind::ExcessivePing.into())
301        }
302    }
303
304    /// Post process the Client succesfully Ack'ing all Push Notifications it's
305    /// been sent.
306    ///
307    /// Notifications are read in small batches (approximately 10). We wait for
308    /// the Client to Ack every Notification in that batch (invoking this
309    /// method) before proceeding to read the next batch (or potential other
310    /// actions such as `reset_uaid`).
311    async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
312        trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
313        let flags = &self.flags;
314        if flags.check_storage {
315            if flags.increment_storage {
316                debug!(
317                    "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
318                    "message_type" => MessageType::Notification.as_ref()
319                );
320                self.increment_storage().await?;
321            }
322
323            debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
324            let smsgs = self.check_storage_loop().await?;
325            if !smsgs.is_empty() {
326                debug_assert!(self.flags.check_storage);
327                // More outgoing notifications: send them out and go back to
328                // waiting for the Client to Ack them all before further
329                // processing
330                return Ok(smsgs);
331            }
332            // Otherwise check_storage is finished
333            debug_assert!(!self.flags.check_storage);
334            debug_assert!(!self.flags.increment_storage);
335        }
336
337        // All Ack'd and finished checking/incrementing storage
338        debug_assert!(!self.ack_state.unacked_notifs());
339        let flags = &self.flags;
340        if flags.old_record_version {
341            debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
342            self.app_state
343                .metrics
344                .incr_with_tags(MetricName::UaExpiration)
345                .with_tag("reason", "old_record_version")
346                .send();
347            self.app_state.db.remove_user(&self.uaid).await?;
348            Err(SMErrorKind::UaidReset.into())
349        } else {
350            Ok(vec![])
351        }
352    }
353}