Skip to content

Commit 941169d

Browse files
authored
feat: add type notation to type imports (#264)
1 parent 38c795f commit 941169d

File tree

3 files changed

+80
-38
lines changed

3 files changed

+80
-38
lines changed

src/core/generate.test.ts

+36-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ describe("generate", () => {
326326
expect(getZodSchemasFile("./villain")).toMatchInlineSnapshot(`
327327
"// Generated by ts-to-zod
328328
import { z } from "zod";
329-
import { Villain, EvilPlan, EvilPlanDetails } from "./villain";
329+
import { type Villain, type EvilPlan, type EvilPlanDetails } from "./villain";
330330
331331
export const villainSchema: z.ZodSchema<Villain> = z.lazy(() => z.object({
332332
name: z.string(),
@@ -961,6 +961,41 @@ describe("generate", () => {
961961
`);
962962
});
963963
});
964+
965+
describe("with mixed imports", () => {
966+
const input = "./person";
967+
968+
const sourceText = `
969+
import { PersonEnum } from "${input}"
970+
971+
export interface Hero {
972+
id: number
973+
hero: PersonEnum.Hero
974+
parent: Hero
975+
}
976+
`;
977+
978+
const { getZodSchemasFile } = generate({
979+
sourceText,
980+
});
981+
982+
it("should add type notation to non-enum imports", () => {
983+
expect(getZodSchemasFile(input)).toMatchInlineSnapshot(`
984+
"// Generated by ts-to-zod
985+
import { z } from "zod";
986+
import { type Hero } from "${input}";
987+
988+
import { PersonEnum } from "${input}";
989+
990+
export const heroSchema: z.ZodSchema<Hero> = z.lazy(() => z.object({
991+
id: z.number(),
992+
hero: z.literal(PersonEnum.Hero),
993+
parent: heroSchema
994+
}));
995+
"
996+
`);
997+
});
998+
});
964999
});
9651000

9661001
describe("with input/output mappings to manage imports", () => {

src/core/generate.ts

+41-34
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ export function generate({
266266
varName,
267267
zodImportValue: "z",
268268
}),
269-
requiresImport: false,
269+
enumImport: false,
270270
typeName: importName,
271271
varName,
272272
};
@@ -282,8 +282,9 @@ export function generate({
282282
{ typeName: string; value: ts.VariableStatement }
283283
>();
284284

285-
// Keep track of types which need to be imported from the source file
285+
// Keep track of types/enums which need to be imported from the source file
286286
const sourceTypeImports: Set<string> = new Set();
287+
const sourceEnumImports: Set<string> = new Set();
287288

288289
// Zod schemas with direct or indirect dependencies that are not in `zodSchemas`, won't be generated
289290
const zodSchemasWithMissingDependencies = new Set<string>();
@@ -302,40 +303,38 @@ export function generate({
302303
!statements.has(varName) &&
303304
!zodSchemasWithMissingDependencies.has(varName)
304305
)
305-
.forEach(
306-
({ varName, dependencies, statement, typeName, requiresImport }) => {
307-
const isCircular = dependencies.includes(varName);
308-
const notGeneratedDependencies = dependencies
309-
.filter((dep) => dep !== varName)
310-
.filter((dep) => !statements.has(dep))
311-
.filter((dep) => !importedZodSchemas.has(dep));
312-
if (notGeneratedDependencies.length === 0) {
313-
done = false;
314-
if (isCircular) {
315-
sourceTypeImports.add(typeName);
316-
statements.set(varName, {
317-
value: transformRecursiveSchema("z", statement, typeName),
318-
typeName,
319-
});
320-
} else {
321-
if (requiresImport) {
322-
sourceTypeImports.add(typeName);
323-
}
324-
statements.set(varName, { value: statement, typeName });
306+
.forEach(({ varName, dependencies, statement, typeName, enumImport }) => {
307+
const isCircular = dependencies.includes(varName);
308+
const notGeneratedDependencies = dependencies
309+
.filter((dep) => dep !== varName)
310+
.filter((dep) => !statements.has(dep))
311+
.filter((dep) => !importedZodSchemas.has(dep));
312+
if (notGeneratedDependencies.length === 0) {
313+
done = false;
314+
if (isCircular) {
315+
sourceTypeImports.add(typeName);
316+
statements.set(varName, {
317+
value: transformRecursiveSchema("z", statement, typeName),
318+
typeName,
319+
});
320+
} else {
321+
if (enumImport) {
322+
sourceEnumImports.add(typeName);
325323
}
326-
} else if (
327-
// Check if every dependency is (in `zodSchemas` and not in `zodSchemasWithMissingDependencies`)
328-
!notGeneratedDependencies.every(
329-
(dep) =>
330-
zodSchemaNames.includes(dep) &&
331-
!zodSchemasWithMissingDependencies.has(dep)
332-
)
333-
) {
334-
done = false;
335-
zodSchemasWithMissingDependencies.add(varName);
324+
statements.set(varName, { value: statement, typeName });
336325
}
326+
} else if (
327+
// Check if every dependency is (in `zodSchemas` and not in `zodSchemasWithMissingDependencies`)
328+
!notGeneratedDependencies.every(
329+
(dep) =>
330+
zodSchemaNames.includes(dep) &&
331+
!zodSchemasWithMissingDependencies.has(dep)
332+
)
333+
) {
334+
done = false;
335+
zodSchemasWithMissingDependencies.add(varName);
337336
}
338-
);
337+
});
339338
}
340339

341340
// Generate remaining schemas, which have circular dependencies with loop of length > 1 like: A->B—>C->A
@@ -390,7 +389,15 @@ ${Array.from(zodSchemasWithMissingDependencies).join("\n")}`
390389
)
391390
);
392391

393-
const sourceTypeImportsValues = Array.from(sourceTypeImports.values());
392+
const sourceTypeImportsValues = [
393+
...sourceTypeImports.values(),
394+
...sourceEnumImports.values(),
395+
].map((name) => {
396+
return sourceEnumImports.has(name)
397+
? name // enum import, no type notation added
398+
: `type ${name}`;
399+
});
400+
394401
const getZodSchemasFile = (
395402
typesImportPath: string
396403
) => `// Generated by ts-to-zod

src/core/generateZodSchema.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ export function generateZodSchemaVariableStatement({
8484
| ts.PropertyAccessExpression
8585
| undefined;
8686
let dependencies: string[] = [];
87-
let requiresImport = false;
87+
let enumImport = false;
8888

8989
if (ts.isInterfaceDeclaration(node)) {
9090
let schemaExtensionClauses: SchemaExtensionClause[] | undefined;
@@ -177,7 +177,7 @@ export function generateZodSchemaVariableStatement({
177177

178178
if (ts.isEnumDeclaration(node)) {
179179
schema = buildZodSchema(zodImportValue, "nativeEnum", [node.name]);
180-
requiresImport = true;
180+
enumImport = true;
181181
}
182182

183183
return {
@@ -196,7 +196,7 @@ export function generateZodSchemaVariableStatement({
196196
ts.NodeFlags.Const
197197
)
198198
),
199-
requiresImport,
199+
enumImport,
200200
};
201201
}
202202

0 commit comments

Comments
 (0)