Skip to content

Commit 178825b

Browse files
committed
Add support for the Duration CQL data type
Example of the Duration CQL data type usage in a rune script: let d1 = "1y3mo2w6d13h14m22s33ms44us55ns"; let d2 = "55mo345d11ns"; let d3 = "1m5s3ms"; let d4 = "5s";
1 parent 48856ef commit 178825b

File tree

3 files changed

+198
-5
lines changed

3 files changed

+198
-5
lines changed

Cargo.lock

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ lazy_static = "1.4.0"
3131
metrohash = "1.0"
3232
more-asserts = "0.3"
3333
num_cpus = "1.13.0"
34+
once_cell = "1.21"
3435
openssl = "0.10.70"
3536
parse_duration = "2.1.1"
3637
pin-project = "1.1"

src/scripting/bind.rs

+194-3
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,37 @@
33
use crate::scripting::cass_error::{CassError, CassErrorKind};
44
use crate::scripting::cql_types::Uuid;
55
use chrono::{NaiveDate, NaiveTime};
6+
use once_cell::sync::Lazy;
7+
use regex::Regex;
68
use rune::{Any, ToValue, Value};
79
use scylla::_macro_internal::ColumnType;
810
use scylla::frame::response::result::{CollectionType, ColumnSpec, NativeType};
911
use scylla::response::query_result::ColumnSpecs;
10-
use scylla::value::{CqlDate, CqlTime, CqlTimeuuid, CqlValue, CqlVarint};
12+
use scylla::value::{CqlDate, CqlDuration, CqlTime, CqlTimeuuid, CqlValue, CqlVarint};
13+
use std::collections::HashMap;
1114
use std::net::IpAddr;
1215
use std::str::FromStr;
1316

1417
use itertools::*;
1518

