autoendpoint/extractors/
notification_headers.rs

1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::headers::crypto_key::CryptoKeyHeader;
3use crate::headers::util::{get_header, get_owned_header};
4use actix_web::HttpRequest;
5use autopush_common::{util::InsertOpt, MAX_NOTIFICATION_TTL};
6use lazy_static::lazy_static;
7use regex::Regex;
8use std::cmp::min;
9use std::collections::HashMap;
10use validator::Validate;
11use validator_derive::Validate;
12
13lazy_static! {
14    static ref VALID_BASE64_URL: Regex = Regex::new(r"^[0-9A-Za-z\-_]+=*$").unwrap();
15    static ref STRIP_PADDING: Regex =
16        Regex::new(r"(?P<head>[0-9A-Za-z\-_]+)=+(?P<tail>[,;]|$)").unwrap();
17}
18
19/// Extractor and validator for notification headers
20#[derive(Clone, Debug, Eq, PartialEq, Validate)]
21pub struct NotificationHeaders {
22    // TTL is a signed value so that validation can catch negative inputs
23    #[validate(range(min = 0, message = "TTL must be greater than 0", code = "114"))]
24    pub ttl: i64,
25
26    #[validate(
27        length(
28            max = 32,
29            message = "Topic must be no greater than 32 characters",
30            code = "113"
31        ),
32        regex(
33            path = *VALID_BASE64_URL,
34            message = "Topic must be URL and Filename safe Base64 alphabet",
35            code = "113"
36        )
37    )]
38    pub topic: Option<String>,
39
40    // These fields are validated separately, because the validation is complex
41    // and based upon the content encoding
42    pub encoding: Option<String>,
43    pub encryption: Option<String>,
44    pub encryption_key: Option<String>,
45    pub crypto_key: Option<String>,
46}
47
48impl From<NotificationHeaders> for HashMap<String, String> {
49    fn from(headers: NotificationHeaders) -> Self {
50        let mut map = HashMap::new();
51
52        map.insert_opt("encoding", headers.encoding);
53        map.insert_opt("encryption", headers.encryption);
54        map.insert_opt("encryption_key", headers.encryption_key);
55        map.insert_opt("crypto_key", headers.crypto_key);
56
57        map
58    }
59}
60
61impl NotificationHeaders {
62    /// Extract the notification headers from a request.
63    /// This can not be implemented as a `FromRequest` impl because we need to
64    /// know if the payload has data, without actually advancing the payload
65    /// stream.
66    pub fn from_request(req: &HttpRequest, has_data: bool) -> ApiResult<Self> {
67        // Collect raw headers
68        let ttl = get_header(req, "ttl")
69            .and_then(|ttl| ttl.parse().ok())
70            // Enforce a maximum TTL, but don't error
71            // NOTE: In order to trap for negative TTLs, this should be a
72            // signed value, otherwise we will error out with NO_TTL.
73            .map(|ttl| min(ttl, MAX_NOTIFICATION_TTL.num_seconds()))
74            .ok_or(ApiErrorKind::NoTTL)?;
75        let topic = get_owned_header(req, "topic");
76
77        let headers = if has_data {
78            NotificationHeaders {
79                ttl,
80                topic,
81                encoding: get_owned_header(req, "content-encoding"),
82                encryption: get_owned_header(req, "encryption").map(Self::strip_header),
83                encryption_key: get_owned_header(req, "encryption-key"),
84                crypto_key: get_owned_header(req, "crypto-key").map(Self::strip_header),
85            }
86        } else {
87            // Messages without a body shouldn't pass along unnecessary headers
88            NotificationHeaders {
89                ttl,
90                topic,
91                encoding: None,
92                encryption: None,
93                encryption_key: None,
94                crypto_key: None,
95            }
96        };
97
98        // Validate encryption if there is a message body
99        if has_data {
100            headers.validate_encryption()?;
101        }
102
103        // Validate the other headers
104        match headers.validate() {
105            Ok(_) => Ok(headers),
106            Err(e) => Err(ApiError::from(e)),
107        }
108    }
109
110    /// Remove Base64 padding and double-quotes
111    fn strip_header(header: String) -> String {
112        let header = header.replace('"', "");
113        STRIP_PADDING.replace_all(&header, "$head$tail").to_string()
114    }
115
116    /// Validate the encryption headers according to the various WebPush
117    /// standard versions
118    fn validate_encryption(&self) -> ApiResult<()> {
119        let encoding = self.encoding.as_deref().ok_or_else(|| {
120            ApiErrorKind::InvalidEncryption("Missing Content-Encoding header".to_string())
121        })?;
122
123        match encoding {
124            "aesgcm" => self.validate_encryption_04_rules()?,
125            "aes128gcm" => self.validate_encryption_06_rules()?,
126            _ => {
127                return Err(ApiErrorKind::InvalidEncryption(
128                    "Unknown Content-Encoding header".to_string(),
129                )
130                .into());
131            }
132        }
133
134        Ok(())
135    }
136
137    /// Validates encryption headers according to
138    /// draft-ietf-webpush-encryption-04
139    fn validate_encryption_04_rules(&self) -> ApiResult<()> {
140        Self::assert_base64_item_exists("Encryption", self.encryption.as_deref(), "salt")?;
141
142        if self.encryption_key.is_some() {
143            return Err(ApiErrorKind::InvalidEncryption(
144                "Encryption-Key header is not valid for webpush draft 02 or later".to_string(),
145            )
146            .into());
147        }
148
149        if self.crypto_key.is_some() {
150            Self::assert_base64_item_exists("Crypto-Key", self.crypto_key.as_deref(), "dh")?;
151        }
152
153        Ok(())
154    }
155
156    /// Validates encryption headers according to
157    /// draft-ietf-httpbis-encryption-encoding-06
158    /// (the encryption values are in the payload, so there shouldn't be any in
159    /// the headers)
160    fn validate_encryption_06_rules(&self) -> ApiResult<()> {
161        Self::assert_not_exists("aes128gcm Encryption", self.encryption.as_deref(), "salt")?;
162        Self::assert_not_exists("aes128gcm Crypto-Key", self.crypto_key.as_deref(), "dh")?;
163
164        Ok(())
165    }
166
167    /// Assert that the given item exists in the header and is valid base64.
168    fn assert_base64_item_exists(
169        header_name: &str,
170        header: Option<&str>,
171        key: &str,
172    ) -> ApiResult<()> {
173        let header = header.ok_or_else(|| {
174            ApiErrorKind::InvalidEncryption(format!("Missing {header_name} header"))
175        })?;
176        let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
177            ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
178        })?;
179        let value = header_data.get_by_key(key).ok_or_else(|| {
180            ApiErrorKind::InvalidEncryption(format!("Missing {key} value in {header_name} header"))
181        })?;
182
183        if !VALID_BASE64_URL.is_match(value) {
184            return Err(ApiErrorKind::InvalidEncryption(format!(
185                "Invalid {key} value in {header_name} header",
186            ))
187            .into());
188        }
189
190        Ok(())
191    }
192
193    /// Assert that the given key does not exist in the header.
194    fn assert_not_exists(header_name: &str, header: Option<&str>, key: &str) -> ApiResult<()> {
195        let header = match header {
196            Some(header) => header,
197            None => return Ok(()),
198        };
199
200        let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
201            ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
202        })?;
203
204        if header_data.get_by_key(key).is_some() {
205            return Err(ApiErrorKind::InvalidEncryption(format!(
206                "Do not include '{key}' header in {header_name} header"
207            ))
208            .into());
209        }
210
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::NotificationHeaders;
218    use crate::error::{ApiErrorKind, ApiResult};
219    use actix_web::test::TestRequest;
220    use autopush_common::MAX_NOTIFICATION_TTL;
221    use chrono::TimeDelta;
222
223    /// Assert that a result is a validation error and check its serialization
224    /// against the JSON value.
225    fn assert_validation_error(
226        result: ApiResult<NotificationHeaders>,
227        expected_json: serde_json::Value,
228    ) {
229        assert!(result.is_err());
230        let errors = match result.unwrap_err().kind {
231            ApiErrorKind::Validation(errors) => errors,
232            _ => panic!("Expected a validation error"),
233        };
234
235        assert_eq!(serde_json::to_value(errors).unwrap(), expected_json);
236    }
237
238    /// Assert that a result is a specific encryption error
239    fn assert_encryption_error(result: ApiResult<NotificationHeaders>, expected_error: &str) {
240        assert!(result.is_err());
241        let error = match result.unwrap_err().kind {
242            ApiErrorKind::InvalidEncryption(error) => error,
243            _ => panic!("Expected an encryption error"),
244        };
245
246        assert_eq!(error, expected_error);
247    }
248
249    /// A valid TTL results in no errors or adjustment
250    #[test]
251    fn valid_ttl() {
252        let req = TestRequest::post()
253            .insert_header(("TTL", "10"))
254            .to_http_request();
255        let result = NotificationHeaders::from_request(&req, false);
256
257        assert!(result.is_ok());
258        assert_eq!(result.unwrap().ttl, 10);
259    }
260
261    /// Negative TTL values are not allowed
262    #[test]
263    fn negative_ttl() {
264        let req = TestRequest::post()
265            .insert_header(("TTL", "-1"))
266            .to_http_request();
267        let result = NotificationHeaders::from_request(&req, false);
268        assert_validation_error(
269            result,
270            serde_json::json!({
271                "ttl": [{
272                    "code": "114",
273                    "message": "TTL must be greater than 0",
274                    "params": {
275                        "min": 0,
276                        "value": -1
277                    }
278                }]
279            }),
280        );
281    }
282
283    /// TTL values above the max are silently reduced to the max
284    #[test]
285    fn maximum_ttl() {
286        let req = TestRequest::post()
287            .insert_header((
288                "TTL",
289                (MAX_NOTIFICATION_TTL + TimeDelta::seconds(1))
290                    .num_seconds()
291                    .to_string(),
292            ))
293            .to_http_request();
294        let result = NotificationHeaders::from_request(&req, false);
295
296        assert!(result.is_ok());
297        assert_eq!(result.unwrap().ttl, MAX_NOTIFICATION_TTL.num_seconds());
298    }
299
300    /// A valid topic results in no errors
301    #[test]
302    fn valid_topic() {
303        let req = TestRequest::post()
304            .insert_header(("TTL", "10"))
305            .insert_header(("TOPIC", "a-test-topic-which-is-just-right"))
306            .to_http_request();
307        let result = NotificationHeaders::from_request(&req, false);
308
309        assert!(result.is_ok());
310        assert_eq!(
311            result.unwrap().topic,
312            Some("a-test-topic-which-is-just-right".to_string())
313        );
314    }
315
316    /// Topic names which are too long return an error
317    #[test]
318    fn too_long_topic() {
319        let req = TestRequest::post()
320            .insert_header(("TTL", "10"))
321            .insert_header(("TOPIC", "test-topic-which-is-too-long-1234"))
322            .to_http_request();
323        let result = NotificationHeaders::from_request(&req, false);
324
325        assert_validation_error(
326            result,
327            serde_json::json!({
328                "topic": [{
329                    "code": "113",
330                    "message": "Topic must be no greater than 32 characters",
331                    "params": {
332                        "max": 32,
333                        "value": "test-topic-which-is-too-long-1234"
334                    }
335                }]
336            }),
337        );
338    }
339
340    /// If there is a payload, there must be a content encoding header
341    #[test]
342    fn payload_without_content_encoding() {
343        let req = TestRequest::post()
344            .insert_header(("TTL", "10"))
345            .to_http_request();
346        let result = NotificationHeaders::from_request(&req, true);
347
348        assert_encryption_error(result, "Missing Content-Encoding header");
349    }
350
351    /// Valid 04 draft encryption passes validation
352    #[test]
353    fn valid_04_encryption() {
354        let req = TestRequest::post()
355            .insert_header(("TTL", "10"))
356            .insert_header(("Content-Encoding", "aesgcm"))
357            .insert_header(("Encryption", "salt=foo"))
358            .insert_header(("Crypto-Key", "dh=bar"))
359            .to_http_request();
360        let result = NotificationHeaders::from_request(&req, true);
361
362        assert!(result.is_ok());
363        assert_eq!(
364            result.unwrap(),
365            NotificationHeaders {
366                ttl: 10,
367                topic: None,
368                encoding: Some("aesgcm".to_string()),
369                encryption: Some("salt=foo".to_string()),
370                encryption_key: None,
371                crypto_key: Some("dh=bar".to_string())
372            }
373        );
374    }
375
376    /// Valid 06 draft encryption passes validation
377    #[test]
378    fn valid_06_encryption() {
379        let req = TestRequest::post()
380            .insert_header(("TTL", "10"))
381            .insert_header(("Content-Encoding", "aes128gcm"))
382            .insert_header(("Encryption", "notsalt=foo"))
383            .insert_header(("Crypto-Key", "notdh=bar"))
384            .to_http_request();
385        let result = NotificationHeaders::from_request(&req, true);
386
387        assert!(result.is_ok());
388        assert_eq!(
389            result.unwrap(),
390            NotificationHeaders {
391                ttl: 10,
392                topic: None,
393                encoding: Some("aes128gcm".to_string()),
394                encryption: Some("notsalt=foo".to_string()),
395                encryption_key: None,
396                crypto_key: Some("notdh=bar".to_string())
397            }
398        );
399    }
400
401    /// The encryption and crypto-key headers are stripped of Base64 padding and
402    /// double-quotes.
403    #[test]
404    fn strip_headers() {
405        let req = TestRequest::post()
406            .insert_header(("TTL", "10"))
407            .insert_header(("Content-Encoding", "aesgcm"))
408            .insert_header(("Encryption", "salt=\"foo\""))
409            .insert_header(("Crypto-Key", "keyid=\"p256dh\";dh=\"deadbeef==\""))
410            .to_http_request();
411        let result = NotificationHeaders::from_request(&req, true);
412
413        assert!(result.is_ok());
414        assert_eq!(
415            result.unwrap(),
416            NotificationHeaders {
417                ttl: 10,
418                topic: None,
419                encoding: Some("aesgcm".to_string()),
420                encryption: Some("salt=foo".to_string()),
421                encryption_key: None,
422                crypto_key: Some("keyid=p256dh;dh=deadbeef".to_string())
423            }
424        );
425    }
426
427    // TODO: Add negative test cases for encryption validation?
428}