autopush_common/db/redis/redis_client/
mod.rs

1use std::collections::HashSet;
2use std::fmt;
3use std::fmt::Display;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::{Duration, SystemTime};
7
8use async_trait::async_trait;
9use cadence::{CountedExt, StatsdClient};
10use redis::aio::MultiplexedConnection;
11use redis::{AsyncCommands, SetExpiry, SetOptions};
12use tokio::sync::OnceCell;
13use uuid::Uuid;
14
15use crate::db::redis::StorableNotification;
16use crate::db::{
17    client::{DbClient, FetchMessageResponse},
18    error::{DbError, DbResult},
19    DbSettings, Notification, User,
20};
21use crate::util::{ms_since_epoch, sec_since_epoch};
22
23use super::RedisDbSettings;
24
25fn now_secs() -> u64 {
26    // Return the current time in seconds since EPOCH
27    SystemTime::now()
28        .duration_since(SystemTime::UNIX_EPOCH)
29        .unwrap()
30        .as_secs()
31}
32
33/// Semi convenience wrapper to ensure that the UAID is formatted and displayed consistently.
34struct Uaid<'a>(&'a Uuid);
35
36impl<'a> Display for Uaid<'a> {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        write!(f, "{}", self.0.as_hyphenated())
39    }
40}
41
42impl<'a> From<Uaid<'a>> for String {
43    fn from(uaid: Uaid) -> String {
44        uaid.0.as_hyphenated().to_string()
45    }
46}
47
48struct ChannelId<'a>(&'a Uuid);
49
50impl<'a> Display for ChannelId<'a> {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", self.0.as_hyphenated())
53    }
54}
55
56impl<'a> From<ChannelId<'a>> for String {
57    fn from(chid: ChannelId) -> String {
58        chid.0.as_hyphenated().to_string()
59    }
60}
61
62#[derive(Clone)]
63/// Wrapper for the Redis connection
64pub struct RedisClientImpl {
65    /// Database connector string
66    pub client: redis::Client,
67    pub conn: OnceCell<MultiplexedConnection>,
68    pub(crate) timeout: Option<Duration>,
69    /// Metrics client
70    metrics: Arc<StatsdClient>,
71    router_opts: SetOptions,
72    // Default notification options (mostly TTL)
73    notification_opts: SetOptions,
74}
75
76impl RedisClientImpl {
77    pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
78        debug!("🐰 New redis client");
79        let dsn = settings.dsn.clone().ok_or(DbError::General(
80            "Redis DSN not configured. Set `db_dsn` to `redis://HOST:PORT` in settings.".to_owned(),
81        ))?;
82        let client = redis::Client::open(dsn)?;
83        let db_settings = RedisDbSettings::try_from(settings.db_settings.as_ref())?;
84        info!("🐰 {:#?}", db_settings);
85        let router_ttl_secs = db_settings.router_ttl.unwrap_or_default().as_secs();
86        let notification_ttl_secs = db_settings.notification_ttl.unwrap_or_default().as_secs();
87        // We specify different TTLs for router vs message.
88        Ok(Self {
89            client,
90            conn: OnceCell::new(),
91            timeout: db_settings.timeout,
92            metrics,
93            router_opts: SetOptions::default().with_expiration(SetExpiry::EX(router_ttl_secs)),
94            notification_opts: SetOptions::default()
95                .with_expiration(SetExpiry::EX(notification_ttl_secs)),
96        })
97    }
98
99    /// Return a [ConnectionLike], which implement redis [Commands] and can be
100    /// used in pipes.
101    ///
102    /// Pools also return a ConnectionLike, so we can add support for pools later.
103    async fn connection(&self) -> DbResult<redis::aio::MultiplexedConnection> {
104        let config = if self.timeout.is_some_and(|t| !t.is_zero()) {
105            redis::AsyncConnectionConfig::new().set_connection_timeout(self.timeout.unwrap())
106        } else {
107            redis::AsyncConnectionConfig::new()
108        };
109        Ok(self
110            .conn
111            .get_or_try_init(|| async {
112                self.client
113                    .get_multiplexed_async_connection_with_config(&config.clone())
114                    .await
115            })
116            .await
117            .inspect_err(|e| error!("No Redis Connection available: {}", e.to_string()))?
118            .clone())
119    }
120
121    fn user_key(&self, uaid: &Uaid) -> String {
122        format!("autopush/user/{}", uaid)
123    }
124
125    /// This store the last connection record, but doesn't update User
126    fn last_co_key(&self, uaid: &Uaid) -> String {
127        format!("autopush/co/{}", uaid)
128    }
129
130    /// This store the last timestamp incremented by the server once messages are ACK'ed
131    fn storage_timestamp_key(&self, uaid: &Uaid) -> String {
132        format!("autopush/timestamp/{}", uaid)
133    }
134
135    fn channel_list_key(&self, uaid: &Uaid) -> String {
136        format!("autopush/channels/{}", uaid)
137    }
138
139    fn message_list_key(&self, uaid: &Uaid) -> String {
140        format!("autopush/msgs/{}", uaid)
141    }
142
143    fn message_exp_list_key(&self, uaid: &Uaid) -> String {
144        format!("autopush/msgs_exp/{}", uaid)
145    }
146
147    fn message_key(&self, uaid: &Uaid, chidmessageid: &str) -> String {
148        format!("autopush/msg/{}/{}", uaid, chidmessageid)
149    }
150
151    #[cfg(feature = "reliable_report")]
152    fn reliability_key(
153        &self,
154        reliability_id: &str,
155        state: &crate::reliability::ReliabilityState,
156    ) -> String {
157        format!("autopush/reliability/{}/{}", reliability_id, state)
158    }
159
160    #[cfg(test)]
161    /// Return a single "raw" message (used by testing and validation)
162    async fn fetch_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<Option<String>> {
163        let message_key = self.message_key(&Uaid(uaid), chidmessageid);
164        let mut con = self.connection().await?;
165        debug!("🐰 Fetching message from {}", &message_key);
166        let message = con.get::<String, Option<String>>(message_key).await?;
167        Ok(message)
168    }
169}
170
171#[async_trait]
172impl DbClient for RedisClientImpl {
173    /// add user to the database
174    async fn add_user(&self, user: &User) -> DbResult<()> {
175        let uaid = Uaid(&user.uaid);
176        let user_key = self.user_key(&uaid);
177        let mut con = self.connection().await?;
178        let co_key = self.last_co_key(&uaid);
179        trace!("🐰 Adding user {} at {}:{}", &user.uaid, &user_key, &co_key);
180        trace!("🐰 Logged at {}", &user.connected_at);
181        redis::pipe()
182            .set_options(co_key, ms_since_epoch(), self.router_opts)
183            .set_options(user_key, serde_json::to_string(user)?, self.router_opts)
184            .exec_async(&mut con)
185            .await?;
186        Ok(())
187    }
188
189    /// To update the TTL of the Redis entry we just have to SET again, with the new expiry
190    ///
191    /// NOTE: This function is called by mobile during the daily
192    /// [autoendpoint::routes::update_token_route] handling, and by desktop
193    /// [autoconnect-ws-sm::get_or_create_user]` which is called
194    /// during the `HELLO` handler. This should be enough to ensure that the ROUTER records
195    /// are properly refreshed for "lively" clients.
196    ///
197    /// NOTE: There is some, very small, potential risk that a desktop client that can
198    /// somehow remain connected the duration of MAX_ROUTER_TTL, may be dropped as not being
199    /// "lively".
200    async fn update_user(&self, user: &mut User) -> DbResult<bool> {
201        trace!("🐰 Updating user");
202        let mut con = self.connection().await?;
203        let co_key = self.last_co_key(&Uaid(&user.uaid));
204        let last_co: Option<u64> = con.get(&co_key).await?;
205        if last_co.is_some_and(|c| c < user.connected_at) {
206            trace!(
207                "🐰 Was connected at {}, now at {}",
208                last_co.unwrap(),
209                &user.connected_at
210            );
211            self.add_user(user).await?;
212            Ok(true)
213        } else {
214            Ok(false)
215        }
216    }
217
218    async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
219        let mut con = self.connection().await?;
220        let user_key = self.user_key(&Uaid(uaid));
221        let user: Option<User> = con
222            .get::<&str, Option<String>>(&user_key)
223            .await?
224            .and_then(|s| serde_json::from_str(s.as_ref()).ok());
225        if user.is_some() {
226            trace!("🐰 Found a record for {}", &uaid);
227        }
228        Ok(user)
229    }
230
231    async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
232        let uaid = Uaid(uaid);
233        let mut con = self.connection().await?;
234        let user_key = self.user_key(&uaid);
235        let co_key = self.last_co_key(&uaid);
236        let chan_list_key = self.channel_list_key(&uaid);
237        let msg_list_key = self.message_list_key(&uaid);
238        let exp_list_key = self.message_exp_list_key(&uaid);
239        let timestamp_key = self.storage_timestamp_key(&uaid);
240        redis::pipe()
241            .del(&user_key)
242            .del(&co_key)
243            .del(&chan_list_key)
244            .del(&msg_list_key)
245            .del(&exp_list_key)
246            .del(&timestamp_key)
247            .exec_async(&mut con)
248            .await?;
249        Ok(())
250    }
251
252    async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
253        let uaid = Uaid(uaid);
254        let mut con = self.connection().await?;
255        let co_key = self.last_co_key(&uaid);
256        let chan_list_key = self.channel_list_key(&uaid);
257
258        let _: () = redis::pipe()
259            .rpush(chan_list_key, channel_id.as_hyphenated().to_string())
260            .set_options(co_key, ms_since_epoch(), self.router_opts)
261            .exec_async(&mut con)
262            .await?;
263        Ok(())
264    }
265
266    /// Add channels in bulk (used mostly during migration)
267    async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
268        let uaid = Uaid(uaid);
269        // channel_ids are stored a list within a single redis key
270        let mut con = self.connection().await?;
271        let co_key = self.last_co_key(&uaid);
272        let chan_list_key = self.channel_list_key(&uaid);
273        redis::pipe()
274            .set_options(co_key, ms_since_epoch(), self.router_opts)
275            .rpush(
276                chan_list_key,
277                channels
278                    .into_iter()
279                    .map(|c| c.as_hyphenated().to_string())
280                    .collect::<Vec<String>>(),
281            )
282            .exec_async(&mut con)
283            .await?;
284        Ok(())
285    }
286
287    async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
288        let uaid = Uaid(uaid);
289        let mut con = self.connection().await?;
290        let chan_list_key = self.channel_list_key(&uaid);
291        let channels: HashSet<Uuid> = con
292            .lrange::<&str, HashSet<String>>(&chan_list_key, 0, -1)
293            .await?
294            .into_iter()
295            .filter_map(|s| Uuid::from_str(&s).ok())
296            .collect();
297        trace!("🐰 Found {} channels for {}", channels.len(), &uaid);
298        Ok(channels)
299    }
300
301    /// Delete the channel. Does not delete its associated pending messages.
302    async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
303        let uaid = Uaid(uaid);
304        let channel_id = ChannelId(channel_id);
305        let mut con = self.connection().await?;
306        let co_key = self.last_co_key(&uaid);
307        let chan_list_key = self.channel_list_key(&uaid);
308        // Remove {channel_id} from autopush/channel/{auid}
309        trace!("🐰 Removing channel {}", channel_id);
310        let (status,): (bool,) = redis::pipe()
311            .set_options(co_key, ms_since_epoch(), self.router_opts)
312            .ignore()
313            .lrem(&chan_list_key, 1, channel_id.to_string())
314            .query_async(&mut con)
315            .await?;
316        Ok(status)
317    }
318
319    /// Remove the node_id
320    async fn remove_node_id(
321        &self,
322        uaid: &Uuid,
323        _node_id: &str,
324        _connected_at: u64,
325        _version: &Option<Uuid>,
326    ) -> DbResult<bool> {
327        if let Some(mut user) = self.get_user(uaid).await? {
328            user.node_id = None;
329            self.update_user(&mut user).await?;
330        }
331        Ok(true)
332    }
333
334    /// Write the notification to storage.
335    ///
336    /// If the message contains a topic, we remove the old message
337    async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
338        let uaid = Uaid(uaid);
339        let mut con = self.connection().await?;
340        let msg_list_key = self.message_list_key(&uaid);
341        let exp_list_key = self.message_exp_list_key(&uaid);
342        let msg_id = &message.chidmessageid();
343        let msg_key = self.message_key(&uaid, msg_id);
344        let storable: StorableNotification = message.into();
345
346        debug!("🐰 Saving message {} :: {:?}", &msg_key, &storable);
347        trace!(
348            "🐰 timestamp: {:?}",
349            &storable.timestamp.to_be_bytes().to_vec()
350        );
351
352        // Remember, `timestamp` is effectively the time to kill the message, not the
353        // current time.
354        let expiry = now_secs() + storable.ttl;
355        trace!("🐰 Message Expiry {}, currently:{} ", expiry, now_secs());
356
357        let mut pipe = redis::pipe();
358
359        // If this is a topic message:
360        // zadd(msg_list_key) and zadd(exp_list_key) will replace their old entry
361        // in the sorted set if one already exists
362        // and set(msg_key, message) will override it too: nothing to do.
363        let is_topic = storable.topic.is_some();
364
365        // notification_ttl is already min(headers.ttl, MAX_NOTIFICATION_TTL)
366        // see autoendpoint/src/extractors/notification_headers.rs
367        let notif_opts = self
368            .notification_opts
369            .with_expiration(SetExpiry::EXAT(expiry));
370
371        // Store notification record in autopush/msg/{uaid}/{chidmessageid}
372        // And store {chidmessageid} in autopush/msgs/{uaid}
373        debug!("🐰 Saving to {}", &msg_key);
374        pipe.set_options(msg_key, serde_json::to_string(&storable)?, notif_opts)
375            // The function [fetch_timestamp_messages] takes a timestamp in input,
376            // here we use the timestamp of the record
377            .zadd(&exp_list_key, msg_id, expiry)
378            .zadd(&msg_list_key, msg_id, sec_since_epoch());
379
380        let _: () = pipe.exec_async(&mut con).await?;
381        self.metrics
382            .incr_with_tags("notification.message.stored")
383            .with_tag("topic", &is_topic.to_string())
384            .with_tag("database", &self.name())
385            .send();
386        Ok(())
387    }
388
389    /// Save a batch of messages to the database.
390    ///
391    /// Currently just iterating through the list and saving one at a time. There's a bulk way
392    /// to save messages, but there are other considerations (e.g. mutation limits)
393    async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
394        // A plate simple way of solving this:
395        for message in messages {
396            self.save_message(uaid, message).await?;
397        }
398        Ok(())
399    }
400
401    /// Delete expired messages
402    async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
403        let uaid = Uaid(uaid);
404        debug!("🐰🔥 Incrementing storage to {}", timestamp);
405        let msg_list_key = self.message_list_key(&uaid);
406        let exp_list_key = self.message_exp_list_key(&uaid);
407        let storage_timestamp_key = self.storage_timestamp_key(&uaid);
408        let mut con = self.connection().await?;
409        trace!("🐇 SEARCH: increment: {:?} - {}", &exp_list_key, timestamp);
410        let exp_id_list: Vec<String> = con.zrangebyscore(&exp_list_key, 0, timestamp).await?;
411        if !exp_id_list.is_empty() {
412            // Remember, we store just the message_ids in the exp and msg lists, but need to convert back to
413            // the full message keys for deletion.
414            let delete_msg_keys: Vec<String> = exp_id_list
415                .clone()
416                .into_iter()
417                .map(|msg_id| self.message_key(&uaid, &msg_id))
418                .collect();
419
420            trace!(
421                "🐰🔥:rem: Deleting {} : [{:?}]",
422                msg_list_key,
423                &delete_msg_keys
424            );
425            trace!("🐰🔥:rem: Deleting {} : [{:?}]", exp_list_key, &exp_id_list);
426            redis::pipe()
427                .set_options::<_, _>(&storage_timestamp_key, timestamp, self.router_opts)
428                .del(&delete_msg_keys)
429                .zrem(&msg_list_key, &exp_id_list)
430                .zrem(&exp_list_key, &exp_id_list)
431                .exec_async(&mut con)
432                .await?;
433        } else {
434            con.set_options::<_, _, ()>(&storage_timestamp_key, timestamp, self.router_opts)
435                .await?;
436        }
437        Ok(())
438    }
439
440    /// Delete the notification from storage.
441    async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
442        let uaid = Uaid(uaid);
443        trace!(
444            "🐰 attemping to delete {:?} :: {:?}",
445            uaid.to_string(),
446            chidmessageid
447        );
448        let msg_key = self.message_key(&uaid, chidmessageid);
449        let msg_list_key = self.message_list_key(&uaid);
450        let exp_list_key = self.message_exp_list_key(&uaid);
451        debug!("🐰🔥 Deleting message {}", &msg_key);
452        let mut con = self.connection().await?;
453        // We remove the id from the exp list at the end, to be sure
454        // it can't be removed from the list before the message is removed
455        trace!(
456            "🐰🔥:remsg: Deleting {} : {:?}",
457            msg_list_key,
458            &chidmessageid
459        );
460        trace!(
461            "🐰🔥:remsg: Deleting {} : {:?}",
462            exp_list_key,
463            &chidmessageid
464        );
465        redis::pipe()
466            .del(&msg_key)
467            .zrem(&msg_list_key, chidmessageid)
468            .zrem(&exp_list_key, chidmessageid)
469            .exec_async(&mut con)
470            .await?;
471        self.metrics
472            .incr_with_tags("notification.message.deleted")
473            .with_tag("database", &self.name())
474            .send();
475        Ok(())
476    }
477
478    /// Topic messages are handled as other messages with redis, we return nothing.
479    async fn fetch_topic_messages(
480        &self,
481        _uaid: &Uuid,
482        _limit: usize,
483    ) -> DbResult<FetchMessageResponse> {
484        Ok(FetchMessageResponse {
485            messages: vec![],
486            timestamp: None,
487        })
488    }
489
490    /// Return [`limit`] messages pending for a [`uaid`] that have a record timestamp
491    /// after [`timestamp`] (secs).
492    ///
493    /// If [`limit`] = 0, we fetch all messages after [`timestamp`].
494    ///
495    /// This can return expired messages, following bigtables behavior
496    async fn fetch_timestamp_messages(
497        &self,
498        uaid: &Uuid,
499        timestamp: Option<u64>,
500        limit: usize,
501    ) -> DbResult<FetchMessageResponse> {
502        let uaid = Uaid(uaid);
503        trace!("🐰 Fetching {} messages since {:?}", limit, timestamp);
504        let mut con = self.connection().await?;
505        let msg_list_key = self.message_list_key(&uaid);
506        let timestamp = if let Some(timestamp) = timestamp {
507            timestamp
508        } else {
509            let storage_timestamp_key = self.storage_timestamp_key(&uaid);
510            con.get(&storage_timestamp_key).await.unwrap_or(0)
511        };
512        // ZRANGE Key (x) +inf LIMIT 0 limit
513        trace!(
514            "🐇 SEARCH: zrangebyscore {:?} {} +inf withscores limit 0 {:?}",
515            &msg_list_key,
516            timestamp,
517            limit,
518        );
519        let results = con
520            .zrangebyscore_limit_withscores::<&str, &str, &str, Vec<(String, u64)>>(
521                &msg_list_key,
522                &timestamp.to_string(),
523                "+inf",
524                0,
525                limit as isize,
526            )
527            .await?;
528        let (messages_id, mut scores): (Vec<String>, Vec<u64>) = results
529            .into_iter()
530            .map(|(id, s): (String, u64)| (self.message_key(&uaid, &id), s))
531            .unzip();
532        if messages_id.is_empty() {
533            trace!("🐰 No message found");
534            return Ok(FetchMessageResponse {
535                messages: vec![],
536                timestamp: None,
537            });
538        }
539        let messages: Vec<Notification> = con
540            .mget::<&Vec<String>, Vec<Option<String>>>(&messages_id)
541            .await?
542            .into_iter()
543            .filter_map(|opt: Option<String>| {
544                if let Some(m) = opt {
545                    serde_json::from_str(&m)
546                        .inspect_err(|e| {
547                            // Since we can't raise the error here, at least record it
548                            // so that it's not lost.
549                            // Mind you, if there is an error here, it's probably due to
550                            // some developmental issue since the unit and integration tests
551                            // should fail.
552                            error!("🐰 ERROR parsing entry: {:?}", e);
553                        })
554                        .ok()
555                } else {
556                    None
557                }
558            })
559            .collect();
560        if messages.is_empty() {
561            trace!("🐰 No Valid messages found");
562            return Ok(FetchMessageResponse {
563                timestamp: None,
564                messages: vec![],
565            });
566        }
567        let timestamp = scores.pop();
568        trace!("🐰 Found {} messages until {:?}", messages.len(), timestamp);
569        Ok(FetchMessageResponse {
570            messages,
571            timestamp,
572        })
573    }
574
575    #[cfg(feature = "reliable_report")]
576    async fn log_report(
577        &self,
578        reliability_id: &str,
579        state: crate::reliability::ReliabilityState,
580    ) -> DbResult<()> {
581        use crate::MAX_NOTIFICATION_TTL_SECS;
582
583        trace!("🐰 Logging reliability report");
584        let mut con = self.connection().await?;
585        // TODO: Should this be a hash key per reliability_id?
586        let reliability_key = self.reliability_key(reliability_id, &state);
587        // Reports should last about as long as the notifications they're tied to.
588        let expiry = MAX_NOTIFICATION_TTL_SECS;
589        let opts = SetOptions::default().with_expiration(SetExpiry::EX(expiry));
590        let mut pipe = redis::pipe();
591        pipe.set_options(reliability_key, sec_since_epoch(), opts)
592            .exec_async(&mut con)
593            .await?;
594        Ok(())
595    }
596
597    async fn health_check(&self) -> DbResult<bool> {
598        let mut con = self.connection().await?;
599        let _: () = con.ping().await?;
600        Ok(true)
601    }
602
603    /// Returns true, because there's no table in Redis
604    async fn router_table_exists(&self) -> DbResult<bool> {
605        Ok(true)
606    }
607
608    /// Returns true, because there's no table in Redis
609    async fn message_table_exists(&self) -> DbResult<bool> {
610        Ok(true)
611    }
612
613    fn box_clone(&self) -> Box<dyn DbClient> {
614        Box::new(self.clone())
615    }
616
617    fn name(&self) -> String {
618        "Redis".to_owned()
619    }
620
621    fn pool_status(&self) -> Option<deadpool::Status> {
622        None
623    }
624}
625
626#[cfg(test)]
627mod tests {
628    use crate::{logging::init_test_logging, util::ms_since_epoch};
629    use rand::prelude::*;
630    use std::env;
631
632    use super::*;
633    const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
634    const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
635
636    fn new_client() -> DbResult<RedisClientImpl> {
637        // Use an environment variable to potentially override the default redis test host.
638        let host = env::var("REDIS_HOST").unwrap_or("localhost".into());
639        let env_dsn = format!("redis://{host}");
640        debug!("🐰 Connecting to {env_dsn}");
641        let settings = DbSettings {
642            dsn: Some(env_dsn),
643            db_settings: "".into(),
644        };
645        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
646        RedisClientImpl::new(metrics, &settings)
647    }
648
649    fn gen_test_user() -> String {
650        // Create a semi-unique test user to avoid conflicting test values.
651        let mut rng = rand::rng();
652        let test_num = rng.random::<u8>();
653        format!("DEADBEEF-0000-0000-{:04}-{:012}", test_num, now_secs())
654    }
655
656    #[actix_rt::test]
657    async fn health_check() {
658        let client = new_client().unwrap();
659
660        let result = client.health_check().await;
661        assert!(result.is_ok());
662        assert!(result.unwrap());
663    }
664
665    /// Test if [increment_storage] correctly wipe expired messages
666    #[actix_rt::test]
667    async fn wipe_expired() -> DbResult<()> {
668        init_test_logging();
669        let client = new_client()?;
670
671        let connected_at = ms_since_epoch();
672
673        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
674        let chid = Uuid::parse_str(TEST_CHID).unwrap();
675
676        let node_id = "test_node".to_owned();
677
678        // purge the user record if it exists.
679        let _ = client.remove_user(&uaid).await;
680
681        let test_user = User {
682            uaid,
683            router_type: "webpush".to_owned(),
684            connected_at,
685            router_data: None,
686            node_id: Some(node_id.clone()),
687            ..Default::default()
688        };
689
690        // purge the old user (if present)
691        // in case a prior test failed for whatever reason.
692        let _ = client.remove_user(&uaid).await;
693
694        // can we add the user?
695        let timestamp = now_secs();
696        client.add_user(&test_user).await?;
697        let test_notification = crate::db::Notification {
698            channel_id: chid,
699            version: "test".to_owned(),
700            ttl: 1,
701            timestamp,
702            data: Some("Encrypted".into()),
703            sortkey_timestamp: Some(timestamp),
704            ..Default::default()
705        };
706        client.save_message(&uaid, test_notification).await?;
707        client.increment_storage(&uaid, timestamp + 1).await?;
708        let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
709        assert_eq!(msgs.messages.len(), 0);
710        assert!(client.remove_user(&uaid).await.is_ok());
711        Ok(())
712    }
713
714    /// run a gauntlet of testing. These are a bit linear because they need
715    /// to run in sequence.
716    #[actix_rt::test]
717    async fn run_gauntlet() -> DbResult<()> {
718        init_test_logging();
719        let client = new_client()?;
720
721        let connected_at = ms_since_epoch();
722
723        let user_id = &gen_test_user();
724        let uaid = Uuid::parse_str(user_id).unwrap();
725        let chid = Uuid::parse_str(TEST_CHID).unwrap();
726        let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
727
728        let node_id = "test_node".to_owned();
729
730        // purge the user record if it exists.
731        let _ = client.remove_user(&uaid).await;
732
733        let test_user = User {
734            uaid,
735            router_type: "webpush".to_owned(),
736            connected_at,
737            router_data: None,
738            node_id: Some(node_id.clone()),
739            ..Default::default()
740        };
741
742        // purge the old user (if present)
743        // in case a prior test failed for whatever reason.
744        let _ = client.remove_user(&uaid).await;
745
746        // can we add the user?
747        client.add_user(&test_user).await?;
748        let fetched = client.get_user(&uaid).await?;
749        assert!(fetched.is_some());
750        let fetched = fetched.unwrap();
751        assert_eq!(fetched.router_type, "webpush".to_owned());
752
753        // Simulate a connected_at occuring before the following writes
754        let connected_at = ms_since_epoch();
755
756        // can we add channels?
757        client.add_channel(&uaid, &chid).await?;
758        let channels = client.get_channels(&uaid).await?;
759        assert!(channels.contains(&chid));
760
761        // can we add lots of channels?
762        let mut new_channels: HashSet<Uuid> = HashSet::new();
763        new_channels.insert(chid);
764        for _ in 1..10 {
765            new_channels.insert(uuid::Uuid::new_v4());
766        }
767        let chid_to_remove = uuid::Uuid::new_v4();
768        new_channels.insert(chid_to_remove);
769        client.add_channels(&uaid, new_channels.clone()).await?;
770        let channels = client.get_channels(&uaid).await?;
771        assert_eq!(channels, new_channels);
772
773        // can we remove a channel?
774        assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
775        assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
776        new_channels.remove(&chid_to_remove);
777        let channels = client.get_channels(&uaid).await?;
778        assert_eq!(channels, new_channels);
779
780        // now ensure that we can update a user that's after the time we set
781        // prior. first ensure that we can't update a user that's before the
782        // time we set prior to the last write
783        let mut updated = User {
784            connected_at,
785            ..test_user.clone()
786        };
787        let result = client.update_user(&mut updated).await;
788        assert!(result.is_ok());
789        assert!(!result.unwrap());
790
791        // Make sure that the `connected_at` wasn't modified
792        let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
793        assert_eq!(fetched.connected_at, fetched2.connected_at);
794
795        // and make sure we can update a record with a later connected_at time.
796        let mut updated = User {
797            connected_at: fetched.connected_at + 300,
798            ..fetched2
799        };
800        let result = client.update_user(&mut updated).await;
801        assert!(result.is_ok());
802        assert!(result.unwrap());
803        assert_ne!(
804            fetched2.connected_at,
805            client.get_user(&uaid).await?.unwrap().connected_at
806        );
807
808        // can we increment the storage for the user?
809        client
810            .increment_storage(
811                &fetched.uaid,
812                SystemTime::now()
813                    .duration_since(SystemTime::UNIX_EPOCH)
814                    .unwrap()
815                    .as_secs(),
816            )
817            .await?;
818
819        let test_data = "An_encrypted_pile_of_crap".to_owned();
820        let timestamp = now_secs();
821        let sort_key = now_secs();
822        let fetch_timestamp = timestamp;
823        // Can we store a message?
824        let test_notification = crate::db::Notification {
825            channel_id: chid,
826            version: "test".to_owned(),
827            ttl: 300,
828            timestamp,
829            data: Some(test_data.clone()),
830            sortkey_timestamp: Some(sort_key),
831            ..Default::default()
832        };
833        let res = client.save_message(&uaid, test_notification.clone()).await;
834        assert!(res.is_ok());
835
836        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
837        assert_ne!(fetched.messages.len(), 0);
838        let fm = fetched.messages.pop().unwrap();
839        assert_eq!(fm.channel_id, test_notification.channel_id);
840        assert_eq!(fm.data, Some(test_data));
841
842        // Grab all 1 of the messages that were submitted within the past 10 seconds.
843        let fetched = client
844            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
845            .await?;
846        assert_ne!(fetched.messages.len(), 0);
847
848        // Try grabbing a message for 10 seconds from now.
849        let fetched = client
850            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
851            .await?;
852        assert_eq!(fetched.messages.len(), 0);
853
854        // can we clean up our toys?
855        assert!(client
856            .remove_message(&uaid, &test_notification.chidmessageid())
857            .await
858            .is_ok());
859
860        assert!(client.remove_channel(&uaid, &chid).await.is_ok());
861
862        let msgs = client
863            .fetch_timestamp_messages(&uaid, None, 999)
864            .await?
865            .messages;
866        assert!(msgs.is_empty());
867
868        // Now, can we do all that with topic messages
869        // Unlike bigtable, we don't use [fetch_topic_messages]: it always return None:
870        // they are handled as usuals messages.
871        client.add_channel(&uaid, &topic_chid).await?;
872        let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
873        let timestamp = now_secs();
874        let sort_key = now_secs();
875
876        // We store 2 messages, with a single topic
877        let test_notification_0 = crate::db::Notification {
878            channel_id: topic_chid,
879            version: "version0".to_owned(),
880            ttl: 300,
881            topic: Some("topic".to_owned()),
882            timestamp,
883            data: Some(test_data.clone()),
884            sortkey_timestamp: Some(sort_key),
885            ..Default::default()
886        };
887        assert!(client
888            .save_message(&uaid, test_notification_0.clone())
889            .await
890            .is_ok());
891
892        let test_notification = crate::db::Notification {
893            timestamp: now_secs(),
894            version: "version1".to_owned(),
895            sortkey_timestamp: Some(sort_key + 10),
896            ..test_notification_0
897        };
898
899        assert!(client
900            .save_message(&uaid, test_notification.clone())
901            .await
902            .is_ok());
903
904        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
905        assert_eq!(fetched.messages.len(), 1);
906        let fm = fetched.messages.pop().unwrap();
907        assert_eq!(fm.channel_id, test_notification.channel_id);
908        assert_eq!(fm.data, Some(test_data));
909
910        // Grab the message that was submitted.
911        let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
912        assert_ne!(fetched.messages.len(), 0);
913
914        // can we clean up our toys?
915        assert!(client
916            .remove_message(&uaid, &test_notification.chidmessageid())
917            .await
918            .is_ok());
919
920        assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
921
922        let msgs = client
923            .fetch_timestamp_messages(&uaid, None, 999)
924            .await?
925            .messages;
926        assert!(msgs.is_empty());
927
928        let fetched = client.get_user(&uaid).await?.unwrap();
929        assert!(client
930            .remove_node_id(&uaid, &node_id, connected_at, &fetched.version)
931            .await
932            .is_ok());
933        // did we remove it?
934        let fetched = client.get_user(&uaid).await?.unwrap();
935        assert_eq!(fetched.node_id, None);
936
937        assert!(client.remove_user(&uaid).await.is_ok());
938
939        assert!(client.get_user(&uaid).await?.is_none());
940        Ok(())
941    }
942
943    #[actix_rt::test]
944    async fn test_expiry() -> DbResult<()> {
945        // Make sure that we really are purging messages correctly
946        init_test_logging();
947        let client = new_client()?;
948
949        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
950        let chid = Uuid::parse_str(TEST_CHID).unwrap();
951        let now = now_secs();
952
953        let test_notification = crate::db::Notification {
954            channel_id: chid,
955            version: "test".to_owned(),
956            ttl: 2,
957            timestamp: now,
958            data: Some("SomeData".into()),
959            sortkey_timestamp: Some(now),
960            ..Default::default()
961        };
962        debug!("Writing test notif");
963        let res = client.save_message(&uaid, test_notification.clone()).await;
964        assert!(res.is_ok());
965        let key = client.message_key(&Uaid(&uaid), &test_notification.chidmessageid());
966        debug!("Checking {}...", &key);
967        let msg = client
968            .fetch_message(&uaid, &test_notification.chidmessageid())
969            .await?;
970        assert!(!msg.unwrap().is_empty());
971        debug!("Purging...");
972        client.increment_storage(&uaid, now + 2).await?;
973        debug!("Checking {}...", &key);
974        let cc = client
975            .fetch_message(&uaid, &test_notification.chidmessageid())
976            .await;
977        assert_eq!(cc.unwrap(), None);
978        // clean up after the test.
979        assert!(client.remove_user(&uaid).await.is_ok());
980        Ok(())
981    }
982}