Skip to content

Commit 3b980f9

Browse files
authored
fix: improve type inference for array of enums (#437)
* fix: improve type inference for array of enums Enhances type inference for arrays of enums in the TypeScript generation logic, ensuring that it correctly expands to the enum's members instead of defaulting to `unknown[]`. Additionally, adds testing guidelines for running tests in the project. * lint
1 parent ac7a1b7 commit 3b980f9

File tree

5 files changed

+136
-15
lines changed

5 files changed

+136
-15
lines changed

.changeset/fancy-chefs-cough.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@ts-safeql/eslint-plugin": patch
3+
"@ts-safeql/generate": patch
4+
---
5+
6+
Fix inference for array of enums (it now expands to the enum's members instead of `unknown[]`).

.cursor/rules/testing.mdc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
---
2+
globs: *.test.ts
3+
alwaysApply: false
4+
---
5+
# Testing Rules
6+
7+
When running tests in this project, you must follow these steps:
8+
9+
1. **Navigate to the package directory**: `cd` into the specific package directory where the test is located (e.g., `cd packages/eslint-plugin`).
10+
2. **Run the test using pnpm and vitest**: Use the following command format:
11+
```bash
12+
pnpm vitest run [filename] -t "[testname]"
13+
```
14+
15+
Replace `[filename]` with the path to the test file relative to the package root (e.g., `src/rules/check-sql.test.ts`), and `[testname]` with the specific test name or a pattern matching the test name you want to run.
16+
17+
## Guidelines
18+
- Always `cd` to the package root before running `pnpm vitest`.
19+
- Use the `run` command for a single execution.
20+
- Use the `-t` flag to isolate specific tests, especially in large test files.
21+
22+
## Example
23+
24+
To run a specific test in `packages/eslint-plugin/src/rules/check-sql.test.ts`:
25+
26+
```bash
27+
cd packages/eslint-plugin
28+
pnpm vitest run src/rules/check-sql.test.ts -t "my test name"
29+
```

packages/generate/src/ast-describe.ts

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,14 @@ export function getASTDescription(params: ASTDescriptionOptions): {
114114
};
115115
}
116116

117+
const pgType = params.pgTypes.get(p.oid);
117118
const typeByOid = getTypeByOid(p.oid);
118119

119120
if (typeByOid.override) {
120121
const baseType: ASTDescribedColumnType = {
121122
kind: "type",
122123
value: typeByOid.value,
123-
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
124+
type: pgType?.name ?? "unknown",
124125
};
125126
return typeByOid.isArray ? { kind: "array", value: baseType } : baseType;
126127
}
@@ -136,25 +137,54 @@ export function getASTDescription(params: ASTDescriptionOptions): {
136137
return typeByBaseOid.isArray ? { kind: "array", value: baseType } : baseType;
137138
}
138139

139-
const enumValue = "oid" in p ? params.pgEnums.get(p.oid) : undefined;
140+
const getEnumByOid = (oid: number): ASTDescribedColumnType | undefined => {
141+
const pgEnum = params.pgEnums.get(oid);
142+
143+
if (pgEnum === undefined) {
144+
return undefined;
145+
}
140146

141-
if (enumValue !== undefined) {
142147
return {
143148
kind: "union",
144-
value: enumValue.values.map((value) => ({
149+
value: pgEnum.values.map((value) => ({
145150
kind: "type",
146151
value: `'${value}'`,
147-
type: enumValue.name,
152+
type: pgEnum.name,
148153
})),
149154
};
155+
};
156+
157+
const valueAsEnum = (() => {
158+
const enumType = getEnumByOid(p.oid);
159+
160+
if (enumType !== undefined) {
161+
return enumType;
162+
}
163+
164+
if (pgType?.typelem !== undefined && pgType.typelem !== 0) {
165+
const arrayEnumValue = getEnumByOid(pgType.typelem);
166+
167+
if (arrayEnumValue !== undefined) {
168+
return {
169+
kind: "array",
170+
value: arrayEnumValue,
171+
} satisfies ASTDescribedColumnType;
172+
}
173+
}
174+
175+
return undefined;
176+
})();
177+
178+
if (valueAsEnum !== undefined) {
179+
return valueAsEnum;
150180
}
151181

152182
const { isArray, value } = typeByBaseOid ?? typeByOid;
153183

154184
const type: ASTDescribedColumnType = {
155185
kind: "type",
156186
value: value,
157-
type: params.pgTypes.get(p.oid)?.name ?? "unknown",
187+
type: pgType?.name ?? "unknown",
158188
};
159189

160190
if (p.baseOid !== null) {

packages/generate/src/generate.test.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,3 +2482,29 @@ test("regression: wrong inference of nullable in aggregation", async () => {
24822482
],
24832483
});
24842484
});
2485+
2486+
test("select array of enums", async () => {
2487+
await testQuery({
2488+
schema: `
2489+
CREATE TYPE my_enum AS ENUM ('A', 'B', 'C');
2490+
CREATE TABLE test_array_enum (col my_enum[] NOT NULL);
2491+
`,
2492+
query: `SELECT col FROM test_array_enum`,
2493+
expected: [
2494+
[
2495+
"col",
2496+
{
2497+
kind: "array",
2498+
value: {
2499+
kind: "union",
2500+
value: [
2501+
{ kind: "type", value: "'A'", type: "my_enum" },
2502+
{ kind: "type", value: "'B'", type: "my_enum" },
2503+
{ kind: "type", value: "'C'", type: "my_enum" },
2504+
],
2505+
},
2506+
},
2507+
],
2508+
],
2509+
});
2510+
});

packages/generate/src/generate.ts

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import {
22
assertNever,
33
defaultTypesMap,
44
DuplicateColumnsError,
5-
fmap,
65
getOrSetFromMapWithEnabled,
76
groupBy,
87
IdentiferCase,
@@ -13,11 +12,11 @@ import {
1312
} from "@ts-safeql/shared";
1413
import * as LibPgQueryAST from "@ts-safeql/sql-ast";
1514
import { either } from "fp-ts";
15+
import * as parser from "libpg-query";
1616
import postgres from "postgres";
1717
import { ASTDescribedColumn, getASTDescription } from "./ast-describe";
1818
import { ColType } from "./utils/colTypes";
1919
import { FlattenedRelationWithJoins } from "./utils/get-relations-with-joins";
20-
import * as parser from "libpg-query";
2120
import { isParsedInsertResult, validateInsertResult } from "./utils/validate-insert";
2221

2322
type JSToPostgresTypeMap = Record<string, unknown>;
@@ -458,14 +457,44 @@ function getResolvedTargetEntry(params: {
458457
}
459458

460459
const pgTypeOid = params.col.introspected?.colBaseTypeOid ?? params.col.described.type;
460+
const pgType = params.context.pgTypes.get(pgTypeOid);
461+
462+
const getEnumByOid = (oid: number): ResolvedTarget | undefined => {
463+
const pgEnum = params.context.pgEnums.get(oid);
464+
465+
if (pgEnum === undefined) {
466+
return undefined;
467+
}
461468

462-
const valueAsEnum = fmap(
463-
params.context.pgEnums.get(pgTypeOid),
464-
({ values }): ResolvedTarget => ({
469+
return {
465470
kind: "union",
466-
value: values.map((x): ResolvedTarget => ({ kind: "type", value: `'${x}'`, type: "text" })),
467-
}),
468-
);
471+
value: pgEnum.values.map(
472+
(value): ResolvedTarget => ({
473+
kind: "type",
474+
value: `'${value}'`,
475+
type: pgEnum.name,
476+
}),
477+
),
478+
};
479+
};
480+
481+
const valueAsEnum = (() => {
482+
const enumTarget = getEnumByOid(pgTypeOid);
483+
484+
if (enumTarget !== undefined) {
485+
return enumTarget;
486+
}
487+
488+
if (pgType?.typelem !== undefined && pgType.typelem !== 0) {
489+
const elemEnumTarget = getEnumByOid(pgType.typelem);
490+
491+
if (elemEnumTarget !== undefined) {
492+
return { kind: "array", value: elemEnumTarget } satisfies ResolvedTarget;
493+
}
494+
}
495+
496+
return undefined;
497+
})();
469498

470499
const valueAsType = getTsTypeFromPgTypeOid({
471500
pgTypeOid: pgTypeOid,
@@ -599,13 +628,14 @@ async function getPgEnums(sql: Sql): Promise<PgEnumsMaps> {
599628
interface PgTypeRow {
600629
oid: number;
601630
name: ColType;
631+
typelem: number;
602632
}
603633

604634
export type PgTypesMap = Map<number, PgTypeRow>;
605635

606636
async function getPgTypes(sql: Sql): Promise<PgTypesMap> {
607637
const rows = await sql<PgTypeRow[]>`
608-
SELECT oid, typname as name FROM pg_type
638+
SELECT oid, typname as name, typelem FROM pg_type
609639
`;
610640

611641
const map = new Map<number, PgTypeRow>();

0 commit comments

Comments
 (0)