Skip to content

Commit dfe1a42

Browse files
authored
Merge pull request #1544 from stefanwebb/onnx_example
Java and Python examples for ONNX
2 parents 40f6c55 + 1414663 commit dfe1a42

File tree

15 files changed

+1723
-0
lines changed

15 files changed

+1723
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
##############################
2+
## Java
3+
##############################
4+
.mtj.tmp/
5+
*.class
6+
*.jar
7+
*.war
8+
*.ear
9+
*.nar
10+
hs_err_pid*
11+
12+
##############################
13+
## Maven
14+
##############################
15+
target/
16+
pom.xml.tag
17+
pom.xml.releaseBackup
18+
pom.xml.versionsBackup
19+
pom.xml.next
20+
pom.xml.bak
21+
release.properties
22+
dependency-reduced-pom.xml
23+
buildNumber.properties
24+
.mvn/timing.properties
25+
.mvn/wrapper/maven-wrapper.jar
26+
27+
##############################
28+
## Gradle
29+
##############################
30+
bin/
31+
build/
32+
.gradle
33+
.gradletasknamecache
34+
gradle-app.setting
35+
!gradle-wrapper.jar
36+
37+
##############################
38+
## IntelliJ
39+
##############################
40+
out/
41+
.idea/
42+
.idea_modules/
43+
*.iml
44+
*.ipr
45+
*.iws
46+
47+
##############################
48+
## Eclipse
49+
##############################
50+
.settings/
51+
bin/
52+
tmp/
53+
.metadata
54+
.classpath
55+
.project
56+
*.tmp
57+
*.bak
58+
*.swp
59+
*~.nib
60+
local.properties
61+
.loadpath
62+
.factorypath
63+
64+
##############################
65+
## NetBeans
66+
##############################
67+
nbproject/private/
68+
build/
69+
nbbuild/
70+
dist/
71+
nbdist/
72+
nbactions.xml
73+
nb-configuration.xml
74+
75+
##############################
76+
## Visual Studio Code
77+
##############################
78+
.vscode/
79+
.code-workspace
80+
81+
##############################
82+
## OS X
83+
##############################
84+
.DS_Store
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* This file was generated by the Gradle 'init' task.
3+
*
4+
* This generated file contains a sample Java application project to get you started.
5+
* For more details on building Java & JVM projects, please refer to https://docs.gradle.org/8.13/userguide/building_java_projects.html in the Gradle documentation.
6+
*/
7+
8+
plugins {
9+
// Apply the application plugin to add support for building a CLI application in Java.
10+
id 'application'
11+
}
12+
13+
repositories {
14+
// Use Maven Central for resolving dependencies.
15+
mavenCentral()
16+
}
17+
18+
dependencies {
19+
// Use JUnit Jupiter for testing.
20+
testImplementation libs.junit.jupiter
21+
22+
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
23+
24+
// This dependency is used by the application.
25+
implementation libs.guava
26+
27+
// onnxruntime full package
28+
implementation 'com.microsoft.onnxruntime:onnxruntime:latest.release'
29+
// onnxruntime-extensions package
30+
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions:latest.release'
31+
32+
implementation 'io.milvus:milvus-sdk-java:2.5.7'
33+
}
34+
35+
// Apply a specific Java toolchain to ease working on different environments.
36+
java {
37+
toolchain {
38+
languageVersion = JavaLanguageVersion.of(21)
39+
}
40+
}
41+
42+
application {
43+
// Define the main class for the application.
44+
mainClass = 'org.example.App'
45+
}
46+
47+
tasks.named('test') {
48+
// Use JUnit Platform for unit tests.
49+
useJUnitPlatform()
50+
}
Binary file not shown.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* This source file was generated by the Gradle 'init' task
3+
*/
4+
package org.example;
5+
6+
import java.util.Arrays;
7+
import java.util.Map;
8+
9+
import ai.onnxruntime.*;
10+
import ai.onnxruntime.extensions.*;
11+
12+
@SuppressWarnings("unused")
13+
public class App {
14+
public static void main(String[] args) {
15+
var env = OrtEnvironment.getEnvironment();
16+
17+
try {
18+
var sess_opt = new OrtSession.SessionOptions();
19+
20+
// NOTE: ONNXRuntimeExtensions for Java on Apple Silicon isn't currently
21+
// available
22+
// Hence I'll comment out these lines for my machine
23+
// sess_opt.registerCustomOpLibrary(OrtxPackage.getLibraryPath());
24+
25+
// Depending on how you call App within Visual Studio, may need to add app/ to
26+
// filenames below
27+
System.out.println(System.getProperty("user.dir"));
28+
29+
// Try out tokenizer
30+
// var tokenizer = env.createSession("tokenizer.onnx", sess_opt);
31+
32+
// Get input and output node names for tokenizer
33+
// var inputName = tokenizer.getInputNames().iterator().next();
34+
// var outputName = tokenizer.getOutputNames().iterator().next();
35+
// System.out.println(inputName);
36+
// System.out.println(outputName);
37+
38+
// Try out embedding model
39+
var session = env.createSession("model.onnx", sess_opt);
40+
41+
// Get input and output names
42+
var inputName = session.getInputNames().iterator().next();
43+
var outputName = session.getOutputNames().iterator().next();
44+
System.out.println(inputName);
45+
System.out.println(outputName);
46+
47+
// Since I wasn't able to run tokenizer via ONNX on Apple Silicon, hardcode
48+
// token ids from Python
49+
// This is for "The quick brown fox..."
50+
long[][] tokens = { { 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 1012, 102 } };
51+
long[][] masks = { { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
52+
long[][] token_type_ids = { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } };
53+
54+
// Wrap native Java types in OnnxTensor's and put in input map
55+
var test_tokens = OnnxTensor.createTensor(env, tokens);
56+
var test_mask = OnnxTensor.createTensor(env, masks);
57+
var test_token_type_ids = OnnxTensor.createTensor(env, token_type_ids);
58+
var inputs = Map.of("input_ids", test_tokens, "attention_mask", test_mask, "token_type_ids",
59+
test_token_type_ids);
60+
61+
// Run embedding model on tokens and convert back to native Java type
62+
var results = session.run(inputs).get("embeddings");
63+
float[][] embeddings = (float[][]) results.get().getValue();
64+
65+
// Print the first 16 dimensions of the resulting embedding
66+
var result = Arrays.toString(Arrays.copyOfRange(embeddings[0], 0, 16));
67+
System.out.println(result);
68+
69+
// Comparing this to our Python notebook, we see it is identical! I.e. calling
70+
// the ONNXRuntime on the same model from Python and Java produces same results
71+
// (up to numerical precision etc.)
72+
} catch (Exception e) {
73+
System.out.println(e);
74+
}
75+
}
76+
}
77+
78+
// 0.146325, 0.32853213, 0.266175, 0.5182375, 0.20214303, -0.17958449,
79+
// 0.15232176, -0.39807054, -0.037162323, -0.057262924, 0.12987728, 0.13251846
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* This source file was generated by the Gradle 'init' task
3+
*/
4+
package org.example;
5+
6+
import org.junit.jupiter.api.Test;
7+
import static org.junit.jupiter.api.Assertions.*;
8+
9+
class AppTest {
10+
@Test void appHasAGreeting() {
11+
App classUnderTest = new App();
12+
assertNotNull(classUnderTest.getGreeting(), "app should have a greeting");
13+
}
14+
}
Binary file not shown.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This file was generated by the Gradle 'init' task.
2+
# https://docs.gradle.org/current/userguide/build_environment.html#sec:gradle_configuration_properties
3+
4+
org.gradle.configuration-cache=true
5+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# This file was generated by the Gradle 'init' task.
2+
# https://docs.gradle.org/current/userguide/platforms.html#sub::toml-dependencies-format
3+
4+
[versions]
5+
guava = "33.3.1-jre"
6+
junit-jupiter = "5.11.3"
7+
8+
[libraries]
9+
guava = { module = "com.google.guava:guava", version.ref = "guava" }
10+
junit-jupiter = { module = "org.junit.jupiter:junit-jupiter", version.ref = "junit-jupiter" }
Binary file not shown.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
distributionBase=GRADLE_USER_HOME
2+
distributionPath=wrapper/dists
3+
distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-bin.zip
4+
networkTimeout=10000
5+
validateDistributionUrl=true
6+
zipStoreBase=GRADLE_USER_HOME
7+
zipStorePath=wrapper/dists

0 commit comments

Comments
 (0)