Skip to content

Commit 81f27e9

Browse files
modify validation for custom base models
1 parent e14954e commit 81f27e9

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

ads/aqua/common/utils.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -668,23 +668,18 @@ def get_model_by_reference_paths(model_file_description: dict):
668668
fine_tune_output_path = UNKNOWN
669669
models = model_file_description["models"]
670670

671-
for model in models:
672-
namespace, bucket_name, prefix = (
673-
model["namespace"],
674-
model["bucketName"],
675-
model["prefix"],
676-
)
677-
bucket_uri = f"oci://{bucket_name}@{namespace}/{prefix}".rstrip("/")
678-
if bucket_name == AQUA_SERVICE_MODELS_BUCKET:
679-
base_model_path = bucket_uri
680-
else:
681-
fine_tune_output_path = bucket_uri
671+
if models:
672+
if len(models) == 1:
673+
base_model_artifact = models[0]
674+
base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip(
675+
"/"
676+
)
677+
if len(models) == 2:
678+
ft_model_artifact = models[1]
679+
fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip(
680+
"/"
681+
)
682682

683-
if not base_model_path:
684-
raise AquaValueError(
685-
f"Base Model should come from the bucket {AQUA_SERVICE_MODELS_BUCKET}. "
686-
f"Other paths are not supported by Aqua."
687-
)
688683
return base_model_path, fine_tune_output_path
689684

690685

0 commit comments

Comments
 (0)