Skip to main content

autoendpoint/routers/apns/
router.rs

1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11    build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15    self, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
16    Response,
17    request::payload::{Payload, PayloadLike},
18};
19use actix_web::http::StatusCode;
20use async_trait::async_trait;
21use cadence::StatsdClient;
22use futures::{StreamExt, TryStreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use url::Url;
28use uuid::Uuid;
29
30/// Apple Push Notification Service router
31pub struct ApnsRouter {
32    /// A map from release channel to APNS client
33    clients: HashMap<String, ApnsClientData>,
34    settings: ApnsSettings,
35    endpoint_url: Url,
36    metrics: Arc<StatsdClient>,
37    db: Box<dyn DbClient>,
38    #[cfg(feature = "reliable_report")]
39    reliability: Arc<PushReliability>,
40}
41
42struct ApnsClientData {
43    client: Box<dyn ApnsClient>,
44    topic: String,
45}
46
47#[async_trait]
48trait ApnsClient: Send + Sync {
49    async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
50}
51
52#[async_trait]
53impl ApnsClient for a2::Client {
54    async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
55        self.send(payload).await
56    }
57}
58
59/// a2 does not allow for Deserialization of the APS structure.
60/// this is copied from that library
61#[derive(Deserialize, Serialize, Default, Debug, Clone)]
62#[serde(rename_all = "kebab-case")]
63#[allow(clippy::upper_case_acronyms)]
64pub struct ApsDeser<'a> {
65    // The notification content. Can be empty for silent notifications.
66    // Note, we overwrite this value, but it's copied and commented here
67    // so that future development notes the change.
68    // #[serde(skip_serializing_if = "Option::is_none")]
69    //pub alert: Option<a2::request::payload::APSAlert<'a>>,
70    /// A number shown on top of the app icon.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub badge: Option<u32>,
73
74    /// The name of the sound file to play when user receives the notification.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub sound: Option<&'a str>,
77
78    /// Set to one for silent notifications.
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub content_available: Option<u8>,
81
82    /// When a notification includes the category key, the system displays the
83    /// actions for that category as buttons in the banner or alert interface.
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub category: Option<&'a str>,
86
87    /// If set to one, the app can change the notification content before
88    /// displaying it to the user.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub mutable_content: Option<u8>,
91
92    #[serde(skip_serializing_if = "Option::is_none")]
93    // Converted for Deserialization
94    // pub url_args: Option<&'a [&'a str]>,
95    pub url_args: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99// Replicate a2::request::notification::DefaultAlert
100// for lifetime reasons.
101pub struct ApsAlertHolder {
102    title: String,
103    subtitle: String,
104    body: String,
105    title_loc_key: String,
106    title_loc_args: Vec<String>,
107    action_loc_key: String,
108    loc_key: String,
109    loc_args: Vec<String>,
110    launch_image: String,
111}
112
113impl ApnsRouter {
114    /// Create a new APNS router. APNS clients will be initialized for each
115    /// channel listed in the settings.
116    pub async fn new(
117        settings: ApnsSettings,
118        endpoint_url: Url,
119        metrics: Arc<StatsdClient>,
120        db: Box<dyn DbClient>,
121        #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
122    ) -> Result<Self, ApnsError> {
123        let channels = settings.channels()?;
124
125        // An explicitly null timeout in the config would otherwise mean "no
126        // timeout" in a2, so backstop with the documented defaults.
127        let defaults = ApnsSettings::default();
128        let request_timeout_secs = settings
129            .request_timeout_secs
130            .or(defaults.request_timeout_secs);
131        let pool_idle_timeout_secs = settings
132            .pool_idle_timeout_secs
133            .or(defaults.pool_idle_timeout_secs);
134
135        let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
136            .then(|(name, channel)| {
137                Self::create_client(name, channel, request_timeout_secs, pool_idle_timeout_secs)
138            })
139            .try_collect()
140            .await?;
141
142        trace!("Initialized {} APNs clients", clients.len());
143        Ok(Self {
144            clients,
145            settings,
146            endpoint_url,
147            metrics,
148            db,
149            #[cfg(feature = "reliable_report")]
150            reliability,
151        })
152    }
153
154    /// Create an APNS client for the channel.
155    ///
156    /// If `key_id` and `team_id` are set, the client uses token-based
157    /// authentication with the `.p8` provider auth key in `key`. Otherwise it
158    /// falls back to certificate-based authentication with `cert`/`key`.
159    async fn create_client(
160        name: String,
161        channel: ApnsChannel,
162        request_timeout_secs: Option<u64>,
163        pool_idle_timeout_secs: Option<u64>,
164    ) -> Result<(String, ApnsClientData), ApnsError> {
165        let endpoint = if channel.sandbox {
166            Endpoint::Sandbox
167        } else {
168            Endpoint::Production
169        };
170        // Timeouts come from the `ApnsSettings` config; `ApnsRouter::new`
171        // backstops explicitly-null values with the defaults, so these are
172        // always set here.
173        let config = a2::ClientConfig {
174            endpoint,
175            request_timeout_secs,
176            pool_idle_timeout_secs,
177        };
178        let client: Box<dyn ApnsClient> = match (&channel.key_id, &channel.team_id) {
179            (Some(key_id), Some(team_id)) => {
180                if !channel.cert.is_empty() {
181                    warn!(
182                        "APNS channel '{}' specifies `cert`, but it is ignored because \
183                         `key_id`/`team_id` select token-based auth",
184                        name
185                    );
186                }
187                let key = Self::read_pem(&channel.key).await?;
188                Box::new(
189                    a2::Client::token(key.as_slice(), key_id, team_id, config)
190                        .map_err(ApnsError::ApnsClient)?,
191                )
192            }
193            (None, None) => {
194                let cert = Self::read_pem(&channel.cert).await?;
195                let key = Self::read_pem(&channel.key).await?;
196                Box::new(
197                    a2::Client::certificate_parts(&cert, &key, config)
198                        .map_err(ApnsError::ApnsClient)?,
199                )
200            }
201            _ => {
202                return Err(ApnsError::Config(
203                    name,
204                    "`key_id` and `team_id` must both be set for token-based auth".to_owned(),
205                ));
206            }
207        };
208        let client = ApnsClientData {
209            client,
210            topic: channel
211                .topic
212                .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
213        };
214
215        Ok((name, client))
216    }
217
218    /// Read PEM material that is either an inline value (starting with "-",
219    /// e.g. `-----BEGIN PRIVATE KEY-----`) or a path to a file
220    async fn read_pem(value: &str) -> Result<Vec<u8>, ApnsError> {
221        if value.starts_with('-') {
222            Ok(value.as_bytes().to_vec())
223        } else {
224            Ok(tokio::fs::read(value).await?)
225        }
226    }
227
228    /// The default APS data for a notification
229    fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
230        DefaultNotificationBuilder::new()
231            .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
232            .set_loc_key("SentTab.NoTabArrivingNotification.body")
233            .set_mutable_content()
234    }
235
236    /// Handle an error by logging, updating metrics, etc
237    async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
238        match &error {
239            a2::Error::ResponseError(response) => {
240                // capture the APNs error as a metric response. This allows us to spot trends.
241                // While APNS can return a number of errors (see a2::response::ErrorReason) we
242                // shouldn't encounter many of those.
243                let reason = response
244                    .error
245                    .as_ref()
246                    .map(|r| format!("{:?}", r.reason))
247                    .unwrap_or_else(|| "Unknown".to_owned());
248                let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
249                incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
250                if response.code == 410 {
251                    debug!("APNS recipient has been unregistered, removing user");
252                    if let Err(e) = self.db.remove_user(&uaid).await {
253                        warn!("Error while removing user due to APNS 410: {}", e);
254                    }
255
256                    return ApiError::from(ApnsError::Unregistered);
257                } else {
258                    warn!("APNS error: {:?}", response.error);
259                }
260            }
261            a2::Error::ConnectionError(e) => {
262                error!("APNS connection error: {:?}", e);
263                incr_error_metric(
264                    &self.metrics,
265                    "apns",
266                    channel,
267                    "connection_unavailable",
268                    StatusCode::SERVICE_UNAVAILABLE,
269                    None,
270                );
271            }
272            _ => {
273                warn!("Unknown error while sending APNS request: {}", error);
274                incr_error_metric(
275                    &self.metrics,
276                    "apns",
277                    channel,
278                    "unknown",
279                    StatusCode::BAD_GATEWAY,
280                    None,
281                );
282            }
283        }
284
285        ApiError::from(ApnsError::ApnsUpstream(error))
286    }
287
288    /// if we have any clients defined, this connection is "active"
289    pub fn active(&self) -> bool {
290        !self.clients.is_empty()
291    }
292
293    /// Derive an APS message from the replacement JSON block.
294    ///
295    /// This requires an external "holder" that contains the data that APS will refer to.
296    /// The holder should live in the same context as the `aps.build()` method.
297    fn derive_aps<'a>(
298        &self,
299        replacement: Value,
300        holder: &'a mut ApsAlertHolder,
301    ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
302        let mut aps = Self::default_aps();
303        // a2 does not have a way to bulk replace these values, so do them by hand.
304        // these could probably be turned into a macro, but hopefully, this is
305        // more one off and I didn't want to fight with the macro generator.
306        // This whole thing was included as a byproduct of
307        // https://bugzilla.mozilla.org/show_bug.cgi?id=1364403 which was put
308        // in place to help debug the iOS build. It was supposed to be temporary,
309        // but apparently bit-lock set in and now no one is super sure if it's
310        // still needed or used. (I want to get rid of this.)
311        if let Some(v) = replacement.get("title") {
312            if let Some(v) = v.as_str() {
313                v.clone_into(&mut holder.title);
314                aps = aps.set_title(&holder.title);
315            } else {
316                return Err(ApnsError::InvalidApsData);
317            }
318        }
319        if let Some(v) = replacement.get("subtitle") {
320            if let Some(v) = v.as_str() {
321                v.clone_into(&mut holder.subtitle);
322                aps = aps.set_subtitle(&holder.subtitle);
323            } else {
324                return Err(ApnsError::InvalidApsData);
325            }
326        }
327        if let Some(v) = replacement.get("body") {
328            if let Some(v) = v.as_str() {
329                v.clone_into(&mut holder.body);
330                aps = aps.set_body(&holder.body);
331            } else {
332                return Err(ApnsError::InvalidApsData);
333            }
334        }
335        if let Some(v) = replacement.get("title_loc_key") {
336            if let Some(v) = v.as_str() {
337                v.clone_into(&mut holder.title_loc_key);
338                aps = aps.set_title_loc_key(&holder.title_loc_key);
339            } else {
340                return Err(ApnsError::InvalidApsData);
341            }
342        }
343        if let Some(v) = replacement.get("title_loc_args") {
344            if let Some(v) = v.as_array() {
345                let mut args: Vec<String> = Vec::new();
346                for val in v {
347                    if let Some(value) = val.as_str() {
348                        args.push(value.to_owned())
349                    } else {
350                        return Err(ApnsError::InvalidApsData);
351                    }
352                }
353                holder.title_loc_args = args;
354                aps = aps.set_title_loc_args(&holder.title_loc_args);
355            } else {
356                return Err(ApnsError::InvalidApsData);
357            }
358        }
359        if let Some(v) = replacement.get("action_loc_key") {
360            if let Some(v) = v.as_str() {
361                v.clone_into(&mut holder.action_loc_key);
362                aps = aps.set_action_loc_key(&holder.action_loc_key);
363            } else {
364                return Err(ApnsError::InvalidApsData);
365            }
366        }
367        if let Some(v) = replacement.get("loc_key") {
368            if let Some(v) = v.as_str() {
369                v.clone_into(&mut holder.loc_key);
370                aps = aps.set_loc_key(&holder.loc_key);
371            } else {
372                return Err(ApnsError::InvalidApsData);
373            }
374        }
375        if let Some(v) = replacement.get("loc_args") {
376            if let Some(v) = v.as_array() {
377                let mut args: Vec<String> = Vec::new();
378                for val in v {
379                    if let Some(value) = val.as_str() {
380                        args.push(value.to_owned())
381                    } else {
382                        return Err(ApnsError::InvalidApsData);
383                    }
384                }
385                holder.loc_args = args;
386                aps = aps.set_loc_args(&holder.loc_args);
387            } else {
388                return Err(ApnsError::InvalidApsData);
389            }
390        }
391        if let Some(v) = replacement.get("launch_image") {
392            if let Some(v) = v.as_str() {
393                v.clone_into(&mut holder.launch_image);
394                aps = aps.set_launch_image(&holder.launch_image);
395            } else {
396                return Err(ApnsError::InvalidApsData);
397            }
398        }
399        // Honestly, we should just check to see if this is present
400        // we don't really care what the value is since we'll never
401        // use
402        if let Some(v) = replacement.get("mutable-content") {
403            if let Some(v) = v.as_i64() {
404                if v != 0 {
405                    aps = aps.set_mutable_content();
406                }
407            } else {
408                return Err(ApnsError::InvalidApsData);
409            }
410        }
411        Ok(aps)
412    }
413}
414
415#[async_trait(?Send)]
416impl Router for ApnsRouter {
417    fn register(
418        &self,
419        router_input: &RouterDataInput,
420        app_id: &str,
421    ) -> Result<HashMap<String, Value>, RouterError> {
422        if !self.clients.contains_key(app_id) {
423            return Err(ApnsError::InvalidReleaseChannel.into());
424        }
425
426        let mut router_data = HashMap::new();
427        router_data.insert(
428            "token".to_string(),
429            serde_json::to_value(&router_input.token).unwrap(),
430        );
431        router_data.insert(
432            "rel_channel".to_string(),
433            serde_json::to_value(app_id).unwrap(),
434        );
435
436        if let Some(aps) = &router_input.aps {
437            if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
438                return Err(ApnsError::InvalidApsData.into());
439            }
440            router_data.insert(
441                "aps".to_string(),
442                serde_json::to_value(aps.clone()).unwrap(),
443            );
444        }
445
446        Ok(router_data)
447    }
448
449    #[allow(unused_mut)]
450    async fn route_notification(
451        &self,
452        mut notification: Notification,
453    ) -> ApiResult<RouterResponse> {
454        debug!(
455            "Sending APNS notification to UAID {}",
456            notification.subscription.user.uaid
457        );
458        trace!("Notification = {:?}", notification);
459
460        // Build message data
461        let router_data = notification
462            .subscription
463            .user
464            .router_data
465            .as_ref()
466            .ok_or(ApnsError::NoDeviceToken)?
467            .clone();
468        let token = router_data
469            .get("token")
470            .and_then(Value::as_str)
471            .ok_or(ApnsError::NoDeviceToken)?;
472        let channel = router_data
473            .get("rel_channel")
474            .and_then(Value::as_str)
475            .ok_or(ApnsError::NoReleaseChannel)?;
476        let aps_json = router_data.get("aps").cloned();
477        let mut message_data = build_message_data(&notification)?;
478        message_data.insert("ver", notification.message_id.clone());
479
480        // Get client and build payload
481        let ApnsClientData { client, topic } = self
482            .clients
483            .get(channel)
484            .ok_or(ApnsError::InvalidReleaseChannel)?;
485
486        // A simple bucket variable so that I don't have to deal with fun lifetime issues if we need
487        // to derive.
488        let mut holder = ApsAlertHolder::default();
489
490        // If we are provided a replacement APS block, derive an APS message from it, otherwise
491        // start with a blank APS message.
492        let aps = if let Some(replacement) = aps_json {
493            self.derive_aps(replacement, &mut holder)?
494        } else {
495            Self::default_aps()
496        };
497
498        // Finalize the APS object.
499        let mut payload = aps.build(
500            token,
501            NotificationOptions {
502                apns_id: None,
503                apns_priority: Some(Priority::High),
504                apns_topic: Some(topic),
505                apns_collapse_id: None,
506                apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
507                ..Default::default()
508            },
509        );
510        let message_length = message_data.get("body").map(|s| s.len()).unwrap_or(0);
511        payload.data = message_data
512            .into_iter()
513            .map(|(k, v)| (k, Value::String(v)))
514            .collect();
515
516        // Check size limit
517        let payload_json = payload
518            .clone()
519            .to_json_string()
520            .map_err(|_| RouterError::TooMuchData(message_length))?;
521        message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
522
523        // Send to APNS
524        trace!("Sending message to APNS: {:?}", payload);
525        if let Err(e) = client.send(payload).await {
526            #[cfg(feature = "reliable_report")]
527            notification
528                .record_reliability(
529                    &self.reliability,
530                    autopush_common::reliability::ReliabilityState::Errored,
531                )
532                .await;
533            return Err(self
534                .handle_error(e, notification.subscription.user.uaid, channel)
535                .await);
536        }
537
538        trace!("APNS request was successful");
539        incr_success_metrics(&self.metrics, "apns", channel, &notification);
540        #[cfg(feature = "reliable_report")]
541        {
542            // Record that we've sent the message out to APNS.
543            // We can't set the state here because the notification isn't
544            // mutable, but we are also essentially consuming the
545            // notification nothing else should modify it.
546            notification
547                .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
548                .await;
549        }
550
551        Ok(RouterResponse::success(
552            self.endpoint_url
553                .join(&format!("/m/{}", notification.message_id))
554                .expect("Message ID is not URL-safe")
555                .to_string(),
556            notification.headers.ttl as usize,
557        ))
558    }
559}
560
561#[cfg(test)]
562mod tests {
563    use crate::error::ApiErrorKind;
564    use crate::extractors::routers::RouterType;
565    use crate::routers::apns::error::ApnsError;
566    use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
567    use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
568    use crate::routers::common::tests::{CHANNEL_ID, make_notification};
569    use crate::routers::{Router, RouterError, RouterResponse};
570    use a2::request::payload::Payload;
571    use a2::{Error, Response};
572    use async_trait::async_trait;
573    use autopush_common::db::client::DbClient;
574    use autopush_common::db::mock::MockDbClient;
575    #[cfg(feature = "reliable_report")]
576    use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
577    use cadence::StatsdClient;
578    use mockall::predicate;
579    use std::collections::HashMap;
580    use std::sync::Arc;
581    use url::Url;
582
583    const DEVICE_TOKEN: &str = "test-token";
584    const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
585
586    /// Generate a throwaway P-256 private key in PKCS#8 PEM, the same shape
587    /// as an Apple `.p8` auth key. Generated at test time so no PEM block
588    /// (even a fake one) lives in the source tree to trip secret scanners.
589    fn test_p8_key() -> String {
590        use openssl::{
591            ec::{EcGroup, EcKey},
592            nid::Nid,
593            pkey::PKey,
594        };
595
596        let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
597        let key = EcKey::generate(&group).unwrap();
598        let pem = PKey::from_ec_key(key).unwrap();
599        String::from_utf8(pem.private_key_to_pem_pkcs8().unwrap()).unwrap()
600    }
601
602    #[allow(clippy::type_complexity)]
603    /// A mock APNS client which allows one to supply a custom APNS response/error
604    struct MockApnsClient {
605        send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
606    }
607
608    #[async_trait]
609    impl ApnsClient for MockApnsClient {
610        async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
611            (self.send_fn)(payload)
612        }
613    }
614
615    impl MockApnsClient {
616        fn new<F>(send_fn: F) -> Self
617        where
618            F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
619            F: Send + Sync + 'static,
620        {
621            Self {
622                send_fn: Box::new(send_fn),
623            }
624        }
625    }
626
627    /// Create a successful APNS response
628    fn apns_success_response() -> a2::Response {
629        a2::Response {
630            error: None,
631            apns_id: Some(APNS_ID.to_string()),
632            code: 200,
633        }
634    }
635
636    /// Create a router for testing, using the given APNS client
637    fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
638        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
639
640        ApnsRouter {
641            clients: {
642                let mut map = HashMap::new();
643                map.insert(
644                    "test-channel".to_string(),
645                    ApnsClientData {
646                        client: Box::new(client),
647                        topic: "test-topic".to_string(),
648                    },
649                );
650                map
651            },
652            settings: ApnsSettings::default(),
653            endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
654            metrics: metrics.clone(),
655            db: db.clone(),
656            #[cfg(feature = "reliable_report")]
657            reliability: Arc::new(
658                PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
659            ),
660        }
661    }
662
663    /// Create default user router data
664    fn default_router_data() -> HashMap<String, serde_json::Value> {
665        let mut map = HashMap::new();
666        map.insert(
667            "token".to_string(),
668            serde_json::to_value(DEVICE_TOKEN).unwrap(),
669        );
670        map.insert(
671            "rel_channel".to_string(),
672            serde_json::to_value("test-channel").unwrap(),
673        );
674        map
675    }
676
677    /// A channel with `key_id` and `team_id` creates a token-auth client
678    #[tokio::test]
679    async fn create_client_token_auth() {
680        let channel = ApnsChannel {
681            key: test_p8_key(),
682            key_id: Some("ABC123DEFG".to_string()),
683            team_id: Some("TEAMID1234".to_string()),
684            topic: Some("com.mozilla.org.Firefox".to_string()),
685            ..Default::default()
686        };
687
688        let (name, client) = ApnsRouter::create_client("test".to_string(), channel, None, None)
689            .await
690            .expect("token client creation failed");
691        assert_eq!(name, "test");
692        assert_eq!(client.topic, "com.mozilla.org.Firefox");
693    }
694
695    /// Setting only one of `key_id`/`team_id` is a config error
696    #[tokio::test]
697    async fn create_client_partial_token_auth() {
698        let channel = ApnsChannel {
699            key: test_p8_key(),
700            key_id: Some("ABC123DEFG".to_string()),
701            ..Default::default()
702        };
703
704        let result = ApnsRouter::create_client("test".to_string(), channel, None, None).await;
705        assert!(
706            matches!(result, Err(ApnsError::Config(ref name, _)) if name == "test"),
707            "expected ApnsError::Config"
708        );
709    }
710
711    /// A notification with no data is packaged correctly and sent to APNS
712    #[tokio::test]
713    async fn successful_routing_no_data() {
714        use a2::NotificationBuilder;
715
716        let client = MockApnsClient::new(|payload| {
717            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
718            assert_eq!(
719                serde_json::to_value(payload.aps).unwrap(),
720                serde_json::to_value(built.aps).unwrap()
721            );
722            assert_eq!(payload.device_token, DEVICE_TOKEN);
723            assert_eq!(payload.options.apns_topic, Some("test-topic"));
724            assert_eq!(
725                serde_json::to_value(payload.data).unwrap(),
726                serde_json::json!({
727                    "chid": CHANNEL_ID,
728                    "ver": "test-message-id"
729                })
730            );
731
732            Ok(apns_success_response())
733        });
734        let mdb = MockDbClient::new();
735        let db = mdb.into_boxed_arc();
736        let router = make_router(client, db);
737        let notification = make_notification(default_router_data(), None, RouterType::APNS);
738
739        let result = router.route_notification(notification).await;
740        assert!(result.is_ok(), "result = {result:?}");
741        assert_eq!(
742            result.unwrap(),
743            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
744        );
745    }
746
747    /// A notification with data is packaged correctly and sent to APNS
748    #[tokio::test]
749    async fn successful_routing_with_data() {
750        use a2::NotificationBuilder;
751
752        let client = MockApnsClient::new(|payload| {
753            let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
754            assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
755            assert_eq!(payload.device_token, DEVICE_TOKEN);
756            assert_eq!(payload.options.apns_topic, Some("test-topic"));
757            assert_eq!(
758                serde_json::to_value(payload.data).unwrap(),
759                serde_json::json!({
760                    "chid": CHANNEL_ID,
761                    "ver": "test-message-id",
762                    "body": "test-data",
763                    "con": "test-encoding",
764                    "enc": "test-encryption",
765                    "cryptokey": "test-crypto-key",
766                    "enckey": "test-encryption-key"
767                })
768            );
769
770            Ok(apns_success_response())
771        });
772        let mdb = MockDbClient::new();
773        let db = mdb.into_boxed_arc();
774        let router = make_router(client, db);
775        let data = "test-data".to_string();
776        let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
777
778        let result = router.route_notification(notification).await;
779        assert!(result.is_ok(), "result = {result:?}");
780        assert_eq!(
781            result.unwrap(),
782            RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
783        );
784    }
785
786    /// If there is no client for the user's release channel, an error is
787    /// returned and the APNS request is not sent.
788    #[tokio::test]
789    async fn missing_client() {
790        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
791        let db = MockDbClient::new().into_boxed_arc();
792        let router = make_router(client, db);
793        let mut router_data = default_router_data();
794        router_data.insert(
795            "rel_channel".to_string(),
796            serde_json::to_value("unknown-app-id").unwrap(),
797        );
798        let notification = make_notification(router_data, None, RouterType::APNS);
799
800        let result = router.route_notification(notification).await;
801        assert!(result.is_err());
802        assert!(
803            matches!(
804                result.as_ref().unwrap_err().kind,
805                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
806            ),
807            "result = {result:?}"
808        );
809    }
810
811    /// If APNS says the user doesn't exist anymore, we return a specific error
812    /// and remove the user from the database.
813    #[tokio::test]
814    async fn user_not_found() {
815        let client = MockApnsClient::new(|_| {
816            Err(a2::Error::ResponseError(a2::Response {
817                error: Some(a2::ErrorBody {
818                    reason: a2::ErrorReason::Unregistered,
819                    timestamp: Some(0),
820                }),
821                apns_id: None,
822                code: 410,
823            }))
824        });
825        let notification = make_notification(default_router_data(), None, RouterType::APNS);
826        let mut db = MockDbClient::new();
827        db.expect_remove_user()
828            .with(predicate::eq(notification.subscription.user.uaid))
829            .times(1)
830            .return_once(|_| Ok(()));
831        let router = make_router(client, db.into_boxed_arc());
832
833        let result = router.route_notification(notification).await;
834        assert!(result.is_err());
835        assert!(
836            matches!(
837                result.as_ref().unwrap_err().kind,
838                ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
839            ),
840            "result = {result:?}"
841        );
842    }
843
844    /// APNS errors (other than Unregistered) are wrapped and returned
845    #[tokio::test]
846    async fn upstream_error() {
847        let client = MockApnsClient::new(|_| {
848            Err(a2::Error::ResponseError(a2::Response {
849                error: Some(a2::ErrorBody {
850                    reason: a2::ErrorReason::BadCertificate,
851                    timestamp: None,
852                }),
853                apns_id: None,
854                code: 403,
855            }))
856        });
857        let db = MockDbClient::new().into_boxed_arc();
858        let router = make_router(client, db);
859        let notification = make_notification(default_router_data(), None, RouterType::APNS);
860
861        let result = router.route_notification(notification).await;
862        assert!(result.is_err());
863        assert!(
864            matches!(
865                result.as_ref().unwrap_err().kind,
866                ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
867                    a2::Error::ResponseError(a2::Response {
868                        error: Some(a2::ErrorBody {
869                            reason: a2::ErrorReason::BadCertificate,
870                            timestamp: None,
871                        }),
872                        apns_id: None,
873                        code: 403,
874                    })
875                )))
876            ),
877            "result = {result:?}"
878        );
879    }
880
881    /// An error is returned if the user's APS data is invalid
882    #[tokio::test]
883    async fn invalid_aps_data() {
884        let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
885        let db = MockDbClient::new().into_boxed_arc();
886        let router = make_router(client, db);
887        let mut router_data = default_router_data();
888        router_data.insert(
889            "aps".to_string(),
890            serde_json::json!({"mutable-content": "should be a number"}),
891        );
892        let notification = make_notification(router_data, None, RouterType::APNS);
893
894        let result = router.route_notification(notification).await;
895        assert!(result.is_err());
896        assert!(
897            matches!(
898                result.as_ref().unwrap_err().kind,
899                ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
900            ),
901            "result = {result:?}"
902        );
903    }
904}