1use crate::error::{ApiError, ApiErrorKind, ApiResult};
2use crate::headers::crypto_key::CryptoKeyHeader;
3use crate::headers::util::{get_header, get_owned_header};
4use actix_web::HttpRequest;
5use autopush_common::{util::InsertOpt, MAX_NOTIFICATION_TTL};
6use lazy_static::lazy_static;
7use regex::Regex;
8use std::cmp::min;
9use std::collections::HashMap;
10use validator::Validate;
11use validator_derive::Validate;
12
13lazy_static! {
14 static ref VALID_BASE64_URL: Regex = Regex::new(r"^[0-9A-Za-z\-_]+=*$").unwrap();
15 static ref STRIP_PADDING: Regex =
16 Regex::new(r"(?P<head>[0-9A-Za-z\-_]+)=+(?P<tail>[,;]|$)").unwrap();
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Validate)]
21pub struct NotificationHeaders {
22 #[validate(range(min = 0, message = "TTL must be greater than 0", code = "114"))]
24 pub ttl: i64,
25
26 #[validate(
27 length(
28 max = 32,
29 message = "Topic must be no greater than 32 characters",
30 code = "113"
31 ),
32 regex(
33 path = *VALID_BASE64_URL,
34 message = "Topic must be URL and Filename safe Base64 alphabet",
35 code = "113"
36 )
37 )]
38 pub topic: Option<String>,
39
40 pub encoding: Option<String>,
43 pub encryption: Option<String>,
44 pub encryption_key: Option<String>,
45 pub crypto_key: Option<String>,
46}
47
48impl From<NotificationHeaders> for HashMap<String, String> {
49 fn from(headers: NotificationHeaders) -> Self {
50 let mut map = HashMap::new();
51
52 map.insert_opt("encoding", headers.encoding);
53 map.insert_opt("encryption", headers.encryption);
54 map.insert_opt("encryption_key", headers.encryption_key);
55 map.insert_opt("crypto_key", headers.crypto_key);
56
57 map
58 }
59}
60
61impl NotificationHeaders {
62 pub fn from_request(req: &HttpRequest, has_data: bool) -> ApiResult<Self> {
67 let ttl = get_header(req, "ttl")
69 .and_then(|ttl| ttl.parse().ok())
70 .map(|ttl| min(ttl, MAX_NOTIFICATION_TTL.num_seconds()))
74 .ok_or(ApiErrorKind::NoTTL)?;
75 let topic = get_owned_header(req, "topic");
76
77 let headers = if has_data {
78 NotificationHeaders {
79 ttl,
80 topic,
81 encoding: get_owned_header(req, "content-encoding"),
82 encryption: get_owned_header(req, "encryption").map(Self::strip_header),
83 encryption_key: get_owned_header(req, "encryption-key"),
84 crypto_key: get_owned_header(req, "crypto-key").map(Self::strip_header),
85 }
86 } else {
87 NotificationHeaders {
89 ttl,
90 topic,
91 encoding: None,
92 encryption: None,
93 encryption_key: None,
94 crypto_key: None,
95 }
96 };
97
98 if has_data {
100 headers.validate_encryption()?;
101 }
102
103 match headers.validate() {
105 Ok(_) => Ok(headers),
106 Err(e) => Err(ApiError::from(e)),
107 }
108 }
109
110 fn strip_header(header: String) -> String {
112 let header = header.replace('"', "");
113 STRIP_PADDING.replace_all(&header, "$head$tail").to_string()
114 }
115
116 fn validate_encryption(&self) -> ApiResult<()> {
119 let encoding = self.encoding.as_deref().ok_or_else(|| {
120 ApiErrorKind::InvalidEncryption("Missing Content-Encoding header".to_string())
121 })?;
122
123 match encoding {
124 "aesgcm" => self.validate_encryption_04_rules()?,
125 "aes128gcm" => self.validate_encryption_06_rules()?,
126 _ => {
127 return Err(ApiErrorKind::InvalidEncryption(
128 "Unknown Content-Encoding header".to_string(),
129 )
130 .into());
131 }
132 }
133
134 Ok(())
135 }
136
137 fn validate_encryption_04_rules(&self) -> ApiResult<()> {
140 Self::assert_base64_item_exists("Encryption", self.encryption.as_deref(), "salt")?;
141
142 if self.encryption_key.is_some() {
143 return Err(ApiErrorKind::InvalidEncryption(
144 "Encryption-Key header is not valid for webpush draft 02 or later".to_string(),
145 )
146 .into());
147 }
148
149 if self.crypto_key.is_some() {
150 Self::assert_base64_item_exists("Crypto-Key", self.crypto_key.as_deref(), "dh")?;
151 }
152
153 Ok(())
154 }
155
156 fn validate_encryption_06_rules(&self) -> ApiResult<()> {
161 Self::assert_not_exists("aes128gcm Encryption", self.encryption.as_deref(), "salt")?;
162 Self::assert_not_exists("aes128gcm Crypto-Key", self.crypto_key.as_deref(), "dh")?;
163
164 Ok(())
165 }
166
167 fn assert_base64_item_exists(
169 header_name: &str,
170 header: Option<&str>,
171 key: &str,
172 ) -> ApiResult<()> {
173 let header = header.ok_or_else(|| {
174 ApiErrorKind::InvalidEncryption(format!("Missing {header_name} header"))
175 })?;
176 let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
177 ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
178 })?;
179 let value = header_data.get_by_key(key).ok_or_else(|| {
180 ApiErrorKind::InvalidEncryption(format!("Missing {key} value in {header_name} header"))
181 })?;
182
183 if !VALID_BASE64_URL.is_match(value) {
184 return Err(ApiErrorKind::InvalidEncryption(format!(
185 "Invalid {key} value in {header_name} header",
186 ))
187 .into());
188 }
189
190 Ok(())
191 }
192
193 fn assert_not_exists(header_name: &str, header: Option<&str>, key: &str) -> ApiResult<()> {
195 let header = match header {
196 Some(header) => header,
197 None => return Ok(()),
198 };
199
200 let header_data = CryptoKeyHeader::parse(header).ok_or_else(|| {
201 ApiErrorKind::InvalidEncryption(format!("Invalid {header_name} header"))
202 })?;
203
204 if header_data.get_by_key(key).is_some() {
205 return Err(ApiErrorKind::InvalidEncryption(format!(
206 "Do not include '{key}' header in {header_name} header"
207 ))
208 .into());
209 }
210
211 Ok(())
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::NotificationHeaders;
218 use crate::error::{ApiErrorKind, ApiResult};
219 use actix_web::test::TestRequest;
220 use autopush_common::MAX_NOTIFICATION_TTL;
221 use chrono::TimeDelta;
222
223 fn assert_validation_error(
226 result: ApiResult<NotificationHeaders>,
227 expected_json: serde_json::Value,
228 ) {
229 assert!(result.is_err());
230 let errors = match result.unwrap_err().kind {
231 ApiErrorKind::Validation(errors) => errors,
232 _ => panic!("Expected a validation error"),
233 };
234
235 assert_eq!(serde_json::to_value(errors).unwrap(), expected_json);
236 }
237
238 fn assert_encryption_error(result: ApiResult<NotificationHeaders>, expected_error: &str) {
240 assert!(result.is_err());
241 let error = match result.unwrap_err().kind {
242 ApiErrorKind::InvalidEncryption(error) => error,
243 _ => panic!("Expected an encryption error"),
244 };
245
246 assert_eq!(error, expected_error);
247 }
248
249 #[test]
251 fn valid_ttl() {
252 let req = TestRequest::post()
253 .insert_header(("TTL", "10"))
254 .to_http_request();
255 let result = NotificationHeaders::from_request(&req, false);
256
257 assert!(result.is_ok());
258 assert_eq!(result.unwrap().ttl, 10);
259 }
260
261 #[test]
263 fn negative_ttl() {
264 let req = TestRequest::post()
265 .insert_header(("TTL", "-1"))
266 .to_http_request();
267 let result = NotificationHeaders::from_request(&req, false);
268 assert_validation_error(
269 result,
270 serde_json::json!({
271 "ttl": [{
272 "code": "114",
273 "message": "TTL must be greater than 0",
274 "params": {
275 "min": 0,
276 "value": -1
277 }
278 }]
279 }),
280 );
281 }
282
283 #[test]
285 fn maximum_ttl() {
286 let req = TestRequest::post()
287 .insert_header((
288 "TTL",
289 (MAX_NOTIFICATION_TTL + TimeDelta::seconds(1))
290 .num_seconds()
291 .to_string(),
292 ))
293 .to_http_request();
294 let result = NotificationHeaders::from_request(&req, false);
295
296 assert!(result.is_ok());
297 assert_eq!(result.unwrap().ttl, MAX_NOTIFICATION_TTL.num_seconds());
298 }
299
300 #[test]
302 fn valid_topic() {
303 let req = TestRequest::post()
304 .insert_header(("TTL", "10"))
305 .insert_header(("TOPIC", "a-test-topic-which-is-just-right"))
306 .to_http_request();
307 let result = NotificationHeaders::from_request(&req, false);
308
309 assert!(result.is_ok());
310 assert_eq!(
311 result.unwrap().topic,
312 Some("a-test-topic-which-is-just-right".to_string())
313 );
314 }
315
316 #[test]
318 fn too_long_topic() {
319 let req = TestRequest::post()
320 .insert_header(("TTL", "10"))
321 .insert_header(("TOPIC", "test-topic-which-is-too-long-1234"))
322 .to_http_request();
323 let result = NotificationHeaders::from_request(&req, false);
324
325 assert_validation_error(
326 result,
327 serde_json::json!({
328 "topic": [{
329 "code": "113",
330 "message": "Topic must be no greater than 32 characters",
331 "params": {
332 "max": 32,
333 "value": "test-topic-which-is-too-long-1234"
334 }
335 }]
336 }),
337 );
338 }
339
340 #[test]
342 fn payload_without_content_encoding() {
343 let req = TestRequest::post()
344 .insert_header(("TTL", "10"))
345 .to_http_request();
346 let result = NotificationHeaders::from_request(&req, true);
347
348 assert_encryption_error(result, "Missing Content-Encoding header");
349 }
350
351 #[test]
353 fn valid_04_encryption() {
354 let req = TestRequest::post()
355 .insert_header(("TTL", "10"))
356 .insert_header(("Content-Encoding", "aesgcm"))
357 .insert_header(("Encryption", "salt=foo"))
358 .insert_header(("Crypto-Key", "dh=bar"))
359 .to_http_request();
360 let result = NotificationHeaders::from_request(&req, true);
361
362 assert!(result.is_ok());
363 assert_eq!(
364 result.unwrap(),
365 NotificationHeaders {
366 ttl: 10,
367 topic: None,
368 encoding: Some("aesgcm".to_string()),
369 encryption: Some("salt=foo".to_string()),
370 encryption_key: None,
371 crypto_key: Some("dh=bar".to_string())
372 }
373 );
374 }
375
376 #[test]
378 fn valid_06_encryption() {
379 let req = TestRequest::post()
380 .insert_header(("TTL", "10"))
381 .insert_header(("Content-Encoding", "aes128gcm"))
382 .insert_header(("Encryption", "notsalt=foo"))
383 .insert_header(("Crypto-Key", "notdh=bar"))
384 .to_http_request();
385 let result = NotificationHeaders::from_request(&req, true);
386
387 assert!(result.is_ok());
388 assert_eq!(
389 result.unwrap(),
390 NotificationHeaders {
391 ttl: 10,
392 topic: None,
393 encoding: Some("aes128gcm".to_string()),
394 encryption: Some("notsalt=foo".to_string()),
395 encryption_key: None,
396 crypto_key: Some("notdh=bar".to_string())
397 }
398 );
399 }
400
401 #[test]
404 fn strip_headers() {
405 let req = TestRequest::post()
406 .insert_header(("TTL", "10"))
407 .insert_header(("Content-Encoding", "aesgcm"))
408 .insert_header(("Encryption", "salt=\"foo\""))
409 .insert_header(("Crypto-Key", "keyid=\"p256dh\";dh=\"deadbeef==\""))
410 .to_http_request();
411 let result = NotificationHeaders::from_request(&req, true);
412
413 assert!(result.is_ok());
414 assert_eq!(
415 result.unwrap(),
416 NotificationHeaders {
417 ttl: 10,
418 topic: None,
419 encoding: Some("aesgcm".to_string()),
420 encryption: Some("salt=foo".to_string()),
421 encryption_key: None,
422 crypto_key: Some("keyid=p256dh;dh=deadbeef".to_string())
423 }
424 );
425 }
426
427 }