autoconnect_ws_sm/
unidentified.rs

1use std::{collections::HashMap, fmt, sync::Arc};
2
3use cadence::Histogrammed;
4use uuid::Uuid;
5
6use autoconnect_common::{
7    broadcast::{Broadcast, BroadcastSubs, BroadcastSubsInit},
8    protocol::{BroadcastValue, ClientMessage, MessageType, ServerMessage},
9};
10use autoconnect_settings::{AppState, Settings};
11use autopush_common::{
12    db::{USER_RECORD_VERSION, User},
13    metric_name::MetricName,
14    metrics::StatsdClientExt,
15    util::{ms_since_epoch, ms_utc_midnight},
16};
17
18use crate::{
19    error::{SMError, SMErrorKind},
20    identified::{ClientFlags, WebPushClient},
21};
22
23/// Represents a Client waiting for (or yet to process) a Hello message
24/// identifying itself
25pub struct UnidentifiedClient {
26    /// Client's User-Agent header
27    ua: String,
28    app_state: Arc<AppState>,
29}
30
31impl fmt::Debug for UnidentifiedClient {
32    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33        fmt.debug_struct("UnidentifiedClient")
34            .field("ua", &self.ua)
35            .finish()
36    }
37}
38
39impl UnidentifiedClient {
40    pub fn new(ua: String, app_state: Arc<AppState>) -> Self {
41        UnidentifiedClient { ua, app_state }
42    }
43
44    /// Return a reference to `AppState`'s `Settings`
45    pub fn app_settings(&self) -> &Settings {
46        &self.app_state.settings
47    }
48
49    /// Handle a WebPush `ClientMessage` sent from the user agent over the
50    /// WebSocket for this user
51    ///
52    /// Anything but a Hello message is rejected as an Error
53    pub async fn on_client_msg(
54        self,
55        msg: ClientMessage,
56    ) -> Result<(WebPushClient, impl IntoIterator<Item = ServerMessage>), SMError> {
57        trace!("❓UnidentifiedClient::on_client_msg");
58        // Validate we received a Hello message before proceeding
59        SMError::validate_message_type(MessageType::Hello, &msg)?;
60
61        // Extract fields from the Hello message
62        let ClientMessage::Hello {
63            uaid,
64            broadcasts,
65            _channel_ids,
66        } = msg
67        else {
68            // This should never happen due to the validate_message_type check above
69            return Err(SMError::expected_message_type(MessageType::Hello));
70        };
71        debug!(
72            "👋UnidentifiedClient::on_client_msg {} from uaid?: {:?}",
73            MessageType::Hello.as_ref(),
74            uaid
75        );
76
77        // Ignore invalid uaids (treat as None) so they'll be issued a new one
78        let original_uaid = uaid.as_deref().and_then(|uaid| Uuid::try_parse(uaid).ok());
79
80        let GetOrCreateUser {
81            user,
82            existing_user,
83            flags,
84        } = self.get_or_create_user(original_uaid).await?;
85        let uaid = user.uaid;
86        debug!(
87            "💬UnidentifiedClient::on_client_msg {}! uaid: {} existing_user: {}",
88            MessageType::Hello.as_ref(),
89            uaid,
90            existing_user,
91        );
92        self.app_state
93            .metrics
94            .incr_with_tags(MetricName::UaCommandHello)
95            .with_tag("uaid", {
96                if existing_user {
97                    "existing"
98                } else if original_uaid.unwrap_or(uaid) != uaid {
99                    "reassigned"
100                } else {
101                    "new"
102                }
103            })
104            .send();
105
106        // This is the first time that the user has connected "today".
107        if flags.emit_channel_metrics {
108            // Return the number of channels for the user using the internal channel_count.
109            // NOTE: this metric can be approximate since we're sampling to determine the
110            // approximate average of channels per user for business reasons.
111            self.app_state
112                .metrics
113                .histogram_with_tags("ua.connection.channel_count", user.channel_count() as u64)
114                .with_tag_value("desktop")
115                .send();
116        }
117
118        let (broadcast_subs, broadcasts) = self
119            .broadcast_init(&Broadcast::from_hashmap(broadcasts.unwrap_or_default()))
120            .await;
121        let (wpclient, check_storage_smsgs) = WebPushClient::new(
122            uaid,
123            self.ua,
124            broadcast_subs,
125            flags,
126            user.connected_at,
127            user.current_timestamp,
128            (!existing_user).then_some(user),
129            self.app_state,
130        )
131        .await?;
132
133        let smsg = ServerMessage::Hello {
134            uaid: uaid.as_simple().to_string(),
135            use_webpush: true,
136            status: 200,
137            broadcasts,
138        };
139        let smsgs = std::iter::once(smsg).chain(check_storage_smsgs);
140        Ok((wpclient, smsgs))
141    }
142
143    /// Lookup a User or return a new User record if the lookup failed
144    async fn get_or_create_user(&self, uaid: Option<Uuid>) -> Result<GetOrCreateUser, SMError> {
145        trace!(
146            "❓UnidentifiedClient::get_or_create_user for {}",
147            MessageType::Hello.as_ref()
148        );
149        let connected_at = ms_since_epoch();
150
151        if let Some(uaid) = uaid
152            && let Some(mut user) = self.app_state.db.get_user(&uaid).await?
153        {
154            let flags = ClientFlags {
155                check_storage: true,
156                old_record_version: user
157                    .record_version
158                    .is_none_or(|rec_ver| rec_ver < USER_RECORD_VERSION),
159                emit_channel_metrics: user.connected_at < ms_utc_midnight(),
160                ..Default::default()
161            };
162            user.node_id = Some(self.app_state.router_url.to_owned());
163            if user.connected_at > connected_at {
164                let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
165                return Err(SMErrorKind::AlreadyConnected.into());
166            }
167            user.connected_at = connected_at;
168            if !self.app_state.db.update_user(&mut user).await? {
169                let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
170                return Err(SMErrorKind::AlreadyConnected.into());
171            }
172            return Ok(GetOrCreateUser {
173                user,
174                existing_user: true,
175                flags,
176            });
177
178            // NOTE: when the client's specified a uaid but get_user returns
179            // None (or process_existing_user dropped the user record due to it
180            // being invalid) we're now deferring the db.add_user call (a
181            // change from the previous state machine impl)
182        }
183
184        let user = User::builder()
185            .node_id(self.app_state.router_url.to_owned())
186            .connected_at(connected_at)
187            .build()
188            .map_err(|e| SMErrorKind::Internal(format!("User::builder error: {e}")))?;
189        Ok(GetOrCreateUser {
190            user,
191            existing_user: false,
192            flags: Default::default(),
193        })
194    }
195
196    /// Initialize `Broadcast`s for a new `WebPushClient`
197    async fn broadcast_init(
198        &self,
199        broadcasts: &[Broadcast],
200    ) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
201        trace!(
202            "UnidentifiedClient::broadcast_init for {}",
203            MessageType::Hello.as_ref()
204        );
205        let bc = self.app_state.broadcaster.read().await;
206        let BroadcastSubsInit(broadcast_subs, delta) = bc.broadcast_delta(broadcasts);
207        let mut response = Broadcast::vec_into_hashmap(delta);
208        let missing = bc.missing_broadcasts(broadcasts);
209        if !missing.is_empty() {
210            response.insert(
211                "errors".to_owned(),
212                BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
213            );
214        }
215        (broadcast_subs, response)
216    }
217}
218
219/// Result of a User lookup for a Hello message
220struct GetOrCreateUser {
221    user: User,
222    existing_user: bool,
223    flags: ClientFlags,
224}
225
226#[cfg(test)]
227mod tests {
228    use std::str::FromStr;
229    use std::sync::Arc;
230
231    use autoconnect_common::{
232        protocol::{ClientMessage, MessageType},
233        test_support::{DUMMY_CHID, DUMMY_UAID, UA, hello_again_db, hello_again_json, hello_db},
234    };
235    use autoconnect_settings::AppState;
236
237    use crate::error::SMErrorKind;
238
239    use super::UnidentifiedClient;
240
241    #[ctor::ctor]
242    fn init_test_logging() {
243        autopush_common::logging::init_test_logging();
244    }
245
246    fn uclient(app_state: AppState) -> UnidentifiedClient {
247        UnidentifiedClient::new(UA.to_owned(), Arc::new(app_state))
248    }
249
250    #[tokio::test]
251    async fn reject_not_hello() {
252        // Test with Ping message
253        let client = uclient(Default::default());
254        let err = client
255            .on_client_msg(ClientMessage::Ping)
256            .await
257            .err()
258            .unwrap();
259        assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
260        // Verify error message contains expected message type
261        assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
262
263        // Test with Register message
264        let client = uclient(Default::default());
265        let err = client
266            .on_client_msg(ClientMessage::Register {
267                channel_id: DUMMY_CHID.to_string(),
268                key: None,
269            })
270            .await
271            .err()
272            .unwrap();
273        assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
274        // Verify error message contains expected message type
275        assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
276    }
277
278    #[tokio::test]
279    async fn hello_existing_user() {
280        let client = uclient(AppState {
281            db: hello_again_db(DUMMY_UAID).into_boxed_arc(),
282            ..Default::default()
283        });
284        // Use hello_again_json helper which properly uses MessageType enum
285        let js = hello_again_json();
286        let msg: ClientMessage = serde_json::from_str(&js).unwrap();
287        client.on_client_msg(msg).await.expect("Hello failed");
288    }
289
290    #[tokio::test]
291    async fn hello_new_user() {
292        let client = uclient(AppState {
293            // Simple hello_db ensures no writes to the db
294            db: hello_db().into_boxed_arc(),
295            ..Default::default()
296        });
297        // Ensure that we do not need to pass the "use_webpush" flag.
298        // (yes, this could just be passing the string, but I want to be
299        // very explicit here.)
300        let json = serde_json::json!({"messageType":"hello"});
301        let raw = json.to_string();
302        let msg = ClientMessage::from_str(&raw).unwrap();
303        client.on_client_msg(msg).await.expect("Hello failed");
304    }
305
306    #[tokio::test]
307    async fn hello_empty_uaid() {
308        let client = uclient(Default::default());
309        let msg = ClientMessage::Hello {
310            uaid: Some("".to_owned()),
311            _channel_ids: None,
312            broadcasts: None,
313        };
314        client.on_client_msg(msg).await.expect("Hello failed");
315    }
316
317    #[tokio::test]
318    async fn hello_invalid_uaid() {
319        let client = uclient(Default::default());
320        let msg = ClientMessage::Hello {
321            uaid: Some("invalid".to_owned()),
322            _channel_ids: None,
323            broadcasts: None,
324        };
325        client.on_client_msg(msg).await.expect("Hello failed");
326    }
327
328    #[tokio::test]
329    async fn hello_bad_user() {}
330}