autopush_common/db/postgres/
mod.rs

1/* Postgres DbClient implementation.
2 * As noted elsewhere, autopush was originally designed to work with NoSql type databases.
3 * This implementation was done partially as an experiment. Postgres allows for limited
4 * NoSql-like functionality. The author, however, has VERY limited knowledge of postgres,
5 * and there are likely many inefficiencies in this implementation.
6 *
7 * PRs are always welcome.
8 */
9
10use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::{Duration, SystemTime};
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use deadpool_postgres::{Pool, Runtime};
19use serde_json::json;
20use tokio_postgres::{types::ToSql, Row}; // Client is sync.
21
22use async_trait::async_trait;
23use cadence::StatsdClient;
24use serde::{Deserialize, Serialize};
25use uuid::Uuid;
26
27use crate::db::client::DbClient;
28use crate::db::error::{DbError, DbResult};
29use crate::db::{DbSettings, User};
30use crate::notification::Notification;
31use crate::{util, MAX_ROUTER_TTL_SECS};
32
33use super::client::FetchMessageResponse;
34
35#[cfg(feature = "reliable_report")]
36const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
37
38#[derive(Debug, Clone, Deserialize, Serialize)]
39#[serde(default)]
40pub struct PostgresDbSettings {
41    pub schema: Option<String>,    // Optional DB Schema
42    pub router_table: String,      // Routing info
43    pub message_table: String,     // Message storage info
44    pub meta_table: String,        // Channels and meta info
45    pub reliability_table: String, // Channels and meta info
46    max_router_ttl: u64,           // Max time for router records to live.
47                                   // #[serde(default)]
48                                   // pub use_tls: bool // Should you use a TLS connection to the db.
49}
50
51impl Default for PostgresDbSettings {
52    fn default() -> Self {
53        Self {
54            schema: None,
55            router_table: "router".to_owned(),
56            message_table: "message".to_owned(),
57            meta_table: "meta".to_owned(),
58            reliability_table: "reliability".to_owned(),
59            max_router_ttl: MAX_ROUTER_TTL_SECS,
60            // use_tls: false,
61        }
62    }
63}
64
65impl TryFrom<&str> for PostgresDbSettings {
66    type Error = DbError;
67    fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
68        if setting_string.trim().is_empty() {
69            return Ok(PostgresDbSettings::default());
70        }
71        serde_json::from_str(setting_string).map_err(|e| {
72            DbError::General(format!(
73                "Could not parse configuration db_settings: {:?}",
74                e
75            ))
76        })
77    }
78}
79
80#[derive(Clone)]
81pub struct PgClientImpl {
82    _metrics: Arc<StatsdClient>,
83    db_settings: PostgresDbSettings,
84    pool: Pool,
85    /// Cached fully-qualified table names to avoid repeated format!/clone per query
86    cached_router_table: String,
87    cached_message_table: String,
88    cached_meta_table: String,
89    #[cfg(feature = "reliable_report")]
90    cached_reliability_table: String,
91}
92
93impl PgClientImpl {
94    /// Create a new Postgres Client.
95    ///
96    /// This uses the `settings.db_dsn`. to try and connect to the postgres database.
97    /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html
98    /// for parameter details and requirements.
99    /// Example DSN: postgresql://user:password@host/database?option=val
100    /// e.g. (postgresql://scott:tiger@dbhost/autopush?connect_timeout=10&keepalives_idle=3600)
101    pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
102        let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
103        info!(
104            "📮 Initializing Postgres DB Client with settings: {:?} from {:?}",
105            db_settings, &settings.db_settings
106        );
107        // TODO: If required, add the TlsConnect<Stream> wrapper here.
108        let tls_flag = tokio_postgres::NoTls;
109        if let Some(dsn) = settings.dsn.clone() {
110            trace!("📮 Postgres Connect {}", &dsn);
111
112            let pool = deadpool_postgres::Config {
113                url: Some(dsn.clone()),
114                ..Default::default()
115            }
116            .create_pool(Some(Runtime::Tokio1), tls_flag)
117            .map_err(|e| DbError::General(e.to_string()))?;
118            let cached_router_table = if let Some(schema) = &db_settings.schema {
119                format!("{}.{}", schema, db_settings.router_table)
120            } else {
121                db_settings.router_table.clone()
122            };
123            let cached_message_table = if let Some(schema) = &db_settings.schema {
124                format!("{}.{}", schema, db_settings.message_table)
125            } else {
126                db_settings.message_table.clone()
127            };
128            let cached_meta_table = if let Some(schema) = &db_settings.schema {
129                format!("{}.{}", schema, db_settings.meta_table)
130            } else {
131                db_settings.meta_table.clone()
132            };
133            #[cfg(feature = "reliable_report")]
134            let cached_reliability_table = if let Some(schema) = &db_settings.schema {
135                format!("{}.{}", schema, db_settings.reliability_table)
136            } else {
137                db_settings.reliability_table.clone()
138            };
139            return Ok(Self {
140                _metrics: metrics,
141                db_settings,
142                pool,
143                cached_router_table,
144                cached_message_table,
145                cached_meta_table,
146                #[cfg(feature = "reliable_report")]
147                cached_reliability_table,
148            });
149        };
150        Err(DbError::ConnectionError("No DSN specified".to_owned()))
151    }
152
153    /// Does the given table exist
154    async fn table_exists(&self, table_name: &str) -> DbResult<bool> {
155        let (schema, table_name) = if table_name.contains('.') {
156            let mut parts = table_name.splitn(2, '.');
157            (
158                parts.next().unwrap_or("public").to_owned(),
159                // If we are in a situation where someone specified a table name as
160                // `whatever.`, then we should absolutely panic here.
161                parts.next().unwrap().to_owned(),
162            )
163        } else {
164            ("public".to_owned(), table_name.to_owned())
165        };
166        let rows = self
167            .pool
168            .get()
169            .await
170            .map_err(DbError::PgPoolError)?
171            .query(
172                "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
173                &[&schema, &table_name],
174            )
175            .await
176            .map_err(DbError::PgError)?;
177        let val: bool = rows[0].try_get(0)?;
178        Ok(val)
179    }
180
181    /// Return the router's expiration timestamp
182    fn router_expiry(&self) -> u64 {
183        util::sec_since_epoch() + self.db_settings.max_router_ttl
184    }
185
186    /// The router table contains how to route messages to the recipient UAID.
187    pub(crate) fn router_table(&self) -> &str {
188        &self.cached_router_table
189    }
190
191    /// The message table contains stored messages for UAIDs.
192    pub(crate) fn message_table(&self) -> &str {
193        &self.cached_message_table
194    }
195
196    /// The meta table contains channel and other metadata for UAIDs.
197    /// With traditional "No-Sql" databases, this would be rolled into the
198    /// router table.
199    pub(crate) fn meta_table(&self) -> &str {
200        &self.cached_meta_table
201    }
202
203    /// The reliability table contains message delivery reliability states.
204    /// This is optional and should only be used to track internally generated
205    /// and consumed messages based on the VAPID public key signature.
206    #[cfg(feature = "reliable_report")]
207    pub(crate) fn reliability_table(&self) -> &str {
208        &self.cached_reliability_table
209    }
210
211    pub(crate) fn error_to_string(e: &tokio_postgres::Error) -> String {
212        e.as_db_error()
213            .map(|e| e.message().to_owned()) // Some errors have a useful message.
214            .unwrap_or_else(|| {
215                // Others are best to just convert to a string.
216                e.to_string()
217            })
218    }
219}
220
221#[async_trait]
222impl DbClient for PgClientImpl {
223    /// add user to router_table if not exists uaid
224    async fn add_user(&self, user: &User) -> DbResult<()> {
225        self.pool.get().await.map_err(DbError::PgPoolError)?.execute(
226            &format!("
227            INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
228             VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
229             ON CONFLICT (uaid) DO
230                UPDATE SET connected_at=EXCLUDED.connected_at,
231                    router_type=EXCLUDED.router_type,
232                    router_data=EXCLUDED.router_data,
233                    node_id=EXCLUDED.node_id,
234                    record_version=EXCLUDED.record_version,
235                    version=EXCLUDED.version,
236                    last_update=EXCLUDED.last_update,
237                    priv_channels=EXCLUDED.priv_channels,
238                    expiry=EXCLUDED.expiry
239                    ;
240            ", tablename=self.router_table()),
241            &[&user.uaid.simple().to_string(),              // 1
242            &(user.connected_at as i64),                    // 2
243            &user.router_type,                              // 3
244            &json!(user.router_data).to_string(),           // 4
245            &user.node_id,                                  // 5
246            &user.record_version.map(|i| i as i64),    // 6
247            &(user.version.map(|v| v.simple().to_string())),   // 7
248            &user.current_timestamp.map(|i| i as i64), // 8
249            &user.priv_channels.iter().map(|v| v.to_string()).collect::<Vec<String>>(),
250            &(self.router_expiry() as i64),                 // 10
251            ]
252        ).await.map_err(|e| {
253            DbError::PgDbError(Self::error_to_string(&e))
254        })?;
255        Ok(())
256    }
257
258    /// update user record in router_table at user.uaid
259    async fn update_user(&self, user: &mut User) -> DbResult<bool> {
260        let cmd = format!(
261            "UPDATE {tablename} SET connected_at=$2::BIGINT,
262                router_type=$3,
263                router_data=$4,
264                node_id=$5,
265                record_version=$6::BIGINT,
266                version=$7,
267                last_update=$8::BIGINT,
268                priv_channels=$9,
269                expiry=$10::BIGINT
270            WHERE
271                uaid = $1 AND connected_at < $2::BIGINT;
272            ",
273            tablename = self.router_table()
274        );
275        let result = self
276            .pool
277            .get()
278            .await
279            .map_err(DbError::PgPoolError)?
280            .execute(
281                &cmd,
282                &[
283                    &user.uaid.simple().to_string(),                 // 1
284                    &(user.connected_at as i64),                     // 2
285                    &user.router_type,                               // 3
286                    &json!(user.router_data).to_string(),            //4
287                    &user.node_id,                                   // 5
288                    &(user.record_version.map(|i| i as i64)),        // 6
289                    &(user.version.map(|v| v.simple().to_string())), // 7
290                    &user.current_timestamp.map(|i| i as i64),       //8
291                    &user
292                        .priv_channels
293                        .iter()
294                        .map(|v| v.to_string())
295                        .collect::<Vec<String>>(),
296                    &(self.router_expiry() as i64), // 10
297                ],
298            )
299            .await
300            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
301        Ok(result > 0)
302    }
303
304    /// fetch user information from router_table for uaid.
305    async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
306        let row = self.pool.get().await.map_err(DbError::PgPoolError)?
307        .query_opt(
308            &format!(
309                "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels
310                 FROM {tablename}
311                 WHERE uaid = $1",
312                tablename=self.router_table()
313            ),
314            &[&uaid.simple().to_string()]
315        )
316        .await
317        .map_err(|e| {
318            DbError::PgDbError(Self::error_to_string(&e))
319        })?;
320
321        let Some(row) = row else {
322            return Ok(None);
323        };
324
325        // I was tempted to make this a From impl, but realized that it would mean making autopush-common require a dependency.
326        // Maybe make this a deserialize?
327        let priv_channels = if let Ok(Some(channels)) =
328            row.try_get::<&str, Option<Vec<String>>>("priv_channels")
329        {
330            let mut priv_channels = HashSet::new();
331            for channel in channels.iter() {
332                let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
333                priv_channels.insert(uuid);
334            }
335            priv_channels
336        } else {
337            HashSet::new()
338        };
339        let resp = User {
340            uaid: *uaid,
341            connected_at: row
342                .try_get::<&str, i64>("connected_at")
343                .map_err(DbError::PgError)? as u64,
344            router_type: row
345                .try_get::<&str, String>("router_type")
346                .map_err(DbError::PgError)?,
347            router_data: serde_json::from_str(
348                row.try_get::<&str, &str>("router_data")
349                    .map_err(DbError::PgError)?,
350            )
351            .map_err(|e| DbError::General(e.to_string()))?,
352            node_id: row
353                .try_get::<&str, Option<String>>("node_id")
354                .map_err(DbError::PgError)?,
355            record_version: row
356                .try_get::<&str, Option<i64>>("record_version")
357                .map_err(DbError::PgError)?
358                .map(|v| v as u64),
359            current_timestamp: row
360                .try_get::<&str, Option<i64>>("last_update")
361                .map_err(DbError::PgError)?
362                .map(|v| v as u64),
363            version: row
364                .try_get::<&str, Option<String>>("version")
365                .map_err(DbError::PgError)?
366                // An invalid UUID here is a data integrity error.
367                .map(|v| {
368                    Uuid::from_str(&v).map_err(|e| {
369                        DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
370                    })
371                })
372                .transpose()?,
373            priv_channels,
374        };
375        Ok(Some(resp))
376    }
377
378    /// delete a user at uaid from router_table
379    async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
380        self.pool
381            .get()
382            .await?
383            .execute(
384                &format!(
385                    "DELETE FROM {tablename}
386                     WHERE uaid = $1",
387                    tablename = self.router_table()
388                ),
389                &[&uaid.simple().to_string()],
390            )
391            .await
392            .map_err(DbError::PgError)?;
393        Ok(())
394    }
395
396    /// update list of channel_ids for uaid in meta table
397    /// Note: a conflicting channel_id is ignored, since it's already registered.
398    /// This should probably be optimized into the router table as a set value,
399    /// however I'm not familiar enough with Postgres to do so at this time.
400    /// Channels can be somewhat ephemeral, and we also want to limit the potential of
401    /// race conditions when adding or removing channels, particularly for mobile devices.
402    /// For some efficiency (mostly around the mobile "daily refresh" call), I've broken
403    /// the channels out by UAID into this table.
404    async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
405        self.pool
406            .get()
407            .await
408            .map_err(DbError::PgPoolError)?
409            .execute(
410                &format!(
411                    "INSERT
412                     INTO {tablename} (uaid, channel_id) VALUES ($1, $2)
413                     ON CONFLICT DO NOTHING",
414                    tablename = self.meta_table()
415                ),
416                &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
417            )
418            .await
419            .map_err(DbError::PgError)?;
420        Ok(())
421    }
422
423    /// Save all channels in a list
424    async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
425        if channels.is_empty() {
426            trace!("📮 No channels to save.");
427            return Ok(());
428        };
429        let uaid_str = uaid.simple().to_string();
430        // TODO: REVISIT THIS!!! May be possible without the gross hack.
431        // tokio-postgres doesn't store tuples as values, so you can't just construct
432        // the query as `INSERT into ... (a, b) VALUES (?,?), (?,?)`
433        // It does accept them as numerically specified values.
434        // The following is a gross hack that does basically that.
435        // The other option would be to just repeatedly call `self.add_channel()`
436        // but that seems far worse.
437        //
438        // first, collect the values into a flat fector. We force the type in
439        // the first item so that the second one is assumed.
440
441        let mut params = Vec::<&(dyn ToSql + Sync)>::new();
442        let mut new_channels = Vec::<String>::new();
443        for channel in channels {
444            new_channels.push(channel.simple().to_string());
445        }
446
447        // This is a bit cheesy, but we want capture the reference, not the value since
448        // it will not live long enough. Clippy will complain that this is a needless
449        // iterator, but it's not, really.
450        #[allow(clippy::needless_range_loop)]
451        for i in 0..new_channels.len() {
452            params.push(&uaid_str);
453            params.push(&new_channels[i]);
454        }
455
456        // Now construct the statement, iterate over the parameters we've got
457        // and redistribute them into tuples.
458        // (Remember, an existing channel_id is ignored during this insert since it's already registered)
459        let statement = format!(
460            "INSERT
461                INTO {tablename} (uaid, channel_id)
462                VALUES {vars}
463                ON CONFLICT DO NOTHING",
464            tablename = self.meta_table(),
465            // Postgres variables are 1-indexed.
466            vars = Vec::from_iter((1..params.len() + 1).step_by(2).map(|v| format!(
467                "(${}, ${})",
468                v,
469                v + 1
470            )))
471            .join(",")
472        );
473        // finally, do the insert.
474        self.pool.get().await?.execute(&statement, &params).await?;
475        Ok(())
476    }
477
478    /// get all channels for uaid from meta table
479    async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
480        let mut result = HashSet::new();
481        let rows = self
482            .pool
483            .get()
484            .await
485            .map_err(DbError::PgPoolError)?
486            .query(
487                &format!(
488                    "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
489                    tablename = self.meta_table()
490                ),
491                &[&uaid.simple().to_string()],
492            )
493            .await
494            .map_err(DbError::PgError)?;
495        for row in rows.iter() {
496            let s = row
497                .try_get::<&str, &str>("channel_id")
498                .map_err(DbError::PgError)?;
499            result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
500        }
501        Ok(result)
502    }
503
504    /// remove an individual channel for a given uaid from meta table
505    async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
506        let cmd = format!(
507            "DELETE FROM {tablename}
508             WHERE uaid = $1 AND channel_id = $2;",
509            tablename = self.meta_table()
510        );
511        let result = self
512            .pool
513            .get()
514            .await
515            .map_err(DbError::PgPoolError)?
516            .execute(
517                &cmd,
518                &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
519            )
520            .await?;
521        // We sometimes want to know if the channel existed previously.
522        Ok(result > 0)
523    }
524
525    /// remove node info for a uaid from router table
526    async fn remove_node_id(
527        &self,
528        uaid: &Uuid,
529        node_id: &str,
530        connected_at: u64,
531        version: &Option<Uuid>,
532    ) -> DbResult<bool> {
533        let Some(version) = version else {
534            return Err(DbError::General("Expected a user version field".to_owned()));
535        };
536        self.pool
537            .get()
538            .await
539            .map_err(DbError::PgPoolError)?
540            .execute(
541                &format!(
542                    "UPDATE {tablename}
543                        SET node_id = null
544                        WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
545                    tablename = self.router_table()
546                ),
547                &[
548                    &uaid.simple().to_string(),
549                    &node_id,
550                    &(connected_at as i64),
551                    &version.simple().to_string(),
552                ],
553            )
554            .await
555            .map_err(DbError::PgError)?;
556        Ok(true)
557    }
558
559    /// write a message to message table
560    async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
561        // fun fact: serde_postgres exists, but only deserializes (as of 0.2)
562        // (This is mutable if `reliable_report` enabled)
563        #[allow(unused_mut)]
564        let mut fields = vec![
565            "uaid",
566            "channel_id",
567            "chid_message_id",
568            "version",
569            "ttl",
570            "expiry",
571            "topic",
572            "timestamp",
573            "data",
574            "sortkey_timestamp",
575            "headers",
576        ];
577        // (This is mutable if `reliable_report` enabled)
578        #[allow(unused_mut)]
579        let mut inputs = vec![
580            "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
581        ];
582        #[cfg(feature = "reliable_report")]
583        {
584            fields.append(&mut ["reliability_id"].to_vec());
585            inputs.append(&mut ["$12"].to_vec());
586        }
587        let cmd = format!(
588            "INSERT INTO {tablename}
589                ({fields})
590                VALUES
591                ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
592                    uaid=EXCLUDED.uaid,
593                    channel_id=EXCLUDED.channel_id,
594                    version=EXCLUDED.version,
595                    ttl=EXCLUDED.ttl,
596                    expiry=EXCLUDED.expiry,
597                    topic=EXCLUDED.topic,
598                    timestamp=EXCLUDED.timestamp,
599                    data=EXCLUDED.data,
600                    sortkey_timestamp=EXCLUDED.sortkey_timestamp,
601                    headers=EXCLUDED.headers",
602            tablename = &self.message_table(),
603            fields = fields.join(","),
604            inputs = inputs.join(",")
605        );
606        self.pool
607            .get()
608            .await
609            .map_err(DbError::PgPoolError)?
610            .execute(
611                &cmd,
612                &[
613                    &uaid.simple().to_string(),
614                    &message.channel_id.simple().to_string(),
615                    &message.chidmessageid(),
616                    &message.version,
617                    &(message.ttl as i64), // Postgres has no auto TTL.
618                    &(util::sec_since_epoch() as i64 + message.ttl as i64),
619                    &message.topic.as_ref().filter(|v| !v.is_empty()),
620                    &(message.timestamp as i64),
621                    &message.data.as_deref().unwrap_or_default(),
622                    &message.sortkey_timestamp.map(|v| v as i64),
623                    &json!(message.headers).to_string(),
624                    #[cfg(feature = "reliable_report")]
625                    &message.reliability_id,
626                ],
627            )
628            .await
629            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
630        Ok(())
631    }
632
633    /// remove a given message from the message table
634    async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
635        debug!(
636            "📮 Removing message for user {} with chid_message_id {}",
637            uaid.simple(),
638            chidmessageid
639        );
640        let result = self
641            .pool
642            .get()
643            .await
644            .map_err(DbError::PgPoolError)?
645            .execute(
646                &format!(
647                    "DELETE FROM {tablename}
648                     WHERE uaid=$1 AND chid_message_id = $2;",
649                    tablename = self.message_table()
650                ),
651                &[&uaid.simple().to_string(), &chidmessageid],
652            )
653            .await
654            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
655        debug!(
656            "📮 Deleted {} rows for user {} with chid_message_id {}",
657            result,
658            uaid.simple(),
659            chidmessageid
660        );
661        Ok(())
662    }
663
664    async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
665        if messages.is_empty() {
666            return Ok(());
667        }
668        if messages.len() == 1 {
669            // Fast path: single message doesn't need batch construction
670            return self
671                .save_message(uaid, messages.into_iter().next().unwrap())
672                .await;
673        }
674
675        #[cfg(not(feature = "reliable_report"))]
676        let fields_per_row: usize = 11;
677        #[cfg(feature = "reliable_report")]
678        let fields_per_row: usize = 12;
679
680        let field_names = {
681            #[allow(unused_mut)]
682            let mut fields = vec![
683                "uaid",
684                "channel_id",
685                "chid_message_id",
686                "version",
687                "ttl",
688                "expiry",
689                "topic",
690                "timestamp",
691                "data",
692                "sortkey_timestamp",
693                "headers",
694            ];
695            #[cfg(feature = "reliable_report")]
696            fields.push("reliability_id");
697            fields.join(",")
698        };
699
700        // Build parameterized value rows: ($1,$2,...,$11), ($12,$13,...,$22), ...
701        let mut value_rows = Vec::with_capacity(messages.len());
702        for i in 0..messages.len() {
703            let base = i * fields_per_row;
704            let params: Vec<String> = (1..=fields_per_row)
705                .map(|j| format!("${}", base + j))
706                .collect();
707            value_rows.push(format!("({})", params.join(",")));
708        }
709
710        let cmd = format!(
711            "INSERT INTO {tablename}
712                ({field_names})
713                VALUES
714                {values} ON CONFLICT (chid_message_id) DO UPDATE SET
715                    uaid=EXCLUDED.uaid,
716                    channel_id=EXCLUDED.channel_id,
717                    version=EXCLUDED.version,
718                    ttl=EXCLUDED.ttl,
719                    expiry=EXCLUDED.expiry,
720                    topic=EXCLUDED.topic,
721                    timestamp=EXCLUDED.timestamp,
722                    data=EXCLUDED.data,
723                    sortkey_timestamp=EXCLUDED.sortkey_timestamp,
724                    headers=EXCLUDED.headers",
725            tablename = &self.message_table(),
726            values = value_rows.join(",")
727        );
728
729        // Build parameter list: we need owned values for strings
730        let uaid_str = uaid.simple().to_string();
731        let now = util::sec_since_epoch() as i64;
732
733        // Pre-compute owned values for each message
734        struct MessageParams {
735            channel_id: String,
736            chidmessageid: String,
737            version: String,
738            ttl: i64,
739            expiry: i64,
740            topic: Option<String>,
741            timestamp: i64,
742            data: String,
743            sortkey_timestamp: Option<i64>,
744            headers: String,
745            #[cfg(feature = "reliable_report")]
746            reliability_id: Option<String>,
747        }
748
749        let msg_params: Vec<MessageParams> = messages
750            .into_iter()
751            .map(|m| {
752                let topic = m.topic.as_ref().filter(|v| !v.is_empty()).cloned();
753                MessageParams {
754                    channel_id: m.channel_id.simple().to_string(),
755                    chidmessageid: m.chidmessageid(),
756                    version: m.version,
757                    ttl: m.ttl as i64,
758                    expiry: now + m.ttl as i64,
759                    topic,
760                    timestamp: m.timestamp as i64,
761                    data: m.data.as_deref().unwrap_or_default().to_owned(),
762                    sortkey_timestamp: m.sortkey_timestamp.map(|v| v as i64),
763                    headers: json!(m.headers).to_string(),
764                    #[cfg(feature = "reliable_report")]
765                    reliability_id: m.reliability_id,
766                }
767            })
768            .collect();
769
770        let mut params: Vec<&(dyn ToSql + Sync)> =
771            Vec::with_capacity(msg_params.len() * fields_per_row);
772        for mp in &msg_params {
773            params.push(&uaid_str);
774            params.push(&mp.channel_id);
775            params.push(&mp.chidmessageid);
776            params.push(&mp.version);
777            params.push(&mp.ttl);
778            params.push(&mp.expiry);
779            params.push(&mp.topic);
780            params.push(&mp.timestamp);
781            params.push(&mp.data);
782            params.push(&mp.sortkey_timestamp);
783            params.push(&mp.headers);
784            #[cfg(feature = "reliable_report")]
785            params.push(&mp.reliability_id);
786        }
787
788        self.pool
789            .get()
790            .await
791            .map_err(DbError::PgPoolError)?
792            .execute(cmd.as_str(), &params)
793            .await
794            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
795        Ok(())
796    }
797
798    /// fetch topic messages for the user up to {limit}
799    /// Topic messages are auto-replacing singleton messages for a given user.
800    async fn fetch_topic_messages(
801        &self,
802        uaid: &Uuid,
803        limit: usize,
804    ) -> DbResult<FetchMessageResponse> {
805        let messages: Vec<Notification> = self
806            .pool
807            .get()
808            .await
809            .map_err(DbError::PgPoolError)?
810            .query(
811                &format!(
812                "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers
813                 FROM {tablename}
814                 WHERE uaid=$1 AND expiry >= $2 AND (topic IS NOT NULL AND topic != '')
815                 ORDER BY timestamp DESC
816                 LIMIT $3",
817                tablename=&self.message_table(),
818            ),
819                &[
820                    &uaid.simple().to_string(),
821                    &(util::sec_since_epoch() as i64),
822                    &(limit as i64),
823                ],
824            )
825            .await
826            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?
827            .iter()
828            .map(|row: &Row| row.try_into())
829            .collect::<Result<Vec<Notification>, DbError>>()?;
830
831        if messages.is_empty() {
832            Ok(Default::default())
833        } else {
834            Ok(FetchMessageResponse {
835                timestamp: Some(messages[0].timestamp),
836                messages,
837            })
838        }
839    }
840
841    /// Fetch messages for a user on or after a given timestamp up to {limit}
842    async fn fetch_timestamp_messages(
843        &self,
844        uaid: &Uuid,
845        timestamp: Option<u64>,
846        limit: usize,
847    ) -> DbResult<FetchMessageResponse> {
848        let uaid = uaid.simple().to_string();
849        let response: Vec<Row> = if let Some(ts) = timestamp {
850            trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
851            self.pool
852                .get()
853                .await
854                .map_err(DbError::PgPoolError)?
855                .query(
856                    &format!(
857                        "SELECT * FROM {}
858                         WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
859                         ORDER BY timestamp
860                         LIMIT $4",
861                        self.message_table()
862                    ),
863                    &[
864                        &uaid,
865                        &(ts as i64),
866                        &(util::sec_since_epoch() as i64),
867                        &(limit as i64),
868                    ],
869                )
870                .await
871        } else {
872            trace!("📮 Fetching messages for user {}", &uaid);
873            self.pool
874                .get()
875                .await
876                .map_err(DbError::PgPoolError)?
877                .query(
878                    &format!(
879                        "SELECT *
880                         FROM {}
881                         WHERE uaid = $1
882                         AND expiry >= $2
883                         LIMIT $3",
884                        self.message_table()
885                    ),
886                    &[&uaid, &(util::sec_since_epoch() as i64), &(limit as i64)],
887                )
888                .await
889        }
890        .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
891        let messages: Vec<Notification> = response
892            .iter()
893            .map(|row: &Row| row.try_into())
894            .collect::<Result<Vec<Notification>, DbError>>()?;
895        let timestamp = if !messages.is_empty() {
896            Some(messages[0].timestamp)
897        } else {
898            None
899        };
900
901        Ok(FetchMessageResponse {
902            timestamp,
903            messages,
904        })
905    }
906
907    /// Convenience function to check if the router table exists
908    async fn router_table_exists(&self) -> DbResult<bool> {
909        self.table_exists(self.router_table()).await
910    }
911
912    /// Convenience function to check if the message table exists
913    async fn message_table_exists(&self) -> DbResult<bool> {
914        self.table_exists(self.message_table()).await
915    }
916
917    #[cfg(feature = "reliable_report")]
918    async fn log_report(
919        &self,
920        reliability_id: &str,
921        new_state: crate::reliability::ReliabilityState,
922    ) -> DbResult<()> {
923        let timestamp =
924            SystemTime::now() + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64);
925        debug!("📮 Logging report for {reliability_id} as {new_state}");
926        /*
927            INSERT INTO {tablename} (id, states) VALUES ({reliability_id}, json_build_object({state}, {timestamp}) )
928            ON CONFLICT(id)
929            UPDATE {tablename} SET states = jsonb_set(states, array[{state}], to_jsonb({timestamp}));
930        */
931
932        let tablename = &self.reliability_table();
933        let state = new_state.to_string();
934        self.pool
935            .get()
936            .await?
937            .execute(
938                &format!(
939                "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
940                 ON CONFLICT (id) DO
941                 UPDATE SET states = EXCLUDED.states,
942                 last_update_timestamp = EXCLUDED.last_update_timestamp;",
943                    tablename = tablename
944                ),
945                &[&reliability_id, &state, &timestamp],
946            )
947            .await
948            .map_err(|e|{DbError::PgDbError(Self::error_to_string(&e))})?;
949        Ok(())
950    }
951
952    async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
953        debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
954        let tablename = &self.router_table();
955
956        trace!("📮 Purging git{uaid} for < {timestamp}");
957        let mut pool = self.pool.get().await.map_err(DbError::PgPoolError)?;
958        let transaction = pool.transaction().await?;
959        // Try to garbage collect old messages first.
960        transaction
961            .execute(
962                &format!(
963                    "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
964                    &self.message_table()
965                ),
966                &[
967                    &uaid.simple().to_string(),
968                    &(util::sec_since_epoch() as i64),
969                ],
970            )
971            .await
972            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
973        // Now, delete messages that we've already delivered.
974        transaction
975            .execute(
976                &format!(
977                    "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
978                    &self.message_table()
979                ),
980                &[&uaid.simple().to_string(), &(timestamp as i64)],
981            )
982            .await
983            .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
984        transaction.execute(
985                &format!(
986                    "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
987                ),
988                &[
989                    &uaid.simple().to_string(),
990                    &(timestamp as i64),
991                    &(self.router_expiry() as i64),
992                ],
993            )
994            .await.map_err(|e|{DbError::PgDbError(Self::error_to_string(&e))})?;
995        transaction.commit().await?;
996        Ok(())
997    }
998
999    fn name(&self) -> String {
1000        "Postgres".to_owned()
1001    }
1002
1003    async fn health_check(&self) -> DbResult<bool> {
1004        // Replace this with a proper health check.
1005        let client = self.pool.get().await.map_err(DbError::PgPoolError);
1006        let row = client?.query_one("select true", &[]).await;
1007        if !row?.try_get::<_, bool>(0)? {
1008            error!("📮 Failed to fetch from database");
1009            return Ok(false);
1010        }
1011        if !self.router_table_exists().await? {
1012            error!("📮 Router table does not exist");
1013            return Ok(false);
1014        }
1015        if !self.message_table_exists().await? {
1016            error!("📮 Message table does not exist");
1017            return Ok(false);
1018        }
1019        Ok(true)
1020    }
1021
1022    /// Convenience function to return self as a Boxed DbClient
1023    fn box_clone(&self) -> Box<dyn DbClient> {
1024        Box::new(self.clone())
1025    }
1026}
1027
1028/* Note:
1029 * For preliminary testing, you will need to start a local postgres instance (see
1030 * https://www.docker.com/blog/how-to-use-the-postgres-docker-official-image/) and initialize the
1031 * database with `schema.psql`.
1032 * Once you have, you can define the environment variable `POSTGRES_HOST` to point to the
1033 * appropriate host (e.g. `postgres:post_pass@localhost:/autopush`). `new_client` will add the
1034 * `postgres://` prefix automatically.
1035 *
1036 * TODO: Really should move the bulk of the tests to a higher level and add backend specific
1037 * versions of `new_client`.
1038 *
1039 */
1040#[cfg(test)]
1041mod tests {
1042    use crate::util::sec_since_epoch;
1043    use crate::{logging::init_test_logging, util::ms_since_epoch};
1044    use rand::prelude::*;
1045    use serde_json::json;
1046    use std::env;
1047
1048    use super::*;
1049    const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
1050    const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
1051
1052    fn new_client() -> DbResult<PgClientImpl> {
1053        // Use an environment variable to potentially override the default storage test host.
1054        let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
1055        let env_dsn = format!("postgres://{host}");
1056        debug!("📮 Connecting to {env_dsn}");
1057        let settings = DbSettings {
1058            dsn: Some(env_dsn),
1059            db_settings: json!(PostgresDbSettings {
1060                schema: Some("autopush".to_owned()),
1061                ..Default::default()
1062            })
1063            .to_string(),
1064        };
1065        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
1066        PgClientImpl::new(metrics, &settings)
1067    }
1068
1069    fn gen_test_user() -> String {
1070        // Create a semi-unique test user to avoid conflicting test values.
1071        let mut rng = rand::rng();
1072        let test_num = rng.random::<u8>();
1073        format!(
1074            "DEADBEEF-0000-0000-{:04}-{:012}",
1075            test_num,
1076            sec_since_epoch()
1077        )
1078    }
1079
1080    #[actix_rt::test]
1081    async fn health_check() {
1082        let client = new_client().unwrap();
1083
1084        let result = client.health_check().await;
1085        assert!(result.is_ok());
1086        assert!(result.unwrap());
1087    }
1088
1089    /// Test if [increment_storage] correctly wipe expired messages
1090    #[actix_rt::test]
1091    async fn wipe_expired() -> DbResult<()> {
1092        init_test_logging();
1093        let client = new_client()?;
1094
1095        let connected_at = ms_since_epoch();
1096
1097        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1098        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1099
1100        let node_id = "test_node".to_owned();
1101
1102        // purge the user record if it exists.
1103        let _ = client.remove_user(&uaid).await;
1104
1105        let test_user = User {
1106            uaid,
1107            router_type: "webpush".to_owned(),
1108            connected_at,
1109            router_data: None,
1110            node_id: Some(node_id.clone()),
1111            ..Default::default()
1112        };
1113
1114        // purge the old user (if present)
1115        // in case a prior test failed for whatever reason.
1116        let _ = client.remove_user(&uaid).await;
1117
1118        // can we add the user?
1119        let timestamp = sec_since_epoch();
1120        client.add_user(&test_user).await?;
1121        let test_notification = crate::db::Notification {
1122            channel_id: chid,
1123            version: "test".to_owned(),
1124            ttl: 1,
1125            timestamp,
1126            data: Some("Encrypted".into()),
1127            sortkey_timestamp: Some(timestamp),
1128            ..Default::default()
1129        };
1130        client.save_message(&uaid, test_notification).await?;
1131        client.increment_storage(&uaid, timestamp + 1).await?;
1132        let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1133        assert_eq!(msgs.messages.len(), 0);
1134        assert!(client.remove_user(&uaid).await.is_ok());
1135        Ok(())
1136    }
1137
1138    /// run a gauntlet of testing. These are a bit linear because they need
1139    /// to run in sequence.
1140    #[actix_rt::test]
1141    async fn run_gauntlet() -> DbResult<()> {
1142        init_test_logging();
1143        let client = new_client()?;
1144
1145        let connected_at = ms_since_epoch();
1146
1147        let user_id = &gen_test_user();
1148        let uaid = Uuid::parse_str(user_id).unwrap();
1149        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1150        let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
1151
1152        let node_id = "test_node".to_owned();
1153
1154        // purge the user record if it exists.
1155        let _ = client.remove_user(&uaid).await;
1156
1157        let test_user = User {
1158            uaid,
1159            router_type: "webpush".to_owned(),
1160            connected_at,
1161            router_data: None,
1162            node_id: Some(node_id.clone()),
1163            ..Default::default()
1164        };
1165
1166        // purge the old user (if present)
1167        // in case a prior test failed for whatever reason.
1168        let _ = client.remove_user(&uaid).await;
1169
1170        // can we add the user?
1171        trace!("📮 Adding user {}", &user_id);
1172        client.add_user(&test_user).await?;
1173        let fetched = client.get_user(&uaid).await?;
1174        assert!(fetched.is_some());
1175        let fetched = fetched.unwrap();
1176        assert_eq!(fetched.router_type, "webpush".to_owned());
1177
1178        // can we add channels?
1179        trace!("📮 Adding channel {} to user {}", &chid, &user_id);
1180        client.add_channel(&uaid, &chid).await?;
1181        let channels = client.get_channels(&uaid).await?;
1182        assert!(channels.contains(&chid));
1183
1184        // can we add lots of channels?
1185        let mut new_channels: HashSet<Uuid> = HashSet::new();
1186        trace!("📮 Adding multiple channels to user {}", &user_id);
1187        new_channels.insert(chid);
1188        for _ in 1..10 {
1189            new_channels.insert(uuid::Uuid::new_v4());
1190        }
1191        let chid_to_remove = uuid::Uuid::new_v4();
1192        trace!(
1193            "📮 Adding removable channel {} to user {}",
1194            &chid_to_remove,
1195            &user_id
1196        );
1197        new_channels.insert(chid_to_remove);
1198        client.add_channels(&uaid, new_channels.clone()).await?;
1199        let channels = client.get_channels(&uaid).await?;
1200        assert_eq!(channels, new_channels);
1201
1202        // can we remove a channel?
1203        trace!(
1204            "📮 Removing channel {} from user {}",
1205            &chid_to_remove,
1206            &user_id
1207        );
1208        assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1209        trace!(
1210            "📮 retrying Removing channel {} from user {}",
1211            &chid_to_remove,
1212            &user_id
1213        );
1214        assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1215        new_channels.remove(&chid_to_remove);
1216        let channels = client.get_channels(&uaid).await?;
1217        assert_eq!(channels, new_channels);
1218
1219        // now ensure that we can update a user that's after the time we set
1220        // prior. first ensure that we can't update a user that's before the
1221        // time we set prior to the last write
1222        let mut updated = User {
1223            connected_at,
1224            ..test_user.clone()
1225        };
1226        trace!(
1227            "📮 Attempting to update user {} with old connected_at: {}",
1228            &user_id,
1229            &updated.connected_at
1230        );
1231        let result = client.update_user(&mut updated).await;
1232        assert!(result.is_ok());
1233        assert!(!result.unwrap());
1234
1235        // Make sure that the `connected_at` wasn't modified
1236        let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1237        assert_eq!(fetched.connected_at, fetched2.connected_at);
1238
1239        // and make sure we can update a record with a later connected_at time.
1240        let mut updated = User {
1241            connected_at: fetched.connected_at + 300,
1242            ..fetched2
1243        };
1244        trace!(
1245            "📮 Attempting to update user {} with new connected_at",
1246            &user_id
1247        );
1248        let result = client.update_user(&mut updated).await;
1249        assert!(result.is_ok());
1250        assert!(result.unwrap());
1251        assert_ne!(
1252            fetched2.connected_at,
1253            client.get_user(&uaid).await?.unwrap().connected_at
1254        );
1255        // can we increment the storage for the user?
1256        trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1257        client
1258            .increment_storage(&fetched.uaid, sec_since_epoch())
1259            .await?;
1260
1261        let test_data = "An_encrypted_pile_of_crap".to_owned();
1262        let timestamp = sec_since_epoch();
1263        let sort_key = sec_since_epoch();
1264        let fetch_timestamp = timestamp;
1265        // Can we store a message?
1266        let test_notification = crate::db::Notification {
1267            channel_id: chid,
1268            version: "test".to_owned(),
1269            ttl: 300,
1270            timestamp,
1271            data: Some(test_data.clone()),
1272            sortkey_timestamp: Some(sort_key),
1273            ..Default::default()
1274        };
1275        trace!("📮 Saving message for user {}", &user_id);
1276        let res = client.save_message(&uaid, test_notification.clone()).await;
1277        assert!(res.is_ok());
1278
1279        trace!("📮 Fetching all messages for user {}", &user_id);
1280        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1281        assert_ne!(fetched.messages.len(), 0);
1282        let fm = fetched.messages.pop().unwrap();
1283        assert_eq!(fm.channel_id, test_notification.channel_id);
1284        assert_eq!(fm.data, Some(test_data));
1285
1286        // Grab all 1 of the messages that were submitted within the past 10 seconds.
1287        trace!(
1288            "📮 Fetching messages for user {} within the past 10 seconds",
1289            &user_id
1290        );
1291        let fetched = client
1292            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1293            .await?;
1294        assert_ne!(fetched.messages.len(), 0);
1295
1296        // Try grabbing a message for 10 seconds from now.
1297        trace!(
1298            "📮 Fetching messages for user {} 10 seconds in the future",
1299            &user_id
1300        );
1301        let fetched = client
1302            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1303            .await?;
1304        assert_eq!(fetched.messages.len(), 0);
1305
1306        // can we clean up our toys?
1307        trace!(
1308            "📮 Removing message for user {} :: {}",
1309            &user_id,
1310            &test_notification.chidmessageid()
1311        );
1312        assert!(client
1313            .remove_message(&uaid, &test_notification.chidmessageid())
1314            .await
1315            .is_ok());
1316
1317        trace!("📮 Removing channel for user {}", &user_id);
1318        assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1319
1320        trace!("📮 Making sure no messages remain for user {}", &user_id);
1321        let msgs = client
1322            .fetch_timestamp_messages(&uaid, None, 999)
1323            .await?
1324            .messages;
1325        assert!(msgs.is_empty());
1326
1327        // Now, can we do all that with topic messages
1328        // Unlike bigtable, we don't use [fetch_topic_messages]: it always return None:
1329        // they are handled as usuals messages.
1330        client.add_channel(&uaid, &topic_chid).await?;
1331        let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1332        let timestamp = sec_since_epoch();
1333        let sort_key = sec_since_epoch();
1334
1335        // We store 2 messages, with a single topic
1336        let test_notification_0 = crate::db::Notification {
1337            channel_id: topic_chid,
1338            version: "version0".to_owned(),
1339            ttl: 300,
1340            topic: Some("topic".to_owned()),
1341            timestamp,
1342            data: Some(test_data.clone()),
1343            sortkey_timestamp: Some(sort_key),
1344            ..Default::default()
1345        };
1346        assert!(client
1347            .save_message(&uaid, test_notification_0.clone())
1348            .await
1349            .is_ok());
1350
1351        let test_notification = crate::db::Notification {
1352            timestamp: sec_since_epoch(),
1353            version: "version1".to_owned(),
1354            sortkey_timestamp: Some(sort_key + 10),
1355            ..test_notification_0
1356        };
1357
1358        assert!(client
1359            .save_message(&uaid, test_notification.clone())
1360            .await
1361            .is_ok());
1362
1363        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1364        assert_eq!(fetched.messages.len(), 1);
1365        let fm = fetched.messages.pop().unwrap();
1366        assert_eq!(fm.channel_id, test_notification.channel_id);
1367        assert_eq!(fm.data, Some(test_data));
1368
1369        // Grab the message that was submitted.
1370        let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1371        assert_ne!(fetched.messages.len(), 0);
1372
1373        // can we clean up our toys?
1374        assert!(client
1375            .remove_message(&uaid, &test_notification.chidmessageid())
1376            .await
1377            .is_ok());
1378
1379        assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1380
1381        let msgs = client
1382            .fetch_timestamp_messages(&uaid, None, 999)
1383            .await?
1384            .messages;
1385        assert!(msgs.is_empty());
1386
1387        let fetched = client.get_user(&uaid).await?.unwrap();
1388        assert!(client
1389            .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1390            .await
1391            .is_ok());
1392        // did we remove it?
1393        let fetched = client.get_user(&uaid).await?.unwrap();
1394        assert_eq!(fetched.node_id, None);
1395
1396        assert!(client.remove_user(&uaid).await.is_ok());
1397
1398        assert!(client.get_user(&uaid).await?.is_none());
1399        Ok(())
1400    }
1401
1402    #[actix_rt::test]
1403    async fn test_expiry() -> DbResult<()> {
1404        // Make sure that we really are purging messages correctly
1405        init_test_logging();
1406        let client = new_client()?;
1407
1408        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1409        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1410        let now = sec_since_epoch();
1411
1412        let test_notification = crate::db::Notification {
1413            channel_id: chid,
1414            version: "test".to_owned(),
1415            ttl: 2,
1416            timestamp: now,
1417            data: Some("SomeData".into()),
1418            sortkey_timestamp: Some(now),
1419            ..Default::default()
1420        };
1421        client
1422            .add_user(&User {
1423                uaid,
1424                router_type: "test".to_owned(),
1425                connected_at: ms_since_epoch(),
1426                ..Default::default()
1427            })
1428            .await?;
1429        client.add_channel(&uaid, &chid).await?;
1430        debug!("🧪Writing test notif");
1431        client
1432            .save_message(&uaid, test_notification.clone())
1433            .await?;
1434        let key = uaid.simple().to_string();
1435        debug!("🧪Checking {}...", &key);
1436        let msg = client
1437            .fetch_timestamp_messages(&uaid, None, 1)
1438            .await?
1439            .messages
1440            .pop();
1441        assert!(msg.is_some());
1442        debug!("🧪Purging...");
1443        client.increment_storage(&uaid, now + 2).await?;
1444        debug!("🧪Checking for empty {}...", &key);
1445        let cc = client
1446            .fetch_timestamp_messages(&uaid, None, 1)
1447            .await?
1448            .messages
1449            .pop();
1450        assert!(cc.is_none());
1451        // clean up after the test.
1452        assert!(client.remove_user(&uaid).await.is_ok());
1453        Ok(())
1454    }
1455}