1mod app_state;
2
3extern crate slog;
4#[macro_use]
5extern crate slog_scope;
6extern crate serde_derive;
7
8#[allow(unused_imports)]
10use std::env;
11use std::{io, net::ToSocketAddrs, time::Duration};
12
13use config::{Config, ConfigError, Environment, File};
14use fernet::Fernet;
15use lazy_static::lazy_static;
16use serde::{Deserialize, Deserializer};
17
18use autopush_common::util::deserialize_u32_to_duration;
19#[allow(unused_imports)]
21use serde_json::json;
22
23pub use app_state::AppState;
24
25pub const ENV_PREFIX: &str = "autoconnect";
26
27lazy_static! {
28 static ref HOSTNAME: String = mozsvc_common::get_hostname()
29 .expect("Couldn't get_hostname")
30 .into_string()
31 .expect("Couldn't convert get_hostname");
32 static ref RESOLVED_HOSTNAME: String = resolve_ip(&HOSTNAME)
33 .unwrap_or_else(|_| panic!("Failed to resolve hostname: {}", *HOSTNAME));
34}
35
36fn resolve_ip(hostname: &str) -> io::Result<String> {
38 Ok((hostname, 0)
39 .to_socket_addrs()?
40 .next()
41 .map_or_else(|| hostname.to_owned(), |addr| addr.ip().to_string()))
42}
43
44fn include_port(scheme: &str, port: u16) -> bool {
46 !((scheme == "http" && port == 80) || (scheme == "https" && port == 443))
47}
48
49#[derive(Clone, Debug, Deserialize)]
53#[serde(default)]
54pub struct Settings {
55 pub port: u16,
57 pub hostname: Option<String>,
59 pub resolve_hostname: bool,
61 pub router_port: u16,
63 pub router_hostname: Option<String>,
65 #[serde(deserialize_with = "deserialize_f64_to_duration")]
67 pub auto_ping_interval: Duration,
68 #[serde(deserialize_with = "deserialize_f64_to_duration")]
70 pub auto_ping_timeout: Duration,
71 #[serde(deserialize_with = "deserialize_u32_to_duration")]
73 pub open_handshake_timeout: Duration,
74 #[serde(deserialize_with = "deserialize_u32_to_duration")]
76 pub close_handshake_timeout: Duration,
77 pub endpoint_scheme: String,
79 pub endpoint_hostname: String,
81 pub endpoint_port: u16,
83 pub crypto_key: String,
85 pub statsd_host: Option<String>,
87 pub statsd_port: u16,
89 pub statsd_label: String,
91 pub db_dsn: Option<String>,
93 pub db_settings: String,
95 pub megaphone_api_url: Option<String>,
97 pub megaphone_api_token: Option<String>,
99 #[serde(deserialize_with = "deserialize_u32_to_duration")]
101 pub megaphone_poll_interval: Duration,
102 pub human_logs: bool,
104 pub msg_limit: u32,
107 pub actix_max_connections: Option<usize>,
112 pub actix_workers: Option<usize>,
116 #[cfg(feature = "reliable_report")]
117 pub reliability_dsn: Option<String>,
121 #[cfg(feature = "reliable_report")]
122 pub reliability_retry_count: usize,
124}
125
126impl Default for Settings {
127 fn default() -> Self {
128 Self {
129 port: 8080,
130 hostname: None,
131 resolve_hostname: false,
132 router_port: 8081,
133 router_hostname: None,
134 auto_ping_interval: Duration::from_secs(300),
135 auto_ping_timeout: Duration::from_secs(4),
136 open_handshake_timeout: Duration::from_secs(5),
137 close_handshake_timeout: Duration::from_secs(0),
138 endpoint_scheme: "http".to_owned(),
139 endpoint_hostname: "localhost".to_owned(),
140 endpoint_port: 8082,
141 crypto_key: format!("[{}]", Fernet::generate_key()),
142 statsd_host: Some("localhost".to_owned()),
143 statsd_label: "autoconnect".to_owned(),
145 statsd_port: 8125,
146 db_dsn: None,
147 db_settings: "".to_owned(),
148 megaphone_api_url: None,
149 megaphone_api_token: None,
150 megaphone_poll_interval: Duration::from_secs(30),
151 human_logs: false,
152 msg_limit: 150,
153 actix_max_connections: None,
154 actix_workers: None,
155 #[cfg(feature = "reliable_report")]
156 reliability_dsn: None,
157 #[cfg(feature = "reliable_report")]
158 reliability_retry_count: autopush_common::redis_util::MAX_TRANSACTION_LOOP,
159 }
160 }
161}
162
163impl Settings {
164 pub fn with_env_and_config_files(filenames: &[String]) -> Result<Self, ConfigError> {
166 let mut s = Config::builder();
167
168 for filename in filenames {
170 s = s.add_source(File::with_name(filename));
171 }
172
173 s = s.add_source(Environment::with_prefix(&ENV_PREFIX.to_uppercase()).separator("__"));
175
176 let built = s.build()?;
177 let s = built.try_deserialize::<Settings>()?;
178 s.validate()?;
179 Ok(s)
180 }
181
182 pub fn router_url(&self) -> String {
183 let router_scheme = "http";
184 let url = format!(
185 "{}://{}",
186 router_scheme,
187 self.router_hostname
188 .as_ref()
189 .map_or_else(|| self.get_hostname(), String::clone),
190 );
191 if include_port(router_scheme, self.router_port) {
192 format!("{}:{}", url, self.router_port)
193 } else {
194 url
195 }
196 }
197
198 pub fn endpoint_url(&self) -> String {
199 let url = format!("{}://{}", self.endpoint_scheme, self.endpoint_hostname,);
200 if include_port(&self.endpoint_scheme, self.endpoint_port) {
201 format!("{}:{}", url, self.endpoint_port)
202 } else {
203 url
204 }
205 }
206
207 fn get_hostname(&self) -> String {
208 if let Some(ref hostname) = self.hostname {
209 if self.resolve_hostname {
210 resolve_ip(hostname)
211 .unwrap_or_else(|_| panic!("Failed to resolve provided hostname: {hostname}"))
212 } else {
213 hostname.clone()
214 }
215 } else if self.resolve_hostname {
216 RESOLVED_HOSTNAME.clone()
217 } else {
218 HOSTNAME.clone()
219 }
220 }
221
222 pub fn validate(&self) -> Result<(), ConfigError> {
223 let non_zero = |val: Duration, name| {
224 if val.is_zero() {
225 return Err(ConfigError::Message(format!(
226 "Invalid {ENV_PREFIX}_{name}: cannot be 0"
227 )));
228 }
229 Ok(())
230 };
231 non_zero(self.megaphone_poll_interval, "MEGAPHONE_POLL_INTERVAL")?;
232 non_zero(self.auto_ping_interval, "AUTO_PING_INTERVAL")?;
233 non_zero(self.auto_ping_timeout, "AUTO_PING_TIMEOUT")?;
234 Ok(())
235 }
236
237 pub fn test_settings() -> Self {
238 #[cfg(all(feature = "bigtable", feature = "redis", feature = "postgres"))]
241 {
242 Self::default()
243 }
244 #[cfg(all(
245 feature = "bigtable",
246 not(any(feature = "redis", feature = "postgres"))
247 ))]
248 {
249 let host = env::var("BIGTABLE_EMULATOR_HOST").unwrap_or("localhost:8086".to_owned());
250 let db_dsn = Some(format!("grpc://{}", host));
251 let db_settings = json!({
253 "table_name":"projects/test/instances/test/tables/autopush",
254 "message_family":"message",
255 "router_family":"router",
256 "message_topic_family":"message_topic",
257 })
258 .to_string();
259 Self {
260 db_dsn,
261 db_settings,
262 ..Default::default()
263 }
264 }
265 #[cfg(all(
266 feature = "redis",
267 not(any(feature = "bigtable", feature = "postgres"))
268 ))]
269 {
270 let host = env::var("REDIS_HOST").unwrap_or("localhost:6379".to_owned());
271 let db_dsn = Some(format!("redis://{}", host));
272 let db_settings = "".to_string();
273 Self {
274 db_dsn,
275 db_settings,
276 ..Default::default()
277 }
278 }
279 #[cfg(all(
280 feature = "postgres",
281 not(any(feature = "bigtable", feature = "redis"))
282 ))]
283 {
284 let host = env::var("POSTGRES_HOST").unwrap_or("localhost:5432".to_owned());
285 let db_dsn = Some(format!("postgres://{}", host));
286 let db_settings = "".to_string();
287 Self {
288 db_dsn,
289 db_settings,
290 ..Default::default()
291 }
292 }
293 }
294}
295
296fn deserialize_f64_to_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
297where
298 D: Deserializer<'de>,
299{
300 let seconds: f64 = Deserialize::deserialize(deserializer)?;
301 Ok(Duration::new(
302 seconds as u64,
303 (seconds.fract() * 1_000_000_000.0) as u32,
304 ))
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 #[cfg(feature = "unsafe")]
311 use slog_scope::trace;
312
313 #[test]
314 fn test_router_url() {
315 let mut settings = Settings {
316 router_hostname: Some("testname".to_string()),
317 router_port: 80,
318 ..Default::default()
319 };
320 let url = settings.router_url();
321 assert_eq!("http://testname", url);
322
323 settings.router_port = 8080;
324 let url = settings.router_url();
325 assert_eq!("http://testname:8080", url);
326 }
327
328 #[test]
329 fn test_endpoint_url() {
330 let mut settings = Settings {
331 endpoint_hostname: "testname".to_string(),
332 endpoint_port: 80,
333 endpoint_scheme: "http".to_string(),
334 ..Default::default()
335 };
336 let url = settings.endpoint_url();
337 assert_eq!("http://testname", url);
338
339 settings.endpoint_port = 8080;
340 let url = settings.endpoint_url();
341 assert_eq!("http://testname:8080", url);
342
343 settings.endpoint_port = 443;
344 settings.endpoint_scheme = "https".to_string();
345 let url = settings.endpoint_url();
346 assert_eq!("https://testname", url);
347
348 settings.endpoint_port = 8080;
349 let url = settings.endpoint_url();
350 assert_eq!("https://testname:8080", url);
351 }
352
353 #[cfg(all(test, feature = "unsafe"))]
355 #[test]
356 fn test_default_settings() {
357 use std::env;
359 let port = format!("{ENV_PREFIX}__PORT").to_uppercase();
360 let msg_limit = format!("{ENV_PREFIX}__MSG_LIMIT").to_uppercase();
361 let fernet = format!("{ENV_PREFIX}__CRYPTO_KEY").to_uppercase();
362
363 let v1 = env::var(&port);
364 let v2 = env::var(&msg_limit);
365 unsafe {
366 env::set_var(&port, "9123");
367 env::set_var(&msg_limit, "123");
368 env::set_var(&fernet, "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]");
369 }
370 let settings = Settings::with_env_and_config_files(&Vec::new()).unwrap();
371 assert_eq!(settings.endpoint_hostname, "localhost".to_owned());
372 assert_eq!(&settings.port, &9123);
373 assert_eq!(&settings.msg_limit, &123);
374 assert_eq!(
375 &settings.crypto_key,
376 "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]"
377 );
378 assert_eq!(settings.open_handshake_timeout, Duration::from_secs(5));
379
380 if let Ok(p) = v1 {
382 trace!("Resetting {}", &port);
383 unsafe { env::set_var(&port, p) };
385 } else {
386 unsafe { env::remove_var(&port) };
388 }
389 if let Ok(p) = v2 {
390 trace!("Resetting {}", msg_limit);
391 unsafe { env::set_var(&msg_limit, p) };
393 } else {
394 unsafe { env::remove_var(&msg_limit) };
396 }
397 unsafe { env::remove_var(&fernet) };
399 }
400}