autoendpoint/
server.rs

1//! Main application server
2#![forbid(unsafe_code)]
3use std::sync::Arc;
4use std::time::Duration;
5
6use actix_cors::Cors;
7use actix_web::{
8    dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data, App, HttpServer,
9};
10use cadence::StatsdClient;
11use fernet::MultiFernet;
12use serde_json::json;
13
14#[cfg(feature = "bigtable")]
15use autopush_common::db::bigtable::BigTableClientImpl;
16#[cfg(feature = "reliable_report")]
17use autopush_common::reliability::PushReliability;
18use autopush_common::{
19    db::{client::DbClient, spawn_pool_periodic_reporter, DbSettings, StorageType},
20    middleware::sentry::SentryWrapper,
21};
22
23use crate::error::{ApiError, ApiErrorKind, ApiResult};
24use crate::metrics;
25#[cfg(feature = "stub")]
26use crate::routers::stub::router::StubRouter;
27use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
28use crate::routes::{
29    health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
30    registration::{
31        get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
32        unregister_user_route, update_token_route,
33    },
34    webpush::{delete_notification_route, webpush_route},
35};
36use crate::settings::Settings;
37use crate::settings::VapidTracker;
38
39#[derive(Clone)]
40pub struct AppState {
41    /// Server Data
42    pub metrics: Arc<StatsdClient>,
43    pub settings: Settings,
44    pub fernet: MultiFernet,
45    pub db: Box<dyn DbClient>,
46    pub http: reqwest::Client,
47    pub fcm_router: Arc<FcmRouter>,
48    pub apns_router: Arc<ApnsRouter>,
49    #[cfg(feature = "stub")]
50    pub stub_router: Arc<StubRouter>,
51    #[cfg(feature = "reliable_report")]
52    pub reliability: Arc<PushReliability>,
53    pub reliability_filter: VapidTracker,
54}
55
56pub struct Server;
57
58impl Server {
59    pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
60        let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
61        let bind_address = format!("{}:{}", settings.host, settings.port);
62        let fernet = settings.make_fernet();
63        let endpoint_url = settings.endpoint_url();
64        let reliability_filter = VapidTracker(
65            settings
66                .tracking_keys()
67                .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
68        );
69        let db_settings = DbSettings {
70            dsn: settings.db_dsn.clone(),
71            db_settings: if settings.db_settings.is_empty() {
72                warn!("❗ Using obsolete message_table and router_table args");
73                // backfill from the older arguments.
74                json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
75            } else {
76                settings.db_settings.clone()
77            },
78        };
79        let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
80            #[cfg(feature = "bigtable")]
81            StorageType::BigTable => {
82                debug!("Using BigTable");
83                let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
84                client.spawn_sweeper(Duration::from_secs(30));
85                Box::new(client)
86            }
87            _ => {
88                debug!("No idea what {:?} is", &db_settings.dsn);
89                return Err(ApiErrorKind::General(
90                    "Invalid or Unsupported DSN specified".to_owned(),
91                )
92                .into());
93            }
94        };
95        #[cfg(feature = "reliable_report")]
96        let reliability = Arc::new(
97            PushReliability::new(
98                &settings.reliability_dsn,
99                db.clone(),
100                &metrics,
101                settings.reliability_retry_count,
102            )
103            .map_err(|e| {
104                ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
105            })?,
106        );
107        let http = reqwest::ClientBuilder::new()
108            .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
109            .timeout(Duration::from_millis(settings.request_timeout_millis))
110            .build()
111            .expect("Could not generate request client");
112        let fcm_router = Arc::new(
113            FcmRouter::new(
114                settings.fcm.clone(),
115                endpoint_url.clone(),
116                http.clone(),
117                metrics.clone(),
118                db.clone(),
119                #[cfg(feature = "reliable_report")]
120                reliability.clone(),
121            )
122            .await?,
123        );
124        let apns_router = Arc::new(
125            ApnsRouter::new(
126                settings.apns.clone(),
127                endpoint_url.clone(),
128                metrics.clone(),
129                db.clone(),
130                #[cfg(feature = "reliable_report")]
131                reliability.clone(),
132            )
133            .await?,
134        );
135        #[cfg(feature = "stub")]
136        let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
137        let app_state = AppState {
138            metrics: metrics.clone(),
139            settings,
140            fernet,
141            db,
142            http,
143            fcm_router,
144            apns_router,
145            #[cfg(feature = "stub")]
146            stub_router,
147            #[cfg(feature = "reliable_report")]
148            reliability,
149            reliability_filter,
150        };
151
152        spawn_pool_periodic_reporter(
153            Duration::from_secs(10),
154            app_state.db.clone(),
155            app_state.metrics.clone(),
156        );
157
158        let server = HttpServer::new(move || {
159            // These have a bad habit of being reset. Specify them explicitly.
160            let cors = Cors::default()
161                .allow_any_origin()
162                .allow_any_header()
163                .allowed_methods(vec![
164                    actix_web::http::Method::DELETE,
165                    actix_web::http::Method::GET,
166                    actix_web::http::Method::POST,
167                    actix_web::http::Method::PUT,
168                ])
169                .max_age(3600);
170            let app = App::new()
171                // Actix 4 recommends wrapping structures wtih web::Data (internally an Arc)
172                .app_data(Data::new(app_state.clone()))
173                // Extractor configuration
174                .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
175                .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
176                // Middleware
177                .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
178                // Our modified Sentry wrapper which does some blocking of non-reportable errors.
179                .wrap(SentryWrapper::<ApiError>::new(
180                    metrics.clone(),
181                    "api_error".to_owned(),
182                ))
183                .wrap(cors)
184                // Endpoints
185                .service(
186                    web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
187                        .route(web::post().to(webpush_route)),
188                )
189                .service(
190                    web::resource("/m/{message_id}")
191                        .route(web::delete().to(delete_notification_route)),
192                )
193                .service(
194                    web::resource("/v1/{router_type}/{app_id}/registration")
195                        .route(web::post().to(register_uaid_route)),
196                )
197                .service(
198                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
199                        .route(web::put().to(update_token_route))
200                        .route(web::get().to(get_channels_route))
201                        .route(web::delete().to(unregister_user_route)),
202                )
203                .service(
204                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
205                        .route(web::post().to(new_channel_route)),
206                )
207                .service(
208                    web::resource(
209                        "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
210                    )
211                    .route(web::delete().to(unregister_channel_route)),
212                )
213                // Health checks
214                .service(web::resource("/status").route(web::get().to(status_route)))
215                .service(web::resource("/health").route(web::get().to(health_route)))
216                // legacy
217                .service(web::resource("/v1/err").route(web::get().to(log_check)))
218                // standardized
219                .service(web::resource("/__error__").route(web::get().to(log_check)))
220                // Dockerflow
221                .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
222                .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
223                .service(web::resource("/__version__").route(web::get().to(version_route)));
224            #[cfg(feature = "reliable_report")]
225            // Note: Only the endpoint returns the Prometheus "/metrics" collection report. This report contains all metrics for both
226            // connection and endpoint, inclusive. It is served here mostly for simplicity's sake (since the endpoint handles more
227            // HTTP requests than the connection server, and this will simplify metric collection and reporting.)
228            let app = app.service(
229                web::resource("/metrics")
230                    .route(web::get().to(crate::routes::reliability::report_handler)),
231            );
232            app
233        })
234        .bind(bind_address)?
235        .run();
236
237        Ok(server)
238    }
239}