1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::headers::crypto_key::CryptoKeyHeader;
3use crate::headers::util::{get_header, get_owned_header};
4use actix_web::HttpRequest;
5use autopush_common::util::InsertOpt;
6use lazy_static::lazy_static;
7use regex::Regex;
8use std::cmp::min;
9use std::collections::HashMap;
10use std::time::Duration;
11use validator::Validate;
12use validator_derive::Validate;
13
14lazy_static! {
15 static ref VALID_BASE64_URL: Regex = Regex::new(r"^[0-9A-Za-z\-_]+=*$").unwrap();
16 static ref STRIP_PADDING: Regex =
17 Regex::new(r"(?P<head>[0-9A-Za-z\-_]+)=+(?P<tail>[,;]|$)").unwrap();
18}
19
20#[derive(Clone, Debug, Eq, PartialEq, Validate)]
22pub struct NotificationHeaders {
23 #[validate(range(min = 0, message = "TTL must be greater than 0", code = "114"))]
25 pub ttl: i64,
26
27 #[validate(
28 length(
29 max = 32,
30 message = "Topic must be no greater than 32 characters",
31 code = "113"
32 ),
33 regex(
34 path = *VALID_BASE64_URL,
35 message = "Topic must be URL and Filename safe Base64 alphabet",
36 code = "113"
37 )
38 )]
39 pub topic: Option<String>,
40
41 pub encoding: Option<String>,
44 pub encryption: Option<String>,
45 pub encryption_key: Option<String>,
46 pub crypto_key: Option<String>,
47}
48
49impl From<NotificationHeaders> for HashMap<String, String> {
50 fn from(headers: NotificationHeaders) -> Self {
51 let mut map = HashMap::new();
52
53 map.insert_opt("encoding", headers.encoding);
54 map.insert_opt("encryption", headers.encryption);
55 map.insert_opt("encryption_key", headers.encryption_key);
56 map.insert_opt("crypto_key", headers.crypto_key);
57
58 map
59 }
60}
61
62impl NotificationHeaders {
63 pub fn from_request(
68 req: &HttpRequest,
69 has_data: bool,
70 max_notification_ttl: Duration,
71 ) -> ApiResult<Self> {
72 let ttl: i64 = get_header(req, "ttl")
74 .and_then(|ttl| ttl.parse().ok())
75 .map(|ttl: i64| min(ttl, max_notification_ttl.as_secs() as i64))
79 .ok_or(ApiErrorKind::NoTTL)?;
80
81 let topic = get_owned_header(req, "topic");
82
83 let headers = if has_data {
84 NotificationHeaders {
85 ttl,
86 topic,
87 encoding: get_owned_header(req, "content-encoding"),
88 encryption: get_owned_header(req, "encryption").map(Self::strip_header),
89 encryption_key: get_owned_header(req, "encryption-key"),
90 crypto_key: get_owned_header(req, "crypto-key").map(Self::strip_header),
91 }
92 } else {
93 NotificationHeaders {
95 ttl,
96 topic,
97 encoding: None,
98 encryption: None,
99 encryption_key: None,
100 crypto_key: None,
101 }
102 };
103
104 if has_data {
106 headers.validate_encryption()?;
107 }
108
109 match headers.validate() {
111 Ok(_) => Ok(headers),
112 Err(e) => Err(ApiError::from(e)),
113 }
114 }
115
116 fn strip_header(header: String) -> String {
118 let header = header.replace('"', "");
119 STRIP_PADDING.replace_all(&header, "$head$tail").to_string()
120 }
121
122 fn validate_encryption(&self) -> ApiResult<()> {
125 let encoding = self.encoding.as_deref().ok_or_else(|| {
126 ApiErrorKind::InvalidEncryption("Missing Content-Encoding header".to_string())
127 })?;
128
129 match encoding {
130 "aesgcm" => self.validate_encryption_04_rules()?,
131 "aes128gcm" => self.validate_encryption_06_rules()?,
132 _ => {
133 return Err(ApiErrorKind::InvalidEncryption(
134 "Unknown Content-Encoding header".to_string(),
135 )
136 .into());
137 }
138 }
139
140 Ok(())
141 }
142
143 fn validate_encryption_04_rules(&self) -> ApiResult<()> {
146 Self::assert_base64_item_exists("Encryption", self.encryption.as_deref(), "salt")?;
147
148 if self.encryption_key.is_some() {
149 return Err(ApiErrorKind::InvalidEncryption(
150 "Encryption-Key header is not valid for webpush draft 02 or later".to_string(),
151 )
152 .into());
153 }
154
155 if self.crypto_key.is_some() {
156 Self::assert_base64_item_exists("Crypto-Key", self.crypto_key.as_deref(), "dh")?;
157 }
158
159 Ok(())
160 }
161
162 fn validate_encryption_06_rules(&self) -> ApiResult<()> {
167 Self::assert_not_exists("aes128gcm Encryption", self.encryption.as_deref(), "salt")?;
168 Self::assert_not_exists("aes128gcm Crypto-Key", self.crypto_key.as_deref(), "dh")?;
169
170 Ok(())
171 }
172
173 fn assert_base64_item_exists(
175 header_name: &str,
176 header: Option<&str>,
177 key: &str,
178 ) -> ApiResult<()> {
179 let header = header.ok_or_else(|| {
180 ApiErrorKind::InvalidEncryption(format!("Missing {header_name} header"))
181 })?;
182 let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
183 ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
184 })?;
185 let value = header_data.get_by_key(key).ok_or_else(|| {
186 ApiErrorKind::InvalidEncryption(format!("Missing {key} value in {header_name} header"))
187 })?;
188
189 if !VALID_BASE64_URL.is_match(value) {
190 return Err(ApiErrorKind::InvalidEncryption(format!(
191 "Invalid {key} value in {header_name} header",
192 ))
193 .into());
194 }
195
196 Ok(())
197 }
198
199 fn assert_not_exists(header_name: &str, header: Option<&str>, key: &str) -> ApiResult<()> {
201 let header = match header {
202 Some(header) => header,
203 None => return Ok(()),
204 };
205
206 let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
207 ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
208 })?;
209
210 if header_data.get_by_key(key).is_some() {
211 return Err(ApiErrorKind::InvalidEncryption(format!(
212 "Do not include '{key}' header in {header_name} header"
213 ))
214 .into());
215 }
216
217 Ok(())
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use std::time::Duration;
224
225 use super::NotificationHeaders;
226 use crate::error::{ApiErrorKind, ApiResult};
227 use actix_web::test::TestRequest;
228 use autopush_common::MAX_NOTIFICATION_TTL_SECS;
229
230 fn assert_validation_error(
233 result: ApiResult<NotificationHeaders>,
234 expected_json: serde_json::Value,
235 ) {
236 assert!(result.is_err());
237 let errors = match result.unwrap_err().kind {
238 ApiErrorKind::Validation(errors) => errors,
239 _ => panic!("Expected a validation error"),
240 };
241
242 assert_eq!(serde_json::to_value(errors).unwrap(), expected_json);
243 }
244
245 fn assert_encryption_error(result: ApiResult<NotificationHeaders>, expected_error: &str) {
247 assert!(result.is_err());
248 let error = match result.unwrap_err().kind {
249 ApiErrorKind::InvalidEncryption(error) => error,
250 _ => panic!("Expected an encryption error"),
251 };
252
253 assert_eq!(error, expected_error);
254 }
255
256 #[test]
258 fn valid_ttl() {
259 let req = TestRequest::post()
260 .insert_header(("TTL", "10"))
261 .to_http_request();
262 let result = NotificationHeaders::from_request(
263 &req,
264 false,
265 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
266 );
267
268 assert!(result.is_ok());
269 assert_eq!(result.unwrap().ttl, 10);
270 }
271
272 #[test]
274 fn negative_ttl() {
275 let req = TestRequest::post()
276 .insert_header(("TTL", "-1"))
277 .to_http_request();
278 let result = NotificationHeaders::from_request(
279 &req,
280 false,
281 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
282 );
283 assert_validation_error(
284 result,
285 serde_json::json!({
286 "ttl": [{
287 "code": "114",
288 "message": "TTL must be greater than 0",
289 "params": {
290 "min": 0,
291 "value": -1
292 }
293 }]
294 }),
295 );
296 }
297
298 #[test]
300 fn maximum_ttl() {
301 let req = TestRequest::post()
302 .insert_header(("TTL", (MAX_NOTIFICATION_TTL_SECS + 1).to_string()))
303 .to_http_request();
304 let result = NotificationHeaders::from_request(
305 &req,
306 false,
307 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
308 );
309
310 assert!(result.is_ok());
311 assert_eq!(result.unwrap().ttl, MAX_NOTIFICATION_TTL_SECS as i64);
312 }
313
314 #[test]
316 fn valid_topic() {
317 let req = TestRequest::post()
318 .insert_header(("TTL", "10"))
319 .insert_header(("TOPIC", "a-test-topic-which-is-just-right"))
320 .to_http_request();
321 let result = NotificationHeaders::from_request(
322 &req,
323 false,
324 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
325 );
326
327 assert!(result.is_ok());
328 assert_eq!(
329 result.unwrap().topic,
330 Some("a-test-topic-which-is-just-right".to_string())
331 );
332 }
333
334 #[test]
336 fn too_long_topic() {
337 let req = TestRequest::post()
338 .insert_header(("TTL", "10"))
339 .insert_header(("TOPIC", "test-topic-which-is-too-long-1234"))
340 .to_http_request();
341 let result = NotificationHeaders::from_request(
342 &req,
343 false,
344 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
345 );
346
347 assert_validation_error(
348 result,
349 serde_json::json!({
350 "topic": [{
351 "code": "113",
352 "message": "Topic must be no greater than 32 characters",
353 "params": {
354 "max": 32,
355 "value": "test-topic-which-is-too-long-1234"
356 }
357 }]
358 }),
359 );
360 }
361
362 #[test]
364 fn payload_without_content_encoding() {
365 let req = TestRequest::post()
366 .insert_header(("TTL", "10"))
367 .to_http_request();
368 let result = NotificationHeaders::from_request(
369 &req,
370 true,
371 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
372 );
373
374 assert_encryption_error(result, "Missing Content-Encoding header");
375 }
376
377 #[test]
379 fn valid_04_encryption() {
380 let req = TestRequest::post()
381 .insert_header(("TTL", "10"))
382 .insert_header(("Content-Encoding", "aesgcm"))
383 .insert_header(("Encryption", "salt=foo"))
384 .insert_header(("Crypto-Key", "dh=bar"))
385 .to_http_request();
386 let result = NotificationHeaders::from_request(
387 &req,
388 true,
389 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
390 );
391
392 assert!(result.is_ok());
393 assert_eq!(
394 result.unwrap(),
395 NotificationHeaders {
396 ttl: 10,
397 topic: None,
398 encoding: Some("aesgcm".to_string()),
399 encryption: Some("salt=foo".to_string()),
400 encryption_key: None,
401 crypto_key: Some("dh=bar".to_string())
402 }
403 );
404 }
405
406 #[test]
408 fn valid_06_encryption() {
409 let req = TestRequest::post()
410 .insert_header(("TTL", "10"))
411 .insert_header(("Content-Encoding", "aes128gcm"))
412 .insert_header(("Encryption", "notsalt=foo"))
413 .insert_header(("Crypto-Key", "notdh=bar"))
414 .to_http_request();
415 let result = NotificationHeaders::from_request(
416 &req,
417 true,
418 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
419 );
420
421 assert!(result.is_ok());
422 assert_eq!(
423 result.unwrap(),
424 NotificationHeaders {
425 ttl: 10,
426 topic: None,
427 encoding: Some("aes128gcm".to_string()),
428 encryption: Some("notsalt=foo".to_string()),
429 encryption_key: None,
430 crypto_key: Some("notdh=bar".to_string())
431 }
432 );
433 }
434
435 #[test]
438 fn strip_headers() {
439 let req = TestRequest::post()
440 .insert_header(("TTL", "10"))
441 .insert_header(("Content-Encoding", "aesgcm"))
442 .insert_header(("Encryption", "salt=\"foo\""))
443 .insert_header(("Crypto-Key", "keyid=\"p256dh\";dh=\"deadbeef==\""))
444 .to_http_request();
445 let result = NotificationHeaders::from_request(
446 &req,
447 true,
448 Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
449 );
450
451 assert!(result.is_ok());
452 assert_eq!(
453 result.unwrap(),
454 NotificationHeaders {
455 ttl: 10,
456 topic: None,
457 encoding: Some("aesgcm".to_string()),
458 encryption: Some("salt=foo".to_string()),
459 encryption_key: None,
460 crypto_key: Some("keyid=p256dh;dh=deadbeef".to_string())
461 }
462 );
463 }
464
465 }