autoendpoint/
settings.rs

1//! Application settings
2
3use actix_http::header::HeaderMap;
4use config::{Config, ConfigError, Environment, File};
5use fernet::{Fernet, MultiFernet};
6use serde::Deserialize;
7use url::Url;
8
9use autopush_common::util;
10
11use crate::headers::vapid::VapidHeaderWithKey;
12use crate::routers::apns::settings::ApnsSettings;
13use crate::routers::fcm::settings::FcmSettings;
14#[cfg(feature = "stub")]
15use crate::routers::stub::settings::StubSettings;
16
17pub const ENV_PREFIX: &str = "autoend";
18
19#[derive(Clone, Debug, Deserialize)]
20#[serde(default)]
21pub struct Settings {
22    pub scheme: String,
23    pub host: String,
24    pub port: u16,
25    pub endpoint_url: String,
26
27    /// The DSN to connect to the storage engine (Used to select between storage systems)
28    pub db_dsn: Option<String>,
29    /// JSON set of specific database settings (See data storage engines)
30    pub db_settings: String,
31
32    pub router_table_name: String,
33    pub message_table_name: String,
34
35    /// A stringified JSON list of VAPID public keys which should be tracked internally.
36    /// This should ONLY include Mozilla generated and consumed messages (e.g. "SendToTab", etc.)
37    /// These keys should be specified in stripped, b64encoded, X962 format (e.g. a single line of
38    /// base64 encoded data without padding).
39    /// You can use `scripts/convert_pem_to_x962.py` to easily convert EC Public keys stored in
40    /// PEM format into appropriate x962 format.
41    pub tracking_keys: String,
42
43    pub max_data_bytes: usize,
44    pub crypto_keys: String,
45    pub auth_keys: String,
46    pub human_logs: bool,
47
48    pub connection_timeout_millis: u64,
49    pub request_timeout_millis: u64,
50
51    pub statsd_host: Option<String>,
52    pub statsd_port: u16,
53    pub statsd_label: String,
54
55    pub fcm: FcmSettings,
56    pub apns: ApnsSettings,
57    #[cfg(feature = "stub")]
58    /// "Stub" is a predictable Mock bridge that allows us to "send" data and return an expected
59    /// result.
60    pub stub: StubSettings,
61    #[cfg(feature = "reliable_report")]
62    /// The DNS for the reliability data store. This is normally a Redis compatible
63    /// storage system. See [Connection Parameters](https://docs.rs/redis/latest/redis/#connection-parameters)
64    /// for details.
65    pub reliability_dsn: Option<String>,
66    #[cfg(feature = "reliable_report")]
67    /// Max number of retries for retries for Redis transactions
68    pub reliability_retry_count: usize,
69}
70
71impl Default for Settings {
72    fn default() -> Settings {
73        Settings {
74            scheme: "http".to_string(),
75            host: "127.0.0.1".to_string(),
76            endpoint_url: "".to_string(),
77            port: 8000,
78            db_dsn: None,
79            db_settings: "".to_owned(),
80            router_table_name: "router".to_string(),
81            message_table_name: "message".to_string(),
82            // max data is a bit hard to figure out, due to encryption. Using something
83            // like pywebpush, if you encode a block of 4096 bytes, you'll get a
84            // 4216 byte data block. Since we're going to be receiving this, we have to
85            // presume base64 encoding, so we can bump things up to 5630 bytes max.
86            max_data_bytes: 5630,
87            crypto_keys: format!("[{}]", Fernet::generate_key()),
88            auth_keys: r#"["AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB="]"#.to_string(),
89            tracking_keys: r#"[]"#.to_string(),
90            human_logs: false,
91            connection_timeout_millis: 1000,
92            request_timeout_millis: 3000,
93            statsd_host: None,
94            statsd_port: 8125,
95            statsd_label: "autoendpoint".to_string(),
96            fcm: FcmSettings::default(),
97            apns: ApnsSettings::default(),
98            #[cfg(feature = "stub")]
99            stub: StubSettings::default(),
100            #[cfg(feature = "reliable_report")]
101            reliability_dsn: None,
102            #[cfg(feature = "reliable_report")]
103            reliability_retry_count: autopush_common::redis_util::MAX_TRANSACTION_LOOP,
104        }
105    }
106}
107
108impl Settings {
109    /// Load the settings from the config file if supplied, then the environment.
110    pub fn with_env_and_config_file(filename: &Option<String>) -> Result<Self, ConfigError> {
111        let mut config = Config::builder();
112
113        // Merge the config file if supplied
114        if let Some(config_filename) = filename {
115            config = config.add_source(File::with_name(config_filename));
116        }
117
118        // Merge the environment overrides
119        // Note: Specify the separator here so that the shell can properly pass args
120        // down to the sub structures.
121        config = config.add_source(Environment::with_prefix(ENV_PREFIX).separator("__"));
122
123        let built: Self = config.build()?.try_deserialize::<Self>().map_err(|error| {
124            match error {
125                // Configuration errors are not very sysop friendly, Try to make them
126                // a bit more 3AM useful.
127                ConfigError::Message(error_msg) => {
128                    println!("Bad configuration: {:?}", &error_msg);
129                    println!("Please set in config file or use environment variable.");
130                    println!(
131                        "For example to set `database_url` use env var `{}_DATABASE_URL`\n",
132                        ENV_PREFIX.to_uppercase()
133                    );
134                    error!("Configuration error: Value undefined {:?}", &error_msg);
135                    ConfigError::NotFound(error_msg)
136                }
137                _ => {
138                    error!("Configuration error: Other: {:?}", &error);
139                    error
140                }
141            }
142        })?;
143
144        Ok(built)
145    }
146
147    /// Convert a string like `[item1,item2]` into a iterator over `item1` and `item2`.
148    /// Panics with a custom message if the string is not in the expected form.
149    fn read_list_from_str<'list>(
150        list_str: &'list str,
151        panic_msg: &'static str,
152    ) -> impl Iterator<Item = &'list str> {
153        if !(list_str.starts_with('[') && list_str.ends_with(']')) {
154            panic!("{}", panic_msg);
155        }
156
157        let items = &list_str[1..list_str.len() - 1];
158        items.split(',')
159    }
160
161    /// Initialize the fernet encryption instance
162    pub fn make_fernet(&self) -> MultiFernet {
163        let keys = &self.crypto_keys.replace(['"', ' '], "");
164        let fernets = Self::read_list_from_str(keys, "Invalid AUTOEND_CRYPTO_KEYS")
165            .map(|key| {
166                debug!("🔐 Fernet keys: {:?}", &key);
167                Fernet::new(key).expect("Invalid AUTOEND_CRYPTO_KEYS")
168            })
169            .collect();
170        MultiFernet::new(fernets)
171    }
172
173    /// Get the list of auth hash keys
174    pub fn auth_keys(&self) -> Vec<String> {
175        let keys = &self.auth_keys.replace(['"', ' '], "");
176        Self::read_list_from_str(keys, "Invalid AUTOEND_AUTH_KEYS")
177            .map(|v| v.to_owned())
178            .collect()
179    }
180
181    /// Get the list of tracking public keys converted to raw, x962 format byte arrays.
182    /// (This avoids problems with formatting, padding, and other concerns. x962 precedes the
183    /// EC key pair with a `\04` byte. We'll keep that value in place for now, since the value we
184    /// are comparing against will also have the same prefix.)
185    pub fn tracking_keys(&self) -> Result<Vec<Vec<u8>>, ConfigError> {
186        let keys = &self.tracking_keys.replace(['"', ' '], "");
187        // I'm sure there's a more clever way to do this. I don't care. I want simple.
188        let mut result = Vec::new();
189        for v in Self::read_list_from_str(keys, "Invalid AUTOEND_TRACKING_KEYS") {
190            result.push(
191                util::b64_decode(v)
192                    .map_err(|e| ConfigError::Message(format!("Invalid tracking key: {e:?}")))?,
193            );
194        }
195        trace!("🔍 tracking_keys: {result:?}");
196        Ok(result)
197    }
198
199    /// Get the URL for this endpoint server
200    pub fn endpoint_url(&self) -> Url {
201        let endpoint = if self.endpoint_url.is_empty() {
202            format!("{}://{}:{}", self.scheme, self.host, self.port)
203        } else {
204            self.endpoint_url.clone()
205        };
206        Url::parse(&endpoint).expect("Invalid endpoint URL")
207    }
208}
209
210#[derive(Clone, Debug)]
211pub struct VapidTracker(pub Vec<Vec<u8>>);
212impl VapidTracker {
213    /// Very simple string check to see if the Public Key specified in the Vapid header
214    /// matches the set of trackable keys.
215    pub fn is_trackable(&self, vapid: &VapidHeaderWithKey) -> bool {
216        // ideally, [Settings.with_env_and_config_file()] does the work of pre-populating
217        // the Settings.tracking_vapid_pubs cache, but we can't rely on that.
218
219        let key = match util::b64_decode(&vapid.public_key) {
220            Ok(v) => v,
221            Err(e) => {
222                // This error is not fatal, and should not happen often. During preliminary
223                // runs, however, we do want to try and spot them.
224                warn!("🔍 VAPID: tracker failure {e}");
225                return false;
226            }
227        };
228        let result = self.0.contains(&key);
229
230        debug!("🔍 Checking {:?} {}", &vapid.public_key, {
231            if result {
232                "Match!"
233            } else {
234                "no match"
235            }
236        });
237        result
238    }
239
240    /// Extract the message Id from the headers (if present), otherwise just make one up.
241    pub fn get_id(&self, headers: &HeaderMap) -> String {
242        headers
243            .get("X-MessageId")
244            .and_then(|v|
245                // TODO: we should convert the public key string to a bitarray
246                // this would prevent any formatting errors from falsely rejecting
247                // the key. We're ok with comparing strings because we currently
248                // have access to the same public key value string that is being
249                // used, but that may not always be the case.
250                v.to_str().ok())
251            .map(|v| v.to_owned())
252            .unwrap_or_else(|| uuid::Uuid::new_v4().as_simple().to_string())
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use actix_http::header::{HeaderMap, HeaderName, HeaderValue};
259
260    use super::{Settings, VapidTracker};
261    use crate::{
262        error::ApiResult,
263        headers::vapid::{VapidHeader, VapidHeaderWithKey},
264    };
265
266    #[test]
267    fn test_auth_keys() -> ApiResult<()> {
268        let success: Vec<String> = vec![
269            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=".to_owned(),
270            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC=".to_owned(),
271        ];
272        // Try with quoted strings
273        let settings = Settings{
274            auth_keys: r#"["AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC="]"#.to_owned(),
275            ..Default::default()
276        };
277        let result = settings.auth_keys();
278        assert_eq!(result, success);
279
280        // try with unquoted, non-JSON compliant strings.
281        let settings = Settings{
282            auth_keys: r#"[AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=,AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC=]"#.to_owned(),
283            ..Default::default()
284        };
285        let result = settings.auth_keys();
286        assert_eq!(result, success);
287        Ok(())
288    }
289
290    #[test]
291    fn test_endpoint_url() -> ApiResult<()> {
292        let example = "https://example.org/";
293        let settings = Settings {
294            endpoint_url: example.to_owned(),
295            ..Default::default()
296        };
297
298        assert_eq!(settings.endpoint_url(), url::Url::parse(example).unwrap());
299        let settings = Settings {
300            ..Default::default()
301        };
302
303        assert_eq!(
304            settings.endpoint_url(),
305            url::Url::parse(&format!(
306                "{}://{}:{}",
307                settings.scheme, settings.host, settings.port
308            ))
309            .unwrap()
310        );
311        Ok(())
312    }
313
314    #[test]
315    fn test_default_settings() {
316        // Test that the Config works the way we expect it to.
317        let port = format!("{}__PORT", super::ENV_PREFIX).to_uppercase();
318        let timeout = format!("{}__FCM__TIMEOUT", super::ENV_PREFIX).to_uppercase();
319
320        use std::env;
321        let v1 = env::var(&port);
322        let v2 = env::var(&timeout);
323        env::set_var(&port, "9123");
324        env::set_var(&timeout, "123");
325
326        let settings = Settings::with_env_and_config_file(&None).unwrap();
327        assert_eq!(&settings.port, &9123);
328        assert_eq!(&settings.fcm.timeout, &123);
329        assert_eq!(settings.host, "127.0.0.1".to_owned());
330        // reset (just in case)
331        if let Ok(p) = v1 {
332            trace!("Resetting {}", &port);
333            env::set_var(&port, p);
334        } else {
335            env::remove_var(&port);
336        }
337        if let Ok(p) = v2 {
338            trace!("Resetting {}", &timeout);
339            env::set_var(&timeout, p);
340        } else {
341            env::remove_var(&timeout);
342        }
343    }
344
345    #[test]
346    fn test_tracking_keys() -> ApiResult<()> {
347        // Handle the case where the settings may use Standard encoding instead of Base64 encoding.
348        let settings = Settings{
349            tracking_keys: r#"["BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon+H/boHTzMtMoNHsAGDlDB6X"]"#.to_owned(),
350            ..Default::default()
351        };
352
353        let test_header = VapidHeaderWithKey {
354            vapid: VapidHeader {
355                scheme: "".to_owned(),
356                token: "".to_owned(),
357                version_data: crate::headers::vapid::VapidVersionData::Version1,
358            },
359            public_key: "BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X==".to_owned()
360        };
361
362        let key_set = settings.tracking_keys().unwrap();
363        assert!(!key_set.is_empty());
364
365        let reliability = VapidTracker(key_set);
366        assert!(reliability.is_trackable(&test_header));
367
368        Ok(())
369    }
370
371    #[test]
372    fn test_multi_tracking_keys() -> ApiResult<()> {
373        // Handle the case where the settings may use Standard encoding instead of Base64 encoding.
374        let settings = Settings{
375            tracking_keys: r#"[BLbZTvXsQr0rdvLQr73ETRcseSpoof5xV83NiPK9U-Qi00DjNJct1N6EZtTBMD0uh-nNjtLAxik1XP9CZXrKtTg,BHDgfiL1hz4oIBFaxxS9jkzyAVing-W9jjt_7WUeFjWS5Invalid5EjC8TQKddJNP3iow7UW6u8JE3t7u_y3Plc]"#.to_owned(),
376            ..Default::default()
377        };
378
379        let test_header = VapidHeaderWithKey {
380            vapid: VapidHeader {
381                scheme: "".to_owned(),
382                token: "".to_owned(),
383                version_data: crate::headers::vapid::VapidVersionData::Version1,
384            },
385            public_key: "BLbZTvXsQr0rdvLQr73ETRcseSpoof5xV83NiPK9U-Qi00DjNJct1N6EZtTBMD0uh-nNjtLAxik1XP9CZXrKtTg".to_owned()
386        };
387
388        let key_set = settings.tracking_keys().unwrap();
389        assert!(!key_set.is_empty());
390
391        let reliability = VapidTracker(key_set);
392        assert!(reliability.is_trackable(&test_header));
393
394        Ok(())
395    }
396
397    #[test]
398    fn test_reliability_id() -> ApiResult<()> {
399        let mut headers = HeaderMap::new();
400        let keys = Vec::new();
401        let reliability = VapidTracker(keys);
402
403        let key = reliability.get_id(&headers);
404        assert!(!key.is_empty());
405
406        headers.insert(
407            HeaderName::from_lowercase(b"x-messageid").unwrap(),
408            HeaderValue::from_static("123foobar456"),
409        );
410
411        let key = reliability.get_id(&headers);
412        assert_eq!(key, "123foobar456".to_owned());
413
414        Ok(())
415    }
416}