|
3 | 3 | use crate::scripting::cass_error::{CassError, CassErrorKind};
|
4 | 4 | use crate::scripting::cql_types::Uuid;
|
5 | 5 | use chrono::{NaiveDate, NaiveTime};
|
| 6 | +use once_cell::sync::Lazy; |
| 7 | +use regex::Regex; |
6 | 8 | use rune::{Any, ToValue, Value};
|
7 | 9 | use scylla::_macro_internal::ColumnType;
|
8 | 10 | use scylla::frame::response::result::{CollectionType, ColumnSpec, NativeType};
|
9 | 11 | use scylla::response::query_result::ColumnSpecs;
|
10 |
| -use scylla::value::{CqlDate, CqlTime, CqlTimeuuid, CqlValue, CqlVarint}; |
| 12 | +use scylla::value::{CqlDate, CqlDuration, CqlTime, CqlTimeuuid, CqlValue, CqlVarint}; |
| 13 | +use std::collections::HashMap; |
11 | 14 | use std::net::IpAddr;
|
12 | 15 | use std::str::FromStr;
|
13 | 16 |
|
14 | 17 | use itertools::*;
|
15 | 18 |
|
| 19 | +static DURATION_REGEX: Lazy<Regex> = Lazy::new(|| { |
| 20 | + Regex::new(concat!( |
| 21 | + r"(?P<years>\d+)y|", |
| 22 | + r"(?P<months>\d+)mo|", |
| 23 | + r"(?P<weeks>\d+)w|", |
| 24 | + r"(?P<days>\d+)d|", |
| 25 | + r"(?P<hours>\d+)h|", |
| 26 | + r"(?P<seconds>\d+)s|", |
| 27 | + r"(?P<millis>\d+)ms|", |
| 28 | + r"(?P<micros>\d+)us|", |
| 29 | + r"(?P<nanoseconds>\d+)ns|", |
| 30 | + r"(?P<minutes>\d+)m|", // must be after 'mo' and 'ms' matchers |
| 31 | + r"(?P<invalid>.+)", // must be last, used for all incorrect matches |
| 32 | + )) |
| 33 | + .unwrap() |
| 34 | +}); |
| 35 | + |
16 | 36 | fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
|
17 |
| - // TODO: add support for the following CQL data types: |
18 |
| - // 'duration' |
19 | 37 | match (v, typ) {
|
20 | 38 | (Value::Bool(v), ColumnType::Native(NativeType::Boolean)) => Ok(CqlValue::Boolean(*v)),
|
21 | 39 |
|
@@ -105,6 +123,104 @@ fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
|
105 | 123 | let cql_time = CqlTime::try_from(naive_time)?;
|
106 | 124 | Ok(CqlValue::Time(cql_time))
|
107 | 125 | }
|
| 126 | + (Value::String(s), ColumnType::Native(NativeType::Duration)) => { |
| 127 | + // TODO: add support for the following 'ISO 8601' format variants: |
| 128 | + // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W |
| 129 | + // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] |
| 130 | + // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations |
| 131 | + let duration_str = s.borrow_ref().unwrap(); |
| 132 | + if duration_str.is_empty() { |
| 133 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 134 | + format!("{:?}", v), |
| 135 | + "NativeType::Duration".to_string(), |
| 136 | + Some("Duration cannot be empty".to_string()), |
| 137 | + ))); |
| 138 | + } |
| 139 | + // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. |
| 140 | + // It stores only months, days and nanoseconds. |
| 141 | + // So, we do not translate days to months and hours to days because those are ambiguous |
| 142 | + let (mut months, mut days, mut nanoseconds) = (0, 0, 0); |
| 143 | + let mut matches_counter = HashMap::from([ |
| 144 | + ("y", 0), |
| 145 | + ("mo", 0), |
| 146 | + ("w", 0), |
| 147 | + ("d", 0), |
| 148 | + ("h", 0), |
| 149 | + ("m", 0), |
| 150 | + ("s", 0), |
| 151 | + ("ms", 0), |
| 152 | + ("us", 0), |
| 153 | + ("ns", 0), |
| 154 | + ]); |
| 155 | + for cap in DURATION_REGEX.captures_iter(&duration_str) { |
| 156 | + if let Some(m) = cap.name("years") { |
| 157 | + months += m.as_str().parse::<i32>().unwrap() * 12; |
| 158 | + *matches_counter.entry("y").or_insert(1) += 1; |
| 159 | + } else if let Some(m) = cap.name("months") { |
| 160 | + months += m.as_str().parse::<i32>().unwrap(); |
| 161 | + *matches_counter.entry("mo").or_insert(1) += 1; |
| 162 | + } else if let Some(m) = cap.name("weeks") { |
| 163 | + days += m.as_str().parse::<i32>().unwrap() * 7; |
| 164 | + *matches_counter.entry("w").or_insert(1) += 1; |
| 165 | + } else if let Some(m) = cap.name("days") { |
| 166 | + days += m.as_str().parse::<i32>().unwrap(); |
| 167 | + *matches_counter.entry("d").or_insert(1) += 1; |
| 168 | + } else if let Some(m) = cap.name("hours") { |
| 169 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 3_600_000_000_000; |
| 170 | + *matches_counter.entry("h").or_insert(1) += 1; |
| 171 | + } else if let Some(m) = cap.name("minutes") { |
| 172 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 60_000_000_000; |
| 173 | + *matches_counter.entry("m").or_insert(1) += 1; |
| 174 | + } else if let Some(m) = cap.name("seconds") { |
| 175 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000_000; |
| 176 | + *matches_counter.entry("s").or_insert(1) += 1; |
| 177 | + } else if let Some(m) = cap.name("millis") { |
| 178 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000; |
| 179 | + *matches_counter.entry("ms").or_insert(1) += 1; |
| 180 | + } else if let Some(m) = cap.name("micros") { |
| 181 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000; |
| 182 | + *matches_counter.entry("us").or_insert(1) += 1; |
| 183 | + } else if let Some(m) = cap.name("nanoseconds") { |
| 184 | + nanoseconds += m.as_str().parse::<i64>().unwrap(); |
| 185 | + *matches_counter.entry("ns").or_insert(1) += 1; |
| 186 | + } else if cap.name("invalid").is_some() { |
| 187 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 188 | + format!("{:?}", v), |
| 189 | + "NativeType::Duration".to_string(), |
| 190 | + Some("Got invalid duration value".to_string()), |
| 191 | + ))); |
| 192 | + } |
| 193 | + } |
| 194 | + if matches_counter.values().all(|&v| v == 0) { |
| 195 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 196 | + format!("{:?}", v), |
| 197 | + "NativeType::Duration".to_string(), |
| 198 | + Some("None time units were found".to_string()), |
| 199 | + ))); |
| 200 | + } |
| 201 | + let duplicated_units: Vec<&str> = matches_counter |
| 202 | + .iter() |
| 203 | + .filter(|&(_, &count)| count > 1) |
| 204 | + .map(|(&unit, _)| unit) |
| 205 | + .collect(); |
| 206 | + if !duplicated_units.is_empty() { |
| 207 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 208 | + format!("{:?}", v), |
| 209 | + "NativeType::Duration".to_string(), |
| 210 | + Some(format!( |
| 211 | + "Got multiple matches for time unit(s): {}", |
| 212 | + duplicated_units.join(", ") |
| 213 | + )), |
| 214 | + ))); |
| 215 | + } |
| 216 | + let cql_duration = CqlDuration { |
| 217 | + months, |
| 218 | + days, |
| 219 | + nanoseconds, |
| 220 | + }; |
| 221 | + Ok(CqlValue::Duration(cql_duration)) |
| 222 | + } |
| 223 | + |
108 | 224 | (Value::String(s), ColumnType::Native(NativeType::Varint)) => {
|
109 | 225 | let varint_str = s.borrow_ref().unwrap();
|
110 | 226 | if !varint_str.chars().all(|c| c.is_ascii_digit()) {
|
@@ -416,3 +532,78 @@ fn read_fields<'a, 'b>(
|
416 | 532 | }
|
417 | 533 | Ok(values)
|
418 | 534 | }
|
| 535 | + |
| 536 | +#[cfg(test)] |
| 537 | +mod tests { |
| 538 | + use super::*; |
| 539 | + |
| 540 | + use rstest::rstest; |
| 541 | + use rune::alloc::String as RuneString; |
| 542 | + use rune::runtime::Shared; |
| 543 | + |
| 544 | + const NS_MULT: i64 = 1_000_000_000; |
| 545 | + |
| 546 | + #[rstest] |
| 547 | + #[case("45ns", 0, 0, 45)] |
| 548 | + #[case("32us", 0, 0, 32 * 1_000)] |
| 549 | + #[case("22ms", 0, 0, 22 * 1_000_000)] |
| 550 | + #[case("15s", 0, 0, 15 * NS_MULT)] |
| 551 | + #[case("2m", 0, 0, 2 * 60 * NS_MULT)] |
| 552 | + #[case("4h", 0, 0, 4 * 3_600 * NS_MULT)] |
| 553 | + #[case("3d", 0, 3, 0)] |
| 554 | + #[case("1w", 0, 7, 0)] |
| 555 | + #[case("1mo", 1, 0, 0)] |
| 556 | + #[case("1y", 12, 0, 0)] |
| 557 | + #[case("45m1s", 0, 0, (45 * 60 + 1) * NS_MULT)] |
| 558 | + #[case("3d21h13m", 0, 3, (21 * 3_600 + 13 * 60) * NS_MULT)] |
| 559 | + #[case("1y3mo2w6d13h14m23s", 15, 20, (13 * 3_600 + 14 * 60 + 23) * NS_MULT)] |
| 560 | + fn test_to_scylla_value_duration_pos( |
| 561 | + #[case] input: String, |
| 562 | + #[case] mo: i32, |
| 563 | + #[case] d: i32, |
| 564 | + #[case] ns: i64, |
| 565 | + ) { |
| 566 | + let expected = format!("{:?}mo{:?}d{:?}ns", mo, d, ns); |
| 567 | + let duration_rune_str = Value::String( |
| 568 | + Shared::new(RuneString::try_from(input).expect("Failed to create RuneString")) |
| 569 | + .expect("Failed to create Shared RuneString"), |
| 570 | + ); |
| 571 | + let actual = to_scylla_value( |
| 572 | + &duration_rune_str, |
| 573 | + &ColumnType::Native(NativeType::Duration), |
| 574 | + ); |
| 575 | + assert_eq!(actual.unwrap().to_string(), expected); |
| 576 | + } |
| 577 | + |
| 578 | + #[rstest] |
| 579 | + #[case("")] |
| 580 | + #[case(" ")] |
| 581 | + #[case("\n")] |
| 582 | + #[case("1")] |
| 583 | + #[case("m1")] |
| 584 | + #[case("1mm")] |
| 585 | + #[case("1mom")] |
| 586 | + #[case("fake")] |
| 587 | + #[case("1d2h3m4h")] |
| 588 | + fn test_to_scylla_value_duration_neg(#[case] input: String) { |
| 589 | + let duration_rune_str = Value::String( |
| 590 | + Shared::new(RuneString::try_from(input.clone()).expect("Failed to create RuneString")) |
| 591 | + .expect("Failed to create Shared RuneString"), |
| 592 | + ); |
| 593 | + let actual = to_scylla_value( |
| 594 | + &duration_rune_str, |
| 595 | + &ColumnType::Native(NativeType::Duration), |
| 596 | + ); |
| 597 | + assert!( |
| 598 | + matches!( |
| 599 | + actual, |
| 600 | + Err(CassError(CassErrorKind::QueryParamConversion(_, _, _))) |
| 601 | + ), |
| 602 | + "{}", |
| 603 | + format!( |
| 604 | + "Error was not raised for the {:?} input. Result: {:?}", |
| 605 | + input, actual |
| 606 | + ) |
| 607 | + ); |
| 608 | + } |
| 609 | +} |
0 commit comments