1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{HashMap, hash_map::RandomState};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Instant;
11use url::Url;
12use uuid::Uuid;
13
14use crate::error::{ApiError, ApiErrorKind, ApiResult};
15use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
16use crate::headers::vapid::VapidHeaderWithKey;
17use crate::routers::{Router, RouterError, RouterResponse};
18
19use autopush_common::db::{User, client::DbClient};
20use autopush_common::metric_name::MetricName;
21use autopush_common::metrics::StatsdClientExt;
22
23pub struct WebPushRouter {
29 pub db: Box<dyn DbClient>,
30 pub metrics: Arc<StatsdClient>,
31 pub http: reqwest::Client,
32 pub endpoint_url: Url,
33 pub in_flight_requests: Arc<AtomicUsize>,
34 #[cfg(feature = "reliable_report")]
35 pub reliability: Arc<PushReliability>,
36}
37
38#[async_trait(?Send)]
39impl Router for WebPushRouter {
40 fn register(
41 &self,
42 _router_input: &RouterDataInput,
43 _app_id: &str,
44 ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
45 Ok(HashMap::new())
47 }
48
49 async fn route_notification(&self, notification: Notification) -> ApiResult<RouterResponse> {
50 let route_start = Instant::now();
51 let result = self.route_notification_inner(notification).await;
52 self.metrics
53 .time_with_tags(
54 MetricName::NotificationRouteTime.as_ref(),
55 route_start.elapsed().as_millis() as u64,
56 )
57 .with_tag("outcome", if result.is_ok() { "ok" } else { "error" })
58 .send();
59 result
60 }
61}
62
63impl WebPushRouter {
64 async fn route_notification_inner(
65 &self,
66 mut notification: Notification,
67 ) -> ApiResult<RouterResponse> {
68 let notif_user = ¬ification.subscription.user;
72 let uaid = notif_user.uaid;
73 let node_id = notif_user.node_id.clone();
74 debug!(
75 "✉ Routing WebPush notification to UAID {} :: {:?}",
76 uaid, notification.subscription.reliability_id,
77 );
78 trace!("✉ Notification = {:?}", notification);
79
80 if let Some(node_id) = node_id {
82 trace!(
83 "✉ User has a node ID, sending notification to node: {}",
84 &node_id
85 );
86
87 #[cfg(feature = "reliable_report")]
88 let revert_state = notification.reliable_state;
89 #[cfg(feature = "reliable_report")]
90 notification
91 .record_reliability(
92 &self.reliability,
93 autopush_common::reliability::ReliabilityState::IntTransmitted,
94 )
95 .await;
96 let send_start = Instant::now();
97 match self.send_notification(¬ification, &node_id).await {
98 Ok(response) => {
99 let elapsed = send_start.elapsed().as_millis() as u64;
100 let status = response.status().as_u16();
101 self.metrics
102 .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
103 .with_tag("status", &status.to_string())
104 .send();
105 self.metrics
106 .incr_with_tags(MetricName::DirectDeliveryStatus)
107 .with_tag("status", &status.to_string())
108 .send();
109 if status == 200 {
111 trace!("✉ Node received notification");
113 return Ok(self.make_delivered_response(¬ification));
114 }
115 trace!(
116 "✉ Node did not receive the notification, response = {:?}",
117 response
118 );
119 }
120 Err(error) => {
121 let elapsed = send_start.elapsed().as_millis() as u64;
122 let status_tag = if let ApiErrorKind::ReqwestError(error) = &error.kind {
123 if error.is_timeout() {
124 self.metrics.incr(MetricName::ErrorNodeTimeout)?;
125 "timeout"
126 } else if error.is_connect() {
127 self.metrics.incr(MetricName::ErrorNodeConnect)?;
128 "connect_error"
129 } else {
130 "error"
131 }
132 } else {
133 "error"
134 };
135 self.metrics
136 .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
137 .with_tag("status", status_tag)
138 .send();
139 self.metrics
140 .incr_with_tags(MetricName::DirectDeliveryStatus)
141 .with_tag("status", status_tag)
142 .send();
143 debug!("✉ Error while sending webpush notification: {}", error);
144 self.remove_node_id(¬ification.subscription.user, &node_id)
145 .await?
146 }
147 }
148
149 #[cfg(feature = "reliable_report")]
150 if let Some(revert_state) = revert_state {
152 trace!(
153 "🔎⚠️ Revert {:?} from {:?} to {:?}",
154 ¬ification.reliability_id, ¬ification.reliable_state, revert_state
155 );
156 notification
157 .record_reliability(&self.reliability, revert_state)
158 .await;
159 }
160 }
161
162 if notification.headers.ttl == 0 {
163 let topic = notification.headers.topic.is_some().to_string();
164 trace!(
165 "✉ Notification has a TTL of zero and was not successfully \
166 delivered, dropping it"
167 );
168 self.metrics
169 .incr_with_tags(MetricName::NotificationMessageExpired)
170 .with_tag("topic", &topic)
172 .send();
173 #[cfg(feature = "reliable_report")]
174 notification
175 .record_reliability(
176 &self.reliability,
177 autopush_common::reliability::ReliabilityState::Expired,
178 )
179 .await;
180 return Ok(self.make_delivered_response(¬ification));
181 }
182
183 trace!("✉ Node is not present or busy, storing notification");
185 let store_start = Instant::now();
186 self.store_notification(&mut notification).await?;
187 self.metrics
188 .time_with_tags(
189 MetricName::StorageSaveTime.as_ref(),
190 store_start.elapsed().as_millis() as u64,
191 )
192 .send();
193
194 let user = match self.db.get_user(&uaid).await {
197 Ok(Some(user)) => user,
198 Ok(None) => {
199 trace!("✉ No user found, must have been deleted");
200 return Err(self.handle_error(
201 ApiErrorKind::Router(RouterError::UserWasDeleted),
202 notification.subscription.vapid.clone(),
203 ));
204 }
205 Err(e) => {
206 debug!("✉ Database error while re-fetching user: {}", e);
208 return Ok(self.make_stored_response(¬ification));
209 }
210 };
211
212 let node_id = match &user.node_id {
214 Some(id) => id,
215 None => {
217 trace!("✉ User is not connected to a node, returning stored response");
218 return Ok(self.make_stored_response(¬ification));
219 }
220 };
221
222 trace!("✉ Notifying node to check for messages");
224 match self.trigger_notification_check(&user.uaid, node_id).await {
225 Ok(response) => {
226 trace!("Response = {:?}", response);
227 if response.status() == 200 {
228 trace!("✉ Node has delivered the message");
229 self.metrics
230 .time_with_tags(
231 MetricName::NotificationTotalRequestTime.as_ref(),
232 (notification.timestamp - autopush_common::util::sec_since_epoch())
233 * 1000,
234 )
235 .with_tag("platform", "websocket")
236 .with_tag("app_id", "direct")
237 .send();
238
239 Ok(self.make_delivered_response(¬ification))
240 } else {
241 trace!("✉ Node has not delivered the message, returning stored response");
242 Ok(self.make_stored_response(¬ification))
243 }
244 }
245 Err(error) => {
246 debug!("✉ Error while triggering notification check: {}", error);
248 self.remove_node_id(&user, node_id).await?;
249 Ok(self.make_stored_response(¬ification))
250 }
251 }
252 }
253
254 fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
256 let mut err = ApiError::from(error);
257 if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
258 let mut extras = err.extras.unwrap_or_default();
259 if let Some(sub) = claims.sub {
260 extras.extend([("sub".to_owned(), sub)]);
261 }
262 err.extras = Some(extras);
263 };
264 err
265 }
266
267 async fn send_notification(
269 &self,
270 notification: &Notification,
271 node_id: &str,
272 ) -> ApiResult<Response> {
273 let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
274
275 let notification_out = notification.serialize_for_delivery()?;
276
277 trace!(
278 "⏩ out: Notification: {}, channel_id: {} :: {:?}",
279 ¬ification.subscription.user.uaid,
280 ¬ification.subscription.channel_id,
281 ¬ification_out,
282 );
283 self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
284 let result = self.http.put(&url).json(¬ification_out).send().await;
285 self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
286 Ok(result?)
287 }
288
289 async fn trigger_notification_check(
291 &self,
292 uaid: &Uuid,
293 node_id: &str,
294 ) -> Result<Response, reqwest::Error> {
295 let url = format!("{node_id}/notif/{uaid}");
296
297 self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
298 let result = self.http.put(&url).send().await;
299 self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
300 result
301 }
302
303 async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
305 let result = self
306 .db
307 .save_message(
308 ¬ification.subscription.user.uaid,
309 autopush_common::notification::Notification::from(&*notification),
310 )
311 .await
312 .map_err(|e| {
313 self.handle_error(
314 ApiErrorKind::Router(RouterError::SaveDb(
315 e,
316 notification.subscription.vapid.as_ref().map(|vapid| {
318 vapid
319 .vapid
320 .claims()
321 .ok()
322 .and_then(|c| c.sub)
323 .unwrap_or_default()
324 }),
325 )),
326 notification.subscription.vapid.clone(),
327 )
328 });
329 #[cfg(feature = "reliable_report")]
330 notification
331 .record_reliability(
332 &self.reliability,
333 autopush_common::reliability::ReliabilityState::Stored,
334 )
335 .await;
336 result
337 }
338
339 async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
342 self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
343 let removed = self
344 .db
345 .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
346 .await?;
347 if !removed {
348 debug!("✉ The node id was not removed");
349 }
350 Ok(())
351 }
352
353 fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
356 self.make_response(notification, "Direct", StatusCode::CREATED)
357 }
358
359 fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
362 self.make_response(notification, "Stored", StatusCode::CREATED)
363 }
364
365 fn make_response(
367 &self,
368 notification: &Notification,
369 destination_tag: &str,
370 status: StatusCode,
371 ) -> RouterResponse {
372 self.metrics
373 .count_with_tags(
374 MetricName::NotificationMessageData.as_ref(),
375 notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
376 )
377 .with_tag("destination", destination_tag)
378 .send();
379
380 RouterResponse {
381 status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
382 headers: {
383 let mut map = HashMap::new();
384 map.insert(
385 "Location",
386 self.endpoint_url
387 .join(&format!("/m/{}", notification.message_id))
388 .expect("Message ID is not URL-safe")
389 .to_string(),
390 );
391 map.insert("TTL", notification.headers.ttl.to_string());
392 map
393 },
394 body: None,
395 }
396 }
397}
398
399#[cfg(test)]
400mod test {
401 use std::boxed::Box;
402 use std::sync::Arc;
403
404 use reqwest;
405
406 use crate::extractors::subscription::tests::{PUB_KEY, make_vapid};
407 use crate::headers::vapid::VapidClaims;
408 use autopush_common::errors::ReportableError;
409 #[cfg(feature = "reliable_report")]
410 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
411
412 use super::*;
413 use autopush_common::db::mock::MockDbClient;
414
415 fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
416 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
417 WebPushRouter {
418 db: db.clone(),
419 metrics: metrics.clone(),
420 http: reqwest::Client::new(),
421 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
422 in_flight_requests: Arc::new(AtomicUsize::new(0)),
423 #[cfg(feature = "reliable_report")]
424 reliability: Arc::new(
425 PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
426 ),
427 }
428 }
429
430 #[tokio::test]
431 async fn pass_extras() {
432 let db = MockDbClient::new().into_boxed_arc();
433 let router = make_router(db);
434 let sub = "foo@example.com";
435 let vapid = make_vapid(
436 sub,
437 "https://push.services.mozilla.org",
438 VapidClaims::default_exp(),
439 PUB_KEY.to_owned(),
440 );
441
442 let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
443 assert!(err.extras().contains(&("sub", sub.to_owned())));
444 }
445}