1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use std::{error::Error, fmt};

use actix_ws::CloseCode;
use backtrace::Backtrace;

use autopush_common::{db::error::DbError, errors::ApcError, errors::ReportableError};

/// WebSocket state machine errors
#[derive(Debug)]
pub struct SMError {
    pub kind: SMErrorKind,
    backtrace: Option<Backtrace>,
}

impl fmt::Display for SMError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.kind)
    }
}

impl Error for SMError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        self.kind.source()
    }
}

// Forward From impls to SMError from SMErrorKind. Because From is reflexive,
// this impl also takes care of From<SMErrorKind>.
impl<T> From<T> for SMError
where
    SMErrorKind: From<T>,
{
    fn from(item: T) -> Self {
        let kind = SMErrorKind::from(item);
        let backtrace = (kind.is_sentry_event() && kind.capture_backtrace()).then(Backtrace::new);
        Self { kind, backtrace }
    }
}

impl SMError {
    pub fn close_code(&self) -> actix_ws::CloseCode {
        match self.kind {
            SMErrorKind::UaidReset => CloseCode::Normal,
            _ => CloseCode::Error,
        }
    }

    pub fn invalid_message(description: String) -> Self {
        SMErrorKind::InvalidMessage(description).into()
    }
}

impl ReportableError for SMError {
    fn reportable_source(&self) -> Option<&(dyn ReportableError + 'static)> {
        match &self.kind {
            SMErrorKind::MakeEndpoint(e) => Some(e),
            SMErrorKind::Database(e) => Some(e),
            _ => None,
        }
    }

    fn backtrace(&self) -> Option<&Backtrace> {
        self.backtrace.as_ref()
    }

    fn is_sentry_event(&self) -> bool {
        self.kind.is_sentry_event()
    }

    fn metric_label(&self) -> Option<&'static str> {
        match &self.kind {
            SMErrorKind::Database(e) => e.metric_label(),
            SMErrorKind::MakeEndpoint(e) => e.metric_label(),
            _ => None,
        }
    }
}

#[derive(thiserror::Error, Debug)]
pub enum SMErrorKind {
    #[error("Database error: {0}")]
    Database(#[from] DbError),

    #[error("Invalid WebPush message: {0}")]
    InvalidMessage(String),

    #[error("Internal error: {0}")]
    Internal(String),

    #[error("Reqwest error: {0}")]
    Reqwest(#[from] reqwest::Error),

    #[error("UAID dropped")]
    UaidReset,

    #[error("Already connected to another node")]
    AlreadyConnected,

    #[error("New Client with the same UAID has connected to this node")]
    Ghost,

    #[error("Failed to generate endpoint: {0}")]
    MakeEndpoint(#[source] ApcError),

    #[error("Client sent too many pings too often")]
    ExcessivePing,
}

impl SMErrorKind {
    /// Whether this error is reported to Sentry
    fn is_sentry_event(&self) -> bool {
        match self {
            SMErrorKind::Database(e) => e.is_sentry_event(),
            SMErrorKind::MakeEndpoint(e) => e.is_sentry_event(),
            SMErrorKind::Reqwest(_) | SMErrorKind::Internal(_) => true,
            _ => false,
        }
    }

    /// Whether this variant has a `Backtrace` captured
    ///
    /// Some Error variants have obvious call sites or more relevant backtraces
    /// in their sources and thus don't need a `Backtrace`. Furthermore
    /// backtraces are only captured for variants returning true from
    /// [Self::is_sentry_event].
    fn capture_backtrace(&self) -> bool {
        !matches!(self, SMErrorKind::MakeEndpoint(_))
    }
}

#[cfg(debug_assertions)]
/// Return a [SMErrorKind::Reqwest] [SMError] for tests
pub async fn __test_sm_reqwest_error() -> SMError {
    // An easily constructed reqwest::Error
    let e = reqwest::Client::builder()
        .https_only(true)
        .build()
        .unwrap()
        .get("http://example.com")
        .send()
        .await
        .unwrap_err();
    SMErrorKind::Reqwest(e).into()
}