autoconnect_ws_sm/identified/
on_client_msg.rs

1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6    broadcast::Broadcast,
7    protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10    endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11    util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18    /// Handle a WebPush `ClientMessage` sent from the user agent over the
19    /// WebSocket for this user
20    pub async fn on_client_msg(
21        &mut self,
22        msg: ClientMessage,
23    ) -> Result<Vec<ServerMessage>, SMError> {
24        match msg {
25            ClientMessage::Hello { .. } => {
26                Err(SMError::invalid_message("Already Hello'd".to_owned()))
27            }
28            ClientMessage::Register { channel_id, key } => {
29                Ok(vec![self.register(channel_id, key).await?])
30            }
31            ClientMessage::Unregister { channel_id, code } => {
32                Ok(vec![self.unregister(channel_id, code).await?])
33            }
34            ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35                .broadcast_subscribe(broadcasts)
36                .await?
37                .map_or_else(Vec::new, |smsg| vec![smsg])),
38            ClientMessage::Ack { updates } => self.ack(&updates).await,
39            ClientMessage::Nack { code, .. } => {
40                self.nack(code);
41                Ok(vec![])
42            }
43            ClientMessage::Ping => Ok(vec![self.ping()?]),
44        }
45    }
46
47    /// Register a new Push subscription
48    async fn register(
49        &mut self,
50        channel_id_str: String,
51        key: Option<String>,
52    ) -> Result<ServerMessage, SMError> {
53        let uaid_str = self.uaid.to_string();
54        trace!("WebPushClient:register";
55               "uaid" => &uaid_str,
56               "channel_id" => &channel_id_str,
57               "key" => &key,
58               "message_type" => MessageType::Register.as_ref(),
59        );
60        let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61            SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62        })?;
63        if channel_id.as_hyphenated().to_string() != channel_id_str {
64            return Err(SMError::invalid_message(format!(
65                "Invalid UUID format, not lower-case/dashed: {channel_id}",
66            )));
67        }
68
69        let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70            Ok(endpoint) => {
71                let _ = self.app_state.metrics.incr(MetricName::UaCommandRegister);
72                self.stats.registers += 1;
73                (200, endpoint)
74            }
75            Err(SMErrorKind::MakeEndpoint(msg)) => {
76                error!("WebPushClient::register make_endpoint failed: {}", msg);
77                (400, "Failed to generate endpoint".to_owned())
78            }
79            Err(e) => {
80                error!("WebPushClient::register failed: {}", e);
81                (500, "".to_owned())
82            }
83        };
84        Ok(ServerMessage::Register {
85            channel_id,
86            status,
87            push_endpoint,
88        })
89    }
90
91    async fn do_register(
92        &mut self,
93        channel_id: &Uuid,
94        key: Option<String>,
95    ) -> Result<String, SMErrorKind> {
96        if let Some(user) = &self.deferred_add_user {
97            debug!(
98                "💬WebPushClient::register: User not yet registered: {}",
99                &user.uaid
100            );
101            self.app_state.db.add_user(user).await?;
102            self.deferred_add_user = None;
103        }
104
105        let endpoint = make_endpoint(
106            &self.uaid,
107            channel_id,
108            key.as_deref(),
109            &self.app_state.endpoint_url,
110            &self.app_state.fernet,
111        )
112        .map_err(SMErrorKind::MakeEndpoint)?;
113        self.app_state
114            .db
115            .add_channel(&self.uaid, channel_id)
116            .await?;
117        Ok(endpoint)
118    }
119
120    /// Unregister an existing Push subscription
121    async fn unregister(
122        &mut self,
123        channel_id: Uuid,
124        code: Option<u32>,
125    ) -> Result<ServerMessage, SMError> {
126        let uaid_str = self.uaid.to_string();
127        let chid_str = channel_id.to_string();
128        trace!("WebPushClient:unregister";
129               "uaid" => &uaid_str,
130               "channel_id" => &chid_str,
131               "code" => &code,
132               "message_type" => MessageType::Unregister.as_ref(),
133        );
134        // TODO: (copied from previous state machine) unregister should check
135        // the format of channel_id like register does
136
137        let result = self
138            .app_state
139            .db
140            .remove_channel(&self.uaid, &channel_id)
141            .await;
142        let status = match result {
143            Ok(_) => {
144                self.app_state
145                    .metrics
146                    .incr_with_tags(MetricName::UaCommandUnregister)
147                    .with_tag("code", &code.unwrap_or(200).to_string())
148                    .send();
149                self.stats.unregisters += 1;
150                200
151            }
152            Err(e) => {
153                error!("WebPushClient::unregister failed: {}", e);
154                500
155            }
156        };
157        Ok(ServerMessage::Unregister { channel_id, status })
158    }
159
160    /// Subscribe to a new set of Broadcasts
161    async fn broadcast_subscribe(
162        &mut self,
163        broadcasts: HashMap<String, String>,
164    ) -> Result<Option<ServerMessage>, SMError> {
165        trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
166        let broadcasts = Broadcast::from_hashmap(broadcasts);
167        let mut response: HashMap<String, BroadcastValue> = HashMap::new();
168
169        let bc = self.app_state.broadcaster.read().await;
170        if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
171            response.extend(Broadcast::vec_into_hashmap(delta));
172        };
173        let missing = bc.missing_broadcasts(&broadcasts);
174        if !missing.is_empty() {
175            response.insert(
176                "errors".to_owned(),
177                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
178            );
179        }
180
181        Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
182            broadcasts: response,
183        }))
184    }
185
186    /// Acknowledge receipt of one or more Push Notifications
187    async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
188        trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
189        let _ = self.app_state.metrics.incr(MetricName::UaCommandAck);
190
191        for notif in updates {
192            let key = notif.version.clone();
193            // Check the map of unacked "direct" (unstored) notifications.
194            if self.ack_state.unacked_direct_notifs.remove(&key).is_some() {
195                debug!("✅ Ack (Direct)";
196                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
197                       "version" => &notif.version
198                );
199                self.stats.direct_acked += 1;
200                continue;
201            };
202
203            // Now, check the map of stored notifications
204            #[allow(unused_mut)]
205            if let Some(mut acked_notification) = self.ack_state.unacked_stored_notifs.remove(&key)
206            {
207                debug!(
208                    "✅ Ack (Stored)";
209                       "channel_id" => notif.channel_id.as_hyphenated().to_string(),
210                       "version" => &notif.version,
211                       "message_type" => MessageType::Ack.as_ref()
212                );
213                // Some storage engines may set this to "".
214                let is_topic = acked_notification
215                    .topic
216                    .as_ref()
217                    .map(|t| !t.is_empty())
218                    .unwrap_or(false);
219                debug!("✅ Ack notif: {:?}", &acked_notification);
220                // Only force delete Topic messages, since they don't have a timestamp.
221                // Other messages persist in the database, to be, eventually, cleaned up by their
222                // TTL. We will need to update the `CurrentTimestamp` field for the channel
223                // record. Use that field to set the baseline timestamp for when to pull messages
224                // in the future.
225                if is_topic {
226                    let chid_message_id = &acked_notification.chidmessageid();
227                    debug!(
228                        "✅🗑 WebPushClient:ack removing Stored, sort_key: {}, version: {}",
229                        &chid_message_id, &acked_notification.version
230                    );
231                    self.app_state
232                        .db
233                        .remove_message(&self.uaid, chid_message_id)
234                        .await?;
235                    // NOTE: timestamp messages may still be in state of flux: they're not fully
236                    // ack'd (removed/unable to be resurrected) until increment_storage is called,
237                    // so their reliability is recorded there
238                    #[cfg(feature = "reliable_report")]
239                    acked_notification
240                        .record_reliability(&self.app_state.reliability, notif.reliability_state())
241                        .await;
242                }
243                #[cfg(feature = "reliable_report")]
244                if !is_topic {
245                    self.ack_state
246                        .acked_stored_timestamp_notifs
247                        .push(acked_notification);
248                }
249                self.stats.stored_acked += 1;
250                continue;
251            };
252        }
253
254        if self.ack_state.unacked_notifs() {
255            // Wait for the Client to Ack all notifications before further
256            // processing
257            Ok(vec![])
258        } else {
259            self.post_process_all_acked().await
260        }
261    }
262
263    /// Negative Acknowledgement (a Client error occurred) of one or more Push
264    /// Notifications
265    fn nack(&mut self, code: Option<i32>) {
266        trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
267        // only metric codes expected from the client (or 0)
268        let code = code
269            .and_then(|code| (301..=303).contains(&code).then_some(code))
270            .unwrap_or(0);
271        self.app_state
272            .metrics
273            .incr_with_tags(MetricName::UaCommandNack)
274            .with_tag("code", &code.to_string())
275            .send();
276        self.stats.nacks += 1;
277    }
278
279    /// Handle a WebPush Ping
280    ///
281    /// Note this is the WebPush Protocol level's Ping: this differs from the
282    /// lower level WebSocket Ping frame (handled by the `webpush_ws` handler).
283    fn ping(&mut self) -> Result<ServerMessage, SMError> {
284        trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
285        // TODO: why is this 45 vs the comment describing a minute? and 45
286        // should be a setting
287        // Clients shouldn't ping > than once per minute or we disconnect them
288        if sec_since_epoch() - self.last_ping >= 45 {
289            trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
290            self.last_ping = sec_since_epoch();
291            Ok(ServerMessage::Ping)
292        } else {
293            Err(SMErrorKind::ExcessivePing.into())
294        }
295    }
296
297    /// Post process the Client succesfully Ack'ing all Push Notifications it's
298    /// been sent.
299    ///
300    /// Notifications are read in small batches (approximately 10). We wait for
301    /// the Client to Ack every Notification in that batch (invoking this
302    /// method) before proceeding to read the next batch (or potential other
303    /// actions such as `reset_uaid`).
304    async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
305        trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
306        let flags = &self.flags;
307        if flags.check_storage {
308            if flags.increment_storage {
309                debug!(
310                    "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
311                    "message_type" => MessageType::Notification.as_ref()
312                );
313                self.increment_storage().await?;
314            }
315
316            debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
317            let smsgs = self.check_storage_loop().await?;
318            if !smsgs.is_empty() {
319                debug_assert!(self.flags.check_storage);
320                // More outgoing notifications: send them out and go back to
321                // waiting for the Client to Ack them all before further
322                // processing
323                return Ok(smsgs);
324            }
325            // Otherwise check_storage is finished
326            debug_assert!(!self.flags.check_storage);
327            debug_assert!(!self.flags.increment_storage);
328        }
329
330        // All Ack'd and finished checking/incrementing storage
331        debug_assert!(!self.ack_state.unacked_notifs());
332        let flags = &self.flags;
333        if flags.old_record_version {
334            debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
335            self.app_state
336                .metrics
337                .incr_with_tags(MetricName::UaExpiration)
338                .with_tag("reason", "old_record_version")
339                .send();
340            self.app_state.db.remove_user(&self.uaid).await?;
341            Err(SMErrorKind::UaidReset.into())
342        } else {
343            Ok(vec![])
344        }
345    }
346}