|
7 | 7 | import java.util.Map; |
8 | 8 |
|
9 | 9 | import ai.onnxruntime.*; |
10 | | -// import ai.onnxruntime.extensions.*; |
11 | 10 | import ai.onnxruntime.extensions.*; |
12 | 11 |
|
13 | 12 | @SuppressWarnings("unused") |
14 | 13 | public class App { |
15 | | - public String getGreeting() { |
16 | | - return "Hello World 2!"; |
17 | | - } |
18 | | - |
19 | 14 | public static void main(String[] args) { |
20 | 15 | var env = OrtEnvironment.getEnvironment(); |
21 | 16 |
|
22 | 17 | try { |
23 | 18 | var sess_opt = new OrtSession.SessionOptions(); |
| 19 | + |
| 20 | + // NOTE: ONNXRuntimeExtensions for Java on Apple Silicon isn't currently |
| 21 | + // available |
| 22 | + // Hence I'll comment out these lines for my machine |
24 | 23 | // sess_opt.registerCustomOpLibrary(OrtxPackage.getLibraryPath()); |
| 24 | + |
| 25 | + // Depending on how you call App within Visual Studio, may need to add app/ to |
| 26 | + // filenames below |
25 | 27 | System.out.println(System.getProperty("user.dir")); |
26 | 28 |
|
27 | | - // var tokenizer = env.createSession("app/tokenizer.onnx", new |
28 | | - // OrtSession.SessionOptions()); |
| 29 | + // Try out tokenizer |
| 30 | + // var tokenizer = env.createSession("tokenizer.onnx", sess_opt); |
29 | 31 |
|
| 32 | + // Get input and output node names for tokenizer |
30 | 33 | // var inputName = tokenizer.getInputNames().iterator().next(); |
31 | 34 | // var outputName = tokenizer.getOutputNames().iterator().next(); |
32 | | - |
33 | 35 | // System.out.println(inputName); |
34 | 36 | // System.out.println(outputName); |
35 | 37 |
|
36 | | - var session = env.createSession("app/model.onnx", sess_opt); |
| 38 | + // Try out embedding model |
| 39 | + var session = env.createSession("model.onnx", sess_opt); |
37 | 40 |
|
38 | 41 | // Get input and output names |
39 | | - |
40 | 42 | var inputName = session.getInputNames().iterator().next(); |
41 | 43 | var outputName = session.getOutputNames().iterator().next(); |
| 44 | + System.out.println(inputName); |
| 45 | + System.out.println(outputName); |
42 | 46 |
|
43 | | - // System.out.println(inputName); |
44 | | - // System.out.println(outputName); |
45 | | - |
| 47 | + // Since I wasn't able to run tokenizer via ONNX on Apple Silicon, hardcode |
| 48 | + // token ids from Python |
| 49 | + // This is for "The quick brown fox..." |
46 | 50 | long[][] tokens = { { 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 1012, 102 } }; |
47 | 51 | long[][] masks = { { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 } }; |
48 | 52 | long[][] token_type_ids = { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }; |
49 | 53 |
|
| 54 | + // Wrap native Java types in OnnxTensor's and put in input map |
50 | 55 | var test_tokens = OnnxTensor.createTensor(env, tokens); |
51 | 56 | var test_mask = OnnxTensor.createTensor(env, masks); |
52 | 57 | var test_token_type_ids = OnnxTensor.createTensor(env, token_type_ids); |
53 | | - |
54 | 58 | var inputs = Map.of("input_ids", test_tokens, "attention_mask", test_mask, "token_type_ids", |
55 | 59 | test_token_type_ids); |
56 | | - var results = session.run(inputs).get("embeddings"); |
57 | 60 |
|
58 | | - // System.out.println("type"); |
59 | | - // System.out.println(results.get().getType()); |
60 | | - |
61 | | - // float[][][] embeddings = (float[][][]) results.get().getValue(); |
| 61 | + // Run embedding model on tokens and convert back to native Java type |
| 62 | + var results = session.run(inputs).get("embeddings"); |
62 | 63 | float[][] embeddings = (float[][]) results.get().getValue(); |
63 | 64 |
|
| 65 | + // Print the first 16 dimensions of the resulting embedding |
64 | 66 | var result = Arrays.toString(Arrays.copyOfRange(embeddings[0], 0, 16)); |
65 | 67 | System.out.println(result); |
| 68 | + |
| 69 | + // Comparing this to our Python notebook, we see it is identical! I.e. calling |
| 70 | + // the ONNXRuntime on the same model from Python and Java produces same results |
| 71 | + // (up to numerical precision etc.) |
66 | 72 | } catch (Exception e) { |
67 | 73 | System.out.println(e); |
68 | 74 | } |
|
0 commit comments