Skip to content

Commit b55f7a1

Browse files
committed
Merge branch main into colin/hide-series-lit
2 parents 341f4fb + d58ce8e commit b55f7a1

File tree

14 files changed

+1303
-89
lines changed

14 files changed

+1303
-89
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ build-release: check-toolchain frontend .venv ## Compile and install a faster D
6464

6565
.PHONY: test
6666
test: .venv build ## Run tests
67-
HYPOTHESIS_MAX_EXAMPLES=$(HYPOTHESIS_MAX_EXAMPLES) $(VENV_BIN)/pytest --hypothesis-seed=$(HYPOTHESIS_SEED)
67+
HYPOTHESIS_MAX_EXAMPLES=$(HYPOTHESIS_MAX_EXAMPLES) $(VENV_BIN)/pytest --hypothesis-seed=$(HYPOTHESIS_SEED) --ignore tests/integration
6868

6969
.PHONY: doctests
7070
doctests:

daft/expressions/expressions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,6 @@ def lit(value: object) -> Expression:
113113
sign, digits, exponent = value.as_tuple()
114114
assert isinstance(exponent, int)
115115
lit_value = _decimal_lit(sign == 1, digits, exponent)
116-
elif isinstance(value, ImageFormat):
117-
lit_value = _lit(str(value))
118-
elif isinstance(value, ImageMode):
119-
lit_value = _lit(str(value))
120116
else:
121117
lit_value = _lit(value)
122118
return Expression._from_pyexpr(lit_value)

src/common/io-config/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[dependencies]
22
aws-credential-types = {version = "0.55.3"}
3-
chrono = {workspace = true}
3+
chrono = {workspace = true, features = ["serde"]}
44
common-error = {path = "../error", default-features = false}
55
common-py-serde = {path = "../py-serde", default-features = false}
66
derivative = {workspace = true}

src/daft-connect/src/execute.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use common_error::DaftResult;
44
use common_file_formats::{FileFormat, WriteMode};
55
use daft_catalog::TableSource;
66
use daft_context::get_context;
7-
use daft_dsl::LiteralValue;
7+
use daft_dsl::{literals_to_series, LiteralValue};
88
use daft_logical_plan::LogicalPlanBuilder;
99
use daft_micropartition::MicroPartition;
1010
use daft_recordbatch::RecordBatch;
@@ -366,9 +366,7 @@ impl ConnectSession {
366366
let tbl = RecordBatch::concat(&tbls)?;
367367
let output = tbl.to_comfy_table(None).to_string();
368368

369-
let s = LiteralValue::Utf8(output)
370-
.into_single_value_series()?
371-
.rename("show_string");
369+
let s = literals_to_series(&[LiteralValue::Utf8(output)])?.rename("show_string");
372370

373371
let tbl = RecordBatch::from_nonempty_columns(vec![s])?;
374372
response_builder.arrow_batch_response(&tbl)

src/daft-dsl/Cargo.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
bincode = {workspace = true}
33
common-error = {path = "../common/error", default-features = false}
44
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
5+
common-io-config = {path = "../common/io-config", default-features = false}
56
common-py-serde = {path = "../common/py-serde", default-features = false}
67
common-resource-request = {path = "../common/resource-request", default-features = false}
78
common-treenode = {path = "../common/treenode", default-features = false}
@@ -10,12 +11,21 @@ daft-sketch = {path = "../daft-sketch", default-features = false}
1011
derive_more = {workspace = true}
1112
indexmap = {workspace = true}
1213
itertools = {workspace = true}
14+
num-traits = {workspace = true}
1315
pyo3 = {workspace = true, optional = true}
1416
serde = {workspace = true}
1517
typetag = {workspace = true}
1618

1719
[features]
18-
python = ["dep:pyo3", "common-error/python", "daft-core/python", "common-treenode/python", "common-py-serde/python", "common-resource-request/python"]
20+
python = [
21+
"dep:pyo3",
22+
"common-error/python",
23+
"daft-core/python",
24+
"common-io-config/python",
25+
"common-treenode/python",
26+
"common-py-serde/python",
27+
"common-resource-request/python"
28+
]
1929
test-utils = []
2030

2131
[lints]

src/daft-dsl/src/functions/function_args.rs

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use std::sync::Arc;
22

33
use common_error::{DaftError, DaftResult};
4+
use daft_core::series::Series;
45
use serde::{Deserialize, Serialize};
56

7+
use crate::{lit::FromLiteral, ExprRef, LiteralValue};
8+
69
/// Wrapper around T to hold either a named or an unnamed argument.
710
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)]
811
pub enum FunctionArg<T> {
@@ -388,11 +391,106 @@ impl<T> FunctionArgs<T> {
388391
}
389392
}
390393

