Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/swerex/deployment/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,54 @@ def get_deployment(self) -> AbstractDeployment:
return DummyDeployment.from_config(self)


class VeFaasDeploymentConfig(BaseModel):
image: str = "python:3.11"
platform: str | None = None
startup_timeout: float = 180.0
model_config = ConfigDict(extra="forbid")
ak: str | None = None
sk: str | None = None
region: str | None = None
apigateway_service_id: str | None = None
function_id: str | None = None
request_timeout: int = 300
client_side_validation: bool = True

@model_validator(mode="before")
def validate_platform_args(cls, data: dict) -> dict:
if not isinstance(data, dict):
return data

docker_args = data.get("docker_args", [])
platform = data.get("platform")

platform_arg_idx = next((i for i, arg in enumerate(docker_args) if arg.startswith("--platform")), -1)

if platform_arg_idx != -1:
if platform is not None:
msg = "Cannot specify platform both via 'platform' field and '--platform' in docker_args"
raise ValueError(msg)
# Extract platform value from --platform argument
if "=" in docker_args[platform_arg_idx]:
# Handle case where platform is specified as --platform=value
data["platform"] = docker_args[platform_arg_idx].split("=", 1)[1]
data["docker_args"] = docker_args[:platform_arg_idx] + docker_args[platform_arg_idx + 1 :]
elif platform_arg_idx + 1 < len(docker_args):
data["platform"] = docker_args[platform_arg_idx + 1]
# Remove the --platform and its value from docker_args
data["docker_args"] = docker_args[:platform_arg_idx] + docker_args[platform_arg_idx + 2 :]
else:
msg = "--platform argument must be followed by a value"
raise ValueError(msg)

return data

def get_deployment(self) -> AbstractDeployment:
from swerex.deployment.vefaas import VeFaasDeployment

return VeFaasDeployment.from_config(self)


class DaytonaDeploymentConfig(BaseModel):
"""Configuration for Daytona deployment."""

Expand All @@ -211,6 +259,7 @@ def get_deployment(self) -> AbstractDeployment:
| RemoteDeploymentConfig
| DummyDeploymentConfig
| DaytonaDeploymentConfig
| VeFaasDeploymentConfig
)
"""Union of all deployment configurations. Useful for type hints."""

Expand Down
202 changes: 202 additions & 0 deletions src/swerex/deployment/vefaas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import logging
import time
import uuid
from typing import Any

from typing_extensions import Self
from volcenginesdkapig import APIGApi
from volcenginesdkapig.models import (
GetGatewayServiceRequest,
GetGatewayServiceResponse,
)
from volcenginesdkcore import ApiClient, Configuration
from volcenginesdkvefaas import (
CreateSandboxRequest,
CreateSandboxResponse,
InstanceImageInfoForCreateSandboxInput,
KillSandboxRequest,
VEFAASApi,
)

from swerex.deployment.abstract import AbstractDeployment
from swerex.deployment.config import VeFaasDeploymentConfig
from swerex.deployment.hooks.abstract import CombinedDeploymentHook, DeploymentHook
from swerex.exceptions import DeploymentNotStartedError, DeploymentStartupError
from swerex.runtime.abstract import IsAliveResponse
from swerex.runtime.config import RemoteRuntimeConfig
from swerex.runtime.remote import RemoteRuntime
from swerex.utils.log import get_logger
from swerex.utils.wait import _wait_until_alive


class VeFaasDeployment(AbstractDeployment):
def __init__(self, *, logger: logging.Logger | None = None, **kwargs: Any):
self._config = VeFaasDeploymentConfig(**kwargs)
self._runtime: RemoteRuntime | None = None
self._container_name = None
self._hooks = CombinedDeploymentHook()
self.logger = logger or get_logger("rex-deploy")
self._runtime_timeout = 0.15
self._api_client = None

self._sandbox_id = ""

@classmethod
def from_config(cls, config: VeFaasDeploymentConfig) -> Self:
return cls(**config.model_dump())

def add_hook(self, hook: DeploymentHook):
self._hooks.add_hook(hook)

def _get_token(self) -> str:
return str(uuid.uuid4())

async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse:
"""Checks if the runtime is alive. The return value can be
tested with bool().

Raises:
DeploymentNotStartedError: If the deployment was not started.
"""
if self._runtime is None:
msg = "Runtime not started"
raise RuntimeError(msg)
return await self._runtime.is_alive(timeout=timeout)

async def _wait_until_alive(self, timeout: float = 10.0):
try:
return await _wait_until_alive(self.is_alive, timeout=timeout, function_timeout=self._runtime_timeout)
except TimeoutError as e:
self.logger.error("Runtime did not start within timeout. Here's the output from the container process.")
await self.stop()
raise e

def _get_domain(self, apigs_id):
api_instance = APIGApi(self._get_api_client())
req = GetGatewayServiceRequest(
id=apigs_id,
)
response = api_instance.get_gateway_service(req)
if not isinstance(response, GetGatewayServiceResponse):
raise Exception(response)

https_domains = [d.domain for d in response.gateway_service.domains if d.domain.startswith("https://")]

if https_domains:
return https_domains[0]
elif response.gateway_service.domains:
return response.gateway_service.domains[0].domain
else:
return None

