diff --git a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java index 7c046706ebd6..901545e3bdf6 100644 --- a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java +++ b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java @@ -28,9 +28,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; import java.util.SequencedMap; +import java.util.Set; import com.oracle.svm.hosted.webimage.wasm.ast.id.WasmId; import com.oracle.svm.hosted.webimage.wasmgc.ast.RecursiveGroup; @@ -65,6 +67,16 @@ public class WasmModule { protected final ActiveData activeData = new ActiveData(); + /** + * Functions that need to be declared in a declarative element segment. + *

+ * Per the WebAssembly spec, any function referenced by {@code ref.func} outside of an + * active/passive element segment must be declared in a declarative element segment. + * This is required for validation by strict validators like {@code wasm-tools validate} + * and {@code wasmtime}. + */ + protected final Set declarativeFuncRefs = new LinkedHashSet<>(); + protected StartFunction startFunction = null; public void addFunction(Function fun) { @@ -161,6 +173,17 @@ public void addActiveData(long offset, byte[] data) { activeData.addData(offset, data); } + /** + * Declares a function reference for a declarative element segment. + */ + public void addDeclarativeFuncRef(WasmId.Func func) { + declarativeFuncRefs.add(func); + } + + public Set getDeclarativeFuncRefs() { + return Collections.unmodifiableSet(declarativeFuncRefs); + } + public void constructActiveDataSegments() { // Limit the number of data segments so that we don't exceet MAX_DATA_SEGMENTS. activeData.constructDataSegments(MAX_DATA_SEGMENTS - this.dataSegments.size()).forEach(this::addData); diff --git a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java index e984ae8009d6..48ef1d1a37b6 100644 --- a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java +++ b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java @@ -27,6 +27,8 @@ import java.io.IOException; import java.io.Writer; +import java.util.LinkedHashSet; +import java.util.Set; import com.oracle.svm.hosted.webimage.options.WebImageOptions; import com.oracle.svm.hosted.webimage.wasm.WebImageWasmOptions; @@ -375,10 +377,13 @@ private void printExtensionSuffix(WasmUtil.Extension extension) { @Override @SuppressWarnings("try") public void visitModule(WasmModule m) { + collectDeclarativeFuncRefs(m); + parenOpen("module"); space(); try (var ignored = new Indenter()) { super.visitModule(m); + emitDeclarativeFuncRefs(m); } newline(); @@ -386,6 +391,67 @@ public void visitModule(WasmModule m) { newline(); } + /** + * Emits declarative element segments for all functions referenced by {@code ref.func}. + *

+ * In WAT: {@code (elem declare func $f1 $f2 ...)} + */ + private void emitDeclarativeFuncRefs(WasmModule m) { + Set funcRefs = m.getDeclarativeFuncRefs(); + if (funcRefs.isEmpty()) { + return; + } + + newline(); + newline(); + printComment("Declarative element segment for ref.func declarations"); + newline(); + parenOpen("elem declare func"); + for (WasmId.Func func : funcRefs) { + space(); + printId(func); + } + parenClose(); + } + + /** + * Scans all functions and globals for {@code ref.func} instructions and registers them + * as declarative function references in the module. + *

+ * Per the WebAssembly spec, functions referenced by {@code ref.func} outside of active + * or passive element segments must be declared in a declarative element segment + * ({@code (elem declare func ...)}). + */ + private void collectDeclarativeFuncRefs(WasmModule m) { + Set funcRefs = new LinkedHashSet<>(); + + // Collect from function bodies + RefFuncCollector collector = new RefFuncCollector(funcRefs); + for (Function func : m.getFunctions()) { + collector.visitFunction(func); + } + + // Collect from global initializers + for (Global global : m.getGlobals().sequencedValues()) { + collector.visitInstruction(global.init); + } + + // Collect from table element initializers + for (Table table : m.getTables()) { + if (table.elements != null) { + for (Instruction elem : table.elements) { + collector.visitInstruction(elem); + } + } + } + + // Functions already in active table elements don't need declarative declaration, + // but including them is harmless and simpler than filtering. + for (WasmId.Func func : funcRefs) { + m.addDeclarativeFuncRef(func); + } + } + @Override public void visitModuleField(ModuleField f) { newline(); @@ -1483,4 +1549,22 @@ public void visitAnyExternConversion(Instruction.AnyExternConversion inst) { } newline(); } + + /** + * Visitor that collects all function IDs referenced by {@code ref.func} instructions. + */ + private static class RefFuncCollector extends WasmVisitor { + + private final Set funcRefs; + + RefFuncCollector(Set funcRefs) { + this.funcRefs = funcRefs; + } + + @Override + public void visitRefFunc(Instruction.RefFunc inst) { + funcRefs.add(inst.func); + super.visitRefFunc(inst); + } + } }