autoendpoint/extractors/
notification.rs

1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::extractors::{
3    message_id::MessageId, notification_headers::NotificationHeaders, subscription::Subscription,
4};
5use crate::server::AppState;
6use actix_web::{dev::Payload, web, FromRequest, HttpRequest};
7use autopush_common::util::{b64_encode_url, ms_since_epoch, sec_since_epoch};
8use cadence::CountedExt;
9use fernet::MultiFernet;
10use futures::{future, FutureExt};
11use serde::Serialize;
12use std::collections::HashMap;
13use uuid::Uuid;
14
15/// Wire format for delivering notifications to connection servers.
16/// Uses a single serialization pass instead of building a HashMap of serde_json::Values.
17#[derive(Debug, Serialize)]
18pub struct TransportNotification<'a> {
19    #[serde(rename = "channelID")]
20    pub channel_id: uuid::Uuid,
21    pub version: &'a str,
22    pub ttl: i64,
23    pub topic: Option<&'a str>,
24    pub timestamp: u64,
25    #[cfg(feature = "reliable_report")]
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub reliability_id: Option<&'a str>,
28    #[cfg(feature = "reliable_report")]
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub reliable_state: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub data: Option<&'a str>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub headers: Option<HashMap<String, String>>,
35}
36
37/// Extracts notification data from `Subscription` and request data
38#[derive(Clone, Debug)]
39pub struct Notification {
40    /// Unique message_id for this notification
41    pub message_id: String,
42    /// The subscription information block
43    pub subscription: Subscription,
44    /// Set of associated crypto headers
45    pub headers: NotificationHeaders,
46    /// UNIX timestamp in seconds
47    pub timestamp: u64,
48    /// UNIX timestamp in milliseconds
49    pub sort_key_timestamp: u64,
50    /// The encrypted notification body
51    pub data: Option<String>,
52    #[cfg(feature = "reliable_report")]
53    /// The current state the message was in (if tracked)
54    pub reliable_state: Option<autopush_common::reliability::ReliabilityState>,
55    #[cfg(feature = "reliable_report")]
56    pub reliability_id: Option<String>,
57}
58
59impl FromRequest for Notification {
60    type Error = ApiError;
61    type Future = future::LocalBoxFuture<'static, Result<Self, Self::Error>>;
62
63    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
64        let req = req.clone();
65        let mut payload = payload.take();
66
67        async move {
68            let subscription = Subscription::extract(&req).await?;
69            let app_state = web::Data::<AppState>::extract(&req)
70                .await
71                .expect("No server state found");
72
73            let max_notification_ttl_secs = app_state.settings.max_notification_ttl;
74            // Read data
75            let data = web::Bytes::from_request(&req, &mut payload)
76                .await
77                .map_err(|e| {
78                    debug!("▶▶ Request read payload error: {:?}", &e);
79                    ApiErrorKind::PayloadError(e)
80                })?;
81
82            // Convert data to base64
83            let data = if data.is_empty() {
84                None
85            } else {
86                Some(b64_encode_url(&data.to_vec()))
87            };
88
89            let headers =
90                NotificationHeaders::from_request(&req, data.is_some(), max_notification_ttl_secs)?;
91            let timestamp = sec_since_epoch();
92            let sort_key_timestamp = ms_since_epoch();
93            let message_id = Self::generate_message_id(
94                &app_state.fernet,
95                subscription.user.uaid,
96                subscription.channel_id,
97                headers.topic.as_deref(),
98                sort_key_timestamp,
99            );
100
101            #[cfg(feature = "reliable_report")]
102            let reliability_id = subscription.reliability_id.clone();
103
104            #[allow(unused_mut)]
105            let mut notif = Notification {
106                message_id,
107                subscription,
108                headers,
109                timestamp,
110                sort_key_timestamp,
111                data,
112                #[cfg(feature = "reliable_report")]
113                reliable_state: None,
114                #[cfg(feature = "reliable_report")]
115                reliability_id,
116            };
117
118            #[cfg(feature = "reliable_report")]
119            // Brand new notification, so record it as "Received"
120            notif
121                .record_reliability(
122                    &app_state.reliability,
123                    autopush_common::reliability::ReliabilityState::Received,
124                )
125                .await;
126
127            // Record the encoding if we have an encrypted payload
128            if let Some(encoding) = &notif.headers.encoding {
129                if notif.data.is_some() {
130                    app_state
131                        .metrics
132                        .incr(&format!("updates.notification.encoding.{encoding}"))
133                        .ok();
134                }
135            }
136
137            Ok(notif)
138        }
139        .boxed_local()
140    }
141}
142
143impl From<&Notification> for autopush_common::notification::Notification {
144    fn from(notification: &Notification) -> Self {
145        let topic = notification.headers.topic.clone();
146        let sortkey_timestamp = topic.is_none().then_some(notification.sort_key_timestamp);
147        autopush_common::notification::Notification {
148            channel_id: notification.subscription.channel_id,
149            version: notification.message_id.clone(),
150            ttl: notification.headers.ttl as u64,
151            topic,
152            timestamp: notification.timestamp,
153            data: notification.data.clone(),
154            sortkey_timestamp,
155            #[cfg(feature = "reliable_report")]
156            reliability_id: notification.subscription.reliability_id.clone(),
157            headers: {
158                let headers: HashMap<String, String> = notification.headers.clone().into();
159                if headers.is_empty() {
160                    None
161                } else {
162                    Some(headers)
163                }
164            },
165            #[cfg(feature = "reliable_report")]
166            reliable_state: notification.reliable_state,
167        }
168    }
169}
170
171impl From<Notification> for autopush_common::notification::Notification {
172    fn from(notification: Notification) -> Self {
173        // Delegate to the borrowing impl to avoid duplication
174        autopush_common::notification::Notification::from(&notification)
175    }
176}
177
178impl Notification {
179    /// Generate a message-id suitable for accessing the message
180    ///
181    /// For topic messages, a sort_key version of 01 is used, and the topic
182    /// is included for reference:
183    ///
184    ///     Encrypted('01' : uaid.hex : channel_id.hex : topic)
185    ///
186    /// For non-topic messages, a sort_key version of 02 is used:
187    ///
188    ///     Encrypted('02' : uaid.hex : channel_id.hex : timestamp)
189    fn generate_message_id(
190        fernet: &MultiFernet,
191        uaid: Uuid,
192        channel_id: Uuid,
193        topic: Option<&str>,
194        timestamp: u64,
195    ) -> String {
196        let message_id = if let Some(topic) = topic {
197            MessageId::WithTopic {
198                uaid,
199                channel_id,
200                topic: topic.to_string(),
201            }
202        } else {
203            MessageId::WithoutTopic {
204                uaid,
205                channel_id,
206                timestamp,
207            }
208        };
209
210        message_id.encrypt(fernet)
211    }
212
213    pub fn has_topic(&self) -> bool {
214        self.headers.topic.is_some()
215    }
216
217    /// Serialize the notification for delivery to the connection server. Some
218    /// fields in `autopush_common`'s `Notification` are marked with
219    /// `#[serde(skip_serializing)]` so they are not shown to the UA. These
220    /// fields are still required when delivering to the connection server, so
221    /// we can't simply convert this notification type to that one and serialize
222    /// via serde.
223    pub fn serialize_for_delivery(&self) -> ApiResult<TransportNotification<'_>> {
224        let headers = self.data.as_ref().map(|_| {
225            let h: HashMap<String, String> = self.headers.clone().into();
226            h
227        });
228        Ok(TransportNotification {
229            channel_id: self.subscription.channel_id,
230            version: &self.message_id,
231            ttl: self.headers.ttl,
232            topic: self.headers.topic.as_deref(),
233            timestamp: self.timestamp,
234            #[cfg(feature = "reliable_report")]
235            reliability_id: self.subscription.reliability_id.as_deref(),
236            #[cfg(feature = "reliable_report")]
237            reliable_state: self.reliable_state.map(|s| s.to_string()),
238            data: self.data.as_deref(),
239            headers,
240        })
241    }
242
243    #[cfg(feature = "reliable_report")]
244    pub async fn record_reliability(
245        &mut self,
246        reliability: &autopush_common::reliability::PushReliability,
247        state: autopush_common::reliability::ReliabilityState,
248    ) {
249        self.reliable_state = reliability
250            .record(
251                &self.reliability_id,
252                state,
253                &self.reliable_state,
254                Some(self.timestamp + self.headers.ttl as u64),
255            )
256            .await
257            .inspect_err(|e| {
258                warn!("🔍⚠️ Unable to record reliability state log: {:?}", e);
259            })
260            .unwrap_or(Some(state))
261    }
262}