Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/generator-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type {
TemplatePass2,
TemplatePass1,
GenerateOptions,
ParsedLine,
} from "./types.js";

interface FunctionCallState {
Expand All @@ -26,7 +27,8 @@ interface FunctionCallState {
}

interface GeneratorState {
readonly pass1: readonly string[];
readonly repo: TemplateRepository<TemplatePass1>;
readonly pass1: readonly ParsedLine[];
readonly codeGenerator: CodeGenerator;

filePath: string;
Expand Down Expand Up @@ -156,8 +158,9 @@ function generateImpl(generatorState: GeneratorState, options: GenerateOptions)
}
};

let previousLineWasEmpty = true;
for (let i = 0; i < pass1.length; i++) {
const line = pass1[i];
const line = pass1[i].line;
currentLine = i;
currentColumn = 0;

Expand Down Expand Up @@ -387,7 +390,26 @@ function generateImpl(generatorState: GeneratorState, options: GenerateOptions)
const maxLineNumber = pass1.length;
const lineNumberWidth = String(maxLineNumber).length;
const paddedLineNumber = String(currentLine + 1).padStart(lineNumberWidth, " ");
output("raw", `// ${paddedLineNumber} | ${line}\n`);
const sourcePath = pass1[i].codeReference.filePath;
const sourceLine = generatorState.repo.templates.get(sourcePath)!.raw[pass1[i].codeReference.lineNumber - 1];
output("raw", `// ${paddedLineNumber} | ${sourceLine}\n`);
}

if (line === "") {
if (i === pass1.length - 1) {
// If this is the last line and it's empty, we can skip it
continue;
}

if (previousLineWasEmpty) {
// If we are ignoring empty lines, skip this line
continue;
}

// When previous line was not empty, will output the current empty line but reset the flag
previousLineWasEmpty = true;
} else {
previousLineWasEmpty = false;
}

if (line.startsWith("#")) {
Expand Down Expand Up @@ -662,6 +684,7 @@ function generateImpl(generatorState: GeneratorState, options: GenerateOptions)
);
}
}
previousLineWasEmpty = true; // Reset the empty line flag after processing a preprocessor directive
} else {
processCurrentLine();
}
Expand All @@ -682,6 +705,7 @@ const generate = (
}

