diff --git a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java
index 7c046706ebd6..901545e3bdf6 100644
--- a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java
+++ b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/WasmModule.java
@@ -28,9 +28,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.SequencedMap;
+import java.util.Set;
import com.oracle.svm.hosted.webimage.wasm.ast.id.WasmId;
import com.oracle.svm.hosted.webimage.wasmgc.ast.RecursiveGroup;
@@ -65,6 +67,16 @@ public class WasmModule {
protected final ActiveData activeData = new ActiveData();
+ /**
+ * Functions that need to be declared in a declarative element segment.
+ *
+ * Per the WebAssembly spec, any function referenced by {@code ref.func} outside of an
+ * active/passive element segment must be declared in a declarative element segment.
+ * This is required for validation by strict validators like {@code wasm-tools validate}
+ * and {@code wasmtime}.
+ */
+ protected final Set declarativeFuncRefs = new LinkedHashSet<>();
+
protected StartFunction startFunction = null;
public void addFunction(Function fun) {
@@ -161,6 +173,17 @@ public void addActiveData(long offset, byte[] data) {
activeData.addData(offset, data);
}
+ /**
+ * Declares a function reference for a declarative element segment.
+ */
+ public void addDeclarativeFuncRef(WasmId.Func func) {
+ declarativeFuncRefs.add(func);
+ }
+
+ public Set getDeclarativeFuncRefs() {
+ return Collections.unmodifiableSet(declarativeFuncRefs);
+ }
+
public void constructActiveDataSegments() {
// Limit the number of data segments so that we don't exceet MAX_DATA_SEGMENTS.
activeData.constructDataSegments(MAX_DATA_SEGMENTS - this.dataSegments.size()).forEach(this::addData);
diff --git a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java
index e984ae8009d6..48ef1d1a37b6 100644
--- a/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java
+++ b/web-image/src/com.oracle.svm.hosted.webimage/src/com/oracle/svm/hosted/webimage/wasm/ast/visitors/WasmPrinter.java
@@ -27,6 +27,8 @@
import java.io.IOException;
import java.io.Writer;
+import java.util.LinkedHashSet;
+import java.util.Set;
import com.oracle.svm.hosted.webimage.options.WebImageOptions;
import com.oracle.svm.hosted.webimage.wasm.WebImageWasmOptions;
@@ -375,10 +377,13 @@ private void printExtensionSuffix(WasmUtil.Extension extension) {
@Override
@SuppressWarnings("try")
public void visitModule(WasmModule m) {
+ collectDeclarativeFuncRefs(m);
+
parenOpen("module");
space();
try (var ignored = new Indenter()) {
super.visitModule(m);
+ emitDeclarativeFuncRefs(m);
}
newline();
@@ -386,6 +391,67 @@ public void visitModule(WasmModule m) {
newline();
}
+ /**
+ * Emits declarative element segments for all functions referenced by {@code ref.func}.
+ *
+ * In WAT: {@code (elem declare func $f1 $f2 ...)}
+ */
+ private void emitDeclarativeFuncRefs(WasmModule m) {
+ Set funcRefs = m.getDeclarativeFuncRefs();
+ if (funcRefs.isEmpty()) {
+ return;
+ }
+
+ newline();
+ newline();
+ printComment("Declarative element segment for ref.func declarations");
+ newline();
+ parenOpen("elem declare func");
+ for (WasmId.Func func : funcRefs) {
+ space();
+ printId(func);
+ }
+ parenClose();
+ }
+
+ /**
+ * Scans all functions and globals for {@code ref.func} instructions and registers them
+ * as declarative function references in the module.
+ *
+ * Per the WebAssembly spec, functions referenced by {@code ref.func} outside of active
+ * or passive element segments must be declared in a declarative element segment
+ * ({@code (elem declare func ...)}).
+ */
+ private void collectDeclarativeFuncRefs(WasmModule m) {
+ Set funcRefs = new LinkedHashSet<>();
+
+ // Collect from function bodies
+ RefFuncCollector collector = new RefFuncCollector(funcRefs);
+ for (Function func : m.getFunctions()) {
+ collector.visitFunction(func);
+ }
+
+ // Collect from global initializers
+ for (Global global : m.getGlobals().sequencedValues()) {
+ collector.visitInstruction(global.init);
+ }
+
+ // Collect from table element initializers
+ for (Table table : m.getTables()) {
+ if (table.elements != null) {
+ for (Instruction elem : table.elements) {
+ collector.visitInstruction(elem);
+ }
+ }
+ }
+
+ // Functions already in active table elements don't need declarative declaration,
+ // but including them is harmless and simpler than filtering.
+ for (WasmId.Func func : funcRefs) {
+ m.addDeclarativeFuncRef(func);
+ }
+ }
+
@Override
public void visitModuleField(ModuleField f) {
newline();
@@ -1483,4 +1549,22 @@ public void visitAnyExternConversion(Instruction.AnyExternConversion inst) {
}
newline();
}
+
+ /**
+ * Visitor that collects all function IDs referenced by {@code ref.func} instructions.
+ */
+ private static class RefFuncCollector extends WasmVisitor {
+
+ private final Set funcRefs;
+
+ RefFuncCollector(Set funcRefs) {
+ this.funcRefs = funcRefs;
+ }
+
+ @Override
+ public void visitRefFunc(Instruction.RefFunc inst) {
+ funcRefs.add(inst.func);
+ super.visitRefFunc(inst);
+ }
+ }
}