1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{hash_map::RandomState, HashMap};
8use std::sync::Arc;
9use url::Url;
10use uuid::Uuid;
11
12use crate::error::{ApiError, ApiErrorKind, ApiResult};
13use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
14use crate::headers::vapid::VapidHeaderWithKey;
15use crate::routers::{Router, RouterError, RouterResponse};
16
17use autopush_common::db::{client::DbClient, User};
18use autopush_common::metric_name::MetricName;
19use autopush_common::metrics::StatsdClientExt;
20
21pub struct WebPushRouter {
27 pub db: Box<dyn DbClient>,
28 pub metrics: Arc<StatsdClient>,
29 pub http: reqwest::Client,
30 pub endpoint_url: Url,
31 #[cfg(feature = "reliable_report")]
32 pub reliability: Arc<PushReliability>,
33}
34
35#[async_trait(?Send)]
36impl Router for WebPushRouter {
37 fn register(
38 &self,
39 _router_input: &RouterDataInput,
40 _app_id: &str,
41 ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
42 Ok(HashMap::new())
44 }
45
46 async fn route_notification(
47 &self,
48 mut notification: Notification,
49 ) -> ApiResult<RouterResponse> {
50 let user = ¬ification.subscription.user.clone();
52 debug!(
55 "✉ Routing WebPush notification to UAID {} :: {:?}",
56 notification.subscription.user.uaid, notification.subscription.reliability_id,
57 );
58 trace!("✉ Notification = {:?}", notification);
59
60 if let Some(node_id) = &user.node_id {
62 trace!(
63 "✉ User has a node ID, sending notification to node: {}",
64 &node_id
65 );
66
67 #[cfg(feature = "reliable_report")]
68 let revert_state = notification.reliable_state;
69 #[cfg(feature = "reliable_report")]
70 notification
71 .record_reliability(
72 &self.reliability,
73 autopush_common::reliability::ReliabilityState::IntTransmitted,
74 )
75 .await;
76 match self.send_notification(¬ification, node_id).await {
77 Ok(response) => {
78 if response.status() == 200 {
80 trace!("✉ Node received notification");
82 return Ok(self.make_delivered_response(¬ification));
83 }
84 trace!(
85 "✉ Node did not receive the notification, response = {:?}",
86 response
87 );
88 }
89 Err(error) => {
90 if let ApiErrorKind::ReqwestError(error) = &error.kind {
91 if error.is_timeout() {
92 self.metrics.incr(MetricName::ErrorNodeTimeout)?;
93 };
94 if error.is_connect() {
95 self.metrics.incr(MetricName::ErrorNodeConnect)?;
96 };
97 };
98 debug!("✉ Error while sending webpush notification: {}", error);
99 self.remove_node_id(user, node_id).await?
100 }
101 }
102
103 #[cfg(feature = "reliable_report")]
104 if let Some(revert_state) = revert_state {
106 trace!(
107 "🔎⚠️ Revert {:?} from {:?} to {:?}",
108 ¬ification.reliability_id,
109 ¬ification.reliable_state,
110 revert_state
111 );
112 notification
113 .record_reliability(&self.reliability, revert_state)
114 .await;
115 }
116 }
117
118 if notification.headers.ttl == 0 {
119 let topic = notification.headers.topic.is_some().to_string();
120 trace!(
121 "✉ Notification has a TTL of zero and was not successfully \
122 delivered, dropping it"
123 );
124 self.metrics
125 .incr_with_tags(MetricName::NotificationMessageExpired)
126 .with_tag("topic", &topic)
128 .send();
129 #[cfg(feature = "reliable_report")]
130 notification
131 .record_reliability(
132 &self.reliability,
133 autopush_common::reliability::ReliabilityState::Expired,
134 )
135 .await;
136 return Ok(self.make_delivered_response(¬ification));
137 }
138
139 trace!("✉ Node is not present or busy, storing notification");
141 self.store_notification(&mut notification).await?;
142
143 let user = match self.db.get_user(&user.uaid).await {
146 Ok(Some(user)) => user,
147 Ok(None) => {
148 trace!("✉ No user found, must have been deleted");
149 return Err(self.handle_error(
150 ApiErrorKind::Router(RouterError::UserWasDeleted),
151 notification.subscription.vapid.clone(),
152 ));
153 }
154 Err(e) => {
155 debug!("✉ Database error while re-fetching user: {}", e);
157 return Ok(self.make_stored_response(¬ification));
158 }
159 };
160
161 let node_id = match &user.node_id {
163 Some(id) => id,
164 None => {
166 trace!("✉ User is not connected to a node, returning stored response");
167 return Ok(self.make_stored_response(¬ification));
168 }
169 };
170
171 trace!("✉ Notifying node to check for messages");
173 match self.trigger_notification_check(&user.uaid, node_id).await {
174 Ok(response) => {
175 trace!("Response = {:?}", response);
176 if response.status() == 200 {
177 trace!("✉ Node has delivered the message");
178 self.metrics
179 .time_with_tags(
180 MetricName::NotificationTotalRequestTime.as_ref(),
181 (notification.timestamp - autopush_common::util::sec_since_epoch())
182 * 1000,
183 )
184 .with_tag("platform", "websocket")
185 .with_tag("app_id", "direct")
186 .send();
187
188 Ok(self.make_delivered_response(¬ification))
189 } else {
190 trace!("✉ Node has not delivered the message, returning stored response");
191 Ok(self.make_stored_response(¬ification))
192 }
193 }
194 Err(error) => {
195 debug!("✉ Error while triggering notification check: {}", error);
197 self.remove_node_id(&user, node_id).await?;
198 Ok(self.make_stored_response(¬ification))
199 }
200 }
201 }
202}
203
204impl WebPushRouter {
205 fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
207 let mut err = ApiError::from(error);
208 if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
209 let mut extras = err.extras.unwrap_or_default();
210 if let Some(sub) = claims.sub {
211 extras.extend([("sub".to_owned(), sub)]);
212 }
213 err.extras = Some(extras);
214 };
215 err
216 }
217
218 async fn send_notification(
220 &self,
221 notification: &Notification,
222 node_id: &str,
223 ) -> ApiResult<Response> {
224 let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
225
226 let notification_out = notification.serialize_for_delivery()?;
227
228 trace!(
229 "⏩ out: Notification: {}, channel_id: {} :: {:?}",
230 ¬ification.subscription.user.uaid,
231 ¬ification.subscription.channel_id,
232 ¬ification_out,
233 );
234 Ok(self.http.put(&url).json(¬ification_out).send().await?)
235 }
236
237 async fn trigger_notification_check(
239 &self,
240 uaid: &Uuid,
241 node_id: &str,
242 ) -> Result<Response, reqwest::Error> {
243 let url = format!("{node_id}/notif/{uaid}");
244
245 self.http.put(&url).send().await
246 }
247
248 async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
250 let result = self
251 .db
252 .save_message(
253 ¬ification.subscription.user.uaid,
254 notification.clone().into(),
255 )
256 .await
257 .map_err(|e| {
258 self.handle_error(
259 ApiErrorKind::Router(RouterError::SaveDb(
260 e,
261 notification.subscription.vapid.as_ref().map(|vapid| {
263 vapid
264 .vapid
265 .claims()
266 .ok()
267 .and_then(|c| c.sub)
268 .unwrap_or_default()
269 }),
270 )),
271 notification.subscription.vapid.clone(),
272 )
273 });
274 #[cfg(feature = "reliable_report")]
275 notification
276 .record_reliability(
277 &self.reliability,
278 autopush_common::reliability::ReliabilityState::Stored,
279 )
280 .await;
281 result
282 }
283
284 async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
287 self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
288 let removed = self
289 .db
290 .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
291 .await?;
292 if !removed {
293 debug!("✉ The node id was not removed");
294 }
295 Ok(())
296 }
297
298 fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
301 self.make_response(notification, "Direct", StatusCode::CREATED)
302 }
303
304 fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
307 self.make_response(notification, "Stored", StatusCode::CREATED)
308 }
309
310 fn make_response(
312 &self,
313 notification: &Notification,
314 destination_tag: &str,
315 status: StatusCode,
316 ) -> RouterResponse {
317 self.metrics
318 .count_with_tags(
319 MetricName::NotificationMessageData.as_ref(),
320 notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
321 )
322 .with_tag("destination", destination_tag)
323 .send();
324
325 RouterResponse {
326 status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
327 headers: {
328 let mut map = HashMap::new();
329 map.insert(
330 "Location",
331 self.endpoint_url
332 .join(&format!("/m/{}", notification.message_id))
333 .expect("Message ID is not URL-safe")
334 .to_string(),
335 );
336 map.insert("TTL", notification.headers.ttl.to_string());
337 map
338 },
339 body: None,
340 }
341 }
342}
343
344#[cfg(test)]
345mod test {
346 use std::boxed::Box;
347 use std::sync::Arc;
348
349 use reqwest;
350
351 use crate::extractors::subscription::tests::{make_vapid, PUB_KEY};
352 use crate::headers::vapid::VapidClaims;
353 use autopush_common::errors::ReportableError;
354 #[cfg(feature = "reliable_report")]
355 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
356
357 use super::*;
358 use autopush_common::db::mock::MockDbClient;
359
360 fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
361 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
362 WebPushRouter {
363 db: db.clone(),
364 metrics: Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)),
365 http: reqwest::Client::new(),
366 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
367 #[cfg(feature = "reliable_report")]
368 reliability: Arc::new(
369 PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
370 ),
371 }
372 }
373
374 #[tokio::test]
375 async fn pass_extras() {
376 let db = MockDbClient::new().into_boxed_arc();
377 let router = make_router(db);
378 let sub = "foo@example.com";
379 let vapid = make_vapid(
380 sub,
381 "https://push.services.mozilla.org",
382 VapidClaims::default_exp(),
383 PUB_KEY.to_owned(),
384 );
385
386 let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
387 assert!(err.extras().contains(&("sub", sub.to_owned())));
388 }
389}