1use std::collections::HashSet;
11use std::str::FromStr;
12use std::sync::Arc;
13#[cfg(feature = "reliable_report")]
14use std::time::{Duration, SystemTime};
15
16#[cfg(feature = "reliable_report")]
17use chrono::TimeDelta;
18use deadpool_postgres::{Pool, Runtime};
19use serde_json::json;
20use tokio_postgres::{types::ToSql, Row}; use async_trait::async_trait;
23use cadence::StatsdClient;
24use serde::{Deserialize, Serialize};
25use uuid::Uuid;
26
27use crate::db::client::DbClient;
28use crate::db::error::{DbError, DbResult};
29use crate::db::{DbSettings, User};
30use crate::notification::Notification;
31use crate::{util, MAX_ROUTER_TTL_SECS};
32
33use super::client::FetchMessageResponse;
34
35#[cfg(feature = "reliable_report")]
36const RELIABLE_LOG_TTL: TimeDelta = TimeDelta::days(60);
37
38#[derive(Debug, Clone, Deserialize, Serialize)]
39#[serde(default)]
40pub struct PostgresDbSettings {
41 pub schema: Option<String>, pub router_table: String, pub message_table: String, pub meta_table: String, pub reliability_table: String, max_router_ttl: u64, }
50
51impl Default for PostgresDbSettings {
52 fn default() -> Self {
53 Self {
54 schema: None,
55 router_table: "router".to_owned(),
56 message_table: "message".to_owned(),
57 meta_table: "meta".to_owned(),
58 reliability_table: "reliability".to_owned(),
59 max_router_ttl: MAX_ROUTER_TTL_SECS,
60 }
62 }
63}
64
65impl TryFrom<&str> for PostgresDbSettings {
66 type Error = DbError;
67 fn try_from(setting_string: &str) -> Result<Self, Self::Error> {
68 if setting_string.trim().is_empty() {
69 return Ok(PostgresDbSettings::default());
70 }
71 serde_json::from_str(setting_string).map_err(|e| {
72 DbError::General(format!(
73 "Could not parse configuration db_settings: {:?}",
74 e
75 ))
76 })
77 }
78}
79
80#[derive(Clone)]
81pub struct PgClientImpl {
82 _metrics: Arc<StatsdClient>,
83 db_settings: PostgresDbSettings,
84 pool: Pool,
85 cached_router_table: String,
87 cached_message_table: String,
88 cached_meta_table: String,
89 #[cfg(feature = "reliable_report")]
90 cached_reliability_table: String,
91}
92
93impl PgClientImpl {
94 pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
102 let db_settings = PostgresDbSettings::try_from(settings.db_settings.as_ref())?;
103 info!(
104 "📮 Initializing Postgres DB Client with settings: {:?} from {:?}",
105 db_settings, &settings.db_settings
106 );
107 let tls_flag = tokio_postgres::NoTls;
109 if let Some(dsn) = settings.dsn.clone() {
110 trace!("📮 Postgres Connect {}", &dsn);
111
112 let pool = deadpool_postgres::Config {
113 url: Some(dsn.clone()),
114 ..Default::default()
115 }
116 .create_pool(Some(Runtime::Tokio1), tls_flag)
117 .map_err(|e| DbError::General(e.to_string()))?;
118 let cached_router_table = if let Some(schema) = &db_settings.schema {
119 format!("{}.{}", schema, db_settings.router_table)
120 } else {
121 db_settings.router_table.clone()
122 };
123 let cached_message_table = if let Some(schema) = &db_settings.schema {
124 format!("{}.{}", schema, db_settings.message_table)
125 } else {
126 db_settings.message_table.clone()
127 };
128 let cached_meta_table = if let Some(schema) = &db_settings.schema {
129 format!("{}.{}", schema, db_settings.meta_table)
130 } else {
131 db_settings.meta_table.clone()
132 };
133 #[cfg(feature = "reliable_report")]
134 let cached_reliability_table = if let Some(schema) = &db_settings.schema {
135 format!("{}.{}", schema, db_settings.reliability_table)
136 } else {
137 db_settings.reliability_table.clone()
138 };
139 return Ok(Self {
140 _metrics: metrics,
141 db_settings,
142 pool,
143 cached_router_table,
144 cached_message_table,
145 cached_meta_table,
146 #[cfg(feature = "reliable_report")]
147 cached_reliability_table,
148 });
149 };
150 Err(DbError::ConnectionError("No DSN specified".to_owned()))
151 }
152
153 async fn table_exists(&self, table_name: &str) -> DbResult<bool> {
155 let (schema, table_name) = if table_name.contains('.') {
156 let mut parts = table_name.splitn(2, '.');
157 (
158 parts.next().unwrap_or("public").to_owned(),
159 parts.next().unwrap().to_owned(),
162 )
163 } else {
164 ("public".to_owned(), table_name.to_owned())
165 };
166 let rows = self
167 .pool
168 .get()
169 .await
170 .map_err(DbError::PgPoolError)?
171 .query(
172 "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname=$1 AND tablename=$2);",
173 &[&schema, &table_name],
174 )
175 .await
176 .map_err(DbError::PgError)?;
177 let val: bool = rows[0].try_get(0)?;
178 Ok(val)
179 }
180
181 fn router_expiry(&self) -> u64 {
183 util::sec_since_epoch() + self.db_settings.max_router_ttl
184 }
185
186 pub(crate) fn router_table(&self) -> &str {
188 &self.cached_router_table
189 }
190
191 pub(crate) fn message_table(&self) -> &str {
193 &self.cached_message_table
194 }
195
196 pub(crate) fn meta_table(&self) -> &str {
200 &self.cached_meta_table
201 }
202
203 #[cfg(feature = "reliable_report")]
207 pub(crate) fn reliability_table(&self) -> &str {
208 &self.cached_reliability_table
209 }
210
211 pub(crate) fn error_to_string(e: &tokio_postgres::Error) -> String {
212 e.as_db_error()
213 .map(|e| e.message().to_owned()) .unwrap_or_else(|| {
215 e.to_string()
217 })
218 }
219}
220
221#[async_trait]
222impl DbClient for PgClientImpl {
223 async fn add_user(&self, user: &User) -> DbResult<()> {
225 self.pool.get().await.map_err(DbError::PgPoolError)?.execute(
226 &format!("
227 INSERT INTO {tablename} (uaid, connected_at, router_type, router_data, node_id, record_version, version, last_update, priv_channels, expiry)
228 VALUES($1, $2::BIGINT, $3, $4, $5, $6::BIGINT, $7, $8::BIGINT, $9, $10::BIGINT)
229 ON CONFLICT (uaid) DO
230 UPDATE SET connected_at=EXCLUDED.connected_at,
231 router_type=EXCLUDED.router_type,
232 router_data=EXCLUDED.router_data,
233 node_id=EXCLUDED.node_id,
234 record_version=EXCLUDED.record_version,
235 version=EXCLUDED.version,
236 last_update=EXCLUDED.last_update,
237 priv_channels=EXCLUDED.priv_channels,
238 expiry=EXCLUDED.expiry
239 ;
240 ", tablename=self.router_table()),
241 &[&user.uaid.simple().to_string(), &(user.connected_at as i64), &user.router_type, &json!(user.router_data).to_string(), &user.node_id, &user.record_version.map(|i| i as i64), &(user.version.map(|v| v.simple().to_string())), &user.current_timestamp.map(|i| i as i64), &user.priv_channels.iter().map(|v| v.to_string()).collect::<Vec<String>>(),
250 &(self.router_expiry() as i64), ]
252 ).await.map_err(|e| {
253 DbError::PgDbError(Self::error_to_string(&e))
254 })?;
255 Ok(())
256 }
257
258 async fn update_user(&self, user: &mut User) -> DbResult<bool> {
260 let cmd = format!(
261 "UPDATE {tablename} SET connected_at=$2::BIGINT,
262 router_type=$3,
263 router_data=$4,
264 node_id=$5,
265 record_version=$6::BIGINT,
266 version=$7,
267 last_update=$8::BIGINT,
268 priv_channels=$9,
269 expiry=$10::BIGINT
270 WHERE
271 uaid = $1 AND connected_at < $2::BIGINT;
272 ",
273 tablename = self.router_table()
274 );
275 let result = self
276 .pool
277 .get()
278 .await
279 .map_err(DbError::PgPoolError)?
280 .execute(
281 &cmd,
282 &[
283 &user.uaid.simple().to_string(), &(user.connected_at as i64), &user.router_type, &json!(user.router_data).to_string(), &user.node_id, &(user.record_version.map(|i| i as i64)), &(user.version.map(|v| v.simple().to_string())), &user.current_timestamp.map(|i| i as i64), &user
292 .priv_channels
293 .iter()
294 .map(|v| v.to_string())
295 .collect::<Vec<String>>(),
296 &(self.router_expiry() as i64), ],
298 )
299 .await
300 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
301 Ok(result > 0)
302 }
303
304 async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
306 let row = self.pool.get().await.map_err(DbError::PgPoolError)?
307 .query_opt(
308 &format!(
309 "SELECT connected_at, router_type, router_data, node_id, record_version, last_update, version, priv_channels
310 FROM {tablename}
311 WHERE uaid = $1",
312 tablename=self.router_table()
313 ),
314 &[&uaid.simple().to_string()]
315 )
316 .await
317 .map_err(|e| {
318 DbError::PgDbError(Self::error_to_string(&e))
319 })?;
320
321 let Some(row) = row else {
322 return Ok(None);
323 };
324
325 let priv_channels = if let Ok(Some(channels)) =
328 row.try_get::<&str, Option<Vec<String>>>("priv_channels")
329 {
330 let mut priv_channels = HashSet::new();
331 for channel in channels.iter() {
332 let uuid = Uuid::from_str(channel).map_err(|e| DbError::General(e.to_string()))?;
333 priv_channels.insert(uuid);
334 }
335 priv_channels
336 } else {
337 HashSet::new()
338 };
339 let resp = User {
340 uaid: *uaid,
341 connected_at: row
342 .try_get::<&str, i64>("connected_at")
343 .map_err(DbError::PgError)? as u64,
344 router_type: row
345 .try_get::<&str, String>("router_type")
346 .map_err(DbError::PgError)?,
347 router_data: serde_json::from_str(
348 row.try_get::<&str, &str>("router_data")
349 .map_err(DbError::PgError)?,
350 )
351 .map_err(|e| DbError::General(e.to_string()))?,
352 node_id: row
353 .try_get::<&str, Option<String>>("node_id")
354 .map_err(DbError::PgError)?,
355 record_version: row
356 .try_get::<&str, Option<i64>>("record_version")
357 .map_err(DbError::PgError)?
358 .map(|v| v as u64),
359 current_timestamp: row
360 .try_get::<&str, Option<i64>>("last_update")
361 .map_err(DbError::PgError)?
362 .map(|v| v as u64),
363 version: row
364 .try_get::<&str, Option<String>>("version")
365 .map_err(DbError::PgError)?
366 .map(|v| {
368 Uuid::from_str(&v).map_err(|e| {
369 DbError::Integrity("Invalid UUID found".to_owned(), Some(e.to_string()))
370 })
371 })
372 .transpose()?,
373 priv_channels,
374 };
375 Ok(Some(resp))
376 }
377
378 async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
380 self.pool
381 .get()
382 .await?
383 .execute(
384 &format!(
385 "DELETE FROM {tablename}
386 WHERE uaid = $1",
387 tablename = self.router_table()
388 ),
389 &[&uaid.simple().to_string()],
390 )
391 .await
392 .map_err(DbError::PgError)?;
393 Ok(())
394 }
395
396 async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
405 self.pool
406 .get()
407 .await
408 .map_err(DbError::PgPoolError)?
409 .execute(
410 &format!(
411 "INSERT
412 INTO {tablename} (uaid, channel_id) VALUES ($1, $2)
413 ON CONFLICT DO NOTHING",
414 tablename = self.meta_table()
415 ),
416 &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
417 )
418 .await
419 .map_err(DbError::PgError)?;
420 Ok(())
421 }
422
423 async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
425 if channels.is_empty() {
426 trace!("📮 No channels to save.");
427 return Ok(());
428 };
429 let uaid_str = uaid.simple().to_string();
430 let mut params = Vec::<&(dyn ToSql + Sync)>::new();
442 let mut new_channels = Vec::<String>::new();
443 for channel in channels {
444 new_channels.push(channel.simple().to_string());
445 }
446
447 #[allow(clippy::needless_range_loop)]
451 for i in 0..new_channels.len() {
452 params.push(&uaid_str);
453 params.push(&new_channels[i]);
454 }
455
456 let statement = format!(
460 "INSERT
461 INTO {tablename} (uaid, channel_id)
462 VALUES {vars}
463 ON CONFLICT DO NOTHING",
464 tablename = self.meta_table(),
465 vars = Vec::from_iter((1..params.len() + 1).step_by(2).map(|v| format!(
467 "(${}, ${})",
468 v,
469 v + 1
470 )))
471 .join(",")
472 );
473 self.pool.get().await?.execute(&statement, ¶ms).await?;
475 Ok(())
476 }
477
478 async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
480 let mut result = HashSet::new();
481 let rows = self
482 .pool
483 .get()
484 .await
485 .map_err(DbError::PgPoolError)?
486 .query(
487 &format!(
488 "SELECT distinct channel_id FROM {tablename} WHERE uaid = $1;",
489 tablename = self.meta_table()
490 ),
491 &[&uaid.simple().to_string()],
492 )
493 .await
494 .map_err(DbError::PgError)?;
495 for row in rows.iter() {
496 let s = row
497 .try_get::<&str, &str>("channel_id")
498 .map_err(DbError::PgError)?;
499 result.insert(Uuid::from_str(s).map_err(|e| DbError::General(e.to_string()))?);
500 }
501 Ok(result)
502 }
503
504 async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
506 let cmd = format!(
507 "DELETE FROM {tablename}
508 WHERE uaid = $1 AND channel_id = $2;",
509 tablename = self.meta_table()
510 );
511 let result = self
512 .pool
513 .get()
514 .await
515 .map_err(DbError::PgPoolError)?
516 .execute(
517 &cmd,
518 &[&uaid.simple().to_string(), &channel_id.simple().to_string()],
519 )
520 .await?;
521 Ok(result > 0)
523 }
524
525 async fn remove_node_id(
527 &self,
528 uaid: &Uuid,
529 node_id: &str,
530 connected_at: u64,
531 version: &Option<Uuid>,
532 ) -> DbResult<bool> {
533 let Some(version) = version else {
534 return Err(DbError::General("Expected a user version field".to_owned()));
535 };
536 self.pool
537 .get()
538 .await
539 .map_err(DbError::PgPoolError)?
540 .execute(
541 &format!(
542 "UPDATE {tablename}
543 SET node_id = null
544 WHERE uaid=$1 AND node_id = $2 AND connected_at = $3 AND version= $4;",
545 tablename = self.router_table()
546 ),
547 &[
548 &uaid.simple().to_string(),
549 &node_id,
550 &(connected_at as i64),
551 &version.simple().to_string(),
552 ],
553 )
554 .await
555 .map_err(DbError::PgError)?;
556 Ok(true)
557 }
558
559 async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
561 #[allow(unused_mut)]
564 let mut fields = vec![
565 "uaid",
566 "channel_id",
567 "chid_message_id",
568 "version",
569 "ttl",
570 "expiry",
571 "topic",
572 "timestamp",
573 "data",
574 "sortkey_timestamp",
575 "headers",
576 ];
577 #[allow(unused_mut)]
579 let mut inputs = vec![
580 "$1", "$2", "$3", "$4", "$5", "$6", "$7", "$8", "$9", "$10", "$11",
581 ];
582 #[cfg(feature = "reliable_report")]
583 {
584 fields.append(&mut ["reliability_id"].to_vec());
585 inputs.append(&mut ["$12"].to_vec());
586 }
587 let cmd = format!(
588 "INSERT INTO {tablename}
589 ({fields})
590 VALUES
591 ({inputs}) ON CONFLICT (chid_message_id) DO UPDATE SET
592 uaid=EXCLUDED.uaid,
593 channel_id=EXCLUDED.channel_id,
594 version=EXCLUDED.version,
595 ttl=EXCLUDED.ttl,
596 expiry=EXCLUDED.expiry,
597 topic=EXCLUDED.topic,
598 timestamp=EXCLUDED.timestamp,
599 data=EXCLUDED.data,
600 sortkey_timestamp=EXCLUDED.sortkey_timestamp,
601 headers=EXCLUDED.headers",
602 tablename = &self.message_table(),
603 fields = fields.join(","),
604 inputs = inputs.join(",")
605 );
606 self.pool
607 .get()
608 .await
609 .map_err(DbError::PgPoolError)?
610 .execute(
611 &cmd,
612 &[
613 &uaid.simple().to_string(),
614 &message.channel_id.simple().to_string(),
615 &message.chidmessageid(),
616 &message.version,
617 &(message.ttl as i64), &(util::sec_since_epoch() as i64 + message.ttl as i64),
619 &message.topic.as_ref().filter(|v| !v.is_empty()),
620 &(message.timestamp as i64),
621 &message.data.as_deref().unwrap_or_default(),
622 &message.sortkey_timestamp.map(|v| v as i64),
623 &json!(message.headers).to_string(),
624 #[cfg(feature = "reliable_report")]
625 &message.reliability_id,
626 ],
627 )
628 .await
629 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
630 Ok(())
631 }
632
633 async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
635 debug!(
636 "📮 Removing message for user {} with chid_message_id {}",
637 uaid.simple(),
638 chidmessageid
639 );
640 let result = self
641 .pool
642 .get()
643 .await
644 .map_err(DbError::PgPoolError)?
645 .execute(
646 &format!(
647 "DELETE FROM {tablename}
648 WHERE uaid=$1 AND chid_message_id = $2;",
649 tablename = self.message_table()
650 ),
651 &[&uaid.simple().to_string(), &chidmessageid],
652 )
653 .await
654 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
655 debug!(
656 "📮 Deleted {} rows for user {} with chid_message_id {}",
657 result,
658 uaid.simple(),
659 chidmessageid
660 );
661 Ok(())
662 }
663
664 async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
665 if messages.is_empty() {
666 return Ok(());
667 }
668 if messages.len() == 1 {
669 return self
671 .save_message(uaid, messages.into_iter().next().unwrap())
672 .await;
673 }
674
675 #[cfg(not(feature = "reliable_report"))]
676 let fields_per_row: usize = 11;
677 #[cfg(feature = "reliable_report")]
678 let fields_per_row: usize = 12;
679
680 let field_names = {
681 #[allow(unused_mut)]
682 let mut fields = vec![
683 "uaid",
684 "channel_id",
685 "chid_message_id",
686 "version",
687 "ttl",
688 "expiry",
689 "topic",
690 "timestamp",
691 "data",
692 "sortkey_timestamp",
693 "headers",
694 ];
695 #[cfg(feature = "reliable_report")]
696 fields.push("reliability_id");
697 fields.join(",")
698 };
699
700 let mut value_rows = Vec::with_capacity(messages.len());
702 for i in 0..messages.len() {
703 let base = i * fields_per_row;
704 let params: Vec<String> = (1..=fields_per_row)
705 .map(|j| format!("${}", base + j))
706 .collect();
707 value_rows.push(format!("({})", params.join(",")));
708 }
709
710 let cmd = format!(
711 "INSERT INTO {tablename}
712 ({field_names})
713 VALUES
714 {values} ON CONFLICT (chid_message_id) DO UPDATE SET
715 uaid=EXCLUDED.uaid,
716 channel_id=EXCLUDED.channel_id,
717 version=EXCLUDED.version,
718 ttl=EXCLUDED.ttl,
719 expiry=EXCLUDED.expiry,
720 topic=EXCLUDED.topic,
721 timestamp=EXCLUDED.timestamp,
722 data=EXCLUDED.data,
723 sortkey_timestamp=EXCLUDED.sortkey_timestamp,
724 headers=EXCLUDED.headers",
725 tablename = &self.message_table(),
726 values = value_rows.join(",")
727 );
728
729 let uaid_str = uaid.simple().to_string();
731 let now = util::sec_since_epoch() as i64;
732
733 struct MessageParams {
735 channel_id: String,
736 chidmessageid: String,
737 version: String,
738 ttl: i64,
739 expiry: i64,
740 topic: Option<String>,
741 timestamp: i64,
742 data: String,
743 sortkey_timestamp: Option<i64>,
744 headers: String,
745 #[cfg(feature = "reliable_report")]
746 reliability_id: Option<String>,
747 }
748
749 let msg_params: Vec<MessageParams> = messages
750 .into_iter()
751 .map(|m| {
752 let topic = m.topic.as_ref().filter(|v| !v.is_empty()).cloned();
753 MessageParams {
754 channel_id: m.channel_id.simple().to_string(),
755 chidmessageid: m.chidmessageid(),
756 version: m.version,
757 ttl: m.ttl as i64,
758 expiry: now + m.ttl as i64,
759 topic,
760 timestamp: m.timestamp as i64,
761 data: m.data.as_deref().unwrap_or_default().to_owned(),
762 sortkey_timestamp: m.sortkey_timestamp.map(|v| v as i64),
763 headers: json!(m.headers).to_string(),
764 #[cfg(feature = "reliable_report")]
765 reliability_id: m.reliability_id,
766 }
767 })
768 .collect();
769
770 let mut params: Vec<&(dyn ToSql + Sync)> =
771 Vec::with_capacity(msg_params.len() * fields_per_row);
772 for mp in &msg_params {
773 params.push(&uaid_str);
774 params.push(&mp.channel_id);
775 params.push(&mp.chidmessageid);
776 params.push(&mp.version);
777 params.push(&mp.ttl);
778 params.push(&mp.expiry);
779 params.push(&mp.topic);
780 params.push(&mp.timestamp);
781 params.push(&mp.data);
782 params.push(&mp.sortkey_timestamp);
783 params.push(&mp.headers);
784 #[cfg(feature = "reliable_report")]
785 params.push(&mp.reliability_id);
786 }
787
788 self.pool
789 .get()
790 .await
791 .map_err(DbError::PgPoolError)?
792 .execute(cmd.as_str(), ¶ms)
793 .await
794 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
795 Ok(())
796 }
797
798 async fn fetch_topic_messages(
801 &self,
802 uaid: &Uuid,
803 limit: usize,
804 ) -> DbResult<FetchMessageResponse> {
805 let messages: Vec<Notification> = self
806 .pool
807 .get()
808 .await
809 .map_err(DbError::PgPoolError)?
810 .query(
811 &format!(
812 "SELECT channel_id, version, ttl, topic, timestamp, data, sortkey_timestamp, headers
813 FROM {tablename}
814 WHERE uaid=$1 AND expiry >= $2 AND (topic IS NOT NULL AND topic != '')
815 ORDER BY timestamp DESC
816 LIMIT $3",
817 tablename=&self.message_table(),
818 ),
819 &[
820 &uaid.simple().to_string(),
821 &(util::sec_since_epoch() as i64),
822 &(limit as i64),
823 ],
824 )
825 .await
826 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?
827 .iter()
828 .map(|row: &Row| row.try_into())
829 .collect::<Result<Vec<Notification>, DbError>>()?;
830
831 if messages.is_empty() {
832 Ok(Default::default())
833 } else {
834 Ok(FetchMessageResponse {
835 timestamp: Some(messages[0].timestamp),
836 messages,
837 })
838 }
839 }
840
841 async fn fetch_timestamp_messages(
843 &self,
844 uaid: &Uuid,
845 timestamp: Option<u64>,
846 limit: usize,
847 ) -> DbResult<FetchMessageResponse> {
848 let uaid = uaid.simple().to_string();
849 let response: Vec<Row> = if let Some(ts) = timestamp {
850 trace!("📮 Fetching messages for user {} since {}", &uaid, ts);
851 self.pool
852 .get()
853 .await
854 .map_err(DbError::PgPoolError)?
855 .query(
856 &format!(
857 "SELECT * FROM {}
858 WHERE uaid = $1 AND timestamp > $2 AND expiry >= $3
859 ORDER BY timestamp
860 LIMIT $4",
861 self.message_table()
862 ),
863 &[
864 &uaid,
865 &(ts as i64),
866 &(util::sec_since_epoch() as i64),
867 &(limit as i64),
868 ],
869 )
870 .await
871 } else {
872 trace!("📮 Fetching messages for user {}", &uaid);
873 self.pool
874 .get()
875 .await
876 .map_err(DbError::PgPoolError)?
877 .query(
878 &format!(
879 "SELECT *
880 FROM {}
881 WHERE uaid = $1
882 AND expiry >= $2
883 LIMIT $3",
884 self.message_table()
885 ),
886 &[&uaid, &(util::sec_since_epoch() as i64), &(limit as i64)],
887 )
888 .await
889 }
890 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
891 let messages: Vec<Notification> = response
892 .iter()
893 .map(|row: &Row| row.try_into())
894 .collect::<Result<Vec<Notification>, DbError>>()?;
895 let timestamp = if !messages.is_empty() {
896 Some(messages[0].timestamp)
897 } else {
898 None
899 };
900
901 Ok(FetchMessageResponse {
902 timestamp,
903 messages,
904 })
905 }
906
907 async fn router_table_exists(&self) -> DbResult<bool> {
909 self.table_exists(self.router_table()).await
910 }
911
912 async fn message_table_exists(&self) -> DbResult<bool> {
914 self.table_exists(self.message_table()).await
915 }
916
917 #[cfg(feature = "reliable_report")]
918 async fn log_report(
919 &self,
920 reliability_id: &str,
921 new_state: crate::reliability::ReliabilityState,
922 ) -> DbResult<()> {
923 let timestamp =
924 SystemTime::now() + Duration::from_secs(RELIABLE_LOG_TTL.num_seconds() as u64);
925 debug!("📮 Logging report for {reliability_id} as {new_state}");
926 let tablename = &self.reliability_table();
933 let state = new_state.to_string();
934 self.pool
935 .get()
936 .await?
937 .execute(
938 &format!(
939 "INSERT INTO {tablename} (id, states, last_update_timestamp) VALUES ($1, json_build_object($2, $3), $3)
940 ON CONFLICT (id) DO
941 UPDATE SET states = EXCLUDED.states,
942 last_update_timestamp = EXCLUDED.last_update_timestamp;",
943 tablename = tablename
944 ),
945 &[&reliability_id, &state, ×tamp],
946 )
947 .await
948 .map_err(|e|{DbError::PgDbError(Self::error_to_string(&e))})?;
949 Ok(())
950 }
951
952 async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
953 debug!("📮 Updating {uaid} current_timestamp:{timestamp}");
954 let tablename = &self.router_table();
955
956 trace!("📮 Purging git{uaid} for < {timestamp}");
957 let mut pool = self.pool.get().await.map_err(DbError::PgPoolError)?;
958 let transaction = pool.transaction().await?;
959 transaction
961 .execute(
962 &format!(
963 "DELETE FROM {} WHERE uaid = $1 and expiry < $2",
964 &self.message_table()
965 ),
966 &[
967 &uaid.simple().to_string(),
968 &(util::sec_since_epoch() as i64),
969 ],
970 )
971 .await
972 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
973 transaction
975 .execute(
976 &format!(
977 "DELETE FROM {} WHERE uaid = $1 AND timestamp IS NOT NULL AND timestamp < $2",
978 &self.message_table()
979 ),
980 &[&uaid.simple().to_string(), &(timestamp as i64)],
981 )
982 .await
983 .map_err(|e| DbError::PgDbError(Self::error_to_string(&e)))?;
984 transaction.execute(
985 &format!(
986 "UPDATE {tablename} SET last_update = $2::BIGINT, expiry= $3::BIGINT WHERE uaid = $1"
987 ),
988 &[
989 &uaid.simple().to_string(),
990 &(timestamp as i64),
991 &(self.router_expiry() as i64),
992 ],
993 )
994 .await.map_err(|e|{DbError::PgDbError(Self::error_to_string(&e))})?;
995 transaction.commit().await?;
996 Ok(())
997 }
998
999 fn name(&self) -> String {
1000 "Postgres".to_owned()
1001 }
1002
1003 async fn health_check(&self) -> DbResult<bool> {
1004 let client = self.pool.get().await.map_err(DbError::PgPoolError);
1006 let row = client?.query_one("select true", &[]).await;
1007 if !row?.try_get::<_, bool>(0)? {
1008 error!("📮 Failed to fetch from database");
1009 return Ok(false);
1010 }
1011 if !self.router_table_exists().await? {
1012 error!("📮 Router table does not exist");
1013 return Ok(false);
1014 }
1015 if !self.message_table_exists().await? {
1016 error!("📮 Message table does not exist");
1017 return Ok(false);
1018 }
1019 Ok(true)
1020 }
1021
1022 fn box_clone(&self) -> Box<dyn DbClient> {
1024 Box::new(self.clone())
1025 }
1026}
1027
1028#[cfg(test)]
1041mod tests {
1042 use crate::util::sec_since_epoch;
1043 use crate::{logging::init_test_logging, util::ms_since_epoch};
1044 use rand::prelude::*;
1045 use serde_json::json;
1046 use std::env;
1047
1048 use super::*;
1049 const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
1050 const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
1051
1052 fn new_client() -> DbResult<PgClientImpl> {
1053 let host = env::var("POSTGRES_HOST").unwrap_or("localhost".into());
1055 let env_dsn = format!("postgres://{host}");
1056 debug!("📮 Connecting to {env_dsn}");
1057 let settings = DbSettings {
1058 dsn: Some(env_dsn),
1059 db_settings: json!(PostgresDbSettings {
1060 schema: Some("autopush".to_owned()),
1061 ..Default::default()
1062 })
1063 .to_string(),
1064 };
1065 let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
1066 PgClientImpl::new(metrics, &settings)
1067 }
1068
1069 fn gen_test_user() -> String {
1070 let mut rng = rand::rng();
1072 let test_num = rng.random::<u8>();
1073 format!(
1074 "DEADBEEF-0000-0000-{:04}-{:012}",
1075 test_num,
1076 sec_since_epoch()
1077 )
1078 }
1079
1080 #[actix_rt::test]
1081 async fn health_check() {
1082 let client = new_client().unwrap();
1083
1084 let result = client.health_check().await;
1085 assert!(result.is_ok());
1086 assert!(result.unwrap());
1087 }
1088
1089 #[actix_rt::test]
1091 async fn wipe_expired() -> DbResult<()> {
1092 init_test_logging();
1093 let client = new_client()?;
1094
1095 let connected_at = ms_since_epoch();
1096
1097 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1098 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1099
1100 let node_id = "test_node".to_owned();
1101
1102 let _ = client.remove_user(&uaid).await;
1104
1105 let test_user = User {
1106 uaid,
1107 router_type: "webpush".to_owned(),
1108 connected_at,
1109 router_data: None,
1110 node_id: Some(node_id.clone()),
1111 ..Default::default()
1112 };
1113
1114 let _ = client.remove_user(&uaid).await;
1117
1118 let timestamp = sec_since_epoch();
1120 client.add_user(&test_user).await?;
1121 let test_notification = crate::db::Notification {
1122 channel_id: chid,
1123 version: "test".to_owned(),
1124 ttl: 1,
1125 timestamp,
1126 data: Some("Encrypted".into()),
1127 sortkey_timestamp: Some(timestamp),
1128 ..Default::default()
1129 };
1130 client.save_message(&uaid, test_notification).await?;
1131 client.increment_storage(&uaid, timestamp + 1).await?;
1132 let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1133 assert_eq!(msgs.messages.len(), 0);
1134 assert!(client.remove_user(&uaid).await.is_ok());
1135 Ok(())
1136 }
1137
1138 #[actix_rt::test]
1141 async fn run_gauntlet() -> DbResult<()> {
1142 init_test_logging();
1143 let client = new_client()?;
1144
1145 let connected_at = ms_since_epoch();
1146
1147 let user_id = &gen_test_user();
1148 let uaid = Uuid::parse_str(user_id).unwrap();
1149 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1150 let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
1151
1152 let node_id = "test_node".to_owned();
1153
1154 let _ = client.remove_user(&uaid).await;
1156
1157 let test_user = User {
1158 uaid,
1159 router_type: "webpush".to_owned(),
1160 connected_at,
1161 router_data: None,
1162 node_id: Some(node_id.clone()),
1163 ..Default::default()
1164 };
1165
1166 let _ = client.remove_user(&uaid).await;
1169
1170 trace!("📮 Adding user {}", &user_id);
1172 client.add_user(&test_user).await?;
1173 let fetched = client.get_user(&uaid).await?;
1174 assert!(fetched.is_some());
1175 let fetched = fetched.unwrap();
1176 assert_eq!(fetched.router_type, "webpush".to_owned());
1177
1178 trace!("📮 Adding channel {} to user {}", &chid, &user_id);
1180 client.add_channel(&uaid, &chid).await?;
1181 let channels = client.get_channels(&uaid).await?;
1182 assert!(channels.contains(&chid));
1183
1184 let mut new_channels: HashSet<Uuid> = HashSet::new();
1186 trace!("📮 Adding multiple channels to user {}", &user_id);
1187 new_channels.insert(chid);
1188 for _ in 1..10 {
1189 new_channels.insert(uuid::Uuid::new_v4());
1190 }
1191 let chid_to_remove = uuid::Uuid::new_v4();
1192 trace!(
1193 "📮 Adding removable channel {} to user {}",
1194 &chid_to_remove,
1195 &user_id
1196 );
1197 new_channels.insert(chid_to_remove);
1198 client.add_channels(&uaid, new_channels.clone()).await?;
1199 let channels = client.get_channels(&uaid).await?;
1200 assert_eq!(channels, new_channels);
1201
1202 trace!(
1204 "📮 Removing channel {} from user {}",
1205 &chid_to_remove,
1206 &user_id
1207 );
1208 assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
1209 trace!(
1210 "📮 retrying Removing channel {} from user {}",
1211 &chid_to_remove,
1212 &user_id
1213 );
1214 assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
1215 new_channels.remove(&chid_to_remove);
1216 let channels = client.get_channels(&uaid).await?;
1217 assert_eq!(channels, new_channels);
1218
1219 let mut updated = User {
1223 connected_at,
1224 ..test_user.clone()
1225 };
1226 trace!(
1227 "📮 Attempting to update user {} with old connected_at: {}",
1228 &user_id,
1229 &updated.connected_at
1230 );
1231 let result = client.update_user(&mut updated).await;
1232 assert!(result.is_ok());
1233 assert!(!result.unwrap());
1234
1235 let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
1237 assert_eq!(fetched.connected_at, fetched2.connected_at);
1238
1239 let mut updated = User {
1241 connected_at: fetched.connected_at + 300,
1242 ..fetched2
1243 };
1244 trace!(
1245 "📮 Attempting to update user {} with new connected_at",
1246 &user_id
1247 );
1248 let result = client.update_user(&mut updated).await;
1249 assert!(result.is_ok());
1250 assert!(result.unwrap());
1251 assert_ne!(
1252 fetched2.connected_at,
1253 client.get_user(&uaid).await?.unwrap().connected_at
1254 );
1255 trace!("📮 Incrementing storage timestamp for user {}", &user_id);
1257 client
1258 .increment_storage(&fetched.uaid, sec_since_epoch())
1259 .await?;
1260
1261 let test_data = "An_encrypted_pile_of_crap".to_owned();
1262 let timestamp = sec_since_epoch();
1263 let sort_key = sec_since_epoch();
1264 let fetch_timestamp = timestamp;
1265 let test_notification = crate::db::Notification {
1267 channel_id: chid,
1268 version: "test".to_owned(),
1269 ttl: 300,
1270 timestamp,
1271 data: Some(test_data.clone()),
1272 sortkey_timestamp: Some(sort_key),
1273 ..Default::default()
1274 };
1275 trace!("📮 Saving message for user {}", &user_id);
1276 let res = client.save_message(&uaid, test_notification.clone()).await;
1277 assert!(res.is_ok());
1278
1279 trace!("📮 Fetching all messages for user {}", &user_id);
1280 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1281 assert_ne!(fetched.messages.len(), 0);
1282 let fm = fetched.messages.pop().unwrap();
1283 assert_eq!(fm.channel_id, test_notification.channel_id);
1284 assert_eq!(fm.data, Some(test_data));
1285
1286 trace!(
1288 "📮 Fetching messages for user {} within the past 10 seconds",
1289 &user_id
1290 );
1291 let fetched = client
1292 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
1293 .await?;
1294 assert_ne!(fetched.messages.len(), 0);
1295
1296 trace!(
1298 "📮 Fetching messages for user {} 10 seconds in the future",
1299 &user_id
1300 );
1301 let fetched = client
1302 .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
1303 .await?;
1304 assert_eq!(fetched.messages.len(), 0);
1305
1306 trace!(
1308 "📮 Removing message for user {} :: {}",
1309 &user_id,
1310 &test_notification.chidmessageid()
1311 );
1312 assert!(client
1313 .remove_message(&uaid, &test_notification.chidmessageid())
1314 .await
1315 .is_ok());
1316
1317 trace!("📮 Removing channel for user {}", &user_id);
1318 assert!(client.remove_channel(&uaid, &chid).await.is_ok());
1319
1320 trace!("📮 Making sure no messages remain for user {}", &user_id);
1321 let msgs = client
1322 .fetch_timestamp_messages(&uaid, None, 999)
1323 .await?
1324 .messages;
1325 assert!(msgs.is_empty());
1326
1327 client.add_channel(&uaid, &topic_chid).await?;
1331 let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
1332 let timestamp = sec_since_epoch();
1333 let sort_key = sec_since_epoch();
1334
1335 let test_notification_0 = crate::db::Notification {
1337 channel_id: topic_chid,
1338 version: "version0".to_owned(),
1339 ttl: 300,
1340 topic: Some("topic".to_owned()),
1341 timestamp,
1342 data: Some(test_data.clone()),
1343 sortkey_timestamp: Some(sort_key),
1344 ..Default::default()
1345 };
1346 assert!(client
1347 .save_message(&uaid, test_notification_0.clone())
1348 .await
1349 .is_ok());
1350
1351 let test_notification = crate::db::Notification {
1352 timestamp: sec_since_epoch(),
1353 version: "version1".to_owned(),
1354 sortkey_timestamp: Some(sort_key + 10),
1355 ..test_notification_0
1356 };
1357
1358 assert!(client
1359 .save_message(&uaid, test_notification.clone())
1360 .await
1361 .is_ok());
1362
1363 let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1364 assert_eq!(fetched.messages.len(), 1);
1365 let fm = fetched.messages.pop().unwrap();
1366 assert_eq!(fm.channel_id, test_notification.channel_id);
1367 assert_eq!(fm.data, Some(test_data));
1368
1369 let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
1371 assert_ne!(fetched.messages.len(), 0);
1372
1373 assert!(client
1375 .remove_message(&uaid, &test_notification.chidmessageid())
1376 .await
1377 .is_ok());
1378
1379 assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
1380
1381 let msgs = client
1382 .fetch_timestamp_messages(&uaid, None, 999)
1383 .await?
1384 .messages;
1385 assert!(msgs.is_empty());
1386
1387 let fetched = client.get_user(&uaid).await?.unwrap();
1388 assert!(client
1389 .remove_node_id(&uaid, &node_id, fetched.connected_at, &fetched.version)
1390 .await
1391 .is_ok());
1392 let fetched = client.get_user(&uaid).await?.unwrap();
1394 assert_eq!(fetched.node_id, None);
1395
1396 assert!(client.remove_user(&uaid).await.is_ok());
1397
1398 assert!(client.get_user(&uaid).await?.is_none());
1399 Ok(())
1400 }
1401
1402 #[actix_rt::test]
1403 async fn test_expiry() -> DbResult<()> {
1404 init_test_logging();
1406 let client = new_client()?;
1407
1408 let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
1409 let chid = Uuid::parse_str(TEST_CHID).unwrap();
1410 let now = sec_since_epoch();
1411
1412 let test_notification = crate::db::Notification {
1413 channel_id: chid,
1414 version: "test".to_owned(),
1415 ttl: 2,
1416 timestamp: now,
1417 data: Some("SomeData".into()),
1418 sortkey_timestamp: Some(now),
1419 ..Default::default()
1420 };
1421 client
1422 .add_user(&User {
1423 uaid,
1424 router_type: "test".to_owned(),
1425 connected_at: ms_since_epoch(),
1426 ..Default::default()
1427 })
1428 .await?;
1429 client.add_channel(&uaid, &chid).await?;
1430 debug!("🧪Writing test notif");
1431 client
1432 .save_message(&uaid, test_notification.clone())
1433 .await?;
1434 let key = uaid.simple().to_string();
1435 debug!("🧪Checking {}...", &key);
1436 let msg = client
1437 .fetch_timestamp_messages(&uaid, None, 1)
1438 .await?
1439 .messages
1440 .pop();
1441 assert!(msg.is_some());
1442 debug!("🧪Purging...");
1443 client.increment_storage(&uaid, now + 2).await?;
1444 debug!("🧪Checking for empty {}...", &key);
1445 let cc = client
1446 .fetch_timestamp_messages(&uaid, None, 1)
1447 .await?
1448 .messages
1449 .pop();
1450 assert!(cc.is_none());
1451 assert!(client.remove_user(&uaid).await.is_ok());
1453 Ok(())
1454 }
1455}