Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 188 additions & 51 deletions crates/core/src/operations/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ use crate::operations::datafusion_utils::Expression;
use crate::protocol::DeltaOperation;
use crate::table::Constraint;
use crate::{DeltaResult, DeltaTable, DeltaTableError};
use std::collections::HashMap;

/// Build a constraint to add to a table
pub struct ConstraintBuilder {
/// A snapshot of the table's state
snapshot: Option<EagerSnapshot>,
/// Name of the constraint
name: Option<String>,
/// Constraint expression
expr: Option<Expression>,
/// Hashmap containing an name of the constraint and expression
check_constraints: HashMap<String, Expression>,
/// Delta object store for handling data files
log_store: LogStoreRef,
/// Datafusion session state relevant for executing the input plan
Expand All @@ -54,8 +53,7 @@ impl ConstraintBuilder {
/// Create a new builder
pub(crate) fn new(log_store: LogStoreRef, snapshot: Option<EagerSnapshot>) -> Self {
Self {
name: None,
expr: None,
check_constraints: Default::default(),
snapshot,
log_store,
session: None,
Expand All @@ -70,8 +68,21 @@ impl ConstraintBuilder {
name: S,
expression: E,
) -> Self {
self.name = Some(name.into());
self.expr = Some(expression.into());
self.check_constraints
.insert(name.into(), expression.into());
self
}

/// Specify multiple constraints to be added
pub fn with_constraints<S: Into<String>, E: Into<Expression>>(
mut self,
constraints: HashMap<S, E>,
) -> Self {
self.check_constraints.extend(
constraints
.into_iter()
.map(|(name, expr)| (name.into(), expr.into())),
);
self
}

Expand Down Expand Up @@ -108,23 +119,29 @@ impl std::future::IntoFuture for ConstraintBuilder {
let operation_id = this.get_operation_id();
this.pre_execute(operation_id).await?;

let name = match this.name {
Some(v) => v,
None => return Err(DeltaTableError::Generic("No name provided".to_string())),
};

let expr = this
.expr
.ok_or_else(|| DeltaTableError::Generic("No Expression provided".to_string()))?;
if this.check_constraints.is_empty() {
return Err(DeltaTableError::Generic(
"No check constraint (Name and Expression) provided".to_string(),
));
}

let mut metadata = snapshot.metadata().clone();
let configuration_key = format!("delta.constraints.{name}");

if metadata.configuration().contains_key(&configuration_key) {
return Err(DeltaTableError::Generic(format!(
"Constraint with name: {name} already exists"
)));
}
let configuration_key_mapper: HashMap<String, String> = HashMap::from_iter(
this.check_constraints
.iter()
.map(|(name, _)| (name.clone(), format!("delta.constraints.{name}"))),
);

// Hold all the conflicted constraints
let preexisting_constraints =
configuration_key_mapper
.iter()
.filter(|(_, configuration_key)| {
metadata
.configuration()
.contains_key(configuration_key.as_str())
});

let session = this
.session
Expand All @@ -136,13 +153,36 @@ impl std::future::IntoFuture for ConstraintBuilder {
.await?;

let schema = scan.schema().to_dfschema()?;
let expr = into_expr(expr, &schema, session.as_ref())?;
let expr_str = fmt_expr_to_sql(&expr)?;

// Checker built here with the one time constraint to check.
let checker =
DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr_str)]);
// Create an Hashmap of the name to the processed expression
let mut constraints_sql_mapper = HashMap::with_capacity(this.check_constraints.len());
for (name, _) in configuration_key_mapper.iter() {
let converted_expr = into_expr(
this.check_constraints[name].clone(),
&schema,
session.as_ref(),
)?;
let constraint_sql = fmt_expr_to_sql(&converted_expr)?;
constraints_sql_mapper.insert(name, constraint_sql);
}

for (name, configuration_key) in preexisting_constraints {
// when the expression is different in the conflicted constraint --> error out due not knowing how to resolve it
if !metadata.configuration()[configuration_key].eq(&constraints_sql_mapper[name]) {
return Err(DeltaTableError::Generic(format!(
"Cannot add constraint '{name}': a constraint with this name already exists with a different expression. Existing: '{}', New: '{}'",
metadata.configuration()[configuration_key],constraints_sql_mapper[name]
)));
}
tracing::warn!("Skipping constraint '{name}': identical constraint already exists with expression '{}'",constraints_sql_mapper[name]);
}
let constraints_checker: Vec<Constraint> = constraints_sql_mapper
.iter()
.map(|(_, sql)| Constraint::new("*", sql))
.collect();

// Checker built here with the one time constraint to check.
let checker = DeltaDataChecker::new_with_constraints(constraints_checker);
let plan: Arc<dyn ExecutionPlan> = Arc::new(scan);
let mut tasks = vec![];
for p in 0..plan.properties().output_partitioning().partition_count() {
Expand Down Expand Up @@ -170,8 +210,12 @@ impl std::future::IntoFuture for ConstraintBuilder {

// We have validated the table passes it's constraints, now to add the constraint to
// the table.
metadata =
metadata.add_config_key(format!("delta.constraints.{name}"), expr_str.clone())?;
for (name, configuration_key) in configuration_key_mapper.iter() {
metadata = metadata.add_config_key(
configuration_key.to_string(),
constraints_sql_mapper[&name].clone(),
)?;
}

let old_protocol = snapshot.protocol();
let protocol = ProtocolInner {
Expand Down Expand Up @@ -199,10 +243,12 @@ impl std::future::IntoFuture for ConstraintBuilder {
},
}
.as_kernel();

// Put all the constraint into one commit
let operation = DeltaOperation::AddConstraint {
name: name.clone(),
expr: expr_str.clone(),
constraints: constraints_sql_mapper
.into_iter()
.map(|(name, sql)| Constraint::new(name, &sql))
.collect(),
};

let actions = vec![metadata.into(), protocol.into()];
Expand Down Expand Up @@ -234,6 +280,7 @@ mod tests {
use arrow_array::{Array, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
use datafusion::logical_expr::{col, lit};
use std::collections::HashMap;

use crate::table::config::TablePropertiesExt as _;
use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch};
Expand All @@ -250,17 +297,28 @@ mod tests {
.unwrap()
}

async fn get_constraint_op_params(table: &mut DeltaTable) -> String {
async fn get_constraint_op_params(table: &mut DeltaTable) -> HashMap<String, String> {
let last_commit = table.last_commit().await.unwrap();
last_commit
let constraints_str = last_commit
.operation_parameters
.as_ref()
.unwrap()
.get("expr")
.get("constraints")
.unwrap()
.as_str()
.unwrap();

let constraints: serde_json::Value = serde_json::from_str(constraints_str).unwrap();
constraints
.as_array()
.unwrap()
.to_owned()
.iter()
.map(|value| {
let name = value.get("name").unwrap().as_str().unwrap().to_owned();
let expr = value.get("expr").unwrap().as_str().unwrap().to_owned();
(name, expr)
})
.collect()
}

#[tokio::test]
Expand Down Expand Up @@ -290,7 +348,7 @@ mod tests {
}

#[tokio::test]
async fn add_constraint_with_invalid_data() -> DeltaResult<()> {
async fn test_add_constraint_with_invalid_data() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
Expand All @@ -306,7 +364,7 @@ mod tests {
}

#[tokio::test]
async fn add_valid_constraint() -> DeltaResult<()> {
async fn test_add_valid_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
Expand All @@ -320,17 +378,53 @@ mod tests {
let version = table.version();
assert_eq!(version, Some(1));

let expected_expr = "value < 1000";
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
let expected_expr = vec!["value < 1000"];
assert_eq!(
get_constraint(&table, "delta.constraints.id"),
get_constraint_op_params(&mut table)
.await
.into_values()
.collect::<Vec<String>>(),
expected_expr
);
assert_eq!(
get_constraint(&table, "delta.constraints.id"),
expected_expr[0]
);
Ok(())
}

#[tokio::test]
async fn add_constraint_datafusion() -> DeltaResult<()> {
async fn test_add_valid_multiple_constraints() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;
let table = DeltaOps(write);

let constraints = HashMap::from([("id", "value < 1000"), ("id2", "value < 20")]);

let mut table = table.add_constraint().with_constraints(constraints).await?;
let version = table.version();
assert_eq!(version, Some(1));

let expected_exprs = HashMap::from([
("id".to_string(), "value < 1000".to_string()),
("id2".to_string(), "value < 20".to_string()),
]);
assert_eq!(get_constraint_op_params(&mut table).await, expected_exprs);
assert_eq!(
get_constraint(&table, "delta.constraints.id"),
expected_exprs["id"]
);
assert_eq!(
get_constraint(&table, "delta.constraints.id2"),
expected_exprs["id2"]
);
Ok(())
}

#[tokio::test]
async fn test_add_constraint_datafusion() -> DeltaResult<()> {
// Add constraint by providing a datafusion expression.
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
Expand All @@ -345,12 +439,18 @@ mod tests {
let version = table.version();
assert_eq!(version, Some(1));

let expected_expr = "value < 1000";
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
let expected_expr = vec!["value < 1000"];
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
get_constraint_op_params(&mut table)
.await
.into_values()
.collect::<Vec<String>>(),
expected_expr
);
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
expected_expr[0]
);

Ok(())
}
Expand Down Expand Up @@ -386,18 +486,24 @@ mod tests {
let version = table.version();
assert_eq!(version, Some(1));

let expected_expr = "\"vAlue\" < 1000"; // spellchecker:disable-line
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
let expected_expr = vec!["\"vAlue\" < 1000"]; // spellchecker:disable-line
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
get_constraint_op_params(&mut table)
.await
.into_values()
.collect::<Vec<String>>(),
expected_expr
);
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
expected_expr[0]
);

Ok(())
}

#[tokio::test]
async fn add_conflicting_named_constraint() -> DeltaResult<()> {
async fn test_add_conflicting_named_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
Expand All @@ -419,7 +525,7 @@ mod tests {
}

#[tokio::test]
async fn write_data_that_violates_constraint() -> DeltaResult<()> {
async fn test_write_data_that_violates_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
Expand All @@ -440,9 +546,40 @@ mod tests {
assert!(err.is_err());
Ok(())
}
#[tokio::test]
async fn test_write_data_that_violates_multiple_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
.await?;

let table = DeltaOps(write)
.add_constraint()
.with_constraints(HashMap::from([
("id", "value > 0"),
("custom_cons", "value < 30"),
]))
.await?;
let table = DeltaOps(table);
let invalid_values: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(vec!["A"])),
Arc::new(Int32Array::from(vec![-10])),
Arc::new(StringArray::from(vec!["2021-02-02"])),
];
let invalid_values_2: Vec<Arc<dyn Array>> = vec![
Arc::new(StringArray::from(vec!["B"])),
Arc::new(Int32Array::from(vec![30])),
Arc::new(StringArray::from(vec!["2021-02-02"])),
];
let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?;
let batch2 = RecordBatch::try_new(get_arrow_schema(&None), invalid_values_2)?;
let err = table.write(vec![batch, batch2]).await;
assert!(err.is_err());
Ok(())
}

#[tokio::test]
async fn write_data_that_does_not_violate_constraint() -> DeltaResult<()> {
async fn test_write_data_that_does_not_violate_constraint() -> DeltaResult<()> {
let batch = get_record_batch(None, false);
let write = DeltaOps(create_bare_table())
.write(vec![batch.clone()])
Expand Down
2 changes: 2 additions & 0 deletions crates/core/src/operations/drop_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ impl std::future::IntoFuture for DropConstraintBuilder {
#[cfg(feature = "datafusion")]
#[cfg(test)]
mod tests {
use std::collections::HashMap;

use crate::writer::test_utils::{create_bare_table, get_record_batch};
use crate::{DeltaOps, DeltaResult, DeltaTable};

Expand Down
Loading
Loading