diff --git a/crates/core/src/operations/constraints.rs b/crates/core/src/operations/constraints.rs index 0512f043e6..50916d3d46 100644 --- a/crates/core/src/operations/constraints.rs +++ b/crates/core/src/operations/constraints.rs @@ -23,15 +23,14 @@ use crate::operations::datafusion_utils::Expression; use crate::protocol::DeltaOperation; use crate::table::Constraint; use crate::{DeltaResult, DeltaTable, DeltaTableError}; +use std::collections::HashMap; /// Build a constraint to add to a table pub struct ConstraintBuilder { /// A snapshot of the table's state snapshot: Option, - /// Name of the constraint - name: Option, - /// Constraint expression - expr: Option, + /// Hashmap containing an name of the constraint and expression + check_constraints: HashMap, /// Delta object store for handling data files log_store: LogStoreRef, /// Datafusion session state relevant for executing the input plan @@ -54,8 +53,7 @@ impl ConstraintBuilder { /// Create a new builder pub(crate) fn new(log_store: LogStoreRef, snapshot: Option) -> Self { Self { - name: None, - expr: None, + check_constraints: Default::default(), snapshot, log_store, session: None, @@ -70,8 +68,21 @@ impl ConstraintBuilder { name: S, expression: E, ) -> Self { - self.name = Some(name.into()); - self.expr = Some(expression.into()); + self.check_constraints + .insert(name.into(), expression.into()); + self + } + + /// Specify multiple constraints to be added + pub fn with_constraints, E: Into>( + mut self, + constraints: HashMap, + ) -> Self { + self.check_constraints.extend( + constraints + .into_iter() + .map(|(name, expr)| (name.into(), expr.into())), + ); self } @@ -108,23 +119,29 @@ impl std::future::IntoFuture for ConstraintBuilder { let operation_id = this.get_operation_id(); this.pre_execute(operation_id).await?; - let name = match this.name { - Some(v) => v, - None => return Err(DeltaTableError::Generic("No name provided".to_string())), - }; - - let expr = this - .expr - .ok_or_else(|| DeltaTableError::Generic("No Expression provided".to_string()))?; + if this.check_constraints.is_empty() { + return Err(DeltaTableError::Generic( + "No check constraint (Name and Expression) provided".to_string(), + )); + } let mut metadata = snapshot.metadata().clone(); - let configuration_key = format!("delta.constraints.{name}"); - if metadata.configuration().contains_key(&configuration_key) { - return Err(DeltaTableError::Generic(format!( - "Constraint with name: {name} already exists" - ))); - } + let configuration_key_mapper: HashMap = HashMap::from_iter( + this.check_constraints + .iter() + .map(|(name, _)| (name.clone(), format!("delta.constraints.{name}"))), + ); + + // Hold all the conflicted constraints + let preexisting_constraints = + configuration_key_mapper + .iter() + .filter(|(_, configuration_key)| { + metadata + .configuration() + .contains_key(configuration_key.as_str()) + }); let session = this .session @@ -136,13 +153,36 @@ impl std::future::IntoFuture for ConstraintBuilder { .await?; let schema = scan.schema().to_dfschema()?; - let expr = into_expr(expr, &schema, session.as_ref())?; - let expr_str = fmt_expr_to_sql(&expr)?; - // Checker built here with the one time constraint to check. - let checker = - DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]); + // Create an Hashmap of the name to the processed expression + let mut constraints_sql_mapper = HashMap::with_capacity(this.check_constraints.len()); + for (name, _) in configuration_key_mapper.iter() { + let converted_expr = into_expr( + this.check_constraints[name].clone(), + &schema, + session.as_ref(), + )?; + let constraint_sql = fmt_expr_to_sql(&converted_expr)?; + constraints_sql_mapper.insert(name, constraint_sql); + } + for (name, configuration_key) in preexisting_constraints { + // when the expression is different in the conflicted constraint --> error out due not knowing how to resolve it + if !metadata.configuration()[configuration_key].eq(&constraints_sql_mapper[name]) { + return Err(DeltaTableError::Generic(format!( + "Cannot add constraint '{name}': a constraint with this name already exists with a different expression. Existing: '{}', New: '{}'", + metadata.configuration()[configuration_key],constraints_sql_mapper[name] + ))); + } + tracing::warn!("Skipping constraint '{name}': identical constraint already exists with expression '{}'",constraints_sql_mapper[name]); + } + let constraints_checker: Vec = constraints_sql_mapper + .iter() + .map(|(_, sql)| Constraint::new("*", sql)) + .collect(); + + // Checker built here with the one time constraint to check. + let checker = DeltaDataChecker::new_with_constraints(constraints_checker); let plan: Arc = Arc::new(scan); let mut tasks = vec![]; for p in 0..plan.properties().output_partitioning().partition_count() { @@ -170,8 +210,12 @@ impl std::future::IntoFuture for ConstraintBuilder { // We have validated the table passes it's constraints, now to add the constraint to // the table. - metadata = - metadata.add_config_key(format!("delta.constraints.{name}"), expr_str.clone())?; + for (name, configuration_key) in configuration_key_mapper.iter() { + metadata = metadata.add_config_key( + configuration_key.to_string(), + constraints_sql_mapper[&name].clone(), + )?; + } let old_protocol = snapshot.protocol(); let protocol = ProtocolInner { @@ -199,10 +243,12 @@ impl std::future::IntoFuture for ConstraintBuilder { }, } .as_kernel(); - + // Put all the constraint into one commit let operation = DeltaOperation::AddConstraint { - name: name.clone(), - expr: expr_str.clone(), + constraints: constraints_sql_mapper + .into_iter() + .map(|(name, sql)| Constraint::new(name, &sql)) + .collect(), }; let actions = vec![metadata.into(), protocol.into()]; @@ -234,6 +280,7 @@ mod tests { use arrow_array::{Array, Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema}; use datafusion::logical_expr::{col, lit}; + use std::collections::HashMap; use crate::table::config::TablePropertiesExt as _; use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch}; @@ -250,17 +297,28 @@ mod tests { .unwrap() } - async fn get_constraint_op_params(table: &mut DeltaTable) -> String { + async fn get_constraint_op_params(table: &mut DeltaTable) -> HashMap { let last_commit = table.last_commit().await.unwrap(); - last_commit + let constraints_str = last_commit .operation_parameters .as_ref() .unwrap() - .get("expr") + .get("constraints") .unwrap() .as_str() + .unwrap(); + + let constraints: serde_json::Value = serde_json::from_str(constraints_str).unwrap(); + constraints + .as_array() .unwrap() - .to_owned() + .iter() + .map(|value| { + let name = value.get("name").unwrap().as_str().unwrap().to_owned(); + let expr = value.get("expr").unwrap().as_str().unwrap().to_owned(); + (name, expr) + }) + .collect() } #[tokio::test] @@ -290,7 +348,7 @@ mod tests { } #[tokio::test] - async fn add_constraint_with_invalid_data() -> DeltaResult<()> { + async fn test_add_constraint_with_invalid_data() -> DeltaResult<()> { let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) .write(vec![batch.clone()]) @@ -306,7 +364,7 @@ mod tests { } #[tokio::test] - async fn add_valid_constraint() -> DeltaResult<()> { + async fn test_add_valid_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) .write(vec![batch.clone()]) @@ -320,17 +378,53 @@ mod tests { let version = table.version(); assert_eq!(version, Some(1)); - let expected_expr = "value < 1000"; - assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + let expected_expr = vec!["value < 1000"]; assert_eq!( - get_constraint(&table, "delta.constraints.id"), + get_constraint_op_params(&mut table) + .await + .into_values() + .collect::>(), expected_expr ); + assert_eq!( + get_constraint(&table, "delta.constraints.id"), + expected_expr[0] + ); Ok(()) } #[tokio::test] - async fn add_constraint_datafusion() -> DeltaResult<()> { + async fn test_add_valid_multiple_constraints() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let constraints = HashMap::from([("id", "value < 1000"), ("id2", "value < 20")]); + + let mut table = table.add_constraint().with_constraints(constraints).await?; + let version = table.version(); + assert_eq!(version, Some(1)); + + let expected_exprs = HashMap::from([ + ("id".to_string(), "value < 1000".to_string()), + ("id2".to_string(), "value < 20".to_string()), + ]); + assert_eq!(get_constraint_op_params(&mut table).await, expected_exprs); + assert_eq!( + get_constraint(&table, "delta.constraints.id"), + expected_exprs["id"] + ); + assert_eq!( + get_constraint(&table, "delta.constraints.id2"), + expected_exprs["id2"] + ); + Ok(()) + } + + #[tokio::test] + async fn test_add_constraint_datafusion() -> DeltaResult<()> { // Add constraint by providing a datafusion expression. let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) @@ -345,12 +439,18 @@ mod tests { let version = table.version(); assert_eq!(version, Some(1)); - let expected_expr = "value < 1000"; - assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + let expected_expr = vec!["value < 1000"]; assert_eq!( - get_constraint(&table, "delta.constraints.valid_values"), + get_constraint_op_params(&mut table) + .await + .into_values() + .collect::>(), expected_expr ); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr[0] + ); Ok(()) } @@ -386,18 +486,24 @@ mod tests { let version = table.version(); assert_eq!(version, Some(1)); - let expected_expr = "\"vAlue\" < 1000"; // spellchecker:disable-line - assert_eq!(get_constraint_op_params(&mut table).await, expected_expr); + let expected_expr = vec!["\"vAlue\" < 1000"]; // spellchecker:disable-line assert_eq!( - get_constraint(&table, "delta.constraints.valid_values"), + get_constraint_op_params(&mut table) + .await + .into_values() + .collect::>(), expected_expr ); + assert_eq!( + get_constraint(&table, "delta.constraints.valid_values"), + expected_expr[0] + ); Ok(()) } #[tokio::test] - async fn add_conflicting_named_constraint() -> DeltaResult<()> { + async fn test_add_conflicting_named_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) .write(vec![batch.clone()]) @@ -419,7 +525,7 @@ mod tests { } #[tokio::test] - async fn write_data_that_violates_constraint() -> DeltaResult<()> { + async fn test_write_data_that_violates_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) .write(vec![batch.clone()]) @@ -440,9 +546,40 @@ mod tests { assert!(err.is_err()); Ok(()) } + #[tokio::test] + async fn test_write_data_that_violates_multiple_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + + let table = DeltaOps(write) + .add_constraint() + .with_constraints(HashMap::from([ + ("id", "value > 0"), + ("custom_cons", "value < 30"), + ])) + .await?; + let table = DeltaOps(table); + let invalid_values: Vec> = vec![ + Arc::new(StringArray::from(vec!["A"])), + Arc::new(Int32Array::from(vec![-10])), + Arc::new(StringArray::from(vec!["2021-02-02"])), + ]; + let invalid_values_2: Vec> = vec![ + Arc::new(StringArray::from(vec!["B"])), + Arc::new(Int32Array::from(vec![30])), + Arc::new(StringArray::from(vec!["2021-02-02"])), + ]; + let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?; + let batch2 = RecordBatch::try_new(get_arrow_schema(&None), invalid_values_2)?; + let err = table.write(vec![batch, batch2]).await; + assert!(err.is_err()); + Ok(()) + } #[tokio::test] - async fn write_data_that_does_not_violate_constraint() -> DeltaResult<()> { + async fn test_write_data_that_does_not_violate_constraint() -> DeltaResult<()> { let batch = get_record_batch(None, false); let write = DeltaOps(create_bare_table()) .write(vec![batch.clone()]) diff --git a/crates/core/src/operations/drop_constraints.rs b/crates/core/src/operations/drop_constraints.rs index 9e20208769..ac54ea816a 100644 --- a/crates/core/src/operations/drop_constraints.rs +++ b/crates/core/src/operations/drop_constraints.rs @@ -135,6 +135,8 @@ impl std::future::IntoFuture for DropConstraintBuilder { #[cfg(feature = "datafusion")] #[cfg(test)] mod tests { + use std::collections::HashMap; + use crate::writer::test_utils::{create_bare_table, get_record_batch}; use crate::{DeltaOps, DeltaResult, DeltaTable}; diff --git a/crates/core/src/operations/restore.rs b/crates/core/src/operations/restore.rs index 8b8a17402b..c235304a03 100644 --- a/crates/core/src/operations/restore.rs +++ b/crates/core/src/operations/restore.rs @@ -385,6 +385,8 @@ mod tests { #[tokio::test] #[cfg(feature = "datafusion")] async fn test_simple_restore_constraints() -> DeltaResult<()> { + use std::collections::HashMap; + use crate::table::config::TablePropertiesExt as _; let batch = get_record_batch(None, false); diff --git a/crates/core/src/protocol/mod.rs b/crates/core/src/protocol/mod.rs index 5ca640d4f8..fb3dcfdff4 100644 --- a/crates/core/src/protocol/mod.rs +++ b/crates/core/src/protocol/mod.rs @@ -2,6 +2,7 @@ #![allow(non_camel_case_types)] +use crate::table::Constraint; use std::borrow::Borrow; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -267,10 +268,8 @@ pub enum DeltaOperation { }, /// Add constraints to a table AddConstraint { - /// Constraints name - name: String, - /// Expression to check against - expr: String, + /// Constraints with Name and Expression + constraints: Vec, }, /// Add table features to a table diff --git a/crates/core/src/table/columns.rs b/crates/core/src/table/columns.rs index b43b6f24e2..1c06857e1d 100644 --- a/crates/core/src/table/columns.rs +++ b/crates/core/src/table/columns.rs @@ -1,10 +1,12 @@ //! Constraints and generated column mappings +use serde::{Deserialize, Serialize}; + use crate::kernel::DataType; use crate::table::DataCheck; use std::any::Any; /// A constraint in a check constraint -#[derive(Eq, PartialEq, Debug, Default, Clone)] +#[derive(Eq, PartialEq, Debug, Default, Clone, Serialize, Deserialize)] pub struct Constraint { /// The full path to the field. pub name: String, diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 6f1e722fc1..d5e2f4e12f 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1794,12 +1794,6 @@ def add_constraint( {'delta.constraints.value_gt_5': 'value > 5'} ``` """ - if len(constraints.keys()) > 1: - raise ValueError( - """add_constraints is limited to a single constraint addition at once for now. - Please execute add_constraints multiple times with each time a different constraint.""" - ) - self.table._table.add_constraints( constraints, commit_properties, diff --git a/python/src/lib.rs b/python/src/lib.rs index 10c40a51be..02e6ad67f9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -794,9 +794,7 @@ impl RawDeltaTable { let table = self._table.lock().map_err(to_rt_err)?.clone(); let mut cmd = DeltaOps(table).add_constraint(); - for (col_name, expression) in constraints { - cmd = cmd.with_constraint(col_name.clone(), expression.clone()); - } + cmd = cmd.with_constraints(constraints); if let Some(commit_properties) = maybe_create_commit_properties(commit_properties, post_commithook_properties) diff --git a/python/tests/test_alter.py b/python/tests/test_alter.py index 8dff1b2488..b2378efa14 100644 --- a/python/tests/test_alter.py +++ b/python/tests/test_alter.py @@ -53,16 +53,6 @@ def test_add_constraint(tmp_path: pathlib.Path, sample_table: Table): write_deltalake(tmp_path, data, mode="append") -def test_add_multiple_constraints(tmp_path: pathlib.Path, sample_table: Table): - write_deltalake(tmp_path, sample_table) - - dt = DeltaTable(tmp_path) - - with pytest.raises(ValueError): - dt.alter.add_constraint( - {"check_price": "price >= 0", "check_price2": "price >= 0"} - ) - def test_add_constraint_roundtrip_metadata(tmp_path: pathlib.Path, sample_table: Table): write_deltalake(tmp_path, sample_table, mode="append") diff --git a/python/tests/test_constraint.py b/python/tests/test_constraint.py index 0ac5d3d623..da6c1ceaa4 100644 --- a/python/tests/test_constraint.py +++ b/python/tests/test_constraint.py @@ -125,3 +125,36 @@ def test_add_constraint(tmp_path, sample_table: Table, sql_string: str): ) write_deltalake(tmp_path, data, mode="append") + +def test_add_multiple_constraint(tmp_path, sample_table: Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price": '"high price" >= 0', "min_price": '"high price" < 5'}) + + last_action = dt.history(1)[0] + assert last_action["operation"] == "ADD CONSTRAINT" + assert dt.version() == 1 + assert dt.metadata().configuration == { + "delta.constraints.check_price": '"high price" >= 0', + "delta.constraints.min_price": '"high price" < 5' + + } + assert dt.protocol().min_writer_version == 3 + + with pytest.raises(DeltaProtocolError): + data = Table( + { + "id": Array(["1"], DataType.string()), + "high price": Array([5], DataType.int64()), + }, + schema=Schema( + fields=[ + Field("id", type=DataType.string(), nullable=True), + Field("high price", type=DataType.int64(), nullable=True), + ] + ), + ) + + write_deltalake(tmp_path, data, mode="append")