autoconnect_ws_sm/identified/
mod.rs

1use std::{collections::HashMap, fmt, mem, sync::Arc};
2
3use actix_web::rt;
4use cadence::Timed;
5use futures::channel::mpsc;
6use once_cell::sync::Lazy;
7use tokio::sync::Semaphore;
8use uuid::Uuid;
9
10/// Limit concurrent disconnect-cleanup tasks to prevent resource exhaustion
11/// during disconnect storms (e.g., pod scaling events).
12static DISCONNECT_SEMAPHORE: Lazy<Semaphore> = Lazy::new(|| Semaphore::new(64));
13
14use autoconnect_common::{
15    broadcast::{Broadcast, BroadcastSubs},
16    protocol::{ServerMessage, ServerNotification},
17};
18
19use autoconnect_settings::{AppState, Settings};
20use autopush_common::{
21    db::User,
22    metric_name::MetricName,
23    metrics::StatsdClientExt,
24    notification::Notification,
25    util::{ms_since_epoch, user_agent::UserAgentInfo},
26};
27
28use crate::error::{SMError, SMErrorKind};
29
30mod on_client_msg;
31mod on_server_notif;
32
33/// A WebPush Client that's successfully identified itself to the server via a
34/// Hello message.
35///
36/// The `webpush_ws` handler feeds input from both the WebSocket connection
37/// (`ClientMessage`) and the `ClientRegistry` (`ServerNotification`)
38/// triggered by autoendpoint to this type's `on_client_msg` and
39/// `on_server_notif` methods whose impls reside in their own modules.
40///
41/// Note the `check_storage` method (in the `on_server_notif` module) is
42/// triggered by both a `ServerNotification` and also the `new` constructor
43pub struct WebPushClient {
44    /// Push User Agent identifier. Each Push client recieves a unique UAID
45    pub uaid: Uuid,
46    /// Unique, local (to each autoconnect instance) identifier
47    pub uid: Uuid,
48    /// The User Agent information block derived from the User-Agent header
49    pub ua_info: UserAgentInfo,
50
51    /// Broadcast Subscriptions this Client is subscribed to
52    broadcast_subs: BroadcastSubs,
53
54    /// Set of session specific flags
55    flags: ClientFlags,
56    /// Notification Ack(knowledgement) related state
57    ack_state: AckState,
58    /// Count of messages sent from storage (for enforcing
59    /// `settings.msg_limit`). Resets to 0 when storage is emptied
60    sent_from_storage: u32,
61    /// Exists for new User records: these are not written to the db during
62    /// Hello, instead they're lazily added to the db on their first Register
63    /// message
64    deferred_add_user: Option<User>,
65
66    /// WebPush Session Statistics
67    stats: SessionStatistics,
68
69    /// Timestamp of when the UA connected (used by database lookup, thus u64)
70    connected_at: u64,
71    /// Timestamp of the last WebPush Ping message
72    last_ping: u64,
73    /// The last notification timestamp.
74    // TODO: RENAME THIS TO `last_notification_timestamp`
75    current_timestamp: Option<u64>,
76
77    app_state: Arc<AppState>,
78}
79
80impl fmt::Debug for WebPushClient {
81    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
82        fmt.debug_struct("WebPushClient")
83            .field("uaid", &self.uaid)
84            .field("uid", &self.uid)
85            .field("ua_info", &self.ua_info)
86            .field("broadcast_subs", &self.broadcast_subs)
87            .field("flags", &self.flags)
88            .field("ack_state", &self.ack_state)
89            .field("sent_from_storage", &self.sent_from_storage)
90            .field("deferred_add_user", &self.deferred_add_user)
91            .field("stats", &self.stats)
92            .field("connected_at", &self.connected_at)
93            .field("last_ping", &self.last_ping)
94            .finish()
95    }
96}
97
98impl WebPushClient {
99    #[allow(clippy::too_many_arguments)]
100    pub async fn new(
101        uaid: Uuid,
102        ua: String,
103        broadcast_subs: BroadcastSubs,
104        flags: ClientFlags,
105        connected_at: u64,
106        current_timestamp: Option<u64>,
107        deferred_add_user: Option<User>,
108        app_state: Arc<AppState>,
109    ) -> Result<(Self, Vec<ServerMessage>), SMError> {
110        trace!("👁‍🗨WebPushClient::new");
111        let stats = SessionStatistics {
112            existing_uaid: deferred_add_user.is_none(),
113            ..Default::default()
114        };
115        let mut client = WebPushClient {
116            uaid,
117            uid: Uuid::new_v4(),
118            ua_info: UserAgentInfo::from(ua.as_str()),
119            broadcast_subs,
120            flags,
121            ack_state: Default::default(),
122            sent_from_storage: Default::default(),
123            connected_at,
124            current_timestamp,
125            deferred_add_user,
126            last_ping: Default::default(),
127            stats,
128            app_state,
129        };
130
131        let smsgs = if client.flags.check_storage {
132            let smsgs = client.check_storage().await?;
133            debug!(
134                "WebPushClient::new: check_storage smsgs.len(): {}",
135                smsgs.len()
136            );
137            smsgs
138        } else {
139            vec![]
140        };
141        Ok((client, smsgs))
142    }
143
144    /// Return a reference to `AppState`'s `Settings`
145    pub fn app_settings(&self) -> &Settings {
146        &self.app_state.settings
147    }
148
149    #[cfg(feature = "reliable_report")]
150    pub fn app_reliability(&self) -> &autopush_common::reliability::PushReliability {
151        &self.app_state.reliability
152    }
153
154    /// Connect this `WebPushClient` to the `ClientRegistry`
155    ///
156    /// Returning a `Stream` of `ServerNotification`s from the `ClientRegistry`
157    pub fn registry_connect(&self) -> mpsc::Receiver<ServerNotification> {
158        self.app_state.clients.connect(self.uaid, self.uid)
159    }
160
161    /// Disconnect this `WebPushClient` from the `ClientRegistry`
162    pub fn registry_disconnect(&self) {
163        // Ignore disconnect (Client wasn't connected) Errors
164        let _ = self.app_state.clients.disconnect(&self.uaid, &self.uid);
165    }
166
167    /// Return the difference between the Client's Broadcast Subscriptions and
168    /// the this server's Broadcasts
169    pub async fn broadcast_delta(&mut self) -> Option<Vec<Broadcast>> {
170        self.app_state
171            .broadcaster
172            .read()
173            .await
174            .change_count_delta(&mut self.broadcast_subs)
175    }
176
177    /// Cleanup after the session has ended
178    pub fn shutdown(&mut self, reason: Option<String>) {
179        trace!("👁‍🗨WebPushClient::shutdown");
180        self.save_and_notify_unacked_direct_notifs();
181
182        let ua_info = &self.ua_info;
183        let stats = &self.stats;
184        let elapsed_sec = (ms_since_epoch() - self.connected_at) / 1_000;
185        self.app_state
186            .metrics
187            .time_with_tags("ua.connection.lifespan", elapsed_sec)
188            .with_tag("ua_os_family", &ua_info.metrics_os)
189            .with_tag("ua_browser_family", &ua_info.metrics_browser)
190            .send();
191
192        // Log out the final stats message
193        info!("Session";
194            "uaid_hash" => self.uaid.as_simple().to_string(),
195            "uaid_reset" => self.flags.old_record_version,
196            "existing_uaid" => stats.existing_uaid,
197            "connection_type" => "webpush",
198            "ua_name" => &ua_info.browser_name,
199            "ua_os_family" => &ua_info.metrics_os,
200            "ua_os_ver" => &ua_info.os_version,
201            "ua_browser_family" => &ua_info.metrics_browser,
202            "ua_browser_ver" => &ua_info.browser_version,
203            "ua_category" => &ua_info.category,
204            "connection_time" => elapsed_sec,
205            "direct_acked" => stats.direct_acked,
206            "direct_storage" => stats.direct_storage,
207            "stored_retrieved" => stats.stored_retrieved,
208            "stored_acked" => stats.stored_acked,
209            "nacks" => stats.nacks,
210            "registers" => stats.registers,
211            "unregisters" => stats.unregisters,
212            "disconnect_reason" => reason.unwrap_or_else(|| "".to_owned()),
213        );
214    }
215
216    /// Save any Direct unAck'd messages to the db (on shutdown)
217    ///
218    /// Direct messages are solely stored in memory until Ack'd by the Client,
219    /// so on shutdown, any not Ack'd are stored in the db to not be lost
220    fn save_and_notify_unacked_direct_notifs(&mut self) {
221        let notif_map = mem::take(&mut self.ack_state.unacked_direct_notifs);
222        trace!(
223            "👁‍🗨WebPushClient::save_and_notify_unacked_direct_notifs len: {}",
224            notif_map.len()
225        );
226        if notif_map.is_empty() {
227            return;
228        }
229
230        self.stats.direct_storage += notif_map.len() as i32;
231        // TODO: clarify this comment re the Python version
232        // Ensure we don't store these as legacy by setting a 0 as the
233        // sortkey_timestamp. This ensures the Python side doesn't mark it as
234        // legacy during conversion and still get the correct default us_time
235        // when saving
236        let mut notifs: Vec<Notification> = notif_map.into_values().collect();
237        for notif in &mut notifs {
238            notif.sortkey_timestamp = Some(0);
239        }
240
241        let app_state = Arc::clone(&self.app_state);
242        let uaid = self.uaid;
243        let connected_at = self.connected_at;
244        rt::spawn(async move {
245            let _permit = match DISCONNECT_SEMAPHORE.acquire().await {
246                Ok(permit) => permit,
247                Err(_) => {
248                    app_state
249                        .metrics
250                        .incr(MetricName::ErrorDisconnectSemaphoreFull)
251                        .ok();
252                    warn!("Disconnect semaphore full, skipping save of unacked direct notifs");
253                    return Ok(());
254                }
255            };
256            #[cfg(not(feature = "reliable_report"))]
257            app_state.db.save_messages(&uaid, notifs).await?;
258            #[cfg(feature = "reliable_report")]
259            {
260                app_state.db.save_messages(&uaid, notifs.clone()).await?;
261                for mut notif in notifs {
262                    notif
263                        .record_reliability(
264                            &app_state.reliability,
265                            autopush_common::reliability::ReliabilityState::Stored,
266                        )
267                        .await;
268                }
269            }
270            debug!("Finished saving unacked direct notifs, checking for reconnect");
271            let Some(user) = app_state.db.get_user(&uaid).await? else {
272                return Err(SMErrorKind::Internal(format!(
273                    "User not found for unacked direct notifs: {uaid}"
274                )));
275            };
276            if connected_at == user.connected_at {
277                return Ok(());
278            }
279            if let Some(node_id) = user.node_id {
280                app_state
281                    .http
282                    .put(format!("{}/notif/{}", node_id, uaid.as_simple()))
283                    .send()
284                    .await?
285                    .error_for_status()?;
286            }
287            Ok(())
288        });
289    }
290
291    /// Add User information and tags for this Client to a Sentry Event
292    pub fn add_sentry_info(self, event: &mut sentry::protocol::Event) {
293        event.user = Some(sentry::User {
294            id: Some(self.uaid.as_simple().to_string()),
295            ..Default::default()
296        });
297        let ua_info = self.ua_info;
298        event
299            .tags
300            .insert("ua_name".to_owned(), ua_info.browser_name);
301        event
302            .tags
303            .insert("ua_os_family".to_owned(), ua_info.metrics_os);
304        event
305            .tags
306            .insert("ua_os_ver".to_owned(), ua_info.os_version);
307        event
308            .tags
309            .insert("ua_browser_family".to_owned(), ua_info.metrics_browser);
310        event
311            .tags
312            .insert("ua_browser_ver".to_owned(), ua_info.browser_version);
313    }
314}
315
316#[derive(Debug)]
317pub struct ClientFlags {
318    /// Whether check_storage queries for topic (not "timestamped") messages
319    pub include_topic: bool,
320    /// Flags the need to increment the last read for timestamp for timestamped messages
321    pub increment_storage: bool,
322    /// Whether this client needs to check storage for messages
323    pub check_storage: bool,
324    /// Flags the need to drop the user record
325    pub old_record_version: bool,
326    /// First time a user has connected "today"
327    pub emit_channel_metrics: bool,
328}
329
330impl Default for ClientFlags {
331    fn default() -> Self {
332        Self {
333            include_topic: true,
334            increment_storage: false,
335            check_storage: false,
336            old_record_version: false,
337            emit_channel_metrics: false,
338        }
339    }
340}
341
342/// WebPush Session Statistics
343///
344/// Tracks statistics about the session that are logged when the session's
345/// closed
346#[derive(Debug, Default)]
347pub struct SessionStatistics {
348    /// Number of acknowledged messages that were sent directly (not via storage)
349    direct_acked: i32,
350    /// Number of messages sent to storage
351    direct_storage: i32,
352    /// Number of messages taken from storage
353    stored_retrieved: i32,
354    /// Number of message pulled from storage and acknowledged
355    stored_acked: i32,
356    /// Number of messages total that are not acknowledged.
357    nacks: i32,
358    /// Number of unregister requests
359    unregisters: i32,
360    /// Number of register requests
361    registers: i32,
362    /// Whether this uaid was previously registered
363    existing_uaid: bool,
364}
365
366/// Key for looking up notifications in the ACK tracking maps.
367/// The `version` field is the `Notification.message_id`, a fernet-encrypted
368/// composite of uaid, channel_id, and topic|timestamp. It is globally unique
369/// (fernet adds random padding), so `version` alone suffices as the key.
370type AckKey = String;
371
372/// Record of Notifications sent to the Client.
373#[derive(Debug, Default)]
374struct AckState {
375    /// Map of unAck'd directly sent (never stored) notifications
376    unacked_direct_notifs: HashMap<AckKey, Notification>,
377    /// Map of unAck'd sent notifications from storage
378    unacked_stored_notifs: HashMap<AckKey, Notification>,
379    /// List of Ack'd timestamp notifications from storage, cleared
380    /// via `increment_storage`
381    #[cfg(feature = "reliable_report")]
382    acked_stored_timestamp_notifs: Vec<Notification>,
383    /// Either the `current_timestamp` value in storage (returned from
384    /// `fetch_messages`) or the last unAck'd timestamp Message's
385    /// `sortkey_timestamp` (returned from `fetch_timestamp_messages`).
386    ///
387    /// This represents the "pointer" to the beginning (more specifically the
388    /// record preceeding the beginning used in a Greater Than query) of the
389    /// next batch of timestamp Messages.
390    ///
391    /// Thus this value is:
392    ///
393    /// a) initially None, then
394    ///
395    /// b) retrieved from `current_timestamp` in storage then passed as the
396    /// `timestamp` to `fetch_timestamp_messages`. When all of those timestamp
397    /// Messages are Ack'd, this value's then
398    ///
399    /// c) written back to `current_timestamp` in storage via
400    /// `increment_storage`
401    unacked_stored_highest: Option<u64>,
402}
403
404impl AckState {
405    /// Whether the Client has outstanding notifications sent to it that it has
406    /// yet to Ack
407    fn unacked_notifs(&self) -> bool {
408        !self.unacked_stored_notifs.is_empty() || !self.unacked_direct_notifs.is_empty()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use std::sync::Arc;
415
416    use uuid::Uuid;
417
418    use autoconnect_common::{
419        protocol::{ClientMessage, ServerMessage, ServerNotification},
420        test_support::{DUMMY_CHID, DUMMY_UAID, UA},
421    };
422    use autoconnect_settings::AppState;
423    use autopush_common::{
424        db::{client::FetchMessageResponse, mock::MockDbClient},
425        notification::Notification,
426        util::{ms_since_epoch, sec_since_epoch},
427    };
428
429    use super::WebPushClient;
430
431    async fn wpclient(uaid: Uuid, app_state: AppState) -> (WebPushClient, Vec<ServerMessage>) {
432        WebPushClient::new(
433            uaid,
434            UA.to_owned(),
435            Default::default(),
436            Default::default(),
437            ms_since_epoch(),
438            None,
439            None,
440            Arc::new(app_state),
441        )
442        .await
443        .unwrap()
444    }
445
446    /// Generate a dummy timestamp `Notification`
447    fn new_timestamp_notif(channel_id: &Uuid, ttl: u64) -> Notification {
448        Notification {
449            channel_id: *channel_id,
450            ttl,
451            timestamp: sec_since_epoch(),
452            sortkey_timestamp: Some(ms_since_epoch()),
453            ..Default::default()
454        }
455    }
456
457    #[actix_rt::test]
458    async fn webpush_ping() {
459        let (mut client, _) = wpclient(DUMMY_UAID, Default::default()).await;
460        let pong = client.on_client_msg(ClientMessage::Ping).await.unwrap();
461        assert!(matches!(pong.as_slice(), [ServerMessage::Ping]));
462    }
463
464    #[actix_rt::test]
465    async fn expired_increments_storage() {
466        let mut db = MockDbClient::new();
467        let mut seq = mockall::Sequence::new();
468        let timestamp = sec_since_epoch();
469        // No topic messages
470        db.expect_fetch_topic_messages()
471            .times(1)
472            .in_sequence(&mut seq)
473            .return_once(move |_, _| {
474                Ok(FetchMessageResponse {
475                    timestamp: None,
476                    messages: vec![],
477                })
478            });
479        // Return expired notifs (default ttl of 0)
480        db.expect_fetch_timestamp_messages()
481            .times(1)
482            .in_sequence(&mut seq)
483            .withf(move |_, ts, _| ts.is_none())
484            .return_once(move |_, _, _| {
485                Ok(FetchMessageResponse {
486                    timestamp: Some(timestamp),
487                    messages: vec![
488                        new_timestamp_notif(&DUMMY_CHID, 0),
489                        new_timestamp_notif(&DUMMY_CHID, 0),
490                    ],
491                })
492            });
493        // EOF
494        db.expect_fetch_timestamp_messages()
495            .times(1)
496            .in_sequence(&mut seq)
497            .withf(move |_, ts, _| ts == &Some(timestamp))
498            .return_once(|_, _, _| {
499                Ok(FetchMessageResponse {
500                    timestamp: None,
501                    messages: vec![],
502                })
503            });
504        // Ensure increment_storage's called to advance the timestamp messages
505        // despite check_storage returning nothing (all filtered out as
506        // expired)
507        db.expect_increment_storage()
508            .times(1)
509            .in_sequence(&mut seq)
510            .withf(move |_, ts| ts == &timestamp)
511            .return_once(|_, _| Ok(()));
512
513        // No check_storage called here (via default ClientFlags)
514        let (mut client, _) = wpclient(
515            DUMMY_UAID,
516            AppState {
517                db: db.into_boxed_arc(),
518                ..Default::default()
519            },
520        )
521        .await;
522
523        let smsgs = client
524            .on_server_notif(ServerNotification::CheckStorage)
525            .await
526            .expect("CheckStorage failed");
527        assert!(smsgs.is_empty())
528    }
529}