autoendpoint/routers/
webpush.rs

1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{HashMap, hash_map::RandomState};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::time::Instant;
11use url::Url;
12use uuid::Uuid;
13
14use crate::error::{ApiError, ApiErrorKind, ApiResult};
15use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
16use crate::headers::vapid::VapidHeaderWithKey;
17use crate::routers::{Router, RouterError, RouterResponse};
18
19use autopush_common::db::{User, client::DbClient};
20use autopush_common::metric_name::MetricName;
21use autopush_common::metrics::StatsdClientExt;
22
23/// The router for desktop user agents.
24///
25/// These agents are connected via an Autopush connection server. The correct
26/// server is located via the database routing table. If the server is busy or
27/// not available, the notification is stored in the database.
28pub struct WebPushRouter {
29    pub db: Box<dyn DbClient>,
30    pub metrics: Arc<StatsdClient>,
31    pub http: reqwest::Client,
32    pub endpoint_url: Url,
33    pub in_flight_requests: Arc<AtomicUsize>,
34    #[cfg(feature = "reliable_report")]
35    pub reliability: Arc<PushReliability>,
36}
37
38#[async_trait(?Send)]
39impl Router for WebPushRouter {
40    fn register(
41        &self,
42        _router_input: &RouterDataInput,
43        _app_id: &str,
44    ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
45        // WebPush registration happens through the connection server
46        Ok(HashMap::new())
47    }
48
49    async fn route_notification(&self, notification: Notification) -> ApiResult<RouterResponse> {
50        let route_start = Instant::now();
51        let result = self.route_notification_inner(notification).await;
52        self.metrics
53            .time_with_tags(
54                MetricName::NotificationRouteTime.as_ref(),
55                route_start.elapsed().as_millis() as u64,
56            )
57            .with_tag("outcome", if result.is_ok() { "ok" } else { "error" })
58            .send();
59        result
60    }
61}
62
63impl WebPushRouter {
64    async fn route_notification_inner(
65        &self,
66        mut notification: Notification,
67    ) -> ApiResult<RouterResponse> {
68        // The notification contains the original subscription information.
69        // Extract user fields upfront to avoid borrow conflicts with
70        // record_reliability's &mut self requirement.
71        let notif_user = &notification.subscription.user;
72        let uaid = notif_user.uaid;
73        let node_id = notif_user.node_id.clone();
74        debug!(
75            "✉ Routing WebPush notification to UAID {} :: {:?}",
76            uaid, notification.subscription.reliability_id,
77        );
78        trace!("✉ Notification = {:?}", notification);
79
80        // Check if there is a node connected to the client
81        if let Some(node_id) = node_id {
82            trace!(
83                "✉ User has a node ID, sending notification to node: {}",
84                &node_id
85            );
86
87            #[cfg(feature = "reliable_report")]
88            let revert_state = notification.reliable_state;
89            #[cfg(feature = "reliable_report")]
90            notification
91                .record_reliability(
92                    &self.reliability,
93                    autopush_common::reliability::ReliabilityState::IntTransmitted,
94                )
95                .await;
96            let send_start = Instant::now();
97            match self.send_notification(&notification, &node_id).await {
98                Ok(response) => {
99                    let elapsed = send_start.elapsed().as_millis() as u64;
100                    let status = response.status().as_u16();
101                    self.metrics
102                        .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
103                        .with_tag("status", &status.to_string())
104                        .send();
105                    self.metrics
106                        .incr_with_tags(MetricName::DirectDeliveryStatus)
107                        .with_tag("status", &status.to_string())
108                        .send();
109                    // The node might be busy, make sure it accepted the notification
110                    if status == 200 {
111                        // The node has received the notification
112                        trace!("✉ Node received notification");
113                        return Ok(self.make_delivered_response(&notification));
114                    }
115                    trace!(
116                        "✉ Node did not receive the notification, response = {:?}",
117                        response
118                    );
119                }
120                Err(error) => {
121                    let elapsed = send_start.elapsed().as_millis() as u64;
122                    let status_tag = if let ApiErrorKind::ReqwestError(error) = &error.kind {
123                        if error.is_timeout() {
124                            self.metrics.incr(MetricName::ErrorNodeTimeout)?;
125                            "timeout"
126                        } else if error.is_connect() {
127                            self.metrics.incr(MetricName::ErrorNodeConnect)?;
128                            "connect_error"
129                        } else {
130                            "error"
131                        }
132                    } else {
133                        "error"
134                    };
135                    self.metrics
136                        .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
137                        .with_tag("status", status_tag)
138                        .send();
139                    self.metrics
140                        .incr_with_tags(MetricName::DirectDeliveryStatus)
141                        .with_tag("status", status_tag)
142                        .send();
143                    debug!("✉ Error while sending webpush notification: {}", error);
144                    self.remove_node_id(&notification.subscription.user, &node_id)
145                        .await?
146                }
147            }
148
149            #[cfg(feature = "reliable_report")]
150            // Couldn't send the message! So revert to the prior state if we have one
151            if let Some(revert_state) = revert_state {
152                trace!(
153                    "🔎⚠️ Revert {:?} from {:?} to {:?}",
154                    &notification.reliability_id, &notification.reliable_state, revert_state
155                );
156                notification
157                    .record_reliability(&self.reliability, revert_state)
158                    .await;
159            }
160        }
161
162        if notification.headers.ttl == 0 {
163            let topic = notification.headers.topic.is_some().to_string();
164            trace!(
165                "✉ Notification has a TTL of zero and was not successfully \
166                 delivered, dropping it"
167            );
168            self.metrics
169                .incr_with_tags(MetricName::NotificationMessageExpired)
170                // TODO: include `internal` if meta is set.
171                .with_tag("topic", &topic)
172                .send();
173            #[cfg(feature = "reliable_report")]
174            notification
175                .record_reliability(
176                    &self.reliability,
177                    autopush_common::reliability::ReliabilityState::Expired,
178                )
179                .await;
180            return Ok(self.make_delivered_response(&notification));
181        }
182
183        // Save notification, node is not present or busy
184        trace!("✉ Node is not present or busy, storing notification");
185        let store_start = Instant::now();
186        self.store_notification(&mut notification).await?;
187        self.metrics
188            .time_with_tags(
189                MetricName::StorageSaveTime.as_ref(),
190                store_start.elapsed().as_millis() as u64,
191            )
192            .send();
193
194        // Retrieve the user data again, they may have reconnected or the node
195        // is no longer busy.
196        let user = match self.db.get_user(&uaid).await {
197            Ok(Some(user)) => user,
198            Ok(None) => {
199                trace!("✉ No user found, must have been deleted");
200                return Err(self.handle_error(
201                    ApiErrorKind::Router(RouterError::UserWasDeleted),
202                    notification.subscription.vapid.clone(),
203                ));
204            }
205            Err(e) => {
206                // Database error, but we already stored the message so it's ok
207                debug!("✉ Database error while re-fetching user: {}", e);
208                return Ok(self.make_stored_response(&notification));
209            }
210        };
211
212        // Try to notify the node the user is currently connected to
213        let node_id = match &user.node_id {
214            Some(id) => id,
215            // The user is not connected to a node, nothing more to do
216            None => {
217                trace!("✉ User is not connected to a node, returning stored response");
218                return Ok(self.make_stored_response(&notification));
219            }
220        };
221
222        // Notify the node to check for messages
223        trace!("✉ Notifying node to check for messages");
224        match self.trigger_notification_check(&user.uaid, node_id).await {
225            Ok(response) => {
226                trace!("Response = {:?}", response);
227                if response.status() == 200 {
228                    trace!("✉ Node has delivered the message");
229                    self.metrics
230                        .time_with_tags(
231                            MetricName::NotificationTotalRequestTime.as_ref(),
232                            (notification.timestamp - autopush_common::util::sec_since_epoch())
233                                * 1000,
234                        )
235                        .with_tag("platform", "websocket")
236                        .with_tag("app_id", "direct")
237                        .send();
238
239                    Ok(self.make_delivered_response(&notification))
240                } else {
241                    trace!("✉ Node has not delivered the message, returning stored response");
242                    Ok(self.make_stored_response(&notification))
243                }
244            }
245            Err(error) => {
246                // Can't communicate with the node, attempt to stop using it
247                debug!("✉ Error while triggering notification check: {}", error);
248                self.remove_node_id(&user, node_id).await?;
249                Ok(self.make_stored_response(&notification))
250            }
251        }
252    }
253
254    /// Use the same sort of error chokepoint that all the mobile clients use.
255    fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
256        let mut err = ApiError::from(error);
257        if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
258            let mut extras = err.extras.unwrap_or_default();
259            if let Some(sub) = claims.sub {
260                extras.extend([("sub".to_owned(), sub)]);
261            }
262            err.extras = Some(extras);
263        };
264        err
265    }
266
267    /// Consume and send the notification to the node
268    async fn send_notification(
269        &self,
270        notification: &Notification,
271        node_id: &str,
272    ) -> ApiResult<Response> {
273        let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
274
275        let notification_out = notification.serialize_for_delivery()?;
276
277        trace!(
278            "⏩ out: Notification: {}, channel_id: {} :: {:?}",
279            &notification.subscription.user.uaid,
280            &notification.subscription.channel_id,
281            &notification_out,
282        );
283        self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
284        let result = self.http.put(&url).json(&notification_out).send().await;
285        self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
286        Ok(result?)
287    }
288
289    /// Notify the node to check for notifications for the user
290    async fn trigger_notification_check(
291        &self,
292        uaid: &Uuid,
293        node_id: &str,
294    ) -> Result<Response, reqwest::Error> {
295        let url = format!("{node_id}/notif/{uaid}");
296
297        self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
298        let result = self.http.put(&url).send().await;
299        self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
300        result
301    }
302
303    /// Store a notification in the database
304    async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
305        let result = self
306            .db
307            .save_message(
308                &notification.subscription.user.uaid,
309                autopush_common::notification::Notification::from(&*notification),
310            )
311            .await
312            .map_err(|e| {
313                self.handle_error(
314                    ApiErrorKind::Router(RouterError::SaveDb(
315                        e,
316                        // try to extract the `sub` from the VAPID claims.
317                        notification.subscription.vapid.as_ref().map(|vapid| {
318                            vapid
319                                .vapid
320                                .claims()
321                                .ok()
322                                .and_then(|c| c.sub)
323                                .unwrap_or_default()
324                        }),
325                    )),
326                    notification.subscription.vapid.clone(),
327                )
328            });
329        #[cfg(feature = "reliable_report")]
330        notification
331            .record_reliability(
332                &self.reliability,
333                autopush_common::reliability::ReliabilityState::Stored,
334            )
335            .await;
336        result
337    }
338
339    /// Remove the node ID from a user. This is done if the user is no longer
340    /// connected to the node.
341    async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
342        self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
343        let removed = self
344            .db
345            .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
346            .await?;
347        if !removed {
348            debug!("✉ The node id was not removed");
349        }
350        Ok(())
351    }
352
353    /// Update metrics and create a response for when a notification has been directly forwarded to
354    /// an autopush server.
355    fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
356        self.make_response(notification, "Direct", StatusCode::CREATED)
357    }
358
359    /// Update metrics and create a response for when a notification has been stored in the database
360    /// for future transmission.
361    fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
362        self.make_response(notification, "Stored", StatusCode::CREATED)
363    }
364
365    /// Update metrics and create a response after routing a notification
366    fn make_response(
367        &self,
368        notification: &Notification,
369        destination_tag: &str,
370        status: StatusCode,
371    ) -> RouterResponse {
372        self.metrics
373            .count_with_tags(
374                MetricName::NotificationMessageData.as_ref(),
375                notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
376            )
377            .with_tag("destination", destination_tag)
378            .send();
379
380        RouterResponse {
381            status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
382            headers: {
383                let mut map = HashMap::new();
384                map.insert(
385                    "Location",
386                    self.endpoint_url
387                        .join(&format!("/m/{}", notification.message_id))
388                        .expect("Message ID is not URL-safe")
389                        .to_string(),
390                );
391                map.insert("TTL", notification.headers.ttl.to_string());
392                map
393            },
394            body: None,
395        }
396    }
397}
398
399#[cfg(test)]
400mod test {
401    use std::boxed::Box;
402    use std::sync::Arc;
403
404    use reqwest;
405
406    use crate::extractors::subscription::tests::{PUB_KEY, make_vapid};
407    use crate::headers::vapid::VapidClaims;
408    use autopush_common::errors::ReportableError;
409    #[cfg(feature = "reliable_report")]
410    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
411
412    use super::*;
413    use autopush_common::db::mock::MockDbClient;
414
415    fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
416        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
417        WebPushRouter {
418            db: db.clone(),
419            metrics: metrics.clone(),
420            http: reqwest::Client::new(),
421            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
422            in_flight_requests: Arc::new(AtomicUsize::new(0)),
423            #[cfg(feature = "reliable_report")]
424            reliability: Arc::new(
425                PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
426            ),
427        }
428    }
429
430    #[tokio::test]
431    async fn pass_extras() {
432        let db = MockDbClient::new().into_boxed_arc();
433        let router = make_router(db);
434        let sub = "foo@example.com";
435        let vapid = make_vapid(
436            sub,
437            "https://push.services.mozilla.org",
438            VapidClaims::default_exp(),
439            PUB_KEY.to_owned(),
440        );
441
442        let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
443        assert!(err.extras().contains(&("sub", sub.to_owned())));
444    }
445}