1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11 build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15 self, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
16 Response,
17 request::payload::{Payload, PayloadLike},
18};
19use actix_web::http::StatusCode;
20use async_trait::async_trait;
21use cadence::StatsdClient;
22use futures::{StreamExt, TryStreamExt};
23use serde::{Deserialize, Serialize};
24use serde_json::Value;
25use std::collections::HashMap;
26use std::sync::Arc;
27use url::Url;
28use uuid::Uuid;
29
30pub struct ApnsRouter {
32 clients: HashMap<String, ApnsClientData>,
34 settings: ApnsSettings,
35 endpoint_url: Url,
36 metrics: Arc<StatsdClient>,
37 db: Box<dyn DbClient>,
38 #[cfg(feature = "reliable_report")]
39 reliability: Arc<PushReliability>,
40}
41
42struct ApnsClientData {
43 client: Box<dyn ApnsClient>,
44 topic: String,
45}
46
47#[async_trait]
48trait ApnsClient: Send + Sync {
49 async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
50}
51
52#[async_trait]
53impl ApnsClient for a2::Client {
54 async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
55 self.send(payload).await
56 }
57}
58
59#[derive(Deserialize, Serialize, Default, Debug, Clone)]
62#[serde(rename_all = "kebab-case")]
63#[allow(clippy::upper_case_acronyms)]
64pub struct ApsDeser<'a> {
65 #[serde(skip_serializing_if = "Option::is_none")]
72 pub badge: Option<u32>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub sound: Option<&'a str>,
77
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub content_available: Option<u8>,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
85 pub category: Option<&'a str>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
90 pub mutable_content: Option<u8>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub url_args: Option<Vec<String>>,
96}
97
98#[derive(Default)]
99pub struct ApsAlertHolder {
102 title: String,
103 subtitle: String,
104 body: String,
105 title_loc_key: String,
106 title_loc_args: Vec<String>,
107 action_loc_key: String,
108 loc_key: String,
109 loc_args: Vec<String>,
110 launch_image: String,
111}
112
113impl ApnsRouter {
114 pub async fn new(
117 settings: ApnsSettings,
118 endpoint_url: Url,
119 metrics: Arc<StatsdClient>,
120 db: Box<dyn DbClient>,
121 #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
122 ) -> Result<Self, ApnsError> {
123 let channels = settings.channels()?;
124
125 let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
126 .then(|(name, settings)| Self::create_client(name, settings))
127 .try_collect()
128 .await?;
129
130 trace!("Initialized {} APNs clients", clients.len());
131 Ok(Self {
132 clients,
133 settings,
134 endpoint_url,
135 metrics,
136 db,
137 #[cfg(feature = "reliable_report")]
138 reliability,
139 })
140 }
141
142 async fn create_client(
144 name: String,
145 settings: ApnsChannel,
146 ) -> Result<(String, ApnsClientData), ApnsError> {
147 let endpoint = if settings.sandbox {
148 Endpoint::Sandbox
149 } else {
150 Endpoint::Production
151 };
152 let cert = if !settings.cert.starts_with('-') {
153 tokio::fs::read(settings.cert).await?
154 } else {
155 settings.cert.as_bytes().to_vec()
156 };
157 let key = if !settings.key.starts_with('-') {
158 tokio::fs::read(settings.key).await?
159 } else {
160 settings.key.as_bytes().to_vec()
161 };
162 let apns_settings = ApnsSettings::default();
167 let config = a2::ClientConfig {
168 endpoint,
169 request_timeout_secs: apns_settings.request_timeout_secs,
170 pool_idle_timeout_secs: apns_settings.pool_idle_timeout_secs,
171 };
172 let client = ApnsClientData {
173 client: Box::new(
174 a2::Client::certificate_parts(&cert, &key, config)
175 .map_err(ApnsError::ApnsClient)?,
176 ),
177 topic: settings
178 .topic
179 .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
180 };
181
182 Ok((name, client))
183 }
184
185 fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
187 DefaultNotificationBuilder::new()
188 .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
189 .set_loc_key("SentTab.NoTabArrivingNotification.body")
190 .set_mutable_content()
191 }
192
193 async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
195 match &error {
196 a2::Error::ResponseError(response) => {
197 let reason = response
201 .error
202 .as_ref()
203 .map(|r| format!("{:?}", r.reason))
204 .unwrap_or_else(|| "Unknown".to_owned());
205 let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
206 incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
207 if response.code == 410 {
208 debug!("APNS recipient has been unregistered, removing user");
209 if let Err(e) = self.db.remove_user(&uaid).await {
210 warn!("Error while removing user due to APNS 410: {}", e);
211 }
212
213 return ApiError::from(ApnsError::Unregistered);
214 } else {
215 warn!("APNS error: {:?}", response.error);
216 }
217 }
218 a2::Error::ConnectionError(e) => {
219 error!("APNS connection error: {:?}", e);
220 incr_error_metric(
221 &self.metrics,
222 "apns",
223 channel,
224 "connection_unavailable",
225 StatusCode::SERVICE_UNAVAILABLE,
226 None,
227 );
228 }
229 _ => {
230 warn!("Unknown error while sending APNS request: {}", error);
231 incr_error_metric(
232 &self.metrics,
233 "apns",
234 channel,
235 "unknown",
236 StatusCode::BAD_GATEWAY,
237 None,
238 );
239 }
240 }
241
242 ApiError::from(ApnsError::ApnsUpstream(error))
243 }
244
245 pub fn active(&self) -> bool {
247 !self.clients.is_empty()
248 }
249
250 fn derive_aps<'a>(
255 &self,
256 replacement: Value,
257 holder: &'a mut ApsAlertHolder,
258 ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
259 let mut aps = Self::default_aps();
260 if let Some(v) = replacement.get("title") {
269 if let Some(v) = v.as_str() {
270 v.clone_into(&mut holder.title);
271 aps = aps.set_title(&holder.title);
272 } else {
273 return Err(ApnsError::InvalidApsData);
274 }
275 }
276 if let Some(v) = replacement.get("subtitle") {
277 if let Some(v) = v.as_str() {
278 v.clone_into(&mut holder.subtitle);
279 aps = aps.set_subtitle(&holder.subtitle);
280 } else {
281 return Err(ApnsError::InvalidApsData);
282 }
283 }
284 if let Some(v) = replacement.get("body") {
285 if let Some(v) = v.as_str() {
286 v.clone_into(&mut holder.body);
287 aps = aps.set_body(&holder.body);
288 } else {
289 return Err(ApnsError::InvalidApsData);
290 }
291 }
292 if let Some(v) = replacement.get("title_loc_key") {
293 if let Some(v) = v.as_str() {
294 v.clone_into(&mut holder.title_loc_key);
295 aps = aps.set_title_loc_key(&holder.title_loc_key);
296 } else {
297 return Err(ApnsError::InvalidApsData);
298 }
299 }
300 if let Some(v) = replacement.get("title_loc_args") {
301 if let Some(v) = v.as_array() {
302 let mut args: Vec<String> = Vec::new();
303 for val in v {
304 if let Some(value) = val.as_str() {
305 args.push(value.to_owned())
306 } else {
307 return Err(ApnsError::InvalidApsData);
308 }
309 }
310 holder.title_loc_args = args;
311 aps = aps.set_title_loc_args(&holder.title_loc_args);
312 } else {
313 return Err(ApnsError::InvalidApsData);
314 }
315 }
316 if let Some(v) = replacement.get("action_loc_key") {
317 if let Some(v) = v.as_str() {
318 v.clone_into(&mut holder.action_loc_key);
319 aps = aps.set_action_loc_key(&holder.action_loc_key);
320 } else {
321 return Err(ApnsError::InvalidApsData);
322 }
323 }
324 if let Some(v) = replacement.get("loc_key") {
325 if let Some(v) = v.as_str() {
326 v.clone_into(&mut holder.loc_key);
327 aps = aps.set_loc_key(&holder.loc_key);
328 } else {
329 return Err(ApnsError::InvalidApsData);
330 }
331 }
332 if let Some(v) = replacement.get("loc_args") {
333 if let Some(v) = v.as_array() {
334 let mut args: Vec<String> = Vec::new();
335 for val in v {
336 if let Some(value) = val.as_str() {
337 args.push(value.to_owned())
338 } else {
339 return Err(ApnsError::InvalidApsData);
340 }
341 }
342 holder.loc_args = args;
343 aps = aps.set_loc_args(&holder.loc_args);
344 } else {
345 return Err(ApnsError::InvalidApsData);
346 }
347 }
348 if let Some(v) = replacement.get("launch_image") {
349 if let Some(v) = v.as_str() {
350 v.clone_into(&mut holder.launch_image);
351 aps = aps.set_launch_image(&holder.launch_image);
352 } else {
353 return Err(ApnsError::InvalidApsData);
354 }
355 }
356 if let Some(v) = replacement.get("mutable-content") {
360 if let Some(v) = v.as_i64() {
361 if v != 0 {
362 aps = aps.set_mutable_content();
363 }
364 } else {
365 return Err(ApnsError::InvalidApsData);
366 }
367 }
368 Ok(aps)
369 }
370}
371
372#[async_trait(?Send)]
373impl Router for ApnsRouter {
374 fn register(
375 &self,
376 router_input: &RouterDataInput,
377 app_id: &str,
378 ) -> Result<HashMap<String, Value>, RouterError> {
379 if !self.clients.contains_key(app_id) {
380 return Err(ApnsError::InvalidReleaseChannel.into());
381 }
382
383 let mut router_data = HashMap::new();
384 router_data.insert(
385 "token".to_string(),
386 serde_json::to_value(&router_input.token).unwrap(),
387 );
388 router_data.insert(
389 "rel_channel".to_string(),
390 serde_json::to_value(app_id).unwrap(),
391 );
392
393 if let Some(aps) = &router_input.aps {
394 if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
395 return Err(ApnsError::InvalidApsData.into());
396 }
397 router_data.insert(
398 "aps".to_string(),
399 serde_json::to_value(aps.clone()).unwrap(),
400 );
401 }
402
403 Ok(router_data)
404 }
405
406 #[allow(unused_mut)]
407 async fn route_notification(
408 &self,
409 mut notification: Notification,
410 ) -> ApiResult<RouterResponse> {
411 debug!(
412 "Sending APNS notification to UAID {}",
413 notification.subscription.user.uaid
414 );
415 trace!("Notification = {:?}", notification);
416
417 let router_data = notification
419 .subscription
420 .user
421 .router_data
422 .as_ref()
423 .ok_or(ApnsError::NoDeviceToken)?
424 .clone();
425 let token = router_data
426 .get("token")
427 .and_then(Value::as_str)
428 .ok_or(ApnsError::NoDeviceToken)?;
429 let channel = router_data
430 .get("rel_channel")
431 .and_then(Value::as_str)
432 .ok_or(ApnsError::NoReleaseChannel)?;
433 let aps_json = router_data.get("aps").cloned();
434 let mut message_data = build_message_data(¬ification)?;
435 message_data.insert("ver", notification.message_id.clone());
436
437 let ApnsClientData { client, topic } = self
439 .clients
440 .get(channel)
441 .ok_or(ApnsError::InvalidReleaseChannel)?;
442
443 let mut holder = ApsAlertHolder::default();
446
447 let aps = if let Some(replacement) = aps_json {
450 self.derive_aps(replacement, &mut holder)?
451 } else {
452 Self::default_aps()
453 };
454
455 let mut payload = aps.build(
457 token,
458 NotificationOptions {
459 apns_id: None,
460 apns_priority: Some(Priority::High),
461 apns_topic: Some(topic),
462 apns_collapse_id: None,
463 apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
464 ..Default::default()
465 },
466 );
467 let message_length = message_data.get("body").map(|s| s.len()).unwrap_or(0);
468 payload.data = message_data
469 .into_iter()
470 .map(|(k, v)| (k, Value::String(v)))
471 .collect();
472
473 let payload_json = payload
475 .clone()
476 .to_json_string()
477 .map_err(|_| RouterError::TooMuchData(message_length))?;
478 message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
479
480 trace!("Sending message to APNS: {:?}", payload);
482 if let Err(e) = client.send(payload).await {
483 #[cfg(feature = "reliable_report")]
484 notification
485 .record_reliability(
486 &self.reliability,
487 autopush_common::reliability::ReliabilityState::Errored,
488 )
489 .await;
490 return Err(self
491 .handle_error(e, notification.subscription.user.uaid, channel)
492 .await);
493 }
494
495 trace!("APNS request was successful");
496 incr_success_metrics(&self.metrics, "apns", channel, ¬ification);
497 #[cfg(feature = "reliable_report")]
498 {
499 notification
504 .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
505 .await;
506 }
507
508 Ok(RouterResponse::success(
509 self.endpoint_url
510 .join(&format!("/m/{}", notification.message_id))
511 .expect("Message ID is not URL-safe")
512 .to_string(),
513 notification.headers.ttl as usize,
514 ))
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use crate::error::ApiErrorKind;
521 use crate::extractors::routers::RouterType;
522 use crate::routers::apns::error::ApnsError;
523 use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
524 use crate::routers::apns::settings::ApnsSettings;
525 use crate::routers::common::tests::{CHANNEL_ID, make_notification};
526 use crate::routers::{Router, RouterError, RouterResponse};
527 use a2::request::payload::Payload;
528 use a2::{Error, Response};
529 use async_trait::async_trait;
530 use autopush_common::db::client::DbClient;
531 use autopush_common::db::mock::MockDbClient;
532 #[cfg(feature = "reliable_report")]
533 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
534 use cadence::StatsdClient;
535 use mockall::predicate;
536 use std::collections::HashMap;
537 use std::sync::Arc;
538 use url::Url;
539
540 const DEVICE_TOKEN: &str = "test-token";
541 const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
542
543 #[allow(clippy::type_complexity)]
544 struct MockApnsClient {
546 send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
547 }
548
549 #[async_trait]
550 impl ApnsClient for MockApnsClient {
551 async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
552 (self.send_fn)(payload)
553 }
554 }
555
556 impl MockApnsClient {
557 fn new<F>(send_fn: F) -> Self
558 where
559 F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
560 F: Send + Sync + 'static,
561 {
562 Self {
563 send_fn: Box::new(send_fn),
564 }
565 }
566 }
567
568 fn apns_success_response() -> a2::Response {
570 a2::Response {
571 error: None,
572 apns_id: Some(APNS_ID.to_string()),
573 code: 200,
574 }
575 }
576
577 fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
579 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
580
581 ApnsRouter {
582 clients: {
583 let mut map = HashMap::new();
584 map.insert(
585 "test-channel".to_string(),
586 ApnsClientData {
587 client: Box::new(client),
588 topic: "test-topic".to_string(),
589 },
590 );
591 map
592 },
593 settings: ApnsSettings::default(),
594 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
595 metrics: metrics.clone(),
596 db: db.clone(),
597 #[cfg(feature = "reliable_report")]
598 reliability: Arc::new(
599 PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
600 ),
601 }
602 }
603
604 fn default_router_data() -> HashMap<String, serde_json::Value> {
606 let mut map = HashMap::new();
607 map.insert(
608 "token".to_string(),
609 serde_json::to_value(DEVICE_TOKEN).unwrap(),
610 );
611 map.insert(
612 "rel_channel".to_string(),
613 serde_json::to_value("test-channel").unwrap(),
614 );
615 map
616 }
617
618 #[tokio::test]
620 async fn successful_routing_no_data() {
621 use a2::NotificationBuilder;
622
623 let client = MockApnsClient::new(|payload| {
624 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
625 assert_eq!(
626 serde_json::to_value(payload.aps).unwrap(),
627 serde_json::to_value(built.aps).unwrap()
628 );
629 assert_eq!(payload.device_token, DEVICE_TOKEN);
630 assert_eq!(payload.options.apns_topic, Some("test-topic"));
631 assert_eq!(
632 serde_json::to_value(payload.data).unwrap(),
633 serde_json::json!({
634 "chid": CHANNEL_ID,
635 "ver": "test-message-id"
636 })
637 );
638
639 Ok(apns_success_response())
640 });
641 let mdb = MockDbClient::new();
642 let db = mdb.into_boxed_arc();
643 let router = make_router(client, db);
644 let notification = make_notification(default_router_data(), None, RouterType::APNS);
645
646 let result = router.route_notification(notification).await;
647 assert!(result.is_ok(), "result = {result:?}");
648 assert_eq!(
649 result.unwrap(),
650 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
651 );
652 }
653
654 #[tokio::test]
656 async fn successful_routing_with_data() {
657 use a2::NotificationBuilder;
658
659 let client = MockApnsClient::new(|payload| {
660 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
661 assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
662 assert_eq!(payload.device_token, DEVICE_TOKEN);
663 assert_eq!(payload.options.apns_topic, Some("test-topic"));
664 assert_eq!(
665 serde_json::to_value(payload.data).unwrap(),
666 serde_json::json!({
667 "chid": CHANNEL_ID,
668 "ver": "test-message-id",
669 "body": "test-data",
670 "con": "test-encoding",
671 "enc": "test-encryption",
672 "cryptokey": "test-crypto-key",
673 "enckey": "test-encryption-key"
674 })
675 );
676
677 Ok(apns_success_response())
678 });
679 let mdb = MockDbClient::new();
680 let db = mdb.into_boxed_arc();
681 let router = make_router(client, db);
682 let data = "test-data".to_string();
683 let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
684
685 let result = router.route_notification(notification).await;
686 assert!(result.is_ok(), "result = {result:?}");
687 assert_eq!(
688 result.unwrap(),
689 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
690 );
691 }
692
693 #[tokio::test]
696 async fn missing_client() {
697 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
698 let db = MockDbClient::new().into_boxed_arc();
699 let router = make_router(client, db);
700 let mut router_data = default_router_data();
701 router_data.insert(
702 "rel_channel".to_string(),
703 serde_json::to_value("unknown-app-id").unwrap(),
704 );
705 let notification = make_notification(router_data, None, RouterType::APNS);
706
707 let result = router.route_notification(notification).await;
708 assert!(result.is_err());
709 assert!(
710 matches!(
711 result.as_ref().unwrap_err().kind,
712 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
713 ),
714 "result = {result:?}"
715 );
716 }
717
718 #[tokio::test]
721 async fn user_not_found() {
722 let client = MockApnsClient::new(|_| {
723 Err(a2::Error::ResponseError(a2::Response {
724 error: Some(a2::ErrorBody {
725 reason: a2::ErrorReason::Unregistered,
726 timestamp: Some(0),
727 }),
728 apns_id: None,
729 code: 410,
730 }))
731 });
732 let notification = make_notification(default_router_data(), None, RouterType::APNS);
733 let mut db = MockDbClient::new();
734 db.expect_remove_user()
735 .with(predicate::eq(notification.subscription.user.uaid))
736 .times(1)
737 .return_once(|_| Ok(()));
738 let router = make_router(client, db.into_boxed_arc());
739
740 let result = router.route_notification(notification).await;
741 assert!(result.is_err());
742 assert!(
743 matches!(
744 result.as_ref().unwrap_err().kind,
745 ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
746 ),
747 "result = {result:?}"
748 );
749 }
750
751 #[tokio::test]
753 async fn upstream_error() {
754 let client = MockApnsClient::new(|_| {
755 Err(a2::Error::ResponseError(a2::Response {
756 error: Some(a2::ErrorBody {
757 reason: a2::ErrorReason::BadCertificate,
758 timestamp: None,
759 }),
760 apns_id: None,
761 code: 403,
762 }))
763 });
764 let db = MockDbClient::new().into_boxed_arc();
765 let router = make_router(client, db);
766 let notification = make_notification(default_router_data(), None, RouterType::APNS);
767
768 let result = router.route_notification(notification).await;
769 assert!(result.is_err());
770 assert!(
771 matches!(
772 result.as_ref().unwrap_err().kind,
773 ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
774 a2::Error::ResponseError(a2::Response {
775 error: Some(a2::ErrorBody {
776 reason: a2::ErrorReason::BadCertificate,
777 timestamp: None,
778 }),
779 apns_id: None,
780 code: 403,
781 })
782 )))
783 ),
784 "result = {result:?}"
785 );
786 }
787
788 #[tokio::test]
790 async fn invalid_aps_data() {
791 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
792 let db = MockDbClient::new().into_boxed_arc();
793 let router = make_router(client, db);
794 let mut router_data = default_router_data();
795 router_data.insert(
796 "aps".to_string(),
797 serde_json::json!({"mutable-content": "should be a number"}),
798 );
799 let notification = make_notification(router_data, None, RouterType::APNS);
800
801 let result = router.route_notification(notification).await;
802 assert!(result.is_err());
803 assert!(
804 matches!(
805 result.as_ref().unwrap_err().kind,
806 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
807 ),
808 "result = {result:?}"
809 );
810 }
811}