Skip to content

Commit 382d9e5

Browse files
ion-elgrecoLiam Brannigan
authored and
Liam Brannigan
committed
feat: write metrics extension planner
Signed-off-by: Ion Koutsouris <[email protected]> Signed-off-by: Liam Brannigan <[email protected]>
1 parent 09f05fd commit 382d9e5

File tree

8 files changed

+102
-52
lines changed

8 files changed

+102
-52
lines changed

crates/core/src/operations/delete.rs

-2
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ async fn execute_non_empty_expr(
259259
None,
260260
writer_properties.clone(),
261261
writer_stats_config.clone(),
262-
None,
263262
)
264263
.await?;
265264

@@ -296,7 +295,6 @@ async fn execute_non_empty_expr(
296295
None,
297296
writer_properties,
298297
writer_stats_config,
299-
None,
300298
)
301299
.await?;
302300
actions.extend(cdc_actions)

crates/core/src/operations/update.rs

-2
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,6 @@ async fn execute(
399399
None,
400400
writer_properties.clone(),
401401
writer_stats_config.clone(),
402-
None,
403402
)
404403
.await?;
405404

@@ -462,7 +461,6 @@ async fn execute(
462461
None,
463462
writer_properties,
464463
writer_stats_config,
465-
None,
466464
)
467465
.await?;
468466
actions.extend(cdc_actions);

crates/core/src/operations/write/execution.rs

+3-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::sync::Arc;
22
use std::vec;
33

4-
use arrow_array::RecordBatch;
54
use arrow_schema::SchemaRef as ArrowSchemaRef;
65
use datafusion::datasource::provider_as_source;
76
use datafusion::execution::context::{SessionState, TaskContext};
@@ -25,7 +24,6 @@ use crate::operations::writer::{DeltaWriter, WriterConfig};
2524
use crate::storage::ObjectStoreRef;
2625
use crate::table::state::DeltaTableState;
2726
use crate::table::Constraint as DeltaConstraint;
28-
use tokio::sync::mpsc::Sender;
2927

3028
use super::configs::WriterStatsConfig;
3129
use super::WriteError;
@@ -42,7 +40,6 @@ pub(crate) async fn write_execution_plan_with_predicate(
4240
write_batch_size: Option<usize>,
4341
writer_properties: Option<WriterProperties>,
4442
writer_stats_config: WriterStatsConfig,
45-
sender: Option<Sender<RecordBatch>>,
4643
) -> DeltaResult<Vec<Action>> {
4744
// We always take the plan Schema since the data may contain Large/View arrow types,
4845
// the schema and batches were prior constructed with this in mind.
@@ -81,33 +78,21 @@ pub(crate) async fn write_execution_plan_with_predicate(
8178
);
8279
let mut writer = DeltaWriter::new(object_store.clone(), config);
8380
let checker_stream = checker.clone();
84-
let sender_stream = sender.clone();
8581
let mut stream = inner_plan.execute(i, task_ctx)?;
8682

87-
let handle: tokio::task::JoinHandle<DeltaResult<Vec<Action>>> = tokio::task::spawn(
88-
async move {
89-
let sendable = sender_stream.clone();
83+
let handle: tokio::task::JoinHandle<DeltaResult<Vec<Action>>> =
84+
tokio::task::spawn(async move {
9085
while let Some(maybe_batch) = stream.next().await {
9186
let batch = maybe_batch?;
92-
9387
checker_stream.check_batch(&batch).await?;
94-
95-
if let Some(s) = sendable.as_ref() {
96-
if let Err(e) = s.send(batch.clone()).await {
97-
error!("Failed to send data to observer: {e:#?}");
98-
}
99-
} else {
100-
debug!("write_execution_plan_with_predicate did not send any batches, no sender.");
101-
}
10288
writer.write(&batch).await?;
10389
}
10490
let add_actions = writer.close().await;
10591
match add_actions {
10692
Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::<Vec<_>>()),
10793
Err(err) => Err(err),
10894
}
109-
},
110-
);
95+
});
11196

