autoendpoint/
error.rs

1//! Error types and transformations
2
3use crate::headers::vapid::VapidError;
4use crate::routers::RouterError;
5use actix_web::{
6    dev::ServiceResponse,
7    error::{JsonPayloadError, PayloadError, ResponseError},
8    http::header::{CacheControl, CacheDirective},
9    http::StatusCode,
10    middleware::ErrorHandlerResponse,
11    HttpResponse, Result,
12};
13// Sentry uses the backtrace crate, not std::backtrace.
14use actix_http::header;
15use backtrace::Backtrace;
16use serde::ser::SerializeMap;
17use serde::{Serialize, Serializer};
18use std::error::Error;
19use std::fmt::{self, Display};
20use thiserror::Error;
21use validator::{ValidationErrors, ValidationErrorsKind};
22
23use autopush_common::{db::error::DbError, errors::ReportableError};
24
25/// Common `Result` type.
26pub type ApiResult<T> = Result<T, ApiError>;
27
28/// A link for more info on the returned error
29const ERROR_URL: &str = "http://autopush.readthedocs.io/en/latest/http.html#error-codes";
30const RETRY_AFTER_PERIOD: &str = "120"; // retry after 2 minutes;
31
32/// The main error type.
33#[derive(Debug)]
34pub struct ApiError {
35    pub kind: ApiErrorKind,
36    pub backtrace: Backtrace,
37    pub extras: Option<Vec<(String, String)>>,
38}
39
40impl ApiError {
41    /// Render a 404 response
42    // wrapper during the move. this should switch to autopush-common's impl.
43    pub fn render_404<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
44        //TODO: remove unwrap here.
45        Ok(autopush_common::errors::render_404(res).unwrap())
46    }
47}
48
49/// The possible errors this application could encounter
50#[derive(Debug, Error)]
51pub enum ApiErrorKind {
52    #[error(transparent)]
53    Io(#[from] std::io::Error),
54
55    #[error(transparent)]
56    Metrics(#[from] cadence::MetricError),
57
58    #[error(transparent)]
59    Validation(#[from] validator::ValidationErrors),
60
61    #[error(transparent)]
62    PayloadError(actix_web::Error),
63
64    #[error(transparent)]
65    VapidError(#[from] VapidError),
66
67    #[error(transparent)]
68    Router(#[from] RouterError),
69
70    #[error(transparent)]
71    Jwt(#[from] jsonwebtoken::errors::Error),
72
73    #[error(transparent)]
74    Serde(#[from] serde_json::Error),
75
76    #[error(transparent)]
77    ReqwestError(#[from] reqwest::Error),
78
79    #[error("Error while validating token")]
80    TokenHashValidation(#[source] openssl::error::ErrorStack),
81
82    #[error("Error while creating secret")]
83    RegistrationSecretHash(#[source] openssl::error::ErrorStack),
84
85    #[error("Error while creating endpoint URL: {0}")]
86    EndpointUrl(#[source] autopush_common::errors::ApcError),
87
88    #[error("Database error: {0}")]
89    Database(#[from] DbError),
90
91    #[error("Conditional database operation failed: {0}")]
92    Conditional(String),
93
94    #[error("Invalid token")]
95    InvalidToken,
96
97    #[error("UAID not found")]
98    NoUser,
99
100    #[error("No such subscription")]
101    NoSubscription,
102
103    /// A specific issue with the encryption headers
104    #[error("{0}")]
105    InvalidEncryption(String),
106
107    /// Used if the API version given is not v1 or v2
108    #[error("Invalid API version")]
109    InvalidApiVersion,
110
111    #[error("Missing TTL value")]
112    NoTTL,
113
114    #[error("Invalid router type")]
115    InvalidRouterType,
116
117    #[error("Invalid router token")]
118    InvalidRouterToken,
119
120    #[error("Invalid message ID")]
121    InvalidMessageId,
122
123    #[error("Invalid Authentication")]
124    InvalidAuthentication,
125
126    #[error("Invalid Local Auth {0}")]
127    InvalidLocalAuth(String),
128
129    #[error("General error {0}")]
130    General(String),
131
132    #[error("ERROR:Success")]
133    LogCheck,
134}
135
136impl ApiErrorKind {
137    /// Get the associated HTTP status code
138    pub fn status(&self) -> StatusCode {
139        match self {
140            ApiErrorKind::PayloadError(e) => e.as_response_error().status_code(),
141            ApiErrorKind::Router(e) => e.status(),
142
143            ApiErrorKind::Validation(_)
144            | ApiErrorKind::InvalidEncryption(_)
145            | ApiErrorKind::NoTTL
146            | ApiErrorKind::InvalidRouterType
147            | ApiErrorKind::InvalidRouterToken
148            | ApiErrorKind::InvalidMessageId => StatusCode::BAD_REQUEST,
149
150            ApiErrorKind::VapidError(_)
151            | ApiErrorKind::Jwt(_)
152            | ApiErrorKind::Serde(_)
153            | ApiErrorKind::TokenHashValidation(_)
154            | ApiErrorKind::InvalidAuthentication
155            | ApiErrorKind::InvalidLocalAuth(_) => StatusCode::UNAUTHORIZED,
156
157            ApiErrorKind::InvalidToken | ApiErrorKind::InvalidApiVersion => StatusCode::NOT_FOUND,
158
159            ApiErrorKind::NoUser | ApiErrorKind::NoSubscription => StatusCode::GONE,
160
161            ApiErrorKind::LogCheck => StatusCode::IM_A_TEAPOT,
162
163            ApiErrorKind::Conditional(_) => StatusCode::SERVICE_UNAVAILABLE,
164
165            ApiErrorKind::Database(e) => e.status(),
166
167            ApiErrorKind::General(_)
168            | ApiErrorKind::Io(_)
169            | ApiErrorKind::Metrics(_)
170            | ApiErrorKind::EndpointUrl(_)
171            | ApiErrorKind::RegistrationSecretHash(_)
172            | ApiErrorKind::ReqwestError(_) => StatusCode::INTERNAL_SERVER_ERROR,
173        }
174    }
175
176    /// Specify the label to use for metrics reporting.
177    pub fn metric_label(&self) -> Option<&'static str> {
178        Some(match self {
179            ApiErrorKind::PayloadError(_) => "payload_error",
180            ApiErrorKind::Router(e) => return e.metric_label(),
181
182            ApiErrorKind::Validation(_) => "validation",
183            ApiErrorKind::InvalidEncryption(_) => "invalid_encryption",
184            ApiErrorKind::NoTTL => "no_ttl",
185            ApiErrorKind::InvalidRouterType => "invalid_router_type",
186            ApiErrorKind::InvalidRouterToken => "invalid_router_token",
187            ApiErrorKind::InvalidMessageId => "invalid_message_id",
188
189            ApiErrorKind::VapidError(_) => "vapid_error",
190            ApiErrorKind::Jwt(_) | ApiErrorKind::Serde(_) => "jwt",
191            ApiErrorKind::TokenHashValidation(_) => "token_hash_validation",
192            ApiErrorKind::InvalidAuthentication => "invalid_authentication",
193            ApiErrorKind::InvalidLocalAuth(_) => "invalid_local_auth",
194
195            ApiErrorKind::InvalidToken => "invalid_token",
196            ApiErrorKind::InvalidApiVersion => "invalid_api_version",
197
198            ApiErrorKind::NoUser => "no_user",
199            ApiErrorKind::NoSubscription => "no_subscription",
200
201            ApiErrorKind::LogCheck => "log_check",
202
203            ApiErrorKind::General(_) => "general",
204            ApiErrorKind::Io(_) => "io",
205            ApiErrorKind::Metrics(_) => "metrics",
206            ApiErrorKind::Database(e) => return e.metric_label(),
207            ApiErrorKind::Conditional(_) => "conditional",
208            ApiErrorKind::EndpointUrl(e) => return e.metric_label(),
209            ApiErrorKind::RegistrationSecretHash(_) => "registration_secret_hash",
210            ApiErrorKind::ReqwestError(_) => "reqwest",
211        })
212    }
213
214    /// Don't report all errors to sentry
215    pub fn is_sentry_event(&self) -> bool {
216        match self {
217            // ignore selected validation errors.
218            ApiErrorKind::Router(e) => e.is_sentry_event(),
219            ApiErrorKind::Database(e) => e.is_sentry_event(),
220            // Ignore common webpush errors
221            ApiErrorKind::NoTTL | ApiErrorKind::InvalidEncryption(_) |
222            // Ignore common VAPID erros
223            ApiErrorKind::VapidError(_)
224                | ApiErrorKind::Jwt(_)
225                | ApiErrorKind::TokenHashValidation(_)
226                | ApiErrorKind::InvalidAuthentication
227                | ApiErrorKind::InvalidLocalAuth(_) |
228            // Ignore missing or invalid user errors
229            ApiErrorKind::NoUser | ApiErrorKind::NoSubscription |
230            // Ignore oversized payload.
231            ApiErrorKind::PayloadError(_) |
232            ApiErrorKind::Validation(_) |
233            ApiErrorKind::Conditional(_) |
234            ApiErrorKind::ReqwestError(_) => false,
235            _ => true,
236        }
237    }
238
239    /// Get the associated error number
240    pub fn errno(&self) -> Option<usize> {
241        match self {
242            ApiErrorKind::Router(e) => e.errno(),
243
244            ApiErrorKind::Validation(e) => errno_from_validation_errors(e),
245
246            ApiErrorKind::InvalidToken | ApiErrorKind::InvalidApiVersion => Some(102),
247
248            ApiErrorKind::NoUser => Some(103),
249
250            ApiErrorKind::PayloadError(error)
251                if matches!(error.as_error(), Some(PayloadError::Overflow))
252                    || matches!(error.as_error(), Some(JsonPayloadError::Overflow { .. })) =>
253            {
254                Some(104)
255            }
256
257            ApiErrorKind::NoSubscription => Some(106),
258
259            ApiErrorKind::InvalidRouterType => Some(108),
260
261            ApiErrorKind::VapidError(_)
262            | ApiErrorKind::TokenHashValidation(_)
263            | ApiErrorKind::Jwt(_)
264            | ApiErrorKind::Serde(_)
265            | ApiErrorKind::InvalidAuthentication
266            | ApiErrorKind::InvalidLocalAuth(_) => Some(109),
267
268            ApiErrorKind::InvalidEncryption(_) => Some(110),
269
270            ApiErrorKind::NoTTL => Some(111),
271
272            ApiErrorKind::LogCheck => Some(999),
273
274            ApiErrorKind::General(_)
275            | ApiErrorKind::Io(_)
276            | ApiErrorKind::Metrics(_)
277            | ApiErrorKind::Database(_)
278            | ApiErrorKind::Conditional(_)
279            | ApiErrorKind::PayloadError(_)
280            | ApiErrorKind::InvalidRouterToken
281            | ApiErrorKind::RegistrationSecretHash(_)
282            | ApiErrorKind::EndpointUrl(_)
283            | ApiErrorKind::InvalidMessageId
284            | ApiErrorKind::ReqwestError(_) => None,
285        }
286    }
287}
288
289impl Display for ApiError {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        self.kind.fmt(f)
292    }
293}
294
295impl Error for ApiError {
296    fn source(&self) -> Option<&(dyn Error + 'static)> {
297        self.kind.source()
298    }
299}
300
301// Forward From impls to ApiError from ApiErrorKind. Because From is reflexive,
302// this impl also takes care of From<ApiErrorKind>.
303impl<T> From<T> for ApiError
304where
305    ApiErrorKind: From<T>,
306{
307    fn from(item: T) -> Self {
308        ApiError {
309            kind: ApiErrorKind::from(item),
310            backtrace: Backtrace::new(),
311            extras: None,
312        }
313    }
314}
315
316impl ResponseError for ApiError {
317    fn status_code(&self) -> StatusCode {
318        self.kind.status()
319    }
320
321    fn error_response(&self) -> HttpResponse {
322        let mut builder = HttpResponse::build(self.kind.status());
323
324        match self.status_code() {
325            StatusCode::GONE => {
326                builder.insert_header(CacheControl(vec![CacheDirective::MaxAge(86400)]));
327            }
328            StatusCode::SERVICE_UNAVAILABLE => {
329                builder.insert_header((header::RETRY_AFTER, RETRY_AFTER_PERIOD));
330            }
331            _ => {}
332        }
333
334        builder.json(self)
335    }
336}
337
338impl Serialize for ApiError {
339    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
340    where
341        S: Serializer,
342    {
343        let status = self.kind.status();
344        let mut map = serializer.serialize_map(Some(5))?;
345
346        map.serialize_entry("code", &status.as_u16())?;
347        map.serialize_entry("errno", &self.kind.errno())?;
348        map.serialize_entry("error", &status.canonical_reason())?;
349        map.serialize_entry("message", &self.kind.to_string())?;
350        map.serialize_entry("more_info", ERROR_URL)?;
351        map.end()
352    }
353}
354
355impl ReportableError for ApiError {
356    fn reportable_source(&self) -> Option<&(dyn ReportableError + 'static)> {
357        match &self.kind {
358            ApiErrorKind::EndpointUrl(e) => Some(e),
359            ApiErrorKind::Database(e) => Some(e),
360            _ => None,
361        }
362    }
363
364    fn backtrace(&self) -> Option<&Backtrace> {
365        Some(&self.backtrace)
366    }
367
368    fn is_sentry_event(&self) -> bool {
369        self.kind.is_sentry_event()
370    }
371
372    fn metric_label(&self) -> Option<&'static str> {
373        self.kind.metric_label()
374    }
375
376    fn extras(&self) -> Vec<(&str, String)> {
377        let mut extras: Vec<(&str, String)> = match &self.extras {
378            Some(extras) => extras.iter().map(|e| (e.0.as_str(), e.1.clone())).collect(),
379            None => Default::default(),
380        };
381
382        match &self.kind {
383            ApiErrorKind::Router(e) => extras.extend(e.extras()),
384            ApiErrorKind::LogCheck => extras.extend(vec![("coffee", "Unsupported".to_owned())]),
385            _ => {}
386        };
387        extras
388    }
389}
390
391/// Get the error number from validation errors. If multiple errors are present,
392/// the first one with a valid error code is used.
393fn errno_from_validation_errors(e: &ValidationErrors) -> Option<usize> {
394    // Build an iterator over the error numbers, then get the first one
395    e.errors()
396        .values()
397        .flat_map(|error| match error {
398            ValidationErrorsKind::Struct(inner_errors) => {
399                Box::new(errno_from_validation_errors(inner_errors).into_iter())
400                    as Box<dyn Iterator<Item = usize>>
401            }
402            ValidationErrorsKind::List(indexed_errors) => Box::new(
403                indexed_errors
404                    .values()
405                    .filter_map(|errors| errno_from_validation_errors(errors)),
406            )
407                as Box<dyn Iterator<Item = usize>>,
408            ValidationErrorsKind::Field(errors) => {
409                Box::new(errors.iter().filter_map(|error| error.code.parse().ok()))
410                    as Box<dyn Iterator<Item = usize>>
411            }
412        })
413        .next()
414}
415
416#[cfg(test)]
417mod tests {
418    use autopush_common::{db::error::DbError, sentry::event_from_error};
419
420    use crate::routers::RouterError;
421
422    use super::{ApiError, ApiErrorKind};
423    use crate::error::ReportableError;
424
425    #[test]
426    fn sentry_event_with_extras() {
427        let dbe = DbError::Integrity("foo".to_owned(), Some("bar".to_owned()));
428        let e: ApiError = ApiErrorKind::Database(dbe).into();
429        let event = event_from_error(&e);
430        assert_eq!(event.exception.len(), 2);
431        assert_eq!(event.exception[0].ty, "Integrity");
432        assert_eq!(event.exception[1].ty, "ApiError");
433        assert_eq!(event.extra.get("row"), Some(&"bar".into()));
434    }
435
436    /// Ensure that Pool error metric labels are specified and that they return a 503 status code.
437    #[cfg(feature = "bigtable")]
438    #[test]
439    fn test_label_for_metrics() {
440        // specifically test for a timeout on pool entry creation.
441        let e: ApiError = ApiErrorKind::Database(DbError::BTError(
442            autopush_common::db::bigtable::BigTableError::PoolTimeout(
443                deadpool::managed::TimeoutType::Create,
444            ),
445        ))
446        .into();
447
448        // Remember, `autoendpoint` is prefixed to this metric label.
449        assert_eq!(
450            e.kind.metric_label(),
451            Some("storage.bigtable.error.pool_timeout")
452        );
453
454        // "Retry-After" is applied on any 503 response (See ApiError::error_response)
455        assert_eq!(e.kind.status(), actix_http::StatusCode::SERVICE_UNAVAILABLE)
456    }
457
458    /// Ensure that extras set on a given error are included in the ApiError.extras() call.
459    #[tokio::test]
460    async fn pass_extras() {
461        let e = RouterError::NotFound;
462        let mut ae = ApiError::from(e);
463        ae.extras = Some([("foo".to_owned(), "bar".to_owned())].to_vec());
464
465        let aex: Vec<(&str, String)> = ae.extras();
466        assert!(aex.contains(&("foo", "bar".to_owned())));
467
468        let e = ApiErrorKind::LogCheck;
469        let mut ae = ApiError::from(e);
470        ae.extras = Some([("foo".to_owned(), "bar".to_owned())].to_vec());
471
472        let aex: Vec<(&str, String)> = ae.extras();
473        assert!(aex.contains(&("foo", "bar".to_owned())));
474        assert!(aex.contains(&("coffee", "Unsupported".to_owned())));
475    }
476}