1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{hash_map::RandomState, HashMap};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use url::Url;
12use uuid::Uuid;
13
14use crate::error::{ApiError, ApiErrorKind, ApiResult};
15use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
16use crate::headers::vapid::VapidHeaderWithKey;
17use crate::routers::{Router, RouterError, RouterResponse};
18
19use autopush_common::db::{client::DbClient, User};
20use autopush_common::metric_name::MetricName;
21use autopush_common::metrics::StatsdClientExt;
22
23pub struct WebPushRouter {
29 pub db: Box<dyn DbClient>,
30 pub metrics: Arc<StatsdClient>,
31 pub http: reqwest::Client,
32 pub endpoint_url: Url,
33 pub in_flight_requests: Arc<AtomicUsize>,
34 #[cfg(feature = "reliable_report")]
35 pub reliability: Arc<PushReliability>,
36}
37
38#[async_trait(?Send)]
39impl Router for WebPushRouter {
40 fn register(
41 &self,
42 _router_input: &RouterDataInput,
43 _app_id: &str,
44 ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
45 Ok(HashMap::new())
47 }
48
49 async fn route_notification(&self, notification: Notification) -> ApiResult<RouterResponse> {
50 let route_start = Instant::now();
51 let result = self.route_notification_inner(notification).await;
52 self.metrics
53 .time_with_tags(
54 MetricName::NotificationRouteTime.as_ref(),
55 route_start.elapsed().as_millis() as u64,
56 )
57 .with_tag("outcome", if result.is_ok() { "ok" } else { "error" })
58 .send();
59 result
60 }
61}
62
63impl WebPushRouter {
64 async fn route_notification_inner(
65 &self,
66 mut notification: Notification,
67 ) -> ApiResult<RouterResponse> {
68 let notif_user = ¬ification.subscription.user;
72 let uaid = notif_user.uaid;
73 let node_id = notif_user.node_id.clone();
74 debug!(
75 "✉ Routing WebPush notification to UAID {} :: {:?}",
76 uaid, notification.subscription.reliability_id,
77 );
78 trace!("✉ Notification = {:?}", notification);
79
80 if let Some(node_id) = node_id {
82 trace!(
83 "✉ User has a node ID, sending notification to node: {}",
84 &node_id
85 );
86
87 #[cfg(feature = "reliable_report")]
88 let revert_state = notification.reliable_state;
89 #[cfg(feature = "reliable_report")]
90 notification
91 .record_reliability(
92 &self.reliability,
93 autopush_common::reliability::ReliabilityState::IntTransmitted,
94 )
95 .await;
96 let send_start = Instant::now();
97 match self.send_notification(¬ification, &node_id).await {
98 Ok(response) => {
99 let elapsed = send_start.elapsed().as_millis() as u64;
100 let status = response.status().as_u16();
101 self.metrics
102 .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
103 .with_tag("status", &status.to_string())
104 .send();
105 self.metrics
106 .incr_with_tags(MetricName::DirectDeliveryStatus)
107 .with_tag("status", &status.to_string())
108 .send();
109 if status == 200 {
111 trace!("✉ Node received notification");
113 return Ok(self.make_delivered_response(¬ification));
114 }
115 trace!(
116 "✉ Node did not receive the notification, response = {:?}",
117 response
118 );
119 }
120 Err(error) => {
121 let elapsed = send_start.elapsed().as_millis() as u64;
122 let status_tag = if let ApiErrorKind::ReqwestError(error) = &error.kind {
123 if error.is_timeout() {
124 self.metrics.incr(MetricName::ErrorNodeTimeout)?;
125 "timeout"
126 } else if error.is_connect() {
127 self.metrics.incr(MetricName::ErrorNodeConnect)?;
128 "connect_error"
129 } else {
130 "error"
131 }
132 } else {
133 "error"
134 };
135 self.metrics
136 .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
137 .with_tag("status", status_tag)
138 .send();
139 self.metrics
140 .incr_with_tags(MetricName::DirectDeliveryStatus)
141 .with_tag("status", status_tag)
142 .send();
143 debug!("✉ Error while sending webpush notification: {}", error);
144 self.remove_node_id(¬ification.subscription.user, &node_id)
145 .await?
146 }
147 }
148
149 #[cfg(feature = "reliable_report")]
150 if let Some(revert_state) = revert_state {
152 trace!(
153 "🔎⚠️ Revert {:?} from {:?} to {:?}",
154 ¬ification.reliability_id,
155 ¬ification.reliable_state,
156 revert_state
157 );
158 notification
159 .record_reliability(&self.reliability, revert_state)
160 .await;
161 }
162 }
163
164 if notification.headers.ttl == 0 {
165 let topic = notification.headers.topic.is_some().to_string();
166 trace!(
167 "✉ Notification has a TTL of zero and was not successfully \
168 delivered, dropping it"
169 );
170 self.metrics
171 .incr_with_tags(MetricName::NotificationMessageExpired)
172 .with_tag("topic", &topic)
174 .send();
175 #[cfg(feature = "reliable_report")]
176 notification
177 .record_reliability(
178 &self.reliability,
179 autopush_common::reliability::ReliabilityState::Expired,
180 )
181 .await;
182 return Ok(self.make_delivered_response(¬ification));
183 }
184
185 trace!("✉ Node is not present or busy, storing notification");
187 let store_start = Instant::now();
188 self.store_notification(&mut notification).await?;
189 self.metrics
190 .time_with_tags(
191 MetricName::StorageSaveTime.as_ref(),
192 store_start.elapsed().as_millis() as u64,
193 )
194 .send();
195
196 let user = match self.db.get_user(&uaid).await {
199 Ok(Some(user)) => user,
200 Ok(None) => {
201 trace!("✉ No user found, must have been deleted");
202 return Err(self.handle_error(
203 ApiErrorKind::Router(RouterError::UserWasDeleted),
204 notification.subscription.vapid.clone(),
205 ));
206 }
207 Err(e) => {
208 debug!("✉ Database error while re-fetching user: {}", e);
210 return Ok(self.make_stored_response(¬ification));
211 }
212 };
213
214 let node_id = match &user.node_id {
216 Some(id) => id,
217 None => {
219 trace!("✉ User is not connected to a node, returning stored response");
220 return Ok(self.make_stored_response(¬ification));
221 }
222 };
223
224 trace!("✉ Notifying node to check for messages");
226 match self.trigger_notification_check(&user.uaid, node_id).await {
227 Ok(response) => {
228 trace!("Response = {:?}", response);
229 if response.status() == 200 {
230 trace!("✉ Node has delivered the message");
231 self.metrics
232 .time_with_tags(
233 MetricName::NotificationTotalRequestTime.as_ref(),
234 (notification.timestamp - autopush_common::util::sec_since_epoch())
235 * 1000,
236 )
237 .with_tag("platform", "websocket")
238 .with_tag("app_id", "direct")
239 .send();
240
241 Ok(self.make_delivered_response(¬ification))
242 } else {
243 trace!("✉ Node has not delivered the message, returning stored response");
244 Ok(self.make_stored_response(¬ification))
245 }
246 }
247 Err(error) => {
248 debug!("✉ Error while triggering notification check: {}", error);
250 self.remove_node_id(&user, node_id).await?;
251 Ok(self.make_stored_response(¬ification))
252 }
253 }
254 }
255
256 fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
258 let mut err = ApiError::from(error);
259 if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
260 let mut extras = err.extras.unwrap_or_default();
261 if let Some(sub) = claims.sub {
262 extras.extend([("sub".to_owned(), sub)]);
263 }
264 err.extras = Some(extras);
265 };
266 err
267 }
268
269 async fn send_notification(
271 &self,
272 notification: &Notification,
273 node_id: &str,
274 ) -> ApiResult<Response> {
275 let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
276
277 let notification_out = notification.serialize_for_delivery()?;
278
279 trace!(
280 "⏩ out: Notification: {}, channel_id: {} :: {:?}",
281 ¬ification.subscription.user.uaid,
282 ¬ification.subscription.channel_id,
283 ¬ification_out,
284 );
285 self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
286 let result = self.http.put(&url).json(¬ification_out).send().await;
287 self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
288 Ok(result?)
289 }
290
291 async fn trigger_notification_check(
293 &self,
294 uaid: &Uuid,
295 node_id: &str,
296 ) -> Result<Response, reqwest::Error> {
297 let url = format!("{node_id}/notif/{uaid}");
298
299 self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
300 let result = self.http.put(&url).send().await;
301 self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
302 result
303 }
304
305 async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
307 let result = self
308 .db
309 .save_message(
310 ¬ification.subscription.user.uaid,
311 autopush_common::notification::Notification::from(&*notification),
312 )
313 .await
314 .map_err(|e| {
315 self.handle_error(
316 ApiErrorKind::Router(RouterError::SaveDb(
317 e,
318 notification.subscription.vapid.as_ref().map(|vapid| {
320 vapid
321 .vapid
322 .claims()
323 .ok()
324 .and_then(|c| c.sub)
325 .unwrap_or_default()
326 }),
327 )),
328 notification.subscription.vapid.clone(),
329 )
330 });
331 #[cfg(feature = "reliable_report")]
332 notification
333 .record_reliability(
334 &self.reliability,
335 autopush_common::reliability::ReliabilityState::Stored,
336 )
337 .await;
338 result
339 }
340
341 async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
344 self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
345 let removed = self
346 .db
347 .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
348 .await?;
349 if !removed {
350 debug!("✉ The node id was not removed");
351 }
352 Ok(())
353 }
354
355 fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
358 self.make_response(notification, "Direct", StatusCode::CREATED)
359 }
360
361 fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
364 self.make_response(notification, "Stored", StatusCode::CREATED)
365 }
366
367 fn make_response(
369 &self,
370 notification: &Notification,
371 destination_tag: &str,
372 status: StatusCode,
373 ) -> RouterResponse {
374 self.metrics
375 .count_with_tags(
376 MetricName::NotificationMessageData.as_ref(),
377 notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
378 )
379 .with_tag("destination", destination_tag)
380 .send();
381
382 RouterResponse {
383 status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
384 headers: {
385 let mut map = HashMap::new();
386 map.insert(
387 "Location",
388 self.endpoint_url
389 .join(&format!("/m/{}", notification.message_id))
390 .expect("Message ID is not URL-safe")
391 .to_string(),
392 );
393 map.insert("TTL", notification.headers.ttl.to_string());
394 map
395 },
396 body: None,
397 }
398 }
399}
400
401#[cfg(test)]
402mod test {
403 use std::boxed::Box;
404 use std::sync::Arc;
405
406 use reqwest;
407
408 use crate::extractors::subscription::tests::{make_vapid, PUB_KEY};
409 use crate::headers::vapid::VapidClaims;
410 use autopush_common::errors::ReportableError;
411 #[cfg(feature = "reliable_report")]
412 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
413
414 use super::*;
415 use autopush_common::db::mock::MockDbClient;
416
417 fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
418 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
419 WebPushRouter {
420 db: db.clone(),
421 metrics: metrics.clone(),
422 http: reqwest::Client::new(),
423 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
424 in_flight_requests: Arc::new(AtomicUsize::new(0)),
425 #[cfg(feature = "reliable_report")]
426 reliability: Arc::new(
427 PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
428 ),
429 }
430 }
431
432 #[tokio::test]
433 async fn pass_extras() {
434 let db = MockDbClient::new().into_boxed_arc();
435 let router = make_router(db);
436 let sub = "foo@example.com";
437 let vapid = make_vapid(
438 sub,
439 "https://push.services.mozilla.org",
440 VapidClaims::default_exp(),
441 PUB_KEY.to_owned(),
442 );
443
444 let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
445 assert!(err.extras().contains(&("sub", sub.to_owned())));
446 }
447}