autoendpoint/
settings.rs

1//! Application settings
2use std::time::Duration;
3
4use actix_http::header::HeaderMap;
5use config::{Config, ConfigError, Environment, File};
6use fernet::{Fernet, MultiFernet};
7use serde::Deserialize;
8use serde_with::serde_as;
9use url::Url;
10
11use autopush_common::{util, MAX_NOTIFICATION_TTL_SECS};
12
13use crate::headers::vapid::VapidHeaderWithKey;
14use crate::routers::apns::settings::ApnsSettings;
15use crate::routers::fcm::settings::FcmSettings;
16#[cfg(feature = "stub")]
17use crate::routers::stub::settings::StubSettings;
18
19pub const ENV_PREFIX: &str = "autoend";
20
21#[serde_as]
22#[derive(Clone, Debug, Deserialize)]
23#[serde(default)]
24pub struct Settings {
25    pub scheme: String,
26    pub host: String,
27    pub port: u16,
28    pub endpoint_url: String,
29
30    /// The DSN to connect to the storage engine (Used to select between storage systems)
31    pub db_dsn: Option<String>,
32    /// JSON set of specific database settings (See data storage engines)
33    pub db_settings: String,
34
35    pub router_table_name: String,
36    pub message_table_name: String,
37
38    /// A stringified JSON list of VAPID public keys which should be tracked internally.
39    /// This should ONLY include Mozilla generated and consumed messages (e.g. "SendToTab", etc.)
40    /// These keys should be specified in stripped, b64encoded, X962 format (e.g. a single line of
41    /// base64 encoded data without padding).
42    /// You can use `scripts/convert_pem_to_x962.py` to easily convert EC Public keys stored in
43    /// PEM format into appropriate x962 format.
44    pub tracking_keys: String,
45
46    pub max_data_bytes: usize,
47    pub crypto_keys: String,
48    pub auth_keys: String,
49    pub human_logs: bool,
50
51    pub connection_timeout_millis: u64,
52    pub request_timeout_millis: u64,
53
54    pub statsd_host: Option<String>,
55    pub statsd_port: u16,
56    pub statsd_label: String,
57
58    pub fcm: FcmSettings,
59    pub apns: ApnsSettings,
60    #[cfg(feature = "stub")]
61    /// "Stub" is a predictable Mock bridge that allows us to "send" data and return an expected
62    /// result.
63    pub stub: StubSettings,
64    #[cfg(feature = "reliable_report")]
65    /// The DNS for the reliability data store. This is normally a Redis compatible
66    /// storage system. See [Connection Parameters](https://docs.rs/redis/latest/redis/#connection-parameters)
67    /// for details.
68    pub reliability_dsn: Option<String>,
69    #[cfg(feature = "reliable_report")]
70    /// Max number of retries for retries for Redis transactions
71    pub reliability_retry_count: usize,
72    /// Max Notification Lifespan
73    #[serde_as(as = "serde_with::DurationSeconds<u64>")]
74    pub max_notification_ttl: Duration,
75}
76
77impl Default for Settings {
78    fn default() -> Settings {
79        Settings {
80            scheme: "http".to_string(),
81            host: "127.0.0.1".to_string(),
82            endpoint_url: "".to_string(),
83            port: 8000,
84            db_dsn: None,
85            db_settings: "".to_owned(),
86            router_table_name: "router".to_string(),
87            message_table_name: "message".to_string(),
88            // max data is a bit hard to figure out, due to encryption. Using something
89            // like pywebpush, if you encode a block of 4096 bytes, you'll get a
90            // 4216 byte data block. Since we're going to be receiving this, we have to
91            // presume base64 encoding, so we can bump things up to 5630 bytes max.
92            max_data_bytes: 5630,
93            crypto_keys: format!("[{}]", Fernet::generate_key()),
94            auth_keys: r#"["AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB="]"#.to_string(),
95            tracking_keys: r#"[]"#.to_string(),
96            human_logs: false,
97            connection_timeout_millis: 1000,
98            request_timeout_millis: 3000,
99            statsd_host: None,
100            statsd_port: 8125,
101            statsd_label: "autoendpoint".to_string(),
102            fcm: FcmSettings::default(),
103            apns: ApnsSettings::default(),
104            #[cfg(feature = "stub")]
105            stub: StubSettings::default(),
106            #[cfg(feature = "reliable_report")]
107            reliability_dsn: None,
108            #[cfg(feature = "reliable_report")]
109            reliability_retry_count: autopush_common::redis_util::MAX_TRANSACTION_LOOP,
110            max_notification_ttl: Duration::from_secs(MAX_NOTIFICATION_TTL_SECS),
111        }
112    }
113}
114
115impl Settings {
116    /// Load the settings from the config file if supplied, then the environment.
117    pub fn with_env_and_config_file(filename: &Option<String>) -> Result<Self, ConfigError> {
118        let mut config = Config::builder();
119
120        // Merge the config file if supplied
121        if let Some(config_filename) = filename {
122            config = config.add_source(File::with_name(config_filename));
123        }
124
125        // Merge the environment overrides
126        // Note: Specify the separator here so that the shell can properly pass args
127        // down to the sub structures.
128        config = config.add_source(Environment::with_prefix(ENV_PREFIX).separator("__"));
129
130        let built: Self = config.build()?.try_deserialize::<Self>().map_err(|error| {
131            match error {
132                // Configuration errors are not very sysop friendly, Try to make them
133                // a bit more 3AM useful.
134                ConfigError::Message(error_msg) => {
135                    println!("Bad configuration: {:?}", &error_msg);
136                    println!("Please set in config file or use environment variable.");
137                    println!(
138                        "For example to set `database_url` use env var `{}_DATABASE_URL`\n",
139                        ENV_PREFIX.to_uppercase()
140                    );
141                    error!("Configuration error: Value undefined {:?}", &error_msg);
142                    ConfigError::NotFound(error_msg)
143                }
144                _ => {
145                    error!("Configuration error: Other: {:?}", &error);
146                    error
147                }
148            }
149        })?;
150
151        Ok(built)
152    }
153
154    /// Convert a string like `[item1,item2]` into a iterator over `item1` and `item2`.
155    /// Panics with a custom message if the string is not in the expected form.
156    fn read_list_from_str<'list>(
157        list_str: &'list str,
158        panic_msg: &'static str,
159    ) -> impl Iterator<Item = &'list str> {
160        if !(list_str.starts_with('[') && list_str.ends_with(']')) {
161            panic!("{}", panic_msg);
162        }
163
164        let items = &list_str[1..list_str.len() - 1];
165        items.split(',')
166    }
167
168    /// Initialize the fernet encryption instance
169    pub fn make_fernet(&self) -> MultiFernet {
170        let keys = &self.crypto_keys.replace(['"', ' '], "");
171        let fernets = Self::read_list_from_str(keys, "Invalid AUTOEND_CRYPTO_KEYS")
172            .map(|key| {
173                debug!("🔐 Fernet keys: {:?}", &key);
174                Fernet::new(key).expect("Invalid AUTOEND_CRYPTO_KEYS")
175            })
176            .collect();
177        MultiFernet::new(fernets)
178    }
179
180    /// Get the list of auth hash keys
181    pub fn auth_keys(&self) -> Vec<String> {
182        let keys = &self.auth_keys.replace(['"', ' '], "");
183        Self::read_list_from_str(keys, "Invalid AUTOEND_AUTH_KEYS")
184            .map(|v| v.to_owned())
185            .collect()
186    }
187
188    /// Get the list of tracking public keys converted to raw, x962 format byte arrays.
189    /// (This avoids problems with formatting, padding, and other concerns. x962 precedes the
190    /// EC key pair with a `\04` byte. We'll keep that value in place for now, since the value we
191    /// are comparing against will also have the same prefix.)
192    pub fn tracking_keys(&self) -> Result<Vec<Vec<u8>>, ConfigError> {
193        let keys = &self.tracking_keys.replace(['"', ' '], "");
194        // I'm sure there's a more clever way to do this. I don't care. I want simple.
195        let mut result = Vec::new();
196        for v in Self::read_list_from_str(keys, "Invalid AUTOEND_TRACKING_KEYS") {
197            result.push(
198                util::b64_decode(v)
199                    .map_err(|e| ConfigError::Message(format!("Invalid tracking key: {e:?}")))?,
200            );
201        }
202        trace!("🔍 tracking_keys: {result:?}");
203        Ok(result)
204    }
205
206    /// Get the URL for this endpoint server
207    pub fn endpoint_url(&self) -> Url {
208        let endpoint = if self.endpoint_url.is_empty() {
209            format!("{}://{}:{}", self.scheme, self.host, self.port)
210        } else {
211            self.endpoint_url.clone()
212        };
213        Url::parse(&endpoint).expect("Invalid endpoint URL")
214    }
215}
216
217#[derive(Clone, Debug)]
218pub struct VapidTracker(pub Vec<Vec<u8>>);
219impl VapidTracker {
220    /// Very simple string check to see if the Public Key specified in the Vapid header
221    /// matches the set of trackable keys.
222    pub fn is_trackable(&self, vapid: &VapidHeaderWithKey) -> bool {
223        // ideally, [Settings.with_env_and_config_file()] does the work of pre-populating
224        // the Settings.tracking_vapid_pubs cache, but we can't rely on that.
225
226        let key = match util::b64_decode(&vapid.public_key) {
227            Ok(v) => v,
228            Err(e) => {
229                // This error is not fatal, and should not happen often. During preliminary
230                // runs, however, we do want to try and spot them.
231                warn!("🔍 VAPID: tracker failure {e}");
232                return false;
233            }
234        };
235        let result = self.0.contains(&key);
236
237        debug!("🔍 Checking {:?} {}", &vapid.public_key, {
238            if result {
239                "Match!"
240            } else {
241                "no match"
242            }
243        });
244        result
245    }
246
247    /// Extract the message Id from the headers (if present), otherwise just make one up.
248    pub fn get_id(&self, headers: &HeaderMap) -> String {
249        headers
250            .get("X-MessageId")
251            .and_then(|v|
252                // TODO: we should convert the public key string to a bitarray
253                // this would prevent any formatting errors from falsely rejecting
254                // the key. We're ok with comparing strings because we currently
255                // have access to the same public key value string that is being
256                // used, but that may not always be the case.
257                v.to_str().ok())
258            .map(|v| v.to_owned())
259            .unwrap_or_else(|| uuid::Uuid::new_v4().as_simple().to_string())
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use actix_http::header::{HeaderMap, HeaderName, HeaderValue};
266
267    use super::{Settings, VapidTracker};
268    use crate::{
269        error::ApiResult,
270        headers::vapid::{VapidHeader, VapidHeaderWithKey},
271    };
272
273    #[test]
274    fn test_auth_keys() -> ApiResult<()> {
275        let success: Vec<String> = vec![
276            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=".to_owned(),
277            "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC=".to_owned(),
278        ];
279        // Try with quoted strings
280        let settings = Settings{
281            auth_keys: r#"["AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC="]"#.to_owned(),
282            ..Default::default()
283        };
284        let result = settings.auth_keys();
285        assert_eq!(result, success);
286
287        // try with unquoted, non-JSON compliant strings.
288        let settings = Settings{
289            auth_keys: r#"[AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=,AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC=]"#.to_owned(),
290            ..Default::default()
291        };
292        let result = settings.auth_keys();
293        assert_eq!(result, success);
294        Ok(())
295    }
296
297    #[test]
298    fn test_endpoint_url() -> ApiResult<()> {
299        let example = "https://example.org/";
300        let settings = Settings {
301            endpoint_url: example.to_owned(),
302            ..Default::default()
303        };
304
305        assert_eq!(settings.endpoint_url(), url::Url::parse(example).unwrap());
306        let settings = Settings {
307            ..Default::default()
308        };
309
310        assert_eq!(
311            settings.endpoint_url(),
312            url::Url::parse(&format!(
313                "{}://{}:{}",
314                settings.scheme, settings.host, settings.port
315            ))
316            .unwrap()
317        );
318        Ok(())
319    }
320
321    /*
322    // The following test is commented out due to the recent change in rust that makes `env::set_var` unsafe
323    #cfg[all(test, feature="unsafe")]
324    #[test]
325    fn test_default_settings() {
326        // Test that the Config works the way we expect it to.
327        let port = format!("{}__PORT", super::ENV_PREFIX).to_uppercase();
328        let timeout = format!("{}__FCM__TIMEOUT", super::ENV_PREFIX).to_uppercase();
329
330        use std::env;
331        let v1 = env::var(&port);
332        let v2 = env::var(&timeout);
333        // TODO: Audit that the environment access only happens in single-threaded code.
334        unsafe { env::set_var(&port, "9123") };
335        // TODO: Audit that the environment access only happens in single-threaded code.
336        unsafe { env::set_var(&timeout, "123") };
337
338        let settings = Settings::with_env_and_config_file(&None).unwrap();
339        assert_eq!(&settings.port, &9123);
340        assert_eq!(&settings.fcm.timeout, &123);
341        assert_eq!(settings.host, "127.0.0.1".to_owned());
342        // reset (just in case)
343        if let Ok(p) = v1 {
344            trace!("Resetting {}", &port);
345            // TODO: Audit that the environment access only happens in single-threaded code.
346            unsafe { env::set_var(&port, p) };
347        } else {
348            // TODO: Audit that the environment access only happens in single-threaded code.
349            unsafe { env::remove_var(&port) };
350        }
351        if let Ok(p) = v2 {
352            trace!("Resetting {}", &timeout);
353            // TODO: Audit that the environment access only happens in single-threaded code.
354            unsafe { env::set_var(&timeout, p) };
355        } else {
356            // TODO: Audit that the environment access only happens in single-threaded code.
357            unsafe { env::remove_var(&timeout) };
358        }
359    }
360    // */
361
362    #[test]
363    fn test_tracking_keys() -> ApiResult<()> {
364        // Handle the case where the settings may use Standard encoding instead of Base64 encoding.
365        let settings = Settings{
366            tracking_keys: r#"["BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon+H/boHTzMtMoNHsAGDlDB6X"]"#.to_owned(),
367            ..Default::default()
368        };
369
370        let test_header = VapidHeaderWithKey {
371            vapid: VapidHeader {
372                scheme: "".to_owned(),
373                token: "".to_owned(),
374                version_data: crate::headers::vapid::VapidVersionData::Version1,
375            },
376            public_key: "BLMymkOqvT6OZ1o9etCqV4jGPkvOXNz5FdBjsAR9zR5oeCV1x5CBKuSLTlHon-H_boHTzMtMoNHsAGDlDB6X==".to_owned()
377        };
378
379        let key_set = settings.tracking_keys().unwrap();
380        assert!(!key_set.is_empty());
381
382        let reliability = VapidTracker(key_set);
383        assert!(reliability.is_trackable(&test_header));
384
385        Ok(())
386    }
387
388    #[test]
389    fn test_multi_tracking_keys() -> ApiResult<()> {
390        // Handle the case where the settings may use Standard encoding instead of Base64 encoding.
391        let settings = Settings{
392            tracking_keys: r#"[BLbZTvXsQr0rdvLQr73ETRcseSpoof5xV83NiPK9U-Qi00DjNJct1N6EZtTBMD0uh-nNjtLAxik1XP9CZXrKtTg,BHDgfiL1hz4oIBFaxxS9jkzyAVing-W9jjt_7WUeFjWS5Invalid5EjC8TQKddJNP3iow7UW6u8JE3t7u_y3Plc]"#.to_owned(),
393            ..Default::default()
394        };
395
396        let test_header = VapidHeaderWithKey {
397            vapid: VapidHeader {
398                scheme: "".to_owned(),
399                token: "".to_owned(),
400                version_data: crate::headers::vapid::VapidVersionData::Version1,
401            },
402            public_key: "BLbZTvXsQr0rdvLQr73ETRcseSpoof5xV83NiPK9U-Qi00DjNJct1N6EZtTBMD0uh-nNjtLAxik1XP9CZXrKtTg".to_owned()
403        };
404
405        let key_set = settings.tracking_keys().unwrap();
406        assert!(!key_set.is_empty());
407
408        let reliability = VapidTracker(key_set);
409        assert!(reliability.is_trackable(&test_header));
410
411        Ok(())
412    }
413
414    #[test]
415    fn test_reliability_id() -> ApiResult<()> {
416        let mut headers = HeaderMap::new();
417        let keys = Vec::new();
418        let reliability = VapidTracker(keys);
419
420        let key = reliability.get_id(&headers);
421        assert!(!key.is_empty());
422
423        headers.insert(
424            HeaderName::from_lowercase(b"x-messageid").unwrap(),
425            HeaderValue::from_static("123foobar456"),
426        );
427
428        let key = reliability.get_id(&headers);
429        assert_eq!(key, "123foobar456".to_owned());
430
431        Ok(())
432    }
433}