1use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::{Duration, SystemTime};
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use deadpool_postgres::{Pool, Runtime};
19use serde_json::json;
20use tokio_postgres::{types::ToSql, Row}; use async_trait::async_trait;
23use cadence::StatsdClient;
24use serde::{Deserialize, Serialize};
25use uuid::Uuid;
26
27use crate::db::client::DbClient;
28use crate::db::error::{DbError, DbResult};
29use crate::db::{DbSettings, User};
30use crate::notification::Notification;
31use crate::{util, MAX_ROUTER_TTL_SECS};
32
33use super::client::FetchMessageResponse;
34
35#[cfg(feature = "reliable_report")]
36const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
37
38#[derive(Debug, Clone, Deserialize, Serialize)]
39pub struct PostgresDbSettings {
40 #[serde(default)]
41 pub schema: Option<String>, #[serde(default)]
43 pub router_table: String, #[serde(default)]
45 pub message_table: String, #[serde(default)]
47 pub meta_table: String, #[serde(default)]
49 pub reliability_table: String, #[serde(default)]
51 max_router_ttl: u64, }
55
56impl Default for PostgresDbSettings {
57 fn default() -> Self {
58 Self {
59 schema: None,
60 router_table: "router".to_owned(),
61 message_table: "message".to_owned(),
62 meta_table: "meta".to_owned(),
63 reliability_table: "reliability".to_owned(),
64 max_router_ttl: MAX_ROUTER_TTL_SECS,
65 }
67 }
68}
69
70impl TryFrom<&str> for PostgresDbSettings {
71 type Error = DbError;
72 fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
73 if setting_string.trim().is_empty() {
74 return Ok(PostgresDbSettings::default());
75 }
76 serde_json::from_str(setting_string).map_err(|e| {
77 DbError::General(format!(
78 "Could not parse configuration db_settings: {:?}",
79 e
80 ))
81 })
82 }
83}
84
85#[derive(Clone)]
86pub struct PgClientImpl {
87 _metrics: Arc<StatsdClient>,
88 db_settings: PostgresDbSettings,
89 pool: Pool,
90}
91
92impl PgClientImpl {
93 pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
101 let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
102 let tls_flag = tokio_postgres::NoTls;
104 if let Some(dsn) = settings.dsn.clone() {
105 trace!("📮 Postgres Connect {}", &dsn);
106
107 let pool = deadpool_postgres::Config {
108 url: Some(dsn.clone()),
109 ..Default::default()
110 }
111 .create_pool(Some(Runtime::Tokio1), tls_flag)
112 .map_err(|e| DbError::General(e.to_string()))?;
113 return Ok(Self {
114 _metrics: metrics,
115 db_settings,
116 pool,
117 });
118 };
119 Err(DbError::ConnectionError("No DSN specified".to_owned()))
120 }
121
122 async fn table_exists(&self, table_name: String) -> DbResult<bool> {
124 let (schema, table_name) = if table_name.contains('.') {
125 let mut parts = table_name.splitn(2, '.');
126 (
127 parts.next().unwrap_or("public").to_owned(),
128 parts.next().unwrap().to_owned(),
131 )
132 } else {
133 ("public".to_owned(), table_name)
134 };
135 let rows = self
136 .pool
137 .get()
138 .await
139 .map_err(DbError::PgPoolError)?
140 .query(
141 "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
142 &[&schema, &table_name],
143 )
144 .await
145 .map_err(DbError::PgError)?;
146 let val: &str = rows[0].get(0);
147 Ok(val.to_lowercase().starts_with('t'))
148 }
149
150 fn router_expiry(&self) -> u64 {
152 util::sec_since_epoch() + self.db_settings.max_router_ttl
153 }
154
155 pub(crate) fn router_table(&self) -> String {
157 if let Some(schema) = &self.db_settings.schema {
158 format!("{}.{}", schema, self.db_settings.router_table)
159 } else {
160 self.db_settings.router_table.clone()
161 }
162 }
163
164 pub(crate) fn message_table(&self) -> String {
166 if let Some(schema) = &self.db_settings.schema {
167 format!("{}.{}", schema, self.db_settings.message_table)
168 } else {
169 self.db_settings.message_table.clone()
170 }
171 }
172
173 pub(crate) fn meta_table(&self) -> String {
177 if let Some(schema) = &self.db_settings.schema {
178 format!("{}.{}", schema, self.db_settings.meta_table)
179 } else {
180 self.db_settings.meta_table.clone()
181 }
182 }
183
184 #[cfg(feature = "reliable_report")]
188 pub(crate) fn reliability_table(&self) -> String {
189 if let Some(schema) = &self.db_settings.schema {
190 format!("{}.{}", schema, self.db_settings.reliability_table)
191 } else {
192 self.db_settings.reliability_table.clone()
193 }
194 }
195}
196
197#[async_trait]
198impl DbClient for PgClientImpl {
199 async fn add_user(&self, user: &User) -> DbResult<()> {
201 self.pool.get().await.map_err(DbError::PgPoolError)?.execute(
202 &format!("
203 INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
204 VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
205 ON CONFLICT (uaid) DO
206 UPDATE SET connected_at=EXCLUDED.connected_at,
207 router_type=EXCLUDED.router_type,
208 router_data=EXCLUDED.router_data,
209 node_id=EXCLUDED.node_id,
210 record_version=EXCLUDED.record_version,
211 version=EXCLUDED.version,
212 last_update=EXCLUDED.last_update,
213 priv_channels=EXCLUDED.priv_channels,
214 expiry=EXCLUDED.expiry
215 ;
216 ", tablename=self.router_table()),
217 &[&user.uaid.simple().to_string(), &(user.connected_at as i64), &user.router_type, &json!(user.router_data).to_string(), &user.node_id, &user.record_version.map(|i| i as i64), &(user.version.map(|v| v.simple().to_string())), &user.current_timestamp.map(|i| i as i64), &user.priv_channels.iter().map(|v| v.to_string()).collect::<Vec<String>>(),
226 &(self.router_expiry() as i64), ]
228 ).await.map_err( DbError::PgError)?;
229 Ok(())
230 }
231
232 async fn update_user(&self, user: &mut User) -> DbResult<bool> {
234 let cmd = format!(
235 "UPDATE {tablename} SET connected_at=$2::BIGINT,
236 router_type=$3,
237 router_data=$4,
238 node_id=$5,
239 record_version=$6::BIGINT,
240 version=$7,
241 last_update=$8::BIGINT,
242 priv_channels=$9,
243 expiry=$10::BIGINT
244 WHERE
245 uaid = $1 AND connected_at < $2::BIGINT;
246 ",
247 tablename = self.router_table()
248 );
249 let result = self
250 .pool
251 .get()
252 .await
253 .map_err(DbError::PgPoolError)?
254 .execute(
255 &cmd,
256 &[
257 &user.uaid.simple().to_string(), &(user.connected_at as i64), &user.router_type, &json!(user.router_data).to_string(), &user.node_id, &(user.record_version.map(|i| i as i64)), &(user.version.map(|v| v.simple().to_string())), &user.current_timestamp.map(|i| i as i64), &user
266 .priv_channels
267 .iter()
268 .map(|v| v.to_string())
269 .collect::<Vec<String>>(),
270 &(self.router_expiry() as i64), ],
272 )
273 .await
274 .map_err(DbError::PgError)?;
275 Ok(result > 0)
276 }
277
278 async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
280 let row = self.pool.get().await.map_err(DbError::PgPoolError)?
281 .query_opt(
282 &format!(
283 "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels
284 FROM {tablename}
285 WHERE uaid = $1",
286 tablename=self.router_table()
287 ),
288 &[&uaid.simple().to_string()]
289 )
290 .await
291 .map_err(DbError::PgError)?;
292
293 let Some(row) = row else {
294 return Ok(None);
295 };
296
297 let priv_channels = if let Ok(Some(channels)) =
300 row.try_get::<&str, Option<Vec<String>>>("priv_channels")
301 {
302 let mut priv_channels = HashSet::new();
303 for channel in channels.iter() {
304 let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
305 priv_channels.insert(uuid);
306 }
307 priv_channels
308 } else {
309 HashSet::new()
310 };
311 let resp = User {
312 uaid: *uaid,
313 connected_at: row
314 .try_get::<&str, i64>("connected_at")
315 .map_err(DbError::PgError)? as u64,
316 router_type: row
317 .try_get::<&str, String>("router_type")
318 .map_err(DbError::PgError)?,
319 router_data: serde_json::from_str(
320 row.try_get::<&str, &str>("router_data")
321 .map_err(DbError::PgError)?,
322 )
323 .map_err(|e| DbError::General(e.to_string()))?,
324 node_id: row
325 .try_get::<&str, Option<String>>("node_id")
326 .map_err(DbError::PgError)?,
327 record_version: row
328 .try_get::<&str, Option<i64>>("record_version")
329 .map_err(DbError::PgError)?
330 .map(|v| v as u64),
331 current_timestamp: row
332 .try_get::<&str, Option<i64>>("last_update")
333 .map_err(DbError::PgError)?
334 .map(|v| v as u64),
335 version: row
336 .try_get::<&str, Option<String>>("version")
337 .map_err(DbError::PgError)?
338 .map(|v| {
340 Uuid::from_str(&v).map_err(|e| {
341 DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
342 })
343 })
344 .transpose()?,
345 priv_channels,
346 };
347 Ok(Some(resp))
348 }
349
350 async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
352 self.pool
353 .get()
354 .await?
355 .execute(
356 &format!(
357 "DELETE FROM {tablename}
358 WHERE uaid = $1",
359 tablename = self.router_table()
360 ),
361 &[&uaid.simple().to_string()],
362 )
363 .await
364 .map_err(DbError::PgError)?;
365 Ok(())
366 }
367
368 async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
377 self.pool
378 .get()
379 .await
380 .map_err(DbError::PgPoolError)?
381 .execute(
382 &format!(
383 "INSERT
384 INTO {tablename} (uaid, channel_id) VALUES ($1, $2)
385 ON CONFLICT DO NOTHING",
386 tablename = self.meta_table()
387 ),
388 &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
389 )
390 .await
391 .map_err(DbError::PgError)?;
392 Ok(())
393 }
394
395 async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
397 if channels.is_empty() {
398 trace!("📮 No channels to save.");
399 return Ok(());
400 };
401 let uaid_str = uaid.simple().to_string();
402 let mut params = Vec::<&(dyn ToSql + Sync)>::new();
414 let mut new_channels = Vec::<String>::new();
415 for channel in channels {
416 new_channels.push(channel.simple().to_string());
417 }
418
419 #[allow(clippy::needless_range_loop)]
423 for i in 0..new_channels.len() {
424 params.push(&uaid_str);
425 params.push(&new_channels[i]);
426 }
427
428 let statement = format!(
432 "INSERT
433 INTO {tablename} (uaid, channel_id)
434 VALUES {vars}
435 ON CONFLICT DO NOTHING",
436 tablename = self.meta_table(),
437 vars = Vec::from_iter((1..params.len() + 1).step_by(2).map(|v| format!(
439 "(${}, ${})",
440 v,
441 v + 1
442 )))
443 .join(",")
444 );
445 self.pool.get().await?.execute(&statement, ¶ms).await?;
447 Ok(())
448 }
449
450 async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
452 let mut result = HashSet::new();
453 let rows = self
454 .pool
455 .get()
456 .await
457 .map_err(DbError::PgPoolError)?
458 .query(
459 &format!(
460 "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
461 tablename = self.meta_table()
462 ),
463 &[&uaid.simple().to_string()],
464 )
465 .await
466 .map_err(DbError::PgError)?;
467 for row in rows.iter() {
468 let s = row
469 .try_get::<&str, &str>("channel_id")
470 .map_err(DbError::PgError)?;
471 result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
472 }
473 Ok(result)
474 }
475
476 async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
478 let cmd = format!(
479 "DELETE FROM {tablename}
480 WHERE uaid = $1 AND channel_id = $2;",
481 tablename = self.meta_table()
482 );
483 let result = self
484 .pool
485 .get()
486 .await
487 .map_err(DbError::PgPoolError)?
488 .execute(
489 &cmd,
490 &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
491 )
492 .await?;
493 Ok(result > 0)
495 }
496
497 async fn remove_node_id(
499 &self,
500 uaid: &Uuid,
501 node_id: &str,
502 connected_at: u64,
503 version: &Option<Uuid>,
504 ) -> DbResult<bool> {
505 let Some(version) = version else {
506 return Err(DbError::General("Expected a user version field".to_owned()));
507 };
508 self.pool
509 .get()
510 .await
511 .map_err(DbError::PgPoolError)?
512 .execute(
513 &format!(
514 "UPDATE {tablename}
515 SET node_id = null
516 WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
517 tablename = self.router_table()
518 ),
519 &[
520 &uaid.simple().to_string(),
521 &node_id,
522 &(connected_at as i64),
523 &version.simple().to_string(),
524 ],
525 )
526 .await
527 .map_err(DbError::PgError)?;
528 Ok(true)
529 }
530
531 async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
533 #[allow(unused_mut)]
536 let mut fields = vec![
537 "uaid",
538 "channel_id",
539 "chid_message_id",
540 "version",
541 "ttl",
542 "expiry",
543 "topic",
544 "timestamp",
545 "data",
546 "sortkey_timestamp",
547 "headers",
548 ];
549 #[allow(unused_mut)]
551 let mut inputs = vec![
552 "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
553 ];
554 #[cfg(feature = "reliable_report")]
555 {
556 fields.append(&mut ["reliability_id"].to_vec());
557 inputs.append(&mut ["$12"].to_vec());
558 }
559 let cmd = format!(
560 "INSERT INTO {tablename}
561 ({fields})
562 VALUES
563 ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
564 uaid=EXCLUDED.uaid,
565 channel_id=EXCLUDED.channel_id,
566 version=EXCLUDED.version,
567 ttl=EXCLUDED.ttl,
568 expiry=EXCLUDED.expiry,
569 topic=EXCLUDED.topic,
570 timestamp=EXCLUDED.timestamp,
571 data=EXCLUDED.data,
572 sortkey_timestamp=EXCLUDED.sortkey_timestamp,
573 headers=EXCLUDED.headers",
574 tablename = &self.message_table(),
575 fields = fields.join(","),
576 inputs = inputs.join(",")
577 );
578 self.pool
579 .get()
580 .await
581 .map_err(DbError::PgPoolError)?
582 .execute(
583 &cmd,
584 &[
585 &uaid.simple().to_string(),
586 &message.channel_id.simple().to_string(),
587 &message.chidmessageid(),
588 &message.version,
589 &(message.ttl as i64), &(util::sec_since_epoch() as i64 + message.ttl as i64),
591 &message.topic,
592 &(message.timestamp as i64),
593 &message.data.unwrap_or_default(),
594 &message.sortkey_timestamp.map(|v| v as i64),
595 &json!(message.headers).to_string(),
596 #[cfg(feature = "reliable_report")]
597 &message.reliability_id,
598 ],
599 )
600 .await
601 .map_err(DbError::PgError)?;
602 Ok(())
603 }
604
605 async fn remove_message(&self, uaid: &Uuid, sort_key: &str) -> DbResult<()> {
607 self.pool
608 .get()
609 .await
610 .map_err(DbError::PgPoolError)?
611 .execute(
612 &format!(
613 "DELETE FROM {tablename}
614 WHERE uaid=$1 AND chid_message_id = $2;",
615 tablename = self.message_table()
616 ),
617 &[&uaid.simple().to_string(), &(sort_key.to_owned())],
618 )
619 .await
620 .map_err(DbError::PgError)?;
621
622 Ok(())
623 }
624
625 async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
626 for message in messages {
627 self.save_message(uaid, message).await?;
628 }
629 Ok(())
630 }
631
632 async fn fetch_topic_messages(
635 &self,
636 uaid: &Uuid,
637 limit: usize,
638 ) -> DbResult<FetchMessageResponse> {
639 let messages: Vec<Notification> = self
640 .pool
641 .get()
642 .await
643 .map_err(DbError::PgPoolError)?
644 .query(
645 &format!(
646 "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers
647 FROM {tablename}
648 WHERE uaid=$1 AND expiry >= $2
649 ORDER BY timestamp DESC
650 LIMIT $3",
651 tablename=&self.message_table(),
652 ),
653 &[&uaid.simple().to_string(),&(util::sec_since_epoch() as i64), &(limit as i64)],
654 )
655 .await
656 .map_err(DbError::PgError)?
657 .iter()
658 .map(|row: &Row| row.try_into())
659 .collect::<Result<Vec<Notification>, DbError>>()?;
660
661 if messages.is_empty() {
662 Ok(Default::default())
663 } else {
664 Ok(FetchMessageResponse {
665 timestamp: Some(messages[0].timestamp),
666 messages,
667 })
668 }
669 }
670
671 async fn fetch_timestamp_messages(
673 &self,
674 uaid: &Uuid,
675 timestamp: Option<u64>,
676 limit: usize,
677 ) -> DbResult<FetchMessageResponse> {
678 let uaid = uaid.simple().to_string();
679 let response: Vec<Row> = if let Some(ts) = timestamp {
680 trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
681 self.pool
682 .get()
683 .await
684 .map_err(DbError::PgPoolError)?
685 .query(
686 &format!(
687 "SELECT * FROM {}
688 WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
689 ORDER BY timestamp
690 LIMIT $3",
691 self.message_table()
692 ),
693 &[
694 &uaid,
695 &(ts as i64),
696 &(limit as i64),
697 &(util::sec_since_epoch() as i64),
698 ],
699 )
700 .await
701 } else {
702 trace!("📮 Fetching messages for user {}", &uaid);
703 self.pool
704 .get()
705 .await
706 .map_err(DbError::PgPoolError)?
707 .query(
708 &format!(
709 "SELECT *
710 FROM {}
711 WHERE uaid = $1
712 AND expiry >= $2
713 LIMIT $2",
714 self.message_table()
715 ),
716 &[&uaid, &(limit as i64), &(util::sec_since_epoch() as i64)],
717 )
718 .await
719 }?;
720
721 let messages: Vec<Notification> = response
722 .iter()
723 .map(|row: &Row| row.try_into())
724 .collect::<Result<Vec<Notification>, DbError>>()?;
725 let timestamp = if !messages.is_empty() {
726 Some(messages[0].timestamp)
727 } else {
728 None
729 };
730
731 Ok(FetchMessageResponse {
732 timestamp,
733 messages,
734 })
735 }
736
737 async fn router_table_exists(&self) -> DbResult<bool> {
739 self.table_exists(self.router_table()).await
740 }
741
742 async fn message_table_exists(&self) -> DbResult<bool> {
744 self.table_exists(self.message_table()).await
745 }
746
747 #[cfg(feature = "reliable_report")]
748 async fn log_report(
749 &self,
750 reliability_id: &str,
751 new_state: crate::reliability::ReliabilityState,
752 ) -> DbResult<()> {
753 let timestamp =
754 SystemTime::now() + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64);
755 debug!("📮 Logging report for {reliability_id} as {new_state}");
756 let tablename = &self.reliability_table();
763 let state = new_state.to_string();
764 self.pool
765 .get()
766 .await?
767 .execute(
768 &format!(
769 "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
770 ON CONFLICT (id) DO
771 UPDATE SET states = EXCLUDED.states,
772 last_update_timestamp = EXCLUDED.last_update_timestamp;",
773 tablename = tablename
774 ),
775 &[&reliability_id, &state, ×tamp],
776 )
777 .await?;
778 Ok(())
779 }
780
781 async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
782 debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
783 let tablename = &self.router_table();
784
785 trace!("📮 Purging git{uaid} for < {timestamp}");
786 let mut pool = self.pool.get().await.map_err(DbError::PgPoolError)?;
787 let transaction = pool.transaction().await?;
788 transaction
790 .execute(
791 &format!(
792 "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
793 &self.message_table()
794 ),
795 &[
796 &uaid.simple().to_string(),
797 &(util::sec_since_epoch() as i64),
798 ],
799 )
800 .await?;
801 transaction
803 .execute(
804 &format!(
805 "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
806 &self.message_table()
807 ),
808 &[&uaid.simple().to_string(), &(timestamp as i64)],
809 )
810 .await?;
811 transaction.execute(
812 &format!(
813 "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
814 ),
815 &[
816 &uaid.simple().to_string(),
817 &(timestamp as i64),
818 &(self.router_expiry() as i64),
819 ],
820 )
821 .await?;
822 transaction.commit().await?;
823 Ok(())
824 }
825
826 fn name(&self) -> String {
827 "Postgres".to_owned()
828 }
829
830 async fn health_check(&self) -> DbResult<bool> {
831 let client = self.pool.get().await.map_err(DbError::PgPoolError);
833 let row = client?.query_one("select true", &[]).await;
834 Ok(!row?.is_empty())
835 }
836
837 fn box_clone(&self) -> Box<dyn DbClient> {
839 Box::new(self.clone())
840 }
841}
842
843#[cfg(test)]
856mod tests {
857 use crate::util::sec_since_epoch;
858 use crate::{logging::init_test_logging, util::ms_since_epoch};
859 use rand::prelude::*;
860 use serde_json::json;
861 use std::env;
862
863 use super::*;
864 const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
865 const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
866
867 fn new_client() -> DbResult<PgClientImpl> {
868 let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
870 let env_dsn = format!("postgres://{host}");
871 debug!("📮 Connecting to {env_dsn}");
872 let settings = DbSettings {
873 dsn: Some(env_dsn),
874 db_settings: json!(PostgresDbSettings {
875 schema: Some("autopush".to_owned()),
876 ..Default::default()
877 })
878 .to_string(),
879 };
880 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
881 PgClientImpl::new(metrics, &settings)
882 }
883
884 fn gen_test_user() -> String {
885 let mut rng = rand::rng();
887 let test_num = rng.random::<u8>();
888 format!(
889 "DEADBEEF-0000-0000-{:04}-{:012}",
890 test_num,
891 sec_since_epoch()
892 )
893 }
894
895 #[actix_rt::test]
896 async fn health_check() {
897 let client = new_client().unwrap();
898
899 let result = client.health_check().await;
900 assert!(result.is_ok());
901 assert!(result.unwrap());
902 }
903
904 #[actix_rt::test]
906 async fn wipe_expired() -> DbResult<()> {
907 init_test_logging();
908 let client = new_client()?;
909
910 let connected_at = ms_since_epoch();
911
912 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
913 let chid = Uuid::parse_str(TEST_CHID).unwrap();
914
915 let node_id = "test_node".to_owned();
916
917 let _ = client.remove_user(&uaid).await;
919
920 let test_user = User {
921 uaid,
922 router_type: "webpush".to_owned(),
923 connected_at,
924 router_data: None,
925 node_id: Some(node_id.clone()),
926 ..Default::default()
927 };
928
929 let _ = client.remove_user(&uaid).await;
932
933 let timestamp = sec_since_epoch();
935 client.add_user(&test_user).await?;
936 let test_notification = crate::db::Notification {
937 channel_id: chid,
938 version: "test".to_owned(),
939 ttl: 1,
940 timestamp,
941 data: Some("Encrypted".into()),
942 sortkey_timestamp: Some(timestamp),
943 ..Default::default()
944 };
945 client.save_message(&uaid, test_notification).await?;
946 client.increment_storage(&uaid, timestamp + 1).await?;
947 let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
948 assert_eq!(msgs.messages.len(), 0);
949 assert!(client.remove_user(&uaid).await.is_ok());
950 Ok(())
951 }
952
953 #[actix_rt::test]
956 async fn run_gauntlet() -> DbResult<()> {
957 init_test_logging();
958 let client = new_client()?;
959
960 let connected_at = ms_since_epoch();
961
962 let user_id = &gen_test_user();
963 let uaid = Uuid::parse_str(user_id).unwrap();
964 let chid = Uuid::parse_str(TEST_CHID).unwrap();
965 let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
966
967 let node_id = "test_node".to_owned();
968
969 let _ = client.remove_user(&uaid).await;
971
972 let test_user = User {
973 uaid,
974 router_type: "webpush".to_owned(),
975 connected_at,
976 router_data: None,
977 node_id: Some(node_id.clone()),
978 ..Default::default()
979 };
980
981 let _ = client.remove_user(&uaid).await;
984
985 trace!("📮 Adding user {}", &user_id);
987 client.add_user(&test_user).await?;
988 let fetched = client.get_user(&uaid).await?;
989 assert!(fetched.is_some());
990 let fetched = fetched.unwrap();
991 assert_eq!(fetched.router_type, "webpush".to_owned());
992
993 trace!("📮 Adding channel {} to user {}", &chid, &user_id);
995 client.add_channel(&uaid, &chid).await?;
996 let channels = client.get_channels(&uaid).await?;
997 assert!(channels.contains(&chid));
998
999 let mut new_channels: HashSet<Uuid> = HashSet::new();
1001 trace!("📮 Adding multiple channels to user {}", &user_id);
1002 new_channels.insert(chid);
1003 for _ in 1..10 {
1004 new_channels.insert(uuid::Uuid::new_v4());
1005 }
1006 let chid_to_remove = uuid::Uuid::new_v4();
1007 trace!(
1008 "📮 Adding removable channel {} to user {}",
1009 &chid_to_remove,
1010 &user_id
1011 );
1012 new_channels.insert(chid_to_remove);
1013 client.add_channels(&uaid, new_channels.clone()).await?;
1014 let channels = client.get_channels(&uaid).await?;
1015 assert_eq!(channels, new_channels);
1016
1017 trace!(
1019 "📮 Removing channel {} from user {}",
1020 &chid_to_remove,
1021 &user_id
1022 );
1023 assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1024 trace!(
1025 "📮 retrying Removing channel {} from user {}",
1026 &chid_to_remove,
1027 &user_id
1028 );
1029 assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1030 new_channels.remove(&chid_to_remove);
1031 let channels = client.get_channels(&uaid).await?;
1032 assert_eq!(channels, new_channels);
1033
1034 let mut updated = User {
1038 connected_at,
1039 ..test_user.clone()
1040 };
1041 trace!(
1042 "📮 Attempting to update user {} with old connected_at: {}",
1043 &user_id,
1044 &updated.connected_at
1045 );
1046 let result = client.update_user(&mut updated).await;
1047 assert!(result.is_ok());
1048 assert!(!result.unwrap());
1049
1050 let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1052 assert_eq!(fetched.connected_at, fetched2.connected_at);
1053
1054 let mut updated = User {
1056 connected_at: fetched.connected_at + 300,
1057 ..fetched2
1058 };
1059 trace!(
1060 "📮 Attempting to update user {} with new connected_at",
1061 &user_id
1062 );
1063 let result = client.update_user(&mut updated).await;
1064 assert!(result.is_ok());
1065 assert!(result.unwrap());
1066 assert_ne!(
1067 fetched2.connected_at,
1068 client.get_user(&uaid).await?.unwrap().connected_at
1069 );
1070 trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1072 client
1073 .increment_storage(&fetched.uaid, sec_since_epoch())
1074 .await?;
1075
1076 let test_data = "An_encrypted_pile_of_crap".to_owned();
1077 let timestamp = sec_since_epoch();
1078 let sort_key = sec_since_epoch();
1079 let fetch_timestamp = timestamp;
1080 let test_notification = crate::db::Notification {
1082 channel_id: chid,
1083 version: "test".to_owned(),
1084 ttl: 300,
1085 timestamp,
1086 data: Some(test_data.clone()),
1087 sortkey_timestamp: Some(sort_key),
1088 ..Default::default()
1089 };
1090 trace!("📮 Saving message for user {}", &user_id);
1091 let res = client.save_message(&uaid, test_notification.clone()).await;
1092 assert!(res.is_ok());
1093
1094 trace!("📮 Fetching all messages for user {}", &user_id);
1095 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1096 assert_ne!(fetched.messages.len(), 0);
1097 let fm = fetched.messages.pop().unwrap();
1098 assert_eq!(fm.channel_id, test_notification.channel_id);
1099 assert_eq!(fm.data, Some(test_data));
1100
1101 trace!(
1103 "📮 Fetching messages for user {} within the past 10 seconds",
1104 &user_id
1105 );
1106 let fetched = client
1107 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1108 .await?;
1109 assert_ne!(fetched.messages.len(), 0);
1110
1111 trace!(
1113 "📮 Fetching messages for user {} 10 seconds in the future",
1114 &user_id
1115 );
1116 let fetched = client
1117 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1118 .await?;
1119 assert_eq!(fetched.messages.len(), 0);
1120
1121 trace!(
1123 "📮 Removing message for user {} :: {}",
1124 &user_id,
1125 &test_notification.chidmessageid()
1126 );
1127 assert!(client
1128 .remove_message(&uaid, &test_notification.chidmessageid())
1129 .await
1130 .is_ok());
1131
1132 trace!("📮 Removing channel for user {}", &user_id);
1133 assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1134
1135 trace!("📮 Making sure no messages remain for user {}", &user_id);
1136 let msgs = client
1137 .fetch_timestamp_messages(&uaid, None, 999)
1138 .await?
1139 .messages;
1140 assert!(msgs.is_empty());
1141
1142 client.add_channel(&uaid, &topic_chid).await?;
1146 let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1147 let timestamp = sec_since_epoch();
1148 let sort_key = sec_since_epoch();
1149
1150 let test_notification_0 = crate::db::Notification {
1152 channel_id: topic_chid,
1153 version: "version0".to_owned(),
1154 ttl: 300,
1155 topic: Some("topic".to_owned()),
1156 timestamp,
1157 data: Some(test_data.clone()),
1158 sortkey_timestamp: Some(sort_key),
1159 ..Default::default()
1160 };
1161 assert!(client
1162 .save_message(&uaid, test_notification_0.clone())
1163 .await
1164 .is_ok());
1165
1166 let test_notification = crate::db::Notification {
1167 timestamp: sec_since_epoch(),
1168 version: "version1".to_owned(),
1169 sortkey_timestamp: Some(sort_key + 10),
1170 ..test_notification_0
1171 };
1172
1173 assert!(client
1174 .save_message(&uaid, test_notification.clone())
1175 .await
1176 .is_ok());
1177
1178 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1179 assert_eq!(fetched.messages.len(), 1);
1180 let fm = fetched.messages.pop().unwrap();
1181 assert_eq!(fm.channel_id, test_notification.channel_id);
1182 assert_eq!(fm.data, Some(test_data));
1183
1184 let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1186 assert_ne!(fetched.messages.len(), 0);
1187
1188 assert!(client
1190 .remove_message(&uaid, &test_notification.chidmessageid())
1191 .await
1192 .is_ok());
1193
1194 assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1195
1196 let msgs = client
1197 .fetch_timestamp_messages(&uaid, None, 999)
1198 .await?
1199 .messages;
1200 assert!(msgs.is_empty());
1201
1202 let fetched = client.get_user(&uaid).await?.unwrap();
1203 assert!(client
1204 .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1205 .await
1206 .is_ok());
1207 let fetched = client.get_user(&uaid).await?.unwrap();
1209 assert_eq!(fetched.node_id, None);
1210
1211 assert!(client.remove_user(&uaid).await.is_ok());
1212
1213 assert!(client.get_user(&uaid).await?.is_none());
1214 Ok(())
1215 }
1216
1217 #[actix_rt::test]
1218 async fn test_expiry() -> DbResult<()> {
1219 init_test_logging();
1221 let client = new_client()?;
1222
1223 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1224 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1225 let now = sec_since_epoch();
1226
1227 let test_notification = crate::db::Notification {
1228 channel_id: chid,
1229 version: "test".to_owned(),
1230 ttl: 2,
1231 timestamp: now,
1232 data: Some("SomeData".into()),
1233 sortkey_timestamp: Some(now),
1234 ..Default::default()
1235 };
1236 client
1237 .add_user(&User {
1238 uaid,
1239 router_type: "test".to_owned(),
1240 connected_at: ms_since_epoch(),
1241 ..Default::default()
1242 })
1243 .await?;
1244 client.add_channel(&uaid, &chid).await?;
1245 debug!("🧪Writing test notif");
1246 client
1247 .save_message(&uaid, test_notification.clone())
1248 .await?;
1249 let key = uaid.simple().to_string();
1250 debug!("🧪Checking {}...", &key);
1251 let msg = client
1252 .fetch_timestamp_messages(&uaid, None, 1)
1253 .await?
1254 .messages
1255 .pop();
1256 assert!(msg.is_some());
1257 debug!("🧪Purging...");
1258 client.increment_storage(&uaid, now + 2).await?;
1259 debug!("🧪Checking for empty {}...", &key);
1260 let cc = client
1261 .fetch_timestamp_messages(&uaid, None, 1)
1262 .await?
1263 .messages
1264 .pop();
1265 assert!(cc.is_none());
1266 assert!(client.remove_user(&uaid).await.is_ok());
1268 Ok(())
1269 }
1270}