autoconnect_ws_sm/identified/
on_client_msg.rs1use std::collections::HashMap;
2
3use uuid::Uuid;
4
5use autoconnect_common::{
6 broadcast::Broadcast,
7 protocol::{BroadcastValue, ClientAck, ClientMessage, MessageType, ServerMessage},
8};
9use autopush_common::{
10 endpoint::make_endpoint, metric_name::MetricName, metrics::StatsdClientExt,
11 util::sec_since_epoch,
12};
13
14use super::WebPushClient;
15use crate::error::{SMError, SMErrorKind};
16
17impl WebPushClient {
18 pub async fn on_client_msg(
21 &mut self,
22 msg: ClientMessage,
23 ) -> Result<Vec<ServerMessage>, SMError> {
24 match msg {
25 ClientMessage::Hello { .. } => {
26 Err(SMError::invalid_message("Already Hello'd".to_owned()))
27 }
28 ClientMessage::Register { channel_id, key } => {
29 Ok(vec![self.register(channel_id, key).await?])
30 }
31 ClientMessage::Unregister { channel_id, code } => {
32 Ok(vec![self.unregister(channel_id, code).await?])
33 }
34 ClientMessage::BroadcastSubscribe { broadcasts } => Ok(self
35 .broadcast_subscribe(broadcasts)
36 .await?
37 .map_or_else(Vec::new, |smsg| vec![smsg])),
38 ClientMessage::Ack { updates } => self.ack(&updates).await,
39 ClientMessage::Nack { code, .. } => {
40 self.nack(code);
41 Ok(vec![])
42 }
43 ClientMessage::Ping => Ok(vec![self.ping()?]),
44 }
45 }
46
47 async fn register(
49 &mut self,
50 channel_id_str: String,
51 key: Option<String>,
52 ) -> Result<ServerMessage, SMError> {
53 let uaid_str = self.uaid.to_string();
54 trace!("WebPushClient:register";
55 "uaid" => &uaid_str,
56 "channel_id" => &channel_id_str,
57 "key" => &key,
58 "message_type" => MessageType::Register.as_ref(),
59 );
60 let channel_id = Uuid::try_parse(&channel_id_str).map_err(|_| {
61 SMError::invalid_message(format!("Invalid channelID: {channel_id_str}"))
62 })?;
63 if channel_id.as_hyphenated().to_string() != channel_id_str {
64 return Err(SMError::invalid_message(format!(
65 "Invalid UUID format, not lower-case/dashed: {channel_id}",
66 )));
67 }
68
69 let (status, push_endpoint) = match self.do_register(&channel_id, key).await {
70 Ok(endpoint) => {
71 let _ = self
72 .app_state
73 .metrics
74 .incr(MetricName::UaCommand(MessageType::Register.to_string()));
75 self.stats.registers += 1;
76 (200, endpoint)
77 }
78 Err(SMErrorKind::MakeEndpoint(msg)) => {
79 error!("WebPushClient::register make_endpoint failed: {}", msg);
80 (400, "Failed to generate endpoint".to_owned())
81 }
82 Err(e) => {
83 error!("WebPushClient::register failed: {}", e);
84 (500, "".to_owned())
85 }
86 };
87 Ok(ServerMessage::Register {
88 channel_id,
89 status,
90 push_endpoint,
91 })
92 }
93
94 async fn do_register(
95 &mut self,
96 channel_id: &Uuid,
97 key: Option<String>,
98 ) -> Result<String, SMErrorKind> {
99 if let Some(user) = &self.deferred_add_user {
100 debug!(
101 "💬WebPushClient::register: User not yet registered: {}",
102 &user.uaid
103 );
104 self.app_state.db.add_user(user).await?;
105 self.deferred_add_user = None;
106 }
107
108 let endpoint = make_endpoint(
109 &self.uaid,
110 channel_id,
111 key.as_deref(),
112 &self.app_state.endpoint_url,
113 &self.app_state.fernet,
114 )
115 .map_err(SMErrorKind::MakeEndpoint)?;
116 self.app_state
117 .db
118 .add_channel(&self.uaid, channel_id)
119 .await?;
120 Ok(endpoint)
121 }
122
123 async fn unregister(
125 &mut self,
126 channel_id: Uuid,
127 code: Option<u32>,
128 ) -> Result<ServerMessage, SMError> {
129 let uaid_str = self.uaid.to_string();
130 let chid_str = channel_id.to_string();
131 trace!("WebPushClient:unregister";
132 "uaid" => &uaid_str,
133 "channel_id" => &chid_str,
134 "code" => &code,
135 "message_type" => MessageType::Unregister.as_ref(),
136 );
137 let result = self
141 .app_state
142 .db
143 .remove_channel(&self.uaid, &channel_id)
144 .await;
145 let status = match result {
146 Ok(_) => {
147 self.app_state
148 .metrics
149 .incr_with_tags(MetricName::UaCommand(MessageType::Unregister.to_string()))
150 .with_tag("code", &code.unwrap_or(200).to_string())
151 .send();
152 self.stats.unregisters += 1;
153 200
154 }
155 Err(e) => {
156 error!("WebPushClient::unregister failed: {}", e);
157 500
158 }
159 };
160 Ok(ServerMessage::Unregister { channel_id, status })
161 }
162
163 async fn broadcast_subscribe(
165 &mut self,
166 broadcasts: HashMap<String, String>,
167 ) -> Result<Option<ServerMessage>, SMError> {
168 trace!("WebPushClient:broadcast_subscribe"; "message_type" => MessageType::BroadcastSubscribe.as_ref());
169 let broadcasts = Broadcast::from_hashmap(broadcasts);
170 let mut response: HashMap<String, BroadcastValue> = HashMap::new();
171
172 let bc = self.app_state.broadcaster.read().await;
173 if let Some(delta) = bc.subscribe_to_broadcasts(&mut self.broadcast_subs, &broadcasts) {
174 response.extend(Broadcast::vec_into_hashmap(delta));
175 };
176 let missing = bc.missing_broadcasts(&broadcasts);
177 if !missing.is_empty() {
178 response.insert(
179 "errors".to_owned(),
180 BroadcastValue::Nested(Broadcast::vec_into_hashmap(missing)),
181 );
182 }
183
184 Ok((!response.is_empty()).then_some(ServerMessage::Broadcast {
185 broadcasts: response,
186 }))
187 }
188
189 async fn ack(&mut self, updates: &[ClientAck]) -> Result<Vec<ServerMessage>, SMError> {
191 trace!("✅ WebPushClient:ack"; "message_type" => MessageType::Ack.as_ref());
192 let _ = self
193 .app_state
194 .metrics
195 .incr(MetricName::UaCommand(MessageType::Ack.to_string()));
196
197 for notif in updates {
198 let pos = self
202 .ack_state
203 .unacked_direct_notifs
204 .iter()
205 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
206 if let Some(pos) = pos {
208 debug!("✅ Ack (Direct)";
209 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
210 "version" => ¬if.version
211 );
212 self.ack_state.unacked_direct_notifs.remove(pos);
213 self.stats.direct_acked += 1;
214 continue;
215 };
216
217 let pos = self
219 .ack_state
220 .unacked_stored_notifs
221 .iter()
222 .position(|n| n.channel_id == notif.channel_id && n.version == notif.version);
223 if let Some(pos) = pos {
224 debug!(
225 "✅ Ack (Stored)";
226 "channel_id" => notif.channel_id.as_hyphenated().to_string(),
227 "version" => ¬if.version,
228 "message_type" => MessageType::Ack.as_ref()
229 );
230 let acked_notification = &mut self.ack_state.unacked_stored_notifs[pos];
232 let is_topic = acked_notification.topic.is_some();
233 debug!("✅ Ack notif: {:?}", &acked_notification);
234 if is_topic {
240 let chid = &acked_notification.chidmessageid();
241 debug!("✅ WebPushClient:ack removing Stored, sort_key: {}", &chid);
242 self.app_state.db.remove_message(&self.uaid, chid).await?;
243 #[cfg(feature = "reliable_report")]
247 acked_notification
248 .record_reliability(&self.app_state.reliability, notif.reliability_state())
249 .await;
250 }
251 let n = self.ack_state.unacked_stored_notifs.remove(pos);
252 #[cfg(feature = "reliable_report")]
253 if !is_topic {
254 self.ack_state.acked_stored_timestamp_notifs.push(n);
255 }
256 self.stats.stored_acked += 1;
257 continue;
258 };
259 }
260
261 if self.ack_state.unacked_notifs() {
262 Ok(vec![])
265 } else {
266 self.post_process_all_acked().await
267 }
268 }
269
270 fn nack(&mut self, code: Option<i32>) {
273 trace!("WebPushClient:nack"; "message_type" => MessageType::Nack.as_ref());
274 let code = code
276 .and_then(|code| (301..=303).contains(&code).then_some(code))
277 .unwrap_or(0);
278 self.app_state
279 .metrics
280 .incr_with_tags(MetricName::UaCommand(MessageType::Nack.to_string()))
281 .with_tag("code", &code.to_string())
282 .send();
283 self.stats.nacks += 1;
284 }
285
286 fn ping(&mut self) -> Result<ServerMessage, SMError> {
291 trace!("WebPushClient:ping"; "message_type" => MessageType::Ping.as_ref());
292 if sec_since_epoch() - self.last_ping >= 45 {
296 trace!("🏓WebPushClient Got a WebPush Ping, sending WebPush Pong");
297 self.last_ping = sec_since_epoch();
298 Ok(ServerMessage::Ping)
299 } else {
300 Err(SMErrorKind::ExcessivePing.into())
301 }
302 }
303
304 async fn post_process_all_acked(&mut self) -> Result<Vec<ServerMessage>, SMError> {
312 trace!("▶️ WebPushClient:post_process_all_acked"; "message_type" => MessageType::Notification.as_ref());
313 let flags = &self.flags;
314 if flags.check_storage {
315 if flags.increment_storage {
316 debug!(
317 "▶️ WebPushClient:post_process_all_acked check_storage && increment_storage";
318 "message_type" => MessageType::Notification.as_ref()
319 );
320 self.increment_storage().await?;
321 }
322
323 debug!("▶️ WebPushClient:post_process_all_acked check_storage"; "message_type" => MessageType::Notification.as_ref());
324 let smsgs = self.check_storage_loop().await?;
325 if !smsgs.is_empty() {
326 debug_assert!(self.flags.check_storage);
327 return Ok(smsgs);
331 }
332 debug_assert!(!self.flags.check_storage);
334 debug_assert!(!self.flags.increment_storage);
335 }
336
337 debug_assert!(!self.ack_state.unacked_notifs());
339 let flags = &self.flags;
340 if flags.old_record_version {
341 debug!("▶️ WebPushClient:post_process_all_acked; resetting uaid"; "message_type" => MessageType::Notification.as_ref());
342 self.app_state
343 .metrics
344 .incr_with_tags(MetricName::UaExpiration)
345 .with_tag("reason", "old_record_version")
346 .send();
347 self.app_state.db.remove_user(&self.uaid).await?;
348 Err(SMErrorKind::UaidReset.into())
349 } else {
350 Ok(vec![])
351 }
352 }
353}