Skip to content

Commit e3629fa

Browse files
authored
feat: improve A_Expr inference (#387)
1 parent f44a45a commit e3629fa

File tree

8 files changed

+175
-356
lines changed

8 files changed

+175
-356
lines changed

.changeset/gentle-lizards-jump.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@ts-safeql/generate": patch
3+
"@ts-safeql/shared": patch
4+
---
5+
6+
improved A_Expr inference

packages/generate/src/ast-describe.ts

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { defaultTypeExprMapping, fmap, normalizeIndent } from "@ts-safeql/shared";
1+
import { fmap, normalizeIndent } from "@ts-safeql/shared";
22
import * as LibPgQueryAST from "@ts-safeql/sql-ast";
33
import {
44
isColumnStarRef,
@@ -16,6 +16,7 @@ type ASTDescriptionOptions = {
1616
parsed: LibPgQueryAST.ParseResult;
1717
relations: FlattenedRelationWithJoins[];
1818
typesMap: Map<string, { override: boolean; value: string }>;
19+
typeExprMap: Map<string, Map<string, Map<string, string>>>;
1920
overridenColumnTypesMap: Map<string, Map<string, string>>;
2021
nonNullableColumns: Set<string>;
2122
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
@@ -274,61 +275,119 @@ function getDescribedAExpr({
274275

275276
if (column === undefined) return null;
276277

277-
if (column.type.kind === "array") {
278-
return { value: "array", nullable: false };
279-
}
278+
const getFromType = (
279+
type: ASTDescribedColumnType,
280+
): { value: string; array: boolean; nullable: boolean } | null => {
281+
switch (true) {
282+
case type.kind === "type":
283+
return { value: type.base ?? type.type, array: false, nullable: false };
280284

281-
if (column.type.kind === "type") {
282-
return { value: column.type.base ?? column.type.type, nullable: false };
283-
}
285+
case type.kind === "literal" && type.base.kind === "type":
286+
return { value: type.base.type, array: false, nullable: false };
284287

285-
if (column.type.kind === "literal" && column.type.base.kind === "type") {
286-
return { value: column.type.base.type, nullable: false };
287-
}
288+
case type.kind === "union" && type.value.every((x) => x.kind === "literal"): {
289+
const resolved = getFromType(type.value[0].base);
290+
291+
if (resolved === null) return null;
292+
293+
return { value: resolved.value, nullable: false, array: false };
294+
}
295+
296+
case type.kind === "union" && isTuple(type.value): {
297+
let nullable = false;
298+
let value: string | undefined = undefined;
299+
300+
for (const valueType of type.value) {
301+
if (valueType.kind !== "type") return null;
302+
if (valueType.value === "null") nullable = true;
303+
if (valueType.value !== "null") value = valueType.type;
304+
}
305+
306+
if (value === undefined) return null;
288307

289-
if (column.type.kind === "union" && isTuple(column.type.value)) {
290-
let nullable = false;
291-
let value: string | undefined = undefined;
308+
return { value, nullable, array: false };
309+
}
292310

293-
for (const type of column.type.value) {
294-
if (type.kind !== "type") return null;
295-
if (type.value === "null") nullable = true;
296-
if (type.value !== "null") value = type.type;
311+
default:
312+
return null;
297313
}
314+
};
315+
316+
if (column.type.kind === "array") {
317+
const resolved = getFromType(column.type.value);
298318

299-
if (value === undefined) return null;
319+
if (!resolved) return null;
300320

301-
return { value, nullable };
321+
return { value: resolved.value, nullable: resolved.nullable, array: true };
302322
}
303323

304-
return null;
324+
return getFromType(column.type);
305325
};
306326

307327
const lnode = getResolvedNullableValueOrNull(node.lexpr);
308328
const rnode = getResolvedNullableValueOrNull(node.rexpr);
329+
const operator = concatStringNodes(node.name);
309330

310331
if (lnode === null || rnode === null) {
311332
return [];
312333
}
313334

314-
const operator = concatStringNodes(node.name);
315-
const resolved: string | undefined =
316-
defaultTypeExprMapping[`${lnode.value} ${operator} ${rnode.value}`];
335+
const downcast = () => {
336+
const left = lnode.array ? `_${lnode.value}` : lnode.value;
337+
const right = rnode.array ? `_${rnode.value}` : rnode.value;
338+
339+
const overrides: Record<string, [string, string, string]> = {
340+
"int4 ^ int4": ["float8", "^", "float8"],
341+
};
342+
343+
if (overrides[`${left} ${operator} ${right}`]) {
344+
return overrides[`${left} ${operator} ${right}`];
345+
}
346+
347+
const adjust = (value: string) => (value === "varchar" ? "text" : value);
348+
349+
return [adjust(left), operator, adjust(right)];
350+
};
317351

318-
if (resolved === undefined) {
352+
const getType = (): ASTDescribedColumnType | undefined => {
353+
const nullable = !context.nonNullableColumns.has(name) && (lnode.nullable || rnode.nullable);
354+
const [dleft, doperator, dright] = downcast();
355+
356+
const type =
357+
context.typeExprMap.get(dleft)?.get(doperator)?.get(dright) ??
358+
context.typeExprMap.get("anycompatiblearray")?.get(operator)?.get("anycompatiblearray") ??
359+
context.typeExprMap.get("anyarray")?.get(operator)?.get("anyarray") ??
360+
context.typeExprMap.get(lnode.value)?.get(operator)?.values().next().value;
361+
362+
if (type === undefined) {
363+
return;
364+
}
365+
366+
if (type === "anycompatiblearray") {
367+
return {
368+
kind: "array",
369+
value: resolveType({
370+
context,
371+
nullable,
372+
type: context.toTypeScriptType({ name: lnode.value }),
373+
}),
374+
};
375+
}
376+
377+
return resolveType({
378+
context,
379+
nullable,
380+
type: context.toTypeScriptType({ name: type }),
381+
});
382+
};
383+
384+
const type = getType();
385+
386+
if (type === undefined) {
319387
return [];
320388
}
321389

322-
return [
323-
{
324-
name: name,
325-
type: resolveType({
326-
context: context,
327-
nullable: !context.nonNullableColumns.has(name) && (lnode.nullable || rnode.nullable),
328-
type: context.toTypeScriptType({ name: resolved }),
329-
}),
330-
},
331-
];
390+
return [{ name, type }];
332391
}
333392

334393
function getDescribedNullTest({
@@ -815,14 +874,14 @@ function getColumnRefOrigins({
815874
// lookup in cte
816875
context.select.withClause?.ctes
817876
.find((cte) => cte.CommonTableExpr?.ctename === source)
818-
?.CommonTableExpr?.ctequery?.SelectStmt?.targetList?.map((x) => x.ResTarget)
819-
.find((x) => x?.name === column)?.val ??
877+
?.CommonTableExpr?.ctequery?.SelectStmt?.targetList?.find(
878+
(x) => x.ResTarget?.name === column,
879+
) ??
820880
// lookup in subselect
821881
context.select.fromClause
822882
?.map((from) => from.RangeSubselect)
823883
.find((subselect) => subselect?.alias?.aliasname === source)
824-
?.subquery?.SelectStmt?.targetList?.map((x) => x.ResTarget)
825-
.find((x) => x?.name === column)?.val;
884+
?.subquery?.SelectStmt?.targetList?.find((x) => x.ResTarget?.name === column);
826885

827886
if (!origin) return undefined;
828887

packages/generate/src/generate.test.ts

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ const testQuery = async (params: {
220220
({ output, unknownColumns }) => {
221221
assert.deepEqual(output?.value ?? null, params.expected);
222222

223-
if (unknownColumns.length > 0) {
223+
if (unknownColumns.length > 0 || params.unknownColumns) {
224224
assert.deepEqual(unknownColumns, params.unknownColumns);
225225
}
226226
},
@@ -737,7 +737,6 @@ test("select from subselect with an alias", async () => {
737737
await testQuery({
738738
query: `SELECT subselect.id FROM (SELECT * FROM caregiver) AS subselect`,
739739
expected: [["id", { kind: "type", value: "number", type: "int4" }]],
740-
unknownColumns: ["id"],
741740
});
742741
});
743742

@@ -1971,46 +1970,48 @@ test("ARRAY[2, 3] <@ ARRAY[1, 2, 3] => boolean", async () => {
19711970
test("ARRAY[1, 2] || ARRAY[3, 4] => array", async () => {
19721971
await testQuery({
19731972
query: `SELECT ARRAY[1, 2] || ARRAY[3, 4]`,
1974-
expected: [["?column?", { kind: "type", value: "array", type: "array" }]],
1973+
expected: [
1974+
["?column?", { kind: "array", value: { kind: "type", type: "int4", value: "number" } }],
1975+
],
19751976
});
19761977
});
19771978

1978-
test("'{\"key\": \"value\"}'::jsonb ? 'key' => boolean", async () => {
1979+
test(`'{"key": "value"}'::jsonb ? 'key' => boolean`, async () => {
19791980
await testQuery({
19801981
query: `SELECT '{"key": "value"}'::jsonb ? 'key'`,
19811982
expected: [["?column?", { kind: "type", value: "boolean", type: "bool" }]],
19821983
});
19831984
});
19841985

1985-
test("'{\"a\": 1, \"b\": 2}'::jsonb ?| array['a', 'c'] => boolean", async () => {
1986+
test(`'{"a": 1, "b": 2}'::jsonb ?| array['a', 'c'] => boolean`, async () => {
19861987
await testQuery({
19871988
query: `SELECT '{"a": 1, "b": 2}'::jsonb ?| array['a', 'c']`,
19881989
expected: [["?column?", { kind: "type", value: "boolean", type: "bool" }]],
19891990
});
19901991
});
19911992

1992-
test("'{\"a\": 1, \"b\": 2}'::jsonb ?& array['a', 'b'] => boolean", async () => {
1993+
test(`'{"a": 1, "b": 2}'::jsonb ?& array['a', 'b'] => boolean`, async () => {
19931994
await testQuery({
19941995
query: `SELECT '{"a": 1, "b": 2}'::jsonb ?& array['a', 'b']`,
19951996
expected: [["?column?", { kind: "type", value: "boolean", type: "bool" }]],
19961997
});
19971998
});
19981999

1999-
test("'{\"a\": {\"b\": 1}}'::jsonb -> 'a' => jsonb", async () => {
2000+
test(`'{"a": {"b": 1}}'::jsonb -> 'a' => jsonb`, async () => {
20002001
await testQuery({
20012002
query: `SELECT '{"a": {"b": 1}}'::jsonb -> 'a'`,
20022003
expected: [["?column?", { kind: "type", type: "jsonb", value: "any" }]],
20032004
});
20042005
});
20052006

2006-
test("'{\"a\": {\"b\": 1}}'::jsonb ->> 'a' => string", async () => {
2007+
test(`'{"a": {"b": 1}}'::jsonb ->> 'a' => string`, async () => {
20072008
await testQuery({
20082009
query: `SELECT '{"a": {"b": 1}}'::jsonb ->> 'a'`,
20092010
expected: [["?column?", { kind: "type", value: "string", type: "text" }]],
20102011
});
20112012
});
20122013

2013-
test("'{\"a\": 1, \"b\": 2}'::jsonb #- '{a}' => jsonb", async () => {
2014+
test(`'{"a": 1, "b": 2}'::jsonb #- '{a}' => jsonb`, async () => {
20142015
await testQuery({
20152016
query: `SELECT '{"a": 1, "b": 2}'::jsonb #- '{a}'`,
20162017
expected: [["?column?", { kind: "type", value: "any", type: "jsonb" }]],
@@ -2126,7 +2127,6 @@ test("with select from inner join and left join", async () => {
21262127
INNER JOIN caregiver_agency ON x.id = caregiver_agency.caregiver_id
21272128
LEFT JOIN agency ON caregiver_agency.agency_id = agency.id
21282129
`,
2129-
unknownColumns: ["id"],
21302130
expected: [
21312131
["id", { kind: "type", type: "int4", value: "number" }],
21322132
[
@@ -2170,7 +2170,6 @@ test("select colref and const from left joined using col", async () => {
21702170
},
21712171
],
21722172
],
2173-
unknownColumns: ["value"], // TODO: `ast-get-source` needs to be refactored to handle this case
21742173
});
21752174
});
21762175

@@ -2276,3 +2275,19 @@ test("select col.tbl from cte with array agg and col filter", async () => {
22762275
],
22772276
});
22782277
});
2278+
2279+
test("varchar not like expr", async () => {
2280+
await testQuery({
2281+
schema: `CREATE TABLE tbl (email varchar(80) NOT NULL)`,
2282+
query: `SELECT jsonb_build_object('key', tbl.email NOT LIKE '%@example.com') AS col FROM tbl`,
2283+
expected: [
2284+
[
2285+
"col",
2286+
{
2287+
kind: "object",
2288+
value: [["key", { kind: "type", type: "bool", value: "boolean" }]],
2289+
},
2290+
],
2291+
],
2292+
});
2293+
});

packages/generate/src/generate.ts

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ type Cache = {
7272
CacheKey,
7373
{
7474
pgTypes: PgTypesMap;
75-
pgCols: PgColRow[];
75+
pgCols: postgres.RowList<PgColRow[]>;
7676
pgEnums: PgEnumsMaps;
7777
pgColsByTableOidCache: Map<number, PgColRow[]>;
7878
pgColsBySchemaAndTableName: Map<string, Map<string, PgColRow[]>>;
7979
pgFnsByName: Map<string, PgFnRow[]>;
80+
pgTypeExprMap: Map<string, Map<string, Map<string, string>>>;
8081
}
8182
>;
8283
overrides: {
@@ -128,13 +129,19 @@ async function generate(
128129
): Promise<either.Either<GenerateError, GenerateResult>> {
129130
const { sql, query, cacheKey, cacheMetadata = true } = params;
130131

131-
const { pgColsByTableOidCache, pgColsBySchemaAndTableName, pgTypes, pgEnums, pgFnsByName } =
132-
await getOrSetFromMapWithEnabled({
133-
shouldCache: cacheMetadata,
134-
map: cache.base,
135-
key: cacheKey,
136-
value: () => getDatabaseMetadata(sql),
137-
});
132+
const {
133+
pgColsByTableOidCache,
134+
pgColsBySchemaAndTableName,
135+
pgTypes,
136+
pgEnums,
137+
pgFnsByName,
138+
pgTypeExprMap,
139+
} = await getOrSetFromMapWithEnabled({
140+
shouldCache: cacheMetadata,
141+
map: cache.base,
142+
key: cacheKey,
143+
value: () => getDatabaseMetadata(sql),
144+
});
138145

139146
const typesMap = await getOrSetFromMapWithEnabled({
140147
shouldCache: cacheMetadata,
@@ -266,6 +273,7 @@ async function generate(
266273
pgTypes: pgTypes,
267274
pgEnums: pgEnums,
268275
pgFns: functionsMap,
276+
typeExprMap: pgTypeExprMap,
269277
});
270278

271279
const columns = result.columns.map((col, position): ColumnAnalysisResult => {
@@ -329,6 +337,7 @@ async function getDatabaseMetadata(sql: Sql) {
329337
const pgColsByTableOidCache = groupBy(pgCols, "tableOid");
330338
const pgColsBySchemaAndTableName = groupBy(pgCols, "schemaName", "tableName");
331339
const pgFnsByName = groupBy(pgFns, "name");
340+
const pgTypeExprMap = await getPgTypeExprMap(sql);
332341

333342
return {
334343
pgTypes,
@@ -337,6 +346,7 @@ async function getDatabaseMetadata(sql: Sql) {
337346
pgColsByTableOidCache,
338347
pgColsBySchemaAndTableName,
339348
pgFnsByName,
349+
pgTypeExprMap,
340350
};
341351
}
342352

@@ -696,3 +706,30 @@ async function getPgFunctions(sql: Sql) {
696706
returnType: row.returnType,
697707
}));
698708
}
709+
710+
async function getPgTypeExprMap(sql: Sql): Promise<Map<string, Map<string, Map<string, string>>>> {
711+
const rows = await sql<{ left: string; operator: string; right: string; result: string }[]>`
712+
SELECT
713+
l.typname as left,
714+
o.oprname as operator,
715+
r.typname as right,
716+
ret.typname AS result
717+
FROM pg_operator o
718+
JOIN pg_type l ON l.oid = o.oprleft
719+
JOIN pg_type r ON r.oid = o.oprright
720+
JOIN pg_type ret ON ret.oid = o.oprresult
721+
WHERE o.oprleft <> 0 AND o.oprright <> 0
722+
`;
723+
724+
const map = new Map<string, Map<string, Map<string, string>>>();
725+
726+
for (const row of rows) {
727+
const leftMap = map.get(row.left) ?? new Map<string, Map<string, string>>();
728+
const rightMap = leftMap.get(row.operator) ?? new Map<string, string>();
729+
rightMap.set(row.right, row.result);
730+
leftMap.set(row.operator, rightMap);
731+
map.set(row.left, leftMap);
732+
}
733+
734+
return map;
735+
}

0 commit comments

Comments
 (0)