def _get_container_name(self) -> str:
"""Returns a unique container name based on the image name."""
image_str = self._config.image.split("/")
image_name_sanitized = image_str[-1].replace("_", "-")
image_name_sanitized = image_name_sanitized.replace(":", "-")
image_name_sanitized = image_name_sanitized.replace(".", "-")

return image_name_sanitized[:-14]

def _get_api_client(self) -> ApiClient:
if self._api_client:
return self._api_client

access_key = self._config.ak
secret_key = self._config.sk
region = self._config.region

if not access_key or not secret_key:
emsg = "VOLCENGINE_ACCESS_KEY and VOLCENGINE_SECRET_KEY must be set"
raise DeploymentStartupError(emsg)

config = Configuration()
config.ak = access_key
config.sk = secret_key
config.region = region
_api_client = ApiClient(config)

self._api_client = _api_client
return self._api_client

async def create_sandbox(self, function_id, image, cmd, request_timeout) -> str:
client = VEFAASApi(self._get_api_client())

instance_image_info = InstanceImageInfoForCreateSandboxInput(image=image, port=8000, command=cmd)

response = client.create_sandbox(
CreateSandboxRequest(
function_id=function_id,
instance_image_info=instance_image_info,
request_timeout=request_timeout,
timeout=120,
)
)
if not isinstance(response, CreateSandboxResponse):
emsg = "Failed to create sandbox"
raise DeploymentStartupError(emsg)
if not response.sandbox_id:
emsg = "Failed to create sandbox: no sandbox id"
raise DeploymentStartupError(emsg)
return response.sandbox_id

async def kill_sandbox(self) -> str:
client = VEFAASApi(self._get_api_client())

if self._sandbox_id:
response = client.kill_sandbox(
KillSandboxRequest(function_id=self._config.function_id, sandbox_id=self._sandbox_id)
)
if not isinstance(response, CreateSandboxResponse):
self.logger.warning(f"Kill Sandbox {self._sandbox_id} Failed")
self._sandbox_id = ""

async def start(self):
"""Start Faas runtime"""

assert self._container_name is None
self._container_name = self._get_container_name()

self.logger.info(f"Starting container {self._container_name}")

# Gen swe-rex command
token = self._get_token()
cmd = f"curl -fsSL https://vefaas-swe.tos-cn-beijing.volces.com/swe-rex/install_1.4.0.sh | bash -s -- {token}"

# create sandbox
sandbox_id = await self.create_sandbox(
self._config.function_id, self._config.image, cmd, self._config.request_timeout
)
self._sandbox_id = sandbox_id

domain = self._get_domain(self._config.apigateway_service_id)

self._runtime = RemoteRuntime.from_config(
RemoteRuntimeConfig(
host=domain, timeout=self._runtime_timeout, auth_token=token, faas_instance_name=self._sandbox_id
)
)

t0 = time.time()
await self._wait_until_alive(timeout=self._config.startup_timeout)
self.logger.info(f"Runtime started in {time.time() - t0:.2f}s")

async def stop(self):
"""Stop the runtime"""
if self._runtime is not None:
await self._runtime.close()
self._runtime = None

# kill sandbox
await self.kill_sandbox()

@property
def runtime(self) -> RemoteRuntime:
"""Returns the runtime if running.

Raises:
DeploymentNotStartedError: If the deployment was not started.
"""
if self._runtime is None:
raise DeploymentNotStartedError()
return self._runtime
4 changes: 3 additions & 1 deletion src/swerex/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(self, message="Deployment not started"):
super().__init__(message)


class DeploymentStartupError(SwerexException, RuntimeError): ...
class DeploymentStartupError(SwerexException, RuntimeError):
def __init__(self, message="Deployment startup error"):
super().__init__(message)


class DockerPullError(DeploymentStartupError): ...
Expand Down
3 changes: 3 additions & 0 deletions src/swerex/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class RemoteRuntimeConfig(BaseModel):
timeout: float = 0.15
"""The timeout for the runtime."""

faas_instance_name: str | None = None
"""For Vefaas instance."""

type: Literal["remote"] = "remote"
"""Discriminator for (de)serialization/CLI. Do not change."""

Expand Down
8 changes: 6 additions & 2 deletions src/swerex/runtime/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ def _get_timeout(self, timeout: float | None = None) -> float:
@property
def _headers(self) -> dict[str, str]:
"""Request headers to use for authentication."""
headers = {}

if self._config.auth_token:
return {"X-API-Key": self._config.auth_token}
return {}
headers["X-API-Key"] = self._config.auth_token
if self._config.faas_instance_name:
headers["x-faas-instance-name"] = self._config.faas_instance_name
return headers

@property
def _api_url(self) -> str:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_faas_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from swerex.deployment.vefaas import VeFaasDeployment


async def test_faas_deployment():
f = VeFaasDeployment(
image="enterprise-public-cn-beijing.cr.volces.com/swe-bench/sweb.eval.x86_64.django_1776_django-15414:latest",
ak="",
sk="",
region="cn-beijing",
function_id="awokjltn",
apigateway_service_id="sd2on64i5ni4n75n9unpg",
)
with pytest.raises(RuntimeError):
await f.is_alive()
await f.start()
assert await f.is_alive()
await f.stop()