autoendpoint/
server.rs

1//! Main application server
2#![forbid(unsafe_code)]
3use std::sync::Arc;
4use std::time::Duration;
5
6use actix_cors::Cors;
7use actix_web::{
8    dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data, App, HttpServer,
9};
10use cadence::StatsdClient;
11use fernet::MultiFernet;
12use serde_json::json;
13
14#[cfg(feature = "bigtable")]
15use autopush_common::db::bigtable::BigTableClientImpl;
16#[cfg(feature = "postgres")]
17use autopush_common::db::postgres::PgClientImpl;
18#[cfg(feature = "redis")]
19use autopush_common::db::redis::RedisClientImpl;
20#[cfg(feature = "reliable_report")]
21use autopush_common::reliability::PushReliability;
22use autopush_common::{
23    db::{client::DbClient, spawn_pool_periodic_reporter, DbSettings, StorageType},
24    middleware::sentry::SentryWrapper,
25};
26
27use crate::error::{ApiError, ApiErrorKind, ApiResult};
28use crate::metrics;
29#[cfg(feature = "stub")]
30use crate::routers::stub::router::StubRouter;
31use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
32use crate::routes::{
33    health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
34    registration::{
35        get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
36        unregister_user_route, update_token_route,
37    },
38    webpush::{delete_notification_route, webpush_route},
39};
40use crate::settings::Settings;
41use crate::settings::VapidTracker;
42
43#[derive(Clone)]
44pub struct AppState {
45    /// Server Data
46    pub metrics: Arc<StatsdClient>,
47    pub settings: Settings,
48    pub fernet: MultiFernet,
49    pub db: Box<dyn DbClient>,
50    pub http: reqwest::Client,
51    pub fcm_router: Arc<FcmRouter>,
52    pub apns_router: Arc<ApnsRouter>,
53    #[cfg(feature = "stub")]
54    pub stub_router: Arc<StubRouter>,
55    #[cfg(feature = "reliable_report")]
56    pub reliability: Arc<PushReliability>,
57    pub reliability_filter: VapidTracker,
58}
59
60pub struct Server;
61
62impl Server {
63    pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
64        let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
65        let bind_address = format!("{}:{}", settings.host, settings.port);
66        let fernet = settings.make_fernet();
67        let endpoint_url = settings.endpoint_url();
68        let reliability_filter = VapidTracker(
69            settings
70                .tracking_keys()
71                .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
72        );
73        let db_settings = DbSettings {
74            dsn: settings.db_dsn.clone(),
75            db_settings: if settings.db_settings.is_empty() {
76                warn!("❗ Using obsolete message_table and router_table args");
77                // backfill from the older arguments.
78                json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
79            } else {
80                settings.db_settings.clone()
81            },
82        };
83        let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
84            #[cfg(feature = "bigtable")]
85            StorageType::BigTable => {
86                debug!("Using BigTable");
87                let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
88                client.spawn_sweeper(Duration::from_secs(30));
89                Box::new(client)
90            }
91            #[cfg(feature = "postgres")]
92            StorageType::Postgres => {
93                debug!("Using Postgres");
94                let client = PgClientImpl::new(metrics.clone(), &db_settings)?;
95                // client.spawn_sweeper(Duration::from_secs(30));
96                Box::new(client)
97            }
98            #[cfg(feature = "redis")]
99            StorageType::Redis => Box::new(RedisClientImpl::new(metrics.clone(), &db_settings)?),
100            _ => {
101                debug!("No idea what {:?} is", &db_settings.dsn);
102                return Err(ApiErrorKind::General(
103                    "Invalid or Unsupported DSN specified".to_owned(),
104                )
105                .into());
106            }
107        };
108        #[cfg(feature = "reliable_report")]
109        let reliability = Arc::new(
110            PushReliability::new(
111                &settings.reliability_dsn,
112                db.clone(),
113                &metrics,
114                settings.reliability_retry_count,
115            )
116            .map_err(|e| {
117                ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
118            })?,
119        );
120        let http = reqwest::ClientBuilder::new()
121            .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
122            .timeout(Duration::from_millis(settings.request_timeout_millis))
123            .build()
124            .expect("Could not generate request client");
125        let fcm_router = Arc::new(
126            FcmRouter::new(
127                settings.fcm.clone(),
128                endpoint_url.clone(),
129                http.clone(),
130                metrics.clone(),
131                db.clone(),
132                #[cfg(feature = "reliable_report")]
133                reliability.clone(),
134            )
135            .await?,
136        );
137        let apns_router = Arc::new(
138            ApnsRouter::new(
139                settings.apns.clone(),
140                endpoint_url.clone(),
141                metrics.clone(),
142                db.clone(),
143                #[cfg(feature = "reliable_report")]
144                reliability.clone(),
145            )
146            .await?,
147        );
148        #[cfg(feature = "stub")]
149        let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
150        let app_state = AppState {
151            metrics: metrics.clone(),
152            settings,
153            fernet,
154            db,
155            http,
156            fcm_router,
157            apns_router,
158            #[cfg(feature = "stub")]
159            stub_router,
160            #[cfg(feature = "reliable_report")]
161            reliability,
162            reliability_filter,
163        };
164
165        spawn_pool_periodic_reporter(
166            Duration::from_secs(10),
167            app_state.db.clone(),
168            app_state.metrics.clone(),
169        );
170
171        let server = HttpServer::new(move || {
172            // These have a bad habit of being reset. Specify them explicitly.
173            let cors = Cors::default()
174                .allow_any_origin()
175                .allow_any_header()
176                .allowed_methods(vec![
177                    actix_web::http::Method::DELETE,
178                    actix_web::http::Method::GET,
179                    actix_web::http::Method::POST,
180                    actix_web::http::Method::PUT,
181                ])
182                .max_age(3600);
183            let app = App::new()
184                // Actix 4 recommends wrapping structures wtih web::Data (internally an Arc)
185                .app_data(Data::new(app_state.clone()))
186                // Extractor configuration
187                .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
188                .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
189                // Middleware
190                .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
191                // Our modified Sentry wrapper which does some blocking of non-reportable errors.
192                .wrap(SentryWrapper::<ApiError>::new(
193                    metrics.clone(),
194                    "api_error".to_owned(),
195                ))
196                .wrap(cors)
197                // Endpoints
198                .service(
199                    web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
200                        .route(web::post().to(webpush_route)),
201                )
202                .service(
203                    web::resource("/m/{message_id}")
204                        .route(web::delete().to(delete_notification_route)),
205                )
206                .service(
207                    web::resource("/v1/{router_type}/{app_id}/registration")
208                        .route(web::post().to(register_uaid_route)),
209                )
210                .service(
211                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
212                        .route(web::put().to(update_token_route))
213                        .route(web::get().to(get_channels_route))
214                        .route(web::delete().to(unregister_user_route)),
215                )
216                .service(
217                    web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
218                        .route(web::post().to(new_channel_route)),
219                )
220                .service(
221                    web::resource(
222                        "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
223                    )
224                    .route(web::delete().to(unregister_channel_route)),
225                )
226                // Health checks
227                .service(web::resource("/status").route(web::get().to(status_route)))
228                .service(web::resource("/health").route(web::get().to(health_route)))
229                // legacy
230                .service(web::resource("/v1/err").route(web::get().to(log_check)))
231                // standardized
232                .service(web::resource("/__error__").route(web::get().to(log_check)))
233                // Dockerflow
234                .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
235                .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
236                .service(web::resource("/__version__").route(web::get().to(version_route)));
237            #[cfg(feature = "reliable_report")]
238            // Note: Only the endpoint returns the Prometheus "/metrics" collection report. This report contains all metrics for both
239            // connection and endpoint, inclusive. It is served here mostly for simplicity's sake (since the endpoint handles more
240            // HTTP requests than the connection server, and this will simplify metric collection and reporting.)
241            let app = app.service(
242                web::resource("/metrics")
243                    .route(web::get().to(crate::routes::reliability::report_handler)),
244            );
245            app
246        })
247        .bind(bind_address)?
248        .run();
249
250        Ok(server)
251    }
252}