29
29
import org .elasticsearch .xcontent .XContentParser ;
30
30
import org .elasticsearch .xpack .core .ml .MlConfigVersion ;
31
31
import org .elasticsearch .xpack .core .ml .inference .TrainedModelConfig ;
32
+ import org .elasticsearch .xpack .core .ml .inference .assignment .AdaptiveAllocationsSettings ;
32
33
import org .elasticsearch .xpack .core .ml .inference .assignment .AllocationStatus ;
33
34
import org .elasticsearch .xpack .core .ml .inference .assignment .Priority ;
35
+ import org .elasticsearch .xpack .core .ml .inference .assignment .TrainedModelAssignment ;
34
36
import org .elasticsearch .xpack .core .ml .job .messages .Messages ;
35
37
import org .elasticsearch .xpack .core .ml .utils .ExceptionsHelper ;
36
38
import org .elasticsearch .xpack .core .ml .utils .MlTaskParams ;
40
42
import java .util .Optional ;
41
43
import java .util .concurrent .TimeUnit ;
42
44
43
- import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
44
45
import static org .elasticsearch .xpack .core .ml .MlTasks .trainedModelAssignmentTaskDescription ;
45
46
46
47
public class StartTrainedModelDeploymentAction extends ActionType <CreateTrainedModelAssignmentAction .Response > {
@@ -99,6 +100,7 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
99
100
public static final ParseField QUEUE_CAPACITY = TaskParams .QUEUE_CAPACITY ;
100
101
public static final ParseField CACHE_SIZE = TaskParams .CACHE_SIZE ;
101
102
public static final ParseField PRIORITY = TaskParams .PRIORITY ;
103
+ public static final ParseField ADAPTIVE_ALLOCATIONS = TrainedModelAssignment .ADAPTIVE_ALLOCATIONS ;
102
104
103
105
public static final ObjectParser <Request , Void > PARSER = new ObjectParser <>(NAME , Request ::new );
104
106
@@ -117,6 +119,12 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
117
119
ObjectParser .ValueType .VALUE
118
120
);
119
121
PARSER .declareString (Request ::setPriority , PRIORITY );
122
+ PARSER .declareObjectOrNull (
123
+ Request ::setAdaptiveAllocationsSettings ,
124
+ (p , c ) -> AdaptiveAllocationsSettings .PARSER .parse (p , c ).build (),
125
+ null ,
126
+ ADAPTIVE_ALLOCATIONS
127
+ );
120
128
}
121
129
122
130
public static Request parseRequest (String modelId , String deploymentId , XContentParser parser ) {
@@ -140,7 +148,8 @@ public static Request parseRequest(String modelId, String deploymentId, XContent
140
148
private TimeValue timeout = DEFAULT_TIMEOUT ;
141
149
private AllocationStatus .State waitForState = DEFAULT_WAITFOR_STATE ;
142
150
private ByteSizeValue cacheSize ;
143
- private int numberOfAllocations = DEFAULT_NUM_ALLOCATIONS ;
151
+ private Integer numberOfAllocations ;
152
+ private AdaptiveAllocationsSettings adaptiveAllocationsSettings = null ;
144
153
private int threadsPerAllocation = DEFAULT_NUM_THREADS ;
145
154
private int queueCapacity = DEFAULT_QUEUE_CAPACITY ;
146
155
private Priority priority = DEFAULT_PRIORITY ;
@@ -160,7 +169,11 @@ public Request(StreamInput in) throws IOException {
160
169
modelId = in .readString ();
161
170
timeout = in .readTimeValue ();
162
171
waitForState = in .readEnum (AllocationStatus .State .class );
163
- numberOfAllocations = in .readVInt ();
172
+ if (in .getTransportVersion ().onOrAfter (TransportVersions .INFERENCE_ADAPTIVE_ALLOCATIONS )) {
173
+ numberOfAllocations = in .readOptionalVInt ();
174
+ } else {
175
+ numberOfAllocations = in .readVInt ();
176
+ }
164
177
threadsPerAllocation = in .readVInt ();
165
178
queueCapacity = in .readVInt ();
166
179
if (in .getTransportVersion ().onOrAfter (TransportVersions .V_8_4_0 )) {
@@ -171,12 +184,16 @@ public Request(StreamInput in) throws IOException {
171
184
} else {
172
185
this .priority = Priority .NORMAL ;
173
186
}
174
-
175
187
if (in .getTransportVersion ().onOrAfter (TransportVersions .V_8_8_0 )) {
176
188
this .deploymentId = in .readString ();
177
189
} else {
178
190
this .deploymentId = modelId ;
179
191
}
192
+ if (in .getTransportVersion ().onOrAfter (TransportVersions .INFERENCE_ADAPTIVE_ALLOCATIONS )) {
193
+ this .adaptiveAllocationsSettings = in .readOptionalWriteable (AdaptiveAllocationsSettings ::new );
194
+ } else {
195
+ this .adaptiveAllocationsSettings = null ;
196
+ }
180
197
}
181
198
182
199
public final void setModelId (String modelId ) {
@@ -212,14 +229,34 @@ public Request setWaitForState(AllocationStatus.State waitForState) {
212
229
return this ;
213
230
}
214
231
215
- public int getNumberOfAllocations () {
232
+ public Integer getNumberOfAllocations () {
216
233
return numberOfAllocations ;
217
234
}
218
235
219
- public void setNumberOfAllocations (int numberOfAllocations ) {
236
+ public int computeNumberOfAllocations () {
237
+ if (numberOfAllocations != null ) {
238
+ return numberOfAllocations ;
239
+ } else {
240
+ if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings .getMinNumberOfAllocations () == null ) {
241
+ return DEFAULT_NUM_ALLOCATIONS ;
242
+ } else {
243
+ return adaptiveAllocationsSettings .getMinNumberOfAllocations ();
244
+ }
245
+ }
246
+ }
247
+
248
+ public void setNumberOfAllocations (Integer numberOfAllocations ) {
220
249
this .numberOfAllocations = numberOfAllocations ;
221
250
}
222
251
252
+ public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings () {
253
+ return adaptiveAllocationsSettings ;
254
+ }
255
+
256
+ public void setAdaptiveAllocationsSettings (AdaptiveAllocationsSettings adaptiveAllocationsSettings ) {
257
+ this .adaptiveAllocationsSettings = adaptiveAllocationsSettings ;
258
+ }
259
+
223
260
public int getThreadsPerAllocation () {
224
261
return threadsPerAllocation ;
225
262
}
@@ -258,7 +295,11 @@ public void writeTo(StreamOutput out) throws IOException {
258
295
out .writeString (modelId );
259
296
out .writeTimeValue (timeout );
260
297
out .writeEnum (waitForState );
261
- out .writeVInt (numberOfAllocations );
298
+ if (out .getTransportVersion ().onOrAfter (TransportVersions .INFERENCE_ADAPTIVE_ALLOCATIONS )) {
299
+ out .writeOptionalVInt (numberOfAllocations );
300
+ } else {
301
+ out .writeVInt (numberOfAllocations );
302
+ }
262
303
out .writeVInt (threadsPerAllocation );
263
304
out .writeVInt (queueCapacity );
264
305
if (out .getTransportVersion ().onOrAfter (TransportVersions .V_8_4_0 )) {
@@ -270,6 +311,9 @@ public void writeTo(StreamOutput out) throws IOException {
270
311
if (out .getTransportVersion ().onOrAfter (TransportVersions .V_8_8_0 )) {
271
312
out .writeString (deploymentId );
272
313
}
314
+ if (out .getTransportVersion ().onOrAfter (TransportVersions .INFERENCE_ADAPTIVE_ALLOCATIONS )) {
315
+ out .writeOptionalWriteable (adaptiveAllocationsSettings );
316
+ }
273
317
}
274
318
275
319
@ Override
@@ -279,7 +323,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
279
323
builder .field (DEPLOYMENT_ID .getPreferredName (), deploymentId );
280
324
builder .field (TIMEOUT .getPreferredName (), timeout .getStringRep ());
281
325
builder .field (WAIT_FOR .getPreferredName (), waitForState );
282
- builder .field (NUMBER_OF_ALLOCATIONS .getPreferredName (), numberOfAllocations );
326
+ if (numberOfAllocations != null ) {
327
+ builder .field (NUMBER_OF_ALLOCATIONS .getPreferredName (), numberOfAllocations );
328
+ }
329
+ if (adaptiveAllocationsSettings != null ) {
330
+ builder .field (ADAPTIVE_ALLOCATIONS .getPreferredName (), adaptiveAllocationsSettings );
331
+ }
283
332
builder .field (THREADS_PER_ALLOCATION .getPreferredName (), threadsPerAllocation );
284
333
builder .field (QUEUE_CAPACITY .getPreferredName (), queueCapacity );
285
334
if (cacheSize != null ) {
@@ -301,12 +350,25 @@ public ActionRequestValidationException validate() {
301
350
+ Strings .arrayToCommaDelimitedString (VALID_WAIT_STATES )
302
351
);
303
352
}
304
- if (numberOfAllocations < 1 ) {
305
- validationException .addValidationError ("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer" );
353
+ if (numberOfAllocations != null ) {
354
+ if (numberOfAllocations < 1 ) {
355
+ validationException .addValidationError ("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer" );
356
+ }
357
+ if (adaptiveAllocationsSettings != null && adaptiveAllocationsSettings .getEnabled ()) {
358
+ validationException .addValidationError (
359
+ "[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled"
360
+ );
361
+ }
306
362
}
307
363
if (threadsPerAllocation < 1 ) {
308
364
validationException .addValidationError ("[" + THREADS_PER_ALLOCATION + "] must be a positive integer" );
309
365
}
366
+ ActionRequestValidationException autoscaleException = adaptiveAllocationsSettings == null
367
+ ? null
368
+ : adaptiveAllocationsSettings .validate ();
369
+ if (autoscaleException != null ) {
370
+ validationException .addValidationErrors (autoscaleException .validationErrors ());
371
+ }
310
372
if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2 (threadsPerAllocation ) == false ) {
311
373
validationException .addValidationError (
312
374
"[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION
@@ -322,7 +384,7 @@ public ActionRequestValidationException validate() {
322
384
validationException .addValidationError ("[" + TIMEOUT + "] must be positive" );
323
385
}
324
386
if (priority == Priority .LOW ) {
325
- if (numberOfAllocations > 1 ) {
387
+ if (numberOfAllocations != null && numberOfAllocations > 1 ) {
326
388
validationException .addValidationError ("[" + NUMBER_OF_ALLOCATIONS + "] must be 1 when [" + PRIORITY + "] is low" );
327
389
}
328
390
if (threadsPerAllocation > 1 ) {
@@ -344,6 +406,7 @@ public int hashCode() {
344
406
timeout ,
345
407
waitForState ,
346
408
numberOfAllocations ,
409
+ adaptiveAllocationsSettings ,
347
410
threadsPerAllocation ,
348
411
queueCapacity ,
349
412
cacheSize ,
@@ -365,7 +428,8 @@ public boolean equals(Object obj) {
365
428
&& Objects .equals (timeout , other .timeout )
366
429
&& Objects .equals (waitForState , other .waitForState )
367
430
&& Objects .equals (cacheSize , other .cacheSize )
368
- && numberOfAllocations == other .numberOfAllocations
431
+ && Objects .equals (numberOfAllocations , other .numberOfAllocations )
432
+ && Objects .equals (adaptiveAllocationsSettings , other .adaptiveAllocationsSettings )
369
433
&& threadsPerAllocation == other .threadsPerAllocation
370
434
&& queueCapacity == other .queueCapacity
371
435
&& priority == other .priority ;
@@ -430,7 +494,7 @@ public static boolean mayAssignToNode(@Nullable DiscoveryNode node) {
430
494
PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), THREADS_PER_ALLOCATION );
431
495
PARSER .declareInt (ConstructingObjectParser .constructorArg (), QUEUE_CAPACITY );
432
496
PARSER .declareField (
433
- optionalConstructorArg (),
497
+ ConstructingObjectParser . optionalConstructorArg (),
434
498
(p , c ) -> ByteSizeValue .parseBytesSizeValue (p .text (), CACHE_SIZE .getPreferredName ()),
435
499
CACHE_SIZE ,
436
500
ObjectParser .ValueType .VALUE
0 commit comments