autoendpoint/
server.rs

1//! Main application server
2#![forbid(unsafe_code)]
3use std::sync::Arc;
4use std::time::Duration;
5
6use actix_cors::Cors;
7use actix_web::{
8    dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data, App, HttpServer,
9};
10use cadence::StatsdClient;
11use fernet::MultiFernet;
12use serde_json::json;
13
14#[cfg(feature = "bigtable")]
15use autopush_common::db::bigtable::BigTableClientImpl;
16#[cfg(feature = "redis")]
17use autopush_common::db::redis::RedisClientImpl;
18#[cfg(feature = "reliable_report")]
19use autopush_common::reliability::PushReliability;
20use autopush_common::{
21    db::{client::DbClient, spawn_pool_periodic_reporter, DbSettings, StorageType},
22    middleware::sentry::SentryWrapper,
23};
24
25use crate::error::{ApiError, ApiErrorKind, ApiResult};
26use crate::metrics;
27#[cfg(feature = "stub")]
28use crate::routers::stub::router::StubRouter;
29use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
30use crate::routes::{
31    health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
32    registration::{
33        get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
34        unregister_user_route, update_token_route,
35    },
36    webpush::{delete_notification_route, webpush_route},
37};
38use crate::settings::Settings;
39use crate::settings::VapidTracker;
40
41#[derive(Clone)]
42pub struct AppState {
43    /// Server Data
44    pub metrics: Arc<StatsdClient>,
45    pub settings: Settings,
46    pub fernet: MultiFernet,
47    pub db: Box<dyn DbClient>,
48    pub http: reqwest::Client,
49    pub fcm_router: Arc<FcmRouter>,
50    pub apns_router: Arc<ApnsRouter>,
51    #[cfg(feature = "stub")]
52    pub stub_router: Arc<StubRouter>,
53    #[cfg(feature = "reliable_report")]
54    pub reliability: Arc<PushReliability>,
55    pub reliability_filter: VapidTracker,
56}
57
58pub struct Server;
59
60impl Server {
61    pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
62        let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
63        let bind_address = format!("{}:{}", settings.host, settings.port);
64        let fernet = settings.make_fernet();
65        let endpoint_url = settings.endpoint_url();
66        let reliability_filter = VapidTracker(
67            settings
68                .tracking_keys()
69                .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
70        );
71        let db_settings = DbSettings {
72            dsn: settings.db_dsn.clone(),
73            db_settings: if settings.db_settings.is_empty() {
74                warn!("❗ Using obsolete message_table and router_table args");
75                // backfill from the older arguments.
76                json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
77            } else {
78                settings.db_settings.clone()
79            },
80        };
81        let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
82            #[cfg(feature = "bigtable")]
83            StorageType::BigTable => {
84                debug!("Using BigTable");
85                let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
86                client.spawn_sweeper(Duration::from_secs(30));
87                Box::new(client)
88            }
89            #[cfg(feature = "redis")]
90            StorageType::Redis => Box::new(RedisClientImpl::new(metrics.clone(), &db_settings)?),
91            _ => {
92                debug!("No idea what {:?} is", &db_settings.dsn);
93                return Err(ApiErrorKind::General(
94                    "Invalid or Unsupported DSN specified".to_owned(),
95                )
96                .into());
97            }
98        };
99        #[cfg(feature = "reliable_report")]
100        let reliability = Arc::new(
101            PushReliability::new(
102                &settings.reliability_dsn,
103                db.clone(),
104                &metrics,
105                settings.reliability_retry_count,
106            )
107            .map_err(|e| {
108                ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
109            })?,
110        );
111        let http = reqwest::ClientBuilder::new()
112            .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
113            .timeout(Duration::from_millis(settings.request_timeout_millis))
114            .build()
115            .expect("Could not generate request client");
116        let fcm_router = Arc::new(
117            FcmRouter::new(
118                settings.fcm.clone(),
119                endpoint_url.clone(),
120                http.clone(),
121                metrics.clone(),
122                db.clone(),
123                #[cfg(feature = "reliable_report")]
124                reliability.clone(),
125            )
126            .await?,
127        );
128        let apns_router = Arc::new(
129            ApnsRouter::new(
130                settings.apns.clone(),
131                endpoint_url.clone(),
132                metrics.clone(),
133                db.clone(),
134                #[cfg(feature = "reliable_report")]
135                reliability.clone(),
136            )
137            .await?,
138        );
139        #[cfg(feature = "stub")]
140        let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
141        let app_state = AppState {
142            metrics: metrics.clone(),
143            settings,
144            fernet,
145            db,
146            http,
147            fcm_router,
148            apns_router,
149            #[cfg(feature = "stub")]
150            stub_router,
151            #[cfg(feature = "reliable_report")]
152            reliability,
153            reliability_filter,
154        };
155
156        spawn_pool_periodic_reporter(
157            Duration::from_secs(10),
158            app_state.db.clone(),
159            app_state.metrics.clone(),
160        );
161
162        let server = HttpServer::new(move || {
163            // These have a bad habit of being reset. Specify them explicitly.
164            let cors = Cors::default()
165                .allow_any_origin()
166                .allow_any_header()
167                .allowed_methods(vec![
168                    actix_web::http::Method::DELETE,
169                    actix_web::http::Method::GET,
170                    actix_web::http::Method::POST,
171                    actix_web::http::Method::PUT,
172                ])
173                .max_age(3600);
174            let app = App::new()
175                // Actix 4 recommends wrapping structures wtih web::Data (internally an Arc)
176                .app_data(Data::new(app_state.clone()))
177                // Extractor configuration
178                .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
179                .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
180                // Middleware
181                .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
182                // Our modified Sentry wrapper which does some blocking of non-reportable errors.
183                .wrap(SentryWrapper::<ApiError>::new(
184                    metrics.clone(),
185                    "api_error".to_owned(),
186                ))
187                .wrap(cors)
188                // Endpoints
189                .service(
190                    web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
191                        .route(web::post().to(webpush_route)),
192                )
193                .service(
194                    web::resource("/m/{message_id}")
195                        .route(web::delete().to(delete_notification_route)),
196                )
197                .service(
198                    web::resource("/v1/{router_type}/{app_id}/registration")
199                        .route(web::post().to(register_uaid_route)),
200                )
201                .service(
202                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
203                        .route(web::put().to(update_token_route))
204                        .route(web::get().to(get_channels_route))
205                        .route(web::delete().to(unregister_user_route)),
206                )
207                .service(
208                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
209                        .route(web::post().to(new_channel_route)),
210                )
211                .service(
212                    web::resource(
213                        "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
214                    )
215                    .route(web::delete().to(unregister_channel_route)),
216                )
217                // Health checks
218                .service(web::resource("/status").route(web::get().to(status_route)))
219                .service(web::resource("/health").route(web::get().to(health_route)))
220                // legacy
221                .service(web::resource("/v1/err").route(web::get().to(log_check)))
222                // standardized
223                .service(web::resource("/__error__").route(web::get().to(log_check)))
224                // Dockerflow
225                .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
226                .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
227                .service(web::resource("/__version__").route(web::get().to(version_route)));
228            #[cfg(feature = "reliable_report")]
229            // Note: Only the endpoint returns the Prometheus "/metrics" collection report. This report contains all metrics for both
230            // connection and endpoint, inclusive. It is served here mostly for simplicity's sake (since the endpoint handles more
231            // HTTP requests than the connection server, and this will simplify metric collection and reporting.)
232            let app = app.service(
233                web::resource("/metrics")
234                    .route(web::get().to(crate::routes::reliability::report_handler)),
235            );
236            app
237        })
238        .bind(bind_address)?
239        .run();
240
241        Ok(server)
242    }
243}