33
44import com .yahoo .config .model .api .ConfigChangeAction ;
55import com .yahoo .config .model .api .ConfigChangeRestartAction .ConfigChange ;
6+ import com .yahoo .config .model .api .OnnxModelCost ;
7+ import com .yahoo .config .model .api .OnnxModelOptions ;
68import com .yahoo .config .model .deploy .DeployState ;
79import com .yahoo .config .model .deploy .TestProperties ;
10+ import com .yahoo .config .provision .Environment ;
11+ import com .yahoo .config .provision .RegionName ;
12+ import com .yahoo .config .provision .SystemName ;
13+ import com .yahoo .config .provision .Zone ;
814import com .yahoo .vespa .model .VespaModel ;
915import com .yahoo .vespa .model .application .validation .ValidationTester ;
1016import com .yahoo .vespa .model .test .utils .VespaModelCreatorWithMockPkg ;
1117import org .junit .jupiter .api .Test ;
1218
1319import java .util .List ;
20+ import java .util .Map ;
1421
1522import static org .junit .jupiter .api .Assertions .assertEquals ;
1623import static org .junit .jupiter .api .Assertions .assertTrue ;
2128public class RestartOnDeployForTritonOnnxRuntimeValidatorTest {
2229
2330 private static final String SERVICES_XML_WITH_ONE_CLUSTER = """
24- <services version='1.0'>
25- <container id='cluster1' version='1.0'>
26- </container>
27- </services>
28- """ ;
31+ <services version='1.0'>
32+ <container id='cluster1' version='1.0'>
33+ </container>
34+ </services>
35+ """ ;
2936
3037 // Clusters must have different ports
3138 private static final String SERVICES_XML_TWO_CLUSTERS = """
32- <services version='1.0'>
33- <container id='cluster1' version='1.0'>
34- <http>
35- <server id='server1' port='8080'/>
36- </http>
37- </container>
38- <container id='other-cluster' version='1.0'>
39- <http>
40- <server id='server2' port='8081'/>
41- </http>
42- </container>
43- </services>
44- """ ;
39+ <services version='1.0'>
40+ <container id='cluster1' version='1.0'>
41+ <http>
42+ <server id='server1' port='8080'/>
43+ </http>
44+ </container>
45+ <container id='other-cluster' version='1.0'>
46+ <http>
47+ <server id='server2' port='8081'/>
48+ </http>
49+ </container>
50+ </services>
51+ """ ;
4552
4653 private static final String SERVICES_XML_OTHER_CLUSTER = """
47- <services version='1.0'>
48- <container id='other-cluster' version='1.0'>
49- </container>
50- </services>
51- """ ;
54+ <services version='1.0'>
55+ <container id='other-cluster' version='1.0'>
56+ </container>
57+ </services>
58+ """ ;
5259
5360 @ Test
5461 void restart_when_triton_runtime_enabled () {
55- var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER ,false );
56- var next = createModel (SERVICES_XML_WITH_ONE_CLUSTER ,true );
62+ var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER , false );
63+ var next = createModel (SERVICES_XML_WITH_ONE_CLUSTER , true );
5764 var result = validateModel (previous , next );
5865
5966 assertEquals (1 , result .size ());
@@ -66,7 +73,7 @@ void restart_when_triton_runtime_enabled() {
6673
6774 @ Test
6875 void restart_when_triton_runtime_disabled () {
69- var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER ,true );
76+ var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER , true );
7077 var next = createModel (SERVICES_XML_WITH_ONE_CLUSTER , false );
7178 var result = validateModel (previous , next );
7279
@@ -89,7 +96,7 @@ void no_restart_when_triton_runtime_remains_enabled() {
8996
9097 @ Test
9198 void no_restart_when_triton_runtime_remains_disabled () {
92- var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER ,false );
99+ var previous = createModel (SERVICES_XML_WITH_ONE_CLUSTER , false );
93100 var next = createModel (SERVICES_XML_WITH_ONE_CLUSTER , false );
94101 var result = validateModel (previous , next );
95102
@@ -115,19 +122,32 @@ void no_restart_when_cluster_with_triton_runtime_is_removed() {
115122 }
116123
117124 private static List <ConfigChangeAction > validateModel (VespaModel current , VespaModel next ) {
118- return ValidationTester .validateChanges (new RestartOnDeployForTritonOnnxRuntimeValidator (),
119- next ,
120- deployStateBuilder (false ).previousModel (current ).build ());
125+ return ValidationTester .validateChanges (
126+ new RestartOnDeployForTritonOnnxRuntimeValidator (),
127+ next ,
128+ deployStateBuilder (false ).previousModel (current ).build ());
121129 }
122-
130+
123131 private static VespaModel createModel (String servicesXml , boolean useTriton ) {
124132 var builder = deployStateBuilder (useTriton );
125133 return new VespaModelCreatorWithMockPkg (null , servicesXml ).create (builder );
126134 }
127135
128136 private static DeployState .Builder deployStateBuilder (boolean useTriton ) {
129- return new DeployState .Builder ()
130- .properties (new TestProperties ().setUseTriton (useTriton ));
137+ var deployStateBuilder = new DeployState .Builder ().properties (new TestProperties ().setUseTriton (useTriton ));
138+
139+ // Need a model and cloud to enable Triton.
140+ if (useTriton ) {
141+ var mockModelCost = new OnnxModelCost .DisabledOnnxModelCost () {
142+ @ Override
143+ public Map <String , ModelInfo > models () {
144+ return Map .of ("modernbert" , new ModelInfo ("modernbert" , 1 , 1 , OnnxModelOptions .empty ()));
145+ }
146+ };
147+ deployStateBuilder .onnxModelCost (mockModelCost );
148+ deployStateBuilder .zone (new Zone (SystemName .PublicCd , Environment .dev , RegionName .defaultName ()));
149+ }
150+
151+ return deployStateBuilder ;
131152 }
132-
133153}
0 commit comments