autoendpoint/extractors/
notification.rs

1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::extractors::{
3    message_id::MessageId, notification_headers::NotificationHeaders, subscription::Subscription,
4};
5use crate::server::AppState;
6use actix_web::{dev::Payload, web, FromRequest, HttpRequest};
7use autopush_common::util::{b64_encode_url, ms_since_epoch, sec_since_epoch};
8use cadence::CountedExt;
9use fernet::MultiFernet;
10use futures::{future, FutureExt};
11use std::collections::HashMap;
12use uuid::Uuid;
13
14/// Extracts notification data from `Subscription` and request data
15#[derive(Clone, Debug)]
16pub struct Notification {
17    /// Unique message_id for this notification
18    pub message_id: String,
19    /// The subscription information block
20    pub subscription: Subscription,
21    /// Set of associated crypto headers
22    pub headers: NotificationHeaders,
23    /// UNIX timestamp in seconds
24    pub timestamp: u64,
25    /// UNIX timestamp in milliseconds
26    pub sort_key_timestamp: u64,
27    /// The encrypted notification body
28    pub data: Option<String>,
29    #[cfg(feature = "reliable_report")]
30    /// The current state the message was in (if tracked)
31    pub reliable_state: Option<autopush_common::reliability::ReliabilityState>,
32    #[cfg(feature = "reliable_report")]
33    pub reliability_id: Option<String>,
34}
35
36impl FromRequest for Notification {
37    type Error = ApiError;
38    type Future = future::LocalBoxFuture<'static, Result<Self, Self::Error>>;
39
40    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
41        let req = req.clone();
42        let mut payload = payload.take();
43
44        async move {
45            let subscription = Subscription::extract(&req).await?;
46            let app_state = web::Data::<AppState>::extract(&req)
47                .await
48                .expect("No server state found");
49
50            let max_notification_ttl_secs = app_state.settings.max_notification_ttl;
51            // Read data
52            let data = web::Bytes::from_request(&req, &mut payload)
53                .await
54                .map_err(|e| {
55                    debug!("▶▶ Request read payload error: {:?}", &e);
56                    ApiErrorKind::PayloadError(e)
57                })?;
58
59            // Convert data to base64
60            let data = if data.is_empty() {
61                None
62            } else {
63                Some(b64_encode_url(&data.to_vec()))
64            };
65
66            let headers =
67                NotificationHeaders::from_request(&req, data.is_some(), max_notification_ttl_secs)?;
68            let timestamp = sec_since_epoch();
69            let sort_key_timestamp = ms_since_epoch();
70            let message_id = Self::generate_message_id(
71                &app_state.fernet,
72                subscription.user.uaid,
73                subscription.channel_id,
74                headers.topic.as_deref(),
75                sort_key_timestamp,
76            );
77
78            #[cfg(feature = "reliable_report")]
79            let reliability_id = subscription.reliability_id.clone();
80
81            #[allow(unused_mut)]
82            let mut notif = Notification {
83                message_id,
84                subscription,
85                headers,
86                timestamp,
87                sort_key_timestamp,
88                data,
89                #[cfg(feature = "reliable_report")]
90                reliable_state: None,
91                #[cfg(feature = "reliable_report")]
92                reliability_id,
93            };
94
95            #[cfg(feature = "reliable_report")]
96            // Brand new notification, so record it as "Received"
97            notif
98                .record_reliability(
99                    &app_state.reliability,
100                    autopush_common::reliability::ReliabilityState::Received,
101                )
102                .await;
103
104            // Record the encoding if we have an encrypted payload
105            if let Some(encoding) = &notif.headers.encoding {
106                if notif.data.is_some() {
107                    app_state
108                        .metrics
109                        .incr(&format!("updates.notification.encoding.{encoding}"))
110                        .ok();
111                }
112            }
113
114            Ok(notif)
115        }
116        .boxed_local()
117    }
118}
119
120impl From<Notification> for autopush_common::notification::Notification {
121    fn from(notification: Notification) -> Self {
122        let topic = notification.headers.topic.clone();
123        let sortkey_timestamp = topic.is_none().then_some(notification.sort_key_timestamp);
124        autopush_common::notification::Notification {
125            channel_id: notification.subscription.channel_id,
126            version: notification.message_id,
127            ttl: notification.headers.ttl as u64,
128            topic,
129            timestamp: notification.timestamp,
130            data: notification.data,
131            sortkey_timestamp,
132            #[cfg(feature = "reliable_report")]
133            reliability_id: notification.subscription.reliability_id,
134            headers: {
135                let headers: HashMap<String, String> = notification.headers.into();
136                if headers.is_empty() {
137                    None
138                } else {
139                    Some(headers)
140                }
141            },
142            #[cfg(feature = "reliable_report")]
143            reliable_state: notification.reliable_state,
144        }
145    }
146}
147
148impl Notification {
149    /// Generate a message-id suitable for accessing the message
150    ///
151    /// For topic messages, a sort_key version of 01 is used, and the topic
152    /// is included for reference:
153    ///
154    ///     Encrypted('01' : uaid.hex : channel_id.hex : topic)
155    ///
156    /// For non-topic messages, a sort_key version of 02 is used:
157    ///
158    ///     Encrypted('02' : uaid.hex : channel_id.hex : timestamp)
159    fn generate_message_id(
160        fernet: &MultiFernet,
161        uaid: Uuid,
162        channel_id: Uuid,
163        topic: Option<&str>,
164        timestamp: u64,
165    ) -> String {
166        let message_id = if let Some(topic) = topic {
167            MessageId::WithTopic {
168                uaid,
169                channel_id,
170                topic: topic.to_string(),
171            }
172        } else {
173            MessageId::WithoutTopic {
174                uaid,
175                channel_id,
176                timestamp,
177            }
178        };
179
180        message_id.encrypt(fernet)
181    }
182
183    pub fn has_topic(&self) -> bool {
184        self.headers.topic.is_some()
185    }
186
187    /// Serialize the notification for delivery to the connection server. Some
188    /// fields in `autopush_common`'s `Notification` are marked with
189    /// `#[serde(skip_serializing)]` so they are not shown to the UA. These
190    /// fields are still required when delivering to the connection server, so
191    /// we can't simply convert this notification type to that one and serialize
192    /// via serde.
193    pub fn serialize_for_delivery(&self) -> ApiResult<HashMap<&'static str, serde_json::Value>> {
194        let mut map = HashMap::new();
195
196        map.insert(
197            "channelID",
198            serde_json::to_value(self.subscription.channel_id)?,
199        );
200        map.insert("version", serde_json::to_value(&self.message_id)?);
201        map.insert("ttl", serde_json::to_value(self.headers.ttl)?);
202        map.insert("topic", serde_json::to_value(&self.headers.topic)?);
203        map.insert("timestamp", serde_json::to_value(self.timestamp)?);
204        #[cfg(feature = "reliable_report")]
205        {
206            if let Some(reliability_id) = &self.subscription.reliability_id {
207                map.insert("reliability_id", serde_json::to_value(reliability_id)?);
208            }
209            if let Some(reliable_state) = self.reliable_state {
210                map.insert(
211                    "reliable_state",
212                    serde_json::to_value(reliable_state.to_string())?,
213                );
214            }
215        }
216        if let Some(data) = &self.data {
217            map.insert("data", serde_json::to_value(data)?);
218
219            let headers: HashMap<_, _> = self.headers.clone().into();
220            map.insert("headers", serde_json::to_value(headers)?);
221        }
222
223        Ok(map)
224    }
225
226    #[cfg(feature = "reliable_report")]
227    pub async fn record_reliability(
228        &mut self,
229        reliability: &autopush_common::reliability::PushReliability,
230        state: autopush_common::reliability::ReliabilityState,
231    ) {
232        self.reliable_state = reliability
233            .record(
234                &self.reliability_id,
235                state,
236                &self.reliable_state,
237                Some(self.timestamp + self.headers.ttl as u64),
238            )
239            .await
240            .inspect_err(|e| {
241                warn!("🔍⚠️ Unable to record reliability state log: {:?}", e);
242            })
243            .unwrap_or(Some(state))
244    }
245}