@@ -141,19 +141,19 @@ class AquaModelApp(AquaApp):
141
141
@telemetry (entry_point = "plugin=model&action=create" , name = "aqua" )
142
142
def create (
143
143
self ,
144
- model_id : Union [str , AquaMultiModelRef ],
144
+ model : Union [str , AquaMultiModelRef ],
145
145
project_id : Optional [str ] = None ,
146
146
compartment_id : Optional [str ] = None ,
147
147
freeform_tags : Optional [Dict ] = None ,
148
148
defined_tags : Optional [Dict ] = None ,
149
149
** kwargs ,
150
- ) -> DataScienceModel :
150
+ ) -> Union [ DataScienceModel , DataScienceModelGroup ] :
151
151
"""
152
- Creates a custom Aqua model from a service model.
152
+ Creates a custom Aqua model or model group from a service model.
153
153
154
154
Parameters
155
155
----------
156
- model_id : Union[str, AquaMultiModelRef]
156
+ model : Union[str, AquaMultiModelRef]
157
157
The model ID as a string or a AquaMultiModelRef instance to be deployed.
158
158
project_id : Optional[str]
159
159
The project ID for the custom model.
@@ -167,28 +167,18 @@ def create(
167
167
168
168
Returns
169
169
-------
170
- DataScienceModel
171
- The instance of DataScienceModel.
170
+ Union[ DataScienceModel, DataScienceModelGroup]
171
+ The instance of DataScienceModel or DataScienceModelGroup .
172
172
"""
173
- model_id = (
174
- model_id .model_id if isinstance (model_id , AquaMultiModelRef ) else model_id
175
- )
176
- service_model = DataScienceModel .from_id (model_id )
173
+ fine_tune_weights = []
174
+ if isinstance (model , AquaMultiModelRef ):
175
+ fine_tune_weights = model .fine_tune_weights
176
+ model = model .model_id
177
+
178
+ service_model = DataScienceModel .from_id (model )
177
179
target_project = project_id or PROJECT_OCID
178
180
target_compartment = compartment_id or COMPARTMENT_OCID
179
181
180
- # Skip model copying if it is registered model or fine-tuned model
181
- if (
182
- service_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ) is not None
183
- or service_model .freeform_tags .get (Tags .AQUA_FINE_TUNED_MODEL_TAG )
184
- is not None
185
- ):
186
- logger .info (
187
- f"Aqua Model { model_id } already exists in the user's compartment."
188
- "Skipped copying."
189
- )
190
- return service_model
191
-
192
182
# combine tags
193
183
combined_freeform_tags = {
194
184
** (service_model .freeform_tags or {}),
@@ -199,29 +189,112 @@ def create(
199
189
** (defined_tags or {}),
200
190
}
201
191
192
+ custom_model = None
193
+ if fine_tune_weights :
194
+ custom_model = self ._create_model_group (
195
+ model_id = model ,
196
+ compartment_id = target_compartment ,
197
+ project_id = target_project ,
198
+ freeform_tags = combined_freeform_tags ,
199
+ defined_tags = combined_defined_tags ,
200
+ fine_tune_weights = fine_tune_weights ,
201
+ service_model = service_model ,
202
+ )
203
+
204
+ logger .info (
205
+ f"Aqua Model Group { custom_model .id } created with the service model { model } ."
206
+ )
207
+ else :
208
+ # Skip model copying if it is registered model or fine-tuned model
209
+ if (
210
+ Tags .BASE_MODEL_CUSTOM in service_model .freeform_tags
211
+ or Tags .AQUA_FINE_TUNED_MODEL_TAG in service_model .freeform_tags
212
+ ):
213
+ logger .info (
214
+ f"Aqua Model { model } already exists in the user's compartment."
215
+ "Skipped copying."
216
+ )
217
+ return service_model
218
+
219
+ custom_model = self ._create_model (
220
+ compartment_id = target_compartment ,
221
+ project_id = target_project ,
222
+ freeform_tags = combined_freeform_tags ,
223
+ defined_tags = combined_defined_tags ,
224
+ service_model = service_model ,
225
+ ** kwargs ,
226
+ )
227
+ logger .info (
228
+ f"Aqua Model { custom_model .id } created with the service model { model } ."
229
+ )
230
+
231
+ # Track unique models that were created in the user's compartment
232
+ self .telemetry .record_event_async (
233
+ category = "aqua/service/model" ,
234
+ action = "create" ,
235
+ detail = service_model .display_name ,
236
+ )
237
+
238
+ return custom_model
239
+
240
+ def _create_model (
241
+ self ,
242
+ compartment_id : str ,
243
+ project_id : str ,
244
+ freeform_tags : Dict ,
245
+ defined_tags : Dict ,
246
+ service_model : DataScienceModel ,
247
+ ** kwargs ,
248
+ ):
249
+ """Creates a data science model by reference."""
202
250
custom_model = (
203
251
DataScienceModel ()
204
- .with_compartment_id (target_compartment )
205
- .with_project_id (target_project )
252
+ .with_compartment_id (compartment_id )
253
+ .with_project_id (project_id )
206
254
.with_model_file_description (json_dict = service_model .model_file_description )
207
255
.with_display_name (service_model .display_name )
208
256
.with_description (service_model .description )
209
- .with_freeform_tags (** combined_freeform_tags )
210
- .with_defined_tags (** combined_defined_tags )
257
+ .with_freeform_tags (** freeform_tags )
258
+ .with_defined_tags (** defined_tags )
211
259
.with_custom_metadata_list (service_model .custom_metadata_list )
212
260
.with_defined_metadata_list (service_model .defined_metadata_list )
213
261
.with_provenance_metadata (service_model .provenance_metadata )
214
262
.create (model_by_reference = True , ** kwargs )
215
263
)
216
- logger .info (
217
- f"Aqua Model { custom_model .id } created with the service model { model_id } ."
218
- )
219
264
220
- # Track unique models that were created in the user's compartment
221
- self .telemetry .record_event_async (
222
- category = "aqua/service/model" ,
223
- action = "create" ,
224
- detail = service_model .display_name ,
265
+ return custom_model
266
+
267
+ def _create_model_group (
268
+ self ,
269
+ model_id : str ,
270
+ compartment_id : str ,
271
+ project_id : str ,
272
+ freeform_tags : Dict ,
273
+ defined_tags : Dict ,
274
+ fine_tune_weights : List ,
275
+ service_model : DataScienceModel ,
276
+ ):
277
+ """Creates a data science model group."""
278
+ custom_model = (
279
+ DataScienceModelGroup ()
280
+ .with_compartment_id (compartment_id )
281
+ .with_project_id (project_id )
282
+ .with_display_name (service_model .display_name )
283
+ .with_description (service_model .description )
284
+ .with_freeform_tags (** freeform_tags )
285
+ .with_defined_tags (** defined_tags )
286
+ .with_custom_metadata_list (service_model .custom_metadata_list )
287
+ .with_base_model_id (model_id )
288
+ .with_member_models (
289
+ [
290
+ {
291
+ "inference_key" : fine_tune_weight .model_name ,
292
+ "model_id" : fine_tune_weight .model_id ,
293
+ }
294
+ for fine_tune_weight in fine_tune_weights
295
+ ]
296
+ )
297
+ .create ()
225
298
)
226
299
227
300
return custom_model
@@ -271,6 +344,16 @@ def create_multi(
271
344
DataScienceModelGroup
272
345
Instance of DataScienceModelGroup object.
273
346
"""
347
+ member_model_ids = [{"model_id" : model .model_id } for model in models ]
348
+ for model in models :
349
+ if model .fine_tune_weights :
350
+ member_model_ids .extend (
351
+ [
352
+ {"model_id" : fine_tune_model .model_id }
353
+ for fine_tune_model in model .fine_tune_weights
354
+ ]
355
+ )
356
+
274
357
custom_model_group = (
275
358
DataScienceModelGroup ()
276
359
.with_compartment_id (compartment_id )
@@ -281,7 +364,7 @@ def create_multi(
281
364
.with_defined_tags (** (defined_tags or {}))
282
365
.with_custom_metadata_list (model_custom_metadata )
283
366
# TODO: add member model inference key
284
- .with_member_models ([{ "model_id" : model . model_id for model in models }] )
367
+ .with_member_models (member_model_ids )
285
368
)
286
369
custom_model_group .create ()
287
370
0 commit comments