autoendpoint/extractors/
notification.rs

1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::extractors::{
3    message_id::MessageId, notification_headers::NotificationHeaders, subscription::Subscription,
4};
5use crate::server::AppState;
6use actix_web::{FromRequest, HttpRequest, dev::Payload, web};
7use autopush_common::util::{b64_encode_url, ms_since_epoch, sec_since_epoch};
8use cadence::CountedExt;
9use fernet::MultiFernet;
10use futures::{FutureExt, future};
11use serde::Serialize;
12use std::collections::HashMap;
13use std::sync::{
14    Arc,
15    atomic::{AtomicUsize, Ordering},
16};
17use uuid::Uuid;
18
19/// Wire format for delivering notifications to connection servers.
20/// Uses a single serialization pass instead of building a HashMap of serde_json::Values.
21#[derive(Debug, Serialize)]
22pub struct TransportNotification<'a> {
23    #[serde(rename = "channelID")]
24    pub channel_id: uuid::Uuid,
25    pub version: &'a str,
26    pub ttl: i64,
27    pub topic: Option<&'a str>,
28    pub timestamp: u64,
29    #[cfg(feature = "reliable_report")]
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub reliability_id: Option<&'a str>,
32    #[cfg(feature = "reliable_report")]
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub reliable_state: Option<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub data: Option<&'a str>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub headers: Option<HashMap<String, String>>,
39}
40
41/// Extracts notification data from `Subscription` and request data
42// Note: Because we introduced the `in_process_counter` field and added the `Drop`
43// implementation, be very careful if adding `Clone` or `Copy` traits.
44#[derive(Debug)]
45pub struct Notification {
46    /// Unique message_id for this notification
47    pub message_id: String,
48    /// The subscription information block
49    pub subscription: Subscription,
50    /// Set of associated crypto headers
51    pub headers: NotificationHeaders,
52    /// UNIX timestamp in seconds
53    pub timestamp: u64,
54    /// UNIX timestamp in milliseconds
55    pub sort_key_timestamp: u64,
56    /// The encrypted notification body
57    pub data: Option<String>,
58    #[cfg(feature = "reliable_report")]
59    /// The current state the message was in (if tracked)
60    pub reliable_state: Option<autopush_common::reliability::ReliabilityState>,
61    #[cfg(feature = "reliable_report")]
62    pub reliability_id: Option<String>,
63    /// Internal reference to the appstate count of in process notifications.
64    /// _Note:_ This only tracks notifications that have been delivered
65    /// from a subscription provider. This does not track notifications that may
66    /// have been retrieved from storage.
67    /// This counter is in response to an incident where a large number of
68    /// valid, inbound notifications caused a cascade impact on our storage
69    /// engine, which resulted in a node OOM error killing the process. This
70    /// metric will be reported as part of the health check to allow the load
71    /// balancer to make informed routing decisions.
72    pub(crate) in_process_counter: Arc<AtomicUsize>,
73}
74
75impl Drop for Notification {
76    fn drop(&mut self) {
77        // `Clone` or `Copy` can cause the counter to decrement too many times.
78        // We'll set a floor for now.
79        let _ =
80            self.in_process_counter
81                .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
82                    current.checked_sub(1)
83                });
84        trace!(
85            "🧹 Dropping notification with message_id: {}",
86            self.message_id
87        );
88    }
89}
90
91impl FromRequest for Notification {
92    type Error = ApiError;
93    type Future = future::LocalBoxFuture<'static, Result<Self, Self::Error>>;
94
95    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
96        let req = req.clone();
97        let mut payload = payload.take();
98
99        async move {
100            let subscription = Subscription::extract(&req).await?;
101            let app_state = web::Data::<AppState>::extract(&req)
102                .await
103                .expect("No server state found");
104
105            let max_notification_ttl_secs = app_state.settings.max_notification_ttl;
106            // Read data
107            let data = web::Bytes::from_request(&req, &mut payload)
108                .await
109                .map_err(|e| {
110                    debug!("▶▶ Request read payload error: {:?}", &e);
111                    ApiErrorKind::PayloadError(e)
112                })?;
113
114            // Convert data to base64
115            let data = if data.is_empty() {
116                None
117            } else {
118                Some(b64_encode_url(&data.to_vec()))
119            };
120
121            let headers =
122                NotificationHeaders::from_request(&req, data.is_some(), max_notification_ttl_secs)?;
123            let timestamp = sec_since_epoch();
124            let sort_key_timestamp = ms_since_epoch();
125            let message_id = Self::generate_message_id(
126                &app_state.fernet,
127                subscription.user.uaid,
128                subscription.channel_id,
129                headers.topic.as_deref(),
130                sort_key_timestamp,
131            );
132
133            #[cfg(feature = "reliable_report")]
134            let reliability_id = subscription.reliability_id.clone();
135
136            #[allow(unused_mut)]
137            let mut notif = Notification {
138                message_id,
139                subscription,
140                headers,
141                timestamp,
142                sort_key_timestamp,
143                data,
144                #[cfg(feature = "reliable_report")]
145                reliable_state: None,
146                #[cfg(feature = "reliable_report")]
147                reliability_id,
148                in_process_counter: app_state.in_process_subscription_updates.clone(),
149            };
150
151            #[cfg(feature = "reliable_report")]
152            // Brand new notification, so record it as "Received"
153            notif
154                .record_reliability(
155                    &app_state.reliability,
156                    autopush_common::reliability::ReliabilityState::Received,
157                )
158                .await;
159
160            // record that we have a notification in process
161            notif.in_process_counter.fetch_add(1, Ordering::Relaxed);
162
163            // Record the encoding if we have an encrypted payload
164            if let Some(encoding) = &notif.headers.encoding
165                && notif.data.is_some()
166            {
167                app_state
168                    .metrics
169                    .incr(&format!("updates.notification.encoding.{encoding}"))
170                    .ok();
171            }
172
173            Ok(notif)
174        }
175        .boxed_local()
176    }
177}
178
179impl From<&Notification> for autopush_common::notification::Notification {
180    fn from(notification: &Notification) -> Self {
181        let topic = notification.headers.topic.clone();
182        let sortkey_timestamp = topic.is_none().then_some(notification.sort_key_timestamp);
183        autopush_common::notification::Notification {
184            channel_id: notification.subscription.channel_id,
185            version: notification.message_id.clone(),
186            ttl: notification.headers.ttl as u64,
187            topic,
188            timestamp: notification.timestamp,
189            data: notification.data.clone(),
190            sortkey_timestamp,
191            #[cfg(feature = "reliable_report")]
192            reliability_id: notification.subscription.reliability_id.clone(),
193            headers: {
194                let headers: HashMap<String, String> = notification.headers.clone().into();
195                if headers.is_empty() {
196                    None
197                } else {
198                    Some(headers)
199                }
200            },
201            #[cfg(feature = "reliable_report")]
202            reliable_state: notification.reliable_state,
203        }
204    }
205}
206
207impl From<Notification> for autopush_common::notification::Notification {
208    fn from(notification: Notification) -> Self {
209        // Delegate to the borrowing impl to avoid duplication
210        autopush_common::notification::Notification::from(&notification)
211    }
212}
213
214impl Notification {
215    /// Generate a message-id suitable for accessing the message
216    ///
217    /// For topic messages, a sort_key version of 01 is used, and the topic
218    /// is included for reference:
219    ///
220    ///     Encrypted('01' : uaid.hex : channel_id.hex : topic)
221    ///
222    /// For non-topic messages, a sort_key version of 02 is used:
223    ///
224    ///     Encrypted('02' : uaid.hex : channel_id.hex : timestamp)
225    fn generate_message_id(
226        fernet: &MultiFernet,
227        uaid: Uuid,
228        channel_id: Uuid,
229        topic: Option<&str>,
230        timestamp: u64,
231    ) -> String {
232        let message_id = if let Some(topic) = topic {
233            MessageId::WithTopic {
234                uaid,
235                channel_id,
236                topic: topic.to_string(),
237            }
238        } else {
239            MessageId::WithoutTopic {
240                uaid,
241                channel_id,
242                timestamp,
243            }
244        };
245
246        message_id.encrypt(fernet)
247    }
248
249    pub fn has_topic(&self) -> bool {
250        self.headers.topic.is_some()
251    }
252
253    /// Serialize the notification for delivery to the connection server. Some
254    /// fields in `autopush_common`'s `Notification` are marked with
255    /// `#[serde(skip_serializing)]` so they are not shown to the UA. These
256    /// fields are still required when delivering to the connection server, so
257    /// we can't simply convert this notification type to that one and serialize
258    /// via serde.
259    pub fn serialize_for_delivery(&self) -> ApiResult<TransportNotification<'_>> {
260        let headers = self.data.as_ref().map(|_| {
261            let h: HashMap<String, String> = self.headers.clone().into();
262            h
263        });
264        Ok(TransportNotification {
265            channel_id: self.subscription.channel_id,
266            version: &self.message_id,
267            ttl: self.headers.ttl,
268            topic: self.headers.topic.as_deref(),
269            timestamp: self.timestamp,
270            #[cfg(feature = "reliable_report")]
271            reliability_id: self.subscription.reliability_id.as_deref(),
272            #[cfg(feature = "reliable_report")]
273            reliable_state: self.reliable_state.map(|s| s.to_string()),
274            data: self.data.as_deref(),
275            headers,
276        })
277    }
278
279    #[cfg(feature = "reliable_report")]
280    pub async fn record_reliability(
281        &mut self,
282        reliability: &autopush_common::reliability::PushReliability,
283        state: autopush_common::reliability::ReliabilityState,
284    ) {
285        self.reliable_state = reliability
286            .record(
287                &self.reliability_id,
288                state,
289                &self.reliable_state,
290                Some(self.timestamp + self.headers.ttl as u64),
291            )
292            .await
293            .inspect_err(|e| {
294                warn!("🔍⚠️ Unable to record reliability state log: {:?}", e);
295            })
296            .unwrap_or(Some(state))
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use autopush_common::endpoint::make_endpoint;
303
304    use super::*;
305    use crate::server::AppState;
306
307    #[actix_rt::test]
308    async fn test_notification_counter() {
309        // If you're ever wondering how to set up a mock request, this is it.
310        let chid = uuid::Uuid::new_v4();
311        let channels = std::collections::HashSet::from([chid]);
312
313        // Set up the mock database returning only the required calls.
314        let mut mock_db = autopush_common::db::mock::MockDbClient::new();
315        mock_db
316            .expect_get_user()
317            .returning(|_| Ok(Some(autopush_common::db::User::default())));
318        mock_db
319            .expect_get_channels()
320            .returning(move |_| Ok(channels.clone()));
321        let app_state = AppState::test_default(mock_db).await;
322
323        // Now build the endpoint so that it passes all validation steps.
324        // First, build a valid endpoint with dummy data.
325        let endpoint = make_endpoint(
326            &uuid::Uuid::new_v4(),
327            &chid,
328            None,
329            "http://example.com/v2",
330            &app_state.fernet,
331        )
332        .unwrap();
333        // We have to extract the token from the endpoint. This is normally done by the actix URL parser, but that's not available..
334        let token_str = endpoint.rsplit('/').next().unwrap().to_owned();
335        let test_request = actix_web::test::TestRequest::with_uri(&endpoint)
336            .param("api_version", "v1")
337            .param("token", token_str)
338            .insert_header((
339                // Remember kids, gotta set the TTL.
340                actix_http::header::HeaderName::from_static("ttl"),
341                actix_http::header::HeaderValue::from_static("0"),
342            ))
343            .app_data(actix_web::web::Data::new(app_state.clone()))
344            .to_http_request();
345        // This has no payload. It's valid, but means no message encryption checks.
346        let mut payload = actix_web::dev::Payload::None;
347
348        // Begin the test.
349        let initial_count = app_state
350            .in_process_subscription_updates
351            .load(Ordering::Relaxed);
352        {
353            // Lets get a notification. This should bump the counter.
354            let result = Notification::from_request(&test_request, &mut payload).await;
355            assert!(
356                initial_count
357                    < app_state
358                        .in_process_subscription_updates
359                        .load(Ordering::Relaxed)
360            );
361            assert!(result.is_ok());
362        }
363        // At this point, the notification is out of context and the `Drop` should have been called.
364        assert_eq!(
365            initial_count,
366            app_state
367                .in_process_subscription_updates
368                .load(Ordering::Relaxed)
369        );
370    }
371}