autopush_common/db/postgres/
mod.rs

1/* Postgres DbClient implementation.
2 * As noted elsewhere, autopush was originally designed to work with NoSql type databases.
3 * This implementation was done partially as an experiment. Postgres allows for limited
4 * NoSql-like functionality. The author, however, has VERY limited knowledge of postgres,
5 * and there are likely many inefficiencies in this implementation.
6 *
7 * PRs are always welcome.
8 */
9
10use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::{Duration, SystemTime};
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use deadpool_postgres::{Pool, Runtime};
19use serde_json::json;
20use tokio_postgres::{types::ToSql, Row}; // Client is sync.
21
22use async_trait::async_trait;
23use cadence::StatsdClient;
24use serde::{Deserialize, Serialize};
25use uuid::Uuid;
26
27use crate::db::client::DbClient;
28use crate::db::error::{DbError, DbResult};
29use crate::db::{DbSettings, User};
30use crate::notification::Notification;
31use crate::{util, MAX_ROUTER_TTL_SECS};
32
33use super::client::FetchMessageResponse;
34
35#[cfg(feature = "reliable_report")]
36const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
37
38#[derive(Debug, Clone, Deserialize, Serialize)]
39pub struct PostgresDbSettings {
40    #[serde(default)]
41    pub schema: Option<String>, // Optional DB Schema
42    #[serde(default)]
43    pub router_table: String, // Routing info
44    #[serde(default)]
45    pub message_table: String, // Message storage info
46    #[serde(default)]
47    pub meta_table: String, // Channels and meta info
48    #[serde(default)]
49    pub reliability_table: String, // Channels and meta info
50    #[serde(default)]
51    max_router_ttl: u64, // Max time for router records to live.
52                         // #[serde(default)]
53                         // pub use_tls: bool // Should you use a TLS connection to the db.
54}
55
56impl Default for PostgresDbSettings {
57    fn default() -> Self {
58        Self {
59            schema: None,
60            router_table: "router".to_owned(),
61            message_table: "message".to_owned(),
62            meta_table: "meta".to_owned(),
63            reliability_table: "reliability".to_owned(),
64            max_router_ttl: MAX_ROUTER_TTL_SECS,
65            // use_tls: false,
66        }
67    }
68}
69
70impl TryFrom<&str> for PostgresDbSettings {
71    type Error = DbError;
72    fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
73        if setting_string.trim().is_empty() {
74            return Ok(PostgresDbSettings::default());
75        }
76        serde_json::from_str(setting_string).map_err(|e| {
77            DbError::General(format!(
78                "Could not parse configuration db_settings: {:?}",
79                e
80            ))
81        })
82    }
83}
84
85#[derive(Clone)]
86pub struct PgClientImpl {
87    _metrics: Arc<StatsdClient>,
88    db_settings: PostgresDbSettings,
89    pool: Pool,
90}
91
92impl PgClientImpl {
93    /// Create a new Postgres Client.
94    ///
95    /// This uses the `settings.db_dsn`. to try and connect to the postgres database.
96    /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html
97    /// for parameter details and requirements.
98    /// Example DSN: postgresql://user:password@host/database?option=val
99    /// e.g. (postgresql://scott:tiger@dbhost/autopush?connect_timeout=10&keepalives_idle=3600)
100    pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
101        let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
102        // TODO: If required, add the TlsConnect<Stream> wrapper here.
103        let tls_flag = tokio_postgres::NoTls;
104        if let Some(dsn) = settings.dsn.clone() {
105            trace!("📮 Postgres Connect {}", &dsn);
106
107            let pool = deadpool_postgres::Config {
108                url: Some(dsn.clone()),
109                ..Default::default()
110            }
111            .create_pool(Some(Runtime::Tokio1), tls_flag)
112            .map_err(|e| DbError::General(e.to_string()))?;
113            return Ok(Self {
114                _metrics: metrics,
115                db_settings,
116                pool,
117            });
118        };
119        Err(DbError::ConnectionError("No DSN specified".to_owned()))
120    }
121
122    /// Does the given table exist
123    async fn table_exists(&self, table_name: String) -> DbResult<bool> {
124        let (schema, table_name) = if table_name.contains('.') {
125            let mut parts = table_name.splitn(2, '.');
126            (
127                parts.next().unwrap_or("public").to_owned(),
128                // If we are in a situation where someone specified a table name as
129                // `whatever.`, then we should absolutely panic here.
130                parts.next().unwrap().to_owned(),
131            )
132        } else {
133            ("public".to_owned(), table_name)
134        };
135        let rows = self
136            .pool
137            .get()
138            .await
139            .map_err(DbError::PgPoolError)?
140            .query(
141                "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
142                &[&schema, &table_name],
143            )
144            .await
145            .map_err(DbError::PgError)?;
146        let val: &str = rows[0].get(0);
147        Ok(val.to_lowercase().starts_with('t'))
148    }
149
150    /// Return the router's expiration timestamp
151    fn router_expiry(&self) -> u64 {
152        util::sec_since_epoch() + self.db_settings.max_router_ttl
153    }
154
155    /// The router table contains how to route messages to the recipient UAID.
156    pub(crate) fn router_table(&self) -> String {
157        if let Some(schema) = &self.db_settings.schema {
158            format!("{}.{}", schema, self.db_settings.router_table)
159        } else {
160            self.db_settings.router_table.clone()
161        }
162    }
163
164    /// The message table contains stored messages for UAIDs.
165    pub(crate) fn message_table(&self) -> String {
166        if let Some(schema) = &self.db_settings.schema {
167            format!("{}.{}", schema, self.db_settings.message_table)
168        } else {
169            self.db_settings.message_table.clone()
170        }
171    }
172
173    /// The meta table contains channel and other metadata for UAIDs.
174    /// With traditional "No-Sql" databases, this would be rolled into the
175    /// router table.
176    pub(crate) fn meta_table(&self) -> String {
177        if let Some(schema) = &self.db_settings.schema {
178            format!("{}.{}", schema, self.db_settings.meta_table)
179        } else {
180            self.db_settings.meta_table.clone()
181        }
182    }
183
184    /// The reliability table contains message delivery reliability states.
185    /// This is optional and should only be used to track internally generated
186    /// and consumed messages based on the VAPID public key signature.
187    #[cfg(feature = "reliable_report")]
188    pub(crate) fn reliability_table(&self) -> String {
189        if let Some(schema) = &self.db_settings.schema {
190            format!("{}.{}", schema, self.db_settings.reliability_table)
191        } else {
192            self.db_settings.reliability_table.clone()
193        }
194    }
195}
196
197#[async_trait]
198impl DbClient for PgClientImpl {
199    /// add user to router_table if not exists uaid
200    async fn add_user(&self, user: &User) -> DbResult<()> {
201        self.pool.get().await.map_err(DbError::PgPoolError)?.execute(
202            &format!("
203            INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
204             VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
205             ON CONFLICT (uaid) DO
206                UPDATE SET connected_at=EXCLUDED.connected_at,
207                    router_type=EXCLUDED.router_type,
208                    router_data=EXCLUDED.router_data,
209                    node_id=EXCLUDED.node_id,
210                    record_version=EXCLUDED.record_version,
211                    version=EXCLUDED.version,
212                    last_update=EXCLUDED.last_update,
213                    priv_channels=EXCLUDED.priv_channels,
214                    expiry=EXCLUDED.expiry
215                    ;
216            ", tablename=self.router_table()),
217            &[&user.uaid.simple().to_string(),              // 1 
218            &(user.connected_at as i64),                    // 2
219            &user.router_type,                              // 3    
220            &json!(user.router_data).to_string(),           // 4
221            &user.node_id,                                  // 5
222            &user.record_version.map(|i| i as i64),    // 6
223            &(user.version.map(|v| v.simple().to_string())),   // 7
224            &user.current_timestamp.map(|i| i as i64), // 8
225            &user.priv_channels.iter().map(|v| v.to_string()).collect::<Vec<String>>(),
226            &(self.router_expiry() as i64),                 // 10    
227            ]
228        ).await.map_err( DbError::PgError)?;
229        Ok(())
230    }
231
232    /// update user record in router_table at user.uaid
233    async fn update_user(&self, user: &mut User) -> DbResult<bool> {
234        let cmd = format!(
235            "UPDATE {tablename} SET connected_at=$2::BIGINT,
236                router_type=$3,
237                router_data=$4,
238                node_id=$5,
239                record_version=$6::BIGINT,
240                version=$7,
241                last_update=$8::BIGINT,
242                priv_channels=$9,
243                expiry=$10::BIGINT
244            WHERE
245                uaid = $1 AND connected_at < $2::BIGINT;
246            ",
247            tablename = self.router_table()
248        );
249        let result = self
250            .pool
251            .get()
252            .await
253            .map_err(DbError::PgPoolError)?
254            .execute(
255                &cmd,
256                &[
257                    &user.uaid.simple().to_string(),                 // 1
258                    &(user.connected_at as i64),                     // 2
259                    &user.router_type,                               // 3
260                    &json!(user.router_data).to_string(),            //4
261                    &user.node_id,                                   // 5
262                    &(user.record_version.map(|i| i as i64)),        // 6
263                    &(user.version.map(|v| v.simple().to_string())), // 7
264                    &user.current_timestamp.map(|i| i as i64),       //8
265                    &user
266                        .priv_channels
267                        .iter()
268                        .map(|v| v.to_string())
269                        .collect::<Vec<String>>(),
270                    &(self.router_expiry() as i64), // 10
271                ],
272            )
273            .await
274            .map_err(DbError::PgError)?;
275        Ok(result > 0)
276    }
277
278    /// fetch user information from router_table for uaid.
279    async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
280        let row = self.pool.get().await.map_err(DbError::PgPoolError)?
281        .query_opt(
282            &format!(
283                "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels 
284                 FROM {tablename} 
285                 WHERE uaid = $1",
286                tablename=self.router_table()
287            ),
288            &[&uaid.simple().to_string()]
289        )
290        .await
291        .map_err(DbError::PgError)?;
292
293        let Some(row) = row else {
294            return Ok(None);
295        };
296
297        // I was tempted to make this a From impl, but realized that it would mean making autopush-common require a dependency.
298        // Maybe make this a deserialize?
299        let priv_channels = if let Ok(Some(channels)) =
300            row.try_get::<&str, Option<Vec<String>>>("priv_channels")
301        {
302            let mut priv_channels = HashSet::new();
303            for channel in channels.iter() {
304                let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
305                priv_channels.insert(uuid);
306            }
307            priv_channels
308        } else {
309            HashSet::new()
310        };
311        let resp = User {
312            uaid: *uaid,
313            connected_at: row
314                .try_get::<&str, i64>("connected_at")
315                .map_err(DbError::PgError)? as u64,
316            router_type: row
317                .try_get::<&str, String>("router_type")
318                .map_err(DbError::PgError)?,
319            router_data: serde_json::from_str(
320                row.try_get::<&str, &str>("router_data")
321                    .map_err(DbError::PgError)?,
322            )
323            .map_err(|e| DbError::General(e.to_string()))?,
324            node_id: row
325                .try_get::<&str, Option<String>>("node_id")
326                .map_err(DbError::PgError)?,
327            record_version: row
328                .try_get::<&str, Option<i64>>("record_version")
329                .map_err(DbError::PgError)?
330                .map(|v| v as u64),
331            current_timestamp: row
332                .try_get::<&str, Option<i64>>("last_update")
333                .map_err(DbError::PgError)?
334                .map(|v| v as u64),
335            version: row
336                .try_get::<&str, Option<String>>("version")
337                .map_err(DbError::PgError)?
338                // An invalid UUID here is a data integrity error.
339                .map(|v| {
340                    Uuid::from_str(&v).map_err(|e| {
341                        DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
342                    })
343                })
344                .transpose()?,
345            priv_channels,
346        };
347        Ok(Some(resp))
348    }
349
350    /// delete a user at uaid from router_table
351    async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
352        self.pool
353            .get()
354            .await?
355            .execute(
356                &format!(
357                    "DELETE FROM {tablename}
358                     WHERE uaid = $1",
359                    tablename = self.router_table()
360                ),
361                &[&uaid.simple().to_string()],
362            )
363            .await
364            .map_err(DbError::PgError)?;
365        Ok(())
366    }
367
368    /// update list of channel_ids for uaid in meta table
369    /// Note: a conflicting channel_id is ignored, since it's already registered.
370    /// This should probably be optimized into the router table as a set value,
371    /// however I'm not familiar enough with Postgres to do so at this time.
372    /// Channels can be somewhat ephemeral, and we also want to limit the potential of
373    /// race conditions when adding or removing channels, particularly for mobile devices.
374    /// For some efficiency (mostly around the mobile "daily refresh" call), I've broken
375    /// the channels out by UAID into this table.
376    async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
377        self.pool
378            .get()
379            .await
380            .map_err(DbError::PgPoolError)?
381            .execute(
382                &format!(
383                    "INSERT 
384                     INTO {tablename} (uaid, channel_id) VALUES ($1, $2) 
385                     ON CONFLICT DO NOTHING",
386                    tablename = self.meta_table()
387                ),
388                &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
389            )
390            .await
391            .map_err(DbError::PgError)?;
392        Ok(())
393    }
394
395    /// Save all channels in a list
396    async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
397        if channels.is_empty() {
398            trace!("📮 No channels to save.");
399            return Ok(());
400        };
401        let uaid_str = uaid.simple().to_string();
402        // TODO: REVISIT THIS!!! May be possible without the gross hack.
403        // tokio-postgres doesn't store tuples as values, so you can't just construct
404        // the query as `INSERT into ... (a, b) VALUES (?,?), (?,?)`
405        // It does accept them as numerically specified values.
406        // The following is a gross hack that does basically that.
407        // The other option would be to just repeatedly call `self.add_channel()`
408        // but that seems far worse.
409        //
410        // first, collect the values into a flat fector. We force the type in
411        // the first item so that the second one is assumed.
412
413        let mut params = Vec::<&(dyn ToSql + Sync)>::new();
414        let mut new_channels = Vec::<String>::new();
415        for channel in channels {
416            new_channels.push(channel.simple().to_string());
417        }
418
419        // This is a bit cheesy, but we want capture the reference, not the value since
420        // it will not live long enough. Clippy will complain that this is a needless
421        // iterator, but it's not, really.
422        #[allow(clippy::needless_range_loop)]
423        for i in 0..new_channels.len() {
424            params.push(&uaid_str);
425            params.push(&new_channels[i]);
426        }
427
428        // Now construct the statement, iterate over the parameters we've got
429        // and redistribute them into tuples.
430        // (Remember, an existing channel_id is ignored during this insert since it's already registered)
431        let statement = format!(
432            "INSERT 
433                INTO {tablename} (uaid, channel_id) 
434                VALUES {vars} 
435                ON CONFLICT DO NOTHING",
436            tablename = self.meta_table(),
437            // Postgres variables are 1-indexed.
438            vars = Vec::from_iter((1..params.len() + 1).step_by(2).map(|v| format!(
439                "(${}, ${})",
440                v,
441                v + 1
442            )))
443            .join(",")
444        );
445        // finally, do the insert.
446        self.pool.get().await?.execute(&statement, &params).await?;
447        Ok(())
448    }
449
450    /// get all channels for uaid from meta table
451    async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
452        let mut result = HashSet::new();
453        let rows = self
454            .pool
455            .get()
456            .await
457            .map_err(DbError::PgPoolError)?
458            .query(
459                &format!(
460                    "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
461                    tablename = self.meta_table()
462                ),
463                &[&uaid.simple().to_string()],
464            )
465            .await
466            .map_err(DbError::PgError)?;
467        for row in rows.iter() {
468            let s = row
469                .try_get::<&str, &str>("channel_id")
470                .map_err(DbError::PgError)?;
471            result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
472        }
473        Ok(result)
474    }
475
476    /// remove an individual channel for a given uaid from meta table
477    async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
478        let cmd = format!(
479            "DELETE FROM {tablename} 
480             WHERE uaid = $1 AND channel_id = $2;",
481            tablename = self.meta_table()
482        );
483        let result = self
484            .pool
485            .get()
486            .await
487            .map_err(DbError::PgPoolError)?
488            .execute(
489                &cmd,
490                &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
491            )
492            .await?;
493        // We sometimes want to know if the channel existed previously.
494        Ok(result > 0)
495    }
496
497    /// remove node info for a uaid from router table
498    async fn remove_node_id(
499        &self,
500        uaid: &Uuid,
501        node_id: &str,
502        connected_at: u64,
503        version: &Option<Uuid>,
504    ) -> DbResult<bool> {
505        let Some(version) = version else {
506            return Err(DbError::General("Expected a user version field".to_owned()));
507        };
508        self.pool
509            .get()
510            .await
511            .map_err(DbError::PgPoolError)?
512            .execute(
513                &format!(
514                    "UPDATE {tablename} 
515                        SET node_id = null 
516                        WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
517                    tablename = self.router_table()
518                ),
519                &[
520                    &uaid.simple().to_string(),
521                    &node_id,
522                    &(connected_at as i64),
523                    &version.simple().to_string(),
524                ],
525            )
526            .await
527            .map_err(DbError::PgError)?;
528        Ok(true)
529    }
530
531    /// write a message to message table
532    async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
533        // fun fact: serde_postgres exists, but only deserializes (as of 0.2)
534        // (This is mutable if `reliable_report` enabled)
535        #[allow(unused_mut)]
536        let mut fields = vec![
537            "uaid",
538            "channel_id",
539            "chid_message_id",
540            "version",
541            "ttl",
542            "expiry",
543            "topic",
544            "timestamp",
545            "data",
546            "sortkey_timestamp",
547            "headers",
548        ];
549        // (This is mutable if `reliable_report` enabled)
550        #[allow(unused_mut)]
551        let mut inputs = vec![
552            "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
553        ];
554        #[cfg(feature = "reliable_report")]
555        {
556            fields.append(&mut ["reliability_id"].to_vec());
557            inputs.append(&mut ["$12"].to_vec());
558        }
559        let cmd = format!(
560            "INSERT INTO {tablename}
561                ({fields})
562                VALUES
563                ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
564                    uaid=EXCLUDED.uaid,
565                    channel_id=EXCLUDED.channel_id,
566                    version=EXCLUDED.version,
567                    ttl=EXCLUDED.ttl,
568                    expiry=EXCLUDED.expiry,
569                    topic=EXCLUDED.topic,
570                    timestamp=EXCLUDED.timestamp,
571                    data=EXCLUDED.data,
572                    sortkey_timestamp=EXCLUDED.sortkey_timestamp,
573                    headers=EXCLUDED.headers",
574            tablename = &self.message_table(),
575            fields = fields.join(","),
576            inputs = inputs.join(",")
577        );
578        self.pool
579            .get()
580            .await
581            .map_err(DbError::PgPoolError)?
582            .execute(
583                &cmd,
584                &[
585                    &uaid.simple().to_string(),
586                    &message.channel_id.simple().to_string(),
587                    &message.chidmessageid(),
588                    &message.version,
589                    &(message.ttl as i64), // Postgres has no auto TTL.
590                    &(util::sec_since_epoch() as i64 + message.ttl as i64),
591                    &message.topic,
592                    &(message.timestamp as i64),
593                    &message.data.unwrap_or_default(),
594                    &message.sortkey_timestamp.map(|v| v as i64),
595                    &json!(message.headers).to_string(),
596                    #[cfg(feature = "reliable_report")]
597                    &message.reliability_id,
598                ],
599            )
600            .await
601            .map_err(DbError::PgError)?;
602        Ok(())
603    }
604
605    /// remove a given message from the message table
606    async fn remove_message(&self, uaid: &Uuid, sort_key: &str) -> DbResult<()> {
607        self.pool
608            .get()
609            .await
610            .map_err(DbError::PgPoolError)?
611            .execute(
612                &format!(
613                    "DELETE FROM {tablename} 
614                     WHERE uaid=$1 AND chid_message_id = $2;",
615                    tablename = self.message_table()
616                ),
617                &[&uaid.simple().to_string(), &(sort_key.to_owned())],
618            )
619            .await
620            .map_err(DbError::PgError)?;
621
622        Ok(())
623    }
624
625    async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
626        for message in messages {
627            self.save_message(uaid, message).await?;
628        }
629        Ok(())
630    }
631
632    /// fetch topic messages for the user up to {limit}
633    /// Topic messages are auto-replacing singleton messages for a given user.
634    async fn fetch_topic_messages(
635        &self,
636        uaid: &Uuid,
637        limit: usize,
638    ) -> DbResult<FetchMessageResponse> {
639        let messages: Vec<Notification> = self
640            .pool
641            .get()
642            .await
643            .map_err(DbError::PgPoolError)?
644            .query(
645                &format!(
646                "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers 
647                 FROM {tablename} 
648                 WHERE uaid=$1 AND expiry >= $2
649                 ORDER BY timestamp DESC 
650                 LIMIT $3",
651                tablename=&self.message_table(),
652            ),
653                &[&uaid.simple().to_string(),&(util::sec_since_epoch() as i64), &(limit as i64)],
654            )
655            .await
656            .map_err(DbError::PgError)?
657            .iter()
658            .map(|row: &Row| row.try_into())
659            .collect::<Result<Vec<Notification>, DbError>>()?;
660
661        if messages.is_empty() {
662            Ok(Default::default())
663        } else {
664            Ok(FetchMessageResponse {
665                timestamp: Some(messages[0].timestamp),
666                messages,
667            })
668        }
669    }
670
671    /// Fetch messages for a user on or after a given timestamp up to {limit}
672    async fn fetch_timestamp_messages(
673        &self,
674        uaid: &Uuid,
675        timestamp: Option<u64>,
676        limit: usize,
677    ) -> DbResult<FetchMessageResponse> {
678        let uaid = uaid.simple().to_string();
679        let response: Vec<Row> = if let Some(ts) = timestamp {
680            trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
681            self.pool
682                .get()
683                .await
684                .map_err(DbError::PgPoolError)?
685                .query(
686                    &format!(
687                        "SELECT * FROM {} 
688                         WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
689                         ORDER BY timestamp 
690                         LIMIT $3",
691                        self.message_table()
692                    ),
693                    &[
694                        &uaid,
695                        &(ts as i64),
696                        &(limit as i64),
697                        &(util::sec_since_epoch() as i64),
698                    ],
699                )
700                .await
701        } else {
702            trace!("📮 Fetching messages for user {}", &uaid);
703            self.pool
704                .get()
705                .await
706                .map_err(DbError::PgPoolError)?
707                .query(
708                    &format!(
709                        "SELECT * 
710                         FROM {} 
711                         WHERE uaid = $1 
712                         AND expiry >= $2
713                         LIMIT $2",
714                        self.message_table()
715                    ),
716                    &[&uaid, &(limit as i64), &(util::sec_since_epoch() as i64)],
717                )
718                .await
719        }?;
720
721        let messages: Vec<Notification> = response
722            .iter()
723            .map(|row: &Row| row.try_into())
724            .collect::<Result<Vec<Notification>, DbError>>()?;
725        let timestamp = if !messages.is_empty() {
726            Some(messages[0].timestamp)
727        } else {
728            None
729        };
730
731        Ok(FetchMessageResponse {
732            timestamp,
733            messages,
734        })
735    }
736
737    /// Convenience function to check if the router table exists
738    async fn router_table_exists(&self) -> DbResult<bool> {
739        self.table_exists(self.router_table()).await
740    }
741
742    /// Convenience function to check if the message table exists
743    async fn message_table_exists(&self) -> DbResult<bool> {
744        self.table_exists(self.message_table()).await
745    }
746
747    #[cfg(feature = "reliable_report")]
748    async fn log_report(
749        &self,
750        reliability_id: &str,
751        new_state: crate::reliability::ReliabilityState,
752    ) -> DbResult<()> {
753        let timestamp =
754            SystemTime::now() + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64);
755        debug!("📮 Logging report for {reliability_id} as {new_state}");
756        /*
757            INSERT INTO {tablename} (id, states) VALUES ({reliability_id}, json_build_object({state}, {timestamp}) )
758            ON CONFLICT(id)
759            UPDATE {tablename} SET states = jsonb_set(states, array[{state}], to_jsonb({timestamp}));
760        */
761
762        let tablename = &self.reliability_table();
763        let state = new_state.to_string();
764        self.pool
765            .get()
766            .await?
767            .execute(
768                &format!(
769                "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
770                 ON CONFLICT (id) DO
771                 UPDATE SET states = EXCLUDED.states, 
772                 last_update_timestamp = EXCLUDED.last_update_timestamp;",
773                    tablename = tablename
774                ),
775                &[&reliability_id, &state, &timestamp],
776            )
777            .await?;
778        Ok(())
779    }
780
781    async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
782        debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
783        let tablename = &self.router_table();
784
785        trace!("📮 Purging git{uaid} for < {timestamp}");
786        let mut pool = self.pool.get().await.map_err(DbError::PgPoolError)?;
787        let transaction = pool.transaction().await?;
788        // Try to garbage collect old messages first.
789        transaction
790            .execute(
791                &format!(
792                    "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
793                    &self.message_table()
794                ),
795                &[
796                    &uaid.simple().to_string(),
797                    &(util::sec_since_epoch() as i64),
798                ],
799            )
800            .await?;
801        // Now, delete messages that we've already delivered.
802        transaction
803            .execute(
804                &format!(
805                    "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
806                    &self.message_table()
807                ),
808                &[&uaid.simple().to_string(), &(timestamp as i64)],
809            )
810            .await?;
811        transaction.execute(
812                &format!(
813                    "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
814                ),
815                &[
816                    &uaid.simple().to_string(),
817                    &(timestamp as i64),
818                    &(self.router_expiry() as i64),
819                ],
820            )
821            .await?;
822        transaction.commit().await?;
823        Ok(())
824    }
825
826    fn name(&self) -> String {
827        "Postgres".to_owned()
828    }
829
830    async fn health_check(&self) -> DbResult<bool> {
831        // Replace this with a proper health check.
832        let client = self.pool.get().await.map_err(DbError::PgPoolError);
833        let row = client?.query_one("select true", &[]).await;
834        Ok(!row?.is_empty())
835    }
836
837    /// Convenience function to return self as a Boxed DbClient
838    fn box_clone(&self) -> Box<dyn DbClient> {
839        Box::new(self.clone())
840    }
841}
842
843/* Note:
844 * For preliminary testing, you will need to start a local postgres instance (see
845 * https://www.docker.com/blog/how-to-use-the-postgres-docker-official-image/) and initialize the
846 * database with `schema.psql`.
847 * Once you have, you can define the environment variable `POSTGRES_HOST` to point to the
848 * appropriate host (e.g. `postgres:post_pass@localhost:/autopush`). `new_client` will add the
849 * `postgres://` prefix automatically.
850 *
851 * TODO: Really should move the bulk of the tests to a higher level and add backend specific
852 * versions of `new_client`.
853 *
854 */
855#[cfg(test)]
856mod tests {
857    use crate::util::sec_since_epoch;
858    use crate::{logging::init_test_logging, util::ms_since_epoch};
859    use rand::prelude::*;
860    use serde_json::json;
861    use std::env;
862
863    use super::*;
864    const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
865    const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
866
867    fn new_client() -> DbResult<PgClientImpl> {
868        // Use an environment variable to potentially override the default storage test host.
869        let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
870        let env_dsn = format!("postgres://{host}");
871        debug!("📮 Connecting to {env_dsn}");
872        let settings = DbSettings {
873            dsn: Some(env_dsn),
874            db_settings: json!(PostgresDbSettings {
875                schema: Some("autopush".to_owned()),
876                ..Default::default()
877            })
878            .to_string(),
879        };
880        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
881        PgClientImpl::new(metrics, &settings)
882    }
883
884    fn gen_test_user() -> String {
885        // Create a semi-unique test user to avoid conflicting test values.
886        let mut rng = rand::rng();
887        let test_num = rng.random::<u8>();
888        format!(
889            "DEADBEEF-0000-0000-{:04}-{:012}",
890            test_num,
891            sec_since_epoch()
892        )
893    }
894
895    #[actix_rt::test]
896    async fn health_check() {
897        let client = new_client().unwrap();
898
899        let result = client.health_check().await;
900        assert!(result.is_ok());
901        assert!(result.unwrap());
902    }
903
904    /// Test if [increment_storage] correctly wipe expired messages
905    #[actix_rt::test]
906    async fn wipe_expired() -> DbResult<()> {
907        init_test_logging();
908        let client = new_client()?;
909
910        let connected_at = ms_since_epoch();
911
912        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
913        let chid = Uuid::parse_str(TEST_CHID).unwrap();
914
915        let node_id = "test_node".to_owned();
916
917        // purge the user record if it exists.
918        let _ = client.remove_user(&uaid).await;
919
920        let test_user = User {
921            uaid,
922            router_type: "webpush".to_owned(),
923            connected_at,
924            router_data: None,
925            node_id: Some(node_id.clone()),
926            ..Default::default()
927        };
928
929        // purge the old user (if present)
930        // in case a prior test failed for whatever reason.
931        let _ = client.remove_user(&uaid).await;
932
933        // can we add the user?
934        let timestamp = sec_since_epoch();
935        client.add_user(&test_user).await?;
936        let test_notification = crate::db::Notification {
937            channel_id: chid,
938            version: "test".to_owned(),
939            ttl: 1,
940            timestamp,
941            data: Some("Encrypted".into()),
942            sortkey_timestamp: Some(timestamp),
943            ..Default::default()
944        };
945        client.save_message(&uaid, test_notification).await?;
946        client.increment_storage(&uaid, timestamp + 1).await?;
947        let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
948        assert_eq!(msgs.messages.len(), 0);
949        assert!(client.remove_user(&uaid).await.is_ok());
950        Ok(())
951    }
952
953    /// run a gauntlet of testing. These are a bit linear because they need
954    /// to run in sequence.
955    #[actix_rt::test]
956    async fn run_gauntlet() -> DbResult<()> {
957        init_test_logging();
958        let client = new_client()?;
959
960        let connected_at = ms_since_epoch();
961
962        let user_id = &gen_test_user();
963        let uaid = Uuid::parse_str(user_id).unwrap();
964        let chid = Uuid::parse_str(TEST_CHID).unwrap();
965        let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
966
967        let node_id = "test_node".to_owned();
968
969        // purge the user record if it exists.
970        let _ = client.remove_user(&uaid).await;
971
972        let test_user = User {
973            uaid,
974            router_type: "webpush".to_owned(),
975            connected_at,
976            router_data: None,
977            node_id: Some(node_id.clone()),
978            ..Default::default()
979        };
980
981        // purge the old user (if present)
982        // in case a prior test failed for whatever reason.
983        let _ = client.remove_user(&uaid).await;
984
985        // can we add the user?
986        trace!("📮 Adding user {}", &user_id);
987        client.add_user(&test_user).await?;
988        let fetched = client.get_user(&uaid).await?;
989        assert!(fetched.is_some());
990        let fetched = fetched.unwrap();
991        assert_eq!(fetched.router_type, "webpush".to_owned());
992
993        // can we add channels?
994        trace!("📮 Adding channel {} to user {}", &chid, &user_id);
995        client.add_channel(&uaid, &chid).await?;
996        let channels = client.get_channels(&uaid).await?;
997        assert!(channels.contains(&chid));
998
999        // can we add lots of channels?
1000        let mut new_channels: HashSet<Uuid> = HashSet::new();
1001        trace!("📮 Adding multiple channels to user {}", &user_id);
1002        new_channels.insert(chid);
1003        for _ in 1..10 {
1004            new_channels.insert(uuid::Uuid::new_v4());
1005        }
1006        let chid_to_remove = uuid::Uuid::new_v4();
1007        trace!(
1008            "📮 Adding removable channel {} to user {}",
1009            &chid_to_remove,
1010            &user_id
1011        );
1012        new_channels.insert(chid_to_remove);
1013        client.add_channels(&uaid, new_channels.clone()).await?;
1014        let channels = client.get_channels(&uaid).await?;
1015        assert_eq!(channels, new_channels);
1016
1017        // can we remove a channel?
1018        trace!(
1019            "📮 Removing channel {} from user {}",
1020            &chid_to_remove,
1021            &user_id
1022        );
1023        assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1024        trace!(
1025            "📮 retrying Removing channel {} from user {}",
1026            &chid_to_remove,
1027            &user_id
1028        );
1029        assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1030        new_channels.remove(&chid_to_remove);
1031        let channels = client.get_channels(&uaid).await?;
1032        assert_eq!(channels, new_channels);
1033
1034        // now ensure that we can update a user that's after the time we set
1035        // prior. first ensure that we can't update a user that's before the
1036        // time we set prior to the last write
1037        let mut updated = User {
1038            connected_at,
1039            ..test_user.clone()
1040        };
1041        trace!(
1042            "📮 Attempting to update user {} with old connected_at: {}",
1043            &user_id,
1044            &updated.connected_at
1045        );
1046        let result = client.update_user(&mut updated).await;
1047        assert!(result.is_ok());
1048        assert!(!result.unwrap());
1049
1050        // Make sure that the `connected_at` wasn't modified
1051        let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1052        assert_eq!(fetched.connected_at, fetched2.connected_at);
1053
1054        // and make sure we can update a record with a later connected_at time.
1055        let mut updated = User {
1056            connected_at: fetched.connected_at + 300,
1057            ..fetched2
1058        };
1059        trace!(
1060            "📮 Attempting to update user {} with new connected_at",
1061            &user_id
1062        );
1063        let result = client.update_user(&mut updated).await;
1064        assert!(result.is_ok());
1065        assert!(result.unwrap());
1066        assert_ne!(
1067            fetched2.connected_at,
1068            client.get_user(&uaid).await?.unwrap().connected_at
1069        );
1070        // can we increment the storage for the user?
1071        trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1072        client
1073            .increment_storage(&fetched.uaid, sec_since_epoch())
1074            .await?;
1075
1076        let test_data = "An_encrypted_pile_of_crap".to_owned();
1077        let timestamp = sec_since_epoch();
1078        let sort_key = sec_since_epoch();
1079        let fetch_timestamp = timestamp;
1080        // Can we store a message?
1081        let test_notification = crate::db::Notification {
1082            channel_id: chid,
1083            version: "test".to_owned(),
1084            ttl: 300,
1085            timestamp,
1086            data: Some(test_data.clone()),
1087            sortkey_timestamp: Some(sort_key),
1088            ..Default::default()
1089        };
1090        trace!("📮 Saving message for user {}", &user_id);
1091        let res = client.save_message(&uaid, test_notification.clone()).await;
1092        assert!(res.is_ok());
1093
1094        trace!("📮 Fetching all messages for user {}", &user_id);
1095        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1096        assert_ne!(fetched.messages.len(), 0);
1097        let fm = fetched.messages.pop().unwrap();
1098        assert_eq!(fm.channel_id, test_notification.channel_id);
1099        assert_eq!(fm.data, Some(test_data));
1100
1101        // Grab all 1 of the messages that were submitted within the past 10 seconds.
1102        trace!(
1103            "📮 Fetching messages for user {} within the past 10 seconds",
1104            &user_id
1105        );
1106        let fetched = client
1107            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1108            .await?;
1109        assert_ne!(fetched.messages.len(), 0);
1110
1111        // Try grabbing a message for 10 seconds from now.
1112        trace!(
1113            "📮 Fetching messages for user {} 10 seconds in the future",
1114            &user_id
1115        );
1116        let fetched = client
1117            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1118            .await?;
1119        assert_eq!(fetched.messages.len(), 0);
1120
1121        // can we clean up our toys?
1122        trace!(
1123            "📮 Removing message for user {} :: {}",
1124            &user_id,
1125            &test_notification.chidmessageid()
1126        );
1127        assert!(client
1128            .remove_message(&uaid, &test_notification.chidmessageid())
1129            .await
1130            .is_ok());
1131
1132        trace!("📮 Removing channel for user {}", &user_id);
1133        assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1134
1135        trace!("📮 Making sure no messages remain for user {}", &user_id);
1136        let msgs = client
1137            .fetch_timestamp_messages(&uaid, None, 999)
1138            .await?
1139            .messages;
1140        assert!(msgs.is_empty());
1141
1142        // Now, can we do all that with topic messages
1143        // Unlike bigtable, we don't use [fetch_topic_messages]: it always return None:
1144        // they are handled as usuals messages.
1145        client.add_channel(&uaid, &topic_chid).await?;
1146        let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1147        let timestamp = sec_since_epoch();
1148        let sort_key = sec_since_epoch();
1149
1150        // We store 2 messages, with a single topic
1151        let test_notification_0 = crate::db::Notification {
1152            channel_id: topic_chid,
1153            version: "version0".to_owned(),
1154            ttl: 300,
1155            topic: Some("topic".to_owned()),
1156            timestamp,
1157            data: Some(test_data.clone()),
1158            sortkey_timestamp: Some(sort_key),
1159            ..Default::default()
1160        };
1161        assert!(client
1162            .save_message(&uaid, test_notification_0.clone())
1163            .await
1164            .is_ok());
1165
1166        let test_notification = crate::db::Notification {
1167            timestamp: sec_since_epoch(),
1168            version: "version1".to_owned(),
1169            sortkey_timestamp: Some(sort_key + 10),
1170            ..test_notification_0
1171        };
1172
1173        assert!(client
1174            .save_message(&uaid, test_notification.clone())
1175            .await
1176            .is_ok());
1177
1178        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1179        assert_eq!(fetched.messages.len(), 1);
1180        let fm = fetched.messages.pop().unwrap();
1181        assert_eq!(fm.channel_id, test_notification.channel_id);
1182        assert_eq!(fm.data, Some(test_data));
1183
1184        // Grab the message that was submitted.
1185        let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1186        assert_ne!(fetched.messages.len(), 0);
1187
1188        // can we clean up our toys?
1189        assert!(client
1190            .remove_message(&uaid, &test_notification.chidmessageid())
1191            .await
1192            .is_ok());
1193
1194        assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1195
1196        let msgs = client
1197            .fetch_timestamp_messages(&uaid, None, 999)
1198            .await?
1199            .messages;
1200        assert!(msgs.is_empty());
1201
1202        let fetched = client.get_user(&uaid).await?.unwrap();
1203        assert!(client
1204            .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1205            .await
1206            .is_ok());
1207        // did we remove it?
1208        let fetched = client.get_user(&uaid).await?.unwrap();
1209        assert_eq!(fetched.node_id, None);
1210
1211        assert!(client.remove_user(&uaid).await.is_ok());
1212
1213        assert!(client.get_user(&uaid).await?.is_none());
1214        Ok(())
1215    }
1216
1217    #[actix_rt::test]
1218    async fn test_expiry() -> DbResult<()> {
1219        // Make sure that we really are purging messages correctly
1220        init_test_logging();
1221        let client = new_client()?;
1222
1223        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1224        let chid = Uuid::parse_str(TEST_CHID).unwrap();
1225        let now = sec_since_epoch();
1226
1227        let test_notification = crate::db::Notification {
1228            channel_id: chid,
1229            version: "test".to_owned(),
1230            ttl: 2,
1231            timestamp: now,
1232            data: Some("SomeData".into()),
1233            sortkey_timestamp: Some(now),
1234            ..Default::default()
1235        };
1236        client
1237            .add_user(&User {
1238                uaid,
1239                router_type: "test".to_owned(),
1240                connected_at: ms_since_epoch(),
1241                ..Default::default()
1242            })
1243            .await?;
1244        client.add_channel(&uaid, &chid).await?;
1245        debug!("🧪Writing test notif");
1246        client
1247            .save_message(&uaid, test_notification.clone())
1248            .await?;
1249        let key = uaid.simple().to_string();
1250        debug!("🧪Checking {}...", &key);
1251        let msg = client
1252            .fetch_timestamp_messages(&uaid, None, 1)
1253            .await?
1254            .messages
1255            .pop();
1256        assert!(msg.is_some());
1257        debug!("🧪Purging...");
1258        client.increment_storage(&uaid, now + 2).await?;
1259        debug!("🧪Checking for empty {}...", &key);
1260        let cc = client
1261            .fetch_timestamp_messages(&uaid, None, 1)
1262            .await?
1263            .messages
1264            .pop();
1265        assert!(cc.is_none());
1266        // clean up after the test.
1267        assert!(client.remove_user(&uaid).await.is_ok());
1268        Ok(())
1269    }
1270}