Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
} from '@prisma-next/contract/types';
import type { SqlStorage, StorageTable } from '@prisma-next/sql-contract/types';
import { castAs } from '@prisma-next/utils/casts';
import type { RelationCardinalityTag } from './types';
import type { IncludeThroughDescriptor, RelationCardinalityTag } from './types';

type ModelStorageFields = Record<string, { column?: string }>;
type ModelEntry = {
Expand Down Expand Up @@ -221,6 +221,7 @@ export interface ResolvedIncludeRelation {
readonly targetColumn: string;
readonly localColumn: string;
readonly cardinality: RelationCardinalityTag | undefined;
readonly through?: IncludeThroughDescriptor;
}

export function resolveIncludeRelation(
Expand All @@ -245,12 +246,27 @@ export function resolveIncludeRelation(
const localColumn = resolveFieldToColumn(contract, modelName, localField);
const targetColumn = resolveFieldToColumn(contract, relation.to, targetField);

let through: IncludeThroughDescriptor | undefined;
if (relation.through !== undefined) {
const parentLocalColumns = relation.on.localFields.map((field) =>
resolveFieldToColumn(contract, modelName, field),
);
through = {
table: relation.through.table,
parentColumns: relation.through.parentColumns,
childColumns: relation.through.childColumns,
targetColumns: relation.through.targetColumns,
parentLocalColumns,
};
}

return {
relatedModelName: relation.to,
relatedTableName,
targetColumn,
localColumn,
cardinality: relation.cardinality,
...(through !== undefined ? { through } : {}),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ function dispatchWithIncludes<Row>(options: {
const generator = async function* (): AsyncGenerator<Row, void, unknown> {
const { scope, release } = await acquireRuntimeScope(runtime);
try {
const parentJoinColumns = state.includes.map((include) => include.localColumn);
const parentJoinColumns = state.includes.flatMap((include) =>
include.through !== undefined ? include.through.parentLocalColumns : [include.localColumn],
);
const { selectedForQuery: parentSelectedForQuery, hiddenColumns: hiddenParentColumns } =
augmentSelectionForJoinColumns(state.selectedFields, parentJoinColumns);
const compiled = compileSelectWithIncludes(
Expand Down
1 change: 1 addition & 0 deletions packages/3-extensions/sql-orm-client/src/collection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ export class Collection<
targetColumn: relation.targetColumn,
localColumn: relation.localColumn,
cardinality: relation.cardinality,
...(relation.through !== undefined ? { through: relation.through } : {}),
nested: nestedState,
scalar: scalarSelector,
combine: combineBranches,
Expand Down
84 changes: 78 additions & 6 deletions packages/3-extensions/sql-orm-client/src/query-plan-select.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
} from '@prisma-next/sql-relational-core/ast';
import { codecRefForStorageColumn } from '@prisma-next/sql-relational-core/codec-descriptor-registry';
import type { SqlQueryPlan } from '@prisma-next/sql-relational-core/plan';
import { castAs } from '@prisma-next/utils/casts';
import { ifDefined } from '@prisma-next/utils/defined';
import {
type PolymorphismInfo,
Expand Down Expand Up @@ -292,6 +293,53 @@ function buildNestedIncludeArtifacts(
return { projections };
}

/**
* Build the correlated WHERE and junction JOIN artifacts for a many-to-many
* include. The resulting WHERE correlates the junction to the parent rows
* (AND-ed across all column pairs for composite keys). The junction JOIN
* connects child rows to the junction via the child columns.
*/
function buildManyToManyJunctionArtifacts(
parentTableName: string,
childTableRef: string,
through: NonNullable<IncludeExpr['through']>,
): {
readonly whereExpr: AnyExpression;
readonly junctionJoin: JoinAst;
} {
const {
table: junctionTable,
parentColumns,
childColumns,
targetColumns,
parentLocalColumns,
} = through;

const joinOnPairs = childColumns.map((junctionCol, i) =>
BinaryExpr.eq(
ColumnRef.of(junctionTable, junctionCol),
ColumnRef.of(childTableRef, targetColumns[i] ?? junctionCol),
),
);
const joinOn: AnyExpression =
joinOnPairs.length === 1 ? castAs<AnyExpression>(joinOnPairs[0]!) : AndExpr.of(joinOnPairs);

const correlationPairs = parentColumns.map((junctionCol, i) =>
BinaryExpr.eq(
ColumnRef.of(junctionTable, junctionCol),
ColumnRef.of(parentTableName, parentLocalColumns[i] ?? junctionCol),
),
);
const whereExpr: AnyExpression =
correlationPairs.length === 1
? castAs<AnyExpression>(correlationPairs[0]!)
: AndExpr.of(correlationPairs);

const junctionJoin = JoinAst.inner(TableSource.named(junctionTable), joinOn, false);

return { whereExpr, junctionJoin };
}

function buildIncludeChildRowsSelect(
contract: Contract<SqlStorage>,
parentTableName: string,
Expand Down Expand Up @@ -327,11 +375,25 @@ function buildIncludeChildRowsSelect(
const childWhere = buildStateWhere(contract, childTableRef, childState, {
filterTableName: include.relatedTableName,
});
const joinExpr = BinaryExpr.eq(
ColumnRef.of(childTableRef, include.targetColumn),
ColumnRef.of(parentTableName, include.localColumn),
);
const whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;

let whereExpr: AnyExpression;
let junctionJoins: JoinAst[] = [];

if (include.through !== undefined) {
const artifacts = buildManyToManyJunctionArtifacts(
parentTableName,
childTableRef,
include.through,
);
whereExpr = childWhere ? AndExpr.of([artifacts.whereExpr, childWhere]) : artifacts.whereExpr;
junctionJoins = [artifacts.junctionJoin];
} else {
const joinExpr = BinaryExpr.eq(
ColumnRef.of(childTableRef, include.targetColumn),
ColumnRef.of(parentTableName, include.localColumn),
);
whereExpr = childWhere ? AndExpr.of([joinExpr, childWhere]) : joinExpr;
}

// `distinct()` on a non-leaf include cannot be lowered as
// `SELECT DISTINCT <scalars>, json_agg(<grandchild>) FROM ...`:
Expand Down Expand Up @@ -359,6 +421,7 @@ function buildIncludeChildRowsSelect(
hiddenOrderProjection,
aggregateOrderBy,
whereExpr,
junctionJoins,
});
}

Expand Down Expand Up @@ -392,6 +455,10 @@ function buildIncludeChildRowsSelect(
.withProjection([...childProjection, ...hiddenOrderProjection])
.withWhere(whereExpr);

if (junctionJoins.length > 0) {
childRows = childRows.withJoins(junctionJoins);
}

if (childState.distinctOn && childState.distinctOn.length > 0) {
childRows = childRows.withDistinctOn(
childState.distinctOn.map((column) => ColumnRef.of(childTableRef, column)),
Expand Down Expand Up @@ -454,6 +521,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
readonly hiddenOrderProjection: ReadonlyArray<ProjectionItem>;
readonly aggregateOrderBy: ReadonlyArray<OrderByItem> | undefined;
readonly whereExpr: AnyExpression;
readonly junctionJoins: ReadonlyArray<JoinAst>;
}): {
readonly childRows: SelectAst;
readonly childProjection: ReadonlyArray<ProjectionItem>;
Expand All @@ -470,6 +538,7 @@ function buildDistinctNonLeafChildRowsSelect(options: {
hiddenOrderProjection,
aggregateOrderBy,
whereExpr,
junctionJoins,
} = options;
const childState = include.nested;

Expand Down Expand Up @@ -511,9 +580,12 @@ function buildDistinctNonLeafChildRowsSelect(options: {
selectedForQuery,
childTableRef,
);
const baseInner = SelectAst.from(TableSource.named(include.relatedTableName, childTableAlias))
let baseInner = SelectAst.from(TableSource.named(include.relatedTableName, childTableAlias))
.withProjection([...innerScalarProjection, ...hiddenOrderProjection])
.withWhere(whereExpr);
if (junctionJoins.length > 0) {
baseInner = baseInner.withJoins(junctionJoins);
}

// `childState.distinct` is non-empty by the `isDistinctNonLeaf` guard
// at the only caller (`buildIncludeChildRowsSelect`); assert here so
Expand Down
13 changes: 13 additions & 0 deletions packages/3-extensions/sql-orm-client/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,26 @@ export interface IncludeCombine<ResultShape extends Record<string, unknown>>
readonly branches: Readonly<Record<string, IncludeCombineBranch>>;
}

export interface IncludeThroughDescriptor {
readonly table: string;
/** FK columns in the junction table that point to the parent. */
readonly parentColumns: readonly string[];
/** FK columns in the junction table that point to the target (child). */
readonly childColumns: readonly string[];
/** PK columns in the target table that the junction's childColumns reference. */
readonly targetColumns: readonly string[];
/** Resolved column names in the parent table that junction.parentColumns reference. */
readonly parentLocalColumns: readonly string[];
}

export interface IncludeExpr {
readonly relationName: string;
readonly relatedModelName: string;
readonly relatedTableName: string;
readonly targetColumn: string;
readonly localColumn: string;
readonly cardinality: RelationCardinalityTag | undefined;
readonly through?: IncludeThroughDescriptor;
readonly nested: CollectionState;
readonly scalar: IncludeScalar<unknown> | undefined;
readonly combine: Readonly<Record<string, IncludeCombineBranch>> | undefined;
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading