autoconnect_ws_sm/identified/
on_client_msg.rs

1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6    broadcast::Broadcast,
7    protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10    endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11    util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18    /// Handle a WebPush `ClientMessage` sent from the user agent over the
19    /// WebSocket for this user
20    pub async fn on_client_msg(
21        &mut self,
22        msg: ClientMessage,
23    ) -> Result<Vec<ServerMessage>, SMError> {
24        match msg {
25            ClientMessage::Hello { .. } => {
26                Err(SMError::invalid_message("Already Hello'd".to_owned()))
27            }
28            ClientMessage::Register { channel_id, key } => {
29                Ok(vec![self.register(channel_id, key).await?])
30            }
31            ClientMessage::Unregister { channel_id, code } => {
32                Ok(vec![self.unregister(channel_id, code).await?])
33            }
34            ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35                .broadcast_subscribe(broadcasts)
36                .await?
37                .map_or_else(Vec::new, |smsg| vec![smsg])),
38            ClientMessage::Ack { updates } => self.ack(&updates).await,
39            ClientMessage::Nack { code, .. } => {
40                self.nack(code);
41                Ok(vec![])
42            }
43            ClientMessage::Ping => Ok(vec![self.ping()?]),
44        }
45    }
46
47    /// Register a new Push subscription
48    async fn register(
49        &mut self,
50        channel_id_str: String,
51        key: Option<String>,
52    ) -> Result<ServerMessage, SMError> {
53        let uaid_str = self.uaid.to_string();
54        trace!("WebPushClient:register";
55               "uaid" => &uaid_str,
56               "channel_id" => &channel_id_str,
57               "key" => &key,
58               "message_type" => MessageType::Register.as_ref(),
59        );
60        let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61            SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62        })?;
63        if channel_id.as_hyphenated().to_string() != channel_id_str {
64            return Err(SMError::invalid_message(format!(
65                "Invalid UUID format, not lower-case/dashed: {channel_id}",
66            )));
67        }
68
69        let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70            Ok(endpoint) => {
71                let _ = self.app_state.metrics.incr(MetricName::UaCommandRegister);
72                self.stats.registers += 1;
73                (200, endpoint)
74            }
75            Err(SMErrorKind::MakeEndpoint(msg)) => {
76                error!("WebPushClient::register make_endpoint failed: {}", msg);
77                (400, "Failed to generate endpoint".to_owned())
78            }
79            Err(e) => {
80                error!("WebPushClient::register failed: {}", e);
81                (500, "".to_owned())
82            }
83        };
84        Ok(ServerMessage::Register {
85            channel_id,
86            status,
87            push_endpoint,
88        })
89    }
90
91    async fn do_register(
92        &mut self,
93        channel_id: &Uuid,
94        key: Option<String>,
95    ) -> Result<String, SMErrorKind> {
96        if let Some(user) = &self.deferred_add_user {
97            debug!(
98                "💬WebPushClient::register: User not yet registered: {}",
99                &user.uaid
100            );
101            self.app_state.db.add_user(user).await?;
102            self.deferred_add_user = None;
103        }
104
105        let endpoint = make_endpoint(
106            &self.uaid,
107            channel_id,
108            key.as_deref(),
109            &self.app_state.endpoint_url,
110            &self.app_state.fernet,
111        )
112        .map_err(SMErrorKind::MakeEndpoint)?;
113        self.app_state
114            .db
115            .add_channel(&self.uaid, channel_id)
116            .await?;
117        Ok(endpoint)
118    }
119
120    /// Unregister an existing Push subscription
121    async fn unregister(
122        &mut self,
123        channel_id: Uuid,
124        code: Option<u32>,
125    ) -> Result<ServerMessage, SMError> {
126        let uaid_str = self.uaid.to_string();
127        let chid_str = channel_id.to_string();
128        trace!("WebPushClient:unregister";
129               "uaid" => &uaid_str,
130               "channel_id" => &chid_str,
131               "code" => &code,
132               "message_type" => MessageType::Unregister.as_ref(),
133        );
134        // TODO: (copied from previous state machine) unregister should check
135        // the format of channel_id like register does
136
137        let result = self
138            .app_state
139            .db
140            .remove_channel(&self.uaid, &channel_id)
141            .await;
142        let status = match result {
143            Ok(_) => {
144                self.app_state
145                    .metrics
146                    .incr_with_tags(MetricName::UaCommandUnregister)
147                    .with_tag("code", &code.unwrap_or(200).to_string())
148                    .send();
149                self.stats.unregisters += 1;
150                200
151            }
152            Err(e) => {
153                error!("WebPushClient::unregister failed: {}", e);
154                500
155            }
156        };
157        Ok(ServerMessage::Unregister { channel_id, status })
158    }
159
160    /// Subscribe to a new set of Broadcasts
161    async fn broadcast_subscribe(
162        &mut self,
163        broadcasts: HashMap<String, String>,
164    ) -> Result<Option<ServerMessage>, SMError> {
165        trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166        let broadcasts = Broadcast::from_hashmap(broadcasts);
167        let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169        let bc = self.app_state.broadcaster.read().await;
170        if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171            response.extend(Broadcast::vec_into_hashmap(delta));
172        };
173        let missing = bc.missing_broadcasts(&broadcasts);
174        if !missing.is_empty() {
175            response.insert(
176                "errors".to_owned(),
177                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178            );
179        }
180
181        Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182            broadcasts: response,
183        }))
184    }
185
186    /// Acknowledge receipt of one or more Push Notifications
187    async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188        trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189        let _ = self.app_state.metrics.incr(MetricName::UaCommandAck);
190
191        for notif in updates {
192            // Check the list of unacked "direct" (unstored) notifications. We only want to
193            // ack messages we've not yet seen and we have the right version, otherwise we could
194            // have gotten an older, inaccurate ACK.
195            let pos = self
196                .ack_state
197                .unacked_direct_notifs
198                .iter()
199                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
200            // We found one, so delete it from our list of unacked messages
201            if let Some(pos) = pos {
202                debug!("✅ Ack (Direct)";
203                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
204                       "version" => &notif.version
205                );
206                self.ack_state.unacked_direct_notifs.remove(pos);
207                self.stats.direct_acked += 1;
208                continue;
209            };
210
211            // Now, check the list of stored notifications
212            let pos = self
213                .ack_state
214                .unacked_stored_notifs
215                .iter()
216                .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
217            if let Some(pos) = pos {
218                debug!(
219                    "✅ Ack (Stored)";
220                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
221                       "version" => &notif.version,
222                       "message_type" => MessageType::Ack.as_ref()
223                );
224                // Get the stored notification record.
225                let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
226                let is_topic = acked_notification.topic.is_some();
227                debug!("✅ Ack notif: {:?}", &acked_notification);
228                // Only force delete Topic messages, since they don't have a timestamp.
229                // Other messages persist in the database, to be, eventually, cleaned up by their
230                // TTL. We will need to update the `CurrentTimestamp` field for the channel
231                // record. Use that field to set the baseline timestamp for when to pull messages
232                // in the future.
233                if is_topic {
234                    let chid = &acked_notification.chidmessageid();
235                    debug!("✅ WebPushClient:ack removing Stored, sort_key: {}", &chid);
236                    self.app_state.db.remove_message(&self.uaid, chid).await?;
237                    // NOTE: timestamp messages may still be in state of flux: they're not fully
238                    // ack'd (removed/unable to be resurrected) until increment_storage is called,
239                    // so their reliability is recorded there
240                    #[cfg(feature = "reliable_report")]
241                    acked_notification
242                        .record_reliability(&self.app_state.reliability, notif.reliability_state())
243                        .await;
244                }
245                let _n = self.ack_state.unacked_stored_notifs.remove(pos);
246                #[cfg(feature = "reliable_report")]
247                if !is_topic {
248                    self.ack_state.acked_stored_timestamp_notifs.push(_n);
249                }
250                self.stats.stored_acked += 1;
251                continue;
252            };
253        }
254
255        if self.ack_state.unacked_notifs() {
256            // Wait for the Client to Ack all notifications before further
257            // processing
258            Ok(vec![])
259        } else {
260            self.post_process_all_acked().await
261        }
262    }
263
264    /// Negative Acknowledgement (a Client error occurred) of one or more Push
265    /// Notifications
266    fn nack(&mut self, code: Option<i32>) {
267        trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
268        // only metric codes expected from the client (or 0)
269        let code = code
270            .and_then(|code| (301..=303).contains(&code).then_some(code))
271            .unwrap_or(0);
272        self.app_state
273            .metrics
274            .incr_with_tags(MetricName::UaCommandNack)
275            .with_tag("code", &code.to_string())
276            .send();
277        self.stats.nacks += 1;
278    }
279
280    /// Handle a WebPush Ping
281    ///
282    /// Note this is the WebPush Protocol level's Ping: this differs from the
283    /// lower level WebSocket Ping frame (handled by the `webpush_ws` handler).
284    fn ping(&mut self) -> Result<ServerMessage, SMError> {
285        trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
286        // TODO: why is this 45 vs the comment describing a minute? and 45
287        // should be a setting
288        // Clients shouldn't ping > than once per minute or we disconnect them
289        if sec_since_epoch() - self.last_ping >= 45 {
290            trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
291            self.last_ping = sec_since_epoch();
292            Ok(ServerMessage::Ping)
293        } else {
294            Err(SMErrorKind::ExcessivePing.into())
295        }
296    }
297
298    /// Post process the Client succesfully Ack'ing all Push Notifications it's
299    /// been sent.
300    ///
301    /// Notifications are read in small batches (approximately 10). We wait for
302    /// the Client to Ack every Notification in that batch (invoking this
303    /// method) before proceeding to read the next batch (or potential other
304    /// actions such as `reset_uaid`).
305    async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
306        trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
307        let flags = &self.flags;
308        if flags.check_storage {
309            if flags.increment_storage {
310                debug!(
311                    "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
312                    "message_type" => MessageType::Notification.as_ref()
313                );
314                self.increment_storage().await?;
315            }
316
317            debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
318            let smsgs = self.check_storage_loop().await?;
319            if !smsgs.is_empty() {
320                debug_assert!(self.flags.check_storage);
321                // More outgoing notifications: send them out and go back to
322                // waiting for the Client to Ack them all before further
323                // processing
324                return Ok(smsgs);
325            }
326            // Otherwise check_storage is finished
327            debug_assert!(!self.flags.check_storage);
328            debug_assert!(!self.flags.increment_storage);
329        }
330
331        // All Ack'd and finished checking/incrementing storage
332        debug_assert!(!self.ack_state.unacked_notifs());
333        let flags = &self.flags;
334        if flags.old_record_version {
335            debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
336            self.app_state
337                .metrics
338                .incr_with_tags(MetricName::UaExpiration)
339                .with_tag("reason", "old_record_version")
340                .send();
341            self.app_state.db.remove_user(&self.uaid).await?;
342            Err(SMErrorKind::UaidReset.into())
343        } else {
344            Ok(vec![])
345        }
346    }
347}