394+
impl FunctionArgs<Series> {
395+
/// Uses serde to extract a scalar value from the function args.
396+
/// This will error if the series is not scalar. (len ==1) or if it is not deserializable to the provided type.
397+
/// It will also return an error if the value does not exist.
398+
pub fn extract<V: FromLiteral, Key: FunctionArgKey>(&self, position: Key) -> DaftResult<V> {
399+
let value = position.required(self).map_err(|_| {
400+
DaftError::ValueError(format!(
401+
"Expected a value for the required argument at position `{position:?}`"
402+
))
403+
})?;
404+
405+
let lit = LiteralValue::try_from_single_value_series(value)?;
406+
407+
let res = V::try_from_literal(&lit);
408+
res.map_err(|e| e.into())
409+
}
410+
411+
/// Uses serde to extract an optional scalar value from the function args.
412+
/// This will error if the series is not scalar. (len ==1) or if it is not deserializable to the provided type.
413+
/// if the value does not exist, None is returned.
414+
pub fn extract_optional<V: FromLiteral, Key: FunctionArgKey>(
415+
&self,
416+
position: Key,
417+
) -> DaftResult<Option<V>> {
418+
let value = position.optional(self).map_err(|_| {
419+
DaftError::ValueError(format!(
420+
"Expected a value for the optional argument at position `{position:?}`"
421+
))
422+
})?;
423+
424+
match value {
425+
Some(value) => {
426+
let lit = LiteralValue::try_from_single_value_series(value)?;
427+
let res = V::try_from_literal(&lit);
428+
res.map_err(|e| e.into()).map(Some)
429+
}
430+
None => Ok(None),
431+
}
432+
}
433+
}
434+
435+
impl FunctionArgs<ExprRef> {
436+
/// Uses serde to extract a scalar value from the function args.
437+
/// This will error if the the expr is not a literal, or if it is not deserializable to the provided type.
438+
/// It will also error if the value does not exist.
439+
pub fn extract<V: FromLiteral, Key: FunctionArgKey>(&self, position: Key) -> DaftResult<V> {
440+
let value = position.required(self).map_err(|_| {
441+
DaftError::ValueError(format!(
442+
"Expected a value for the required argument at position `{position:?}`"
443+
))
444+
})?;
445+
446+
match value.as_literal() {
447+
Some(lit) => {
448+
let res = V::try_from_literal(lit);
449+
res.map_err(|e| e.into())
450+
}
451+
None => Err(DaftError::ValueError(format!(
452+
"Expected a literal value for the optional argument at position `{position:?}`"
453+
))),
454+
}
455+
}
456+
/// Uses serde to extract an optional scalar value from the function args.
457+
/// This will error if the the expr is not a literal, or if it is not deserializable to the provided type.
458+
/// if the value does not exist, None is returned.
459+
pub fn extract_optional<V: FromLiteral, Key: FunctionArgKey>(
460+
&self,
461+
position: Key,
462+
) -> DaftResult<Option<V>> {
463+
let value = position.optional(self).map_err(|_| {
464+
DaftError::ValueError(format!(
465+
"Expected a value for the optional argument at position `{position:?}`"
466+
))
467+
})?;
468+
469+
match value {
470+
Some(value) => match value.as_literal() {
471+
Some(lit) => {
472+
let res = V::try_from_literal(lit);
473+
res.map_err(|e| e.into()).map(Some)
474+
}
475+
None => Err(DaftError::ValueError(format!(
476+
"Expected a literal value for the optional argument at position `{position:?}`"
477+
))),
478+
},
479+
None => Ok(None),
480+
}
481+
}
482+
}
483+
391484
#[cfg(test)]
392485
mod tests {
393486
use common_error::DaftResult;
487+
use common_io_config::IOConfig;
488+
use daft_core::prelude::CountMode;
394489

395-
use crate::functions::function_args::{FunctionArg, FunctionArgs};
490+
use crate::{
491+
functions::function_args::{FunctionArg, FunctionArgs},
492+
lit, Literal,
493+
};
396494
#[test]
397495
fn test_function_args_ordering() {
398496
let res = FunctionArgs::try_new(vec![
@@ -533,4 +631,26 @@ mod tests {
533631
assert!(args.is_empty());
534632
assert_eq!(args.len(), 0);
535633
}
634+
635+
#[test]
636+
fn test_extract() -> DaftResult<()> {
637+
let io_conf = IOConfig::default();
638+
let count_mode = CountMode::Valid;
639+
let args = FunctionArgs::new_unchecked(vec![
640+
FunctionArg::unnamed(100i64.lit()),
641+
FunctionArg::unnamed(222i32.lit()),
642+
FunctionArg::named("io_config", lit(io_conf.clone())),
643+
FunctionArg::named("arg2", lit(count_mode.clone())),
644+
]);
645+
646+
let res: usize = args.extract(0)?;
647+
assert_eq!(res, 100);
648+
let second_pos: usize = args.extract(1)?;
649+
assert_eq!(second_pos, 222);
650+
let io_conf_extracted: IOConfig = args.extract("io_config")?;
651+
assert_eq!(io_conf_extracted, io_conf);
652+
let count_mode_extracted: CountMode = args.extract("arg2")?;
653+
assert_eq!(count_mode_extracted, count_mode);
654+
Ok(())
655+
}
536656
}

src/daft-dsl/src/lit/conversions.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
use common_error::{DaftError, DaftResult};
2+
use common_io_config::IOConfig;
3+
use daft_core::{
4+
datatypes::IntervalValue,
5+
prelude::{CountMode, ImageFormat, ImageMode},
6+
series::Series,
7+
};
8+
use num_traits::cast;
9+
use serde::de::DeserializeOwned;
10+
11+
use super::{deserializer, serializer, FromLiteral, Literal, LiteralValue};
12+
#[cfg(feature = "python")]
13+
use crate::pyobj_serde::PyObjectWrapper;
14+
15+
impl Literal for IntervalValue {
16+
fn literal_value(self) -> LiteralValue {
17+
LiteralValue::Interval(self)
18+
}
19+
}
20+
21+
impl Literal for String {
22+
fn literal_value(self) -> LiteralValue {
23+
LiteralValue::Utf8(self)
24+
}
25+
}
26+
27+
impl FromLiteral for String {
28+
fn try_from_literal(lit: &LiteralValue) -> DaftResult<Self> {
29+
match lit {
30+
LiteralValue::Utf8(s) => Ok(s.clone()),
31+
_ => Err(DaftError::TypeError(format!(
32+
"Cannot convert {:?} to String",
33+
lit
34+
))),
35+
}
36+
}
37+
}
38+
39+
impl Literal for &'_ str {
40+
fn literal_value(self) -> LiteralValue {
41+
LiteralValue::Utf8(self.to_owned())
42+
}
43+
}
44+
45+
impl Literal for &'_ [u8] {
46+
fn literal_value(self) -> LiteralValue {
47+
LiteralValue::Binary(self.to_vec())
48+
}
49+
}
50+
51+
impl Literal for Series {
52+
fn literal_value(self) -> LiteralValue {
53+
LiteralValue::Series(self)
54+
}
55+
}
56+
57+
#[cfg(feature = "python")]
58+
impl Literal for pyo3::PyObject {
59+
fn literal_value(self) -> LiteralValue {
60+
LiteralValue::Python(PyObjectWrapper(std::sync::Arc::new(self)))
61+
}
62+
}
63+
64+
impl<T> Literal for Option<T>
65+
where
66+
T: Literal,
67+
{
68+
fn literal_value(self) -> LiteralValue {
69+
match self {
70+
Some(val) => val.literal_value(),
71+
None => LiteralValue::Null,
72+
}
73+
}
74+
}
75+
76+
macro_rules! make_literal {
77+
($TYPE:ty, $SCALAR:ident) => {
78+
impl Literal for $TYPE {
79+
fn literal_value(self) -> LiteralValue {
80+
LiteralValue::$SCALAR(self)
81+
}
82+
}
83+
impl FromLiteral for $TYPE {
84+
fn try_from_literal(lit: &LiteralValue) -> DaftResult<Self> {
85+
match lit {
86+
LiteralValue::$SCALAR(v) => Ok(*v),
87+
_ => Err(DaftError::TypeError(format!(
88+
"Expected {} literal",
89+
stringify!($TYPE)
90+
))),
91+
}
92+
}
93+
}
94+
};
95+
}
96+
make_literal!(bool, Boolean);
97+
make_literal!(i8, Int8);
98+
make_literal!(u8, UInt8);
99+
make_literal!(i16, Int16);
100+
make_literal!(u16, UInt16);
101+
make_literal!(i32, Int32);
102+
make_literal!(u32, UInt32);
103+
make_literal!(i64, Int64);
104+
make_literal!(u64, UInt64);
105+
make_literal!(f64, Float64);
106+
107+
impl FromLiteral for usize {
108+
fn try_from_literal(lit: &LiteralValue) -> DaftResult<Self> {
109+
match lit {
110+
LiteralValue::Int8(i8) => cast(*i8),
111+
LiteralValue::UInt8(u8) => cast(*u8),
112+
LiteralValue::Int16(i16) => cast(*i16),
113+
LiteralValue::UInt16(u16) => cast(*u16),
114+
LiteralValue::Int32(i32) => cast(*i32),
115+
LiteralValue::UInt32(u32) => cast(*u32),
116+
LiteralValue::Int64(i64) => cast(*i64),
117+
LiteralValue::UInt64(u64) => cast(*u64),
118+
_ => None,
119+
}
120+
.ok_or_else(|| DaftError::ValueError("Unsupported literal type".to_string()))
121+
}
122+
}
123+
124+
/// Marker trait to allowlist what can be converted to a literal via serde
125+
trait SerializableLiteral: serde::Serialize {}
126+
127+
/// Marker trait to allowlist what can be converted from a literal via serde
128+
trait DeserializableLiteral: DeserializeOwned {}
129+
130+
impl SerializableLiteral for IOConfig {}
131+
impl DeserializableLiteral for IOConfig {}
132+
impl SerializableLiteral for ImageMode {}
133+
impl DeserializableLiteral for ImageMode {}
134+
impl SerializableLiteral for ImageFormat {}
135+
impl DeserializableLiteral for ImageFormat {}
136+
impl SerializableLiteral for CountMode {}
137+
impl DeserializableLiteral for CountMode {}
138+
139+
impl<D> FromLiteral for D
140+
where
141+
D: DeserializableLiteral,
142+
{
143+
fn try_from_literal(lit: &LiteralValue) -> DaftResult<Self> {
144+
let deserializer = deserializer::LiteralValueDeserializer { lit };
145+
D::deserialize(deserializer).map_err(|e| e.into())
146+
}
147+
}
148+
149+
impl<S> Literal for S
150+
where
151+
S: SerializableLiteral,
152+
{
153+
fn literal_value(self) -> LiteralValue {
154+
self.serialize(serializer::LiteralValueSerializer)
155+
.expect("serialization failed")
156+
}
157+
}

0 commit comments

Comments
 (0)