1#[cfg(feature = "reliable_report")]
2use autopush_common::reliability::{PushReliability, ReliabilityState};
3use autopush_common::{db::client::DbClient, MAX_FCM_NOTIFICATION_TTL};
4
5use crate::error::ApiResult;
6use crate::extractors::notification::Notification;
7use crate::extractors::router_data_input::RouterDataInput;
8use crate::routers::common::{build_message_data, handle_error, incr_success_metrics};
9use crate::routers::fcm::client::FcmClient;
10use crate::routers::fcm::error::FcmError;
11use crate::routers::fcm::settings::{FcmServerCredential, FcmSettings};
12use crate::routers::{Router, RouterError, RouterResponse};
13use async_trait::async_trait;
14use cadence::StatsdClient;
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::Arc;
18use url::Url;
19use uuid::Uuid;
20
21pub struct FcmRouter {
23 settings: FcmSettings,
24 endpoint_url: Url,
25 metrics: Arc<StatsdClient>,
26 db: Box<dyn DbClient>,
27 clients: HashMap<String, FcmClient>,
29 #[cfg(feature = "reliable_report")]
30 reliability: Arc<PushReliability>,
31}
32
33impl FcmRouter {
34 pub async fn new(
36 settings: FcmSettings,
37 endpoint_url: Url,
38 http: reqwest::Client,
39 metrics: Arc<StatsdClient>,
40 db: Box<dyn DbClient>,
41 #[cfg(feature = "reliable_report")] reliability: Arc<PushReliability>,
42 ) -> Result<Self, FcmError> {
43 let server_credentials = settings.credentials()?;
44 let clients = Self::create_clients(&settings, server_credentials, http.clone())
45 .await
46 .map_err(FcmError::OAuthClientBuild)?;
47 Ok(Self {
48 settings,
49 endpoint_url,
50 metrics,
51 db,
52 clients,
53 #[cfg(feature = "reliable_report")]
54 reliability,
55 })
56 }
57
58 async fn create_clients(
60 settings: &FcmSettings,
61 server_credentials: HashMap<String, FcmServerCredential>,
62 http: reqwest::Client,
63 ) -> std::io::Result<HashMap<String, FcmClient>> {
64 let mut clients = HashMap::new();
65
66 for (profile, server_credential) in server_credentials {
67 clients.insert(
68 profile,
69 FcmClient::new(settings, server_credential, http.clone()).await?,
70 );
71 }
72 trace!("Initialized {} FCM clients", clients.len());
73 Ok(clients)
74 }
75
76 pub fn active(&self) -> bool {
78 !self.clients.is_empty()
79 }
80
81 fn routing_info(
87 &self,
88 router_data: &HashMap<String, Value>,
89 uaid: &Uuid,
90 ) -> ApiResult<(String, String)> {
91 let creds = router_data.get("creds").and_then(Value::as_object);
92 let routing_token = match router_data.get("token").and_then(Value::as_str) {
98 Some(v) => v.to_owned(),
99 None => {
100 warn!("No Registration token found for user {}", uaid.to_string());
101 return Err(FcmError::NoRegistrationToken.into());
102 }
103 };
104 let app_id = match router_data.get("app_id").and_then(Value::as_str) {
105 Some(v) => v.to_owned(),
106 None => {
107 if creds.is_none() {
108 warn!("No App_id found for user {}", uaid.to_string());
109 return Err(FcmError::NoAppId.into());
110 }
111 match creds
112 .unwrap()
113 .get("senderID")
114 .map(|v| v.as_str())
115 .unwrap_or(None)
116 {
117 Some(v) => v.to_owned(),
118 None => return Err(FcmError::NoAppId.into()),
119 }
120 }
121 };
122 Ok((routing_token, app_id))
123 }
124}
125
126#[async_trait(?Send)]
127impl Router for FcmRouter {
128 fn register(
129 &self,
130 router_data_input: &RouterDataInput,
131 app_id: &str,
132 ) -> Result<HashMap<String, Value>, RouterError> {
133 if !self.clients.contains_key(app_id) {
134 return Err(FcmError::InvalidAppId(app_id.to_owned()).into());
135 }
136
137 let mut router_data = HashMap::new();
138 router_data.insert(
139 "token".to_string(),
140 serde_json::to_value(&router_data_input.token).unwrap(),
141 );
142 router_data.insert("app_id".to_string(), serde_json::to_value(app_id).unwrap());
143
144 Ok(router_data)
145 }
146
147 #[allow(unused_mut)]
148 async fn route_notification(
149 &self,
150 mut notification: Notification,
151 ) -> ApiResult<RouterResponse> {
152 debug!(
153 "Sending FCM notification to UAID {}",
154 notification.subscription.user.uaid
155 );
156 trace!("Notification = {:?}", notification);
157
158 let router_data = notification
159 .subscription
160 .user
161 .router_data
162 .as_ref()
163 .ok_or(FcmError::NoRegistrationToken)?;
164
165 let (routing_token, app_id) =
166 self.routing_info(router_data, ¬ification.subscription.user.uaid)?;
167 let ttl = (MAX_FCM_NOTIFICATION_TTL.num_seconds() as u64)
168 .min(self.settings.min_ttl.max(notification.headers.ttl as u64));
169
170 let client = self
172 .clients
173 .get(&app_id)
174 .ok_or_else(|| FcmError::InvalidAppId(app_id.clone()))?;
175
176 let message_data = build_message_data(¬ification)?;
177 let platform = "fcmv1";
178 trace!("Sending message to {platform}: [{:?}]", &app_id);
179 if let Err(e) = client.send(message_data, routing_token, ttl).await {
180 #[cfg(feature = "reliable_report")]
181 notification
182 .record_reliability(&self.reliability, ReliabilityState::Errored)
183 .await;
184 return Err(handle_error(
185 e,
186 &self.metrics,
187 self.db.as_ref(),
188 platform,
189 &app_id,
190 notification.subscription.user.uaid,
191 notification.subscription.vapid.clone(),
192 )
193 .await);
194 };
195 incr_success_metrics(&self.metrics, platform, &app_id, ¬ification);
196 #[cfg(feature = "reliable_report")]
197 notification
202 .record_reliability(&self.reliability, ReliabilityState::BridgeTransmitted)
203 .await;
204 trace!("Send request was successful");
206
207 Ok(RouterResponse::success(
208 self.endpoint_url
209 .join(&format!("/m/{}", notification.message_id))
210 .expect("Message ID is not URL-safe")
211 .to_string(),
212 notification.headers.ttl as usize,
213 ))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use crate::error::ApiErrorKind;
220 use crate::extractors::routers::RouterType;
221 use crate::routers::common::tests::{make_notification, CHANNEL_ID};
222 use crate::routers::fcm::client::tests::{
223 make_service_key, mock_fcm_endpoint_builder, mock_token_endpoint, GCM_PROJECT_ID,
224 PROJECT_ID,
225 };
226 use crate::routers::fcm::error::FcmError;
227 use crate::routers::fcm::router::FcmRouter;
228 use crate::routers::fcm::settings::FcmSettings;
229 use crate::routers::RouterError;
230 use crate::routers::{Router, RouterResponse};
231 use autopush_common::db::client::DbClient;
232 use autopush_common::db::mock::MockDbClient;
233 #[cfg(feature = "reliable_report")]
234 use autopush_common::{redis_util::MAX_TRANSACTION_LOOP, reliability::PushReliability};
235 use std::sync::Arc;
236
237 use cadence::StatsdClient;
238 use mockall::predicate;
239 use std::collections::HashMap;
240 use url::Url;
241
242 const FCM_TOKEN: &str = "test-token";
243
244 async fn make_router(
246 server: &mut mockito::ServerGuard,
247 fcm_credential: String,
248 gcm_credential: String,
249 db: Box<dyn DbClient>,
250 ) -> FcmRouter {
251 let url = &server.url();
252 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
253
254 FcmRouter::new(
255 FcmSettings {
256 base_url: Url::parse(url).unwrap(),
257 server_credentials: serde_json::json!({
258 "dev": {
259 "project_id": PROJECT_ID,
260 "credential": fcm_credential
261 },
262 GCM_PROJECT_ID: {
263 "project_id": GCM_PROJECT_ID,
264 "credential": gcm_credential,
265 "is_gcm": true,
266 }
267 })
268 .to_string(),
269 ..Default::default()
270 },
271 Url::parse("http://localhost:8080/").unwrap(),
272 reqwest::Client::new(),
273 Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)),
274 db.clone(),
275 #[cfg(feature = "reliable_report")]
276 Arc::new(
277 PushReliability::new(&None, db.clone(), &metrics, MAX_TRANSACTION_LOOP).unwrap(),
278 ),
279 )
280 .await
281 .unwrap()
282 }
283
284 fn default_router_data() -> HashMap<String, serde_json::Value> {
286 let mut map = HashMap::new();
287 map.insert(
288 "token".to_string(),
289 serde_json::to_value(FCM_TOKEN).unwrap(),
290 );
291 map.insert("app_id".to_string(), serde_json::to_value("dev").unwrap());
292 map
293 }
294
295 #[tokio::test]
297 async fn successful_routing_no_data() {
298 let mut server = mockito::Server::new_async().await;
299
300 let mdb = MockDbClient::new();
301 let db = mdb.into_boxed_arc();
302 let service_key = make_service_key(&server);
303 let router = make_router(&mut server, service_key, "whatever".to_string(), db).await;
304 assert!(router.active());
305 let _token_mock = mock_token_endpoint(&mut server).await;
306 let fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID)
307 .match_body(
308 serde_json::json!({
309 "message": {
310 "android": {
311 "data": {
312 "chid": CHANNEL_ID
313 },
314 "ttl": "60s"
315 },
316 "token": "test-token"
317 }
318 })
319 .to_string()
320 .as_str(),
321 )
322 .create();
323 let notification = make_notification(default_router_data(), None, RouterType::FCM);
324
325 let result = router.route_notification(notification).await;
326 assert!(result.is_ok(), "result = {result:?}");
327 assert_eq!(
328 result.unwrap(),
329 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
330 );
331 fcm_mock.assert();
332 }
333
334 #[tokio::test]
336 async fn successful_routing_with_data() {
337 let mut server = mockito::Server::new_async().await;
338
339 let mdb = MockDbClient::new();
340 let db = mdb.into_boxed_arc();
341 let service_key = make_service_key(&server);
342 let router = make_router(&mut server, service_key, "whatever".to_string(), db).await;
343 let _token_mock = mock_token_endpoint(&mut server).await;
344 let fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID)
345 .match_body(
346 serde_json::json!({
347 "message": {
348 "android": {
349 "data": {
350 "chid": CHANNEL_ID,
351 "body": "test-data",
352 "con": "test-encoding",
353 "enc": "test-encryption",
354 "cryptokey": "test-crypto-key",
355 "enckey": "test-encryption-key"
356 },
357 "ttl": "60s"
358 },
359 "token": "test-token"
360 }
361 })
362 .to_string()
363 .as_str(),
364 )
365 .create();
366 let data = "test-data".to_string();
367 let notification = make_notification(default_router_data(), Some(data), RouterType::FCM);
368
369 let result = router.route_notification(notification).await;
370 assert!(result.is_ok(), "result = {result:?}");
371 assert_eq!(
372 result.unwrap(),
373 RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0)
374 );
375 fcm_mock.assert();
376 }
377
378 #[tokio::test]
381 async fn missing_client() {
382 let mut server = mockito::Server::new_async().await;
383
384 let db = MockDbClient::new().into_boxed_arc();
385 let service_key = make_service_key(&server);
386 let router = make_router(&mut server, service_key, "whatever".to_string(), db).await;
387 let _token_mock = mock_token_endpoint(&mut server).await;
388 let fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID)
389 .expect(0)
390 .create_async()
391 .await;
392 let mut router_data = default_router_data();
393 let app_id = "app_id".to_string();
394 router_data.insert(
395 app_id.clone(),
396 serde_json::to_value("unknown-app-id").unwrap(),
397 );
398 let notification = make_notification(router_data, None, RouterType::FCM);
399
400 let result = router.route_notification(notification).await;
401 assert!(result.is_err());
402 assert!(
403 matches!(
404 &result.as_ref().unwrap_err().kind,
405 ApiErrorKind::Router(RouterError::Fcm(FcmError::InvalidAppId(_app_id)))
406 ),
407 "result = {result:?}"
408 );
409 fcm_mock.assert();
410 }
411
412 #[tokio::test]
414 async fn no_fcm_user() {
415 let mut server = mockito::Server::new_async().await;
416
417 let notification = make_notification(default_router_data(), None, RouterType::FCM);
418 let mut db = MockDbClient::new();
419 db.expect_remove_user()
420 .with(predicate::eq(notification.subscription.user.uaid))
421 .times(1)
422 .return_once(|_| Ok(()));
423
424 let service_key = make_service_key(&server);
425 let router = make_router(
426 &mut server,
427 service_key,
428 "whatever".to_string(),
429 db.into_boxed_arc(),
430 )
431 .await;
432 let _token_mock = mock_token_endpoint(&mut server).await;
433 let _fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID)
434 .with_status(404)
435 .with_body(r#"{"error":{"status":"NOT_FOUND","message":"test-message"}}"#)
436 .create_async()
437 .await;
438
439 let result = router.route_notification(notification).await;
440 assert!(result.is_err());
441 assert!(
442 matches!(
443 result.as_ref().unwrap_err().kind,
444 ApiErrorKind::Router(RouterError::NotFound)
445 ),
446 "result = {result:?}"
447 );
448 }
449}