autoendpoint/extractors/
notification_headers.rs

1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::headers::crypto_key::CryptoKeyHeader;
3use crate::headers::util::{get_header, get_owned_header};
4use actix_web::HttpRequest;
5use autopush_common::util::InsertOpt;
6use lazy_static::lazy_static;
7use regex::Regex;
8use std::cmp::min;
9use std::collections::HashMap;
10use std::time::Duration;
11use validator::Validate;
12use validator_derive::Validate;
13
14lazy_static! {
15    static ref VALID_BASE64_URL: Regex = Regex::new(r"^[0-9A-Za-z\-_]+=*$").unwrap();
16    static ref STRIP_PADDING: Regex =
17        Regex::new(r"(?P<head>[0-9A-Za-z\-_]+)=+(?P<tail>[,;]|$)").unwrap();
18}
19
20/// Extractor and validator for notification headers
21#[derive(Clone, Debug, Eq, PartialEq, Validate)]
22pub struct NotificationHeaders {
23    // TTL is a signed value so that validation can catch negative inputs
24    #[validate(range(min = 0, message = "TTL must be greater than 0", code = "114"))]
25    pub ttl: i64,
26
27    #[validate(
28        length(
29            max = 32,
30            message = "Topic must be no greater than 32 characters",
31            code = "113"
32        ),
33        regex(
34            path = *VALID_BASE64_URL,
35            message = "Topic must be URL and Filename safe Base64 alphabet",
36            code = "113"
37        )
38    )]
39    pub topic: Option<String>,
40
41    // These fields are validated separately, because the validation is complex
42    // and based upon the content encoding
43    pub encoding: Option<String>,
44    pub encryption: Option<String>,
45    pub encryption_key: Option<String>,
46    pub crypto_key: Option<String>,
47}
48
49impl From<NotificationHeaders> for HashMap<String, String> {
50    fn from(headers: NotificationHeaders) -> Self {
51        let mut map = HashMap::new();
52
53        map.insert_opt("encoding", headers.encoding);
54        map.insert_opt("encryption", headers.encryption);
55        map.insert_opt("encryption_key", headers.encryption_key);
56        map.insert_opt("crypto_key", headers.crypto_key);
57
58        map
59    }
60}
61
62impl NotificationHeaders {
63    /// Extract the notification headers from a request.
64    /// This can not be implemented as a `FromRequest` impl because we need to
65    /// know if the payload has data, without actually advancing the payload
66    /// stream.
67    pub fn from_request(
68        req: &HttpRequest,
69        has_data: bool,
70        max_notification_ttl: Duration,
71    ) -> ApiResult<Self> {
72        // Collect raw headers
73        let ttl: i64 = get_header(req, "ttl")
74            .and_then(|ttl| ttl.parse().ok())
75            // Enforce a maximum TTL, but don't error
76            // NOTE: In order to trap for negative TTLs, this should be a
77            // signed value, otherwise we will error out with NO_TTL.
78            .map(|ttl: i64| min(ttl, max_notification_ttl.as_secs() as i64))
79            .ok_or(ApiErrorKind::NoTTL)?;
80
81        let topic = get_owned_header(req, "topic");
82
83        let headers = if has_data {
84            NotificationHeaders {
85                ttl,
86                topic,
87                encoding: get_owned_header(req, "content-encoding"),
88                encryption: get_owned_header(req, "encryption").map(Self::strip_header),
89                encryption_key: get_owned_header(req, "encryption-key"),
90                crypto_key: get_owned_header(req, "crypto-key").map(Self::strip_header),
91            }
92        } else {
93            // Messages without a body shouldn't pass along unnecessary headers
94            NotificationHeaders {
95                ttl,
96                topic,
97                encoding: None,
98                encryption: None,
99                encryption_key: None,
100                crypto_key: None,
101            }
102        };
103
104        // Validate encryption if there is a message body
105        if has_data {
106            headers.validate_encryption()?;
107        }
108
109        // Validate the other headers
110        match headers.validate() {
111            Ok(_) => Ok(headers),
112            Err(e) => Err(ApiError::from(e)),
113        }
114    }
115
116    /// Remove Base64 padding and double-quotes
117    fn strip_header(header: String) -> String {
118        let header = header.replace('"', "");
119        STRIP_PADDING.replace_all(&header, "$head$tail").to_string()
120    }
121
122    /// Validate the encryption headers according to the various WebPush
123    /// standard versions
124    fn validate_encryption(&self) -> ApiResult<()> {
125        let encoding = self.encoding.as_deref().ok_or_else(|| {
126            ApiErrorKind::InvalidEncryption("Missing Content-Encoding header".to_string())
127        })?;
128
129        match encoding {
130            "aesgcm" => self.validate_encryption_04_rules()?,
131            "aes128gcm" => self.validate_encryption_06_rules()?,
132            _ => {
133                return Err(ApiErrorKind::InvalidEncryption(
134                    "Unknown Content-Encoding header".to_string(),
135                )
136                .into());
137            }
138        }
139
140        Ok(())
141    }
142
143    /// Validates encryption headers according to
144    /// draft-ietf-webpush-encryption-04
145    fn validate_encryption_04_rules(&self) -> ApiResult<()> {
146        Self::assert_base64_item_exists("Encryption", self.encryption.as_deref(), "salt")?;
147
148        if self.encryption_key.is_some() {
149            return Err(ApiErrorKind::InvalidEncryption(
150                "Encryption-Key header is not valid for webpush draft 02 or later".to_string(),
151            )
152            .into());
153        }
154
155        if self.crypto_key.is_some() {
156            Self::assert_base64_item_exists("Crypto-Key", self.crypto_key.as_deref(), "dh")?;
157        }
158
159        Ok(())
160    }
161
162    /// Validates encryption headers according to
163    /// draft-ietf-httpbis-encryption-encoding-06
164    /// (the encryption values are in the payload, so there shouldn't be any in
165    /// the headers)
166    fn validate_encryption_06_rules(&self) -> ApiResult<()> {
167        Self::assert_not_exists("aes128gcm Encryption", self.encryption.as_deref(), "salt")?;
168        Self::assert_not_exists("aes128gcm Crypto-Key", self.crypto_key.as_deref(), "dh")?;
169
170        Ok(())
171    }
172
173    /// Assert that the given item exists in the header and is valid base64.
174    fn assert_base64_item_exists(
175        header_name: &str,
176        header: Option<&str>,
177        key: &str,
178    ) -> ApiResult<()> {
179        let header = header.ok_or_else(|| {
180            ApiErrorKind::InvalidEncryption(format!("Missing {header_name} header"))
181        })?;
182        let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
183            ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
184        })?;
185        let value = header_data.get_by_key(key).ok_or_else(|| {
186            ApiErrorKind::InvalidEncryption(format!("Missing {key} value in {header_name} header"))
187        })?;
188
189        if !VALID_BASE64_URL.is_match(value) {
190            return Err(ApiErrorKind::InvalidEncryption(format!(
191                "Invalid {key} value in {header_name} header",
192            ))
193            .into());
194        }
195
196        Ok(())
197    }
198
199    /// Assert that the given key does not exist in the header.
200    fn assert_not_exists(header_name: &str, header: Option<&str>, key: &str) -> ApiResult<()> {
201        let header = match header {
202            Some(header) => header,
203            None => return Ok(()),
204        };
205
206        let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
207            ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
208        })?;
209
210        if header_data.get_by_key(key).is_some() {
211            return Err(ApiErrorKind::InvalidEncryption(format!(
212                "Do not include '{key}' header in {header_name} header"
213            ))
214            .into());
215        }
216
217        Ok(())
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use std::time::Duration;
224
225    use super::NotificationHeaders;
226    use crate::error::{ApiErrorKind, ApiResult};
227    use actix_web::test::TestRequest;
228    use autopush_common::MAX_NOTIFICATION_TTL_SECS;
229
230    /// Assert that a result is a validation error and check its serialization
231    /// against the JSON value.
232    fn assert_validation_error(
233        result: ApiResult<NotificationHeaders>,
234        expected_json: serde_json::Value,
235    ) {
236        assert!(result.is_err());
237        let errors = match result.unwrap_err().kind {
238            ApiErrorKind::Validation(errors) => errors,
239            _ => panic!("Expected a validation error"),
240        };
241
242        assert_eq!(serde_json::to_value(errors).unwrap(), expected_json);
243    }
244
245    /// Assert that a result is a specific encryption error
246    fn assert_encryption_error(result: ApiResult<NotificationHeaders>, expected_error: &str) {
247        assert!(result.is_err());
248        let error = match result.unwrap_err().kind {
249            ApiErrorKind::InvalidEncryption(error) => error,
250            _ => panic!("Expected an encryption error"),
251        };
252
253        assert_eq!(error, expected_error);
254    }
255
256    /// A valid TTL results in no errors or adjustment
257    #[test]
258    fn valid_ttl() {
259        let req = TestRequest::post()
260            .insert_header(("TTL", "10"))
261            .to_http_request();
262        let result = NotificationHeaders::from_request(
263            &req,
264            false,
265            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
266        );
267
268        assert!(result.is_ok());
269        assert_eq!(result.unwrap().ttl, 10);
270    }
271
272    /// Negative TTL values are not allowed
273    #[test]
274    fn negative_ttl() {
275        let req = TestRequest::post()
276            .insert_header(("TTL", "-1"))
277            .to_http_request();
278        let result = NotificationHeaders::from_request(
279            &req,
280            false,
281            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
282        );
283        assert_validation_error(
284            result,
285            serde_json::json!({
286                "ttl": [{
287                    "code": "114",
288                    "message": "TTL must be greater than 0",
289                    "params": {
290                        "min": 0,
291                        "value": -1
292                    }
293                }]
294            }),
295        );
296    }
297
298    /// TTL values above the max are silently reduced to the max
299    #[test]
300    fn maximum_ttl() {
301        let req = TestRequest::post()
302            .insert_header(("TTL", (MAX_NOTIFICATION_TTL_SECS + 1).to_string()))
303            .to_http_request();
304        let result = NotificationHeaders::from_request(
305            &req,
306            false,
307            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
308        );
309
310        assert!(result.is_ok());
311        assert_eq!(result.unwrap().ttl, MAX_NOTIFICATION_TTL_SECS as i64);
312    }
313
314    /// A valid topic results in no errors
315    #[test]
316    fn valid_topic() {
317        let req = TestRequest::post()
318            .insert_header(("TTL", "10"))
319            .insert_header(("TOPIC", "a-test-topic-which-is-just-right"))
320            .to_http_request();
321        let result = NotificationHeaders::from_request(
322            &req,
323            false,
324            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
325        );
326
327        assert!(result.is_ok());
328        assert_eq!(
329            result.unwrap().topic,
330            Some("a-test-topic-which-is-just-right".to_string())
331        );
332    }
333
334    /// Topic names which are too long return an error
335    #[test]
336    fn too_long_topic() {
337        let req = TestRequest::post()
338            .insert_header(("TTL", "10"))
339            .insert_header(("TOPIC", "test-topic-which-is-too-long-1234"))
340            .to_http_request();
341        let result = NotificationHeaders::from_request(
342            &req,
343            false,
344            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
345        );
346
347        assert_validation_error(
348            result,
349            serde_json::json!({
350                "topic": [{
351                    "code": "113",
352                    "message": "Topic must be no greater than 32 characters",
353                    "params": {
354                        "max": 32,
355                        "value": "test-topic-which-is-too-long-1234"
356                    }
357                }]
358            }),
359        );
360    }
361
362    /// If there is a payload, there must be a content encoding header
363    #[test]
364    fn payload_without_content_encoding() {
365        let req = TestRequest::post()
366            .insert_header(("TTL", "10"))
367            .to_http_request();
368        let result = NotificationHeaders::from_request(
369            &req,
370            true,
371            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
372        );
373
374        assert_encryption_error(result, "Missing Content-Encoding header");
375    }
376
377    /// Valid 04 draft encryption passes validation
378    #[test]
379    fn valid_04_encryption() {
380        let req = TestRequest::post()
381            .insert_header(("TTL", "10"))
382            .insert_header(("Content-Encoding", "aesgcm"))
383            .insert_header(("Encryption", "salt=foo"))
384            .insert_header(("Crypto-Key", "dh=bar"))
385            .to_http_request();
386        let result = NotificationHeaders::from_request(
387            &req,
388            true,
389            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
390        );
391
392        assert!(result.is_ok());
393        assert_eq!(
394            result.unwrap(),
395            NotificationHeaders {
396                ttl: 10,
397                topic: None,
398                encoding: Some("aesgcm".to_string()),
399                encryption: Some("salt=foo".to_string()),
400                encryption_key: None,
401                crypto_key: Some("dh=bar".to_string())
402            }
403        );
404    }
405
406    /// Valid 06 draft encryption passes validation
407    #[test]
408    fn valid_06_encryption() {
409        let req = TestRequest::post()
410            .insert_header(("TTL", "10"))
411            .insert_header(("Content-Encoding", "aes128gcm"))
412            .insert_header(("Encryption", "notsalt=foo"))
413            .insert_header(("Crypto-Key", "notdh=bar"))
414            .to_http_request();
415        let result = NotificationHeaders::from_request(
416            &req,
417            true,
418            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
419        );
420
421        assert!(result.is_ok());
422        assert_eq!(
423            result.unwrap(),
424            NotificationHeaders {
425                ttl: 10,
426                topic: None,
427                encoding: Some("aes128gcm".to_string()),
428                encryption: Some("notsalt=foo".to_string()),
429                encryption_key: None,
430                crypto_key: Some("notdh=bar".to_string())
431            }
432        );
433    }
434
435    /// The encryption and crypto-key headers are stripped of Base64 padding and
436    /// double-quotes.
437    #[test]
438    fn strip_headers() {
439        let req = TestRequest::post()
440            .insert_header(("TTL", "10"))
441            .insert_header(("Content-Encoding", "aesgcm"))
442            .insert_header(("Encryption", "salt=\"foo\""))
443            .insert_header(("Crypto-Key", "keyid=\"p256dh\";dh=\"deadbeef==\""))
444            .to_http_request();
445        let result = NotificationHeaders::from_request(
446            &req,
447            true,
448            Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
449        );
450
451        assert!(result.is_ok());
452        assert_eq!(
453            result.unwrap(),
454            NotificationHeaders {
455                ttl: 10,
456                topic: None,
457                encoding: Some("aesgcm".to_string()),
458                encryption: Some("salt=foo".to_string()),
459                encryption_key: None,
460                crypto_key: Some("keyid=p256dh;dh=deadbeef".to_string())
461            }
462        );
463    }
464
465    // TODO: Add negative test cases for encryption validation?
466}