autoendpoint/routers/
webpush.rs

1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{hash_map::RandomState, HashMap};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Instant;
11use url::Url;
12use uuid::Uuid;
13
14use crate::error::{ApiError, ApiErrorKind, ApiResult};
15use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
16use crate::headers::vapid::VapidHeaderWithKey;
17use crate::routers::{Router, RouterError, RouterResponse};
18
19use autopush_common::db::{client::DbClient, User};
20use autopush_common::metric_name::MetricName;
21use autopush_common::metrics::StatsdClientExt;
22
23/// The router for desktop user agents.
24///
25/// These agents are connected via an Autopush connection server. The correct
26/// server is located via the database routing table. If the server is busy or
27/// not available, the notification is stored in the database.
28pub struct WebPushRouter {
29    pub db: Box<dyn DbClient>,
30    pub metrics: Arc<StatsdClient>,
31    pub http: reqwest::Client,
32    pub endpoint_url: Url,
33    pub in_flight_requests: Arc<AtomicUsize>,
34    #[cfg(feature = "reliable_report")]
35    pub reliability: Arc<PushReliability>,
36}
37
38#[async_trait(?Send)]
39impl Router for WebPushRouter {
40    fn register(
41        &self,
42        _router_input: &RouterDataInput,
43        _app_id: &str,
44    ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
45        // WebPush registration happens through the connection server
46        Ok(HashMap::new())
47    }
48
49    async fn route_notification(&self, notification: Notification) -> ApiResult<RouterResponse> {
50        let route_start = Instant::now();
51        let result = self.route_notification_inner(notification).await;
52        self.metrics
53            .time_with_tags(
54                MetricName::NotificationRouteTime.as_ref(),
55                route_start.elapsed().as_millis() as u64,
56            )
57            .with_tag("outcome", if result.is_ok() { "ok" } else { "error" })
58            .send();
59        result
60    }
61}
62
63impl WebPushRouter {
64    async fn route_notification_inner(
65        &self,
66        mut notification: Notification,
67    ) -> ApiResult<RouterResponse> {
68        // The notification contains the original subscription information.
69        // Extract user fields upfront to avoid borrow conflicts with
70        // record_reliability's &mut self requirement.
71        let notif_user = &notification.subscription.user;
72        let uaid = notif_user.uaid;
73        let node_id = notif_user.node_id.clone();
74        debug!(
75            "✉ Routing WebPush notification to UAID {} :: {:?}",
76            uaid, notification.subscription.reliability_id,
77        );
78        trace!("✉ Notification = {:?}", notification);
79
80        // Check if there is a node connected to the client
81        if let Some(node_id) = node_id {
82            trace!(
83                "✉ User has a node ID, sending notification to node: {}",
84                &node_id
85            );
86
87            #[cfg(feature = "reliable_report")]
88            let revert_state = notification.reliable_state;
89            #[cfg(feature = "reliable_report")]
90            notification
91                .record_reliability(
92                    &self.reliability,
93                    autopush_common::reliability::ReliabilityState::IntTransmitted,
94                )
95                .await;
96            let send_start = Instant::now();
97            match self.send_notification(&notification, &node_id).await {
98                Ok(response) => {
99                    let elapsed = send_start.elapsed().as_millis() as u64;
100                    let status = response.status().as_u16();
101                    self.metrics
102                        .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
103                        .with_tag("status", &status.to_string())
104                        .send();
105                    self.metrics
106                        .incr_with_tags(MetricName::DirectDeliveryStatus)
107                        .with_tag("status", &status.to_string())
108                        .send();
109                    // The node might be busy, make sure it accepted the notification
110                    if status == 200 {
111                        // The node has received the notification
112                        trace!("✉ Node received notification");
113                        return Ok(self.make_delivered_response(&notification));
114                    }
115                    trace!(
116                        "✉ Node did not receive the notification, response = {:?}",
117                        response
118                    );
119                }
120                Err(error) => {
121                    let elapsed = send_start.elapsed().as_millis() as u64;
122                    let status_tag = if let ApiErrorKind::ReqwestError(error) = &error.kind {
123                        if error.is_timeout() {
124                            self.metrics.incr(MetricName::ErrorNodeTimeout)?;
125                            "timeout"
126                        } else if error.is_connect() {
127                            self.metrics.incr(MetricName::ErrorNodeConnect)?;
128                            "connect_error"
129                        } else {
130                            "error"
131                        }
132                    } else {
133                        "error"
134                    };
135                    self.metrics
136                        .time_with_tags(MetricName::DirectDeliveryTime.as_ref(), elapsed)
137                        .with_tag("status", status_tag)
138                        .send();
139                    self.metrics
140                        .incr_with_tags(MetricName::DirectDeliveryStatus)
141                        .with_tag("status", status_tag)
142                        .send();
143                    debug!("✉ Error while sending webpush notification: {}", error);
144                    self.remove_node_id(&notification.subscription.user, &node_id)
145                        .await?
146                }
147            }
148
149            #[cfg(feature = "reliable_report")]
150            // Couldn't send the message! So revert to the prior state if we have one
151            if let Some(revert_state) = revert_state {
152                trace!(
153                    "🔎⚠️ Revert {:?} from {:?} to {:?}",
154                    &notification.reliability_id,
155                    &notification.reliable_state,
156                    revert_state
157                );
158                notification
159                    .record_reliability(&self.reliability, revert_state)
160                    .await;
161            }
162        }
163
164        if notification.headers.ttl == 0 {
165            let topic = notification.headers.topic.is_some().to_string();
166            trace!(
167                "✉ Notification has a TTL of zero and was not successfully \
168                 delivered, dropping it"
169            );
170            self.metrics
171                .incr_with_tags(MetricName::NotificationMessageExpired)
172                // TODO: include `internal` if meta is set.
173                .with_tag("topic", &topic)
174                .send();
175            #[cfg(feature = "reliable_report")]
176            notification
177                .record_reliability(
178                    &self.reliability,
179                    autopush_common::reliability::ReliabilityState::Expired,
180                )
181                .await;
182            return Ok(self.make_delivered_response(&notification));
183        }
184
185        // Save notification, node is not present or busy
186        trace!("✉ Node is not present or busy, storing notification");
187        let store_start = Instant::now();
188        self.store_notification(&mut notification).await?;
189        self.metrics
190            .time_with_tags(
191                MetricName::StorageSaveTime.as_ref(),
192                store_start.elapsed().as_millis() as u64,
193            )
194            .send();
195
196        // Retrieve the user data again, they may have reconnected or the node
197        // is no longer busy.
198        let user = match self.db.get_user(&uaid).await {
199            Ok(Some(user)) => user,
200            Ok(None) => {
201                trace!("✉ No user found, must have been deleted");
202                return Err(self.handle_error(
203                    ApiErrorKind::Router(RouterError::UserWasDeleted),
204                    notification.subscription.vapid.clone(),
205                ));
206            }
207            Err(e) => {
208                // Database error, but we already stored the message so it's ok
209                debug!("✉ Database error while re-fetching user: {}", e);
210                return Ok(self.make_stored_response(&notification));
211            }
212        };
213
214        // Try to notify the node the user is currently connected to
215        let node_id = match &user.node_id {
216            Some(id) => id,
217            // The user is not connected to a node, nothing more to do
218            None => {
219                trace!("✉ User is not connected to a node, returning stored response");
220                return Ok(self.make_stored_response(&notification));
221            }
222        };
223
224        // Notify the node to check for messages
225        trace!("✉ Notifying node to check for messages");
226        match self.trigger_notification_check(&user.uaid, node_id).await {
227            Ok(response) => {
228                trace!("Response = {:?}", response);
229                if response.status() == 200 {
230                    trace!("✉ Node has delivered the message");
231                    self.metrics
232                        .time_with_tags(
233                            MetricName::NotificationTotalRequestTime.as_ref(),
234                            (notification.timestamp - autopush_common::util::sec_since_epoch())
235                                * 1000,
236                        )
237                        .with_tag("platform", "websocket")
238                        .with_tag("app_id", "direct")
239                        .send();
240
241                    Ok(self.make_delivered_response(&notification))
242                } else {
243                    trace!("✉ Node has not delivered the message, returning stored response");
244                    Ok(self.make_stored_response(&notification))
245                }
246            }
247            Err(error) => {
248                // Can't communicate with the node, attempt to stop using it
249                debug!("✉ Error while triggering notification check: {}", error);
250                self.remove_node_id(&user, node_id).await?;
251                Ok(self.make_stored_response(&notification))
252            }
253        }
254    }
255
256    /// Use the same sort of error chokepoint that all the mobile clients use.
257    fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
258        let mut err = ApiError::from(error);
259        if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
260            let mut extras = err.extras.unwrap_or_default();
261            if let Some(sub) = claims.sub {
262                extras.extend([("sub".to_owned(), sub)]);
263            }
264            err.extras = Some(extras);
265        };
266        err
267    }
268
269    /// Consume and send the notification to the node
270    async fn send_notification(
271        &self,
272        notification: &Notification,
273        node_id: &str,
274    ) -> ApiResult<Response> {
275        let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
276
277        let notification_out = notification.serialize_for_delivery()?;
278
279        trace!(
280            "⏩ out: Notification: {}, channel_id: {} :: {:?}",
281            &notification.subscription.user.uaid,
282            &notification.subscription.channel_id,
283            &notification_out,
284        );
285        self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
286        let result = self.http.put(&url).json(&notification_out).send().await;
287        self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
288        Ok(result?)
289    }
290
291    /// Notify the node to check for notifications for the user
292    async fn trigger_notification_check(
293        &self,
294        uaid: &Uuid,
295        node_id: &str,
296    ) -> Result<Response, reqwest::Error> {
297        let url = format!("{node_id}/notif/{uaid}");
298
299        self.in_flight_requests.fetch_add(1, Ordering::Relaxed);
300        let result = self.http.put(&url).send().await;
301        self.in_flight_requests.fetch_sub(1, Ordering::Relaxed);
302        result
303    }
304
305    /// Store a notification in the database
306    async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
307        let result = self
308            .db
309            .save_message(
310                &notification.subscription.user.uaid,
311                autopush_common::notification::Notification::from(&*notification),
312            )
313            .await
314            .map_err(|e| {
315                self.handle_error(
316                    ApiErrorKind::Router(RouterError::SaveDb(
317                        e,
318                        // try to extract the `sub` from the VAPID claims.
319                        notification.subscription.vapid.as_ref().map(|vapid| {
320                            vapid
321                                .vapid
322                                .claims()
323                                .ok()
324                                .and_then(|c| c.sub)
325                                .unwrap_or_default()
326                        }),
327                    )),
328                    notification.subscription.vapid.clone(),
329                )
330            });
331        #[cfg(feature = "reliable_report")]
332        notification
333            .record_reliability(
334                &self.reliability,
335                autopush_common::reliability::ReliabilityState::Stored,
336            )
337            .await;
338        result
339    }
340
341    /// Remove the node ID from a user. This is done if the user is no longer
342    /// connected to the node.
343    async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
344        self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
345        let removed = self
346            .db
347            .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
348            .await?;
349        if !removed {
350            debug!("✉ The node id was not removed");
351        }
352        Ok(())
353    }
354
355    /// Update metrics and create a response for when a notification has been directly forwarded to
356    /// an autopush server.
357    fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
358        self.make_response(notification, "Direct", StatusCode::CREATED)
359    }
360
361    /// Update metrics and create a response for when a notification has been stored in the database
362    /// for future transmission.
363    fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
364        self.make_response(notification, "Stored", StatusCode::CREATED)
365    }
366
367    /// Update metrics and create a response after routing a notification
368    fn make_response(
369        &self,
370        notification: &Notification,
371        destination_tag: &str,
372        status: StatusCode,
373    ) -> RouterResponse {
374        self.metrics
375            .count_with_tags(
376                MetricName::NotificationMessageData.as_ref(),
377                notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
378            )
379            .with_tag("destination", destination_tag)
380            .send();
381
382        RouterResponse {
383            status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
384            headers: {
385                let mut map = HashMap::new();
386                map.insert(
387                    "Location",
388                    self.endpoint_url
389                        .join(&format!("/m/{}", notification.message_id))
390                        .expect("Message ID is not URL-safe")
391                        .to_string(),
392                );
393                map.insert("TTL", notification.headers.ttl.to_string());
394                map
395            },
396            body: None,
397        }
398    }
399}
400
401#[cfg(test)]
402mod test {
403    use std::boxed::Box;
404    use std::sync::Arc;
405
406    use reqwest;
407
408    use crate::extractors::subscription::tests::{make_vapid, PUB_KEY};
409    use crate::headers::vapid::VapidClaims;
410    use autopush_common::errors::ReportableError;
411    #[cfg(feature = "reliable_report")]
412    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
413
414    use super::*;
415    use autopush_common::db::mock::MockDbClient;
416
417    fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
418        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
419        WebPushRouter {
420            db: db.clone(),
421            metrics: metrics.clone(),
422            http: reqwest::Client::new(),
423            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
424            in_flight_requests: Arc::new(AtomicUsize::new(0)),
425            #[cfg(feature = "reliable_report")]
426            reliability: Arc::new(
427                PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
428            ),
429        }
430    }
431
432    #[tokio::test]
433    async fn pass_extras() {
434        let db = MockDbClient::new().into_boxed_arc();
435        let router = make_router(db);
436        let sub = "foo@example.com";
437        let vapid = make_vapid(
438            sub,
439            "https://push.services.mozilla.org",
440            VapidClaims::default_exp(),
441            PUB_KEY.to_owned(),
442        );
443
444        let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
445        assert!(err.extras().contains(&("sub", sub.to_owned())));
446    }
447}