diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index b8bec2e5..3c747ad3 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -3,6 +3,7 @@ on: push: branches: - main + - 'v*' tags: - 'v*' @@ -34,7 +35,7 @@ jobs: run: nix build --print-build-logs - name: Create release 🚀 - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: mongodb-connector path: result/bin/mongodb-connector diff --git a/CHANGELOG.md b/CHANGELOG.md index e8b7cf02..bd68958e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ This changelog documents the changes between release versions. ## [Unreleased] +### Changed + +- Query plans for remote joins no longer use `$lookup` stage if there is exactly one incoming variable set - this allows use of `$vectorSearch` in native queries in remote joins in certain circumstances ([#147](https://github.com/hasura/ndc-mongodb/pull/147)) + ## [1.6.0] - 2025-01-17 ### Added diff --git a/Cargo.lock b/Cargo.lock index 9f8de50b..f776ecf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1032,9 +1032,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hickory-proto" -version = "0.24.2" +version = "0.24.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447afdcdb8afb9d0a852af6dc65d9b285ce720ed7a59e42a8bf2e931c67bc1b5" +checksum = "2ad3d6d98c648ed628df039541a5577bee1a7c83e9e16fe3dbedeea4cdfeb971" dependencies = [ "async-trait", "cfg-if", diff --git a/crates/integration-tests/src/tests/native_query.rs b/crates/integration-tests/src/tests/native_query.rs index 6865b5fe..9b476843 100644 --- a/crates/integration-tests/src/tests/native_query.rs +++ b/crates/integration-tests/src/tests/native_query.rs @@ -62,3 +62,65 @@ async fn runs_native_query_with_variable_sets() -> anyhow::Result<()> { ); Ok(()) } + +#[tokio::test] +async fn runs_native_query_with_a_single_variable_set() -> anyhow::Result<()> { + assert_yaml_snapshot!( + run_connector_query( + Connector::SampleMflix, + query_request() + .variables([[("count", 3)]]) + .collection("title_word_frequency") + .query( + query() + .predicate(binop("_eq", target!("count"), variable!(count))) + .order_by([asc!("_id")]) + .limit(20) + .fields([field!("_id"), field!("count")]), + ) + ) + .await? + ); + Ok(()) +} + +#[tokio::test] +async fn runs_native_query_without_input_collection_with_variable_sets() -> anyhow::Result<()> { + assert_yaml_snapshot!( + run_connector_query( + Connector::SampleMflix, + query_request() + .variables([[("type", "decimal")], [("type", "date")]]) + .collection("extended_json_test_data") + .query( + query() + .predicate(binop("_eq", target!("type"), variable!(type))) + .order_by([asc!("type"), asc!("value")]) + .fields([field!("type"), field!("value")]), + ) + ) + .await? + ); + Ok(()) +} + +#[tokio::test] +async fn runs_native_query_without_input_collection_with_single_variable_set() -> anyhow::Result<()> +{ + assert_yaml_snapshot!( + run_connector_query( + Connector::SampleMflix, + query_request() + .variables([[("type", "decimal")]]) + .collection("extended_json_test_data") + .query( + query() + .predicate(binop("_eq", target!("type"), variable!(type))) + .order_by([asc!("type"), asc!("value")]) + .fields([field!("type"), field!("value")]), + ) + ) + .await? + ); + Ok(()) +} diff --git a/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_with_a_single_variable_set.snap b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_with_a_single_variable_set.snap new file mode 100644 index 00000000..0f2a99ef --- /dev/null +++ b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_with_a_single_variable_set.snap @@ -0,0 +1,45 @@ +--- +source: crates/integration-tests/src/tests/native_query.rs +expression: "run_connector_query(Connector::SampleMflix,\nquery_request().variables([[(\"count\",\n3)]]).collection(\"title_word_frequency\").query(query().predicate(binop(\"_eq\",\ntarget!(\"count\"),\nvariable!(count))).order_by([asc!(\"_id\")]).limit(20).fields([field!(\"_id\"),\nfield!(\"count\")]),)).await?" +--- +- rows: + - _id: "#1" + count: 3 + - _id: "'n" + count: 3 + - _id: "'n'" + count: 3 + - _id: (Not) + count: 3 + - _id: "100" + count: 3 + - _id: 10th + count: 3 + - _id: "15" + count: 3 + - _id: "174" + count: 3 + - _id: "23" + count: 3 + - _id: 3-D + count: 3 + - _id: "42" + count: 3 + - _id: "420" + count: 3 + - _id: "72" + count: 3 + - _id: Abandoned + count: 3 + - _id: Abendland + count: 3 + - _id: Absence + count: 3 + - _id: Absent + count: 3 + - _id: Abu + count: 3 + - _id: Accident + count: 3 + - _id: Accidental + count: 3 diff --git a/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_single_variable_set.snap b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_single_variable_set.snap new file mode 100644 index 00000000..cbece735 --- /dev/null +++ b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_single_variable_set.snap @@ -0,0 +1,11 @@ +--- +source: crates/integration-tests/src/tests/native_query.rs +expression: "run_connector_query(Connector::SampleMflix,\nquery_request().variables([[(\"type\",\n\"decimal\")]]).collection(\"extended_json_test_data\").query(query().predicate(binop(\"_eq\",\ntarget!(\"type\"),\nvariable!(type))).order_by([asc!(\"type\"),\nasc!(\"value\")]).fields([field!(\"type\"), field!(\"value\")]),)).await?" +--- +- rows: + - type: decimal + value: + $numberDecimal: "1" + - type: decimal + value: + $numberDecimal: "2" diff --git a/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_variable_sets.snap b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_variable_sets.snap new file mode 100644 index 00000000..a3eb8e6a --- /dev/null +++ b/crates/integration-tests/src/tests/snapshots/integration_tests__tests__native_query__runs_native_query_without_input_collection_with_variable_sets.snap @@ -0,0 +1,20 @@ +--- +source: crates/integration-tests/src/tests/native_query.rs +expression: "run_connector_query(Connector::SampleMflix,\nquery_request().variables([[(\"type\", \"decimal\")],\n[(\"type\",\n\"date\")]]).collection(\"extended_json_test_data\").query(query().predicate(binop(\"_eq\",\ntarget!(\"type\"),\nvariable!(type))).order_by([asc!(\"type\"),\nasc!(\"value\")]).fields([field!(\"type\"), field!(\"value\")]),)).await?" +--- +- rows: + - type: decimal + value: + $numberDecimal: "1" + - type: decimal + value: + $numberDecimal: "2" +- rows: + - type: date + value: + $date: + $numberLong: "1637571600000" + - type: date + value: + $date: + $numberLong: "1724164680000" diff --git a/crates/mongodb-agent-common/src/explain.rs b/crates/mongodb-agent-common/src/explain.rs index 0b504da4..d23a5edf 100644 --- a/crates/mongodb-agent-common/src/explain.rs +++ b/crates/mongodb-agent-common/src/explain.rs @@ -1,13 +1,12 @@ use std::collections::BTreeMap; use mongodb::bson::{doc, to_bson, Bson}; +use mongodb_support::aggregate::AggregateCommand; use ndc_models::{ExplainResponse, QueryRequest}; use ndc_query_plan::plan_for_query_request; use crate::{ - interface_types::MongoAgentError, - mongo_query_plan::MongoConfiguration, - query::{self, QueryTarget}, + interface_types::MongoAgentError, mongo_query_plan::MongoConfiguration, query, state::ConnectorState, }; @@ -19,19 +18,28 @@ pub async fn explain_query( let db = state.database(); let query_plan = plan_for_query_request(config, query_request)?; - let pipeline = query::pipeline_for_query_request(config, &query_plan)?; + let AggregateCommand { + collection, + pipeline, + let_vars, + } = query::command_for_query_request(config, &query_plan)?; let pipeline_bson = to_bson(&pipeline)?; - let target = QueryTarget::for_request(config, &query_plan); - let aggregate_target = match (target.input_collection(), query_plan.has_variables()) { - (Some(collection_name), false) => Bson::String(collection_name.to_string()), - _ => Bson::Int32(1), + let aggregate_target = match collection { + Some(collection_name) => Bson::String(collection_name.to_string()), + None => Bson::Int32(1), }; - let query_command = doc! { - "aggregate": aggregate_target, - "pipeline": pipeline_bson, - "cursor": {}, + let query_command = { + let mut cmd = doc! { + "aggregate": aggregate_target, + "pipeline": pipeline_bson, + "cursor": {}, + }; + if let Some(let_vars) = let_vars { + cmd.insert("let", let_vars); + } + cmd }; let explain_command = doc! { diff --git a/crates/mongodb-agent-common/src/query/execute_query_request.rs b/crates/mongodb-agent-common/src/query/execute_query_request.rs index aa1b4551..68d0c850 100644 --- a/crates/mongodb-agent-common/src/query/execute_query_request.rs +++ b/crates/mongodb-agent-common/src/query/execute_query_request.rs @@ -1,17 +1,16 @@ use futures::Stream; use futures_util::TryStreamExt as _; -use mongodb::bson; -use mongodb_support::aggregate::Pipeline; +use mongodb::{bson, options::AggregateOptions}; +use mongodb_support::aggregate::AggregateCommand; use ndc_models::{QueryRequest, QueryResponse}; use ndc_query_plan::plan_for_query_request; use tracing::{instrument, Instrument}; -use super::{pipeline::pipeline_for_query_request, response::serialize_query_response}; +use super::{pipeline::command_for_query_request, response::serialize_query_response}; use crate::{ interface_types::MongoAgentError, mongo_query_plan::{MongoConfiguration, QueryPlan}, mongodb::{CollectionTrait as _, DatabaseTrait}, - query::QueryTarget, }; type Result = std::result::Result; @@ -31,8 +30,8 @@ pub async fn execute_query_request( ); let query_plan = preprocess_query_request(config, query_request)?; tracing::debug!(?query_plan, "abstract query plan"); - let pipeline = pipeline_for_query_request(config, &query_plan)?; - let documents = execute_query_pipeline(database, config, &query_plan, pipeline).await?; + let command = command_for_query_request(config, &query_plan)?; + let documents = execute_query_command(database, command).await?; let response = serialize_query_response(config.extended_json_mode(), &query_plan, documents)?; Ok(response) } @@ -47,19 +46,24 @@ fn preprocess_query_request( } #[instrument(name = "Execute Query Pipeline", skip_all, fields(internal.visibility = "user"))] -async fn execute_query_pipeline( +async fn execute_query_command( database: impl DatabaseTrait, - config: &MongoConfiguration, - query_plan: &QueryPlan, - pipeline: Pipeline, + AggregateCommand { + collection, + pipeline, + let_vars, + }: AggregateCommand, ) -> Result> { - let target = QueryTarget::for_request(config, query_plan); tracing::debug!( - ?target, + ?collection, pipeline = %serde_json::to_string(&pipeline).unwrap(), + let_vars = %serde_json::to_string(&let_vars).unwrap(), "executing query" ); + let aggregate_options = + let_vars.map(|let_vars| AggregateOptions::builder().let_vars(let_vars).build()); + // The target of a query request might be a collection, or it might be a native query. In the // latter case there is no collection to perform the aggregation against. So instead of sending // the MongoDB API call `db..aggregate` we instead call `db.aggregate`. @@ -67,12 +71,12 @@ async fn execute_query_pipeline( // If the query request includes variable sets then instead of specifying the target collection // up front that is deferred until the `$lookup` stage of the aggregation pipeline. That is // another case where we call `db.aggregate` instead of `db..aggregate`. - let documents = match (target.input_collection(), query_plan.has_variables()) { - (Some(collection_name), false) => { + let documents = match collection { + Some(collection_name) => { let collection = database.collection(collection_name.as_str()); collect_response_documents( collection - .aggregate(pipeline, None) + .aggregate(pipeline, aggregate_options) .instrument(tracing::info_span!( "MongoDB Aggregate Command", internal.visibility = "user" @@ -81,10 +85,10 @@ async fn execute_query_pipeline( ) .await } - _ => { + None => { collect_response_documents( database - .aggregate(pipeline, None) + .aggregate(pipeline, aggregate_options) .instrument(tracing::info_span!( "MongoDB Aggregate Command", internal.visibility = "user" diff --git a/crates/mongodb-agent-common/src/query/foreach.rs b/crates/mongodb-agent-common/src/query/foreach.rs index 4995eb40..cf76404b 100644 --- a/crates/mongodb-agent-common/src/query/foreach.rs +++ b/crates/mongodb-agent-common/src/query/foreach.rs @@ -1,7 +1,7 @@ use anyhow::anyhow; use itertools::Itertools as _; use mongodb::bson::{self, doc, Bson}; -use mongodb_support::aggregate::{Pipeline, Selection, Stage}; +use mongodb_support::aggregate::{AggregateCommand, Pipeline, Selection, Stage}; use ndc_query_plan::VariableSet; use super::pipeline::pipeline_for_non_foreach; @@ -15,16 +15,39 @@ use crate::mongo_query_plan::{MongoConfiguration, QueryPlan, Type, VariableTypes type Result = std::result::Result; /// Produces a complete MongoDB pipeline for a query request that includes variable sets. -pub fn pipeline_for_foreach( +pub fn command_for_foreach( request_variable_sets: &[VariableSet], config: &MongoConfiguration, query_request: &QueryPlan, -) -> Result { +) -> Result { let target = QueryTarget::for_request(config, query_request); let variable_sets = variable_sets_to_bson(request_variable_sets, &query_request.variable_types)?; + let query_pipeline = pipeline_for_non_foreach(config, query_request, QueryLevel::Top)?; + + // If there are multiple variable sets we need to use sub-pipelines to fork the query for each + // set. So we start the pipeline with a `$documents` stage to inject variable sets, and join + // the target collection with a `$lookup` stage with the query pipeline as a sub-pipeline. But + // if there is exactly one variable set then we can optimize away the `$lookup`. This is useful + // because some aggregation operations, like `$vectorSearch`, are not allowed in sub-pipelines. + Ok(if variable_sets.len() == 1 { + // safety: we just checked the length of variable_sets + let single_set = variable_sets.into_iter().next().unwrap(); + command_for_single_variable_set(single_set, target, query_pipeline) + } else { + command_for_multiple_variable_sets(query_request, variable_sets, target, query_pipeline) + }) +} + +// Where "multiple" means either zero or more than 1 +fn command_for_multiple_variable_sets( + query_request: &QueryPlan, + variable_sets: Vec, + target: QueryTarget<'_>, + query_pipeline: Pipeline, +) -> AggregateCommand { let variable_names = variable_sets .iter() .flat_map(|variable_set| variable_set.keys()); @@ -34,8 +57,6 @@ pub fn pipeline_for_foreach( let variable_sets_stage = Stage::Documents(variable_sets); - let query_pipeline = pipeline_for_non_foreach(config, query_request, QueryLevel::Top)?; - let lookup_stage = Stage::Lookup { from: target.input_collection().map(ToString::to_string), local_field: None, @@ -61,9 +82,25 @@ pub fn pipeline_for_foreach( }; let selection_stage = Stage::ReplaceWith(Selection::new(selection)); - Ok(Pipeline { - stages: vec![variable_sets_stage, lookup_stage, selection_stage], - }) + AggregateCommand { + collection: None, + pipeline: Pipeline { + stages: vec![variable_sets_stage, lookup_stage, selection_stage], + }, + let_vars: None, + } +} + +fn command_for_single_variable_set( + variable_set: bson::Document, + target: QueryTarget<'_>, + query_pipeline: Pipeline, +) -> AggregateCommand { + AggregateCommand { + collection: target.input_collection().map(ToString::to_string), + pipeline: query_pipeline, + let_vars: Some(variable_set), + } } fn variable_sets_to_bson( diff --git a/crates/mongodb-agent-common/src/query/make_selector/mod.rs b/crates/mongodb-agent-common/src/query/make_selector/mod.rs index 2f28b1d0..3612542c 100644 --- a/crates/mongodb-agent-common/src/query/make_selector/mod.rs +++ b/crates/mongodb-agent-common/src/query/make_selector/mod.rs @@ -33,7 +33,7 @@ pub fn make_selector(expr: &Expression) -> Result { mod tests { use configuration::MongoScalarType; use mongodb::bson::{self, bson, doc}; - use mongodb_support::BsonScalarType; + use mongodb_support::{aggregate::AggregateCommand, BsonScalarType}; use ndc_models::UnaryComparisonOperator; use ndc_query_plan::{plan_for_query_request, Scope}; use ndc_test_helpers::{ @@ -47,7 +47,7 @@ mod tests { mongo_query_plan::{ ComparisonTarget, ComparisonValue, ExistsInCollection, Expression, Type, }, - query::pipeline_for_query_request, + query::command_for_query_request, test_helpers::{chinook_config, chinook_relationships}, }; @@ -172,7 +172,7 @@ mod tests { } #[test] - fn root_column_reference_refereces_column_of_nearest_query() -> anyhow::Result<()> { + fn root_column_reference_references_column_of_nearest_query() -> anyhow::Result<()> { let request = query_request() .collection("Artist") .query( @@ -193,7 +193,11 @@ mod tests { let config = chinook_config(); let plan = plan_for_query_request(&config, request)?; - let pipeline = pipeline_for_query_request(&config, &plan)?; + let AggregateCommand { + collection: _, + pipeline, + let_vars, + } = command_for_query_request(&config, &plan)?; let expected_pipeline = bson!([ { @@ -256,6 +260,7 @@ mod tests { ]); assert_eq!(bson::to_bson(&pipeline).unwrap(), expected_pipeline); + assert_eq!(let_vars, None); Ok(()) } diff --git a/crates/mongodb-agent-common/src/query/mod.rs b/crates/mongodb-agent-common/src/query/mod.rs index d6094ca6..7ad27d10 100644 --- a/crates/mongodb-agent-common/src/query/mod.rs +++ b/crates/mongodb-agent-common/src/query/mod.rs @@ -19,7 +19,7 @@ use self::execute_query_request::execute_query_request; pub use self::{ make_selector::make_selector, make_sort::make_sort_stages, - pipeline::{is_response_faceted, pipeline_for_non_foreach, pipeline_for_query_request}, + pipeline::{command_for_query_request, is_response_faceted, pipeline_for_non_foreach}, query_target::QueryTarget, response::QueryResponseError, }; diff --git a/crates/mongodb-agent-common/src/query/pipeline.rs b/crates/mongodb-agent-common/src/query/pipeline.rs index f89d2c8f..cc036adb 100644 --- a/crates/mongodb-agent-common/src/query/pipeline.rs +++ b/crates/mongodb-agent-common/src/query/pipeline.rs @@ -4,7 +4,7 @@ use configuration::MongoScalarType; use itertools::Itertools; use mongodb::bson::{self, doc, Bson}; use mongodb_support::{ - aggregate::{Accumulator, Pipeline, Selection, Stage}, + aggregate::{Accumulator, AggregateCommand, Pipeline, Selection, Stage}, BsonScalarType, }; use ndc_models::FieldName; @@ -24,12 +24,13 @@ use crate::{ use super::{ column_ref::ColumnRef, constants::{RESULT_FIELD, ROWS_FIELD}, - foreach::pipeline_for_foreach, + foreach::command_for_foreach, make_selector, make_sort::make_sort_stages, native_query::pipeline_for_native_query, query_level::QueryLevel, relations::pipeline_for_relations, + QueryTarget, }; /// A query that includes aggregates will be run using a $facet pipeline stage, while a query @@ -44,14 +45,20 @@ pub fn is_response_faceted(query: &Query) -> bool { /// Shared logic to produce a MongoDB aggregation pipeline for a query request. #[instrument(name = "Build Query Pipeline" skip_all, fields(internal.visibility = "user"))] -pub fn pipeline_for_query_request( +pub fn command_for_query_request( config: &MongoConfiguration, query_plan: &QueryPlan, -) -> Result { +) -> Result { if let Some(variable_sets) = &query_plan.variables { - pipeline_for_foreach(variable_sets, config, query_plan) + command_for_foreach(variable_sets, config, query_plan) } else { - pipeline_for_non_foreach(config, query_plan, QueryLevel::Top) + let target = QueryTarget::for_request(config, query_plan); + let pipeline = pipeline_for_non_foreach(config, query_plan, QueryLevel::Top)?; + Ok(AggregateCommand { + collection: target.input_collection().map(ToString::to_string), + pipeline, + let_vars: None, + }) } } diff --git a/crates/mongodb-agent-common/src/query/response.rs b/crates/mongodb-agent-common/src/query/response.rs index cec6f1b8..b761f5de 100644 --- a/crates/mongodb-agent-common/src/query/response.rs +++ b/crates/mongodb-agent-common/src/query/response.rs @@ -56,7 +56,11 @@ pub fn serialize_query_response( ) -> Result { let collection_name = &query_plan.collection; - let row_sets = if query_plan.has_variables() { + let row_sets = if query_plan + .variables + .as_ref() + .is_some_and(|variables| variables.len() != 1) + { response_documents .into_iter() .map(|document| { diff --git a/crates/mongodb-support/src/aggregate/command.rs b/crates/mongodb-support/src/aggregate/command.rs new file mode 100644 index 00000000..b9cc24db --- /dev/null +++ b/crates/mongodb-support/src/aggregate/command.rs @@ -0,0 +1,12 @@ +use mongodb::bson::Document; + +use super::Pipeline; + +/// Aggregate command used with, e.g., `db..aggregate()` +/// +/// This is not a complete implementation - only the fields needed by the connector are listed. +pub struct AggregateCommand { + pub collection: Option, + pub pipeline: Pipeline, + pub let_vars: Option, +} diff --git a/crates/mongodb-support/src/aggregate/mod.rs b/crates/mongodb-support/src/aggregate/mod.rs index dfab9856..afd65bb3 100644 --- a/crates/mongodb-support/src/aggregate/mod.rs +++ b/crates/mongodb-support/src/aggregate/mod.rs @@ -1,10 +1,12 @@ mod accumulator; +mod command; mod pipeline; mod selection; mod sort_document; mod stage; pub use self::accumulator::Accumulator; +pub use self::command::AggregateCommand; pub use self::pipeline::Pipeline; pub use self::selection::Selection; pub use self::sort_document::SortDocument;