Skip to content

Java and Python examples for ONNX #1544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions bootcamp/tutorials/quickstart/onnx_example/java/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
##############################
## Java
##############################
.mtj.tmp/
*.class
*.jar
*.war
*.ear
*.nar
hs_err_pid*

##############################
## Maven
##############################
target/
pom.xml.tag
pom.xml.releaseBackup
pom.xml.versionsBackup
pom.xml.next
pom.xml.bak
release.properties
dependency-reduced-pom.xml
buildNumber.properties
.mvn/timing.properties
.mvn/wrapper/maven-wrapper.jar

##############################
## Gradle
##############################
bin/
build/
.gradle
.gradletasknamecache
gradle-app.setting
!gradle-wrapper.jar

##############################
## IntelliJ
##############################
out/
.idea/
.idea_modules/
*.iml
*.ipr
*.iws

##############################
## Eclipse
##############################
.settings/
bin/
tmp/
.metadata
.classpath
.project
*.tmp
*.bak
*.swp
*~.nib
local.properties
.loadpath
.factorypath

##############################
## NetBeans
##############################
nbproject/private/
build/
nbbuild/
dist/
nbdist/
nbactions.xml
nb-configuration.xml

##############################
## Visual Studio Code
##############################
.vscode/
.code-workspace

##############################
## OS X
##############################
.DS_Store
50 changes: 50 additions & 0 deletions bootcamp/tutorials/quickstart/onnx_example/java/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* This file was generated by the Gradle 'init' task.
*
* This generated file contains a sample Java application project to get you started.
* For more details on building Java & JVM projects, please refer to https://docs.gradle.org/8.13/userguide/building_java_projects.html in the Gradle documentation.
*/

plugins {
// Apply the application plugin to add support for building a CLI application in Java.
id 'application'
}

repositories {
// Use Maven Central for resolving dependencies.
mavenCentral()
}

dependencies {
// Use JUnit Jupiter for testing.
testImplementation libs.junit.jupiter

testRuntimeOnly 'org.junit.platform:junit-platform-launcher'

// This dependency is used by the application.
implementation libs.guava

// onnxruntime full package
implementation 'com.microsoft.onnxruntime:onnxruntime:latest.release'
// onnxruntime-extensions package
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions:latest.release'

implementation 'io.milvus:milvus-sdk-java:2.5.7'
}

// Apply a specific Java toolchain to ease working on different environments.
java {
toolchain {
languageVersion = JavaLanguageVersion.of(21)
}
}

application {
// Define the main class for the application.
mainClass = 'org.example.App'
}

tasks.named('test') {
// Use JUnit Platform for unit tests.
useJUnitPlatform()
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* This source file was generated by the Gradle 'init' task
*/
package org.example;

import java.util.Arrays;
import java.util.Map;

import ai.onnxruntime.*;
import ai.onnxruntime.extensions.*;

@SuppressWarnings("unused")
public class App {
public static void main(String[] args) {
var env = OrtEnvironment.getEnvironment();

try {
var sess_opt = new OrtSession.SessionOptions();

// NOTE: ONNXRuntimeExtensions for Java on Apple Silicon isn't currently
// available
// Hence I'll comment out these lines for my machine
// sess_opt.registerCustomOpLibrary(OrtxPackage.getLibraryPath());

// Depending on how you call App within Visual Studio, may need to add app/ to
// filenames below
System.out.println(System.getProperty("user.dir"));

// Try out tokenizer
// var tokenizer = env.createSession("tokenizer.onnx", sess_opt);

// Get input and output node names for tokenizer
// var inputName = tokenizer.getInputNames().iterator().next();
// var outputName = tokenizer.getOutputNames().iterator().next();
// System.out.println(inputName);
// System.out.println(outputName);

// Try out embedding model
var session = env.createSession("model.onnx", sess_opt);

// Get input and output names
var inputName = session.getInputNames().iterator().next();
var outputName = session.getOutputNames().iterator().next();
System.out.println(inputName);
System.out.println(outputName);

// Since I wasn't able to run tokenizer via ONNX on Apple Silicon, hardcode
// token ids from Python
// This is for "The quick brown fox..."
long[][] tokens = { { 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 1012, 102 } };
long[][] masks = { { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
long[][] token_type_ids = { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } };

// Wrap native Java types in OnnxTensor's and put in input map
var test_tokens = OnnxTensor.createTensor(env, tokens);
var test_mask = OnnxTensor.createTensor(env, masks);
var test_token_type_ids = OnnxTensor.createTensor(env, token_type_ids);
var inputs = Map.of("input_ids", test_tokens, "attention_mask", test_mask, "token_type_ids",
test_token_type_ids);

// Run embedding model on tokens and convert back to native Java type
var results = session.run(inputs).get("embeddings");
float[][] embeddings = (float[][]) results.get().getValue();

// Print the first 16 dimensions of the resulting embedding
var result = Arrays.toString(Arrays.copyOfRange(embeddings[0], 0, 16));
System.out.println(result);

// Comparing this to our Python notebook, we see it is identical! I.e. calling
// the ONNXRuntime on the same model from Python and Java produces same results
// (up to numerical precision etc.)
} catch (Exception e) {
System.out.println(e);
}
}
}

// 0.146325, 0.32853213, 0.266175, 0.5182375, 0.20214303, -0.17958449,
// 0.15232176, -0.39807054, -0.037162323, -0.057262924, 0.12987728, 0.13251846
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* This source file was generated by the Gradle 'init' task
*/
package org.example;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;

class AppTest {
@Test void appHasAGreeting() {
App classUnderTest = new App();
assertNotNull(classUnderTest.getGreeting(), "app should have a greeting");
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# This file was generated by the Gradle 'init' task.
# https://docs.gradle.org/current/userguide/build_environment.html#sec:gradle_configuration_properties

org.gradle.configuration-cache=true

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file was generated by the Gradle 'init' task.
# https://docs.gradle.org/current/userguide/platforms.html#sub::toml-dependencies-format

[versions]
guava = "33.3.1-jre"
junit-jupiter = "5.11.3"

[libraries]
guava = { module = "com.google.guava:guava", version.ref = "guava" }
junit-jupiter = { module = "org.junit.jupiter:junit-jupiter", version.ref = "junit-jupiter" }
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.13-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
Loading