autoendpoint/routers/apns/
router.rs

1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11    build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15    self, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
16    Response,
17    request::payload::{Payload, PayloadLike},
18};
19use actix_web::http::StatusCode;
20use async_trait::async_trait;
21use cadence::StatsdClient;
22use futures::{StreamExt, TryStreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use url::Url;
28use uuid::Uuid;
29
30/// Apple Push Notification Service router
31pub struct ApnsRouter {
32    /// A map from release channel to APNS client
33    clients: HashMap<String, ApnsClientData>,
34    settings: ApnsSettings,
35    endpoint_url: Url,
36    metrics: Arc<StatsdClient>,
37    db: Box<dyn DbClient>,
38    #[cfg(feature = "reliable_report")]
39    reliability: Arc<PushReliability>,
40}
41
42struct ApnsClientData {
43    client: Box<dyn ApnsClient>,
44    topic: String,
45}
46
47#[async_trait]
48trait ApnsClient: Send + Sync {
49    async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
50}
51
52#[async_trait]
53impl ApnsClient for a2::Client {
54    async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
55        self.send(payload).await
56    }
57}
58
59/// a2 does not allow for Deserialization of the APS structure.
60/// this is copied from that library
61#[derive(Deserialize, Serialize, Default, Debug, Clone)]
62#[serde(rename_all = "kebab-case")]
63#[allow(clippy::upper_case_acronyms)]
64pub struct ApsDeser<'a> {
65    // The notification content. Can be empty for silent notifications.
66    // Note, we overwrite this value, but it's copied and commented here
67    // so that future development notes the change.
68    // #[serde(skip_serializing_if = "Option::is_none")]
69    //pub alert: Option<a2::request::payload::APSAlert<'a>>,
70    /// A number shown on top of the app icon.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub badge: Option<u32>,
73
74    /// The name of the sound file to play when user receives the notification.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub sound: Option<&'a str>,
77
78    /// Set to one for silent notifications.
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub content_available: Option<u8>,
81
82    /// When a notification includes the category key, the system displays the
83    /// actions for that category as buttons in the banner or alert interface.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub category: Option<&'a str>,
86
87    /// If set to one, the app can change the notification content before
88    /// displaying it to the user.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub mutable_content: Option<u8>,
91
92    #[serde(skip_serializing_if = "Option::is_none")]
93    // Converted for Deserialization
94    // pub url_args: Option<&'a [&'a str]>,
95    pub url_args: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99// Replicate a2::request::notification::DefaultAlert
100// for lifetime reasons.
101pub struct ApsAlertHolder {
102    title: String,
103    subtitle: String,
104    body: String,
105    title_loc_key: String,
106    title_loc_args: Vec<String>,
107    action_loc_key: String,
108    loc_key: String,
109    loc_args: Vec<String>,
110    launch_image: String,
111}
112
113impl ApnsRouter {
114    /// Create a new APNS router. APNS clients will be initialized for each
115    /// channel listed in the settings.
116    pub async fn new(
117        settings: ApnsSettings,
118        endpoint_url: Url,
119        metrics: Arc<StatsdClient>,
120        db: Box<dyn DbClient>,
121        #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
122    ) -> Result<Self, ApnsError> {
123        let channels = settings.channels()?;
124
125        let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
126            .then(|(name, settings)| Self::create_client(name, settings))
127            .try_collect()
128            .await?;
129
130        trace!("Initialized {} APNs clients", clients.len());
131        Ok(Self {
132            clients,
133            settings,
134            endpoint_url,
135            metrics,
136            db,
137            #[cfg(feature = "reliable_report")]
138            reliability,
139        })
140    }
141
142    /// Create an APNS client for the channel
143    async fn create_client(
144        name: String,
145        settings: ApnsChannel,
146    ) -> Result<(String, ApnsClientData), ApnsError> {
147        let endpoint = if settings.sandbox {
148            Endpoint::Sandbox
149        } else {
150            Endpoint::Production
151        };
152        let cert = if !settings.cert.starts_with('-') {
153            tokio::fs::read(settings.cert).await?
154        } else {
155            settings.cert.as_bytes().to_vec()
156        };
157        let key = if !settings.key.starts_with('-') {
158            tokio::fs::read(settings.key).await?
159        } else {
160            settings.key.as_bytes().to_vec()
161        };
162        // Timeouts defined in ApnsSettings settings.rs config and can be modified.
163        // We define them to prevent possible a2 library changes that could
164        // create unexpected behavior if timeouts are altered.
165        // They currently map to values matching the detaults in the a2 lib v0.10.
166        let apns_settings = ApnsSettings::default();
167        let config = a2::ClientConfig {
168            endpoint,
169            request_timeout_secs: apns_settings.request_timeout_secs,
170            pool_idle_timeout_secs: apns_settings.pool_idle_timeout_secs,
171        };
172        let client = ApnsClientData {
173            client: Box::new(
174                a2::Client::certificate_parts(&cert, &key, config)
175                    .map_err(ApnsError::ApnsClient)?,
176            ),
177            topic: settings
178                .topic
179                .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
180        };
181
182        Ok((name, client))
183    }
184
185    /// The default APS data for a notification
186    fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
187        DefaultNotificationBuilder::new()
188            .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
189            .set_loc_key("SentTab.NoTabArrivingNotification.body")
190            .set_mutable_content()
191    }
192
193    /// Handle an error by logging, updating metrics, etc
194    async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
195        match &error {
196            a2::Error::ResponseError(response) => {
197                // capture the APNs error as a metric response. This allows us to spot trends.
198                // While APNS can return a number of errors (see a2::response::ErrorReason) we
199                // shouldn't encounter many of those.
200                let reason = response
201                    .error
202                    .as_ref()
203                    .map(|r| format!("{:?}", r.reason))
204                    .unwrap_or_else(|| "Unknown".to_owned());
205                let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
206                incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
207                if response.code == 410 {
208                    debug!("APNS recipient has been unregistered, removing user");
209                    if let Err(e) = self.db.remove_user(&uaid).await {
210                        warn!("Error while removing user due to APNS 410: {}", e);
211                    }
212
213                    return ApiError::from(ApnsError::Unregistered);
214                } else {
215                    warn!("APNS error: {:?}", response.error);
216                }
217            }
218            a2::Error::ConnectionError(e) => {
219                error!("APNS connection error: {:?}", e);
220                incr_error_metric(
221                    &self.metrics,
222                    "apns",
223                    channel,
224                    "connection_unavailable",
225                    StatusCode::SERVICE_UNAVAILABLE,
226                    None,
227                );
228            }
229            _ => {
230                warn!("Unknown error while sending APNS request: {}", error);
231                incr_error_metric(
232                    &self.metrics,
233                    "apns",
234                    channel,
235                    "unknown",
236                    StatusCode::BAD_GATEWAY,
237                    None,
238                );
239            }
240        }
241
242        ApiError::from(ApnsError::ApnsUpstream(error))
243    }
244
245    /// if we have any clients defined, this connection is "active"
246    pub fn active(&self) -> bool {
247        !self.clients.is_empty()
248    }
249
250    /// Derive an APS message from the replacement JSON block.
251    ///
252    /// This requires an external "holder" that contains the data that APS will refer to.
253    /// The holder should live in the same context as the `aps.build()` method.
254    fn derive_aps<'a>(
255        &self,
256        replacement: Value,
257        holder: &'a mut ApsAlertHolder,
258    ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
259        let mut aps = Self::default_aps();
260        // a2 does not have a way to bulk replace these values, so do them by hand.
261        // these could probably be turned into a macro, but hopefully, this is
262        // more one off and I didn't want to fight with the macro generator.
263        // This whole thing was included as a byproduct of
264        // https://bugzilla.mozilla.org/show_bug.cgi?id=1364403 which was put
265        // in place to help debug the iOS build. It was supposed to be temporary,
266        // but apparently bit-lock set in and now no one is super sure if it's
267        // still needed or used. (I want to get rid of this.)
268        if let Some(v) = replacement.get("title") {
269            if let Some(v) = v.as_str() {
270                v.clone_into(&mut holder.title);
271                aps = aps.set_title(&holder.title);
272            } else {
273                return Err(ApnsError::InvalidApsData);
274            }
275        }
276        if let Some(v) = replacement.get("subtitle") {
277            if let Some(v) = v.as_str() {
278                v.clone_into(&mut holder.subtitle);
279                aps = aps.set_subtitle(&holder.subtitle);
280            } else {
281                return Err(ApnsError::InvalidApsData);
282            }
283        }
284        if let Some(v) = replacement.get("body") {
285            if let Some(v) = v.as_str() {
286                v.clone_into(&mut holder.body);
287                aps = aps.set_body(&holder.body);
288            } else {
289                return Err(ApnsError::InvalidApsData);
290            }
291        }
292        if let Some(v) = replacement.get("title_loc_key") {
293            if let Some(v) = v.as_str() {
294                v.clone_into(&mut holder.title_loc_key);
295                aps = aps.set_title_loc_key(&holder.title_loc_key);
296            } else {
297                return Err(ApnsError::InvalidApsData);
298            }
299        }
300        if let Some(v) = replacement.get("title_loc_args") {
301            if let Some(v) = v.as_array() {
302                let mut args: Vec<String> = Vec::new();
303                for val in v {
304                    if let Some(value) = val.as_str() {
305                        args.push(value.to_owned())
306                    } else {
307                        return Err(ApnsError::InvalidApsData);
308                    }
309                }
310                holder.title_loc_args = args;
311                aps = aps.set_title_loc_args(&holder.title_loc_args);
312            } else {
313                return Err(ApnsError::InvalidApsData);
314            }
315        }
316        if let Some(v) = replacement.get("action_loc_key") {
317            if let Some(v) = v.as_str() {
318                v.clone_into(&mut holder.action_loc_key);
319                aps = aps.set_action_loc_key(&holder.action_loc_key);
320            } else {
321                return Err(ApnsError::InvalidApsData);
322            }
323        }
324        if let Some(v) = replacement.get("loc_key") {
325            if let Some(v) = v.as_str() {
326                v.clone_into(&mut holder.loc_key);
327                aps = aps.set_loc_key(&holder.loc_key);
328            } else {
329                return Err(ApnsError::InvalidApsData);
330            }
331        }
332        if let Some(v) = replacement.get("loc_args") {
333            if let Some(v) = v.as_array() {
334                let mut args: Vec<String> = Vec::new();
335                for val in v {
336                    if let Some(value) = val.as_str() {
337                        args.push(value.to_owned())
338                    } else {
339                        return Err(ApnsError::InvalidApsData);
340                    }
341                }
342                holder.loc_args = args;
343                aps = aps.set_loc_args(&holder.loc_args);
344            } else {
345                return Err(ApnsError::InvalidApsData);
346            }
347        }
348        if let Some(v) = replacement.get("launch_image") {
349            if let Some(v) = v.as_str() {
350                v.clone_into(&mut holder.launch_image);
351                aps = aps.set_launch_image(&holder.launch_image);
352            } else {
353                return Err(ApnsError::InvalidApsData);
354            }
355        }
356        // Honestly, we should just check to see if this is present
357        // we don't really care what the value is since we'll never
358        // use
359        if let Some(v) = replacement.get("mutable-content") {
360            if let Some(v) = v.as_i64() {
361                if v != 0 {
362                    aps = aps.set_mutable_content();
363                }
364            } else {
365                return Err(ApnsError::InvalidApsData);
366            }
367        }
368        Ok(aps)
369    }
370}
371
372#[async_trait(?Send)]
373impl Router for ApnsRouter {
374    fn register(
375        &self,
376        router_input: &RouterDataInput,
377        app_id: &str,
378    ) -> Result<HashMap<String, Value>, RouterError> {
379        if !self.clients.contains_key(app_id) {
380            return Err(ApnsError::InvalidReleaseChannel.into());
381        }
382
383        let mut router_data = HashMap::new();
384        router_data.insert(
385            "token".to_string(),
386            serde_json::to_value(&router_input.token).unwrap(),
387        );
388        router_data.insert(
389            "rel_channel".to_string(),
390            serde_json::to_value(app_id).unwrap(),
391        );
392
393        if let Some(aps) = &router_input.aps {
394            if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
395                return Err(ApnsError::InvalidApsData.into());
396            }
397            router_data.insert(
398                "aps".to_string(),
399                serde_json::to_value(aps.clone()).unwrap(),
400            );
401        }
402
403        Ok(router_data)
404    }
405
406    #[allow(unused_mut)]
407    async fn route_notification(
408        &self,
409        mut notification: Notification,
410    ) -> ApiResult<RouterResponse> {
411        debug!(
412            "Sending APNS notification to UAID {}",
413            notification.subscription.user.uaid
414        );
415        trace!("Notification = {:?}", notification);
416
417        // Build message data
418        let router_data = notification
419            .subscription
420            .user
421            .router_data
422            .as_ref()
423            .ok_or(ApnsError::NoDeviceToken)?
424            .clone();
425        let token = router_data
426            .get("token")
427            .and_then(Value::as_str)
428            .ok_or(ApnsError::NoDeviceToken)?;
429        let channel = router_data
430            .get("rel_channel")
431            .and_then(Value::as_str)
432            .ok_or(ApnsError::NoReleaseChannel)?;
433        let aps_json = router_data.get("aps").cloned();
434        let mut message_data = build_message_data(&notification)?;
435        message_data.insert("ver", notification.message_id.clone());
436
437        // Get client and build payload
438        let ApnsClientData { client, topic } = self
439            .clients
440            .get(channel)
441            .ok_or(ApnsError::InvalidReleaseChannel)?;
442
443        // A simple bucket variable so that I don't have to deal with fun lifetime issues if we need
444        // to derive.
445        let mut holder = ApsAlertHolder::default();
446
447        // If we are provided a replacement APS block, derive an APS message from it, otherwise
448        // start with a blank APS message.
449        let aps = if let Some(replacement) = aps_json {
450            self.derive_aps(replacement, &mut holder)?
451        } else {
452            Self::default_aps()
453        };
454
455        // Finalize the APS object.
456        let mut payload = aps.build(
457            token,
458            NotificationOptions {
459                apns_id: None,
460                apns_priority: Some(Priority::High),
461                apns_topic: Some(topic),
462                apns_collapse_id: None,
463                apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
464                ..Default::default()
465            },
466        );
467        let message_length = message_data.get("body").map(|s| s.len()).unwrap_or(0);
468        payload.data = message_data
469            .into_iter()
470            .map(|(k, v)| (k, Value::String(v)))
471            .collect();
472
473        // Check size limit
474        let payload_json = payload
475            .clone()
476            .to_json_string()
477            .map_err(|_| RouterError::TooMuchData(message_length))?;
478        message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
479
480        // Send to APNS
481        trace!("Sending message to APNS: {:?}", payload);
482        if let Err(e) = client.send(payload).await {
483            #[cfg(feature = "reliable_report")]
484            notification
485                .record_reliability(
486                    &self.reliability,
487                    autopush_common::reliability::ReliabilityState::Errored,
488                )
489                .await;
490            return Err(self
491                .handle_error(e, notification.subscription.user.uaid, channel)
492                .await);
493        }
494
495        trace!("APNS request was successful");
496        incr_success_metrics(&self.metrics, "apns", channel, &notification);
497        #[cfg(feature = "reliable_report")]
498        {
499            // Record that we've sent the message out to APNS.
500            // We can't set the state here because the notification isn't
501            // mutable, but we are also essentially consuming the
502            // notification nothing else should modify it.
503            notification
504                .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
505                .await;
506        }
507
508        Ok(RouterResponse::success(
509            self.endpoint_url
510                .join(&format!("/m/{}", notification.message_id))
511                .expect("Message ID is not URL-safe")
512                .to_string(),
513            notification.headers.ttl as usize,
514        ))
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use crate::error::ApiErrorKind;
521    use crate::extractors::routers::RouterType;
522    use crate::routers::apns::error::ApnsError;
523    use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
524    use crate::routers::apns::settings::ApnsSettings;
525    use crate::routers::common::tests::{CHANNEL_ID, make_notification};
526    use crate::routers::{Router, RouterError, RouterResponse};
527    use a2::request::payload::Payload;
528    use a2::{Error, Response};
529    use async_trait::async_trait;
530    use autopush_common::db::client::DbClient;
531    use autopush_common::db::mock::MockDbClient;
532    #[cfg(feature = "reliable_report")]
533    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
534    use cadence::StatsdClient;
535    use mockall::predicate;
536    use std::collections::HashMap;
537    use std::sync::Arc;
538    use url::Url;
539
540    const DEVICE_TOKEN: &str = "test-token";
541    const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
542
543    #[allow(clippy::type_complexity)]
544    /// A mock APNS client which allows one to supply a custom APNS response/error
545    struct MockApnsClient {
546        send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
547    }
548
549    #[async_trait]
550    impl ApnsClient for MockApnsClient {
551        async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
552            (self.send_fn)(payload)
553        }
554    }
555
556    impl MockApnsClient {
557        fn new<F>(send_fn: F) -> Self
558        where
559            F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
560            F: Send + Sync + 'static,
561        {
562            Self {
563                send_fn: Box::new(send_fn),
564            }
565        }
566    }
567
568    /// Create a successful APNS response
569    fn apns_success_response() -> a2::Response {
570        a2::Response {
571            error: None,
572            apns_id: Some(APNS_ID.to_string()),
573            code: 200,
574        }
575    }
576
577    /// Create a router for testing, using the given APNS client
578    fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
579        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
580
581        ApnsRouter {
582            clients: {
583                let mut map = HashMap::new();
584                map.insert(
585                    "test-channel".to_string(),
586                    ApnsClientData {
587                        client: Box::new(client),
588                        topic: "test-topic".to_string(),
589                    },
590                );
591                map
592            },
593            settings: ApnsSettings::default(),
594            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
595            metrics: metrics.clone(),
596            db: db.clone(),
597            #[cfg(feature = "reliable_report")]
598            reliability: Arc::new(
599                PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
600            ),
601        }
602    }
603
604    /// Create default user router data
605    fn default_router_data() -> HashMap<String, serde_json::Value> {
606        let mut map = HashMap::new();
607        map.insert(
608            "token".to_string(),
609            serde_json::to_value(DEVICE_TOKEN).unwrap(),
610        );
611        map.insert(
612            "rel_channel".to_string(),
613            serde_json::to_value("test-channel").unwrap(),
614        );
615        map
616    }
617
618    /// A notification with no data is packaged correctly and sent to APNS
619    #[tokio::test]
620    async fn successful_routing_no_data() {
621        use a2::NotificationBuilder;
622
623        let client = MockApnsClient::new(|payload| {
624            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
625            assert_eq!(
626                serde_json::to_value(payload.aps).unwrap(),
627                serde_json::to_value(built.aps).unwrap()
628            );
629            assert_eq!(payload.device_token, DEVICE_TOKEN);
630            assert_eq!(payload.options.apns_topic, Some("test-topic"));
631            assert_eq!(
632                serde_json::to_value(payload.data).unwrap(),
633                serde_json::json!({
634                    "chid": CHANNEL_ID,
635                    "ver": "test-message-id"
636                })
637            );
638
639            Ok(apns_success_response())
640        });
641        let mdb = MockDbClient::new();
642        let db = mdb.into_boxed_arc();
643        let router = make_router(client, db);
644        let notification = make_notification(default_router_data(), None, RouterType::APNS);
645
646        let result = router.route_notification(notification).await;
647        assert!(result.is_ok(), "result = {result:?}");
648        assert_eq!(
649            result.unwrap(),
650            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
651        );
652    }
653
654    /// A notification with data is packaged correctly and sent to APNS
655    #[tokio::test]
656    async fn successful_routing_with_data() {
657        use a2::NotificationBuilder;
658
659        let client = MockApnsClient::new(|payload| {
660            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
661            assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
662            assert_eq!(payload.device_token, DEVICE_TOKEN);
663            assert_eq!(payload.options.apns_topic, Some("test-topic"));
664            assert_eq!(
665                serde_json::to_value(payload.data).unwrap(),
666                serde_json::json!({
667                    "chid": CHANNEL_ID,
668                    "ver": "test-message-id",
669                    "body": "test-data",
670                    "con": "test-encoding",
671                    "enc": "test-encryption",
672                    "cryptokey": "test-crypto-key",
673                    "enckey": "test-encryption-key"
674                })
675            );
676
677            Ok(apns_success_response())
678        });
679        let mdb = MockDbClient::new();
680        let db = mdb.into_boxed_arc();
681        let router = make_router(client, db);
682        let data = "test-data".to_string();
683        let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
684
685        let result = router.route_notification(notification).await;
686        assert!(result.is_ok(), "result = {result:?}");
687        assert_eq!(
688            result.unwrap(),
689            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
690        );
691    }
692
693    /// If there is no client for the user's release channel, an error is
694    /// returned and the APNS request is not sent.
695    #[tokio::test]
696    async fn missing_client() {
697        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
698        let db = MockDbClient::new().into_boxed_arc();
699        let router = make_router(client, db);
700        let mut router_data = default_router_data();
701        router_data.insert(
702            "rel_channel".to_string(),
703            serde_json::to_value("unknown-app-id").unwrap(),
704        );
705        let notification = make_notification(router_data, None, RouterType::APNS);
706
707        let result = router.route_notification(notification).await;
708        assert!(result.is_err());
709        assert!(
710            matches!(
711                result.as_ref().unwrap_err().kind,
712                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
713            ),
714            "result = {result:?}"
715        );
716    }
717
718    /// If APNS says the user doesn't exist anymore, we return a specific error
719    /// and remove the user from the database.
720    #[tokio::test]
721    async fn user_not_found() {
722        let client = MockApnsClient::new(|_| {
723            Err(a2::Error::ResponseError(a2::Response {
724                error: Some(a2::ErrorBody {
725                    reason: a2::ErrorReason::Unregistered,
726                    timestamp: Some(0),
727                }),
728                apns_id: None,
729                code: 410,
730            }))
731        });
732        let notification = make_notification(default_router_data(), None, RouterType::APNS);
733        let mut db = MockDbClient::new();
734        db.expect_remove_user()
735            .with(predicate::eq(notification.subscription.user.uaid))
736            .times(1)
737            .return_once(|_| Ok(()));
738        let router = make_router(client, db.into_boxed_arc());
739
740        let result = router.route_notification(notification).await;
741        assert!(result.is_err());
742        assert!(
743            matches!(
744                result.as_ref().unwrap_err().kind,
745                ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
746            ),
747            "result = {result:?}"
748        );
749    }
750
751    /// APNS errors (other than Unregistered) are wrapped and returned
752    #[tokio::test]
753    async fn upstream_error() {
754        let client = MockApnsClient::new(|_| {
755            Err(a2::Error::ResponseError(a2::Response {
756                error: Some(a2::ErrorBody {
757                    reason: a2::ErrorReason::BadCertificate,
758                    timestamp: None,
759                }),
760                apns_id: None,
761                code: 403,
762            }))
763        });
764        let db = MockDbClient::new().into_boxed_arc();
765        let router = make_router(client, db);
766        let notification = make_notification(default_router_data(), None, RouterType::APNS);
767
768        let result = router.route_notification(notification).await;
769        assert!(result.is_err());
770        assert!(
771            matches!(
772                result.as_ref().unwrap_err().kind,
773                ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
774                    a2::Error::ResponseError(a2::Response {
775                        error: Some(a2::ErrorBody {
776                            reason: a2::ErrorReason::BadCertificate,
777                            timestamp: None,
778                        }),
779                        apns_id: None,
780                        code: 403,
781                    })
782                )))
783            ),
784            "result = {result:?}"
785        );
786    }
787
788    /// An error is returned if the user's APS data is invalid
789    #[tokio::test]
790    async fn invalid_aps_data() {
791        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
792        let db = MockDbClient::new().into_boxed_arc();
793        let router = make_router(client, db);
794        let mut router_data = default_router_data();
795        router_data.insert(
796            "aps".to_string(),
797            serde_json::json!({"mutable-content": "should be a number"}),
798        );
799        let notification = make_notification(router_data, None, RouterType::APNS);
800
801        let result = router.route_notification(notification).await;
802        assert!(result.is_err());
803        assert!(
804            matches!(
805                result.as_ref().unwrap_err().kind,
806                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
807            ),
808            "result = {result:?}"
809        );
810    }
811}