autoconnect_ws_sm/
unidentified.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use cadence::Histogrammed;
4use uuid::Uuid;
5
6use autoconnect_common::{
7    broadcast::{Broadcast, BroadcastSubs, BroadcastSubsInit},
8    protocol::{BroadcastValue, ClientMessage, MessageType, ServerMessage},
9};
10use autoconnect_settings::{AppState, Settings};
11use autopush_common::{
12    db::{User, USER_RECORD_VERSION},
13    metric_name::MetricName,
14    metrics::StatsdClientExt,
15    util::{ms_since_epoch, ms_utc_midnight},
16};
17
18use crate::{
19    error::{SMError, SMErrorKind},
20    identified::{ClientFlags, WebPushClient},
21};
22
23/// Represents a Client waiting for (or yet to process) a Hello message
24/// identifying itself
25pub struct UnidentifiedClient {
26    /// Client's User-Agent header
27    ua: String,
28    app_state: Arc<AppState>,
29}
30
31impl fmt::Debug for UnidentifiedClient {
32    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33        fmt.debug_struct("UnidentifiedClient")
34            .field("ua", &self.ua)
35            .finish()
36    }
37}
38
39impl UnidentifiedClient {
40    pub fn new(ua: String, app_state: Arc<AppState>) -> Self {
41        UnidentifiedClient { ua, app_state }
42    }
43
44    /// Return a reference to `AppState`'s `Settings`
45    pub fn app_settings(&self) -> &Settings {
46        &self.app_state.settings
47    }
48
49    /// Handle a WebPush `ClientMessage` sent from the user agent over the
50    /// WebSocket for this user
51    ///
52    /// Anything but a Hello message is rejected as an Error
53    pub async fn on_client_msg(
54        self,
55        msg: ClientMessage,
56    ) -> Result<(WebPushClient, impl IntoIterator<Item = ServerMessage>), SMError> {
57        trace!("❓UnidentifiedClient::on_client_msg");
58        // Validate we received a Hello message before proceeding
59        SMError::validate_message_type(MessageType::Hello, &msg)?;
60
61        // Extract fields from the Hello message
62        let ClientMessage::Hello {
63            uaid,
64            broadcasts,
65            _channel_ids,
66        } = msg
67        else {
68            // This should never happen due to the validate_message_type check above
69            return Err(SMError::expected_message_type(MessageType::Hello));
70        };
71        debug!(
72            "👋UnidentifiedClient::on_client_msg {} from uaid?: {:?}",
73            MessageType::Hello.as_ref(),
74            uaid
75        );
76
77        // Ignore invalid uaids (treat as None) so they'll be issued a new one
78        let original_uaid = uaid.as_deref().and_then(|uaid| Uuid::try_parse(uaid).ok());
79
80        let GetOrCreateUser {
81            user,
82            existing_user,
83            flags,
84        } = self.get_or_create_user(original_uaid).await?;
85        let uaid = user.uaid;
86        debug!(
87            "💬UnidentifiedClient::on_client_msg {}! uaid: {} existing_user: {}",
88            MessageType::Hello.as_ref(),
89            uaid,
90            existing_user,
91        );
92        self.app_state
93            .metrics
94            .incr_with_tags(MetricName::UaCommand(MessageType::Hello.to_string()))
95            .with_tag("uaid", {
96                if existing_user {
97                    "existing"
98                } else if original_uaid.unwrap_or(uaid) != uaid {
99                    "reassigned"
100                } else {
101                    "new"
102                }
103            })
104            .send();
105
106        // This is the first time that the user has connected "today".
107        if flags.emit_channel_metrics {
108            // Return the number of channels for the user using the internal channel_count.
109            // NOTE: this metric can be approximate since we're sampling to determine the
110            // approximate average of channels per user for business reasons.
111            self.app_state
112                .metrics
113                .histogram_with_tags("ua.connection.channel_count", user.channel_count() as u64)
114                .with_tag_value("desktop")
115                .send();
116        }
117
118        let (broadcast_subs, broadcasts) = self
119            .broadcast_init(&Broadcast::from_hashmap(broadcasts.unwrap_or_default()))
120            .await;
121        let (wpclient, check_storage_smsgs) = WebPushClient::new(
122            uaid,
123            self.ua,
124            broadcast_subs,
125            flags,
126            user.connected_at,
127            user.current_timestamp,
128            (!existing_user).then_some(user),
129            self.app_state,
130        )
131        .await?;
132
133        let smsg = ServerMessage::Hello {
134            uaid: uaid.as_simple().to_string(),
135            use_webpush: true,
136            status: 200,
137            broadcasts,
138        };
139        let smsgs = std::iter::once(smsg).chain(check_storage_smsgs);
140        Ok((wpclient, smsgs))
141    }
142
143    /// Lookup a User or return a new User record if the lookup failed
144    async fn get_or_create_user(&self, uaid: Option<Uuid>) -> Result<GetOrCreateUser, SMError> {
145        trace!(
146            "❓UnidentifiedClient::get_or_create_user for {}",
147            MessageType::Hello.as_ref()
148        );
149        let connected_at = ms_since_epoch();
150
151        if let Some(uaid) = uaid {
152            if let Some(mut user) = self.app_state.db.get_user(&uaid).await? {
153                let flags = ClientFlags {
154                    check_storage: true,
155                    old_record_version: user
156                        .record_version
157                        .is_none_or(|rec_ver| rec_ver < USER_RECORD_VERSION),
158                    emit_channel_metrics: user.connected_at < ms_utc_midnight(),
159                    ..Default::default()
160                };
161                user.node_id = Some(self.app_state.router_url.to_owned());
162                if user.connected_at > connected_at {
163                    let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
164                    return Err(SMErrorKind::AlreadyConnected.into());
165                }
166                user.connected_at = connected_at;
167                if !self.app_state.db.update_user(&mut user).await? {
168                    let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
169                    return Err(SMErrorKind::AlreadyConnected.into());
170                }
171                return Ok(GetOrCreateUser {
172                    user,
173                    existing_user: true,
174                    flags,
175                });
176            }
177            // NOTE: when the client's specified a uaid but get_user returns
178            // None (or process_existing_user dropped the user record due to it
179            // being invalid) we're now deferring the db.add_user call (a
180            // change from the previous state machine impl)
181        }
182
183        let user = User::builder()
184            .node_id(self.app_state.router_url.to_owned())
185            .connected_at(connected_at)
186            .build()
187            .map_err(|e| SMErrorKind::Internal(format!("User::builder error: {e}")))?;
188        Ok(GetOrCreateUser {
189            user,
190            existing_user: false,
191            flags: Default::default(),
192        })
193    }
194
195    /// Initialize `Broadcast`s for a new `WebPushClient`
196    async fn broadcast_init(
197        &self,
198        broadcasts: &[Broadcast],
199    ) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
200        trace!(
201            "UnidentifiedClient::broadcast_init for {}",
202            MessageType::Hello.as_ref()
203        );
204        let bc = self.app_state.broadcaster.read().await;
205        let BroadcastSubsInit(broadcast_subs, delta) = bc.broadcast_delta(broadcasts);
206        let mut response = Broadcast::vec_into_hashmap(delta);
207        let missing = bc.missing_broadcasts(broadcasts);
208        if !missing.is_empty() {
209            response.insert(
210                "errors".to_owned(),
211                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
212            );
213        }
214        (broadcast_subs, response)
215    }
216}
217
218/// Result of a User lookup for a Hello message
219struct GetOrCreateUser {
220    user: User,
221    existing_user: bool,
222    flags: ClientFlags,
223}
224
225#[cfg(test)]
226mod tests {
227    use std::str::FromStr;
228    use std::sync::Arc;
229
230    use autoconnect_common::{
231        protocol::{ClientMessage, MessageType},
232        test_support::{hello_again_db, hello_again_json, hello_db, DUMMY_CHID, DUMMY_UAID, UA},
233    };
234    use autoconnect_settings::AppState;
235
236    use crate::error::SMErrorKind;
237
238    use super::UnidentifiedClient;
239
240    #[ctor::ctor]
241    fn init_test_logging() {
242        autopush_common::logging::init_test_logging();
243    }
244
245    fn uclient(app_state: AppState) -> UnidentifiedClient {
246        UnidentifiedClient::new(UA.to_owned(), Arc::new(app_state))
247    }
248
249    #[tokio::test]
250    async fn reject_not_hello() {
251        // Test with Ping message
252        let client = uclient(Default::default());
253        let err = client
254            .on_client_msg(ClientMessage::Ping)
255            .await
256            .err()
257            .unwrap();
258        assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
259        // Verify error message contains expected message type
260        assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
261
262        // Test with Register message
263        let client = uclient(Default::default());
264        let err = client
265            .on_client_msg(ClientMessage::Register {
266                channel_id: DUMMY_CHID.to_string(),
267                key: None,
268            })
269            .await
270            .err()
271            .unwrap();
272        assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
273        // Verify error message contains expected message type
274        assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
275    }
276
277    #[tokio::test]
278    async fn hello_existing_user() {
279        let client = uclient(AppState {
280            db: hello_again_db(DUMMY_UAID).into_boxed_arc(),
281            ..Default::default()
282        });
283        // Use hello_again_json helper which properly uses MessageType enum
284        let js = hello_again_json();
285        let msg: ClientMessage = serde_json::from_str(&js).unwrap();
286        client.on_client_msg(msg).await.expect("Hello failed");
287    }
288
289    #[tokio::test]
290    async fn hello_new_user() {
291        let client = uclient(AppState {
292            // Simple hello_db ensures no writes to the db
293            db: hello_db().into_boxed_arc(),
294            ..Default::default()
295        });
296        // Ensure that we do not need to pass the "use_webpush" flag.
297        // (yes, this could just be passing the string, but I want to be
298        // very explicit here.)
299        let json = serde_json::json!({"messageType":"hello"});
300        let raw = json.to_string();
301        let msg = ClientMessage::from_str(&raw).unwrap();
302        client.on_client_msg(msg).await.expect("Hello failed");
303    }
304
305    #[tokio::test]
306    async fn hello_empty_uaid() {
307        let client = uclient(Default::default());
308        let msg = ClientMessage::Hello {
309            uaid: Some("".to_owned()),
310            _channel_ids: None,
311            broadcasts: None,
312        };
313        client.on_client_msg(msg).await.expect("Hello failed");
314    }
315
316    #[tokio::test]
317    async fn hello_invalid_uaid() {
318        let client = uclient(Default::default());
319        let msg = ClientMessage::Hello {
320            uaid: Some("invalid".to_owned()),
321            _channel_ids: None,
322            broadcasts: None,
323        };
324        client.on_client_msg(msg).await.expect("Hello failed");
325    }
326
327    #[tokio::test]
328    async fn hello_bad_user() {}
329}