autoendpoint/routers/apns/
router.rs

1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11    build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15    self,
16    request::payload::{Payload, PayloadLike},
17    DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
18    Response,
19};
20use actix_web::http::StatusCode;
21use async_trait::async_trait;
22use cadence::StatsdClient;
23use futures::{StreamExt, TryStreamExt};
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26use std::collections::HashMap;
27use std::sync::Arc;
28use url::Url;
29use uuid::Uuid;
30
31/// Apple Push Notification Service router
32pub struct ApnsRouter {
33    /// A map from release channel to APNS client
34    clients: HashMap<String, ApnsClientData>,
35    settings: ApnsSettings,
36    endpoint_url: Url,
37    metrics: Arc<StatsdClient>,
38    db: Box<dyn DbClient>,
39    #[cfg(feature = "reliable_report")]
40    reliability: Arc<PushReliability>,
41}
42
43struct ApnsClientData {
44    client: Box<dyn ApnsClient>,
45    topic: String,
46}
47
48#[async_trait]
49trait ApnsClient: Send + Sync {
50    async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
51}
52
53#[async_trait]
54impl ApnsClient for a2::Client {
55    async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
56        self.send(payload).await
57    }
58}
59
60/// a2 does not allow for Deserialization of the APS structure.
61/// this is copied from that library
62#[derive(Deserialize, Serialize, Default, Debug, Clone)]
63#[serde(rename_all = "kebab-case")]
64#[allow(clippy::upper_case_acronyms)]
65pub struct ApsDeser<'a> {
66    // The notification content. Can be empty for silent notifications.
67    // Note, we overwrite this value, but it's copied and commented here
68    // so that future development notes the change.
69    // #[serde(skip_serializing_if = "Option::is_none")]
70    //pub alert: Option<a2::request::payload::APSAlert<'a>>,
71    /// A number shown on top of the app icon.
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub badge: Option<u32>,
74
75    /// The name of the sound file to play when user receives the notification.
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub sound: Option<&'a str>,
78
79    /// Set to one for silent notifications.
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub content_available: Option<u8>,
82
83    /// When a notification includes the category key, the system displays the
84    /// actions for that category as buttons in the banner or alert interface.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub category: Option<&'a str>,
87
88    /// If set to one, the app can change the notification content before
89    /// displaying it to the user.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub mutable_content: Option<u8>,
92
93    #[serde(skip_serializing_if = "Option::is_none")]
94    // Converted for Deserialization
95    // pub url_args: Option<&'a [&'a str]>,
96    pub url_args: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100// Replicate a2::request::notification::DefaultAlert
101// for lifetime reasons.
102pub struct ApsAlertHolder {
103    title: String,
104    subtitle: String,
105    body: String,
106    title_loc_key: String,
107    title_loc_args: Vec<String>,
108    action_loc_key: String,
109    loc_key: String,
110    loc_args: Vec<String>,
111    launch_image: String,
112}
113
114impl ApnsRouter {
115    /// Create a new APNS router. APNS clients will be initialized for each
116    /// channel listed in the settings.
117    pub async fn new(
118        settings: ApnsSettings,
119        endpoint_url: Url,
120        metrics: Arc<StatsdClient>,
121        db: Box<dyn DbClient>,
122        #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
123    ) -> Result<Self, ApnsError> {
124        let channels = settings.channels()?;
125
126        let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
127            .then(|(name, settings)| Self::create_client(name, settings))
128            .try_collect()
129            .await?;
130
131        trace!("Initialized {} APNs clients", clients.len());
132        Ok(Self {
133            clients,
134            settings,
135            endpoint_url,
136            metrics,
137            db,
138            #[cfg(feature = "reliable_report")]
139            reliability,
140        })
141    }
142
143    /// Create an APNS client for the channel
144    async fn create_client(
145        name: String,
146        settings: ApnsChannel,
147    ) -> Result<(String, ApnsClientData), ApnsError> {
148        let endpoint = if settings.sandbox {
149            Endpoint::Sandbox
150        } else {
151            Endpoint::Production
152        };
153        let cert = if !settings.cert.starts_with('-') {
154            tokio::fs::read(settings.cert).await?
155        } else {
156            settings.cert.as_bytes().to_vec()
157        };
158        let key = if !settings.key.starts_with('-') {
159            tokio::fs::read(settings.key).await?
160        } else {
161            settings.key.as_bytes().to_vec()
162        };
163        // Timeouts defined in ApnsSettings settings.rs config and can be modified.
164        // We define them to prevent possible a2 library changes that could
165        // create unexpected behavior if timeouts are altered.
166        // They currently map to values matching the detaults in the a2 lib v0.10.
167        let apns_settings = ApnsSettings::default();
168        let config = a2::ClientConfig {
169            endpoint,
170            request_timeout_secs: apns_settings.request_timeout_secs,
171            pool_idle_timeout_secs: apns_settings.pool_idle_timeout_secs,
172        };
173        let client = ApnsClientData {
174            client: Box::new(
175                a2::Client::certificate_parts(&cert, &key, config)
176                    .map_err(ApnsError::ApnsClient)?,
177            ),
178            topic: settings
179                .topic
180                .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
181        };
182
183        Ok((name, client))
184    }
185
186    /// The default APS data for a notification
187    fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
188        DefaultNotificationBuilder::new()
189            .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
190            .set_loc_key("SentTab.NoTabArrivingNotification.body")
191            .set_mutable_content()
192    }
193
194    /// Handle an error by logging, updating metrics, etc
195    async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
196        match &error {
197            a2::Error::ResponseError(response) => {
198                // capture the APNs error as a metric response. This allows us to spot trends.
199                // While APNS can return a number of errors (see a2::response::ErrorReason) we
200                // shouldn't encounter many of those.
201                let reason = response
202                    .error
203                    .as_ref()
204                    .map(|r| format!("{:?}", r.reason))
205                    .unwrap_or_else(|| "Unknown".to_owned());
206                let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
207                incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
208                if response.code == 410 {
209                    debug!("APNS recipient has been unregistered, removing user");
210                    if let Err(e) = self.db.remove_user(&uaid).await {
211                        warn!("Error while removing user due to APNS 410: {}", e);
212                    }
213
214                    return ApiError::from(ApnsError::Unregistered);
215                } else {
216                    warn!("APNS error: {:?}", response.error);
217                }
218            }
219            a2::Error::ConnectionError(e) => {
220                error!("APNS connection error: {:?}", e);
221                incr_error_metric(
222                    &self.metrics,
223                    "apns",
224                    channel,
225                    "connection_unavailable",
226                    StatusCode::SERVICE_UNAVAILABLE,
227                    None,
228                );
229            }
230            _ => {
231                warn!("Unknown error while sending APNS request: {}", error);
232                incr_error_metric(
233                    &self.metrics,
234                    "apns",
235                    channel,
236                    "unknown",
237                    StatusCode::BAD_GATEWAY,
238                    None,
239                );
240            }
241        }
242
243        ApiError::from(ApnsError::ApnsUpstream(error))
244    }
245
246    /// if we have any clients defined, this connection is "active"
247    pub fn active(&self) -> bool {
248        !self.clients.is_empty()
249    }
250
251    /// Derive an APS message from the replacement JSON block.
252    ///
253    /// This requires an external "holder" that contains the data that APS will refer to.
254    /// The holder should live in the same context as the `aps.build()` method.
255    fn derive_aps<'a>(
256        &self,
257        replacement: Value,
258        holder: &'a mut ApsAlertHolder,
259    ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
260        let mut aps = Self::default_aps();
261        // a2 does not have a way to bulk replace these values, so do them by hand.
262        // these could probably be turned into a macro, but hopefully, this is
263        // more one off and I didn't want to fight with the macro generator.
264        // This whole thing was included as a byproduct of
265        // https://bugzilla.mozilla.org/show_bug.cgi?id=1364403 which was put
266        // in place to help debug the iOS build. It was supposed to be temporary,
267        // but apparently bit-lock set in and now no one is super sure if it's
268        // still needed or used. (I want to get rid of this.)
269        if let Some(v) = replacement.get("title") {
270            if let Some(v) = v.as_str() {
271                v.clone_into(&mut holder.title);
272                aps = aps.set_title(&holder.title);
273            } else {
274                return Err(ApnsError::InvalidApsData);
275            }
276        }
277        if let Some(v) = replacement.get("subtitle") {
278            if let Some(v) = v.as_str() {
279                v.clone_into(&mut holder.subtitle);
280                aps = aps.set_subtitle(&holder.subtitle);
281            } else {
282                return Err(ApnsError::InvalidApsData);
283            }
284        }
285        if let Some(v) = replacement.get("body") {
286            if let Some(v) = v.as_str() {
287                v.clone_into(&mut holder.body);
288                aps = aps.set_body(&holder.body);
289            } else {
290                return Err(ApnsError::InvalidApsData);
291            }
292        }
293        if let Some(v) = replacement.get("title_loc_key") {
294            if let Some(v) = v.as_str() {
295                v.clone_into(&mut holder.title_loc_key);
296                aps = aps.set_title_loc_key(&holder.title_loc_key);
297            } else {
298                return Err(ApnsError::InvalidApsData);
299            }
300        }
301        if let Some(v) = replacement.get("title_loc_args") {
302            if let Some(v) = v.as_array() {
303                let mut args: Vec<String> = Vec::new();
304                for val in v {
305                    if let Some(value) = val.as_str() {
306                        args.push(value.to_owned())
307                    } else {
308                        return Err(ApnsError::InvalidApsData);
309                    }
310                }
311                holder.title_loc_args = args;
312                aps = aps.set_title_loc_args(&holder.title_loc_args);
313            } else {
314                return Err(ApnsError::InvalidApsData);
315            }
316        }
317        if let Some(v) = replacement.get("action_loc_key") {
318            if let Some(v) = v.as_str() {
319                v.clone_into(&mut holder.action_loc_key);
320                aps = aps.set_action_loc_key(&holder.action_loc_key);
321            } else {
322                return Err(ApnsError::InvalidApsData);
323            }
324        }
325        if let Some(v) = replacement.get("loc_key") {
326            if let Some(v) = v.as_str() {
327                v.clone_into(&mut holder.loc_key);
328                aps = aps.set_loc_key(&holder.loc_key);
329            } else {
330                return Err(ApnsError::InvalidApsData);
331            }
332        }
333        if let Some(v) = replacement.get("loc_args") {
334            if let Some(v) = v.as_array() {
335                let mut args: Vec<String> = Vec::new();
336                for val in v {
337                    if let Some(value) = val.as_str() {
338                        args.push(value.to_owned())
339                    } else {
340                        return Err(ApnsError::InvalidApsData);
341                    }
342                }
343                holder.loc_args = args;
344                aps = aps.set_loc_args(&holder.loc_args);
345            } else {
346                return Err(ApnsError::InvalidApsData);
347            }
348        }
349        if let Some(v) = replacement.get("launch_image") {
350            if let Some(v) = v.as_str() {
351                v.clone_into(&mut holder.launch_image);
352                aps = aps.set_launch_image(&holder.launch_image);
353            } else {
354                return Err(ApnsError::InvalidApsData);
355            }
356        }
357        // Honestly, we should just check to see if this is present
358        // we don't really care what the value is since we'll never
359        // use
360        if let Some(v) = replacement.get("mutable-content") {
361            if let Some(v) = v.as_i64() {
362                if v != 0 {
363                    aps = aps.set_mutable_content();
364                }
365            } else {
366                return Err(ApnsError::InvalidApsData);
367            }
368        }
369        Ok(aps)
370    }
371}
372
373#[async_trait(?Send)]
374impl Router for ApnsRouter {
375    fn register(
376        &self,
377        router_input: &RouterDataInput,
378        app_id: &str,
379    ) -> Result<HashMap<String, Value>, RouterError> {
380        if !self.clients.contains_key(app_id) {
381            return Err(ApnsError::InvalidReleaseChannel.into());
382        }
383
384        let mut router_data = HashMap::new();
385        router_data.insert(
386            "token".to_string(),
387            serde_json::to_value(&router_input.token).unwrap(),
388        );
389        router_data.insert(
390            "rel_channel".to_string(),
391            serde_json::to_value(app_id).unwrap(),
392        );
393
394        if let Some(aps) = &router_input.aps {
395            if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
396                return Err(ApnsError::InvalidApsData.into());
397            }
398            router_data.insert(
399                "aps".to_string(),
400                serde_json::to_value(aps.clone()).unwrap(),
401            );
402        }
403
404        Ok(router_data)
405    }
406
407    #[allow(unused_mut)]
408    async fn route_notification(
409        &self,
410        mut notification: Notification,
411    ) -> ApiResult<RouterResponse> {
412        debug!(
413            "Sending APNS notification to UAID {}",
414            notification.subscription.user.uaid
415        );
416        trace!("Notification = {:?}", notification);
417
418        // Build message data
419        let router_data = notification
420            .subscription
421            .user
422            .router_data
423            .as_ref()
424            .ok_or(ApnsError::NoDeviceToken)?
425            .clone();
426        let token = router_data
427            .get("token")
428            .and_then(Value::as_str)
429            .ok_or(ApnsError::NoDeviceToken)?;
430        let channel = router_data
431            .get("rel_channel")
432            .and_then(Value::as_str)
433            .ok_or(ApnsError::NoReleaseChannel)?;
434        let aps_json = router_data.get("aps").cloned();
435        let mut message_data = build_message_data(&notification)?;
436        message_data.insert("ver", notification.message_id.clone());
437
438        // Get client and build payload
439        let ApnsClientData { client, topic } = self
440            .clients
441            .get(channel)
442            .ok_or(ApnsError::InvalidReleaseChannel)?;
443
444        // A simple bucket variable so that I don't have to deal with fun lifetime issues if we need
445        // to derive.
446        let mut holder = ApsAlertHolder::default();
447
448        // If we are provided a replacement APS block, derive an APS message from it, otherwise
449        // start with a blank APS message.
450        let aps = if let Some(replacement) = aps_json {
451            self.derive_aps(replacement, &mut holder)?
452        } else {
453            Self::default_aps()
454        };
455
456        // Finalize the APS object.
457        let mut payload = aps.build(
458            token,
459            NotificationOptions {
460                apns_id: None,
461                apns_priority: Some(Priority::High),
462                apns_topic: Some(topic),
463                apns_collapse_id: None,
464                apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
465                ..Default::default()
466            },
467        );
468        payload.data = message_data
469            .into_iter()
470            .map(|(k, v)| (k, Value::String(v)))
471            .collect();
472
473        // Check size limit
474        let payload_json = payload
475            .clone()
476            .to_json_string()
477            .map_err(ApnsError::SizeLimit)?;
478        message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
479
480        // Send to APNS
481        trace!("Sending message to APNS: {:?}", payload);
482        if let Err(e) = client.send(payload).await {
483            #[cfg(feature = "reliable_report")]
484            notification
485                .record_reliability(
486                    &self.reliability,
487                    autopush_common::reliability::ReliabilityState::Errored,
488                )
489                .await;
490            return Err(self
491                .handle_error(e, notification.subscription.user.uaid, channel)
492                .await);
493        }
494
495        trace!("APNS request was successful");
496        incr_success_metrics(&self.metrics, "apns", channel, &notification);
497        #[cfg(feature = "reliable_report")]
498        {
499            // Record that we've sent the message out to APNS.
500            // We can't set the state here because the notification isn't
501            // mutable, but we are also essentially consuming the
502            // notification nothing else should modify it.
503            notification
504                .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
505                .await;
506        }
507
508        Ok(RouterResponse::success(
509            self.endpoint_url
510                .join(&format!("/m/{}", notification.message_id))
511                .expect("Message ID is not URL-safe")
512                .to_string(),
513            notification.headers.ttl as usize,
514        ))
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use crate::error::ApiErrorKind;
521    use crate::extractors::routers::RouterType;
522    use crate::routers::apns::error::ApnsError;
523    use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
524    use crate::routers::apns::settings::ApnsSettings;
525    use crate::routers::common::tests::{make_notification, CHANNEL_ID};
526    use crate::routers::{Router, RouterError, RouterResponse};
527    use a2::request::payload::Payload;
528    use a2::{Error, Response};
529    use async_trait::async_trait;
530    use autopush_common::db::client::DbClient;
531    use autopush_common::db::mock::MockDbClient;
532    #[cfg(feature = "reliable_report")]
533    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
534    use cadence::StatsdClient;
535    use mockall::predicate;
536    use std::collections::HashMap;
537    use std::sync::Arc;
538    use url::Url;
539
540    const DEVICE_TOKEN: &str = "test-token";
541    const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
542
543    #[allow(clippy::type_complexity)]
544    /// A mock APNS client which allows one to supply a custom APNS response/error
545    struct MockApnsClient {
546        send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
547    }
548
549    #[async_trait]
550    impl ApnsClient for MockApnsClient {
551        async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
552            (self.send_fn)(payload)
553        }
554    }
555
556    impl MockApnsClient {
557        fn new<F>(send_fn: F) -> Self
558        where
559            F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
560            F: Send + Sync + 'static,
561        {
562            Self {
563                send_fn: Box::new(send_fn),
564            }
565        }
566    }
567
568    /// Create a successful APNS response
569    fn apns_success_response() -> a2::Response {
570        a2::Response {
571            error: None,
572            apns_id: Some(APNS_ID.to_string()),
573            code: 200,
574        }
575    }
576
577    /// Create a router for testing, using the given APNS client
578    fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
579        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
580
581        ApnsRouter {
582            clients: {
583                let mut map = HashMap::new();
584                map.insert(
585                    "test-channel".to_string(),
586                    ApnsClientData {
587                        client: Box::new(client),
588                        topic: "test-topic".to_string(),
589                    },
590                );
591                map
592            },
593            settings: ApnsSettings::default(),
594            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
595            metrics: Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)),
596            db: db.clone(),
597            #[cfg(feature = "reliable_report")]
598            reliability: Arc::new(
599                PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
600            ),
601        }
602    }
603
604    /// Create default user router data
605    fn default_router_data() -> HashMap<String, serde_json::Value> {
606        let mut map = HashMap::new();
607        map.insert(
608            "token".to_string(),
609            serde_json::to_value(DEVICE_TOKEN).unwrap(),
610        );
611        map.insert(
612            "rel_channel".to_string(),
613            serde_json::to_value("test-channel").unwrap(),
614        );
615        map
616    }
617
618    /// A notification with no data is packaged correctly and sent to APNS
619    #[tokio::test]
620    async fn successful_routing_no_data() {
621        use a2::NotificationBuilder;
622
623        let client = MockApnsClient::new(|payload| {
624            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
625            assert_eq!(
626                serde_json::to_value(payload.aps).unwrap(),
627                serde_json::to_value(built.aps).unwrap()
628            );
629            assert_eq!(payload.device_token, DEVICE_TOKEN);
630            assert_eq!(payload.options.apns_topic, Some("test-topic"));
631            assert_eq!(
632                serde_json::to_value(payload.data).unwrap(),
633                serde_json::json!({
634                    "chid": CHANNEL_ID,
635                    "ver": "test-message-id"
636                })
637            );
638
639            Ok(apns_success_response())
640        });
641        let mdb = MockDbClient::new();
642        let db = mdb.into_boxed_arc();
643        let router = make_router(client, db);
644        let notification = make_notification(default_router_data(), None, RouterType::APNS);
645
646        let result = router.route_notification(notification).await;
647        assert!(result.is_ok(), "result = {result:?}");
648        assert_eq!(
649            result.unwrap(),
650            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
651        );
652    }
653
654    /// A notification with data is packaged correctly and sent to APNS
655    #[tokio::test]
656    async fn successful_routing_with_data() {
657        use a2::NotificationBuilder;
658
659        let client = MockApnsClient::new(|payload| {
660            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
661            assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
662            assert_eq!(payload.device_token, DEVICE_TOKEN);
663            assert_eq!(payload.options.apns_topic, Some("test-topic"));
664            assert_eq!(
665                serde_json::to_value(payload.data).unwrap(),
666                serde_json::json!({
667                    "chid": CHANNEL_ID,
668                    "ver": "test-message-id",
669                    "body": "test-data",
670                    "con": "test-encoding",
671                    "enc": "test-encryption",
672                    "cryptokey": "test-crypto-key",
673                    "enckey": "test-encryption-key"
674                })
675            );
676
677            Ok(apns_success_response())
678        });
679        let mdb = MockDbClient::new();
680        let db = mdb.into_boxed_arc();
681        let router = make_router(client, db);
682        let data = "test-data".to_string();
683        let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
684
685        let result = router.route_notification(notification).await;
686        assert!(result.is_ok(), "result = {result:?}");
687        assert_eq!(
688            result.unwrap(),
689            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
690        );
691    }
692
693    /// If there is no client for the user's release channel, an error is
694    /// returned and the APNS request is not sent.
695    #[tokio::test]
696    async fn missing_client() {
697        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
698        let db = MockDbClient::new().into_boxed_arc();
699        let router = make_router(client, db);
700        let mut router_data = default_router_data();
701        router_data.insert(
702            "rel_channel".to_string(),
703            serde_json::to_value("unknown-app-id").unwrap(),
704        );
705        let notification = make_notification(router_data, None, RouterType::APNS);
706
707        let result = router.route_notification(notification).await;
708        assert!(result.is_err());
709        assert!(
710            matches!(
711                result.as_ref().unwrap_err().kind,
712                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
713            ),
714            "result = {result:?}"
715        );
716    }
717
718    /// If APNS says the user doesn't exist anymore, we return a specific error
719    /// and remove the user from the database.
720    #[tokio::test]
721    async fn user_not_found() {
722        let client = MockApnsClient::new(|_| {
723            Err(a2::Error::ResponseError(a2::Response {
724                error: Some(a2::ErrorBody {
725                    reason: a2::ErrorReason::Unregistered,
726                    timestamp: Some(0),
727                }),
728                apns_id: None,
729                code: 410,
730            }))
731        });
732        let notification = make_notification(default_router_data(), None, RouterType::APNS);
733        let mut db = MockDbClient::new();
734        db.expect_remove_user()
735            .with(predicate::eq(notification.subscription.user.uaid))
736            .times(1)
737            .return_once(|_| Ok(()));
738        let router = make_router(client, db.into_boxed_arc());
739
740        let result = router.route_notification(notification).await;
741        assert!(result.is_err());
742        assert!(
743            matches!(
744                result.as_ref().unwrap_err().kind,
745                ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
746            ),
747            "result = {result:?}"
748        );
749    }
750
751    /// APNS errors (other than Unregistered) are wrapped and returned
752    #[tokio::test]
753    async fn upstream_error() {
754        let client = MockApnsClient::new(|_| {
755            Err(a2::Error::ResponseError(a2::Response {
756                error: Some(a2::ErrorBody {
757                    reason: a2::ErrorReason::BadCertificate,
758                    timestamp: None,
759                }),
760                apns_id: None,
761                code: 403,
762            }))
763        });
764        let db = MockDbClient::new().into_boxed_arc();
765        let router = make_router(client, db);
766        let notification = make_notification(default_router_data(), None, RouterType::APNS);
767
768        let result = router.route_notification(notification).await;
769        assert!(result.is_err());
770        assert!(
771            matches!(
772                result.as_ref().unwrap_err().kind,
773                ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
774                    a2::Error::ResponseError(a2::Response {
775                        error: Some(a2::ErrorBody {
776                            reason: a2::ErrorReason::BadCertificate,
777                            timestamp: None,
778                        }),
779                        apns_id: None,
780                        code: 403,
781                    })
782                )))
783            ),
784            "result = {result:?}"
785        );
786    }
787
788    /// An error is returned if the user's APS data is invalid
789    #[tokio::test]
790    async fn invalid_aps_data() {
791        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
792        let db = MockDbClient::new().into_boxed_arc();
793        let router = make_router(client, db);
794        let mut router_data = default_router_data();
795        router_data.insert(
796            "aps".to_string(),
797            serde_json::json!({"mutable-content": "should be a number"}),
798        );
799        let notification = make_notification(router_data, None, RouterType::APNS);
800
801        let result = router.route_notification(notification).await;
802        assert!(result.is_err());
803        assert!(
804            matches!(
805                result.as_ref().unwrap_err().kind,
806                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
807            ),
808            "result = {result:?}"
809        );
810    }
811}