Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 259 additions & 6 deletions packages/3-extensions/sql-orm-client/src/mutation-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
} from '@prisma-next/sql-relational-core/ast';
import type { ExecutionContext } from '@prisma-next/sql-relational-core/query-lane-context';
import type { RuntimeScope } from '@prisma-next/sql-relational-core/types';
import { blindCast } from '@prisma-next/utils/casts';
import {
getColumnToFieldMap,
resolveFieldToColumn,
Expand All @@ -23,6 +24,8 @@ import {
import { executeQueryPlan } from './execute-query-plan';
import { and, shorthandToWhereExpr } from './filters';
import {
compileDeleteCount,
compileInsertCount,
compileInsertReturning,
compileSelect,
compileUpdateCount,
Expand All @@ -44,13 +47,30 @@ import type {
} from './types';
import { emptyState } from './types';

interface JunctionThrough {
readonly table: string;
readonly parentColumns: readonly string[];
readonly childColumns: readonly string[];
readonly targetColumns: readonly string[];
readonly requiredPayloadColumns: readonly string[];
}

interface RelationDefinition {
readonly relationName: string;
readonly relatedModelName: string;
readonly relatedTableName: string;
readonly cardinality: RelationCardinalityTag | undefined;
readonly localColumns: readonly string[];
readonly targetColumns: readonly string[];
readonly through: JunctionThrough | undefined;
}

interface JunctionRelationDefinition extends RelationDefinition {
readonly through: JunctionThrough;
}

function hasThrough(relation: RelationDefinition): relation is JunctionRelationDefinition {
return relation.through !== undefined;
}

interface ParsedRelationMutation {
Expand Down Expand Up @@ -164,7 +184,7 @@ async function createGraph(
): Promise<Record<string, unknown>> {
const contract = context.contract;
const parsed = parseMutationInput(contract, modelName, input);
const { parentOwned, childOwned } = partitionByOwnership(parsed.relationMutations);
const { parentOwned, childOwned, junctionOwned } = partitionByOwnership(parsed.relationMutations);

const scalarData = { ...parsed.scalarData };

Expand Down Expand Up @@ -200,6 +220,21 @@ async function createGraph(
);
}

for (const relationMutation of junctionOwned) {
if (relationMutation.mutation.kind === 'disconnect') {
throw new Error('disconnect() is only supported in update() nested mutations');
}

await applyJunctionOwnedMutation(
scope,
context,
modelName,
parentRow,
relationMutation.relation,
relationMutation.mutation,
);
}

return parentRow;
}

Expand All @@ -217,7 +252,7 @@ async function updateFirstGraph(
}

const parsed = parseMutationInput(contract, modelName, input as Record<string, unknown>);
const { parentOwned, childOwned } = partitionByOwnership(parsed.relationMutations);
const { parentOwned, childOwned, junctionOwned } = partitionByOwnership(parsed.relationMutations);

const scalarData = { ...parsed.scalarData };

Expand Down Expand Up @@ -284,6 +319,17 @@ async function updateFirstGraph(
);
}

for (const relationMutation of junctionOwned) {
await applyJunctionOwnedMutation(
scope,
context,
modelName,
parentRow,
relationMutation.relation,
relationMutation.mutation,
);
}

return parentRow;
}

Expand Down Expand Up @@ -335,21 +381,31 @@ function parseMutationInput(
};
}

interface JunctionParsedRelationMutation extends ParsedRelationMutation {
readonly relation: JunctionRelationDefinition;
}

