autoconnect_ws_sm/identified/
on_client_msg.rs

1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6    broadcast::Broadcast,
7    protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10    endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11    util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18    /// Handle a WebPush `ClientMessage` sent from the user agent over the
19    /// WebSocket for this user
20    pub async fn on_client_msg(
21        &mut self,
22        msg: ClientMessage,
23    ) -> Result<Vec<ServerMessage>, SMError> {
24        match msg {
25            ClientMessage::Hello { .. } => {
26                Err(SMError::invalid_message("Already Hello'd".to_owned()))
27            }
28            ClientMessage::Register { channel_id, key } => {
29                Ok(vec![self.register(channel_id, key).await?])
30            }
31            ClientMessage::Unregister { channel_id, code } => {
32                Ok(vec![self.unregister(channel_id, code).await?])
33            }
34            ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35                .broadcast_subscribe(broadcasts)
36                .await?
37                .map_or_else(Vec::new, |smsg| vec![smsg])),
38            ClientMessage::Ack { updates } => self.ack(&updates).await,
39            ClientMessage::Nack { code, .. } => {
40                self.nack(code);
41                Ok(vec![])
42            }
43            ClientMessage::Ping => Ok(vec![self.ping()?]),
44        }
45    }
46
47    /// Register a new Push subscription
48    async fn register(
49        &mut self,
50        channel_id_str: String,
51        key: Option<String>,
52    ) -> Result<ServerMessage, SMError> {
53        trace!("WebPushClient:register";
54               "uaid" => &self.uaid.to_string(),
55               "channel_id" => &channel_id_str,
56               "key" => &key,
57               "message_type" => MessageType::Register.as_ref(),
58        );
59        let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
60            SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
61        })?;
62        if channel_id.as_hyphenated().to_string() != channel_id_str {
63            return Err(SMError::invalid_message(format!(
64                "Invalid UUID format, not lower-case/dashed: {channel_id}",
65            )));
66        }
67
68        let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
69            Ok(endpoint) => {
70                let _ = self
71                    .app_state
72                    .metrics
73                    .incr(MetricName::UaCommand(MessageType::Register.to_string()));
74                self.stats.registers += 1;
75                (200, endpoint)
76            }
77            Err(SMErrorKind::MakeEndpoint(msg)) => {
78                error!("WebPushClient::register make_endpoint failed: {}", msg);
79                (400, "Failed to generate endpoint".to_owned())
80            }
81            Err(e) => {
82                error!("WebPushClient::register failed: {}", e);
83                (500, "".to_owned())
84            }
85        };
86        Ok(ServerMessage::Register {
87            channel_id,
88            status,
89            push_endpoint,
90        })
91    }
92
93    async fn do_register(
94        &mut self,
95        channel_id: &Uuid,
96        key: Option<String>,
97    ) -> Result<String, SMErrorKind> {
98        if let Some(user) = &self.deferred_add_user {
99            debug!(
100                "💬WebPushClient::register: User not yet registered: {}",
101                &user.uaid
102            );
103            self.app_state.db.add_user(user).await?;
104            self.deferred_add_user = None;
105        }
106
107        let endpoint = make_endpoint(
108            &self.uaid,
109            channel_id,
110            key.as_deref(),
111            &self.app_state.endpoint_url,
112            &self.app_state.fernet,
113        )
114        .map_err(SMErrorKind::MakeEndpoint)?;
115        self.app_state
116            .db
117            .add_channel(&self.uaid, channel_id)
118            .await?;
119        Ok(endpoint)
120    }
121
122    /// Unregister an existing Push subscription
123    async fn unregister(
124        &mut self,
125        channel_id: Uuid,
126        code: Option<u32>,
127    ) -> Result<ServerMessage, SMError> {
128        trace!("WebPushClient:unregister";
129               "uaid" => &self.uaid.to_string(),
130               "channel_id" => &channel_id.to_string(),
131               "code" => &code,
132               "message_type" => MessageType::Unregister.as_ref(),
133        );
134        // TODO: (copied from previous state machine) unregister should check
135        // the format of channel_id like register does
136
137        let result = self
138            .app_state
139            .db
140            .remove_channel(&self.uaid, &channel_id)
141            .await;
142        let status = match result {
143            Ok(_) => {
144                self.app_state
145                    .metrics
146                    .incr_with_tags(MetricName::UaCommand(MessageType::Unregister.to_string()))
147                    .with_tag("code", &code.unwrap_or(200).to_string())
148                    .send();
149                self.stats.unregisters += 1;
150                200
151            }
152            Err(e) => {
153                error!("WebPushClient::unregister failed: {}", e);
154                500
155            }
156        };
157        Ok(ServerMessage::Unregister { channel_id, status })
158    }
159
160    /// Subscribe to a new set of Broadcasts
161    async fn broadcast_subscribe(
162        &mut self,
163        broadcasts: HashMap<String, String>,
164    ) -> Result<Option<ServerMessage>, SMError> {
165        trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166        let broadcasts = Broadcast::from_hashmap(broadcasts);
167        let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169        let bc = self.app_state.broadcaster.read().await;
170        if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171            response.extend(Broadcast::vec_into_hashmap(delta));
172        };
173        let missing = bc.missing_broadcasts(&broadcasts);
174        if !missing.is_empty() {
175            response.insert(
176                "errors".to_owned(),
177                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178            );
179        }
180
181        Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182            broadcasts: response,
183        }))
184    }
185
186    /// Acknowledge receipt of one or more Push Notifications
187    async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188        trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189        let _ = self
190            .app_state
191            .metrics
192            .incr(MetricName::UaCommand(MessageType::Ack.to_string()));
193
194        for notif in updates {
195            // Check the list of unacked "direct" (unstored) notifications. We only want to
196            // ack messages we've not yet seen and we have the right version, otherwise we could
197            // have gotten an older, inaccurate ACK.
198            let pos = self
199                .ack_state
200                .unacked_direct_notifs
201                .iter()
202                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
203            // We found one, so delete it from our list of unacked messages
204            if let Some(pos) = pos {
205                debug!("✅ Ack (Direct)";
206                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
207                       "version" => &notif.version
208                );
209                self.ack_state.unacked_direct_notifs.remove(pos);
210                self.stats.direct_acked += 1;
211                continue;
212            };
213
214            // Now, check the list of stored notifications
215            let pos = self
216                .ack_state
217                .unacked_stored_notifs
218                .iter()
219                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
220            if let Some(pos) = pos {
221                debug!(
222                    "✅ Ack (Stored)";
223                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
224                       "version" => &notif.version,
225                       "message_type" => MessageType::Ack.as_ref()
226                );
227                // Get the stored notification record.
228                let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
229                let is_topic = acked_notification.topic.is_some();
230                debug!("✅ Ack notif: {:?}", &acked_notification);
231                // Only force delete Topic messages, since they don't have a timestamp.
232                // Other messages persist in the database, to be, eventually, cleaned up by their
233                // TTL. We will need to update the `CurrentTimestamp` field for the channel
234                // record. Use that field to set the baseline timestamp for when to pull messages
235                // in the future.
236                if is_topic {
237                    debug!(
238                        "✅ WebPushClient:ack removing Stored, sort_key: {}",
239                        &acked_notification.chidmessageid()
240                    );
241                    self.app_state
242                        .db
243                        .remove_message(&self.uaid, &acked_notification.chidmessageid())
244                        .await?;
245                    // NOTE: timestamp messages may still be in state of flux: they're not fully
246                    // ack'd (removed/unable to be resurrected) until increment_storage is called,
247                    // so their reliability is recorded there
248                    #[cfg(feature = "reliable_report")]
249                    acked_notification
250                        .record_reliability(&self.app_state.reliability, notif.reliability_state())
251                        .await;
252                }
253                let n = self.ack_state.unacked_stored_notifs.remove(pos);
254                #[cfg(feature = "reliable_report")]
255                if !is_topic {
256                    self.ack_state.acked_stored_timestamp_notifs.push(n);
257                }
258                self.stats.stored_acked += 1;
259                continue;
260            };
261        }
262
263        if self.ack_state.unacked_notifs() {
264            // Wait for the Client to Ack all notifications before further
265            // processing
266            Ok(vec![])
267        } else {
268            self.post_process_all_acked().await
269        }
270    }
271
272    /// Negative Acknowledgement (a Client error occurred) of one or more Push
273    /// Notifications
274    fn nack(&mut self, code: Option<i32>) {
275        trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
276        // only metric codes expected from the client (or 0)
277        let code = code
278            .and_then(|code| (301..=303).contains(&code).then_some(code))
279            .unwrap_or(0);
280        self.app_state
281            .metrics
282            .incr_with_tags(MetricName::UaCommand(MessageType::Nack.to_string()))
283            .with_tag("code", &code.to_string())
284            .send();
285        self.stats.nacks += 1;
286    }
287
288    /// Handle a WebPush Ping
289    ///
290    /// Note this is the WebPush Protocol level's Ping: this differs from the
291    /// lower level WebSocket Ping frame (handled by the `webpush_ws` handler).
292    fn ping(&mut self) -> Result<ServerMessage, SMError> {
293        trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
294        // TODO: why is this 45 vs the comment describing a minute? and 45
295        // should be a setting
296        // Clients shouldn't ping > than once per minute or we disconnect them
297        if sec_since_epoch() - self.last_ping >= 45 {
298            trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
299            self.last_ping = sec_since_epoch();
300            Ok(ServerMessage::Ping)
301        } else {
302            Err(SMErrorKind::ExcessivePing.into())
303        }
304    }
305
306    /// Post process the Client succesfully Ack'ing all Push Notifications it's
307    /// been sent.
308    ///
309    /// Notifications are read in small batches (approximately 10). We wait for
310    /// the Client to Ack every Notification in that batch (invoking this
311    /// method) before proceeding to read the next batch (or potential other
312    /// actions such as `reset_uaid`).
313    async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
314        trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
315        let flags = &self.flags;
316        if flags.check_storage {
317            if flags.increment_storage {
318                debug!(
319                    "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
320                    "message_type" => MessageType::Notification.as_ref()
321                );
322                self.increment_storage().await?;
323            }
324
325            debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
326            let smsgs = self.check_storage_loop().await?;
327            if !smsgs.is_empty() {
328                debug_assert!(self.flags.check_storage);
329                // More outgoing notifications: send them out and go back to
330                // waiting for the Client to Ack them all before further
331                // processing
332                return Ok(smsgs);
333            }
334            // Otherwise check_storage is finished
335            debug_assert!(!self.flags.check_storage);
336            debug_assert!(!self.flags.increment_storage);
337        }
338
339        // All Ack'd and finished checking/incrementing storage
340        debug_assert!(!self.ack_state.unacked_notifs());
341        let flags = &self.flags;
342        if flags.old_record_version {
343            debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
344            self.app_state
345                .metrics
346                .incr_with_tags(MetricName::UaExpiration)
347                .with_tag("reason", "old_record_version")
348                .send();
349            self.app_state.db.remove_user(&self.uaid).await?;
350            Err(SMErrorKind::UaidReset.into())
351        } else {
352            Ok(vec![])
353        }
354    }
355}