autoconnect_ws_sm/
error.rs

1use std::{error::Error, fmt};
2
3use actix_ws::CloseCode;
4use backtrace::Backtrace;
5
6use autoconnect_common::protocol::{ClientMessage, MessageType, ServerMessage};
7use autopush_common::{db::error::DbError, errors::ApcError, errors::ReportableError};
8
9/// Trait for types that can provide a MessageType
10pub trait MessageTypeProvider {
11    /// Returns the message type of this object
12    fn message_type(&self) -> MessageType;
13}
14
15impl MessageTypeProvider for ClientMessage {
16    fn message_type(&self) -> MessageType {
17        self.message_type()
18    }
19}
20
21impl MessageTypeProvider for ServerMessage {
22    fn message_type(&self) -> MessageType {
23        self.message_type()
24    }
25}
26
27/// WebSocket state machine errors
28#[derive(Debug)]
29pub struct SMError {
30    pub kind: SMErrorKind,
31    backtrace: Option<Backtrace>,
32}
33
34impl fmt::Display for SMError {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        write!(f, "{}", self.kind)
37    }
38}
39
40impl Error for SMError {
41    fn source(&self) -> Option<&(dyn Error + 'static)> {
42        self.kind.source()
43    }
44}
45
46// Forward From impls to SMError from SMErrorKind. Because From is reflexive,
47// this impl also takes care of From<SMErrorKind>.
48impl<T> From<T> for SMError
49where
50    SMErrorKind: From<T>,
51{
52    fn from(item: T) -> Self {
53        let kind = SMErrorKind::from(item);
54        let backtrace = (kind.is_sentry_event() && kind.capture_backtrace()).then(Backtrace::new);
55        Self { kind, backtrace }
56    }
57}
58
59impl SMError {
60    pub fn close_code(&self) -> actix_ws::CloseCode {
61        match self.kind {
62            SMErrorKind::UaidReset => CloseCode::Normal,
63            _ => CloseCode::Error,
64        }
65    }
66
67    pub fn invalid_message(description: String) -> Self {
68        SMErrorKind::InvalidMessage(description).into()
69    }
70
71    /// Creates an invalid message error for an expected message type
72    pub fn expected_message_type(expected: MessageType) -> Self {
73        SMErrorKind::InvalidMessage(expected.expected_msg()).into()
74    }
75
76    /// Validates a message is of the expected type, returning an error if not
77    pub fn validate_message_type<T>(expected: MessageType, msg: &T) -> Result<(), Self>
78    where
79        T: MessageTypeProvider,
80    {
81        if msg.message_type() == expected {
82            Ok(())
83        } else {
84            Err(Self::expected_message_type(expected))
85        }
86    }
87}
88
89impl ReportableError for SMError {
90    fn reportable_source(&self) -> Option<&(dyn ReportableError + 'static)> {
91        match &self.kind {
92            SMErrorKind::MakeEndpoint(e) => Some(e),
93            SMErrorKind::Database(e) => Some(e),
94            _ => None,
95        }
96    }
97
98    fn backtrace(&self) -> Option<&Backtrace> {
99        self.backtrace.as_ref()
100    }
101
102    fn is_sentry_event(&self) -> bool {
103        self.kind.is_sentry_event()
104    }
105
106    fn metric_label(&self) -> Option<&'static str> {
107        match &self.kind {
108            SMErrorKind::Database(e) => e.metric_label(),
109            SMErrorKind::MakeEndpoint(e) => e.metric_label(),
110            _ => None,
111        }
112    }
113}
114
115#[derive(thiserror::Error, Debug)]
116pub enum SMErrorKind {
117    #[error("Database error: {0}")]
118    Database(#[from] DbError),
119
120    #[error("Invalid WebPush message: {0}")]
121    InvalidMessage(String),
122
123    #[error("Internal error: {0}")]
124    Internal(String),
125
126    #[error("Reqwest error: {0}")]
127    Reqwest(#[from] reqwest::Error),
128
129    #[error("UAID dropped")]
130    UaidReset,
131
132    #[error("Already connected to another node")]
133    AlreadyConnected,
134
135    #[error("New Client with the same UAID has connected to this node")]
136    Ghost,
137
138    #[error("Failed to generate endpoint: {0}")]
139    MakeEndpoint(#[source] ApcError),
140
141    #[error("Client sent too many pings too often")]
142    ExcessivePing,
143}
144
145impl SMErrorKind {
146    /// Whether this error is reported to Sentry
147    fn is_sentry_event(&self) -> bool {
148        match self {
149            SMErrorKind::Database(e) => e.is_sentry_event(),
150            SMErrorKind::MakeEndpoint(e) => e.is_sentry_event(),
151            SMErrorKind::Reqwest(_) | SMErrorKind::Internal(_) => true,
152            _ => false,
153        }
154    }
155
156    /// Whether this variant has a `Backtrace` captured
157    ///
158    /// Some Error variants have obvious call sites or more relevant backtraces
159    /// in their sources and thus don't need a `Backtrace`. Furthermore
160    /// backtraces are only captured for variants returning true from
161    /// [Self::is_sentry_event].
162    fn capture_backtrace(&self) -> bool {
163        !matches!(self, SMErrorKind::MakeEndpoint(_))
164    }
165}
166
167#[cfg(debug_assertions)]
168/// Return a [SMErrorKind::Reqwest] [SMError] for tests
169pub async fn __test_sm_reqwest_error() -> SMError {
170    // An easily constructed reqwest::Error
171    let e = reqwest::Client::builder()
172        .https_only(true)
173        .build()
174        .unwrap()
175        .get("http://example.com")
176        .send()
177        .await
178        .unwrap_err();
179    SMErrorKind::Reqwest(e).into()
180}