function partitionByOwnership(relationMutations: readonly ParsedRelationMutation[]): {
parentOwned: ParsedRelationMutation[];
childOwned: ParsedRelationMutation[];
junctionOwned: JunctionParsedRelationMutation[];
} {
const parentOwned: ParsedRelationMutation[] = [];
const childOwned: ParsedRelationMutation[] = [];
const junctionOwned: JunctionParsedRelationMutation[] = [];

for (const relationMutation of relationMutations) {
if (relationMutation.relation.cardinality === 'N:1') {
parentOwned.push(relationMutation);
if (hasThrough(relationMutation.relation)) {
junctionOwned.push({
relation: relationMutation.relation,
mutation: relationMutation.mutation,
});
continue;
}

if (relationMutation.relation.cardinality === 'N:M') {
throw new Error('N:M nested mutations are not supported yet');
if (relationMutation.relation.cardinality === 'N:1') {
parentOwned.push(relationMutation);
continue;
}

childOwned.push(relationMutation);
Expand All @@ -358,6 +414,7 @@ function partitionByOwnership(relationMutations: readonly ParsedRelationMutation
return {
parentOwned,
childOwned,
junctionOwned,
};
}

Expand Down Expand Up @@ -527,6 +584,193 @@ async function applyChildOwnedMutation(
}
}

async function applyJunctionOwnedMutation(
scope: RuntimeScope,
context: ExecutionContext,
parentModelName: string,
parentRow: Record<string, unknown>,
relation: JunctionRelationDefinition,
mutation: RelationMutation<Contract<SqlStorage>, string>,
): Promise<void> {
const contract = context.contract;
const through = relation.through;
const parentPkValues = readJunctionParentValues(contract, parentModelName, relation, parentRow);

if (mutation.kind === 'create' || mutation.kind === 'connect') {
if (through.requiredPayloadColumns.length > 0) {
const cols = through.requiredPayloadColumns.map((c) => `\`${c}\``).join(', ');
throw new Error(
`Cannot \`${mutation.kind}\` on relation \`${relation.relationName}\`: its junction \`${through.table}\` has required column(s) ${cols} the relation API can't populate. Use the \`${relation.relatedModelName}\` model directly or the SQL builder.`,
);
}
}

if (mutation.kind === 'create') {
for (const childInput of mutation.data) {
const relatedRow = await insertSingleRow(
scope,
context,
relation.relatedModelName,
blindCast<Record<string, unknown>, 'mutation create input is a plain object payload'>(
childInput,
),
);
const targetPkValues = readJunctionTargetValues(contract, relation, relatedRow);
await insertJunctionLink(scope, context, through, parentPkValues, targetPkValues);
}
return;
}

if (mutation.kind === 'connect') {
for (const criterion of mutation.criteria) {
const targetPkValues = await resolveJunctionTargetValues(
scope,
context,
relation,
'connect',
criterion,
);
await insertJunctionLink(scope, context, through, parentPkValues, targetPkValues);
}
return;
}

if (!mutation.criteria || mutation.criteria.length === 0) {
throw new Error(
`disconnect() nested mutation for relation "${relation.relationName}" requires criterion`,
);
}

for (const criterion of mutation.criteria) {
const targetPkValues = await resolveJunctionTargetValues(
scope,
context,
relation,
'disconnect',
criterion,
);
await deleteJunctionLink(scope, context, through, parentPkValues, targetPkValues);
}
}

async function resolveJunctionTargetValues(
scope: RuntimeScope,
context: ExecutionContext,
relation: JunctionRelationDefinition,
kind: 'connect' | 'disconnect',
criterion: unknown,
): Promise<Map<string, unknown>> {
const relatedRow = await findRowByCriterion(
scope,
context,
relation.relatedModelName,
blindCast<Record<string, unknown>, 'connect/disconnect criterion is a plain object'>(criterion),
);
if (!relatedRow) {
throw new Error(
`${kind}() nested mutation for relation "${relation.relationName}" did not find a matching row`,
);
}
return readJunctionTargetValues(context.contract, relation, relatedRow);
}

function readJunctionParentValues(
contract: Contract<SqlStorage>,
parentModelName: string,
relation: JunctionRelationDefinition,
parentRow: Record<string, unknown>,
): Map<string, unknown> {
const values = new Map<string, unknown>();

for (let i = 0; i < relation.through.parentColumns.length; i++) {
const junctionColumn = relation.through.parentColumns[i];
const parentColumn = relation.localColumns[i];
if (!junctionColumn || !parentColumn) {
continue;
}

const parentFieldName = toFieldName(contract, parentModelName, parentColumn);
const parentValue = parentRow[parentFieldName];
if (parentValue === undefined) {
throw new Error(
`Nested mutation requires parent field "${parentFieldName}" to be present in returned row`,
);
}

values.set(junctionColumn, parentValue);
}

return values;
}

function readJunctionTargetValues(
contract: Contract<SqlStorage>,
relation: JunctionRelationDefinition,
relatedRow: Record<string, unknown>,
): Map<string, unknown> {
const values = new Map<string, unknown>();

for (let i = 0; i < relation.through.childColumns.length; i++) {
const junctionColumn = relation.through.childColumns[i];
const targetColumn = relation.through.targetColumns[i];
if (!junctionColumn || !targetColumn) {
continue;
}

const targetFieldName = toFieldName(contract, relation.relatedModelName, targetColumn);
const targetValue = relatedRow[targetFieldName];
if (targetValue === undefined) {
throw new Error(
`Nested mutation requires target field "${targetFieldName}" to be present in returned row`,
);
}

values.set(junctionColumn, targetValue);
}

return values;
}

async function insertJunctionLink(
scope: RuntimeScope,
context: ExecutionContext,
through: JunctionThrough,
parentPkValues: Map<string, unknown>,
targetPkValues: Map<string, unknown>,
): Promise<void> {
const junctionRow: Record<string, unknown> = {};
for (const [column, value] of parentPkValues.entries()) {
junctionRow[column] = value;
}
for (const [column, value] of targetPkValues.entries()) {
junctionRow[column] = value;
}

const compiled = compileInsertCount(context.contract, through.table, [junctionRow]);
await executeQueryPlan<Record<string, unknown>>(scope, compiled).toArray();
}

async function deleteJunctionLink(
scope: RuntimeScope,
context: ExecutionContext,
through: JunctionThrough,
parentPkValues: Map<string, unknown>,
targetPkValues: Map<string, unknown>,
): Promise<void> {
const exprs: AnyExpression[] = [];
for (const [column, value] of parentPkValues.entries()) {
exprs.push(BinaryExpr.eq(ColumnRef.of(through.table, column), LiteralExpr.of(value)));
}
for (const [column, value] of targetPkValues.entries()) {
exprs.push(BinaryExpr.eq(ColumnRef.of(through.table, column), LiteralExpr.of(value)));
}

const first = exprs[0];
const where = exprs.length === 1 && first !== undefined ? first : and(...exprs);
const compiled = compileDeleteCount(context.contract, through.table, [where]);
await executeQueryPlan<Record<string, unknown>>(scope, compiled).toArray();
}

function readParentColumnValues(
contract: Contract<SqlStorage>,
parentModelName: string,
Expand Down Expand Up @@ -701,6 +945,15 @@ function getRelationDefinitions(
targetColumns: relation.on.targetFields.map((f) =>
resolveFieldToColumn(contract, relation.to, f),
),
through: relation.through
? {
table: relation.through.table,
parentColumns: relation.through.parentColumns,
childColumns: relation.through.childColumns,
targetColumns: relation.through.targetColumns,
requiredPayloadColumns: relation.through.requiredPayloadColumns,
}
: undefined,
}));

perContract.set(modelName, definitions);
Expand Down
Loading
Loading