autoendpoint/
server.rs

1//! Main application server
2#![forbid(unsafe_code)]
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::time::Duration;
6
7use actix_cors::Cors;
8use actix_web::{
9    App, HttpServer, dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data,
10};
11use cadence::StatsdClient;
12use fernet::MultiFernet;
13use serde_json::json;
14
15#[cfg(feature = "bigtable")]
16use autopush_common::db::bigtable::BigTableClientImpl;
17#[cfg(feature = "postgres")]
18use autopush_common::db::postgres::PgClientImpl;
19#[cfg(feature = "redis")]
20use autopush_common::db::redis::RedisClientImpl;
21#[cfg(feature = "reliable_report")]
22use autopush_common::reliability::PushReliability;
23use autopush_common::{
24    db::{DbSettings, StorageType, client::DbClient, spawn_pool_periodic_reporter},
25    metric_name::MetricName,
26    middleware::sentry::SentryWrapper,
27};
28
29use crate::error::{ApiError, ApiErrorKind, ApiResult};
30use crate::metrics;
31#[cfg(feature = "stub")]
32use crate::routers::stub::router::StubRouter;
33use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
34use crate::routes::{
35    health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
36    registration::{
37        get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
38        unregister_user_route, update_token_route,
39    },
40    webpush::{delete_notification_route, webpush_route},
41};
42use crate::settings::Settings;
43use crate::settings::VapidTracker;
44
45#[derive(Clone)]
46pub struct AppState {
47    /// Server Data
48    pub metrics: Arc<StatsdClient>,
49    pub settings: Settings,
50    pub fernet: MultiFernet,
51    pub db: Box<dyn DbClient>,
52    pub http: reqwest::Client,
53    // requests that are currently being exchanged between nodes.
54    pub in_flight_requests: Arc<AtomicUsize>,
55    // subscription updates that are currently being processed.
56    pub in_process_subscription_updates: Arc<AtomicUsize>,
57    pub fcm_router: Arc<FcmRouter>,
58    pub apns_router: Arc<ApnsRouter>,
59    #[cfg(feature = "stub")]
60    pub stub_router: Arc<StubRouter>,
61    #[cfg(feature = "reliable_report")]
62    pub reliability: Arc<PushReliability>,
63    pub reliability_filter: VapidTracker,
64}
65
66#[cfg(test)]
67impl AppState {
68    /// build a fake AppState with a MockDbClient and default settings, for use in tests.
69    pub(crate) async fn test_default(mock_db: autopush_common::db::mock::MockDbClient) -> Self {
70        let settings = Settings {
71            auth_keys: "HJVPy4ZwF4Yz_JdvXTL8hRcwIhv742vC60Tg5Ycrvw8=".to_owned(),
72            ..Default::default()
73        };
74        let metrics = Arc::new(crate::metrics::Metrics::sink());
75        let db = mock_db.into_boxed_arc();
76        #[cfg(feature = "reliable_report")]
77        let reliability =
78            Arc::new(PushReliability::new(&None, db.box_clone(), &metrics.clone(), 0).unwrap());
79        let fernet = settings.make_fernet();
80        Self {
81            metrics: metrics.clone(),
82            settings,
83            fernet,
84            db: db.clone(),
85            http: reqwest::Client::new(),
86            in_flight_requests: Arc::new(AtomicUsize::new(0)),
87            in_process_subscription_updates: Arc::new(AtomicUsize::new(0)),
88            fcm_router: Arc::new(
89                FcmRouter::new(
90                    crate::routers::fcm::settings::FcmSettings::default(),
91                    url::Url::parse("https://example.com").unwrap(),
92                    reqwest::Client::new(),
93                    metrics.clone(),
94                    db.clone(),
95                    #[cfg(feature = "reliable_report")]
96                    reliability.clone(),
97                )
98                .await
99                .unwrap(),
100            ),
101            apns_router: Arc::new(
102                ApnsRouter::new(
103                    crate::routers::apns::settings::ApnsSettings::default(),
104                    url::Url::parse("https://example.com").unwrap(),
105                    metrics.clone(),
106                    db.clone(),
107                    #[cfg(feature = "reliable_report")]
108                    reliability.clone(),
109                )
110                .await
111                .unwrap(),
112            ),
113            #[cfg(feature = "stub")]
114            stub_router: Arc::new(
115                StubRouter::new(crate::routers::stub::settings::StubSettings::default()).unwrap(),
116            ),
117            #[cfg(feature = "reliable_report")]
118            reliability,
119            reliability_filter: VapidTracker(Vec::new()),
120        }
121    }
122}
123
124pub struct Server;
125
126impl Server {
127    pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
128        let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
129        let bind_address = format!("{}:{}", settings.host, settings.port);
130        let fernet = settings.make_fernet();
131        let endpoint_url = settings.endpoint_url();
132        let reliability_filter = VapidTracker(
133            settings
134                .tracking_keys()
135                .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
136        );
137        let db_settings = DbSettings {
138            dsn: settings.db_dsn.clone(),
139            db_settings: if settings.db_settings.is_empty() {
140                warn!("❗ Using obsolete message_table and router_table args");
141                // backfill from the older arguments.
142                json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
143            } else {
144                settings.db_settings.clone()
145            },
146        };
147        let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
148            #[cfg(feature = "bigtable")]
149            StorageType::BigTable => {
150                debug!("Using BigTable");
151                let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
152                client.spawn_sweeper(Duration::from_secs(30));
153                Box::new(client)
154            }
155            #[cfg(feature = "postgres")]
156            StorageType::Postgres => {
157                debug!("Using Postgres");
158                let client = PgClientImpl::new(metrics.clone(), &db_settings)?;
159                // client.spawn_sweeper(Duration::from_secs(30));
160                Box::new(client)
161            }
162            #[cfg(feature = "redis")]
163            StorageType::Redis => Box::new(RedisClientImpl::new(metrics.clone(), &db_settings)?),
164            _ => {
165                debug!("No idea what {:?} is", &db_settings.dsn);
166                return Err(ApiErrorKind::General(
167                    "Invalid or Unsupported DSN specified".to_owned(),
168                )
169                .into());
170            }
171        };
172        #[cfg(feature = "reliable_report")]
173        let reliability = Arc::new(
174            PushReliability::new(
175                &settings.reliability_dsn,
176                db.clone(),
177                &metrics,
178                settings.reliability_retry_count,
179            )
180            .map_err(|e| {
181                ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
182            })?,
183        );
184        let http = reqwest::ClientBuilder::new()
185            .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
186            .timeout(Duration::from_millis(settings.request_timeout_millis))
187            .pool_max_idle_per_host(settings.pool_max_idle_per_host)
188            .pool_idle_timeout(Duration::from_secs(settings.pool_idle_timeout_secs))
189            .build()
190            .expect("Could not generate request client");
191        let fcm_router = Arc::new(
192            FcmRouter::new(
193                settings.fcm.clone(),
194                endpoint_url.clone(),
195                http.clone(),
196                metrics.clone(),
197                db.clone(),
198                #[cfg(feature = "reliable_report")]
199                reliability.clone(),
200            )
201            .await?,
202        );
203        let apns_router = Arc::new(
204            ApnsRouter::new(
205                settings.apns.clone(),
206                endpoint_url.clone(),
207                metrics.clone(),
208                db.clone(),
209                #[cfg(feature = "reliable_report")]
210                reliability.clone(),
211            )
212            .await?,
213        );
214        #[cfg(feature = "stub")]
215        let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
216        let in_flight_requests = Arc::new(AtomicUsize::new(0));
217        let in_process_subscription_updates = Arc::new(AtomicUsize::new(0));
218        let app_state = AppState {
219            metrics: metrics.clone(),
220            settings,
221            fernet,
222            db,
223            http,
224            in_flight_requests: in_flight_requests.clone(),
225            in_process_subscription_updates: in_process_subscription_updates.clone(),
226            fcm_router,
227            apns_router,
228            #[cfg(feature = "stub")]
229            stub_router,
230            #[cfg(feature = "reliable_report")]
231            reliability,
232            reliability_filter,
233        };
234
235        spawn_pool_periodic_reporter(
236            Duration::from_secs(10),
237            app_state.db.clone(),
238            app_state.metrics.clone(),
239        );
240
241        // Periodically report in-flight request gauge
242        {
243            let metrics = app_state.metrics.clone();
244            let in_flight = in_flight_requests;
245            tokio::spawn(async move {
246                let mut interval = tokio::time::interval(Duration::from_secs(10));
247                loop {
248                    interval.tick().await;
249                    let count = in_flight.load(Ordering::Relaxed);
250                    if let Err(e) = cadence::Gauged::gauge(
251                        metrics.as_ref(),
252                        MetricName::InFlightNodeRequests.as_ref(),
253                        count as u64,
254                    ) {
255                        debug!("Failed to report in-flight metric: {}", e);
256                    }
257                }
258            });
259        }
260
261        let server = HttpServer::new(move || {
262            // These have a bad habit of being reset. Specify them explicitly.
263            let cors = Cors::default()
264                .allow_any_origin()
265                .allow_any_header()
266                .allowed_methods(vec![
267                    actix_web::http::Method::DELETE,
268                    actix_web::http::Method::GET,
269                    actix_web::http::Method::POST,
270                    actix_web::http::Method::PUT,
271                ])
272                .max_age(3600);
273            let app = App::new()
274                // Actix 4 recommends wrapping structures wtih web::Data (internally an Arc)
275                .app_data(Data::new(app_state.clone()))
276                // Extractor configuration
277                .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
278                .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
279                // Middleware
280                .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
281                // Our modified Sentry wrapper which does some blocking of non-reportable errors.
282                .wrap(SentryWrapper::<ApiError>::new(
283                    metrics.clone(),
284                    "api_error".to_owned(),
285                    app_state.settings.disable_sentry,
286                ))
287                .wrap(cors)
288                // Endpoints
289                .service(
290                    web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
291                        .route(web::post().to(webpush_route)),
292                )
293                .service(
294                    web::resource("/m/{message_id}")
295                        .route(web::delete().to(delete_notification_route)),
296                )
297                .service(
298                    web::resource("/v1/{router_type}/{app_id}/registration")
299                        .route(web::post().to(register_uaid_route)),
300                )
301                .service(
302                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
303                        .route(web::put().to(update_token_route))
304                        .route(web::get().to(get_channels_route))
305                        .route(web::delete().to(unregister_user_route)),
306                )
307                .service(
308                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
309                        .route(web::post().to(new_channel_route)),
310                )
311                .service(
312                    web::resource(
313                        "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
314                    )
315                    .route(web::delete().to(unregister_channel_route)),
316                )
317                // Health checks
318                .service(web::resource("/status").route(web::get().to(status_route)))
319                .service(web::resource("/health").route(web::get().to(health_route)))
320                // legacy
321                .service(web::resource("/v1/err").route(web::get().to(log_check)))
322                // standardized
323                .service(web::resource("/__error__").route(web::get().to(log_check)))
324                // Dockerflow
325                .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
326                .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
327                .service(web::resource("/__version__").route(web::get().to(version_route)));
328            #[cfg(feature = "reliable_report")]
329            // Note: Only the endpoint returns the Prometheus "/metrics" collection report. This report contains all metrics for both
330            // connection and endpoint, inclusive. It is served here mostly for simplicity's sake (since the endpoint handles more
331            // HTTP requests than the connection server, and this will simplify metric collection and reporting.)
332            let app = app.service(
333                web::resource("/metrics")
334                    .route(web::get().to(crate::routes::reliability::report_handler)),
335            );
336            app
337        })
338        .bind(bind_address)?
339        .run();
340
341        Ok(server)
342    }
343}