Skip to content
31 changes: 23 additions & 8 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ def get_config(

return ModelConfigResult(config=config, model_details=oci_model)

def get_container_image(self, container_type: str = None) -> str:
def get_container_image(
self, container_type: str = None, container_tag: str = None
) -> str:
"""
Gets the latest smc container complete image name from the given container type.

Expand All @@ -463,6 +465,9 @@ def get_container_image(self, container_type: str = None) -> str:
container_type: str
type of container, can be either odsc-vllm-serving, odsc-llm-fine-tuning, odsc-llm-evaluate

container_tag: str
tag of container, ex: 0.8.5.post1.1

Returns
-------
str:
Expand All @@ -476,13 +481,23 @@ def get_container_image(self, container_type: str = None) -> str:
)
if not container:
raise AquaValueError(f"Invalid container type : {container_type}")
container_image = (
SERVICE_MANAGED_CONTAINER_URI_SCHEME
+ container.container_name
+ ":"
+ container.tag
)
return container_image

if container_tag:
container_image = (
SERVICE_MANAGED_CONTAINER_URI_SCHEME
+ container.container_name
+ ":"
+ container_tag
)
return container_image
else:
container_image = (
SERVICE_MANAGED_CONTAINER_URI_SCHEME
+ container.container_name
+ ":"
+ container.tag
)
return container_image

@cached(cache=TTLCache(maxsize=20, ttl=timedelta(minutes=30), timer=datetime.now))
def list_service_containers(self) -> List[ContainerSummary]:
Expand Down
22 changes: 16 additions & 6 deletions ads/aqua/config/container_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class AquaContainerConfigItem(Serializable):
platforms (Optional[List[str]]): Supported platforms.
model_formats (Optional[List[str]]): Supported model formats.
spec (Optional[AquaContainerConfigSpec]): Container specification details.
lifecycle_state (Optional[str]): Lifecycle state of the container
"""

name: Optional[str] = Field(
Expand Down Expand Up @@ -101,6 +102,9 @@ class AquaContainerConfigItem(Serializable):
usages: Optional[List[str]] = Field(
default_factory=list, description="Supported usages."
)
lifecycle_state: Optional[str] = Field(
default=None, description="Lifecycle state of the container (e.g., 'ACTIVE', 'INACTIVE')."
)

class Config:
extra = "allow"
Expand All @@ -118,7 +122,7 @@ class AquaContainerConfig(Serializable):
evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items.
"""

inference: Dict[str, AquaContainerConfigItem] = Field(
inference: Dict[str, List[AquaContainerConfigItem]] = Field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change will break in a few places where AquaContainerConfig is used, likely in deployment.py. Can you check if that class needs changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure , will check , thanks

default_factory=dict, description="Inference container configuration items."
)
finetune: Dict[str, AquaContainerConfigItem] = Field(
Expand All @@ -130,7 +134,9 @@ class AquaContainerConfig(Serializable):

def to_dict(self):
return {
"inference": list(self.inference.values()),
"inference": [
item for sublist in self.inference.values() for item in sublist
],
"finetune": list(self.finetune.values()),
"evaluate": list(self.evaluate.values()),
}
Expand All @@ -149,18 +155,20 @@ def from_service_config(
-------
AquaContainerConfig: The constructed container configuration.
"""

inference_items: Dict[str, AquaContainerConfigItem] = {}
inference_items: Dict[str, List[AquaContainerConfigItem]] = {}
finetune_items: Dict[str, AquaContainerConfigItem] = {}
evaluate_items: Dict[str, AquaContainerConfigItem] = {}
for container in service_containers:
if not container.is_latest:
if getattr(container, "lifecycle_state", "").upper() != "ACTIVE":
continue
if "INFERENCE" not in container.usages and not container.is_latest:
continue
container_item = AquaContainerConfigItem(
name=SERVICE_MANAGED_CONTAINER_URI_SCHEME + container.container_name,
version=container.tag,
display_name=container.display_name,
family=container.family_name,
lifecycle_state=container.lifecycle_state,
usages=container.usages,
platforms=[],
model_formats=[],
Expand Down Expand Up @@ -242,7 +250,9 @@ def from_service_config(
)

if "INFERENCE" in usages or "MULTI_MODEL" in usages:
inference_items[container_type] = container_item
if container_type not in inference_items:
inference_items[container_type] = []
inference_items[container_type].append(container_item)
if "FINE_TUNE" in usages:
finetune_items[container_type] = container_item
if "EVALUATION" in usages:
Expand Down
11 changes: 9 additions & 2 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def create(
health_check_port (Optional[int]): Health check port for the Docker container image.
env_var (Optional[Dict[str, str]]): Environment variables for deployment.
container_family (Optional[str]): Image family of the model deployment container runtime.
container_tag (Optional[str]): Image tag of the model deployment container runtime
memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
ocpus (Optional[float]): OCPU count for the selected shape.
model_file (Optional[str]): File used for model deployment.
Expand Down Expand Up @@ -436,7 +437,10 @@ def _create(

container_image_uri = (
create_deployment_details.container_image_uri
or self.get_container_image(container_type=container_type_key)
or self.get_container_image(
container_type=container_type_key,
container_tag=create_deployment_details.container_tag,
)
)
if not container_image_uri:
try:
Expand Down Expand Up @@ -665,7 +669,10 @@ def _create_multi(

container_image_uri = (
create_deployment_details.container_image_uri
or self.get_container_image(container_type=container_type_key)
or self.get_container_image(
container_type=container_type_key,
container_tag=create_deployment_details.container_tag,
)
)
server_port = create_deployment_details.server_port or (
container_spec.server_port if container_spec else None
Expand Down
3 changes: 3 additions & 0 deletions ads/aqua/modeldeployment/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ class CreateModelDeploymentDetails(BaseModel):
container_family: Optional[str] = Field(
None, description="Image family of the model deployment container runtime."
)
container_tag: Optional[str] = Field(
None, description="Image tag of the model deployment container runtime."
)
memory_in_gbs: Optional[float] = Field(
None, description="Memory (in GB) for the selected shape."
)
Expand Down
Loading