Skip to content

Commit 16a9d1e

Browse files
committed
JNA callback should check inputs and outputs for null.
This addresses #27 once we merge extism/extism#760 and release a new version of libextism. This PR simply check the parameters for null and the counts to be valid, in preparation for extism/extism#760 which otherwise would cause a NullPointerException when outputs and inputs are empty. Signed-off-by: Edoardo Vacchi <[email protected]>
1 parent 1448042 commit 16a9d1e

File tree

3 files changed

+119
-75
lines changed

3 files changed

+119
-75
lines changed

Diff for: src/main/java/org/extism/sdk/HostFunction.java

+43-24
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,12 @@ public class HostFunction<T extends HostUserData> {
2020

2121
public final LibExtism.ExtismValType[] returns;
2222

23-
public final Optional<T> userData;
24-
2523
public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.ExtismValType[] returns, ExtismFunction f, Optional<T> userData) {
2624
this.freed = false;
2725
this.name = name;
2826
this.params = params;
2927
this.returns = returns;
30-
this.userData = userData;
31-
this.callback = (Pointer currentPlugin,
32-
LibExtism.ExtismVal inputs,
33-
int nInputs,
34-
LibExtism.ExtismVal outs,
35-
int nOutputs,
36-
Pointer data) -> {
37-
38-
LibExtism.ExtismVal[] outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);
39-
40-
f.invoke(
41-
new ExtismCurrentPlugin(currentPlugin),
42-
(LibExtism.ExtismVal[]) inputs.toArray(nInputs),
43-
outputs,
44-
userData
45-
);
46-
47-
for (LibExtism.ExtismVal output : outputs) {
48-
convertOutput(output, output);
49-
}
50-
};
28+
this.callback = new Callback(f, userData);
5129

5230
this.pointer = LibExtism.INSTANCE.extism_function_new(
5331
this.name,
@@ -61,7 +39,7 @@ public HostFunction(String name, LibExtism.ExtismValType[] params, LibExtism.Ext
6139
);
6240
}
6341

64-
void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
42+
static void convertOutput(LibExtism.ExtismVal original, LibExtism.ExtismVal fromHostFunction) {
6543
if (fromHostFunction.t != original.t)
6644
throw new ExtismException(String.format("Output type mismatch, got %d but expected %d", fromHostFunction.t, original.t));
6745

@@ -103,4 +81,45 @@ public void free() {
10381
this.freed = true;
10482
}
10583
}
84+
85+
static class Callback<T> implements LibExtism.InternalExtismFunction {
86+
private final ExtismFunction f;
87+
private final Optional<T> userData;
88+
89+
public Callback(ExtismFunction f, Optional<T> userData) {
90+
this.f = f;
91+
this.userData = userData;
92+
}
93+
94+
@Override
95+
public void invoke(Pointer currentPlugin, LibExtism.ExtismVal ins, int nInputs, LibExtism.ExtismVal outs, int nOutputs, Pointer data) {
96+
97+
LibExtism.ExtismVal[] inputs;
98+
LibExtism.ExtismVal[] outputs;
99+
100+
if (outs == null) {
101+
if (nOutputs > 0) {
102+
throw new ExtismException("Output array is null but nOutputs is greater than 0");
103+
}
104+
outputs = new LibExtism.ExtismVal[0];
105+
} else {
106+
outputs = (LibExtism.ExtismVal[]) outs.toArray(nOutputs);
107+
}
108+
109+
if (ins == null) {
110+
if (nInputs > 0) {
111+
throw new ExtismException("Input array is null but nInputs is greater than 0");
112+
}
113+
inputs = new LibExtism.ExtismVal[0];
114+
} else {
115+
inputs = (LibExtism.ExtismVal[]) ins.toArray(nInputs);
116+
}
117+
118+
f.invoke(new ExtismCurrentPlugin(currentPlugin), inputs, outputs, userData);
119+
120+
for (LibExtism.ExtismVal output : outputs) {
121+
convertOutput(output, output);
122+
}
123+
}
124+
}
106125
}

Diff for: src/test/java/org/extism/sdk/HostFunctionTests.java

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.extism.sdk;
2+
3+
import com.sun.jna.Pointer;
4+
import org.junit.jupiter.api.Test;
5+
6+
import static org.junit.jupiter.api.Assertions.assertThrows;
7+
8+
public class HostFunctionTests {
9+
@Test
10+
public void callbackShouldAcceptNullParameters() {
11+
var callback = new HostFunction.Callback<>(
12+
(plugin, params, returns, userData) -> {/* NOOP */}, null);
13+
callback.invoke(Pointer.NULL, null, 0, null, 0, Pointer.NULL);
14+
}
15+
16+
@Test
17+
public void callbackShouldThrowOnNullParametersAndNonzeroCounts() {
18+
var callback = new HostFunction.Callback<>(
19+
(plugin, params, returns, userData) -> {/* NOOP */}, null);
20+
assertThrows(ExtismException.class, () ->
21+
callback.invoke(Pointer.NULL, null, 1, null, 0, Pointer.NULL));
22+
assertThrows(ExtismException.class, () ->
23+
callback.invoke(Pointer.NULL, null, 0, null, 1, Pointer.NULL));
24+
}
25+
}

