autoconnect_settings/
lib.rs

1mod app_state;
2
3extern crate slog;
4#[macro_use]
5extern crate slog_scope;
6extern crate serde_derive;
7
8use std::env;
9use std::{io, net::ToSocketAddrs, time::Duration};
10
11use config::{Config, ConfigError, Environment, File};
12use fernet::Fernet;
13use lazy_static::lazy_static;
14use serde::{Deserialize, Deserializer};
15
16use autopush_common::util::deserialize_u32_to_duration;
17#[cfg(feature = "bigtable")]
18use serde_json::json;
19
20pub use app_state::AppState;
21
22pub const ENV_PREFIX: &str = "autoconnect";
23
24lazy_static! {
25    static ref HOSTNAME: String = mozsvc_common::get_hostname()
26        .expect("Couldn't get_hostname")
27        .into_string()
28        .expect("Couldn't convert get_hostname");
29    static ref RESOLVED_HOSTNAME: String = resolve_ip(&HOSTNAME)
30        .unwrap_or_else(|_| panic!("Failed to resolve hostname: {}", *HOSTNAME));
31}
32
33/// Resolve a hostname to its IP if possible
34fn resolve_ip(hostname: &str) -> io::Result<String> {
35    Ok((hostname, 0)
36        .to_socket_addrs()?
37        .next()
38        .map_or_else(|| hostname.to_owned(), |addr| addr.ip().to_string()))
39}
40
41/// Indicate whether the port should be included for the given scheme
42fn include_port(scheme: &str, port: u16) -> bool {
43    !((scheme == "http" && port == 80) || (scheme == "https" && port == 443))
44}
45
46/// The Applications settings, read from CLI, Environment or settings file, for the
47/// autoconnect application. These are later converted to
48/// [autoconnect::autoconnect-settings::AppState].
49#[derive(Clone, Debug, Deserialize)]
50#[serde(default)]
51pub struct Settings {
52    /// The application port to listen on
53    pub port: u16,
54    /// The DNS specified name of the application host to used for internal routing
55    pub hostname: Option<String>,
56    /// The override hostname to use for internal routing (NOTE: requires `hostname` to be set)
57    pub resolve_hostname: bool,
58    /// The internal webpush routing port
59    pub router_port: u16,
60    /// The DNS name to use for internal routing
61    pub router_hostname: Option<String>,
62    /// The server based ping interval (also used for Broadcast sends)
63    #[serde(deserialize_with = "deserialize_f64_to_duration")]
64    pub auto_ping_interval: Duration,
65    /// How long to wait for a response Pong before being timed out and connection drop
66    #[serde(deserialize_with = "deserialize_f64_to_duration")]
67    pub auto_ping_timeout: Duration,
68    /// How long to wait for the initial connection handshake.
69    #[serde(deserialize_with = "deserialize_u32_to_duration")]
70    pub open_handshake_timeout: Duration,
71    /// How long to wait while closing a connection for the response handshake.
72    #[serde(deserialize_with = "deserialize_u32_to_duration")]
73    pub close_handshake_timeout: Duration,
74    /// The URL scheme (http/https) for the endpoint URL
75    pub endpoint_scheme: String,
76    /// The host url for the endpoint URL (differs from `hostname` and `resolve_hostname`)
77    pub endpoint_hostname: String,
78    /// The optional port override for the endpoint URL
79    pub endpoint_port: u16,
80    /// The seed key to use for endpoint encryption
81    pub crypto_key: String,
82    /// The host name to send recorded metrics
83    pub statsd_host: Option<String>,
84    /// The port number to send recorded metrics
85    pub statsd_port: u16,
86    /// The root label to apply to metrics.
87    pub statsd_label: String,
88    /// The DSN to connect to the storage engine (Used to select between storage systems)
89    pub db_dsn: Option<String>,
90    /// JSON set of specific database settings (See data storage engines)
91    pub db_settings: String,
92    /// Server endpoint to pull Broadcast ID change values (Sent in Pings)
93    pub megaphone_api_url: Option<String>,
94    /// Broadcast token for authentication
95    pub megaphone_api_token: Option<String>,
96    /// How often to poll the server for new data
97    #[serde(deserialize_with = "deserialize_u32_to_duration")]
98    pub megaphone_poll_interval: Duration,
99    /// Use human readable (simplified, non-JSON)
100    pub human_logs: bool,
101    /// Maximum allowed number of backlogged messages. Exceeding this number will
102    /// trigger a user reset because the user may have been offline way too long.
103    pub msg_limit: u32,
104    /// Sets the maximum number of concurrent connections per actix-web worker.
105    ///
106    /// All socket listeners will stop accepting connections when this limit is
107    /// reached for each worker.
108    pub actix_max_connections: Option<usize>,
109    /// Sets number of actix-web workers to start (per bind address).
110    ///
111    /// By default, the number of available physical CPUs is used as the worker count.
112    pub actix_workers: Option<usize>,
113    #[cfg(feature = "reliable_report")]
114    /// The DNS for the reliability data store. This is normally a Redis compatible
115    /// storage system. See [Connection Parameters](https://docs.rs/redis/latest/redis/#connection-parameters)
116    /// for details.
117    pub reliability_dsn: Option<String>,
118    #[cfg(feature = "reliable_report")]
119    /// Max number of retries for retries for Redis transactions
120    pub reliability_retry_count: usize,
121}
122
123impl Default for Settings {
124    fn default() -> Self {
125        Self {
126            port: 8080,
127            hostname: None,
128            resolve_hostname: false,
129            router_port: 8081,
130            router_hostname: None,
131            auto_ping_interval: Duration::from_secs(300),
132            auto_ping_timeout: Duration::from_secs(4),
133            open_handshake_timeout: Duration::from_secs(5),
134            close_handshake_timeout: Duration::from_secs(0),
135            endpoint_scheme: "http".to_owned(),
136            endpoint_hostname: "localhost".to_owned(),
137            endpoint_port: 8082,
138            crypto_key: format!("[{}]", Fernet::generate_key()),
139            statsd_host: Some("localhost".to_owned()),
140            // Matches the legacy value
141            statsd_label: "autoconnect".to_owned(),
142            statsd_port: 8125,
143            db_dsn: None,
144            db_settings: "".to_owned(),
145            megaphone_api_url: None,
146            megaphone_api_token: None,
147            megaphone_poll_interval: Duration::from_secs(30),
148            human_logs: false,
149            msg_limit: 150,
150            actix_max_connections: None,
151            actix_workers: None,
152            #[cfg(feature = "reliable_report")]
153            reliability_dsn: None,
154            #[cfg(feature = "reliable_report")]
155            reliability_retry_count: autopush_common::redis_util::MAX_TRANSACTION_LOOP,
156        }
157    }
158}
159
160impl Settings {
161    /// Load the settings from the config files in order first then the environment.
162    pub fn with_env_and_config_files(filenames: &[String]) -> Result<Self, ConfigError> {
163        let mut s = Config::builder();
164
165        // Merge the configs from the files
166        for filename in filenames {
167            s = s.add_source(File::with_name(filename));
168        }
169
170        // Merge the environment overrides
171        s = s.add_source(Environment::with_prefix(&ENV_PREFIX.to_uppercase()).separator("__"));
172
173        let built = s.build()?;
174        let s = built.try_deserialize::<Settings>()?;
175        s.validate()?;
176        Ok(s)
177    }
178
179    pub fn router_url(&self) -> String {
180        let router_scheme = "http";
181        let url = format!(
182            "{}://{}",
183            router_scheme,
184            self.router_hostname
185                .as_ref()
186                .map_or_else(|| self.get_hostname(), String::clone),
187        );
188        if include_port(router_scheme, self.router_port) {
189            format!("{}:{}", url, self.router_port)
190        } else {
191            url
192        }
193    }
194
195    pub fn endpoint_url(&self) -> String {
196        let url = format!("{}://{}", self.endpoint_scheme, self.endpoint_hostname,);
197        if include_port(&self.endpoint_scheme, self.endpoint_port) {
198            format!("{}:{}", url, self.endpoint_port)
199        } else {
200            url
201        }
202    }
203
204    fn get_hostname(&self) -> String {
205        if let Some(ref hostname) = self.hostname {
206            if self.resolve_hostname {
207                resolve_ip(hostname)
208                    .unwrap_or_else(|_| panic!("Failed to resolve provided hostname: {hostname}"))
209            } else {
210                hostname.clone()
211            }
212        } else if self.resolve_hostname {
213            RESOLVED_HOSTNAME.clone()
214        } else {
215            HOSTNAME.clone()
216        }
217    }
218
219    pub fn validate(&self) -> Result<(), ConfigError> {
220        let non_zero = |val: Duration, name| {
221            if val.is_zero() {
222                return Err(ConfigError::Message(format!(
223                    "Invalid {ENV_PREFIX}_{name}: cannot be 0"
224                )));
225            }
226            Ok(())
227        };
228        non_zero(self.megaphone_poll_interval, "MEGAPHONE_POLL_INTERVAL")?;
229        non_zero(self.auto_ping_interval, "AUTO_PING_INTERVAL")?;
230        non_zero(self.auto_ping_timeout, "AUTO_PING_TIMEOUT")?;
231        Ok(())
232    }
233
234    #[cfg(feature = "bigtable")]
235    pub fn test_settings() -> Self {
236        let host = env::var("BIGTABLE_EMULATOR_HOST").unwrap_or("localhost:8086".to_owned());
237        let db_dsn = Some(format!("grpc://{}", host));
238        // BigTable DB_SETTINGS.
239        let db_settings = json!({
240            "table_name":"projects/test/instances/test/tables/autopush",
241            "message_family":"message",
242            "router_family":"router",
243            "message_topic_family":"message_topic",
244        })
245        .to_string();
246        Self {
247            db_dsn,
248            db_settings,
249            ..Default::default()
250        }
251    }
252
253    #[cfg(all(feature = "redis", not(feature = "bigtable")))]
254    pub fn test_settings() -> Self {
255        let host = env::var("REDIS_HOST").unwrap_or("localhost:6379".to_owned());
256        let db_dsn = Some(format!("redis://{}", host));
257        let db_settings = "".to_string();
258        Self {
259            db_dsn,
260            db_settings,
261            ..Default::default()
262        }
263    }
264}
265
266fn deserialize_f64_to_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
267where
268    D: Deserializer<'de>,
269{
270    let seconds: f64 = Deserialize::deserialize(deserializer)?;
271    Ok(Duration::new(
272        seconds as u64,
273        (seconds.fract() * 1_000_000_000.0) as u32,
274    ))
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    #[cfg(feature = "unsafe")]
281    use slog_scope::trace;
282
283    #[test]
284    fn test_router_url() {
285        let mut settings = Settings {
286            router_hostname: Some("testname".to_string()),
287            router_port: 80,
288            ..Default::default()
289        };
290        let url = settings.router_url();
291        assert_eq!("http://testname", url);
292
293        settings.router_port = 8080;
294        let url = settings.router_url();
295        assert_eq!("http://testname:8080", url);
296    }
297
298    #[test]
299    fn test_endpoint_url() {
300        let mut settings = Settings {
301            endpoint_hostname: "testname".to_string(),
302            endpoint_port: 80,
303            endpoint_scheme: "http".to_string(),
304            ..Default::default()
305        };
306        let url = settings.endpoint_url();
307        assert_eq!("http://testname", url);
308
309        settings.endpoint_port = 8080;
310        let url = settings.endpoint_url();
311        assert_eq!("http://testname:8080", url);
312
313        settings.endpoint_port = 443;
314        settings.endpoint_scheme = "https".to_string();
315        let url = settings.endpoint_url();
316        assert_eq!("https://testname", url);
317
318        settings.endpoint_port = 8080;
319        let url = settings.endpoint_url();
320        assert_eq!("https://testname:8080", url);
321    }
322
323    // The following test is commented out due to the recent change in rust that makes `env::set_var` unsafe
324    #[cfg(all(test, feature = "unsafe"))]
325    #[test]
326    fn test_default_settings() {
327        // Test that the Config works the way we expect it to.
328        use std::env;
329        let port = format!("{ENV_PREFIX}__PORT").to_uppercase();
330        let msg_limit = format!("{ENV_PREFIX}__MSG_LIMIT").to_uppercase();
331        let fernet = format!("{ENV_PREFIX}__CRYPTO_KEY").to_uppercase();
332
333        let v1 = env::var(&port);
334        let v2 = env::var(&msg_limit);
335        unsafe {
336            env::set_var(&port, "9123");
337            env::set_var(&msg_limit, "123");
338            env::set_var(&fernet, "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]");
339        }
340        let settings = Settings::with_env_and_config_files(&Vec::new()).unwrap();
341        assert_eq!(settings.endpoint_hostname, "localhost".to_owned());
342        assert_eq!(&settings.port, &9123);
343        assert_eq!(&settings.msg_limit, &123);
344        assert_eq!(
345            &settings.crypto_key,
346            "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]"
347        );
348        assert_eq!(settings.open_handshake_timeout, Duration::from_secs(5));
349
350        // reset (just in case)
351        if let Ok(p) = v1 {
352            trace!("Resetting {}", &port);
353            // TODO: Audit that the environment access only happens in single-threaded code.
354            unsafe { env::set_var(&port, p) };
355        } else {
356            // TODO: Audit that the environment access only happens in single-threaded code.
357            unsafe { env::remove_var(&port) };
358        }
359        if let Ok(p) = v2 {
360            trace!("Resetting {}", msg_limit);
361            // TODO: Audit that the environment access only happens in single-threaded code.
362            unsafe { env::set_var(&msg_limit, p) };
363        } else {
364            // TODO: Audit that the environment access only happens in single-threaded code.
365            unsafe { env::remove_var(&msg_limit) };
366        }
367        // TODO: Audit that the environment access only happens in single-threaded code.
368        unsafe { env::remove_var(&fernet) };
369    }
370}