19+
static DURATION_REGEX: Lazy<Regex> = Lazy::new(|| {
20+
Regex::new(concat!(
21+
r"(?P<years>\d+)y|",
22+
r"(?P<months>\d+)mo|",
23+
r"(?P<weeks>\d+)w|",
24+
r"(?P<days>\d+)d|",
25+
r"(?P<hours>\d+)h|",
26+
r"(?P<seconds>\d+)s|",
27+
r"(?P<millis>\d+)ms|",
28+
r"(?P<micros>\d+)us|",
29+
r"(?P<nanoseconds>\d+)ns|",
30+
r"(?P<minutes>\d+)m|", // must be after 'mo' and 'ms' matchers
31+
r"(?P<invalid>.+)", // must be last, used for all incorrect matches
32+
))
33+
.unwrap()
34+
});
35+
1636
fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
17-
// TODO: add support for the following CQL data types:
18-
// 'duration'
1937
match (v, typ) {
2038
(Value::Bool(v), ColumnType::Native(NativeType::Boolean)) => Ok(CqlValue::Boolean(*v)),
2139

@@ -105,6 +123,104 @@ fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
105123
let cql_time = CqlTime::try_from(naive_time)?;
106124
Ok(CqlValue::Time(cql_time))
107125
}
126+
(Value::String(s), ColumnType::Native(NativeType::Duration)) => {
127+
// TODO: add support for the following 'ISO 8601' format variants:
128+
// - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W
129+
// - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss]
130+
// See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations
131+
let duration_str = s.borrow_ref().unwrap();
132+
if duration_str.is_empty() {
133+
return Err(CassError(CassErrorKind::QueryParamConversion(
134+
format!("{:?}", v),
135+
"NativeType::Duration".to_string(),
136+
Some("Duration cannot be empty".to_string()),
137+
)));
138+
}
139+
// NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics.
140+
// It stores only months, days and nanoseconds.
141+
// So, we do not translate days to months and hours to days because those are ambiguous
142+
let (mut months, mut days, mut nanoseconds) = (0, 0, 0);
143+
let mut matches_counter = HashMap::from([
144+
("y", 0),
145+
("mo", 0),
146+
("w", 0),
147+
("d", 0),
148+
("h", 0),
149+
("m", 0),
150+
("s", 0),
151+
("ms", 0),
152+
("us", 0),
153+
("ns", 0),
154+
]);
155+
for cap in DURATION_REGEX.captures_iter(&duration_str) {
156+
if let Some(m) = cap.name("years") {
157+
months += m.as_str().parse::<i32>().unwrap() * 12;
158+
*matches_counter.entry("y").or_insert(1) += 1;
159+
} else if let Some(m) = cap.name("months") {
160+
months += m.as_str().parse::<i32>().unwrap();
161+
*matches_counter.entry("mo").or_insert(1) += 1;
162+
} else if let Some(m) = cap.name("weeks") {
163+
days += m.as_str().parse::<i32>().unwrap() * 7;
164+
*matches_counter.entry("w").or_insert(1) += 1;
165+
} else if let Some(m) = cap.name("days") {
166+
days += m.as_str().parse::<i32>().unwrap();
167+
*matches_counter.entry("d").or_insert(1) += 1;
168+
} else if let Some(m) = cap.name("hours") {
169+
nanoseconds += m.as_str().parse::<i64>().unwrap() * 3_600_000_000_000;
170+
*matches_counter.entry("h").or_insert(1) += 1;
171+
} else if let Some(m) = cap.name("minutes") {
172+
nanoseconds += m.as_str().parse::<i64>().unwrap() * 60_000_000_000;
173+
*matches_counter.entry("m").or_insert(1) += 1;
174+
} else if let Some(m) = cap.name("seconds") {
175+
nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000_000;
176+
*matches_counter.entry("s").or_insert(1) += 1;
177+
} else if let Some(m) = cap.name("millis") {
178+
nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000;
179+
*matches_counter.entry("ms").or_insert(1) += 1;
180+
} else if let Some(m) = cap.name("micros") {
181+
nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000;
182+
*matches_counter.entry("us").or_insert(1) += 1;
183+
} else if let Some(m) = cap.name("nanoseconds") {
184+
nanoseconds += m.as_str().parse::<i64>().unwrap();
185+
*matches_counter.entry("ns").or_insert(1) += 1;
186+
} else if cap.name("invalid").is_some() {
187+
return Err(CassError(CassErrorKind::QueryParamConversion(
188+
format!("{:?}", v),
189+
"NativeType::Duration".to_string(),
190+
Some("Got invalid duration value".to_string()),
191+
)));
192+
}
193+
}
194+
if matches_counter.values().all(|&v| v == 0) {
195+
return Err(CassError(CassErrorKind::QueryParamConversion(
196+
format!("{:?}", v),
197+
"NativeType::Duration".to_string(),
198+
Some("None time units were found".to_string()),
199+
)));
200+
}
201+
let duplicated_units: Vec<&str> = matches_counter
202+
.iter()
203+
.filter(|&(_, &count)| count > 1)
204+
.map(|(&unit, _)| unit)
205+
.collect();
206+
if !duplicated_units.is_empty() {
207+
return Err(CassError(CassErrorKind::QueryParamConversion(
208+
format!("{:?}", v),
209+
"NativeType::Duration".to_string(),
210+
Some(format!(
211+
"Got multiple matches for time unit(s): {}",
212+
duplicated_units.join(", ")
213+
)),
214+
)));
215+
}
216+
let cql_duration = CqlDuration {
217+
months,
218+
days,
219+
nanoseconds,
220+
};
221+
Ok(CqlValue::Duration(cql_duration))
222+
}
223+
108224
(Value::String(s), ColumnType::Native(NativeType::Varint)) => {
109225
let varint_str = s.borrow_ref().unwrap();
110226
if !varint_str.chars().all(|c| c.is_ascii_digit()) {
@@ -416,3 +532,78 @@ fn read_fields<'a, 'b>(
416532
}
417533
Ok(values)
418534
}
535+
536+
#[cfg(test)]
537+
mod tests {
538+
use super::*;
539+
540+
use rstest::rstest;
541+
use rune::alloc::String as RuneString;
542+
use rune::runtime::Shared;
543+
544+
const NS_MULT: i64 = 1_000_000_000;
545+
546+
#[rstest]
547+
#[case("45ns", 0, 0, 45)]
548+
#[case("32us", 0, 0, 32 * 1_000)]
549+
#[case("22ms", 0, 0, 22 * 1_000_000)]
550+
#[case("15s", 0, 0, 15 * NS_MULT)]
551+
#[case("2m", 0, 0, 2 * 60 * NS_MULT)]
552+
#[case("4h", 0, 0, 4 * 3_600 * NS_MULT)]
553+
#[case("3d", 0, 3, 0)]
554+
#[case("1w", 0, 7, 0)]
555+
#[case("1mo", 1, 0, 0)]
556+
#[case("1y", 12, 0, 0)]
557+
#[case("45m1s", 0, 0, (45 * 60 + 1) * NS_MULT)]
558+
#[case("3d21h13m", 0, 3, (21 * 3_600 + 13 * 60) * NS_MULT)]
559+
#[case("1y3mo2w6d13h14m23s", 15, 20, (13 * 3_600 + 14 * 60 + 23) * NS_MULT)]
560+
fn test_to_scylla_value_duration_pos(
561+
#[case] input: String,
562+
#[case] mo: i32,
563+
#[case] d: i32,
564+
#[case] ns: i64,
565+
) {
566+
let expected = format!("{:?}mo{:?}d{:?}ns", mo, d, ns);
567+
let duration_rune_str = Value::String(
568+
Shared::new(RuneString::try_from(input).expect("Failed to create RuneString"))
569+
.expect("Failed to create Shared RuneString"),
570+
);
571+
let actual = to_scylla_value(
572+
&duration_rune_str,
573+
&ColumnType::Native(NativeType::Duration),
574+
);
575+
assert_eq!(actual.unwrap().to_string(), expected);
576+
}
577+
578+
#[rstest]
579+
#[case("")]
580+
#[case(" ")]
581+
#[case("\n")]
582+
#[case("1")]
583+
#[case("m1")]
584+
#[case("1mm")]
585+
#[case("1mom")]
586+
#[case("fake")]
587+
#[case("1d2h3m4h")]
588+
fn test_to_scylla_value_duration_neg(#[case] input: String) {
589+
let duration_rune_str = Value::String(
590+
Shared::new(RuneString::try_from(input.clone()).expect("Failed to create RuneString"))
591+
.expect("Failed to create Shared RuneString"),
592+
);
593+
let actual = to_scylla_value(
594+
&duration_rune_str,
595+
&ColumnType::Native(NativeType::Duration),
596+
);
597+
assert!(
598+
matches!(
599+
actual,
600+
Err(CassError(CassErrorKind::QueryParamConversion(_, _, _)))
601+
),
602+
"{}",
603+
format!(
604+
"Error was not raised for the {:?} input. Result: {:?}",
605+
input, actual
606+
)
607+
);
608+
}
609+
}

0 commit comments

Comments
 (0)