Diff for: src/test/java/org/extism/sdk/PluginTests.java

+51-51
Original file line numberDiff line numberDiff line change
@@ -53,57 +53,57 @@ public void shouldInvokeFunctionFromUrlWasmSource() {
5353
assertThat(output).isEqualTo("{\"count\":4,\"total\":4,\"vowels\":\"aeiouyAEIOUY\"}");
5454
}
5555

56-
// @Test
57-
// public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
58-
// var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
59-
// var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));
60-
//
61-
// // Our application KV store
62-
// // Pretend this is redis or a database :)
63-
// var kvStore = new HashMap<String, byte[]>();
64-
//
65-
// ExtismFunction kvWrite = (plugin, params, returns, data) -> {
66-
// System.out.println("Hello from Java Host Function!");
67-
// var key = plugin.inputString(params[0]);
68-
// var value = plugin.inputBytes(params[1]);
69-
// System.out.println("Writing to key " + key);
70-
// kvStore.put(key, value);
71-
// };
72-
//
73-
// ExtismFunction kvRead = (plugin, params, returns, data) -> {
74-
// System.out.println("Hello from Java Host Function!");
75-
// var key = plugin.inputString(params[0]);
76-
// System.out.println("Reading from key " + key);
77-
// var value = kvStore.get(key);
78-
// if (value == null) {
79-
// // default to zeroed bytes
80-
// var zero = new byte[]{0,0,0,0};
81-
// plugin.returnBytes(returns[0], zero);
82-
// } else {
83-
// plugin.returnBytes(returns[0], value);
84-
// }
85-
// };
86-
//
87-
// HostFunction kvWriteHostFn = new HostFunction<>(
88-
// "kv_write",
89-
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
90-
// new LibExtism.ExtismValType[0],
91-
// kvWrite,
92-
// Optional.empty()
93-
// );
94-
//
95-
// HostFunction kvReadHostFn = new HostFunction<>(
96-
// "kv_read",
97-
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
98-
// new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
99-
// kvRead,
100-
// Optional.empty()
101-
// );
102-
//
103-
// HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
104-
// var plugin = new Plugin(manifest, false, functions);
105-
// var output = plugin.call("count_vowels", "Hello, World!");
106-
// }
56+
@Test
57+
public void shouldInvokeFunctionFromUrlWasmSourceHostFuncs() {
58+
var url = "https://github.com/extism/plugins/releases/latest/download/count_vowels_kvstore.wasm";
59+
var manifest = new Manifest(List.of(UrlWasmSource.fromUrl(url)));
60+
61+
// Our application KV store
62+
// Pretend this is redis or a database :)
63+
var kvStore = new HashMap<String, byte[]>();
64+
65+
ExtismFunction kvWrite = (plugin, params, returns, data) -> {
66+
System.out.println("Hello from Java Host Function!");
67+
var key = plugin.inputString(params[0]);
68+
var value = plugin.inputBytes(params[1]);
69+
System.out.println("Writing to key " + key);
70+
kvStore.put(key, value);
71+
};
72+
73+
ExtismFunction kvRead = (plugin, params, returns, data) -> {
74+
System.out.println("Hello from Java Host Function!");
75+
var key = plugin.inputString(params[0]);
76+
System.out.println("Reading from key " + key);
77+
var value = kvStore.get(key);
78+
if (value == null) {
79+
// default to zeroed bytes
80+
var zero = new byte[]{0,0,0,0};
81+
plugin.returnBytes(returns[0], zero);
82+
} else {
83+
plugin.returnBytes(returns[0], value);
84+
}
85+
};
86+
87+
HostFunction kvWriteHostFn = new HostFunction<>(
88+
"kv_write",
89+
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64, LibExtism.ExtismValType.I64},
90+
new LibExtism.ExtismValType[0],
91+
kvWrite,
92+
Optional.empty()
93+
);
94+
95+
HostFunction kvReadHostFn = new HostFunction<>(
96+
"kv_read",
97+
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
98+
new LibExtism.ExtismValType[]{LibExtism.ExtismValType.I64},
99+
kvRead,
100+
Optional.empty()
101+
);
102+
103+
HostFunction[] functions = {kvWriteHostFn, kvReadHostFn};
104+
var plugin = new Plugin(manifest, false, functions);
105+
var output = plugin.call("count_vowels", "Hello, World!");
106+
}
107107

108108
@Test
109109
public void shouldInvokeFunctionFromByteArrayWasmSource() {

0 commit comments

Comments
 (0)