1use std::{collections::HashMap, fmt, sync::Arc};
2
3use cadence::Histogrammed;
4use uuid::Uuid;
5
6use autoconnect_common::{
7 broadcast::{Broadcast, BroadcastSubs, BroadcastSubsInit},
8 protocol::{BroadcastValue, ClientMessage, MessageType, ServerMessage},
9};
10use autoconnect_settings::{AppState, Settings};
11use autopush_common::{
12 db::{USER_RECORD_VERSION, User},
13 metric_name::MetricName,
14 metrics::StatsdClientExt,
15 util::{ms_since_epoch, ms_utc_midnight},
16};
17
18use crate::{
19 error::{SMError, SMErrorKind},
20 identified::{ClientFlags, WebPushClient},
21};
22
23pub struct UnidentifiedClient {
26 ua: String,
28 app_state: Arc<AppState>,
29}
30
31impl fmt::Debug for UnidentifiedClient {
32 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33 fmt.debug_struct("UnidentifiedClient")
34 .field("ua", &self.ua)
35 .finish()
36 }
37}
38
39impl UnidentifiedClient {
40 pub fn new(ua: String, app_state: Arc<AppState>) -> Self {
41 UnidentifiedClient { ua, app_state }
42 }
43
44 pub fn app_settings(&self) -> &Settings {
46 &self.app_state.settings
47 }
48
49 pub async fn on_client_msg(
54 self,
55 msg: ClientMessage,
56 ) -> Result<(WebPushClient, impl IntoIterator<Item = ServerMessage>), SMError> {
57 trace!("❓UnidentifiedClient::on_client_msg");
58 SMError::validate_message_type(MessageType::Hello, &msg)?;
60
61 let ClientMessage::Hello {
63 uaid,
64 broadcasts,
65 _channel_ids,
66 } = msg
67 else {
68 return Err(SMError::expected_message_type(MessageType::Hello));
70 };
71 debug!(
72 "👋UnidentifiedClient::on_client_msg {} from uaid?: {:?}",
73 MessageType::Hello.as_ref(),
74 uaid
75 );
76
77 let original_uaid = uaid.as_deref().and_then(|uaid| Uuid::try_parse(uaid).ok());
79
80 let GetOrCreateUser {
81 user,
82 existing_user,
83 flags,
84 } = self.get_or_create_user(original_uaid).await?;
85 let uaid = user.uaid;
86 debug!(
87 "💬UnidentifiedClient::on_client_msg {}! uaid: {} existing_user: {}",
88 MessageType::Hello.as_ref(),
89 uaid,
90 existing_user,
91 );
92 self.app_state
93 .metrics
94 .incr_with_tags(MetricName::UaCommandHello)
95 .with_tag("uaid", {
96 if existing_user {
97 "existing"
98 } else if original_uaid.unwrap_or(uaid) != uaid {
99 "reassigned"
100 } else {
101 "new"
102 }
103 })
104 .send();
105
106 if flags.emit_channel_metrics {
108 self.app_state
112 .metrics
113 .histogram_with_tags("ua.connection.channel_count", user.channel_count() as u64)
114 .with_tag_value("desktop")
115 .send();
116 }
117
118 let (broadcast_subs, broadcasts) = self
119 .broadcast_init(&Broadcast::from_hashmap(broadcasts.unwrap_or_default()))
120 .await;
121 let (wpclient, check_storage_smsgs) = WebPushClient::new(
122 uaid,
123 self.ua,
124 broadcast_subs,
125 flags,
126 user.connected_at,
127 user.current_timestamp,
128 (!existing_user).then_some(user),
129 self.app_state,
130 )
131 .await?;
132
133 let smsg = ServerMessage::Hello {
134 uaid: uaid.as_simple().to_string(),
135 use_webpush: true,
136 status: 200,
137 broadcasts,
138 };
139 let smsgs = std::iter::once(smsg).chain(check_storage_smsgs);
140 Ok((wpclient, smsgs))
141 }
142
143 async fn get_or_create_user(&self, uaid: Option<Uuid>) -> Result<GetOrCreateUser, SMError> {
145 trace!(
146 "❓UnidentifiedClient::get_or_create_user for {}",
147 MessageType::Hello.as_ref()
148 );
149 let connected_at = ms_since_epoch();
150
151 if let Some(uaid) = uaid
152 && let Some(mut user) = self.app_state.db.get_user(&uaid).await?
153 {
154 let flags = ClientFlags {
155 check_storage: true,
156 old_record_version: user
157 .record_version
158 .is_none_or(|rec_ver| rec_ver < USER_RECORD_VERSION),
159 emit_channel_metrics: user.connected_at < ms_utc_midnight(),
160 ..Default::default()
161 };
162 user.node_id = Some(self.app_state.router_url.to_owned());
163 if user.connected_at > connected_at {
164 let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
165 return Err(SMErrorKind::AlreadyConnected.into());
166 }
167 user.connected_at = connected_at;
168 if !self.app_state.db.update_user(&mut user).await? {
169 let _ = self.app_state.metrics.incr(MetricName::UaAlreadyConnected);
170 return Err(SMErrorKind::AlreadyConnected.into());
171 }
172 return Ok(GetOrCreateUser {
173 user,
174 existing_user: true,
175 flags,
176 });
177
178 }
183
184 let user = User::builder()
185 .node_id(self.app_state.router_url.to_owned())
186 .connected_at(connected_at)
187 .build()
188 .map_err(|e| SMErrorKind::Internal(format!("User::builder error: {e}")))?;
189 Ok(GetOrCreateUser {
190 user,
191 existing_user: false,
192 flags: Default::default(),
193 })
194 }
195
196 async fn broadcast_init(
198 &self,
199 broadcasts: &[Broadcast],
200 ) -> (BroadcastSubs, HashMap<String, BroadcastValue>) {
201 trace!(
202 "UnidentifiedClient::broadcast_init for {}",
203 MessageType::Hello.as_ref()
204 );
205 let bc = self.app_state.broadcaster.read().await;
206 let BroadcastSubsInit(broadcast_subs, delta) = bc.broadcast_delta(broadcasts);
207 let mut response = Broadcast::vec_into_hashmap(delta);
208 let missing = bc.missing_broadcasts(broadcasts);
209 if !missing.is_empty() {
210 response.insert(
211 "errors".to_owned(),
212 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
213 );
214 }
215 (broadcast_subs, response)
216 }
217}
218
219struct GetOrCreateUser {
221 user: User,
222 existing_user: bool,
223 flags: ClientFlags,
224}
225
226#[cfg(test)]
227mod tests {
228 use std::str::FromStr;
229 use std::sync::Arc;
230
231 use autoconnect_common::{
232 protocol::{ClientMessage, MessageType},
233 test_support::{DUMMY_CHID, DUMMY_UAID, UA, hello_again_db, hello_again_json, hello_db},
234 };
235 use autoconnect_settings::AppState;
236
237 use crate::error::SMErrorKind;
238
239 use super::UnidentifiedClient;
240
241 #[ctor::ctor]
242 fn init_test_logging() {
243 autopush_common::logging::init_test_logging();
244 }
245
246 fn uclient(app_state: AppState) -> UnidentifiedClient {
247 UnidentifiedClient::new(UA.to_owned(), Arc::new(app_state))
248 }
249
250 #[tokio::test]
251 async fn reject_not_hello() {
252 let client = uclient(Default::default());
254 let err = client
255 .on_client_msg(ClientMessage::Ping)
256 .await
257 .err()
258 .unwrap();
259 assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
260 assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
262
263 let client = uclient(Default::default());
265 let err = client
266 .on_client_msg(ClientMessage::Register {
267 channel_id: DUMMY_CHID.to_string(),
268 key: None,
269 })
270 .await
271 .err()
272 .unwrap();
273 assert!(matches!(err.kind, SMErrorKind::InvalidMessage(_)));
274 assert!(format!("{err}").contains(MessageType::Hello.as_ref()));
276 }
277
278 #[tokio::test]
279 async fn hello_existing_user() {
280 let client = uclient(AppState {
281 db: hello_again_db(DUMMY_UAID).into_boxed_arc(),
282 ..Default::default()
283 });
284 let js = hello_again_json();
286 let msg: ClientMessage = serde_json::from_str(&js).unwrap();
287 client.on_client_msg(msg).await.expect("Hello failed");
288 }
289
290 #[tokio::test]
291 async fn hello_new_user() {
292 let client = uclient(AppState {
293 db: hello_db().into_boxed_arc(),
295 ..Default::default()
296 });
297 let json = serde_json::json!({"messageType":"hello"});
301 let raw = json.to_string();
302 let msg = ClientMessage::from_str(&raw).unwrap();
303 client.on_client_msg(msg).await.expect("Hello failed");
304 }
305
306 #[tokio::test]
307 async fn hello_empty_uaid() {
308 let client = uclient(Default::default());
309 let msg = ClientMessage::Hello {
310 uaid: Some("".to_owned()),
311 _channel_ids: None,
312 broadcasts: None,
313 };
314 client.on_client_msg(msg).await.expect("Hello failed");
315 }
316
317 #[tokio::test]
318 async fn hello_invalid_uaid() {
319 let client = uclient(Default::default());
320 let msg = ClientMessage::Hello {
321 uaid: Some("invalid".to_owned()),
322 _channel_ids: None,
323 broadcasts: None,
324 };
325 client.on_client_msg(msg).await.expect("Hello failed");
326 }
327
328 #[tokio::test]
329 async fn hello_bad_user() {}
330}