1use autopush_common::db::client::DbClient;
2#[cfg(feature = "reliable_report")]
3use autopush_common::reliability::{PushReliability, ReliabilityState};
4
5use crate::error::{ApiError, ApiResult};
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::apns::error::ApnsError;
9use crate::routers::apns::settings::{ApnsChannel, ApnsSettings};
10use crate::routers::common::{
11 build_message_data, incr_error_metric, incr_success_metrics, message_size_check,
12};
13use crate::routers::{Router, RouterError, RouterResponse};
14use a2::{
15 self,
16 request::payload::{Payload, PayloadLike},
17 DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, Priority,
18 Response,
19};
20use actix_web::http::StatusCode;
21use async_trait::async_trait;
22use cadence::StatsdClient;
23use futures::{StreamExt, TryStreamExt};
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26use std::collections::HashMap;
27use std::sync::Arc;
28use url::Url;
29use uuid::Uuid;
30
31pub struct ApnsRouter {
33 clients: HashMap<String, ApnsClientData>,
35 settings: ApnsSettings,
36 endpoint_url: Url,
37 metrics: Arc<StatsdClient>,
38 db: Box<dyn DbClient>,
39 #[cfg(feature = "reliable_report")]
40 reliability: Arc<PushReliability>,
41}
42
43struct ApnsClientData {
44 client: Box<dyn ApnsClient>,
45 topic: String,
46}
47
48#[async_trait]
49trait ApnsClient: Send + Sync {
50 async fn send(&self, payload: Payload<'_>) -> Result<a2::Response, a2::Error>;
51}
52
53#[async_trait]
54impl ApnsClient for a2::Client {
55 async fn send(&self, payload: Payload<'_>) -> Result<Response, a2::Error> {
56 self.send(payload).await
57 }
58}
59
60#[derive(Deserialize, Serialize, Default, Debug, Clone)]
63#[serde(rename_all = "kebab-case")]
64#[allow(clippy::upper_case_acronyms)]
65pub struct ApsDeser<'a> {
66 #[serde(skip_serializing_if = "Option::is_none")]
73 pub badge: Option<u32>,
74
75 #[serde(skip_serializing_if = "Option::is_none")]
77 pub sound: Option<&'a str>,
78
79 #[serde(skip_serializing_if = "Option::is_none")]
81 pub content_available: Option<u8>,
82
83 #[serde(skip_serializing_if = "Option::is_none")]
86 pub category: Option<&'a str>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
91 pub mutable_content: Option<u8>,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub url_args: Option<Vec<String>>,
97}
98
99#[derive(Default)]
100pub struct ApsAlertHolder {
103 title: String,
104 subtitle: String,
105 body: String,
106 title_loc_key: String,
107 title_loc_args: Vec<String>,
108 action_loc_key: String,
109 loc_key: String,
110 loc_args: Vec<String>,
111 launch_image: String,
112}
113
114impl ApnsRouter {
115 pub async fn new(
118 settings: ApnsSettings,
119 endpoint_url: Url,
120 metrics: Arc<StatsdClient>,
121 db: Box<dyn DbClient>,
122 #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
123 ) -> Result<Self, ApnsError> {
124 let channels = settings.channels()?;
125
126 let clients: HashMap<String, ApnsClientData> = futures::stream::iter(channels)
127 .then(|(name, settings)| Self::create_client(name, settings))
128 .try_collect()
129 .await?;
130
131 trace!("Initialized {} APNs clients", clients.len());
132 Ok(Self {
133 clients,
134 settings,
135 endpoint_url,
136 metrics,
137 db,
138 #[cfg(feature = "reliable_report")]
139 reliability,
140 })
141 }
142
143 async fn create_client(
145 name: String,
146 settings: ApnsChannel,
147 ) -> Result<(String, ApnsClientData), ApnsError> {
148 let endpoint = if settings.sandbox {
149 Endpoint::Sandbox
150 } else {
151 Endpoint::Production
152 };
153 let cert = if !settings.cert.starts_with('-') {
154 tokio::fs::read(settings.cert).await?
155 } else {
156 settings.cert.as_bytes().to_vec()
157 };
158 let key = if !settings.key.starts_with('-') {
159 tokio::fs::read(settings.key).await?
160 } else {
161 settings.key.as_bytes().to_vec()
162 };
163 let apns_settings = ApnsSettings::default();
168 let config = a2::ClientConfig {
169 endpoint,
170 request_timeout_secs: apns_settings.request_timeout_secs,
171 pool_idle_timeout_secs: apns_settings.pool_idle_timeout_secs,
172 };
173 let client = ApnsClientData {
174 client: Box::new(
175 a2::Client::certificate_parts(&cert, &key, config)
176 .map_err(ApnsError::ApnsClient)?,
177 ),
178 topic: settings
179 .topic
180 .unwrap_or_else(|| format!("com.mozilla.org.{name}")),
181 };
182
183 Ok((name, client))
184 }
185
186 fn default_aps<'a>() -> DefaultNotificationBuilder<'a> {
188 DefaultNotificationBuilder::new()
189 .set_title_loc_key("SentTab.NoTabArrivingNotification.title")
190 .set_loc_key("SentTab.NoTabArrivingNotification.body")
191 .set_mutable_content()
192 }
193
194 async fn handle_error(&self, error: a2::Error, uaid: Uuid, channel: &str) -> ApiError {
196 match &error {
197 a2::Error::ResponseError(response) => {
198 let reason = response
202 .error
203 .as_ref()
204 .map(|r| format!("{:?}", r.reason))
205 .unwrap_or_else(|| "Unknown".to_owned());
206 let code = StatusCode::from_u16(response.code).unwrap_or(StatusCode::BAD_GATEWAY);
207 incr_error_metric(&self.metrics, "apns", channel, &reason, code, None);
208 if response.code == 410 {
209 debug!("APNS recipient has been unregistered, removing user");
210 if let Err(e) = self.db.remove_user(&uaid).await {
211 warn!("Error while removing user due to APNS 410: {}", e);
212 }
213
214 return ApiError::from(ApnsError::Unregistered);
215 } else {
216 warn!("APNS error: {:?}", response.error);
217 }
218 }
219 a2::Error::ConnectionError(e) => {
220 error!("APNS connection error: {:?}", e);
221 incr_error_metric(
222 &self.metrics,
223 "apns",
224 channel,
225 "connection_unavailable",
226 StatusCode::SERVICE_UNAVAILABLE,
227 None,
228 );
229 }
230 _ => {
231 warn!("Unknown error while sending APNS request: {}", error);
232 incr_error_metric(
233 &self.metrics,
234 "apns",
235 channel,
236 "unknown",
237 StatusCode::BAD_GATEWAY,
238 None,
239 );
240 }
241 }
242
243 ApiError::from(ApnsError::ApnsUpstream(error))
244 }
245
246 pub fn active(&self) -> bool {
248 !self.clients.is_empty()
249 }
250
251 fn derive_aps<'a>(
256 &self,
257 replacement: Value,
258 holder: &'a mut ApsAlertHolder,
259 ) -> Result<DefaultNotificationBuilder<'a>, ApnsError> {
260 let mut aps = Self::default_aps();
261 if let Some(v) = replacement.get("title") {
270 if let Some(v) = v.as_str() {
271 v.clone_into(&mut holder.title);
272 aps = aps.set_title(&holder.title);
273 } else {
274 return Err(ApnsError::InvalidApsData);
275 }
276 }
277 if let Some(v) = replacement.get("subtitle") {
278 if let Some(v) = v.as_str() {
279 v.clone_into(&mut holder.subtitle);
280 aps = aps.set_subtitle(&holder.subtitle);
281 } else {
282 return Err(ApnsError::InvalidApsData);
283 }
284 }
285 if let Some(v) = replacement.get("body") {
286 if let Some(v) = v.as_str() {
287 v.clone_into(&mut holder.body);
288 aps = aps.set_body(&holder.body);
289 } else {
290 return Err(ApnsError::InvalidApsData);
291 }
292 }
293 if let Some(v) = replacement.get("title_loc_key") {
294 if let Some(v) = v.as_str() {
295 v.clone_into(&mut holder.title_loc_key);
296 aps = aps.set_title_loc_key(&holder.title_loc_key);
297 } else {
298 return Err(ApnsError::InvalidApsData);
299 }
300 }
301 if let Some(v) = replacement.get("title_loc_args") {
302 if let Some(v) = v.as_array() {
303 let mut args: Vec<String> = Vec::new();
304 for val in v {
305 if let Some(value) = val.as_str() {
306 args.push(value.to_owned())
307 } else {
308 return Err(ApnsError::InvalidApsData);
309 }
310 }
311 holder.title_loc_args = args;
312 aps = aps.set_title_loc_args(&holder.title_loc_args);
313 } else {
314 return Err(ApnsError::InvalidApsData);
315 }
316 }
317 if let Some(v) = replacement.get("action_loc_key") {
318 if let Some(v) = v.as_str() {
319 v.clone_into(&mut holder.action_loc_key);
320 aps = aps.set_action_loc_key(&holder.action_loc_key);
321 } else {
322 return Err(ApnsError::InvalidApsData);
323 }
324 }
325 if let Some(v) = replacement.get("loc_key") {
326 if let Some(v) = v.as_str() {
327 v.clone_into(&mut holder.loc_key);
328 aps = aps.set_loc_key(&holder.loc_key);
329 } else {
330 return Err(ApnsError::InvalidApsData);
331 }
332 }
333 if let Some(v) = replacement.get("loc_args") {
334 if let Some(v) = v.as_array() {
335 let mut args: Vec<String> = Vec::new();
336 for val in v {
337 if let Some(value) = val.as_str() {
338 args.push(value.to_owned())
339 } else {
340 return Err(ApnsError::InvalidApsData);
341 }
342 }
343 holder.loc_args = args;
344 aps = aps.set_loc_args(&holder.loc_args);
345 } else {
346 return Err(ApnsError::InvalidApsData);
347 }
348 }
349 if let Some(v) = replacement.get("launch_image") {
350 if let Some(v) = v.as_str() {
351 v.clone_into(&mut holder.launch_image);
352 aps = aps.set_launch_image(&holder.launch_image);
353 } else {
354 return Err(ApnsError::InvalidApsData);
355 }
356 }
357 if let Some(v) = replacement.get("mutable-content") {
361 if let Some(v) = v.as_i64() {
362 if v != 0 {
363 aps = aps.set_mutable_content();
364 }
365 } else {
366 return Err(ApnsError::InvalidApsData);
367 }
368 }
369 Ok(aps)
370 }
371}
372
373#[async_trait(?Send)]
374impl Router for ApnsRouter {
375 fn register(
376 &self,
377 router_input: &RouterDataInput,
378 app_id: &str,
379 ) -> Result<HashMap<String, Value>, RouterError> {
380 if !self.clients.contains_key(app_id) {
381 return Err(ApnsError::InvalidReleaseChannel.into());
382 }
383
384 let mut router_data = HashMap::new();
385 router_data.insert(
386 "token".to_string(),
387 serde_json::to_value(&router_input.token).unwrap(),
388 );
389 router_data.insert(
390 "rel_channel".to_string(),
391 serde_json::to_value(app_id).unwrap(),
392 );
393
394 if let Some(aps) = &router_input.aps {
395 if serde_json::from_str::<ApsDeser<'_>>(aps).is_err() {
396 return Err(ApnsError::InvalidApsData.into());
397 }
398 router_data.insert(
399 "aps".to_string(),
400 serde_json::to_value(aps.clone()).unwrap(),
401 );
402 }
403
404 Ok(router_data)
405 }
406
407 #[allow(unused_mut)]
408 async fn route_notification(
409 &self,
410 mut notification: Notification,
411 ) -> ApiResult<RouterResponse> {
412 debug!(
413 "Sending APNS notification to UAID {}",
414 notification.subscription.user.uaid
415 );
416 trace!("Notification = {:?}", notification);
417
418 let router_data = notification
420 .subscription
421 .user
422 .router_data
423 .as_ref()
424 .ok_or(ApnsError::NoDeviceToken)?
425 .clone();
426 let token = router_data
427 .get("token")
428 .and_then(Value::as_str)
429 .ok_or(ApnsError::NoDeviceToken)?;
430 let channel = router_data
431 .get("rel_channel")
432 .and_then(Value::as_str)
433 .ok_or(ApnsError::NoReleaseChannel)?;
434 let aps_json = router_data.get("aps").cloned();
435 let mut message_data = build_message_data(¬ification)?;
436 message_data.insert("ver", notification.message_id.clone());
437
438 let ApnsClientData { client, topic } = self
440 .clients
441 .get(channel)
442 .ok_or(ApnsError::InvalidReleaseChannel)?;
443
444 let mut holder = ApsAlertHolder::default();
447
448 let aps = if let Some(replacement) = aps_json {
451 self.derive_aps(replacement, &mut holder)?
452 } else {
453 Self::default_aps()
454 };
455
456 let mut payload = aps.build(
458 token,
459 NotificationOptions {
460 apns_id: None,
461 apns_priority: Some(Priority::High),
462 apns_topic: Some(topic),
463 apns_collapse_id: None,
464 apns_expiration: Some(notification.timestamp + notification.headers.ttl as u64),
465 ..Default::default()
466 },
467 );
468 payload.data = message_data
469 .into_iter()
470 .map(|(k, v)| (k, Value::String(v)))
471 .collect();
472
473 let payload_json = payload
475 .clone()
476 .to_json_string()
477 .map_err(ApnsError::SizeLimit)?;
478 message_size_check(payload_json.as_bytes(), self.settings.max_data)?;
479
480 trace!("Sending message to APNS: {:?}", payload);
482 if let Err(e) = client.send(payload).await {
483 #[cfg(feature = "reliable_report")]
484 notification
485 .record_reliability(
486 &self.reliability,
487 autopush_common::reliability::ReliabilityState::Errored,
488 )
489 .await;
490 return Err(self
491 .handle_error(e, notification.subscription.user.uaid, channel)
492 .await);
493 }
494
495 trace!("APNS request was successful");
496 incr_success_metrics(&self.metrics, "apns", channel, ¬ification);
497 #[cfg(feature = "reliable_report")]
498 {
499 notification
504 .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
505 .await;
506 }
507
508 Ok(RouterResponse::success(
509 self.endpoint_url
510 .join(&format!("/m/{}", notification.message_id))
511 .expect("Message ID is not URL-safe")
512 .to_string(),
513 notification.headers.ttl as usize,
514 ))
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use crate::error::ApiErrorKind;
521 use crate::extractors::routers::RouterType;
522 use crate::routers::apns::error::ApnsError;
523 use crate::routers::apns::router::{ApnsClient, ApnsClientData, ApnsRouter};
524 use crate::routers::apns::settings::ApnsSettings;
525 use crate::routers::common::tests::{make_notification, CHANNEL_ID};
526 use crate::routers::{Router, RouterError, RouterResponse};
527 use a2::request::payload::Payload;
528 use a2::{Error, Response};
529 use async_trait::async_trait;
530 use autopush_common::db::client::DbClient;
531 use autopush_common::db::mock::MockDbClient;
532 #[cfg(feature = "reliable_report")]
533 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
534 use cadence::StatsdClient;
535 use mockall::predicate;
536 use std::collections::HashMap;
537 use std::sync::Arc;
538 use url::Url;
539
540 const DEVICE_TOKEN: &str = "test-token";
541 const APNS_ID: &str = "deadbeef-4f5e-4403-be8f-35d0251655f5";
542
543 #[allow(clippy::type_complexity)]
544 struct MockApnsClient {
546 send_fn: Box<dyn Fn(Payload<'_>) -> Result<a2::Response, a2::Error> + Send + Sync>,
547 }
548
549 #[async_trait]
550 impl ApnsClient for MockApnsClient {
551 async fn send(&self, payload: Payload<'_>) -> Result<Response, Error> {
552 (self.send_fn)(payload)
553 }
554 }
555
556 impl MockApnsClient {
557 fn new<F>(send_fn: F) -> Self
558 where
559 F: Fn(Payload<'_>) -> Result<a2::Response, a2::Error>,
560 F: Send + Sync + 'static,
561 {
562 Self {
563 send_fn: Box::new(send_fn),
564 }
565 }
566 }
567
568 fn apns_success_response() -> a2::Response {
570 a2::Response {
571 error: None,
572 apns_id: Some(APNS_ID.to_string()),
573 code: 200,
574 }
575 }
576
577 fn make_router(client: MockApnsClient, db: Box<dyn DbClient>) -> ApnsRouter {
579 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
580
581 ApnsRouter {
582 clients: {
583 let mut map = HashMap::new();
584 map.insert(
585 "test-channel".to_string(),
586 ApnsClientData {
587 client: Box::new(client),
588 topic: "test-topic".to_string(),
589 },
590 );
591 map
592 },
593 settings: ApnsSettings::default(),
594 endpoint_url: Url::parse("http://localhost:8080/").unwrap(),
595 metrics: Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)),
596 db: db.clone(),
597 #[cfg(feature = "reliable_report")]
598 reliability: Arc::new(
599 PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
600 ),
601 }
602 }
603
604 fn default_router_data() -> HashMap<String, serde_json::Value> {
606 let mut map = HashMap::new();
607 map.insert(
608 "token".to_string(),
609 serde_json::to_value(DEVICE_TOKEN).unwrap(),
610 );
611 map.insert(
612 "rel_channel".to_string(),
613 serde_json::to_value("test-channel").unwrap(),
614 );
615 map
616 }
617
618 #[tokio::test]
620 async fn successful_routing_no_data() {
621 use a2::NotificationBuilder;
622
623 let client = MockApnsClient::new(|payload| {
624 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
625 assert_eq!(
626 serde_json::to_value(payload.aps).unwrap(),
627 serde_json::to_value(built.aps).unwrap()
628 );
629 assert_eq!(payload.device_token, DEVICE_TOKEN);
630 assert_eq!(payload.options.apns_topic, Some("test-topic"));
631 assert_eq!(
632 serde_json::to_value(payload.data).unwrap(),
633 serde_json::json!({
634 "chid": CHANNEL_ID,
635 "ver": "test-message-id"
636 })
637 );
638
639 Ok(apns_success_response())
640 });
641 let mdb = MockDbClient::new();
642 let db = mdb.into_boxed_arc();
643 let router = make_router(client, db);
644 let notification = make_notification(default_router_data(), None, RouterType::APNS);
645
646 let result = router.route_notification(notification).await;
647 assert!(result.is_ok(), "result = {result:?}");
648 assert_eq!(
649 result.unwrap(),
650 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
651 );
652 }
653
654 #[tokio::test]
656 async fn successful_routing_with_data() {
657 use a2::NotificationBuilder;
658
659 let client = MockApnsClient::new(|payload| {
660 let built = ApnsRouter::default_aps().build(DEVICE_TOKEN, Default::default());
661 assert_eq!(serde_json::json!(payload.aps), serde_json::json!(built.aps));
662 assert_eq!(payload.device_token, DEVICE_TOKEN);
663 assert_eq!(payload.options.apns_topic, Some("test-topic"));
664 assert_eq!(
665 serde_json::to_value(payload.data).unwrap(),
666 serde_json::json!({
667 "chid": CHANNEL_ID,
668 "ver": "test-message-id",
669 "body": "test-data",
670 "con": "test-encoding",
671 "enc": "test-encryption",
672 "cryptokey": "test-crypto-key",
673 "enckey": "test-encryption-key"
674 })
675 );
676
677 Ok(apns_success_response())
678 });
679 let mdb = MockDbClient::new();
680 let db = mdb.into_boxed_arc();
681 let router = make_router(client, db);
682 let data = "test-data".to_string();
683 let notification = make_notification(default_router_data(), Some(data), RouterType::APNS);
684
685 let result = router.route_notification(notification).await;
686 assert!(result.is_ok(), "result = {result:?}");
687 assert_eq!(
688 result.unwrap(),
689 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
690 );
691 }
692
693 #[tokio::test]
696 async fn missing_client() {
697 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
698 let db = MockDbClient::new().into_boxed_arc();
699 let router = make_router(client, db);
700 let mut router_data = default_router_data();
701 router_data.insert(
702 "rel_channel".to_string(),
703 serde_json::to_value("unknown-app-id").unwrap(),
704 );
705 let notification = make_notification(router_data, None, RouterType::APNS);
706
707 let result = router.route_notification(notification).await;
708 assert!(result.is_err());
709 assert!(
710 matches!(
711 result.as_ref().unwrap_err().kind,
712 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidReleaseChannel))
713 ),
714 "result = {result:?}"
715 );
716 }
717
718 #[tokio::test]
721 async fn user_not_found() {
722 let client = MockApnsClient::new(|_| {
723 Err(a2::Error::ResponseError(a2::Response {
724 error: Some(a2::ErrorBody {
725 reason: a2::ErrorReason::Unregistered,
726 timestamp: Some(0),
727 }),
728 apns_id: None,
729 code: 410,
730 }))
731 });
732 let notification = make_notification(default_router_data(), None, RouterType::APNS);
733 let mut db = MockDbClient::new();
734 db.expect_remove_user()
735 .with(predicate::eq(notification.subscription.user.uaid))
736 .times(1)
737 .return_once(|_| Ok(()));
738 let router = make_router(client, db.into_boxed_arc());
739
740 let result = router.route_notification(notification).await;
741 assert!(result.is_err());
742 assert!(
743 matches!(
744 result.as_ref().unwrap_err().kind,
745 ApiErrorKind::Router(RouterError::Apns(ApnsError::Unregistered))
746 ),
747 "result = {result:?}"
748 );
749 }
750
751 #[tokio::test]
753 async fn upstream_error() {
754 let client = MockApnsClient::new(|_| {
755 Err(a2::Error::ResponseError(a2::Response {
756 error: Some(a2::ErrorBody {
757 reason: a2::ErrorReason::BadCertificate,
758 timestamp: None,
759 }),
760 apns_id: None,
761 code: 403,
762 }))
763 });
764 let db = MockDbClient::new().into_boxed_arc();
765 let router = make_router(client, db);
766 let notification = make_notification(default_router_data(), None, RouterType::APNS);
767
768 let result = router.route_notification(notification).await;
769 assert!(result.is_err());
770 assert!(
771 matches!(
772 result.as_ref().unwrap_err().kind,
773 ApiErrorKind::Router(RouterError::Apns(ApnsError::ApnsUpstream(
774 a2::Error::ResponseError(a2::Response {
775 error: Some(a2::ErrorBody {
776 reason: a2::ErrorReason::BadCertificate,
777 timestamp: None,
778 }),
779 apns_id: None,
780 code: 403,
781 })
782 )))
783 ),
784 "result = {result:?}"
785 );
786 }
787
788 #[tokio::test]
790 async fn invalid_aps_data() {
791 let client = MockApnsClient::new(|_| panic!("The notification should not be sent"));
792 let db = MockDbClient::new().into_boxed_arc();
793 let router = make_router(client, db);
794 let mut router_data = default_router_data();
795 router_data.insert(
796 "aps".to_string(),
797 serde_json::json!({"mutable-content": "should be a number"}),
798 );
799 let notification = make_notification(router_data, None, RouterType::APNS);
800
801 let result = router.route_notification(notification).await;
802 assert!(result.is_err());
803 assert!(
804 matches!(
805 result.as_ref().unwrap_err().kind,
806 ApiErrorKind::Router(RouterError::Apns(ApnsError::InvalidApsData))
807 ),
808 "result = {result:?}"
809 );
810 }
811}