1mod app_state;
2
3extern crate slog;
4#[macro_use]
5extern crate slog_scope;
6extern crate serde_derive;
7
8#[allow(unused_imports)]
10use std::env;
11use std::{io, net::ToSocketAddrs, time::Duration};
12
13use config::{Config, ConfigError, Environment, File};
14use fernet::Fernet;
15use lazy_static::lazy_static;
16use serde::{Deserialize, Deserializer};
17
18use autopush_common::util::deserialize_u32_to_duration;
19#[allow(unused_imports)]
21use serde_json::json;
22
23pub use app_state::AppState;
24
25pub const ENV_PREFIX: &str = "autoconnect";
26
27lazy_static! {
28 static ref HOSTNAME: String = mozsvc_common::get_hostname()
29 .expect("Couldn't get_hostname")
30 .into_string()
31 .expect("Couldn't convert get_hostname");
32 static ref RESOLVED_HOSTNAME: String = resolve_ip(&HOSTNAME)
33 .unwrap_or_else(|_| panic!("Failed to resolve hostname: {}", *HOSTNAME));
34}
35
36fn resolve_ip(hostname: &str) -> io::Result<String> {
38 Ok((hostname, 0)
39 .to_socket_addrs()?
40 .next()
41 .map_or_else(|| hostname.to_owned(), |addr| addr.ip().to_string()))
42}
43
44fn include_port(scheme: &str, port: u16) -> bool {
46 !((scheme == "http" && port == 80) || (scheme == "https" && port == 443))
47}
48
49#[derive(Clone, Debug, Deserialize)]
53#[serde(default)]
54pub struct Settings {
55 pub port: u16,
57 pub hostname: Option<String>,
59 pub resolve_hostname: bool,
61 pub router_port: u16,
63 pub router_hostname: Option<String>,
65 #[serde(deserialize_with = "deserialize_f64_to_duration")]
67 pub auto_ping_interval: Duration,
68 #[serde(deserialize_with = "deserialize_f64_to_duration")]
70 pub auto_ping_timeout: Duration,
71 #[serde(deserialize_with = "deserialize_u32_to_duration")]
73 pub open_handshake_timeout: Duration,
74 pub endpoint_scheme: String,
76 pub endpoint_hostname: String,
78 pub endpoint_port: u16,
80 pub crypto_key: String,
82 pub statsd_host: Option<String>,
84 pub statsd_port: u16,
86 pub statsd_label: String,
88 pub disable_sentry: bool,
90 pub db_dsn: Option<String>,
92 pub db_settings: String,
94 pub megaphone_api_url: Option<String>,
96 pub megaphone_api_token: Option<String>,
98 #[serde(deserialize_with = "deserialize_u32_to_duration")]
100 pub megaphone_poll_interval: Duration,
101 pub human_logs: bool,
103 pub msg_limit: u32,
106 pub client_channel_capacity: usize,
109 pub actix_max_connections: Option<usize>,
114 pub actix_workers: Option<usize>,
118 pub pool_max_idle_per_host: usize,
120 pub pool_idle_timeout_secs: u64,
122 #[cfg(feature = "reliable_report")]
123 pub reliability_dsn: Option<String>,
127 #[cfg(feature = "reliable_report")]
128 pub reliability_retry_count: usize,
130}
131impl Default for Settings {
134 fn default() -> Self {
135 Self {
136 port: 8080,
137 hostname: None,
138 resolve_hostname: false,
139 router_port: 8081,
140 router_hostname: None,
141 auto_ping_interval: Duration::from_secs(300),
142 auto_ping_timeout: Duration::from_secs(4),
143 open_handshake_timeout: Duration::from_secs(5),
144 endpoint_scheme: "http".to_owned(),
145 endpoint_hostname: "localhost".to_owned(),
146 endpoint_port: 8082,
147 crypto_key: format!("[{}]", Fernet::generate_key()),
148 statsd_host: Some("localhost".to_owned()),
149 statsd_label: "autoconnect".to_owned(),
151 statsd_port: 8125,
152 disable_sentry: false,
153 db_dsn: None,
154 db_settings: "".to_owned(),
155 megaphone_api_url: None,
156 megaphone_api_token: None,
157 megaphone_poll_interval: Duration::from_secs(30),
158 human_logs: false,
159 msg_limit: 150,
160 client_channel_capacity: 128,
161 actix_max_connections: None,
162 actix_workers: None,
163 pool_max_idle_per_host: 10,
164 pool_idle_timeout_secs: 30,
165 #[cfg(feature = "reliable_report")]
166 reliability_dsn: None,
167 #[cfg(feature = "reliable_report")]
168 reliability_retry_count: autopush_common::redis_util::MAX_TRANSACTION_LOOP,
169 }
170 }
171}
172
173impl Settings {
174 pub fn with_env_and_config_files(filenames: &[String]) -> Result<Self, ConfigError> {
176 let mut s = Config::builder();
177
178 for filename in filenames {
180 s = s.add_source(File::with_name(filename));
181 }
182
183 s = s.add_source(Environment::with_prefix(&ENV_PREFIX.to_uppercase()).separator("__"));
185
186 let built = s.build()?;
187 let s = built.try_deserialize::<Settings>()?;
188 s.validate()?;
189 Ok(s)
190 }
191
192 pub fn router_url(&self) -> String {
193 let router_scheme = "http";
194 let url = format!(
195 "{}://{}",
196 router_scheme,
197 self.router_hostname
198 .as_ref()
199 .map_or_else(|| self.get_hostname(), String::clone),
200 );
201 if include_port(router_scheme, self.router_port) {
202 format!("{}:{}", url, self.router_port)
203 } else {
204 url
205 }
206 }
207
208 pub fn endpoint_url(&self) -> String {
209 let url = format!("{}://{}", self.endpoint_scheme, self.endpoint_hostname,);
210 if include_port(&self.endpoint_scheme, self.endpoint_port) {
211 format!("{}:{}", url, self.endpoint_port)
212 } else {
213 url
214 }
215 }
216
217 fn get_hostname(&self) -> String {
218 if let Some(ref hostname) = self.hostname {
219 if self.resolve_hostname {
220 resolve_ip(hostname)
221 .unwrap_or_else(|_| panic!("Failed to resolve provided hostname: {hostname}"))
222 } else {
223 hostname.clone()
224 }
225 } else if self.resolve_hostname {
226 RESOLVED_HOSTNAME.clone()
227 } else {
228 HOSTNAME.clone()
229 }
230 }
231
232 pub fn validate(&self) -> Result<(), ConfigError> {
233 let non_zero = |val: Duration, name| {
234 if val.is_zero() {
235 return Err(ConfigError::Message(format!(
236 "Invalid {ENV_PREFIX}_{name}: cannot be 0"
237 )));
238 }
239 Ok(())
240 };
241 non_zero(self.megaphone_poll_interval, "MEGAPHONE_POLL_INTERVAL")?;
242 non_zero(self.auto_ping_interval, "AUTO_PING_INTERVAL")?;
243 non_zero(self.auto_ping_timeout, "AUTO_PING_TIMEOUT")?;
244 Ok(())
245 }
246
247 pub fn test_settings() -> Self {
248 #[cfg(all(feature = "bigtable", feature = "redis", feature = "postgres"))]
251 {
252 Self::default()
253 }
254 #[cfg(all(
255 feature = "bigtable",
256 not(any(feature = "redis", feature = "postgres"))
257 ))]
258 {
259 let host = env::var("BIGTABLE_EMULATOR_HOST").unwrap_or("localhost:8086".to_owned());
260 let db_dsn = Some(format!("grpc://{}", host));
261 let db_settings = json!({
263 "table_name":"projects/test/instances/test/tables/autopush",
264 "message_family":"message",
265 "router_family":"router",
266 "message_topic_family":"message_topic",
267 })
268 .to_string();
269 Self {
270 db_dsn,
271 db_settings,
272 ..Default::default()
273 }
274 }
275 #[cfg(all(
276 feature = "redis",
277 not(any(feature = "bigtable", feature = "postgres"))
278 ))]
279 {
280 let host = env::var("REDIS_HOST").unwrap_or("localhost:6379".to_owned());
281 let db_dsn = Some(format!("redis://{}", host));
282 let db_settings = "".to_string();
283 Self {
284 db_dsn,
285 db_settings,
286 ..Default::default()
287 }
288 }
289 #[cfg(all(
290 feature = "postgres",
291 not(any(feature = "bigtable", feature = "redis"))
292 ))]
293 {
294 let host = env::var("POSTGRES_HOST").unwrap_or("localhost:5432".to_owned());
295 let db_dsn = Some(format!("postgres://{}", host));
296 let db_settings = "".to_string();
297 Self {
298 db_dsn,
299 db_settings,
300 ..Default::default()
301 }
302 }
303 }
304}
305
306fn deserialize_f64_to_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
307where
308 D: Deserializer<'de>,
309{
310 let seconds: f64 = Deserialize::deserialize(deserializer)?;
311 Ok(Duration::new(
312 seconds as u64,
313 (seconds.fract() * 1_000_000_000.0) as u32,
314 ))
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 #[cfg(feature = "unsafe")]
321 use slog_scope::trace;
322
323 #[test]
324 fn test_router_url() {
325 let mut settings = Settings {
326 router_hostname: Some("testname".to_string()),
327 router_port: 80,
328 ..Default::default()
329 };
330 let url = settings.router_url();
331 assert_eq!("http://testname", url);
332
333 settings.router_port = 8080;
334 let url = settings.router_url();
335 assert_eq!("http://testname:8080", url);
336 }
337
338 #[test]
339 fn test_endpoint_url() {
340 let mut settings = Settings {
341 endpoint_hostname: "testname".to_string(),
342 endpoint_port: 80,
343 endpoint_scheme: "http".to_string(),
344 ..Default::default()
345 };
346 let url = settings.endpoint_url();
347 assert_eq!("http://testname", url);
348
349 settings.endpoint_port = 8080;
350 let url = settings.endpoint_url();
351 assert_eq!("http://testname:8080", url);
352
353 settings.endpoint_port = 443;
354 settings.endpoint_scheme = "https".to_string();
355 let url = settings.endpoint_url();
356 assert_eq!("https://testname", url);
357
358 settings.endpoint_port = 8080;
359 let url = settings.endpoint_url();
360 assert_eq!("https://testname:8080", url);
361 }
362
363 #[cfg(all(test, feature = "unsafe"))]
365 #[test]
366 fn test_default_settings() {
367 use std::env;
369 let port = format!("{ENV_PREFIX}__PORT").to_uppercase();
370 let msg_limit = format!("{ENV_PREFIX}__MSG_LIMIT").to_uppercase();
371 let fernet = format!("{ENV_PREFIX}__CRYPTO_KEY").to_uppercase();
372
373 let v1 = env::var(&port);
374 let v2 = env::var(&msg_limit);
375 unsafe {
376 env::set_var(&port, "9123");
377 env::set_var(&msg_limit, "123");
378 env::set_var(&fernet, "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]");
379 }
380 let settings = Settings::with_env_and_config_files(&Vec::new()).unwrap();
381 assert_eq!(settings.endpoint_hostname, "localhost".to_owned());
382 assert_eq!(&settings.port, &9123);
383 assert_eq!(&settings.msg_limit, &123);
384 assert_eq!(
385 &settings.crypto_key,
386 "[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]"
387 );
388 assert_eq!(settings.open_handshake_timeout, Duration::from_secs(5));
389
390 if let Ok(p) = v1 {
392 trace!("Resetting {}", &port);
393 unsafe { env::set_var(&port, p) };
395 } else {
396 unsafe { env::remove_var(&port) };
398 }
399 if let Ok(p) = v2 {
400 trace!("Resetting {}", msg_limit);
401 unsafe { env::set_var(&msg_limit, p) };
403 } else {
404 unsafe { env::remove_var(&msg_limit) };
406 }
407 unsafe { env::remove_var(&fernet) };
409 }
410}