1#![forbid(unsafe_code)]
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::time::Duration;
6
7use actix_cors::Cors;
8use actix_web::{
9 App, HttpServer, dev, http::StatusCode, middleware::ErrorHandlers, web, web::Data,
10};
11use cadence::StatsdClient;
12use fernet::MultiFernet;
13use serde_json::json;
14
15#[cfg(feature = "bigtable")]
16use autopush_common::db::bigtable::BigTableClientImpl;
17#[cfg(feature = "postgres")]
18use autopush_common::db::postgres::PgClientImpl;
19#[cfg(feature = "redis")]
20use autopush_common::db::redis::RedisClientImpl;
21#[cfg(feature = "reliable_report")]
22use autopush_common::reliability::PushReliability;
23use autopush_common::{
24 db::{DbSettings, StorageType, client::DbClient, spawn_pool_periodic_reporter},
25 metric_name::MetricName,
26 middleware::sentry::SentryWrapper,
27};
28
29use crate::error::{ApiError, ApiErrorKind, ApiResult};
30use crate::metrics;
31#[cfg(feature = "stub")]
32use crate::routers::stub::router::StubRouter;
33use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter};
34use crate::routes::{
35 health::{health_route, lb_heartbeat_route, log_check, status_route, version_route},
36 registration::{
37 get_channels_route, new_channel_route, register_uaid_route, unregister_channel_route,
38 unregister_user_route, update_token_route,
39 },
40 webpush::{delete_notification_route, webpush_route},
41};
42use crate::settings::Settings;
43use crate::settings::VapidTracker;
44
45#[derive(Clone)]
46pub struct AppState {
47 pub metrics: Arc<StatsdClient>,
49 pub settings: Settings,
50 pub fernet: MultiFernet,
51 pub db: Box<dyn DbClient>,
52 pub http: reqwest::Client,
53 pub in_flight_requests: Arc<AtomicUsize>,
55 pub in_process_subscription_updates: Arc<AtomicUsize>,
57 pub fcm_router: Arc<FcmRouter>,
58 pub apns_router: Arc<ApnsRouter>,
59 #[cfg(feature = "stub")]
60 pub stub_router: Arc<StubRouter>,
61 #[cfg(feature = "reliable_report")]
62 pub reliability: Arc<PushReliability>,
63 pub reliability_filter: VapidTracker,
64}
65
66#[cfg(test)]
67impl AppState {
68 pub(crate) async fn test_default(mock_db: autopush_common::db::mock::MockDbClient) -> Self {
70 let settings = Settings {
71 auth_keys: "HJVPy4ZwF4Yz_JdvXTL8hRcwIhv742vC60Tg5Ycrvw8=".to_owned(),
72 ..Default::default()
73 };
74 let metrics = Arc::new(crate::metrics::Metrics::sink());
75 let db = mock_db.into_boxed_arc();
76 #[cfg(feature = "reliable_report")]
77 let reliability =
78 Arc::new(PushReliability::new(&None, db.box_clone(), &metrics.clone(), 0).unwrap());
79 let fernet = settings.make_fernet();
80 Self {
81 metrics: metrics.clone(),
82 settings,
83 fernet,
84 db: db.clone(),
85 http: reqwest::Client::new(),
86 in_flight_requests: Arc::new(AtomicUsize::new(0)),
87 in_process_subscription_updates: Arc::new(AtomicUsize::new(0)),
88 fcm_router: Arc::new(
89 FcmRouter::new(
90 crate::routers::fcm::settings::FcmSettings::default(),
91 url::Url::parse("https://example.com").unwrap(),
92 reqwest::Client::new(),
93 metrics.clone(),
94 db.clone(),
95 #[cfg(feature = "reliable_report")]
96 reliability.clone(),
97 )
98 .await
99 .unwrap(),
100 ),
101 apns_router: Arc::new(
102 ApnsRouter::new(
103 crate::routers::apns::settings::ApnsSettings::default(),
104 url::Url::parse("https://example.com").unwrap(),
105 metrics.clone(),
106 db.clone(),
107 #[cfg(feature = "reliable_report")]
108 reliability.clone(),
109 )
110 .await
111 .unwrap(),
112 ),
113 #[cfg(feature = "stub")]
114 stub_router: Arc::new(
115 StubRouter::new(crate::routers::stub::settings::StubSettings::default()).unwrap(),
116 ),
117 #[cfg(feature = "reliable_report")]
118 reliability,
119 reliability_filter: VapidTracker(Vec::new()),
120 }
121 }
122}
123
124pub struct Server;
125
126impl Server {
127 pub async fn with_settings(settings: Settings) -> ApiResult<dev::Server> {
128 let metrics = Arc::new(metrics::metrics_from_settings(&settings)?);
129 let bind_address = format!("{}:{}", settings.host, settings.port);
130 let fernet = settings.make_fernet();
131 let endpoint_url = settings.endpoint_url();
132 let reliability_filter = VapidTracker(
133 settings
134 .tracking_keys()
135 .map_err(|e| ApiErrorKind::General(format!("Configuration Error: {e}")))?,
136 );
137 let db_settings = DbSettings {
138 dsn: settings.db_dsn.clone(),
139 db_settings: if settings.db_settings.is_empty() {
140 warn!("❗ Using obsolete message_table and router_table args");
141 json!({"message_table": settings.message_table_name, "router_table":settings.router_table_name}).to_string()
143 } else {
144 settings.db_settings.clone()
145 },
146 };
147 let db: Box<dyn DbClient> = match StorageType::from_dsn(&db_settings.dsn) {
148 #[cfg(feature = "bigtable")]
149 StorageType::BigTable => {
150 debug!("Using BigTable");
151 let client = BigTableClientImpl::new(metrics.clone(), &db_settings)?;
152 client.spawn_sweeper(Duration::from_secs(30));
153 Box::new(client)
154 }
155 #[cfg(feature = "postgres")]
156 StorageType::Postgres => {
157 debug!("Using Postgres");
158 let client = PgClientImpl::new(metrics.clone(), &db_settings)?;
159 Box::new(client)
161 }
162 #[cfg(feature = "redis")]
163 StorageType::Redis => Box::new(RedisClientImpl::new(metrics.clone(), &db_settings)?),
164 _ => {
165 debug!("No idea what {:?} is", &db_settings.dsn);
166 return Err(ApiErrorKind::General(
167 "Invalid or Unsupported DSN specified".to_owned(),
168 )
169 .into());
170 }
171 };
172 #[cfg(feature = "reliable_report")]
173 let reliability = Arc::new(
174 PushReliability::new(
175 &settings.reliability_dsn,
176 db.clone(),
177 &metrics,
178 settings.reliability_retry_count,
179 )
180 .map_err(|e| {
181 ApiErrorKind::General(format!("Could not initialize Reliability Report: {e:?}"))
182 })?,
183 );
184 let http = reqwest::ClientBuilder::new()
185 .connect_timeout(Duration::from_millis(settings.connection_timeout_millis))
186 .timeout(Duration::from_millis(settings.request_timeout_millis))
187 .pool_max_idle_per_host(settings.pool_max_idle_per_host)
188 .pool_idle_timeout(Duration::from_secs(settings.pool_idle_timeout_secs))
189 .build()
190 .expect("Could not generate request client");
191 let fcm_router = Arc::new(
192 FcmRouter::new(
193 settings.fcm.clone(),
194 endpoint_url.clone(),
195 http.clone(),
196 metrics.clone(),
197 db.clone(),
198 #[cfg(feature = "reliable_report")]
199 reliability.clone(),
200 )
201 .await?,
202 );
203 let apns_router = Arc::new(
204 ApnsRouter::new(
205 settings.apns.clone(),
206 endpoint_url.clone(),
207 metrics.clone(),
208 db.clone(),
209 #[cfg(feature = "reliable_report")]
210 reliability.clone(),
211 )
212 .await?,
213 );
214 #[cfg(feature = "stub")]
215 let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?);
216 let in_flight_requests = Arc::new(AtomicUsize::new(0));
217 let in_process_subscription_updates = Arc::new(AtomicUsize::new(0));
218 let app_state = AppState {
219 metrics: metrics.clone(),
220 settings,
221 fernet,
222 db,
223 http,
224 in_flight_requests: in_flight_requests.clone(),
225 in_process_subscription_updates: in_process_subscription_updates.clone(),
226 fcm_router,
227 apns_router,
228 #[cfg(feature = "stub")]
229 stub_router,
230 #[cfg(feature = "reliable_report")]
231 reliability,
232 reliability_filter,
233 };
234
235 spawn_pool_periodic_reporter(
236 Duration::from_secs(10),
237 app_state.db.clone(),
238 app_state.metrics.clone(),
239 );
240
241 {
243 let metrics = app_state.metrics.clone();
244 let in_flight = in_flight_requests;
245 tokio::spawn(async move {
246 let mut interval = tokio::time::interval(Duration::from_secs(10));
247 loop {
248 interval.tick().await;
249 let count = in_flight.load(Ordering::Relaxed);
250 if let Err(e) = cadence::Gauged::gauge(
251 metrics.as_ref(),
252 MetricName::InFlightNodeRequests.as_ref(),
253 count as u64,
254 ) {
255 debug!("Failed to report in-flight metric: {}", e);
256 }
257 }
258 });
259 }
260
261 let server = HttpServer::new(move || {
262 let cors = Cors::default()
264 .allow_any_origin()
265 .allow_any_header()
266 .allowed_methods(vec![
267 actix_web::http::Method::DELETE,
268 actix_web::http::Method::GET,
269 actix_web::http::Method::POST,
270 actix_web::http::Method::PUT,
271 ])
272 .max_age(3600);
273 let app = App::new()
274 .app_data(Data::new(app_state.clone()))
276 .app_data(web::PayloadConfig::new(app_state.settings.max_data_bytes))
278 .app_data(web::JsonConfig::default().limit(app_state.settings.max_data_bytes))
279 .wrap(ErrorHandlers::new().handler(StatusCode::NOT_FOUND, ApiError::render_404))
281 .wrap(SentryWrapper::<ApiError>::new(
283 metrics.clone(),
284 "api_error".to_owned(),
285 app_state.settings.disable_sentry,
286 ))
287 .wrap(cors)
288 .service(
290 web::resource(["/wpush/{api_version}/{token}", "/wpush/{token}"])
291 .route(web::post().to(webpush_route)),
292 )
293 .service(
294 web::resource("/m/{message_id}")
295 .route(web::delete().to(delete_notification_route)),
296 )
297 .service(
298 web::resource("/v1/{router_type}/{app_id}/registration")
299 .route(web::post().to(register_uaid_route)),
300 )
301 .service(
302 web::resource("/v1/{router_type}/{app_id}/registration/{uaid}")
303 .route(web::put().to(update_token_route))
304 .route(web::get().to(get_channels_route))
305 .route(web::delete().to(unregister_user_route)),
306 )
307 .service(
308 web::resource("/v1/{router_type}/{app_id}/registration/{uaid}/subscription")
309 .route(web::post().to(new_channel_route)),
310 )
311 .service(
312 web::resource(
313 "/v1/{router_type}/{app_id}/registration/{uaid}/subscription/{chid}",
314 )
315 .route(web::delete().to(unregister_channel_route)),
316 )
317 .service(web::resource("/status").route(web::get().to(status_route)))
319 .service(web::resource("/health").route(web::get().to(health_route)))
320 .service(web::resource("/v1/err").route(web::get().to(log_check)))
322 .service(web::resource("/__error__").route(web::get().to(log_check)))
324 .service(web::resource("/__heartbeat__").route(web::get().to(health_route)))
326 .service(web::resource("/__lbheartbeat__").route(web::get().to(lb_heartbeat_route)))
327 .service(web::resource("/__version__").route(web::get().to(version_route)));
328 #[cfg(feature = "reliable_report")]
329 let app = app.service(
333 web::resource("/metrics")
334 .route(web::get().to(crate::routes::reliability::report_handler)),
335 );
336 app
337 })
338 .bind(bind_address)?
339 .run();
340
341 Ok(server)
342 }
343}