1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11 build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15 self, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
16 Response,
17 request::payload::{Payload, PayloadLike},
18};
19use actix_web::http::StatusCode;
20use async_trait::async_trait;
21use cadence::StatsdClient;
22use futures::{StreamExt, TryStreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use url::Url;
28use uuid::Uuid;
29
30pub struct ApnsRouter {
32 clients: HashMap<String, ApnsClientData>,
34 settings: ApnsSettings,
35 endpoint_url: Url,
36 metrics: Arc<StatsdClient>,
37 db: Box<dyn DbClient>,
38 #[cfg(feature = "reliable_report")]
39 reliability: Arc<PushReliability>,
40}
41
42struct ApnsClientData {
43 client: Box<dyn ApnsClient>,
44 topic: String,
45}
46
47#[async_trait]
48trait ApnsClient: Send + Sync {
49 async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
50}
51
52#[async_trait]
53impl ApnsClient for a2::Client {
54 async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
55 self.send(payload).await
56 }
57}
58
59#[derive(Deserialize, Serialize, Default, Debug, Clone)]
62#[serde(rename_all = "kebab-case")]
63#[allow(clippy::upper_case_acronyms)]
64pub struct ApsDeser<'a> {
65 #[serde(skip_serializing_if = "Option::is_none")]
72 pub badge: Option<u32>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub sound: Option<&'a str>,
77
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub content_available: Option<u8>,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
85 pub category: Option<&'a str>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
90 pub mutable_content: Option<u8>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub url_args: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99pub struct ApsAlertHolder {
102 title: String,
103 subtitle: String,
104 body: String,
105 title_loc_key: String,
106 title_loc_args: Vec<String>,
107 action_loc_key: String,
108 loc_key: String,
109 loc_args: Vec<String>,
110 launch_image: String,
111}
112
113impl ApnsRouter {
114 pub async fn new(
117 settings: ApnsSettings,
118 endpoint_url: Url,
119 metrics: Arc<StatsdClient>,
120 db: Box<dyn DbClient>,
121 #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
122 ) -> Result<Self, ApnsError> {
123 let channels = settings.channels()?;
124
125 let defaults = ApnsSettings::default();
128 let request_timeout_secs = settings
129 .request_timeout_secs
130 .or(defaults.request_timeout_secs);
131 let pool_idle_timeout_secs = settings
132 .pool_idle_timeout_secs
133 .or(defaults.pool_idle_timeout_secs);
134
135 let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
136 .then(|(name, channel)| {
137 Self::create_client(name, channel, request_timeout_secs, pool_idle_timeout_secs)
138 })
139 .try_collect()
140 .await?;
141
142 trace!("Initialized {} APNs clients", clients.len());
143 Ok(Self {
144 clients,
145 settings,
146 endpoint_url,
147 metrics,
148 db,
149 #[cfg(feature = "reliable_report")]
150 reliability,
151 })
152 }
153
154 async fn create_client(
160 name: String,
161 channel: ApnsChannel,
162 request_timeout_secs: Option<u64>,
163 pool_idle_timeout_secs: Option<u64>,
164 ) -> Result<(String, ApnsClientData), ApnsError> {
165 let endpoint = if channel.sandbox {
166 Endpoint::Sandbox
167 } else {
168 Endpoint::Production
169 };
170 let config = a2::ClientConfig {
174 endpoint,
175 request_timeout_secs,
176 pool_idle_timeout_secs,
177 };
178 let client: Box<dyn ApnsClient> = match (&channel.key_id, &channel.team_id) {
179 (Some(key_id), Some(team_id)) => {
180 if !channel.cert.is_empty() {
181 warn!(
182 "APNS channel '{}' specifies `cert`, but it is ignored because \
183 `key_id`/`team_id` select token-based auth",
184 name
185 );
186 }
187 let key = Self::read_pem(&channel.key).await?;
188 Box::new(
189 a2::Client::token(key.as_slice(), key_id, team_id, config)
190 .map_err(ApnsError::ApnsClient)?,
191 )
192 }
193 (None, None) => {
194 let cert = Self::read_pem(&channel.cert).await?;
195 let key = Self::read_pem(&channel.key).await?;
196 Box::new(
197 a2::Client::certificate_parts(&cert, &key, config)
198 .map_err(ApnsError::ApnsClient)?,
199 )
200 }
201 _ => {
202 return Err(ApnsError::Config(
203 name,
204 "`key_id` and `team_id` must both be set for token-based auth".to_owned(),
205 ));
206 }
207 };
208 let client = ApnsClientData {
209 client,
210 topic: channel
211 .topic
212 .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
213 };
214
215 Ok((name, client))
216 }
217
218 async fn read_pem(value: &str) -> Result<Vec<u8>, ApnsError> {
221 if value.starts_with('-') {
222 Ok(value.as_bytes().to_vec())
223 } else {
224 Ok(tokio::fs::read(value).await?)
225 }
226 }
227
228 fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
230 DefaultNotificationBuilder::new()
231 .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
232 .set_loc_key("SentTab.NoTabArrivingNotification.body")
233 .set_mutable_content()
234 }
235
236 async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
238 match &error {
239 a2::Error::ResponseError(response) => {
240 let reason = response
244 .error
245 .as_ref()
246 .map(|r| format!("{:?}", r.reason))
247 .unwrap_or_else(|| "Unknown".to_owned());
248 let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
249 incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
250 if response.code == 410 {
251 debug!("APNS recipient has been unregistered, removing user");
252 if let Err(e) = self.db.remove_user(&uaid).await {
253 warn!("Error while removing user due to APNS 410: {}", e);
254 }
255
256 return ApiError::from(ApnsError::Unregistered);
257 } else {
258 warn!("APNS error: {:?}", response.error);
259 }
260 }
261 a2::Error::ConnectionError(e) => {
262 error!("APNS connection error: {:?}", e);
263 incr_error_metric(
264 &self.metrics,
265 "apns",
266 channel,
267 "connection_unavailable",
268 StatusCode::SERVICE_UNAVAILABLE,
269 None,
270 );
271 }
272 _ => {
273 warn!("Unknown error while sending APNS request: {}", error);
274 incr_error_metric(
275 &self.metrics,
276 "apns",
277 channel,
278 "unknown",
279 StatusCode::BAD_GATEWAY,
280 None,
281 );
282 }
283 }
284
285 ApiError::from(ApnsError::ApnsUpstream(error))
286 }
287
288 pub fn active(&self) -> bool {
290 !self.clients.is_empty()
291 }
292
293 fn derive_aps<'a>(
298 &self,
299 replacement: Value,
300 holder: &'a mut ApsAlertHolder,
301 ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
302 let mut aps = Self::default_aps();
303 if let Some(v) = replacement.get("title") {
312 if let Some(v) = v.as_str() {
313 v.clone_into(&mut holder.title);
314 aps = aps.set_title(&holder.title);
315 } else {
316 return Err(ApnsError::InvalidApsData);
317 }
318 }
319 if let Some(v) = replacement.get("subtitle") {
320 if let Some(v) = v.as_str() {
321 v.clone_into(&mut holder.subtitle);
322 aps = aps.set_subtitle(&holder.subtitle);
323 } else {
324 return Err(ApnsError::InvalidApsData);
325 }
326 }
327 if let Some(v) = replacement.get("body") {
328 if let Some(v) = v.as_str() {
329 v.clone_into(&mut holder.body);
330 aps = aps.set_body(&holder.body);
331 } else {
332 return Err(ApnsError::InvalidApsData);
333 }
334 }
335 if let Some(v) = replacement.get("title_loc_key") {
336 if let Some(v) = v.as_str() {
337 v.clone_into(&mut holder.title_loc_key);
338 aps = aps.set_title_loc_key(&holder.title_loc_key);
339 } else {
340 return Err(ApnsError::InvalidApsData);
341 }
342 }
343 if let Some(v) = replacement.get("title_loc_args") {
344 if let Some(v) = v.as_array() {
345 let mut args: Vec<String> = Vec::new();
346 for val in v {
347 if let Some(value) = val.as_str() {
348 args.push(value.to_owned())
349 } else {
350 return Err(ApnsError::InvalidApsData);
351 }
352 }
353 holder.title_loc_args = args;
354 aps = aps.set_title_loc_args(&holder.title_loc_args);
355 } else {
356 return Err(ApnsError::InvalidApsData);
357 }
358 }
359 if let Some(v) = replacement.get("action_loc_key") {
360 if let Some(v) = v.as_str() {
361 v.clone_into(&mut holder.action_loc_key);
362 aps = aps.set_action_loc_key(&holder.action_loc_key);
363 } else {
364 return Err(ApnsError::InvalidApsData);
365 }
366 }
367 if let Some(v) = replacement.get("loc_key") {
368 if let Some(v) = v.as_str() {
369 v.clone_into(&mut holder.loc_key);
370 aps = aps.set_loc_key(&holder.loc_key);
371 } else {
372 return Err(ApnsError::InvalidApsData);
373 }
374 }
375 if let Some(v) = replacement.get("loc_args") {
376 if let Some(v) = v.as_array() {
377 let mut args: Vec<String> = Vec::new();
378 for val in v {
379 if let Some(value) = val.as_str() {
380 args.push(value.to_owned())
381 } else {
382 return Err(ApnsError::InvalidApsData);
383 }
384 }
385 holder.loc_args = args;
386 aps = aps.set_loc_args(&holder.loc_args);
387 } else {
388 return Err(ApnsError::InvalidApsData);
389 }
390 }
391 if let Some(v) = replacement.get("launch_image") {
392 if let Some(v) = v.as_str() {
393 v.clone_into(&mut holder.launch_image);
394 aps = aps.set_launch_image(&holder.launch_image);
395 } else {
396 return Err(ApnsError::InvalidApsData);
397 }
398 }
399 if let Some(v) = replacement.get("mutable-content") {
403 if let Some(v) = v.as_i64() {
404 if v != 0 {
405 aps = aps.set_mutable_content();
406 }
407 } else {
408 return Err(ApnsError::InvalidApsData);
409 }
410 }
411 Ok(aps)
412 }
413}
414
415#[async_trait(?Send)]
416impl Router for ApnsRouter {
417 fn register(
418 &self,
419 router_input: &RouterDataInput,
420 app_id: &str,
421 ) -> Result<HashMap<String, Value>, RouterError> {
422 if !self.clients.contains_key(app_id) {
423 return Err(ApnsError::InvalidReleaseChannel.into());
424 }
425
426 let mut router_data = HashMap::new();
427 router_data.insert(
428 "token".to_string(),
429 serde_json::to_value(&router_input.token).unwrap(),
430 );
431 router_data.insert(
432 "rel_channel".to_string(),
433 serde_json::to_value(app_id).unwrap(),
434 );
435
436 if let Some(aps) = &router_input.aps {
437 if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
438 return Err(ApnsError::InvalidApsData.into());
439 }
440 router_data.insert(
441 "aps".to_string(),
442 serde_json::to_value(aps.clone()).unwrap(),
443 );
444 }
445
446 Ok(router_data)
447 }
448
449 #[allow(unused_mut)]
450 async fn route_notification(
451 &self,
452 mut notification: Notification,
453 ) -> ApiResult<RouterResponse> {
454 debug!(
455 "Sending APNS notification to UAID {}",
456 notification.subscription.user.uaid
457 );
458 trace!("Notification = {:?}", notification);
459
460 let router_data = notification
462 .subscription
463 .user
464 .router_data
465 .as_ref()
466 .ok_or(ApnsError::NoDeviceToken)?
467 .clone();
468 let token = router_data
469 .get("token")
470 .and_then(Value::as_str)
471 .ok_or(ApnsError::NoDeviceToken)?;
472 let channel = router_data
473 .get("rel_channel")
474 .and_then(Value::as_str)
475 .ok_or(ApnsError::NoReleaseChannel)?;
476 let aps_json = router_data.get("aps").cloned();
477 let mut message_data = build_message_data(¬ification)?;
478 message_data.insert("ver", notification.message_id.clone());
479
480 let ApnsClientData { client, topic } = self
482 .clients
483 .get(channel)
484 .ok_or(ApnsError::InvalidReleaseChannel)?;
485
486 let mut holder = ApsAlertHolder::default();
489
490 let aps = if let Some(replacement) = aps_json {
493 self.derive_aps(replacement, &mut holder)?
494 } else {
495 Self::default_aps()
496 };
497
498 let mut payload = aps.build(
500 token,
501 NotificationOptions {
502 apns_id: None,
503 apns_priority: Some(Priority::High),
504 apns_topic: Some(topic),
505 apns_collapse_id: None,
506 apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
507 ..Default::default()
508 },
509 );
510 let message_length = message_data.get("body").map(|s| s.len()).unwrap_or(0);
511 payload.data = message_data
512 .into_iter()
513 .map(|(k, v)| (k, Value::String(v)))
514 .collect();
515
516 let payload_json = payload
518 .clone()
519 .to_json_string()
520 .map_err(|_| RouterError::TooMuchData(message_length))?;
521 message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
522
523 trace!("Sending message to APNS: {:?}", payload);
525 if let Err(e) = client.send(payload).await {
526 #[cfg(feature = "reliable_report")]
527 notification
528 .record_reliability(
529 &self.reliability,
530 autopush_common::reliability::ReliabilityState::Errored,
531 )
532 .await;
533 return Err(self
534 .handle_error(e, notification.subscription.user.uaid, channel)
535 .await);
536 }
537
538 trace!("APNS request was successful");
539 incr_success_metrics(&self.metrics, "apns", channel, ¬ification);
540 #[cfg(feature = "reliable_report")]
541 {
542 notification
547 .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
548 .await;
549 }
550
551 Ok(RouterResponse::success(
552 self.endpoint_url
553 .join(&format!("/m/{}", notification.message_id))
554 .expect("Message ID is not URL-safe")
555 .to_string(),
556 notification.headers.ttl as usize,
557 ))
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use crate::error::ApiErrorKind;
564 use crate::extractors::routers::RouterType;
565 use crate::routers::apns::error::ApnsError;
566 use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
567 use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
568 use crate::routers::common::tests::{CHANNEL_ID, make_notification};
569 use crate::routers::{Router, RouterError, RouterResponse};
570 use a2::request::payload::Payload;
571 use a2::{Error, Response};
572 use async_trait::async_trait;
573 use autopush_common::db::client::DbClient;
574 use autopush_common::db::mock::MockDbClient;
575 #[cfg(feature = "reliable_report")]
576 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
577 use cadence::StatsdClient;
578 use mockall::predicate;
579 use std::collections::HashMap;
580 use std::sync::Arc;
581 use url::Url;
582
583 const DEVICE_TOKEN: &str = "test-token";
584 const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
585
586 fn test_p8_key() -> String {
590 use openssl::{
591 ec::{EcGroup, EcKey},
592 nid::Nid,
593 pkey::PKey,
594 };
595
596 let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1).unwrap();
597 let key = EcKey::generate(&group).unwrap();
598 let pem = PKey::from_ec_key(key).unwrap();
599 String::from_utf8(pem.private_key_to_pem_pkcs8().unwrap()).unwrap()
600 }
601
602 #[allow(clippy::type_complexity)]
603 struct MockApnsClient {
605 send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
606 }
607
608 #[async_trait]
609 impl ApnsClient for MockApnsClient {
610 async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
611 (self.send_fn)(payload)
612 }
613 }
614
615 impl MockApnsClient {
616 fn new<F>(send_fn: F) -> Self
617 where
618 F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
619 F: Send + Sync + 'static,
620 {
621 Self {
622 send_fn: Box::new(send_fn),
623 }
624 }
625 }
626
627 fn apns_success_response() -> a2::Response {
629 a2::Response {
630 error: None,
631 apns_id: Some(APNS_ID.to_string()),
632 code: 200,
633 }
634 }
635
636 fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
638 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
639
640 ApnsRouter {
641 clients: {
642 let mut map = HashMap::new();
643 map.insert(
644 "test-channel".to_string(),
645 ApnsClientData {
646 client: Box::new(client),
647 topic: "test-topic".to_string(),
648 },
649 );
650 map
651 },
652 settings: ApnsSettings::default(),
653 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
654 metrics: metrics.clone(),
655 db: db.clone(),
656 #[cfg(feature = "reliable_report")]
657 reliability: Arc::new(
658 PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
659 ),
660 }
661 }
662
663 fn default_router_data() -> HashMap<String, serde_json::Value> {
665 let mut map = HashMap::new();
666 map.insert(
667 "token".to_string(),
668 serde_json::to_value(DEVICE_TOKEN).unwrap(),
669 );
670 map.insert(
671 "rel_channel".to_string(),
672 serde_json::to_value("test-channel").unwrap(),
673 );
674 map
675 }
676
677 #[tokio::test]
679 async fn create_client_token_auth() {
680 let channel = ApnsChannel {
681 key: test_p8_key(),
682 key_id: Some("ABC123DEFG".to_string()),
683 team_id: Some("TEAMID1234".to_string()),
684 topic: Some("com.mozilla.org.Firefox".to_string()),
685 ..Default::default()
686 };
687
688 let (name, client) = ApnsRouter::create_client("test".to_string(), channel, None, None)
689 .await
690 .expect("token client creation failed");
691 assert_eq!(name, "test");
692 assert_eq!(client.topic, "com.mozilla.org.Firefox");
693 }
694
695 #[tokio::test]
697 async fn create_client_partial_token_auth() {
698 let channel = ApnsChannel {
699 key: test_p8_key(),
700 key_id: Some("ABC123DEFG".to_string()),
701 ..Default::default()
702 };
703
704 let result = ApnsRouter::create_client("test".to_string(), channel, None, None).await;
705 assert!(
706 matches!(result, Err(ApnsError::Config(ref name, _)) if name == "test"),
707 "expected ApnsError::Config"
708 );
709 }
710
711 #[tokio::test]
713 async fn successful_routing_no_data() {
714 use a2::NotificationBuilder;
715
716 let client = MockApnsClient::new(|payload| {
717 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
718 assert_eq!(
719 serde_json::to_value(payload.aps).unwrap(),
720 serde_json::to_value(built.aps).unwrap()
721 );
722 assert_eq!(payload.device_token, DEVICE_TOKEN);
723 assert_eq!(payload.options.apns_topic, Some("test-topic"));
724 assert_eq!(
725 serde_json::to_value(payload.data).unwrap(),
726 serde_json::json!({
727 "chid": CHANNEL_ID,
728 "ver": "test-message-id"
729 })
730 );
731
732 Ok(apns_success_response())
733 });
734 let mdb = MockDbClient::new();
735 let db = mdb.into_boxed_arc();
736 let router = make_router(client, db);
737 let notification = make_notification(default_router_data(), None, RouterType::APNS);
738
739 let result = router.route_notification(notification).await;
740 assert!(result.is_ok(), "result = {result:?}");
741 assert_eq!(
742 result.unwrap(),
743 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
744 );
745 }
746
747 #[tokio::test]
749 async fn successful_routing_with_data() {
750 use a2::NotificationBuilder;
751
752 let client = MockApnsClient::new(|payload| {
753 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
754 assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
755 assert_eq!(payload.device_token, DEVICE_TOKEN);
756 assert_eq!(payload.options.apns_topic, Some("test-topic"));
757 assert_eq!(
758 serde_json::to_value(payload.data).unwrap(),
759 serde_json::json!({
760 "chid": CHANNEL_ID,
761 "ver": "test-message-id",
762 "body": "test-data",
763 "con": "test-encoding",
764 "enc": "test-encryption",
765 "cryptokey": "test-crypto-key",
766 "enckey": "test-encryption-key"
767 })
768 );
769
770 Ok(apns_success_response())
771 });
772 let mdb = MockDbClient::new();
773 let db = mdb.into_boxed_arc();
774 let router = make_router(client, db);
775 let data = "test-data".to_string();
776 let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
777
778 let result = router.route_notification(notification).await;
779 assert!(result.is_ok(), "result = {result:?}");
780 assert_eq!(
781 result.unwrap(),
782 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
783 );
784 }
785
786 #[tokio::test]
789 async fn missing_client() {
790 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
791 let db = MockDbClient::new().into_boxed_arc();
792 let router = make_router(client, db);
793 let mut router_data = default_router_data();
794 router_data.insert(
795 "rel_channel".to_string(),
796 serde_json::to_value("unknown-app-id").unwrap(),
797 );
798 let notification = make_notification(router_data, None, RouterType::APNS);
799
800 let result = router.route_notification(notification).await;
801 assert!(result.is_err());
802 assert!(
803 matches!(
804 result.as_ref().unwrap_err().kind,
805 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
806 ),
807 "result = {result:?}"
808 );
809 }
810
811 #[tokio::test]
814 async fn user_not_found() {
815 let client = MockApnsClient::new(|_| {
816 Err(a2::Error::ResponseError(a2::Response {
817 error: Some(a2::ErrorBody {
818 reason: a2::ErrorReason::Unregistered,
819 timestamp: Some(0),
820 }),
821 apns_id: None,
822 code: 410,
823 }))
824 });
825 let notification = make_notification(default_router_data(), None, RouterType::APNS);
826 let mut db = MockDbClient::new();
827 db.expect_remove_user()
828 .with(predicate::eq(notification.subscription.user.uaid))
829 .times(1)
830 .return_once(|_| Ok(()));
831 let router = make_router(client, db.into_boxed_arc());
832
833 let result = router.route_notification(notification).await;
834 assert!(result.is_err());
835 assert!(
836 matches!(
837 result.as_ref().unwrap_err().kind,
838 ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
839 ),
840 "result = {result:?}"
841 );
842 }
843
844 #[tokio::test]
846 async fn upstream_error() {
847 let client = MockApnsClient::new(|_| {
848 Err(a2::Error::ResponseError(a2::Response {
849 error: Some(a2::ErrorBody {
850 reason: a2::ErrorReason::BadCertificate,
851 timestamp: None,
852 }),
853 apns_id: None,
854 code: 403,
855 }))
856 });
857 let db = MockDbClient::new().into_boxed_arc();
858 let router = make_router(client, db);
859 let notification = make_notification(default_router_data(), None, RouterType::APNS);
860
861 let result = router.route_notification(notification).await;
862 assert!(result.is_err());
863 assert!(
864 matches!(
865 result.as_ref().unwrap_err().kind,
866 ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
867 a2::Error::ResponseError(a2::Response {
868 error: Some(a2::ErrorBody {
869 reason: a2::ErrorReason::BadCertificate,
870 timestamp: None,
871 }),
872 apns_id: None,
873 code: 403,
874 })
875 )))
876 ),
877 "result = {result:?}"
878 );
879 }
880
881 #[tokio::test]
883 async fn invalid_aps_data() {
884 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
885 let db = MockDbClient::new().into_boxed_arc();
886 let router = make_router(client, db);
887 let mut router_data = default_router_data();
888 router_data.insert(
889 "aps".to_string(),
890 serde_json::json!({"mutable-content": "should be a number"}),
891 );
892 let notification = make_notification(router_data, None, RouterType::APNS);
893
894 let result = router.route_notification(notification).await;
895 assert!(result.is_err());
896 assert!(
897 matches!(
898 result.as_ref().unwrap_err().kind,
899 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
900 ),
901 "result = {result:?}"
902 );
903 }
904}