11297
tasks.push(handle);
11398
}
@@ -136,7 +121,6 @@ pub(crate) async fn write_execution_plan_cdc(
136121
write_batch_size: Option<usize>,
137122
writer_properties: Option<WriterProperties>,
138123
writer_stats_config: WriterStatsConfig,
139-
sender: Option<Sender<RecordBatch>>,
140124
) -> DeltaResult<Vec<Action>> {
141125
let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data"));
142126

@@ -150,7 +134,6 @@ pub(crate) async fn write_execution_plan_cdc(
150134
write_batch_size,
151135
writer_properties,
152136
writer_stats_config,
153-
sender,
154137
)
155138
.await?
156139
.into_iter()
@@ -185,7 +168,6 @@ pub(crate) async fn write_execution_plan(
185168
write_batch_size: Option<usize>,
186169
writer_properties: Option<WriterProperties>,
187170
writer_stats_config: WriterStatsConfig,
188-
sender: Option<Sender<RecordBatch>>,
189171
) -> DeltaResult<Vec<Action>> {
190172
write_execution_plan_with_predicate(
191173
None,
@@ -198,7 +180,6 @@ pub(crate) async fn write_execution_plan(
198180
write_batch_size,
199181
writer_properties,
200182
writer_stats_config,
201-
sender,
202183
)
203184
.await
204185
}
@@ -258,7 +239,6 @@ pub(crate) async fn execute_non_empty_expr(
258239
None,
259240
writer_properties.clone(),
260241
writer_stats_config.clone(),
261-
None,
262242
)
263243
.await?;
264244

@@ -330,7 +310,6 @@ pub(crate) async fn execute_non_empty_expr_cdc(
330310
None,
331311
writer_properties,
332312
writer_stats_config,
333-
None,
334313
)
335314
.await?;
336315
Ok(Some(cdc_actions))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use std::sync::Arc;
2+
3+
use async_trait::async_trait;
4+
use datafusion::{
5+
execution::SessionState,
6+
physical_planner::{ExtensionPlanner, PhysicalPlanner},
7+
};
8+
use datafusion_common::Result as DataFusionResult;
9+
use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
10+
use datafusion_physical_plan::{metrics::MetricBuilder, ExecutionPlan};
11+
12+
use crate::delta_datafusion::{logical::MetricObserver, physical::MetricObserverExec};
13+
14+
pub(crate) const SOURCE_COUNT_ID: &str = "write_source_count";
15+
pub(crate) const SOURCE_COUNT_METRIC: &str = "num_source_rows";
16+
17+
#[derive(Clone, Debug)]
18+
pub(crate) struct WriteMetricExtensionPlanner {}
19+
20+
#[async_trait]
21+
impl ExtensionPlanner for WriteMetricExtensionPlanner {
22+
async fn plan_extension(
23+
&self,
24+
_planner: &dyn PhysicalPlanner,
25+
node: &dyn UserDefinedLogicalNode,
26+
_logical_inputs: &[&LogicalPlan],
27+
physical_inputs: &[Arc<dyn ExecutionPlan>],
28+
_session_state: &SessionState,
29+
) -> DataFusionResult<Option<Arc<dyn ExecutionPlan>>> {
30+
if let Some(metric_observer) = node.as_any().downcast_ref::<MetricObserver>() {
31+
if metric_observer.id.eq(SOURCE_COUNT_ID) {
32+
return Ok(Some(MetricObserverExec::try_new(
33+
SOURCE_COUNT_ID.into(),
34+
physical_inputs,
35+
|batch, metrics| {
36+
MetricBuilder::new(metrics)
37+
.global_counter(SOURCE_COUNT_METRIC)
38+
.add(batch.num_rows());
39+
},
40+
)?));
41+
}
42+
}
43+
Ok(None)
44+
}
45+
}

crates/core/src/operations/write/mod.rs

+41-13
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
pub mod configs;
2727
pub(crate) mod execution;
2828
pub(crate) mod generated_columns;
29-
pub mod lazy;
29+
pub(crate) mod metrics;
3030
pub(crate) mod schema_evolution;
3131

3232
use arrow_schema::Schema;
3333
pub use configs::WriterStatsConfig;
3434
use datafusion::execution::SessionStateBuilder;
3535
use generated_columns::{add_generated_columns, add_missing_generated_columns};
36+
use metrics::{WriteMetricExtensionPlanner, SOURCE_COUNT_ID, SOURCE_COUNT_METRIC};
3637
use std::collections::HashMap;
3738
use std::str::FromStr;
3839
use std::sync::Arc;
@@ -45,7 +46,7 @@ use datafusion::datasource::MemTable;
4546
use datafusion::execution::context::{SessionContext, SessionState};
4647
use datafusion::prelude::DataFrame;
4748
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
48-
use datafusion_expr::{cast, lit, Expr, LogicalPlan};
49+
use datafusion_expr::{cast, lit, try_cast, Expr, Extension, LogicalPlan};
4950
use execution::{prepare_predicate_actions, write_execution_plan_with_predicate};
5051
use futures::future::BoxFuture;
5152
use parquet::file::properties::WriterProperties;
@@ -58,6 +59,9 @@ use super::transaction::{CommitBuilder, CommitProperties, TableReference, PROTOC
5859
use super::{CreateBuilder, CustomExecuteHandler, Operation};
5960
use crate::delta_datafusion::expr::fmt_expr_to_sql;
6061
use crate::delta_datafusion::expr::parse_predicate_expression;
62+
use crate::delta_datafusion::logical::MetricObserver;
63+
use crate::delta_datafusion::physical::{find_metric_node, get_metric};
64+
use crate::delta_datafusion::planner::DeltaPlanner;
6165
use crate::delta_datafusion::register_store;
6266
use crate::delta_datafusion::DataFusionMixins;
6367
use crate::errors::{DeltaResult, DeltaTableError};
@@ -418,16 +422,25 @@ impl std::future::IntoFuture for WriteBuilder {
418422
let mut metrics = WriteMetrics::default();
419423
let exec_start = Instant::now();
420424

425+
let write_planner = DeltaPlanner::<WriteMetricExtensionPlanner> {
426+
extension_planner: WriteMetricExtensionPlanner {},
427+
};
428+
421429
// Create table actions to initialize table in case it does not yet exist
422430
// and should be created
423431
let mut actions = this.check_preconditions().await?;
424432

425433
let partition_columns = this.get_partition_columns()?;
426434

427435
let state = match this.state {
428-
Some(state) => state,
436+
Some(state) => SessionStateBuilder::new_from_existing(state.clone())
437+
.with_query_planner(Arc::new(write_planner))
438+
.build(),
429439
None => {
430-
let state = SessionStateBuilder::new().with_default_features().build();
440+
let state = SessionStateBuilder::new()
441+
.with_default_features()
442+
.with_query_planner(Arc::new(write_planner))
443+
.build();
431444
register_store(this.log_store.clone(), state.runtime_env().clone());
432445
state
433446
}
@@ -491,7 +504,8 @@ impl std::future::IntoFuture for WriteBuilder {
491504
for field in new_schema.fields() {
492505
// If field exist in source data, we cast to new datatype
493506
if source_schema.index_of(field.name()).is_ok() {
494-
let cast_expr = cast(
507+
let cast_fn = if this.safe_cast { try_cast } else { cast };
508+
let cast_expr = cast_fn(
495509
Expr::Column(Column::from_name(field.name())),
496510
// col(field.name()),
497511
field.data_type().clone(),
@@ -520,6 +534,16 @@ impl std::future::IntoFuture for WriteBuilder {
520534
&state,
521535
)?;
522536

537+
let source = LogicalPlan::Extension(Extension {
538+
node: Arc::new(MetricObserver {
539+
id: "write_source_count".into(),
540+
input: source.logical_plan().clone(),
541+
enable_pushdown: false,
542+
}),
543+
});
544+
545+
let source = DataFrame::new(state.clone(), source);
546+
523547
let schema = Arc::new(source.schema().as_arrow().clone());
524548

525549
// Maybe create schema action
@@ -576,21 +600,31 @@ impl std::future::IntoFuture for WriteBuilder {
576600
stats_columns,
577601
};
578602

603+
let source_plan = source.clone().create_physical_plan().await?;
604+
579605
// Here we need to validate if the new data conforms to a predicate if one is provided
580606
let add_actions = write_execution_plan_with_predicate(
581607
predicate.clone(),
582608
this.snapshot.as_ref(),
583609
state.clone(),
584-
source.clone().create_physical_plan().await?,
610+
source_plan.clone(),
585611
partition_columns.clone(),
586612
this.log_store.object_store(Some(operation_id)).clone(),
587613
target_file_size,
588614
this.write_batch_size,
589615
this.writer_properties.clone(),
590616
writer_stats_config.clone(),
591-
None,
592617
)
593618
.await?;
619+
620+
let source_count =
621+
find_metric_node(SOURCE_COUNT_ID, &source_plan).ok_or_else(|| {
622+
DeltaTableError::Generic("Unable to locate expected metric node".into())
623+
})?;
624+
let source_count_metrics = source_count.metrics().unwrap();
625+
let num_added_rows = get_metric(&source_count_metrics, SOURCE_COUNT_METRIC);
626+
metrics.num_added_rows = num_added_rows;
627+
594628
metrics.num_added_files = add_actions.len();
595629
actions.extend(add_actions);
596630

@@ -989,7 +1023,6 @@ mod tests {
9891023
assert_eq!(table.version(), 0);
9901024
assert_eq!(table.get_files_count(), 2);
9911025
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
992-
assert!(write_metrics.num_partitions > 0);
9931026
assert_eq!(write_metrics.num_added_files, 2);
9941027
assert_common_write_metrics(write_metrics);
9951028

@@ -1003,7 +1036,6 @@ mod tests {
10031036
assert_eq!(table.get_files_count(), 4);
10041037

10051038
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
1006-
assert!(write_metrics.num_partitions > 0);
10071039
assert_eq!(write_metrics.num_added_files, 4);
10081040
assert_common_write_metrics(write_metrics);
10091041
}
@@ -1093,7 +1125,6 @@ mod tests {
10931125
assert_eq!(table.version(), 0);
10941126

10951127
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
1096-
assert!(write_metrics.num_partitions > 0);
10971128
assert_common_write_metrics(write_metrics);
10981129

10991130
let mut new_schema_builder = arrow_schema::SchemaBuilder::new();
@@ -1146,7 +1177,6 @@ mod tests {
11461177
assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions
11471178

11481179
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
1149-
assert!(write_metrics.num_partitions > 0);
11501180
assert_common_write_metrics(write_metrics);
11511181
}
11521182

@@ -1668,7 +1698,6 @@ mod tests {
16681698
assert_eq!(table.version(), 1);
16691699
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
16701700
assert_eq!(write_metrics.num_added_rows, 3);
1671-
assert!(write_metrics.num_partitions > 0);
16721701
assert_common_write_metrics(write_metrics);
16731702

16741703
let table = DeltaOps(table)
@@ -1680,7 +1709,6 @@ mod tests {
16801709
assert_eq!(table.version(), 2);
16811710
let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await;
16821711
assert_eq!(write_metrics.num_added_rows, 1);
1683-
assert!(write_metrics.num_partitions > 0);
16841712
assert!(write_metrics.num_removed_files > 0);
16851713
assert_common_write_metrics(write_metrics);
16861714

crates/core/tests/integration_datafusion.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,7 @@ mod local {
11191119
let _ = write_builder
11201120
.with_input_execution_plan(plan)
11211121
.with_save_mode(SaveMode::Overwrite)
1122+
.with_schema_mode(deltalake_core::operations::write::SchemaMode::Overwrite)
11221123
.await
11231124
.unwrap();
11241125

python/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod merge;
55
mod query;
66
mod schema;
77
mod utils;
8+
mod writer;
89

910
use std::cmp::min;
1011
use std::collections::{HashMap, HashSet};
@@ -2179,9 +2180,9 @@ fn write_to_deltalake(
21792180
);
21802181
builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap()));
21812182
} else {
2183+
use crate::writer::to_lazy_table;
21822184
use deltalake::datafusion::datasource::provider_as_source;
21832185
use deltalake::datafusion::logical_expr::LogicalPlanBuilder;
2184-
use deltalake::operations::write::lazy::to_lazy_table;
21852186
let table_provider = to_lazy_table(data.0).map_err(PythonError::from)?;
21862187

21872188
let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None)

0 commit comments

Comments
 (0)