autopush_common/db/redis/redis_client/
mod.rs

1use std::collections::HashSet;
2use std::fmt;
3use std::fmt::Display;
4use std::str::FromStr;
5use std::sync::Arc;
6use std::time::SystemTime;
7
8use async_trait::async_trait;
9use cadence::{CountedExt, StatsdClient};
10use deadpool_redis::redis::{pipe, AsyncCommands, SetExpiry, SetOptions};
11use deadpool_redis::Config;
12use uuid::Uuid;
13
14use crate::db::redis::StorableNotification;
15use crate::db::{
16    client::{DbClient, FetchMessageResponse},
17    error::{DbError, DbResult},
18    DbSettings, Notification, User,
19};
20use crate::util::{ms_since_epoch, sec_since_epoch};
21
22use super::RedisDbSettings;
23
24fn now_secs() -> u64 {
25    // Return the current time in seconds since EPOCH
26    SystemTime::now()
27        .duration_since(SystemTime::UNIX_EPOCH)
28        .unwrap()
29        .as_secs()
30}
31
32/// Semi convenience wrapper to ensure that the UAID is formatted and displayed consistently.
33struct Uaid<'a>(&'a Uuid);
34
35impl<'a> Display for Uaid<'a> {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{}", self.0.as_hyphenated())
38    }
39}
40
41impl<'a> From<Uaid<'a>> for String {
42    fn from(uaid: Uaid) -> String {
43        uaid.0.as_hyphenated().to_string()
44    }
45}
46
47struct ChannelId<'a>(&'a Uuid);
48
49impl<'a> Display for ChannelId<'a> {
50    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51        write!(f, "{}", self.0.as_hyphenated())
52    }
53}
54
55impl<'a> From<ChannelId<'a>> for String {
56    fn from(chid: ChannelId) -> String {
57        chid.0.as_hyphenated().to_string()
58    }
59}
60
61#[derive(Clone)]
62/// Wrapper for the Redis connection
63pub struct RedisClientImpl {
64    /// Database connector string
65    pub pool: deadpool_redis::Pool,
66    /// Metrics client
67    metrics: Arc<StatsdClient>,
68    router_opts: SetOptions,
69    // Default notification options (mostly TTL)
70    notification_opts: SetOptions,
71}
72
73impl RedisClientImpl {
74    pub fn new(metrics: Arc<StatsdClient>, settings: &DbSettings) -> DbResult<Self> {
75        debug!("🐰 New redis client");
76        let dsn = settings.dsn.clone().ok_or(DbError::General(
77            "Redis DSN not configured. Set `db_dsn` to `redis://HOST:PORT` in settings.".to_owned(),
78        ))?;
79        let db_settings = RedisDbSettings::try_from(settings.db_settings.as_ref())?;
80        info!("🐰 {:#?}", db_settings);
81        let router_ttl_secs = db_settings.router_ttl.unwrap_or_default().as_secs();
82        let notification_ttl_secs = db_settings.notification_ttl.unwrap_or_default().as_secs();
83
84        let config = Config::from_url(dsn);
85        let pool = config
86            .builder()
87            .map_err(|e| DbError::General(format!("Could not create Redis pool: {:?}", e)))?
88            .create_timeout(db_settings.create_timeout)
89            .runtime(deadpool_redis::Runtime::Tokio1)
90            .build()
91            .map_err(|e| DbError::General(format!("Could not create Redis pool: {:?}", e)))?;
92        /* We have the option of using either a OneCell wrapped get_multiplexed_async_connection or
93         * a pool. Reliability already uses a pool, so for consistency we use a pool here as well.
94         */
95
96        // We specify different TTLs for router vs message.
97        Ok(Self {
98            pool,
99            metrics,
100            router_opts: SetOptions::default().with_expiration(SetExpiry::EX(router_ttl_secs)),
101            notification_opts: SetOptions::default()
102                .with_expiration(SetExpiry::EX(notification_ttl_secs)),
103        })
104    }
105
106    /// Return a [ConnectionLike], which implement redis [Commands] and can be
107    /// used in pipes.
108    ///
109    /// Pools also return a ConnectionLike, so we can add support for pools later.
110    async fn connection(&self) -> DbResult<deadpool_redis::Connection> {
111        self.pool.get().await.map_err(|e| {
112            DbError::RedisError(redis::RedisError::from((
113                redis::ErrorKind::IoError,
114                "Could not get Redis connection from pool",
115                format!("{:?}", e),
116            )))
117        })
118    }
119
120    fn user_key(&self, uaid: &Uaid) -> String {
121        format!("autopush/user/{}", uaid)
122    }
123
124    /// This store the last connection record, but doesn't update User
125    fn last_co_key(&self, uaid: &Uaid) -> String {
126        format!("autopush/co/{}", uaid)
127    }
128
129    /// This store the last timestamp incremented by the server once messages are ACK'ed
130    fn storage_timestamp_key(&self, uaid: &Uaid) -> String {
131        format!("autopush/timestamp/{}", uaid)
132    }
133
134    fn channel_list_key(&self, uaid: &Uaid) -> String {
135        format!("autopush/channels/{}", uaid)
136    }
137
138    fn message_list_key(&self, uaid: &Uaid) -> String {
139        format!("autopush/msgs/{}", uaid)
140    }
141
142    fn message_exp_list_key(&self, uaid: &Uaid) -> String {
143        format!("autopush/msgs_exp/{}", uaid)
144    }
145
146    fn message_key(&self, uaid: &Uaid, chidmessageid: &str) -> String {
147        format!("autopush/msg/{}/{}", uaid, chidmessageid)
148    }
149
150    #[cfg(feature = "reliable_report")]
151    fn reliability_key(
152        &self,
153        reliability_id: &str,
154        state: &crate::reliability::ReliabilityState,
155    ) -> String {
156        format!("autopush/reliability/{}/{}", reliability_id, state)
157    }
158
159    #[cfg(test)]
160    /// Return a single "raw" message (used by testing and validation)
161    async fn fetch_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<Option<String>> {
162        let message_key = self.message_key(&Uaid(uaid), chidmessageid);
163        let mut con = self.connection().await?;
164        debug!("🐰 Fetching message from {}", &message_key);
165        let message = con.get::<String, Option<String>>(message_key).await?;
166        Ok(message)
167    }
168}
169
170#[async_trait]
171impl DbClient for RedisClientImpl {
172    /// add user to the database
173    async fn add_user(&self, user: &User) -> DbResult<()> {
174        let uaid = Uaid(&user.uaid);
175        let user_key = self.user_key(&uaid);
176        let mut con = self.connection().await?;
177        let co_key = self.last_co_key(&uaid);
178        trace!("🐰 Adding user {} at {}:{}", &user.uaid, &user_key, &co_key);
179        trace!("🐰 Logged at {}", &user.connected_at);
180        pipe()
181            .set_options(co_key, ms_since_epoch(), self.router_opts)
182            .set_options(user_key, serde_json::to_string(user)?, self.router_opts)
183            .exec_async(&mut con)
184            .await?;
185        Ok(())
186    }
187
188    /// To update the TTL of the Redis entry we just have to SET again, with the new expiry
189    ///
190    /// NOTE: This function is called by mobile during the daily
191    /// [autoendpoint::routes::update_token_route] handling, and by desktop
192    /// [autoconnect-ws-sm::get_or_create_user]` which is called
193    /// during the `HELLO` handler. This should be enough to ensure that the ROUTER records
194    /// are properly refreshed for "lively" clients.
195    ///
196    /// NOTE: There is some, very small, potential risk that a desktop client that can
197    /// somehow remain connected the duration of MAX_ROUTER_TTL, may be dropped as not being
198    /// "lively".
199    async fn update_user(&self, user: &mut User) -> DbResult<bool> {
200        trace!("🐰 Updating user");
201        let mut con = self.connection().await?;
202        let co_key = self.last_co_key(&Uaid(&user.uaid));
203        let last_co: Option<u64> = con.get(&co_key).await?;
204        if last_co.is_some_and(|c| c < user.connected_at) {
205            trace!(
206                "🐰 Was connected at {}, now at {}",
207                last_co.unwrap(),
208                &user.connected_at
209            );
210            self.add_user(user).await?;
211            Ok(true)
212        } else {
213            Ok(false)
214        }
215    }
216
217    async fn get_user(&self, uaid: &Uuid) -> DbResult<Option<User>> {
218        let mut con = self.connection().await?;
219        let user_key = self.user_key(&Uaid(uaid));
220        let user: Option<User> = con
221            .get::<&str, Option<String>>(&user_key)
222            .await?
223            .and_then(|s| serde_json::from_str(s.as_ref()).ok());
224        if user.is_some() {
225            trace!("🐰 Found a record for {}", &uaid);
226        }
227        Ok(user)
228    }
229
230    async fn remove_user(&self, uaid: &Uuid) -> DbResult<()> {
231        let uaid = Uaid(uaid);
232        let mut con = self.connection().await?;
233        let user_key = self.user_key(&uaid);
234        let co_key = self.last_co_key(&uaid);
235        let chan_list_key = self.channel_list_key(&uaid);
236        let msg_list_key = self.message_list_key(&uaid);
237        let exp_list_key = self.message_exp_list_key(&uaid);
238        let timestamp_key = self.storage_timestamp_key(&uaid);
239        pipe()
240            .del(&user_key)
241            .del(&co_key)
242            .del(&chan_list_key)
243            .del(&msg_list_key)
244            .del(&exp_list_key)
245            .del(&timestamp_key)
246            .exec_async(&mut con)
247            .await?;
248        Ok(())
249    }
250
251    async fn add_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<()> {
252        let uaid = Uaid(uaid);
253        let mut con = self.connection().await?;
254        let co_key = self.last_co_key(&uaid);
255        let chan_list_key = self.channel_list_key(&uaid);
256
257        let _: () = pipe()
258            .rpush(chan_list_key, channel_id.as_hyphenated().to_string())
259            .set_options(co_key, ms_since_epoch(), self.router_opts)
260            .exec_async(&mut con)
261            .await?;
262        Ok(())
263    }
264
265    /// Add channels in bulk (used mostly during migration)
266    async fn add_channels(&self, uaid: &Uuid, channels: HashSet<Uuid>) -> DbResult<()> {
267        let uaid = Uaid(uaid);
268        // channel_ids are stored a list within a single redis key
269        let mut con = self.connection().await?;
270        let co_key = self.last_co_key(&uaid);
271        let chan_list_key = self.channel_list_key(&uaid);
272        pipe()
273            .set_options(co_key, ms_since_epoch(), self.router_opts)
274            .rpush(
275                chan_list_key,
276                channels
277                    .into_iter()
278                    .map(|c| c.as_hyphenated().to_string())
279                    .collect::<Vec<String>>(),
280            )
281            .exec_async(&mut con)
282            .await?;
283        Ok(())
284    }
285
286    async fn get_channels(&self, uaid: &Uuid) -> DbResult<HashSet<Uuid>> {
287        let uaid = Uaid(uaid);
288        let mut con = self.connection().await?;
289        let chan_list_key = self.channel_list_key(&uaid);
290        let channels: HashSet<Uuid> = con
291            .lrange::<&str, HashSet<String>>(&chan_list_key, 0, -1)
292            .await?
293            .into_iter()
294            .filter_map(|s| Uuid::from_str(&s).ok())
295            .collect();
296        trace!("🐰 Found {} channels for {}", channels.len(), &uaid);
297        Ok(channels)
298    }
299
300    /// Delete the channel. Does not delete its associated pending messages.
301    async fn remove_channel(&self, uaid: &Uuid, channel_id: &Uuid) -> DbResult<bool> {
302        let uaid = Uaid(uaid);
303        let channel_id = ChannelId(channel_id);
304        let mut con = self.connection().await?;
305        let co_key = self.last_co_key(&uaid);
306        let chan_list_key = self.channel_list_key(&uaid);
307        // Remove {channel_id} from autopush/channel/{auid}
308        trace!("🐰 Removing channel {}", channel_id);
309        let (status,): (bool,) = pipe()
310            .set_options(co_key, ms_since_epoch(), self.router_opts)
311            .ignore()
312            .lrem(&chan_list_key, 1, channel_id.to_string())
313            .query_async(&mut con)
314            .await?;
315        Ok(status)
316    }
317
318    /// Remove the node_id
319    async fn remove_node_id(
320        &self,
321        uaid: &Uuid,
322        _node_id: &str,
323        _connected_at: u64,
324        _version: &Option<Uuid>,
325    ) -> DbResult<bool> {
326        if let Some(mut user) = self.get_user(uaid).await? {
327            user.node_id = None;
328            self.update_user(&mut user).await?;
329        }
330        Ok(true)
331    }
332
333    /// Write the notification to storage.
334    ///
335    /// If the message contains a topic, we remove the old message
336    async fn save_message(&self, uaid: &Uuid, message: Notification) -> DbResult<()> {
337        let uaid = Uaid(uaid);
338        let mut con = self.connection().await?;
339        let msg_list_key = self.message_list_key(&uaid);
340        let exp_list_key = self.message_exp_list_key(&uaid);
341        let msg_id = &message.chidmessageid();
342        let msg_key = self.message_key(&uaid, msg_id);
343        let storable: StorableNotification = message.into();
344
345        debug!("🐰 Saving message {} :: {:?}", &msg_key, &storable);
346        trace!(
347            "🐰 timestamp: {:?}",
348            &storable.timestamp.to_be_bytes().to_vec()
349        );
350
351        // Remember, `timestamp` is effectively the time to kill the message, not the
352        // current time.
353        let expiry = now_secs() + storable.ttl;
354        trace!("🐰 Message Expiry {}, currently:{} ", expiry, now_secs());
355
356        let mut pipe = pipe();
357
358        // If this is a topic message:
359        // zadd(msg_list_key) and zadd(exp_list_key) will replace their old entry
360        // in the sorted set if one already exists
361        // and set(msg_key, message) will override it too: nothing to do.
362        let is_topic = storable.topic.is_some();
363
364        // notification_ttl is already min(headers.ttl, MAX_NOTIFICATION_TTL)
365        // see autoendpoint/src/extractors/notification_headers.rs
366        let notif_opts = self
367            .notification_opts
368            .with_expiration(SetExpiry::EXAT(expiry));
369
370        // Store notification record in autopush/msg/{uaid}/{chidmessageid}
371        // And store {chidmessageid} in autopush/msgs/{uaid}
372        debug!("🐰 Saving to {}", &msg_key);
373        pipe.set_options(msg_key, serde_json::to_string(&storable)?, notif_opts)
374            // The function [fetch_timestamp_messages] takes a timestamp in input,
375            // here we use the timestamp of the record
376            .zadd(&exp_list_key, msg_id, expiry)
377            .zadd(&msg_list_key, msg_id, sec_since_epoch());
378
379        let _: () = pipe.exec_async(&mut con).await?;
380        self.metrics
381            .incr_with_tags("notification.message.stored")
382            .with_tag("topic", &is_topic.to_string())
383            .with_tag("database", &self.name())
384            .send();
385        Ok(())
386    }
387
388    /// Save a batch of messages to the database.
389    ///
390    /// Currently just iterating through the list and saving one at a time. There's a bulk way
391    /// to save messages, but there are other considerations (e.g. mutation limits)
392    async fn save_messages(&self, uaid: &Uuid, messages: Vec<Notification>) -> DbResult<()> {
393        // A plate simple way of solving this:
394        for message in messages {
395            self.save_message(uaid, message).await?;
396        }
397        Ok(())
398    }
399
400    /// Delete expired messages
401    async fn increment_storage(&self, uaid: &Uuid, timestamp: u64) -> DbResult<()> {
402        let uaid = Uaid(uaid);
403        debug!("🐰🔥 Incrementing storage to {}", timestamp);
404        let msg_list_key = self.message_list_key(&uaid);
405        let exp_list_key = self.message_exp_list_key(&uaid);
406        let storage_timestamp_key = self.storage_timestamp_key(&uaid);
407        let mut con = self.connection().await?;
408        trace!("🐇 SEARCH: increment: {:?} - {}", &exp_list_key, timestamp);
409        let exp_id_list: Vec<String> = con.zrangebyscore(&exp_list_key, 0, timestamp).await?;
410        if !exp_id_list.is_empty() {
411            // Remember, we store just the message_ids in the exp and msg lists, but need to convert back to
412            // the full message keys for deletion.
413            let delete_msg_keys: Vec<String> = exp_id_list
414                .clone()
415                .into_iter()
416                .map(|msg_id| self.message_key(&uaid, &msg_id))
417                .collect();
418
419            trace!(
420                "🐰🔥:rem: Deleting {} : [{:?}]",
421                msg_list_key,
422                &delete_msg_keys
423            );
424            trace!("🐰🔥:rem: Deleting {} : [{:?}]", exp_list_key, &exp_id_list);
425            pipe()
426                .set_options::<_, _>(&storage_timestamp_key, timestamp, self.router_opts)
427                .del(&delete_msg_keys)
428                .zrem(&msg_list_key, &exp_id_list)
429                .zrem(&exp_list_key, &exp_id_list)
430                .exec_async(&mut con)
431                .await?;
432        } else {
433            con.set_options::<_, _, ()>(&storage_timestamp_key, timestamp, self.router_opts)
434                .await?;
435        }
436        Ok(())
437    }
438
439    /// Delete the notification from storage.
440    async fn remove_message(&self, uaid: &Uuid, chidmessageid: &str) -> DbResult<()> {
441        let uaid = Uaid(uaid);
442        trace!(
443            "🐰 attemping to delete {:?} :: {:?}",
444            uaid.to_string(),
445            chidmessageid
446        );
447        let msg_key = self.message_key(&uaid, chidmessageid);
448        let msg_list_key = self.message_list_key(&uaid);
449        let exp_list_key = self.message_exp_list_key(&uaid);
450        debug!("🐰🔥 Deleting message {}", &msg_key);
451        let mut con = self.connection().await?;
452        // We remove the id from the exp list at the end, to be sure
453        // it can't be removed from the list before the message is removed
454        trace!(
455            "🐰🔥:remsg: Deleting {} : {:?}",
456            msg_list_key,
457            &chidmessageid
458        );
459        trace!(
460            "🐰🔥:remsg: Deleting {} : {:?}",
461            exp_list_key,
462            &chidmessageid
463        );
464        pipe()
465            .del(&msg_key)
466            .zrem(&msg_list_key, chidmessageid)
467            .zrem(&exp_list_key, chidmessageid)
468            .exec_async(&mut con)
469            .await?;
470        self.metrics
471            .incr_with_tags("notification.message.deleted")
472            .with_tag("database", &self.name())
473            .send();
474        Ok(())
475    }
476
477    /// Topic messages are handled as other messages with redis, we return nothing.
478    async fn fetch_topic_messages(
479        &self,
480        _uaid: &Uuid,
481        _limit: usize,
482    ) -> DbResult<FetchMessageResponse> {
483        Ok(FetchMessageResponse {
484            messages: vec![],
485            timestamp: None,
486        })
487    }
488
489    /// Return [`limit`] messages pending for a [`uaid`] that have a record timestamp
490    /// after [`timestamp`] (secs).
491    ///
492    /// If [`limit`] = 0, we fetch all messages after [`timestamp`].
493    ///
494    /// This can return expired messages, following bigtables behavior
495    async fn fetch_timestamp_messages(
496        &self,
497        uaid: &Uuid,
498        timestamp: Option<u64>,
499        limit: usize,
500    ) -> DbResult<FetchMessageResponse> {
501        let uaid = Uaid(uaid);
502        trace!("🐰 Fetching {} messages since {:?}", limit, timestamp);
503        let mut con = self.connection().await?;
504        let msg_list_key = self.message_list_key(&uaid);
505        let timestamp = if let Some(timestamp) = timestamp {
506            timestamp
507        } else {
508            let storage_timestamp_key = self.storage_timestamp_key(&uaid);
509            con.get(&storage_timestamp_key).await.unwrap_or(0)
510        };
511        // ZRANGE Key (x) +inf LIMIT 0 limit
512        trace!(
513            "🐇 SEARCH: zrangebyscore {:?} {} +inf withscores limit 0 {:?}",
514            &msg_list_key,
515            timestamp,
516            limit,
517        );
518        let results = con
519            .zrangebyscore_limit_withscores::<&str, &str, &str, Vec<(String, u64)>>(
520                &msg_list_key,
521                &timestamp.to_string(),
522                "+inf",
523                0,
524                limit as isize,
525            )
526            .await?;
527        let (messages_id, mut scores): (Vec<String>, Vec<u64>) = results
528            .into_iter()
529            .map(|(id, s): (String, u64)| (self.message_key(&uaid, &id), s))
530            .unzip();
531        if messages_id.is_empty() {
532            trace!("🐰 No message found");
533            return Ok(FetchMessageResponse {
534                messages: vec![],
535                timestamp: None,
536            });
537        }
538        let messages: Vec<Notification> = con
539            .mget::<&Vec<String>, Vec<Option<String>>>(&messages_id)
540            .await?
541            .into_iter()
542            .filter_map(|opt: Option<String>| {
543                if let Some(m) = opt {
544                    serde_json::from_str(&m)
545                        .inspect_err(|e| {
546                            // Since we can't raise the error here, at least record it
547                            // so that it's not lost.
548                            // Mind you, if there is an error here, it's probably due to
549                            // some developmental issue since the unit and integration tests
550                            // should fail.
551                            error!("🐰 ERROR parsing entry: {:?}", e);
552                        })
553                        .ok()
554                } else {
555                    None
556                }
557            })
558            .collect();
559        if messages.is_empty() {
560            trace!("🐰 No Valid messages found");
561            return Ok(FetchMessageResponse {
562                timestamp: None,
563                messages: vec![],
564            });
565        }
566        let timestamp = scores.pop();
567        trace!("🐰 Found {} messages until {:?}", messages.len(), timestamp);
568        Ok(FetchMessageResponse {
569            messages,
570            timestamp,
571        })
572    }
573
574    #[cfg(feature = "reliable_report")]
575    async fn log_report(
576        &self,
577        reliability_id: &str,
578        state: crate::reliability::ReliabilityState,
579    ) -> DbResult<()> {
580        use crate::MAX_NOTIFICATION_TTL_SECS;
581
582        trace!("🐰 Logging reliability report");
583        let mut con = self.connection().await?;
584        // TODO: Should this be a hash key per reliability_id?
585        let reliability_key = self.reliability_key(reliability_id, &state);
586        // Reports should last about as long as the notifications they're tied to.
587        let expiry = MAX_NOTIFICATION_TTL_SECS;
588        let opts = SetOptions::default().with_expiration(SetExpiry::EX(expiry));
589        let mut pipe = pipe();
590        pipe.set_options(reliability_key, sec_since_epoch(), opts)
591            .exec_async(&mut con)
592            .await?;
593        Ok(())
594    }
595
596    async fn health_check(&self) -> DbResult<bool> {
597        let _: () = self.connection().await?.ping().await?;
598        Ok(true)
599    }
600
601    /// Returns true, because there's no table in Redis
602    async fn router_table_exists(&self) -> DbResult<bool> {
603        Ok(true)
604    }
605
606    /// Returns true, because there's no table in Redis
607    async fn message_table_exists(&self) -> DbResult<bool> {
608        Ok(true)
609    }
610
611    fn box_clone(&self) -> Box<dyn DbClient> {
612        Box::new(self.clone())
613    }
614
615    fn name(&self) -> String {
616        "Redis".to_owned()
617    }
618
619    fn pool_status(&self) -> Option<deadpool::Status> {
620        None
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use crate::{logging::init_test_logging, util::ms_since_epoch};
627    use rand::prelude::*;
628    use std::env;
629
630    use super::*;
631    const TEST_CHID: &str = "DECAFBAD-0000-0000-0000-0123456789AB";
632    const TOPIC_CHID: &str = "DECAFBAD-1111-0000-0000-0123456789AB";
633
634    fn new_client() -> DbResult<RedisClientImpl> {
635        // Use an environment variable to potentially override the default redis test host.
636        let host = env::var("REDIS_HOST").unwrap_or("localhost".into());
637        let env_dsn = format!("redis://{host}");
638        debug!("🐰 Connecting to {env_dsn}");
639        let settings = DbSettings {
640            dsn: Some(env_dsn),
641            db_settings: "".into(),
642        };
643        let metrics = Arc::new(StatsdClient::builder("", cadence::NopMetricSink).build());
644        RedisClientImpl::new(metrics, &settings)
645    }
646
647    fn gen_test_user() -> String {
648        // Create a semi-unique test user to avoid conflicting test values.
649        let mut rng = rand::rng();
650        let test_num = rng.random::<u8>();
651        format!("DEADBEEF-0000-0000-{:04}-{:012}", test_num, now_secs())
652    }
653
654    #[actix_rt::test]
655    async fn health_check() {
656        let client = new_client().unwrap();
657
658        let result = client.health_check().await;
659        assert!(result.is_ok());
660        assert!(result.unwrap());
661    }
662
663    /// Test if [increment_storage] correctly wipe expired messages
664    #[actix_rt::test]
665    async fn wipe_expired() -> DbResult<()> {
666        init_test_logging();
667        let client = new_client()?;
668
669        let connected_at = ms_since_epoch();
670
671        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
672        let chid = Uuid::parse_str(TEST_CHID).unwrap();
673
674        let node_id = "test_node".to_owned();
675
676        // purge the user record if it exists.
677        let _ = client.remove_user(&uaid).await;
678
679        let test_user = User {
680            uaid,
681            router_type: "webpush".to_owned(),
682            connected_at,
683            router_data: None,
684            node_id: Some(node_id.clone()),
685            ..Default::default()
686        };
687
688        // purge the old user (if present)
689        // in case a prior test failed for whatever reason.
690        let _ = client.remove_user(&uaid).await;
691
692        // can we add the user?
693        let timestamp = now_secs();
694        client.add_user(&test_user).await?;
695        let test_notification = crate::db::Notification {
696            channel_id: chid,
697            version: "test".to_owned(),
698            ttl: 1,
699            timestamp,
700            data: Some("Encrypted".into()),
701            sortkey_timestamp: Some(timestamp),
702            ..Default::default()
703        };
704        client.save_message(&uaid, test_notification).await?;
705        client.increment_storage(&uaid, timestamp + 1).await?;
706        let msgs = client.fetch_timestamp_messages(&uaid, None, 999).await?;
707        assert_eq!(msgs.messages.len(), 0);
708        assert!(client.remove_user(&uaid).await.is_ok());
709        Ok(())
710    }
711
712    /// run a gauntlet of testing. These are a bit linear because they need
713    /// to run in sequence.
714    #[actix_rt::test]
715    async fn run_gauntlet() -> DbResult<()> {
716        init_test_logging();
717        let client = new_client()?;
718
719        let connected_at = ms_since_epoch();
720
721        let user_id = &gen_test_user();
722        let uaid = Uuid::parse_str(user_id).unwrap();
723        let chid = Uuid::parse_str(TEST_CHID).unwrap();
724        let topic_chid = Uuid::parse_str(TOPIC_CHID).unwrap();
725
726        let node_id = "test_node".to_owned();
727
728        // purge the user record if it exists.
729        let _ = client.remove_user(&uaid).await;
730
731        let test_user = User {
732            uaid,
733            router_type: "webpush".to_owned(),
734            connected_at,
735            router_data: None,
736            node_id: Some(node_id.clone()),
737            ..Default::default()
738        };
739
740        // purge the old user (if present)
741        // in case a prior test failed for whatever reason.
742        let _ = client.remove_user(&uaid).await;
743
744        // can we add the user?
745        client.add_user(&test_user).await?;
746        let fetched = client.get_user(&uaid).await?;
747        assert!(fetched.is_some());
748        let fetched = fetched.unwrap();
749        assert_eq!(fetched.router_type, "webpush".to_owned());
750
751        // Simulate a connected_at occuring before the following writes
752        let connected_at = ms_since_epoch();
753
754        // can we add channels?
755        client.add_channel(&uaid, &chid).await?;
756        let channels = client.get_channels(&uaid).await?;
757        assert!(channels.contains(&chid));
758
759        // can we add lots of channels?
760        let mut new_channels: HashSet<Uuid> = HashSet::new();
761        new_channels.insert(chid);
762        for _ in 1..10 {
763            new_channels.insert(uuid::Uuid::new_v4());
764        }
765        let chid_to_remove = uuid::Uuid::new_v4();
766        new_channels.insert(chid_to_remove);
767        client.add_channels(&uaid, new_channels.clone()).await?;
768        let channels = client.get_channels(&uaid).await?;
769        assert_eq!(channels, new_channels);
770
771        // can we remove a channel?
772        assert!(client.remove_channel(&uaid, &chid_to_remove).await?);
773        assert!(!client.remove_channel(&uaid, &chid_to_remove).await?);
774        new_channels.remove(&chid_to_remove);
775        let channels = client.get_channels(&uaid).await?;
776        assert_eq!(channels, new_channels);
777
778        // now ensure that we can update a user that's after the time we set
779        // prior. first ensure that we can't update a user that's before the
780        // time we set prior to the last write
781        let mut updated = User {
782            connected_at,
783            ..test_user.clone()
784        };
785        let result = client.update_user(&mut updated).await;
786        assert!(result.is_ok());
787        assert!(!result.unwrap());
788
789        // Make sure that the `connected_at` wasn't modified
790        let fetched2 = client.get_user(&fetched.uaid).await?.unwrap();
791        assert_eq!(fetched.connected_at, fetched2.connected_at);
792
793        // and make sure we can update a record with a later connected_at time.
794        let mut updated = User {
795            connected_at: fetched.connected_at + 300,
796            ..fetched2
797        };
798        let result = client.update_user(&mut updated).await;
799        assert!(result.is_ok());
800        assert!(result.unwrap());
801        assert_ne!(
802            fetched2.connected_at,
803            client.get_user(&uaid).await?.unwrap().connected_at
804        );
805
806        // can we increment the storage for the user?
807        client
808            .increment_storage(
809                &fetched.uaid,
810                SystemTime::now()
811                    .duration_since(SystemTime::UNIX_EPOCH)
812                    .unwrap()
813                    .as_secs(),
814            )
815            .await?;
816
817        let test_data = "An_encrypted_pile_of_crap".to_owned();
818        let timestamp = now_secs();
819        let sort_key = now_secs();
820        let fetch_timestamp = timestamp;
821        // Can we store a message?
822        let test_notification = crate::db::Notification {
823            channel_id: chid,
824            version: "test".to_owned(),
825            ttl: 300,
826            timestamp,
827            data: Some(test_data.clone()),
828            sortkey_timestamp: Some(sort_key),
829            ..Default::default()
830        };
831        let res = client.save_message(&uaid, test_notification.clone()).await;
832        assert!(res.is_ok());
833
834        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
835        assert_ne!(fetched.messages.len(), 0);
836        let fm = fetched.messages.pop().unwrap();
837        assert_eq!(fm.channel_id, test_notification.channel_id);
838        assert_eq!(fm.data, Some(test_data));
839
840        // Grab all 1 of the messages that were submitted within the past 10 seconds.
841        let fetched = client
842            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp - 10), 999)
843            .await?;
844        assert_ne!(fetched.messages.len(), 0);
845
846        // Try grabbing a message for 10 seconds from now.
847        let fetched = client
848            .fetch_timestamp_messages(&uaid, Some(fetch_timestamp + 10), 999)
849            .await?;
850        assert_eq!(fetched.messages.len(), 0);
851
852        // can we clean up our toys?
853        assert!(client
854            .remove_message(&uaid, &test_notification.chidmessageid())
855            .await
856            .is_ok());
857
858        assert!(client.remove_channel(&uaid, &chid).await.is_ok());
859
860        let msgs = client
861            .fetch_timestamp_messages(&uaid, None, 999)
862            .await?
863            .messages;
864        assert!(msgs.is_empty());
865
866        // Now, can we do all that with topic messages
867        // Unlike bigtable, we don't use [fetch_topic_messages]: it always return None:
868        // they are handled as usuals messages.
869        client.add_channel(&uaid, &topic_chid).await?;
870        let test_data = "An_encrypted_pile_of_crap_with_a_topic".to_owned();
871        let timestamp = now_secs();
872        let sort_key = now_secs();
873
874        // We store 2 messages, with a single topic
875        let test_notification_0 = crate::db::Notification {
876            channel_id: topic_chid,
877            version: "version0".to_owned(),
878            ttl: 300,
879            topic: Some("topic".to_owned()),
880            timestamp,
881            data: Some(test_data.clone()),
882            sortkey_timestamp: Some(sort_key),
883            ..Default::default()
884        };
885        assert!(client
886            .save_message(&uaid, test_notification_0.clone())
887            .await
888            .is_ok());
889
890        let test_notification = crate::db::Notification {
891            timestamp: now_secs(),
892            version: "version1".to_owned(),
893            sortkey_timestamp: Some(sort_key + 10),
894            ..test_notification_0
895        };
896
897        assert!(client
898            .save_message(&uaid, test_notification.clone())
899            .await
900            .is_ok());
901
902        let mut fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
903        assert_eq!(fetched.messages.len(), 1);
904        let fm = fetched.messages.pop().unwrap();
905        assert_eq!(fm.channel_id, test_notification.channel_id);
906        assert_eq!(fm.data, Some(test_data));
907
908        // Grab the message that was submitted.
909        let fetched = client.fetch_timestamp_messages(&uaid, None, 999).await?;
910        assert_ne!(fetched.messages.len(), 0);
911
912        // can we clean up our toys?
913        assert!(client
914            .remove_message(&uaid, &test_notification.chidmessageid())
915            .await
916            .is_ok());
917
918        assert!(client.remove_channel(&uaid, &topic_chid).await.is_ok());
919
920        let msgs = client
921            .fetch_timestamp_messages(&uaid, None, 999)
922            .await?
923            .messages;
924        assert!(msgs.is_empty());
925
926        let fetched = client.get_user(&uaid).await?.unwrap();
927        assert!(client
928            .remove_node_id(&uaid, &node_id, connected_at, &fetched.version)
929            .await
930            .is_ok());
931        // did we remove it?
932        let fetched = client.get_user(&uaid).await?.unwrap();
933        assert_eq!(fetched.node_id, None);
934
935        assert!(client.remove_user(&uaid).await.is_ok());
936
937        assert!(client.get_user(&uaid).await?.is_none());
938        Ok(())
939    }
940
941    #[actix_rt::test]
942    async fn test_expiry() -> DbResult<()> {
943        // Make sure that we really are purging messages correctly
944        init_test_logging();
945        let client = new_client()?;
946
947        let uaid = Uuid::parse_str(&gen_test_user()).unwrap();
948        let chid = Uuid::parse_str(TEST_CHID).unwrap();
949        let now = now_secs();
950
951        let test_notification = crate::db::Notification {
952            channel_id: chid,
953            version: "test".to_owned(),
954            ttl: 2,
955            timestamp: now,
956            data: Some("SomeData".into()),
957            sortkey_timestamp: Some(now),
958            ..Default::default()
959        };
960        debug!("Writing test notif");
961        let res = client.save_message(&uaid, test_notification.clone()).await;
962        assert!(res.is_ok());
963        let key = client.message_key(&Uaid(&uaid), &test_notification.chidmessageid());
964        debug!("Checking {}...", &key);
965        let msg = client
966            .fetch_message(&uaid, &test_notification.chidmessageid())
967            .await?;
968        assert!(!msg.unwrap().is_empty());
969        debug!("Purging...");
970        client.increment_storage(&uaid, now + 2).await?;
971        debug!("Checking {}...", &key);
972        let cc = client
973            .fetch_message(&uaid, &test_notification.chidmessageid())
974            .await;
975        assert_eq!(cc.unwrap(), None);
976        // clean up after the test.
977        assert!(client.remove_user(&uaid).await.is_ok());
978        Ok(())
979    }
980}