|
3 | 3 | use crate::scripting::cass_error::{CassError, CassErrorKind};
|
4 | 4 | use crate::scripting::cql_types::Uuid;
|
5 | 5 | use chrono::{NaiveDate, NaiveTime};
|
| 6 | +use once_cell::sync::Lazy; |
| 7 | +use regex::Regex; |
6 | 8 | use rune::{Any, ToValue, Value};
|
7 | 9 | use scylla::_macro_internal::ColumnType;
|
8 | 10 | use scylla::frame::response::result::{CollectionType, ColumnSpec, NativeType};
|
9 | 11 | use scylla::response::query_result::ColumnSpecs;
|
10 |
| -use scylla::value::{CqlDate, CqlTime, CqlTimeuuid, CqlValue, CqlVarint}; |
| 12 | +use scylla::value::{CqlDate, CqlDuration, CqlTime, CqlTimeuuid, CqlValue, CqlVarint}; |
| 13 | +use std::collections::HashMap; |
11 | 14 | use std::net::IpAddr;
|
12 | 15 | use std::str::FromStr;
|
13 | 16 |
|
14 | 17 | use itertools::*;
|
15 | 18 |
|
| 19 | +static DURATION_REGEX: Lazy<Regex> = Lazy::new(|| { |
| 20 | + Regex::new(concat!( |
| 21 | + r"(?P<years>\d+)y|", |
| 22 | + r"(?P<months>\d+)mo|", |
| 23 | + r"(?P<weeks>\d+)w|", |
| 24 | + r"(?P<days>\d+)d|", |
| 25 | + r"(?P<hours>\d+)h|", |
| 26 | + r"(?P<seconds>\d+)s|", |
| 27 | + r"(?P<millis>\d+)ms|", |
| 28 | + r"(?P<micros>\d+)us|", |
| 29 | + r"(?P<nanoseconds>\d+)ns|", |
| 30 | + r"(?P<minutes>\d+)m|", // must be after 'mo' and 'ms' matchers |
| 31 | + r"(?P<invalid>.+)", // must be last, used for all incorrect matches |
| 32 | + )) |
| 33 | + .unwrap() |
| 34 | +}); |
| 35 | + |
16 | 36 | fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
|
17 |
| - // TODO: add support for the following CQL data types: |
18 |
| - // 'duration' |
19 | 37 | match (v, typ) {
|
20 | 38 | (Value::Bool(v), ColumnType::Native(NativeType::Boolean)) => Ok(CqlValue::Boolean(*v)),
|
21 | 39 |
|
@@ -105,6 +123,97 @@ fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result<CqlValue, CassError> {
|
105 | 123 | let cql_time = CqlTime::try_from(naive_time)?;
|
106 | 124 | Ok(CqlValue::Time(cql_time))
|
107 | 125 | }
|
| 126 | + (Value::String(s), ColumnType::Native(NativeType::Duration)) => { |
| 127 | + // TODO: add support for the following 'ISO 8601' format variants: |
| 128 | + // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W |
| 129 | + // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] |
| 130 | + // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations |
| 131 | + let duration_str = s.borrow_ref().unwrap(); |
| 132 | + if duration_str.is_empty() { |
| 133 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 134 | + format!("{:?}", v), |
| 135 | + "NativeType::Duration".to_string(), |
| 136 | + Some("Duration cannot be empty".to_string()), |
| 137 | + ))); |
| 138 | + } |
| 139 | + // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. |
| 140 | + // It stores only months, days and nanoseconds. |
| 141 | + // So, we do not translate days to months and hours to days because those are ambiguous |
| 142 | + let (mut months, mut days, mut nanoseconds) = (0, 0, 0); |
| 143 | + let mut matches_counter = HashMap::from([ |
| 144 | + ("y", 0), |
| 145 | + ("mo", 0), |
| 146 | + ("w", 0), |
| 147 | + ("d", 0), |
| 148 | + ("h", 0), |
| 149 | + ("m", 0), |
| 150 | + ("s", 0), |
| 151 | + ("ms", 0), |
| 152 | + ("us", 0), |
| 153 | + ("ns", 0), |
| 154 | + ]); |
| 155 | + for cap in DURATION_REGEX.captures_iter(&duration_str) { |
| 156 | + if let Some(m) = cap.name("years") { |
| 157 | + months += m.as_str().parse::<i32>().unwrap() * 12; |
| 158 | + *matches_counter.entry("y").or_insert(1) += 1; |
| 159 | + } else if let Some(m) = cap.name("months") { |
| 160 | + months += m.as_str().parse::<i32>().unwrap(); |
| 161 | + *matches_counter.entry("mo").or_insert(1) += 1; |
| 162 | + } else if let Some(m) = cap.name("weeks") { |
| 163 | + days += m.as_str().parse::<i32>().unwrap() * 7; |
| 164 | + *matches_counter.entry("w").or_insert(1) += 1; |
| 165 | + } else if let Some(m) = cap.name("days") { |
| 166 | + days += m.as_str().parse::<i32>().unwrap(); |
| 167 | + *matches_counter.entry("d").or_insert(1) += 1; |
| 168 | + } else if let Some(m) = cap.name("hours") { |
| 169 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 3_600_000_000_000; |
| 170 | + *matches_counter.entry("h").or_insert(1) += 1; |
| 171 | + } else if let Some(m) = cap.name("minutes") { |
| 172 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 60_000_000_000; |
| 173 | + *matches_counter.entry("m").or_insert(1) += 1; |
| 174 | + } else if let Some(m) = cap.name("seconds") { |
| 175 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000_000; |
| 176 | + *matches_counter.entry("s").or_insert(1) += 1; |
| 177 | + } else if let Some(m) = cap.name("millis") { |
| 178 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000_000; |
| 179 | + *matches_counter.entry("ms").or_insert(1) += 1; |
| 180 | + } else if let Some(m) = cap.name("micros") { |
| 181 | + nanoseconds += m.as_str().parse::<i64>().unwrap() * 1_000; |
| 182 | + *matches_counter.entry("us").or_insert(1) += 1; |
| 183 | + } else if let Some(m) = cap.name("nanoseconds") { |
| 184 | + nanoseconds += m.as_str().parse::<i64>().unwrap(); |
| 185 | + *matches_counter.entry("ns").or_insert(1) += 1; |
| 186 | + } else if cap.name("invalid").is_some() { |
| 187 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 188 | + format!("{:?}", v), |
| 189 | + "NativeType::Duration".to_string(), |
| 190 | + Some("Got invalid duration value".to_string()), |
| 191 | + ))); |
| 192 | + } |
| 193 | + } |
| 194 | + let duplicated_units: Vec<&str> = matches_counter |
| 195 | + .iter() |
| 196 | + .filter(|&(_, &count)| count > 1) |
| 197 | + .map(|(&unit, _)| unit) |
| 198 | + .collect(); |
| 199 | + if !duplicated_units.is_empty() { |
| 200 | + return Err(CassError(CassErrorKind::QueryParamConversion( |
| 201 | + format!("{:?}", v), |
| 202 | + "NativeType::Duration".to_string(), |
| 203 | + Some(format!( |
| 204 | + "Got multiple matches for time unit(s): {}", |
| 205 | + duplicated_units.join(", ") |
| 206 | + )), |
| 207 | + ))); |
| 208 | + } |
| 209 | + let cql_duration = CqlDuration { |
| 210 | + months, |
| 211 | + days, |
| 212 | + nanoseconds, |
| 213 | + }; |
| 214 | + Ok(CqlValue::Duration(cql_duration)) |
| 215 | + } |
| 216 | + |
108 | 217 | (Value::String(s), ColumnType::Native(NativeType::Varint)) => {
|
109 | 218 | let varint_str = s.borrow_ref().unwrap();
|
110 | 219 | if !varint_str.chars().all(|c| c.is_ascii_digit()) {
|
|
0 commit comments