Skip to content

Commit fec03bc

Browse files
authored
Adds shouldUseTriton methods that decides whether to use Triton for ONNX inference. Used by both Triton sidecar and TritonOnnxRuntime, making connection between them more explicit. (#35784)
1 parent d65cdc1 commit fec03bc

2 files changed

Lines changed: 64 additions & 41 deletions

File tree

config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,18 @@ private void addClusterContent(ApplicationContainerCluster cluster, Element spec
248248

249249
addParameterStoreValidationHandler(cluster, deployState);
250250
}
251-
252-
private List<SidecarSpec> getSidecars(ApplicationContainerCluster cluster, DeployState deployState, NodesSpecification nodesSpecification) {
253-
var sidecars = new ArrayList<SidecarSpec>();
254-
251+
252+
private boolean shouldUseTriton(ApplicationContainerCluster cluster, DeployState deployState) {
255253
var isPublicCloud = deployState.zone().system().isPublicCloudLike();
256254
var hasOnnxModels = !cluster.onnxModelCostCalculator().models().isEmpty();
257-
var useTritonFlagValue = deployState.featureFlags().useTriton();
255+
var useTritonFeatureFlagValue = deployState.featureFlags().useTriton();
256+
return useTritonFeatureFlagValue && isPublicCloud && hasOnnxModels;
257+
}
258258

259-
if (useTritonFlagValue && isPublicCloud && hasOnnxModels) {
259+
private List<SidecarSpec> getSidecars(ApplicationContainerCluster cluster, DeployState deployState, NodesSpecification nodesSpecification) {
260+
var sidecars = new ArrayList<SidecarSpec>();
261+
262+
if (shouldUseTriton(cluster, deployState)) {
260263
var hasGpu = !nodesSpecification.minResources().nodeResources().gpuResources().isZero();
261264

262265
// Hardcoded values for changes to be reviewed and tested
@@ -974,7 +977,7 @@ protected void addModelEvaluationRuntime(DeployState deployState, ApplicationCon
974977
cluster.addPlatformBundle(ContainerModelEvaluation.MODEL_INTEGRATION_BUNDLE_FILE);
975978
cluster.addPlatformBundle(ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_FILE);
976979
/* The ONNX runtime is always available for injection to any component */
977-
if (deployState.featureFlags().useTriton()) {
980+
if (shouldUseTriton(cluster, deployState)) {
978981
cluster.addSimpleComponent(
979982
ContainerModelEvaluation.TRITON_ONNX_RUNTIME_CLASS, null, ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME);
980983
} else {

config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForTritonOnnxRuntimeValidatorTest.java

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33

44
import com.yahoo.config.model.api.ConfigChangeAction;
55
import com.yahoo.config.model.api.ConfigChangeRestartAction.ConfigChange;
6+
import com.yahoo.config.model.api.OnnxModelCost;
7+
import com.yahoo.config.model.api.OnnxModelOptions;
68
import com.yahoo.config.model.deploy.DeployState;
79
import com.yahoo.config.model.deploy.TestProperties;
10+
import com.yahoo.config.provision.Environment;
11+
import com.yahoo.config.provision.RegionName;
12+
import com.yahoo.config.provision.SystemName;
13+
import com.yahoo.config.provision.Zone;
814
import com.yahoo.vespa.model.VespaModel;
915
import com.yahoo.vespa.model.application.validation.ValidationTester;
1016
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg;
1117
import org.junit.jupiter.api.Test;
1218

1319
import java.util.List;
20+
import java.util.Map;
1421

1522
import static org.junit.jupiter.api.Assertions.assertEquals;
1623
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -21,39 +28,39 @@
2128
public class RestartOnDeployForTritonOnnxRuntimeValidatorTest {
2229

2330
private static final String SERVICES_XML_WITH_ONE_CLUSTER = """
24-
<services version='1.0'>
25-
<container id='cluster1' version='1.0'>
26-
</container>
27-
</services>
28-
""";
31+
<services version='1.0'>
32+
<container id='cluster1' version='1.0'>
33+
</container>
34+
</services>
35+
""";
2936

3037
// Clusters must have different ports
3138
private static final String SERVICES_XML_TWO_CLUSTERS = """
32-
<services version='1.0'>
33-
<container id='cluster1' version='1.0'>
34-
<http>
35-
<server id='server1' port='8080'/>
36-
</http>
37-
</container>
38-
<container id='other-cluster' version='1.0'>
39-
<http>
40-
<server id='server2' port='8081'/>
41-
</http>
42-
</container>
43-
</services>
44-
""";
39+
<services version='1.0'>
40+
<container id='cluster1' version='1.0'>
41+
<http>
42+
<server id='server1' port='8080'/>
43+
</http>
44+
</container>
45+
<container id='other-cluster' version='1.0'>
46+
<http>
47+
<server id='server2' port='8081'/>
48+
</http>
49+
</container>
50+
</services>
51+
""";
4552

4653
private static final String SERVICES_XML_OTHER_CLUSTER = """
47-
<services version='1.0'>
48-
<container id='other-cluster' version='1.0'>
49-
</container>
50-
</services>
51-
""";
54+
<services version='1.0'>
55+
<container id='other-cluster' version='1.0'>
56+
</container>
57+
</services>
58+
""";
5259

5360
@Test
5461
void restart_when_triton_runtime_enabled() {
55-
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER,false);
56-
var next = createModel(SERVICES_XML_WITH_ONE_CLUSTER,true);
62+
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER, false);
63+
var next = createModel(SERVICES_XML_WITH_ONE_CLUSTER, true);
5764
var result = validateModel(previous, next);
5865

5966
assertEquals(1, result.size());
@@ -66,7 +73,7 @@ void restart_when_triton_runtime_enabled() {
6673

6774
@Test
6875
void restart_when_triton_runtime_disabled() {
69-
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER,true);
76+
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER, true);
7077
var next = createModel(SERVICES_XML_WITH_ONE_CLUSTER, false);
7178
var result = validateModel(previous, next);
7279

@@ -89,7 +96,7 @@ void no_restart_when_triton_runtime_remains_enabled() {
8996

9097
@Test
9198
void no_restart_when_triton_runtime_remains_disabled() {
92-
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER,false);
99+
var previous = createModel(SERVICES_XML_WITH_ONE_CLUSTER, false);
93100
var next = createModel(SERVICES_XML_WITH_ONE_CLUSTER, false);
94101
var result = validateModel(previous, next);
95102

@@ -115,19 +122,32 @@ void no_restart_when_cluster_with_triton_runtime_is_removed() {
115122
}
116123

117124
private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) {
118-
return ValidationTester.validateChanges(new RestartOnDeployForTritonOnnxRuntimeValidator(),
119-
next,
120-
deployStateBuilder(false).previousModel(current).build());
125+
return ValidationTester.validateChanges(
126+
new RestartOnDeployForTritonOnnxRuntimeValidator(),
127+
next,
128+
deployStateBuilder(false).previousModel(current).build());
121129
}
122-
130+
123131
private static VespaModel createModel(String servicesXml, boolean useTriton) {
124132
var builder = deployStateBuilder(useTriton);
125133
return new VespaModelCreatorWithMockPkg(null, servicesXml).create(builder);
126134
}
127135

128136
private static DeployState.Builder deployStateBuilder(boolean useTriton) {
129-
return new DeployState.Builder()
130-
.properties(new TestProperties().setUseTriton(useTriton));
137+
var deployStateBuilder = new DeployState.Builder().properties(new TestProperties().setUseTriton(useTriton));
138+
139+
// Need a model and cloud to enable Triton.
140+
if (useTriton) {
141+
var mockModelCost = new OnnxModelCost.DisabledOnnxModelCost() {
142+
@Override
143+
public Map<String, ModelInfo> models() {
144+
return Map.of("modernbert", new ModelInfo("modernbert", 1, 1, OnnxModelOptions.empty()));
145+
}
146+
};
147+
deployStateBuilder.onnxModelCost(mockModelCost);
148+
deployStateBuilder.zone(new Zone(SystemName.PublicCd, Environment.dev, RegionName.defaultName()));
149+
}
150+
151+
return deployStateBuilder;
131152
}
132-
133153
}

0 commit comments

Comments
 (0)