autoendpoint/routers/
webpush.rs

1use async_trait::async_trait;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::PushReliability;
4use cadence::{Counted, StatsdClient, Timed};
5use reqwest::{Response, StatusCode};
6use serde_json::Value;
7use std::collections::{hash_map::RandomState, HashMap};
8use std::sync::Arc;
9use url::Url;
10use uuid::Uuid;
11
12use crate::error::{ApiError, ApiErrorKind, ApiResult};
13use crate::extractors::{notification::Notification, router_data_input::RouterDataInput};
14use crate::headers::vapid::VapidHeaderWithKey;
15use crate::routers::{Router, RouterError, RouterResponse};
16
17use autopush_common::db::{client::DbClient, User};
18use autopush_common::metric_name::MetricName;
19use autopush_common::metrics::StatsdClientExt;
20
21/// The router for desktop user agents.
22///
23/// These agents are connected via an Autopush connection server. The correct
24/// server is located via the database routing table. If the server is busy or
25/// not available, the notification is stored in the database.
26pub struct WebPushRouter {
27    pub db: Box<dyn DbClient>,
28    pub metrics: Arc<StatsdClient>,
29    pub http: reqwest::Client,
30    pub endpoint_url: Url,
31    #[cfg(feature = "reliable_report")]
32    pub reliability: Arc<PushReliability>,
33}
34
35#[async_trait(?Send)]
36impl Router for WebPushRouter {
37    fn register(
38        &self,
39        _router_input: &RouterDataInput,
40        _app_id: &str,
41    ) -> Result<HashMap<String, Value, RandomState>, RouterError> {
42        // WebPush registration happens through the connection server
43        Ok(HashMap::new())
44    }
45
46    async fn route_notification(
47        &self,
48        mut notification: Notification,
49    ) -> ApiResult<RouterResponse> {
50        // The notification contains the original subscription information
51        let user = &notification.subscription.user.clone();
52        // A clone of the notification used only for the responses
53        // The canonical Notification is consumed by the various functions.
54        debug!(
55            "✉ Routing WebPush notification to UAID {} :: {:?}",
56            notification.subscription.user.uaid, notification.subscription.reliability_id,
57        );
58        trace!("✉ Notification = {:?}", notification);
59
60        // Check if there is a node connected to the client
61        if let Some(node_id) = &user.node_id {
62            trace!(
63                "✉ User has a node ID, sending notification to node: {}",
64                &node_id
65            );
66
67            #[cfg(feature = "reliable_report")]
68            let revert_state = notification.reliable_state;
69            #[cfg(feature = "reliable_report")]
70            notification
71                .record_reliability(
72                    &self.reliability,
73                    autopush_common::reliability::ReliabilityState::IntTransmitted,
74                )
75                .await;
76            match self.send_notification(&notification, node_id).await {
77                Ok(response) => {
78                    // The node might be busy, make sure it accepted the notification
79                    if response.status() == 200 {
80                        // The node has received the notification
81                        trace!("✉ Node received notification");
82                        return Ok(self.make_delivered_response(&notification));
83                    }
84                    trace!(
85                        "✉ Node did not receive the notification, response = {:?}",
86                        response
87                    );
88                }
89                Err(error) => {
90                    if let ApiErrorKind::ReqwestError(error) = &error.kind {
91                        if error.is_timeout() {
92                            self.metrics.incr(MetricName::ErrorNodeTimeout)?;
93                        };
94                        if error.is_connect() {
95                            self.metrics.incr(MetricName::ErrorNodeConnect)?;
96                        };
97                    };
98                    debug!("✉ Error while sending webpush notification: {}", error);
99                    self.remove_node_id(user, node_id).await?
100                }
101            }
102
103            #[cfg(feature = "reliable_report")]
104            // Couldn't send the message! So revert to the prior state if we have one
105            if let Some(revert_state) = revert_state {
106                trace!(
107                    "🔎⚠️ Revert {:?} from {:?} to {:?}",
108                    &notification.reliability_id,
109                    &notification.reliable_state,
110                    revert_state
111                );
112                notification
113                    .record_reliability(&self.reliability, revert_state)
114                    .await;
115            }
116        }
117
118        if notification.headers.ttl == 0 {
119            let topic = notification.headers.topic.is_some().to_string();
120            trace!(
121                "✉ Notification has a TTL of zero and was not successfully \
122                 delivered, dropping it"
123            );
124            self.metrics
125                .incr_with_tags(MetricName::NotificationMessageExpired)
126                // TODO: include `internal` if meta is set.
127                .with_tag("topic", &topic)
128                .send();
129            #[cfg(feature = "reliable_report")]
130            notification
131                .record_reliability(
132                    &self.reliability,
133                    autopush_common::reliability::ReliabilityState::Expired,
134                )
135                .await;
136            return Ok(self.make_delivered_response(&notification));
137        }
138
139        // Save notification, node is not present or busy
140        trace!("✉ Node is not present or busy, storing notification");
141        self.store_notification(&mut notification).await?;
142
143        // Retrieve the user data again, they may have reconnected or the node
144        // is no longer busy.
145        let user = match self.db.get_user(&user.uaid).await {
146            Ok(Some(user)) => user,
147            Ok(None) => {
148                trace!("✉ No user found, must have been deleted");
149                return Err(self.handle_error(
150                    ApiErrorKind::Router(RouterError::UserWasDeleted),
151                    notification.subscription.vapid.clone(),
152                ));
153            }
154            Err(e) => {
155                // Database error, but we already stored the message so it's ok
156                debug!("✉ Database error while re-fetching user: {}", e);
157                return Ok(self.make_stored_response(&notification));
158            }
159        };
160
161        // Try to notify the node the user is currently connected to
162        let node_id = match &user.node_id {
163            Some(id) => id,
164            // The user is not connected to a node, nothing more to do
165            None => {
166                trace!("✉ User is not connected to a node, returning stored response");
167                return Ok(self.make_stored_response(&notification));
168            }
169        };
170
171        // Notify the node to check for messages
172        trace!("✉ Notifying node to check for messages");
173        match self.trigger_notification_check(&user.uaid, node_id).await {
174            Ok(response) => {
175                trace!("Response = {:?}", response);
176                if response.status() == 200 {
177                    trace!("✉ Node has delivered the message");
178                    self.metrics
179                        .time_with_tags(
180                            MetricName::NotificationTotalRequestTime.as_ref(),
181                            (notification.timestamp - autopush_common::util::sec_since_epoch())
182                                * 1000,
183                        )
184                        .with_tag("platform", "websocket")
185                        .with_tag("app_id", "direct")
186                        .send();
187
188                    Ok(self.make_delivered_response(&notification))
189                } else {
190                    trace!("✉ Node has not delivered the message, returning stored response");
191                    Ok(self.make_stored_response(&notification))
192                }
193            }
194            Err(error) => {
195                // Can't communicate with the node, attempt to stop using it
196                debug!("✉ Error while triggering notification check: {}", error);
197                self.remove_node_id(&user, node_id).await?;
198                Ok(self.make_stored_response(&notification))
199            }
200        }
201    }
202}
203
204impl WebPushRouter {
205    /// Use the same sort of error chokepoint that all the mobile clients use.
206    fn handle_error(&self, error: ApiErrorKind, vapid: Option<VapidHeaderWithKey>) -> ApiError {
207        let mut err = ApiError::from(error);
208        if let Some(Ok(claims)) = vapid.map(|v| v.vapid.claims()) {
209            let mut extras = err.extras.unwrap_or_default();
210            if let Some(sub) = claims.sub {
211                extras.extend([("sub".to_owned(), sub)]);
212            }
213            err.extras = Some(extras);
214        };
215        err
216    }
217
218    /// Consume and send the notification to the node
219    async fn send_notification(
220        &self,
221        notification: &Notification,
222        node_id: &str,
223    ) -> ApiResult<Response> {
224        let url = format!("{}/push/{}", node_id, notification.subscription.user.uaid);
225
226        let notification_out = notification.serialize_for_delivery()?;
227
228        trace!(
229            "⏩ out: Notification: {}, channel_id: {} :: {:?}",
230            &notification.subscription.user.uaid,
231            &notification.subscription.channel_id,
232            &notification_out,
233        );
234        Ok(self.http.put(&url).json(&notification_out).send().await?)
235    }
236
237    /// Notify the node to check for notifications for the user
238    async fn trigger_notification_check(
239        &self,
240        uaid: &Uuid,
241        node_id: &str,
242    ) -> Result<Response, reqwest::Error> {
243        let url = format!("{node_id}/notif/{uaid}");
244
245        self.http.put(&url).send().await
246    }
247
248    /// Store a notification in the database
249    async fn store_notification(&self, notification: &mut Notification) -> ApiResult<()> {
250        let result = self
251            .db
252            .save_message(
253                &notification.subscription.user.uaid,
254                notification.clone().into(),
255            )
256            .await
257            .map_err(|e| {
258                self.handle_error(
259                    ApiErrorKind::Router(RouterError::SaveDb(
260                        e,
261                        // try to extract the `sub` from the VAPID claims.
262                        notification.subscription.vapid.as_ref().map(|vapid| {
263                            vapid
264                                .vapid
265                                .claims()
266                                .ok()
267                                .and_then(|c| c.sub)
268                                .unwrap_or_default()
269                        }),
270                    )),
271                    notification.subscription.vapid.clone(),
272                )
273            });
274        #[cfg(feature = "reliable_report")]
275        notification
276            .record_reliability(
277                &self.reliability,
278                autopush_common::reliability::ReliabilityState::Stored,
279            )
280            .await;
281        result
282    }
283
284    /// Remove the node ID from a user. This is done if the user is no longer
285    /// connected to the node.
286    async fn remove_node_id(&self, user: &User, node_id: &str) -> ApiResult<()> {
287        self.metrics.incr(MetricName::UpdatesClientHostGone).ok();
288        let removed = self
289            .db
290            .remove_node_id(&user.uaid, node_id, user.connected_at, &user.version)
291            .await?;
292        if !removed {
293            debug!("✉ The node id was not removed");
294        }
295        Ok(())
296    }
297
298    /// Update metrics and create a response for when a notification has been directly forwarded to
299    /// an autopush server.
300    fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
301        self.make_response(notification, "Direct", StatusCode::CREATED)
302    }
303
304    /// Update metrics and create a response for when a notification has been stored in the database
305    /// for future transmission.
306    fn make_stored_response(&self, notification: &Notification) -> RouterResponse {
307        self.make_response(notification, "Stored", StatusCode::CREATED)
308    }
309
310    /// Update metrics and create a response after routing a notification
311    fn make_response(
312        &self,
313        notification: &Notification,
314        destination_tag: &str,
315        status: StatusCode,
316    ) -> RouterResponse {
317        self.metrics
318            .count_with_tags(
319                MetricName::NotificationMessageData.as_ref(),
320                notification.data.as_ref().map(String::len).unwrap_or(0) as i64,
321            )
322            .with_tag("destination", destination_tag)
323            .send();
324
325        RouterResponse {
326            status: actix_http::StatusCode::from_u16(status.as_u16()).unwrap_or_default(),
327            headers: {
328                let mut map = HashMap::new();
329                map.insert(
330                    "Location",
331                    self.endpoint_url
332                        .join(&format!("/m/{}", notification.message_id))
333                        .expect("Message ID is not URL-safe")
334                        .to_string(),
335                );
336                map.insert("TTL", notification.headers.ttl.to_string());
337                map
338            },
339            body: None,
340        }
341    }
342}
343
344#[cfg(test)]
345mod test {
346    use std::boxed::Box;
347    use std::sync::Arc;
348
349    use reqwest;
350
351    use crate::extractors::subscription::tests::{make_vapid, PUB_KEY};
352    use crate::headers::vapid::VapidClaims;
353    use autopush_common::errors::ReportableError;
354    #[cfg(feature = "reliable_report")]
355    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
356
357    use super::*;
358    use autopush_common::db::mock::MockDbClient;
359
360    fn make_router(db: Box<dyn DbClient>) -> WebPushRouter {
361        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
362        WebPushRouter {
363            db: db.clone(),
364            metrics: Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)),
365            http: reqwest::Client::new(),
366            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
367            #[cfg(feature = "reliable_report")]
368            reliability: Arc::new(
369                PushReliability::new(&None, db, &metrics, MAX_TRANSACTION_LOOP).unwrap(),
370            ),
371        }
372    }
373
374    #[tokio::test]
375    async fn pass_extras() {
376        let db = MockDbClient::new().into_boxed_arc();
377        let router = make_router(db);
378        let sub = "foo@example.com";
379        let vapid = make_vapid(
380            sub,
381            "https://push.services.mozilla.org",
382            VapidClaims::default_exp(),
383            PUB_KEY.to_owned(),
384        );
385
386        let err = router.handle_error(ApiErrorKind::LogCheck, Some(vapid));
387        assert!(err.extras().contains(&("sub", sub.to_owned())));
388    }
389}