Skip to content

Commit 1cb54bb

Browse files
authored
Merge pull request #37252 from vespa-engine/arnej/configurable-onnx-optimize
Make ONNX graph optimization configurable per model (default off)
2 parents 9dbf5f7 + ea92484 commit 1cb54bb

20 files changed

Lines changed: 118 additions & 11 deletions

File tree

config-model/src/main/java/com/yahoo/schema/OnnxModel.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public class OnnxModel extends DistributableResource implements Cloneable {
3030

3131
// Runtime options
3232
private OnnxModelOptions onnxModelOptions = OnnxModelOptions.empty();
33+
private boolean optimizeModel = false;
3334

3435
public OnnxModel(String name) {
3536
super(name);
@@ -174,4 +175,12 @@ public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() {
174175

175176
public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; }
176177

178+
public void setOptimizeModel(boolean optimizeModel) {
179+
this.optimizeModel = optimizeModel;
180+
}
181+
182+
public boolean getOptimizeModel() {
183+
return optimizeModel;
184+
}
185+
177186
}

config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public FileDistributedOnnxModels clone() {
4949
private static OnnxModelsConfig.Model.Builder toConfig(OnnxModel model) {
5050
OnnxModelsConfig.Model.Builder builder = new OnnxModelsConfig.Model.Builder();
5151
builder.dry_run_on_setup(true);
52+
builder.optimize_model(model.getOptimizeModel());
5253
builder.name(model.getName());
5354
builder.fileref(model.getFileReference());
5455
model.getInputMap().forEach((name, source) -> builder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));

config-model/src/main/javacc/SchemaParser.jj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ TOKEN :
201201
| < INTEROP_THREADS: "interop-threads">
202202
| < GPU_DEVICE: "gpu-device">
203203
| < EXECUTION_MODE: "execution-mode">
204+
| < OPTIMIZE_MODEL: "optimize-model">
204205
| < PARALLEL: "parallel">
205206
| < SEQUENTIAL: "sequential">
206207
| < MODEL: "model" >
@@ -1735,6 +1736,7 @@ void onnxModelItem(OnnxModel onnxModel) :
17351736
{
17361737
String path = null;
17371738
int num;
1739+
boolean enable;
17381740
}
17391741
{
17401742
(
@@ -1745,6 +1747,7 @@ void onnxModelItem(OnnxModel onnxModel) :
17451747
<INTEROP_THREADS> <COLON> num = integer() { onnxModel.setStatelessInterOpThreads(num); } |
17461748
<EXECUTION_MODE> <COLON> ( <PARALLEL> { onnxModel.setStatelessExecutionMode("parallel"); }
17471749
| <SEQUENTIAL> { onnxModel.setStatelessExecutionMode("sequential"); } ) |
1750+
<OPTIMIZE_MODEL> <COLON> enable = bool() { onnxModel.setOptimizeModel(enable); } |
17481751
(<ONNX_INPUT_SL>) {
17491752
String name = token.image.substring(5, token.image.lastIndexOf(":")).trim();
17501753
if (name.startsWith("\"")) { name = name.substring(1, name.length() - 1); }

config-model/src/test/derived/globalphase_onnx_inside/onnx-models.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ model[].input[].source "constant(xx)"
1010
model[].output[].name "vector_Y"
1111
model[].output[].as "out"
1212
model[].dry_run_on_setup true
13+
model[].optimize_model false
1314
model[].stateless_execution_mode ""
1415
model[].stateless_interop_threads -1
1516
model[].stateless_intraop_threads -1
@@ -26,6 +27,7 @@ model[].input[].source "rankingExpression(indirect_x)"
2627
model[].output[].name "vector_Y"
2728
model[].output[].as "foobar"
2829
model[].dry_run_on_setup true
30+
model[].optimize_model false
2931
model[].stateless_execution_mode "parallel"
3032
model[].stateless_interop_threads 5
3133
model[].stateless_intraop_threads 3
@@ -42,6 +44,7 @@ model[].input[].source "rankingExpression(indirect_x)"
4244
model[].output[].name "vector_Y"
4345
model[].output[].as "foobar"
4446
model[].dry_run_on_setup true
47+
model[].optimize_model false
4548
model[].stateless_execution_mode ""
4649
model[].stateless_interop_threads -1
4750
model[].stateless_intraop_threads -1
@@ -58,6 +61,7 @@ model[].input[].source "rankingExpression(indirect_x)"
5861
model[].output[].name "vector_Y"
5962
model[].output[].as "foobar"
6063
model[].dry_run_on_setup true
64+
model[].optimize_model false
6165
model[].stateless_execution_mode ""
6266
model[].stateless_interop_threads -1
6367
model[].stateless_intraop_threads -1

config-model/src/test/derived/globalphase_token_functions/onnx-models.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ model[].input[].source "rankingExpression(token_type_ids)"
1010
model[].output[].name "score"
1111
model[].output[].as "score"
1212
model[].dry_run_on_setup true
13+
model[].optimize_model false
1314
model[].stateless_execution_mode ""
1415
model[].stateless_interop_threads -1
1516
model[].stateless_intraop_threads -1

config-model/src/test/derived/vector_constant/onnx-models.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ model[].input[].source "constant(xx)"
1010
model[].output[].name "vector_Y"
1111
model[].output[].as "foobar"
1212
model[].dry_run_on_setup true
13+
model[].optimize_model false
1314
model[].stateless_execution_mode ""
1415
model[].stateless_interop_threads -1
1516
model[].stateless_intraop_threads -1

config-model/src/test/java/com/yahoo/schema/SchemaTestCase.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,33 @@ function vespa_input() {
125125
assertEquals("vespa_input", input);
126126
}
127127

128+
@Test
129+
void testOnnxModelOptimizeModel() throws Exception {
130+
String schema =
131+
"""
132+
schema test {
133+
document test {
134+
field id type string {
135+
indexing: summary | attribute
136+
}
137+
}
138+
onnx-model default_model {
139+
file: files/default.onnx
140+
}
141+
onnx-model optimized_model {
142+
file: files/optimized.onnx
143+
optimize-model: true
144+
}
145+
}""";
146+
ApplicationBuilder builder = new ApplicationBuilder(new DeployLoggerStub());
147+
builder.processorsToSkip().add(OnnxModelTypeResolver.class); // Avoid discovering the Onnx model referenced does not exist
148+
builder.addSchema(schema);
149+
var application = builder.build(true);
150+
var models = application.schemas().get("test").onnxModels();
151+
assertFalse(models.get("default_model").getOptimizeModel());
152+
assertTrue(models.get("optimized_model").getOptimizeModel());
153+
}
154+
128155
@Test
129156
void testSchemaInheritance() throws ParseException {
130157
String parentLines = joinLines(

config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Map;
1919

2020
import static org.junit.jupiter.api.Assertions.assertEquals;
21+
import static org.junit.jupiter.api.Assertions.assertFalse;
2122
import static org.junit.jupiter.api.Assertions.assertTrue;
2223

2324
public class RankingExpressionWithOnnxModelTestCase {
@@ -75,6 +76,7 @@ private void assertGeneratedConfig(VespaModel vespaModel) {
7576
assertEquals(6, config.model().size());
7677
for (OnnxModelsConfig.Model model : config.model()) {
7778
assertTrue(model.dry_run_on_setup());
79+
assertFalse(model.optimize_model());
7880
}
7981

8082
OnnxModelsConfig.Model model = config.model(0);

configdefinitions/src/vespa/onnx-models.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ model[].input[].source string
1111
model[].output[].name string
1212
model[].output[].as string
1313
model[].dry_run_on_setup bool default=false
14+
model[].optimize_model bool default=false
1415
model[].stateless_execution_mode string default=""
1516
model[].stateless_interop_threads int default=-1
1617
model[].stateless_intraop_threads int default=-1

eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,22 @@ TEST(OnnxModelCacheTest, share_and_evict_onnx_models) {
501501
EXPECT_EQ(OnnxModelCache::count_refs(), 0);
502502
}
503503

504+
TEST(OnnxModelCacheTest, optimize_setting_is_part_of_cache_key) {
505+
{
506+
auto enabled1 = OnnxModelCache::load(simple_model, Onnx::Optimize::ENABLE);
507+
auto enabled2 = OnnxModelCache::load(simple_model, Onnx::Optimize::ENABLE);
508+
auto disabled = OnnxModelCache::load(simple_model, Onnx::Optimize::DISABLE);
509+
// same file + same optimize setting shares the loaded model
510+
EXPECT_EQ(&(enabled1->get()), &(enabled2->get()));
511+
// same file but different optimize setting is a distinct cache entry
512+
EXPECT_NE(&(enabled1->get()), &(disabled->get()));
513+
EXPECT_EQ(OnnxModelCache::num_cached(), 2);
514+
EXPECT_EQ(OnnxModelCache::count_refs(), 3);
515+
}
516+
EXPECT_EQ(OnnxModelCache::num_cached(), 0);
517+
EXPECT_EQ(OnnxModelCache::count_refs(), 0);
518+
}
519+
504520
TensorSpec val(const std::string& expr) {
505521
auto result = TensorSpec::from_expr(expr);
506522
EXPECT_FALSE(ValueType::from_spec(result.type()).is_error());

0 commit comments

Comments
 (0)