Skip to content

Commit 1414663

Browse files
committed
finishing tidying up example code
1 parent 8415e7e commit 1414663

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

bootcamp/tutorials/quickstart/onnx_example/java/app/src/main/java/org/example/App.java

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,62 +7,68 @@
77
import java.util.Map;
88

99
import ai.onnxruntime.*;
10-
// import ai.onnxruntime.extensions.*;
1110
import ai.onnxruntime.extensions.*;
1211

1312
@SuppressWarnings("unused")
1413
public class App {
15-
public String getGreeting() {
16-
return "Hello World 2!";
17-
}
18-
1914
public static void main(String[] args) {
2015
var env = OrtEnvironment.getEnvironment();
2116

2217
try {
2318
var sess_opt = new OrtSession.SessionOptions();
19+
20+
// NOTE: ONNXRuntimeExtensions for Java on Apple Silicon isn't currently
21+
// available
22+
// Hence I'll comment out these lines for my machine
2423
// sess_opt.registerCustomOpLibrary(OrtxPackage.getLibraryPath());
24+
25+
// Depending on how you call App within Visual Studio, may need to add app/ to
26+
// filenames below
2527
System.out.println(System.getProperty("user.dir"));
2628

27-
// var tokenizer = env.createSession("app/tokenizer.onnx", new
28-
// OrtSession.SessionOptions());
29+
// Try out tokenizer
30+
// var tokenizer = env.createSession("tokenizer.onnx", sess_opt);
2931

32+
// Get input and output node names for tokenizer
3033
// var inputName = tokenizer.getInputNames().iterator().next();
3134
// var outputName = tokenizer.getOutputNames().iterator().next();
32-
3335
// System.out.println(inputName);
3436
// System.out.println(outputName);
3537

36-
var session = env.createSession("app/model.onnx", sess_opt);
38+
// Try out embedding model
39+
var session = env.createSession("model.onnx", sess_opt);
3740

3841
// Get input and output names
39-
4042
var inputName = session.getInputNames().iterator().next();
4143
var outputName = session.getOutputNames().iterator().next();
44+
System.out.println(inputName);
45+
System.out.println(outputName);
4246

43-
// System.out.println(inputName);
44-
// System.out.println(outputName);
45-
47+
// Since I wasn't able to run tokenizer via ONNX on Apple Silicon, hardcode
48+
// token ids from Python
49+
// This is for "The quick brown fox..."
4650
long[][] tokens = { { 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 1012, 102 } };
4751
long[][] masks = { { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
4852
long[][] token_type_ids = { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } };
4953

54+
// Wrap native Java types in OnnxTensor's and put in input map
5055
var test_tokens = OnnxTensor.createTensor(env, tokens);
5156
var test_mask = OnnxTensor.createTensor(env, masks);
5257
var test_token_type_ids = OnnxTensor.createTensor(env, token_type_ids);
53-
5458
var inputs = Map.of("input_ids", test_tokens, "attention_mask", test_mask, "token_type_ids",
5559
test_token_type_ids);
56-
var results = session.run(inputs).get("embeddings");
5760

58-
// System.out.println("type");
59-
// System.out.println(results.get().getType());
60-
61-
// float[][][] embeddings = (float[][][]) results.get().getValue();
61+
// Run embedding model on tokens and convert back to native Java type
62+
var results = session.run(inputs).get("embeddings");
6263
float[][] embeddings = (float[][]) results.get().getValue();
6364

65+
// Print the first 16 dimensions of the resulting embedding
6466
var result = Arrays.toString(Arrays.copyOfRange(embeddings[0], 0, 16));
6567
System.out.println(result);
68+
69+
// Comparing this to our Python notebook, we see it is identical! I.e. calling
70+
// the ONNXRuntime on the same model from Python and Java produces same results
71+
// (up to numerical precision etc.)
6672
} catch (Exception e) {
6773
System.out.println(e);
6874
}

bootcamp/tutorials/quickstart/onnx_example/python/export_hf_to_onnx.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,9 @@
688688
"id": "987319b9",
689689
"metadata": {},
690690
"source": [
691-
"The embeddings between the ONNX model and the saved HuggingFace weights will differ as they're from different trainings."
691+
"The embeddings between the ONNX model and the saved HuggingFace weights will differ as they're from different trainings.\n",
692+
"\n",
693+
"In the corresponding Java app, you can compare the embedding below to the one produced via Java."
692694
]
693695
},
694696
{

0 commit comments

Comments
 (0)