const generatorState: GeneratorState = {
repo,
pass1,
codeGenerator,
filePath,
Expand Down
125 changes: 52 additions & 73 deletions src/parser-impl.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import type { Parser, TemplateRepository, TemplatePass1, TemplatePass0 } from "./types.js";
import type { Parser, TemplateRepository, TemplatePass1, TemplatePass0, ParsedLine } from "./types.js";
import { WgslTemplateParseError } from "./errors.js";

type ParseState = Map<
string,
{
lines: readonly ParsedLine[];
includeProcessed: boolean;
}
>;

/**
* Parses raw content of a template file and remove comments.
*
* This function removes both single-line and multi-line comments
* while preserving the original line structure. Empty lines after comment removal
* are preserved to maintain line numbers for error reporting.
*/
function parseComments(raw: readonly string[]): string[] {
function parseComments(filePath: string, raw: readonly string[]): string[] {
const rawWithoutComments: string[] = [];
let inMultiLineComment = false;

Expand Down Expand Up @@ -47,6 +55,12 @@ function parseComments(raw: readonly string[]): string[] {
rawWithoutComments.push(processedLine.trimEnd());
}

if (inMultiLineComment) {
throw new WgslTemplateParseError("Unterminated multi-line comment detected in template", "comment-removal", {
filePath,
});
}

return rawWithoutComments;
}

Expand All @@ -56,17 +70,8 @@ function parseComments(raw: readonly string[]): string[] {
* @param includeStack represents the stack of currently processed includes
* @param parseState a map that stores the state of each file being parsed
*/
function parsePreprocessorIncludeDirectives(
includeStack: string[],
parseState: Map<
string,
{
lines: string[];
includeProcessed: boolean;
}
>
): void {
const lines: string[] = [];
function parsePreprocessorIncludeDirectives(includeStack: string[], parseState: ParseState): void {
const lines: ParsedLine[] = [];

const currentFile = includeStack[includeStack.length - 1];
const currentState = parseState.get(currentFile);
Expand All @@ -83,7 +88,8 @@ function parsePreprocessorIncludeDirectives(
}

for (let lineNumber = 0; lineNumber < currentState.lines.length; lineNumber++) {
const line = currentState.lines[lineNumber];
const parsedLine = currentState.lines[lineNumber];
const line = parsedLine.line;
// Process each line and extract include directives
const includeMatch = line.match(/^#include\s+(.+)$/);
if (includeMatch) {
Expand Down Expand Up @@ -121,8 +127,9 @@ function parsePreprocessorIncludeDirectives(
lines.push(...parseState.get(includePath)!.lines);
includeStack.pop();
} else {
// If no include directive, just add the line to the result
lines.push(line);
// If no include directive, just add the original line to the result
// (preserve comments in the output)
lines.push(parsedLine);
}
}

Expand All @@ -137,12 +144,13 @@ function parsePreprocessorIncludeDirectives(
* @param fileName Name of the file being processed (for error reporting)
* @returns Array of lines with macros defined and substituted
*/
function parseMacroDirectives(lines: string[], fileName: string): string[] {
function parseMacroDirectives(lines: ParsedLine[], fileName: string): ParsedLine[] {
const macros = new Map<string, string>();
const processedLines: string[] = [];
const processedLines: ParsedLine[] = [];

for (let lineNumber = 0; lineNumber < lines.length; lineNumber++) {
const line = lines[lineNumber];
const parsedLine = lines[lineNumber];
let line = parsedLine.line;

// Check for malformed #define directives
if (line.trim().startsWith("#define ")) {
Expand Down Expand Up @@ -236,19 +244,22 @@ function parseMacroDirectives(lines: string[], fileName: string): string[] {
}

macros.set(macroName, expandedValue);
// Don't include the #define line in output
continue;
}

// Apply macro substitutions to the current line
let processedLine = line;
for (const [macroName, macroValue] of macros) {
// Use word boundaries to ensure we only replace whole identifiers
const regex = new RegExp(`\\b${macroName}\\b`, "g");
processedLine = processedLine.replace(regex, macroValue);
// Clear the line since it's a macro definition
line = "";
} else {
// Apply macro substitutions to the current line
for (const [macroName, macroValue] of macros) {
// Use word boundaries to ensure we only replace whole identifiers
const regex = new RegExp(`\\b${macroName}\\b`, "g");
line = line.replace(regex, macroValue);
}
}

processedLines.push(processedLine);
processedLines.push({
line,
codeReference: parsedLine.codeReference,
});
}

return processedLines;
Expand All @@ -269,22 +280,20 @@ export const parser: Parser = {
*/ parse(repo: TemplateRepository<TemplatePass0>): TemplateRepository<TemplatePass1> {
const pass1Repo = new Map<string, TemplatePass1>();

const parseState = new Map<
string,
{
lines: string[];
includeProcessed: boolean;
}
>();
const parseState: ParseState = new Map();

// STEP.1. Parse comments. Segments now contains:
// - Raw segments
// - Comment segments
// STEP.1. Parse comments.
for (const [templateKey, template] of repo.templates) {
const rawWithoutComments = parseComments(template.raw);
const parsedLines = parseComments(templateKey, template.raw);

parseState.set(templateKey, {
lines: rawWithoutComments,
lines: parsedLines.map((line, index) => ({
line,
codeReference: {
filePath: templateKey,
lineNumber: index + 1, // Line numbers are 1-based
},
})),
includeProcessed: false,
});
}
Expand All @@ -295,6 +304,7 @@ export const parser: Parser = {

pass1Repo.set(templateKey, {
filePath: template.filePath,
raw: template.raw,
pass1: parseState.get(templateKey)!.lines,
});
}
Expand All @@ -304,42 +314,11 @@ export const parser: Parser = {
const processedLines = parseMacroDirectives([...template.pass1], templateKey);
pass1Repo.set(templateKey, {
filePath: template.filePath,
raw: template.raw,
pass1: processedLines,
});
}

// STEP.4. Deal with empty lines:
// - Collapse multiple empty lines to single empty line
// - Remove heading/trailing empty lines
for (const [templateKey, template] of pass1Repo) {
const lines = template.pass1;
const collapsedLines: string[] = [];
let lastLineEmpty = false;
for (const line of lines) {
const isEmpty = line.trim() === "";
if (isEmpty) {
if (!lastLineEmpty) {
collapsedLines.push(line);
lastLineEmpty = true;
}
} else {
collapsedLines.push(line);
lastLineEmpty = false;
}
}
// Remove leading/trailing empty lines
while (collapsedLines.length > 0 && collapsedLines[0].trim() === "") {
collapsedLines.shift();
}
while (collapsedLines.length > 0 && collapsedLines[collapsedLines.length - 1].trim() === "") {
collapsedLines.pop();
}
pass1Repo.set(templateKey, {
filePath: template.filePath,
pass1: collapsedLines,
});
}

return {
basePath: repo.basePath,
templates: pass1Repo,
Expand Down
12 changes: 10 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@ export type TemplatePass0 = TemplateBase & {
readonly raw: readonly string[];
};

export type TemplatePass1 = TemplateBase & {
export interface ParsedLine {
line: string;
codeReference: {
filePath: string;
lineNumber: number;
};
}

export type TemplatePass1 = TemplatePass0 & {
/**
* The content after pass 1 processing, including:
* - comments removal
* - #include expansion
* - #define expansion
*/
readonly pass1: readonly string[];
readonly pass1: readonly ParsedLine[];
};

export type TemplatePass2 = TemplateBase & {
Expand Down
2 changes: 1 addition & 1 deletion test/test-runner-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export async function runParserTest(testCase: TestCase, debug?: boolean): Promis

// Parse expected content into lines (normalize line endings)
const expectedLines = expectedContent.split(/\r?\n/);
const actualLines = Array.from(template.pass1);
const actualLines = template.pass1.map((line) => line.line);

// Compare line by line
if (actualLines.length !== expectedLines.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Status ApplyTemplate<"tensor/pad.wgsl.template">(ShaderHelper& shader_helper, Te
// Extract variables
auto& __var_output = *params.var_output;

(*ss_ptr) << "\n\n";
MainFunctionStart();
(*ss_ptr) << "\n ";
(*ss_ptr) << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
Expand All @@ -27,7 +26,6 @@ if (__param_is_float16) {
} else {
(*ss_ptr) << " bitcast<output_value_t>(uniforms.constant_value);\n";
}
(*ss_ptr) << "\n";
if (__param_dim_value_zero) {
(*ss_ptr) << " output[global_idx] = constant_value;\n";
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::string pass_as_string(T&& v) {

// Include template implementations

#include "generated/tensor/pad.h" // 6322dc19e1e013413b06fa369e2961e3683123a3f4a36e646798f3219b132f49
#include "generated/tensor/pad.h" // 0a0ab18c4abbbd85c08852d67af2741d77f44c28c604fe599c54a0d050ea354f

#pragma pop_macro("MainFunctionStart")
#pragma pop_macro("MainFunctionEnd")
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
emit("\n");
emit("\n");
MainFunctionStart();
emit("\n");
emit(" ");
Expand All @@ -12,7 +10,6 @@ emit(" bitcast<vec2<f16>>(uniforms.constant_value)[0];\n");
} else {
emit(" bitcast<output_value_t>(uniforms.constant_value);\n");
}
emit("\n");
if (param["dim_value_zero"]) {
emit(" output[global_idx] = constant_value;\n");
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
(*ss_ptr) << "\n\n";
MainFunctionStart();
(*ss_ptr) << "\n ";
(*ss_ptr) << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
Expand All @@ -8,7 +7,6 @@ if (__param_is_float16) {
} else {
(*ss_ptr) << " bitcast<output_value_t>(uniforms.constant_value);\n";
}
(*ss_ptr) << "\n";
if (__param_dim_value_zero) {
(*ss_ptr) << " output[global_idx] = constant_value;\n";
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
emit("\n");
emit("fn process_data() {\n");
emit(" return ");
emit(GetElementAt("buffer", `${param["CHANNEL"]} - 1u`, variable["buffer"].Rank()));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
emit("\n");
emit("fn compute_kernel() {\n");
emit(" let global_idx = get_global_id(0u);\n");
emit("\n");
Expand All @@ -21,4 +20,4 @@ emit(" ");
emit(variable["b"].SetByIndices("indices", "42.0"));
emit(";\n");
emit(" }\n");
emit("}\n");
emit("}\n");
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
emit("\n");
emit("fn my_func(value: f32) {\n");
if (param["ENABLE_FACTOR"]) {
emit(" return value + 0.1;\n");
Expand Down
Loading