1use std::{collections::HashMap, fmt, sync::Arc};
2
3use cadence::Histogrammed;
4use uuid::Uuid;
5
6use autoconnect_common::{
7 broadcast::{Broadcast, BroadcastSubs, BroadcastSubsInit},
8 protocol::{BroadcastValue, ClientMessage, MessageType, ServerMessage},
9};
10use autoconnect_settings::{AppState, Settings};
11use autopush_common::{
12 db::{User, USER_RECORD_VERSION},
13 metric_name::MetricName,
14 metrics::StatsdClientExt,
15 util::{ms_since_epoch, ms_utc_midnight},
16};
17
18use crate::{
19 error::{SMError, SMErrorKind},
20 identified::{ClientFlags, WebPushClient},
21};
22
23pub struct UnidentifiedClient {
26 ua: String,
28 app_state: Arc<AppState>,
29}
30
31impl fmt::Debug for UnidentifiedClient {
32 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33 fmt.debug_struct("UnidentifiedClient")
34 .field("ua", &self.ua)
35 .finish()
36 }
37}
38
39impl UnidentifiedClient {
40 pub fn new(ua: String, app_state: Arc<AppState>) -> Self {
41 UnidentifiedClient { ua, app_state }
42 }
43
44 pub fn app_settings(&self) -> &Settings {
46 &self.app_state.settings
47 }
48
49 pub async fn on_client_msg(
54 self,
55 msg: ClientMessage,
56 ) -> Result<(WebPushClient, impl IntoIterator<Item = ServerMessage>), SMError> {
57 trace!("❓UnidentifiedClient::on_client_msg");
58 SMError::validate_message_type(MessageType::Hello, &msg)?;
60
61 let ClientMessage::Hello {
63 uaid,
64 broadcasts,
65 _channel_ids,
66 } = msg
67 else {
68 return Err(SMError::expected_message_type(MessageType::Hello));
70 };
71 debug!(
72 "👋UnidentifiedClient::on_client_msg {} from uaid?: {:?}",
73 MessageType::Hello.as_ref(),
74 uaid
75 );
76
77 let original_uaid = uaid.as_deref().and_then(|uaid| Uuid::try_parse(uaid).ok());
79
80 let GetOrCreateUser {
81 user,
82 existing_user,
83 flags,
84 } = self.get_or_create_user(original_uaid).await?;
85 let uaid = user.uaid;
86 debug!(
87 "💬UnidentifiedClient::on_client_msg {}! uaid: {} existing_user: {}",
88 MessageType::Hello.as_ref(),
89 uaid,
90 existing_user,
91 );
92 self.app_state
93 .metrics
94 .incr_with_tags(MetricName::UaCommand(MessageType::Hello.to_string()))
95 .with_tag("uaid", {
96 if existing_user {
97 "existing"
98 } else if original_uaid.unwrap_or(uaid) != uaid {
99 "reassigned"
100 } else {
101 "new"
102 }
103 })
104 .send();
105
106 if flags.emit_channel_metrics {
108 self.app_state
112 .metrics
113 .histogram_with_tags("ua.connection.channel_count", user.channel_count() as u64)
114 .with_tag_value("desktop")
115 .send();
116 }
117
118 let (broadcast_subs, broadcasts) = self
119 .broadcast_init(&Broadcast::from_hashmap(broadcasts.unwrap_or_default()))
120 .await;
121 let (wpclient, check_storage_smsgs) = WebPushClient::new(
122 uaid,
123 self.ua,
124 broadcast_subs,
125 flags,
126 user.connected_at,
127 user.current_timestamp,
128 (!existing_user).then_some(user),
129 self.app_state,
130 )
131 .await?;
132
133 let smsg = ServerMessage::Hello {
134 uaid: uaid.as_simple().to_string(),
135 use_webpush: true,
136 status: 200,
137 broadcasts,
138 };
139 let smsgs = std::iter::once(smsg).chain(check_storage_smsgs);
140 Ok((wpclient, smsgs))
141 }
142
143 async fn get_or_create_user(&self, uaid: Option<Uuid>) -> Result<GetOrCreateUser, SMError> {
145 trace!(
146 "❓UnidentifiedClient::get_or_create_user for {}",
147 MessageType::Hello.as_ref()
148 );
149 let connected_at = ms_since_epoch();
150
151 if let Some(uaid) = uaid {
152 if let Some(mut user) = self.app_state.db.get_user(&uaid).await? {
153 let flags = ClientFlags {
154 check_storage: true,
155 old_record_version: user
156 .record_version
157 .is_none_or(|rec_ver| rec_ver < USER_RECORD_VERSION),
158 emit_channel_metrics: user.connected_at < ms_utc_midnight(),
159 ..Default::default()
160 };
161 user.node_id = Some(self.app_state.router_url.to_owned());
162 if user.connected_at > connected_at {
163 let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
164 return Err(SMErrorKind::AlreadyConnected.into());
165 }
166 user.connected_at = connected_at;
167 if !self.app_state.db.update_user(&mut user).await? {
168 let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
169 return Err(SMErrorKind::AlreadyConnected.into());
170 }
171 return Ok(GetOrCreateUser {
172 user,
173 existing_user: true,
174 flags,
175 });
176 }
177 }
182
183 let user = User::builder()
184 .node_id(self.app_state.router_url.to_owned())
185 .connected_at(connected_at)
186 .build()
187 .map_err(|e| SMErrorKind::Internal(format!("User::builder error: {e}")))?;
188 Ok(GetOrCreateUser {
189 user,
190 existing_user: false,
191 flags: Default::default(),
192 })
193 }
194
195 async fn broadcast_init(
197 &self,
198 broadcasts: &[Broadcast],
199 ) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
200 trace!(
201 "UnidentifiedClient::broadcast_init for {}",
202 MessageType::Hello.as_ref()
203 );
204 let bc = self.app_state.broadcaster.read().await;
205 let BroadcastSubsInit(broadcast_subs, delta) = bc.broadcast_delta(broadcasts);
206 let mut response = Broadcast::vec_into_hashmap(delta);
207 let missing = bc.missing_broadcasts(broadcasts);
208 if !missing.is_empty() {
209 response.insert(
210 "errors".to_owned(),
211 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
212 );
213 }
214 (broadcast_subs, response)
215 }
216}
217
218struct GetOrCreateUser {
220 user: User,
221 existing_user: bool,
222 flags: ClientFlags,
223}
224
225#[cfg(test)]
226mod tests {
227 use std::str::FromStr;
228 use std::sync::Arc;
229
230 use autoconnect_common::{
231 protocol::{ClientMessage, MessageType},
232 test_support::{hello_again_db, hello_again_json, hello_db, DUMMY_CHID, DUMMY_UAID, UA},
233 };
234 use autoconnect_settings::AppState;
235
236 use crate::error::SMErrorKind;
237
238 use super::UnidentifiedClient;
239
240 #[ctor::ctor]
241 fn init_test_logging() {
242 autopush_common::logging::init_test_logging();
243 }
244
245 fn uclient(app_state: AppState) -> UnidentifiedClient {
246 UnidentifiedClient::new(UA.to_owned(), Arc::new(app_state))
247 }
248
249 #[tokio::test]
250 async fn reject_not_hello() {
251 let client = uclient(Default::default());
253 let err = client
254 .on_client_msg(ClientMessage::Ping)
255 .await
256 .err()
257 .unwrap();
258 assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
259 assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
261
262 let client = uclient(Default::default());
264 let err = client
265 .on_client_msg(ClientMessage::Register {
266 channel_id: DUMMY_CHID.to_string(),
267 key: None,
268 })
269 .await
270 .err()
271 .unwrap();
272 assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
273 assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
275 }
276
277 #[tokio::test]
278 async fn hello_existing_user() {
279 let client = uclient(AppState {
280 db: hello_again_db(DUMMY_UAID).into_boxed_arc(),
281 ..Default::default()
282 });
283 let js = hello_again_json();
285 let msg: ClientMessage = serde_json::from_str(&js).unwrap();
286 client.on_client_msg(msg).await.expect("Hello failed");
287 }
288
289 #[tokio::test]
290 async fn hello_new_user() {
291 let client = uclient(AppState {
292 db: hello_db().into_boxed_arc(),
294 ..Default::default()
295 });
296 let json = serde_json::json!({"messageType":"hello"});
300 let raw = json.to_string();
301 let msg = ClientMessage::from_str(&raw).unwrap();
302 client.on_client_msg(msg).await.expect("Hello failed");
303 }
304
305 #[tokio::test]
306 async fn hello_empty_uaid() {
307 let client = uclient(Default::default());
308 let msg = ClientMessage::Hello {
309 uaid: Some("".to_owned()),
310 _channel_ids: None,
311 broadcasts: None,
312 };
313 client.on_client_msg(msg).await.expect("Hello failed");
314 }
315
316 #[tokio::test]
317 async fn hello_invalid_uaid() {
318 let client = uclient(Default::default());
319 let msg = ClientMessage::Hello {
320 uaid: Some("invalid".to_owned()),
321 _channel_ids: None,
322 broadcasts: None,
323 };
324 client.on_client_msg(msg).await.expect("Hello failed");
325 }
326
327 #[tokio::test]
328 async fn hello_bad_user() {}
329}