autoendpoint/
server.rs

1//! Main application server
2#![forbid(unsafe_code)]
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6
7use actix_cors::Cors;
8use actix_web::{
9    dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data, App, HttpServer,
10};
11use cadence::StatsdClient;
12use fernet::MultiFernet;
13use serde_json::json;
14
15#[cfg(feature = "bigtable")]
16use autopush_common::db::bigtable::BigTableClientImpl;
17#[cfg(feature = "postgres")]
18use autopush_common::db::postgres::PgClientImpl;
19#[cfg(feature = "redis")]
20use autopush_common::db::redis::RedisClientImpl;
21#[cfg(feature = "reliable_report")]
22use autopush_common::reliability::PushReliability;
23use autopush_common::{
24    db::{client::DbClient, spawn_pool_periodic_reporter, DbSettings, StorageType},
25    metric_name::MetricName,
26    middleware::sentry::SentryWrapper,
27};
28
29use crate::error::{ApiError, ApiErrorKind, ApiResult};
30use crate::metrics;
31#[cfg(feature = "stub")]
32use crate::routers::stub::router::StubRouter;
33use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
34use crate::routes::{
35    health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
36    registration::{
37        get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
38        unregister_user_route, update_token_route,
39    },
40    webpush::{delete_notification_route, webpush_route},
41};
42use crate::settings::Settings;
43use crate::settings::VapidTracker;
44
45#[derive(Clone)]
46pub struct AppState {
47    /// Server Data
48    pub metrics: Arc<StatsdClient>,
49    pub settings: Settings,
50    pub fernet: MultiFernet,
51    pub db: Box<dyn DbClient>,
52    pub http: reqwest::Client,
53    pub in_flight_requests: Arc<AtomicUsize>,
54    pub fcm_router: Arc<FcmRouter>,
55    pub apns_router: Arc<ApnsRouter>,
56    #[cfg(feature = "stub")]
57    pub stub_router: Arc<StubRouter>,
58    #[cfg(feature = "reliable_report")]
59    pub reliability: Arc<PushReliability>,
60    pub reliability_filter: VapidTracker,
61}
62
63pub struct Server;
64
65impl Server {
66    pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
67        let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
68        let bind_address = format!("{}:{}", settings.host, settings.port);
69        let fernet = settings.make_fernet();
70        let endpoint_url = settings.endpoint_url();
71        let reliability_filter = VapidTracker(
72            settings
73                .tracking_keys()
74                .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
75        );
76        let db_settings = DbSettings {
77            dsn: settings.db_dsn.clone(),
78            db_settings: if settings.db_settings.is_empty() {
79                warn!("❗ Using obsolete message_table and router_table args");
80                // backfill from the older arguments.
81                json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
82            } else {
83                settings.db_settings.clone()
84            },
85        };
86        let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
87            #[cfg(feature = "bigtable")]
88            StorageType::BigTable => {
89                debug!("Using BigTable");
90                let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
91                client.spawn_sweeper(Duration::from_secs(30));
92                Box::new(client)
93            }
94            #[cfg(feature = "postgres")]
95            StorageType::Postgres => {
96                debug!("Using Postgres");
97                let client = PgClientImpl::new(metrics.clone(), &db_settings)?;
98                // client.spawn_sweeper(Duration::from_secs(30));
99                Box::new(client)
100            }
101            #[cfg(feature = "redis")]
102            StorageType::Redis => Box::new(RedisClientImpl::new(metrics.clone(), &db_settings)?),
103            _ => {
104                debug!("No idea what {:?} is", &db_settings.dsn);
105                return Err(ApiErrorKind::General(
106                    "Invalid or Unsupported DSN specified".to_owned(),
107                )
108                .into());
109            }
110        };
111        #[cfg(feature = "reliable_report")]
112        let reliability = Arc::new(
113            PushReliability::new(
114                &settings.reliability_dsn,
115                db.clone(),
116                &metrics,
117                settings.reliability_retry_count,
118            )
119            .map_err(|e| {
120                ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
121            })?,
122        );
123        let http = reqwest::ClientBuilder::new()
124            .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
125            .timeout(Duration::from_millis(settings.request_timeout_millis))
126            .pool_max_idle_per_host(settings.pool_max_idle_per_host)
127            .pool_idle_timeout(Duration::from_secs(settings.pool_idle_timeout_secs))
128            .build()
129            .expect("Could not generate request client");
130        let fcm_router = Arc::new(
131            FcmRouter::new(
132                settings.fcm.clone(),
133                endpoint_url.clone(),
134                http.clone(),
135                metrics.clone(),
136                db.clone(),
137                #[cfg(feature = "reliable_report")]
138                reliability.clone(),
139            )
140            .await?,
141        );
142        let apns_router = Arc::new(
143            ApnsRouter::new(
144                settings.apns.clone(),
145                endpoint_url.clone(),
146                metrics.clone(),
147                db.clone(),
148                #[cfg(feature = "reliable_report")]
149                reliability.clone(),
150            )
151            .await?,
152        );
153        #[cfg(feature = "stub")]
154        let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
155        let in_flight_requests = Arc::new(AtomicUsize::new(0));
156        let app_state = AppState {
157            metrics: metrics.clone(),
158            settings,
159            fernet,
160            db,
161            http,
162            in_flight_requests: in_flight_requests.clone(),
163            fcm_router,
164            apns_router,
165            #[cfg(feature = "stub")]
166            stub_router,
167            #[cfg(feature = "reliable_report")]
168            reliability,
169            reliability_filter,
170        };
171
172        spawn_pool_periodic_reporter(
173            Duration::from_secs(10),
174            app_state.db.clone(),
175            app_state.metrics.clone(),
176        );
177
178        // Periodically report in-flight request gauge
179        {
180            let metrics = app_state.metrics.clone();
181            let in_flight = in_flight_requests;
182            tokio::spawn(async move {
183                let mut interval = tokio::time::interval(Duration::from_secs(10));
184                loop {
185                    interval.tick().await;
186                    let count = in_flight.load(Ordering::Relaxed);
187                    if let Err(e) = cadence::Gauged::gauge(
188                        metrics.as_ref(),
189                        MetricName::InFlightNodeRequests.as_ref(),
190                        count as u64,
191                    ) {
192                        debug!("Failed to report in-flight metric: {}", e);
193                    }
194                }
195            });
196        }
197
198        let server = HttpServer::new(move || {
199            // These have a bad habit of being reset. Specify them explicitly.
200            let cors = Cors::default()
201                .allow_any_origin()
202                .allow_any_header()
203                .allowed_methods(vec![
204                    actix_web::http::Method::DELETE,
205                    actix_web::http::Method::GET,
206                    actix_web::http::Method::POST,
207                    actix_web::http::Method::PUT,
208                ])
209                .max_age(3600);
210            let app = App::new()
211                // Actix 4 recommends wrapping structures wtih web::Data (internally an Arc)
212                .app_data(Data::new(app_state.clone()))
213                // Extractor configuration
214                .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
215                .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
216                // Middleware
217                .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
218                // Our modified Sentry wrapper which does some blocking of non-reportable errors.
219                .wrap(SentryWrapper::<ApiError>::new(
220                    metrics.clone(),
221                    "api_error".to_owned(),
222                    app_state.settings.disable_sentry,
223                ))
224                .wrap(cors)
225                // Endpoints
226                .service(
227                    web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
228                        .route(web::post().to(webpush_route)),
229                )
230                .service(
231                    web::resource("/m/{message_id}")
232                        .route(web::delete().to(delete_notification_route)),
233                )
234                .service(
235                    web::resource("/v1/{router_type}/{app_id}/registration")
236                        .route(web::post().to(register_uaid_route)),
237                )
238                .service(
239                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
240                        .route(web::put().to(update_token_route))
241                        .route(web::get().to(get_channels_route))
242                        .route(web::delete().to(unregister_user_route)),
243                )
244                .service(
245                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
246                        .route(web::post().to(new_channel_route)),
247                )
248                .service(
249                    web::resource(
250                        "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
251                    )
252                    .route(web::delete().to(unregister_channel_route)),
253                )
254                // Health checks
255                .service(web::resource("/status").route(web::get().to(status_route)))
256                .service(web::resource("/health").route(web::get().to(health_route)))
257                // legacy
258                .service(web::resource("/v1/err").route(web::get().to(log_check)))
259                // standardized
260                .service(web::resource("/__error__").route(web::get().to(log_check)))
261                // Dockerflow
262                .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
263                .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
264                .service(web::resource("/__version__").route(web::get().to(version_route)));
265            #[cfg(feature = "reliable_report")]
266            // Note: Only the endpoint returns the Prometheus "/metrics" collection report. This report contains all metrics for both
267            // connection and endpoint, inclusive. It is served here mostly for simplicity's sake (since the endpoint handles more
268            // HTTP requests than the connection server, and this will simplify metric collection and reporting.)
269            let app = app.service(
270                web::resource("/metrics")
271                    .route(web::get().to(crate::routes::reliability::report_handler)),
272            );
273            app
274        })
275        .bind(bind_address)?
276        .run();
277
278        Ok(server)
279    }
280}