diff --git a/mplang/core/cluster.py b/mplang/core/cluster.py index 2beeb868..010deb58 100644 --- a/mplang/core/cluster.py +++ b/mplang/core/cluster.py @@ -20,6 +20,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from functools import cached_property from typing import Any @@ -169,6 +170,16 @@ def to_dict(self) -> dict[str, Any]: }, } + @cached_property + def endpoints(self) -> list[str]: + eps: list[str] = [] + for n in sorted( + self.nodes.values(), + key=lambda x: x.rank, # type: ignore[attr-defined] + ): + eps.append(n.endpoint) + return eps + @classmethod def from_dict(cls, config: dict[str, Any]) -> ClusterSpec: """Parses a raw config dictionary and returns a validated ClusterSpec.""" diff --git a/mplang/core/comm.py b/mplang/core/comm.py index 35ab4d95..e395e1be 100644 --- a/mplang/core/comm.py +++ b/mplang/core/comm.py @@ -48,6 +48,10 @@ def send(self, to: int, key: str, data: Any) -> None: def recv(self, frm: int, key: str) -> Any: """Receive data from peer with the given key""" + @abstractmethod + def onSent(self, frm: int, key: str, data: Any) -> None: + """Called when a key is sent to self""" + class ICollective(ABC): """Interface for collective communication""" diff --git a/mplang/runtime/server.py b/mplang/runtime/server.py index ac6b4263..7e63d771 100644 --- a/mplang/runtime/server.py +++ b/mplang/runtime/server.py @@ -31,6 +31,7 @@ from pydantic import BaseModel from mplang.core import IrReader, TableType, TensorType +from mplang.core.cluster import ClusterSpec from mplang.kernels.base import KernelContext from mplang.kernels.value import Value, decode_value, encode_value from mplang.protos.v1alpha1 import mpir_pb2 @@ -40,6 +41,7 @@ Computation, Session, Symbol, + create_session_from_spec, ) logger = logging.getLogger(__name__) @@ -260,15 +262,12 @@ def list_session_computations(session_name: str) -> ComputationListResponse: def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse: validate_name(session_name, "session") # Delegate cluster spec parsing & session construction to resource layer - from mplang.core.cluster import ClusterSpec # local import to avoid cycles if session_name in _sessions: sess = _sessions[session_name] else: spec = ClusterSpec.from_dict(request.cluster_spec) - if len(spec.get_devices_by_kind("SPU")) == 0: - raise InvalidRequestError("No SPU device found in cluster_spec for session") - sess = Session(name=session_name, rank=request.rank, cluster_spec=spec) + sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec) _sessions[session_name] = sess return SessionResponse(name=sess.name) diff --git a/mplang/runtime/session.py b/mplang/runtime/session.py index 114a231c..e598e716 100644 --- a/mplang/runtime/session.py +++ b/mplang/runtime/session.py @@ -34,6 +34,8 @@ import spu.libspu as libspu +from mplang.core.cluster import ClusterSpec +from mplang.core.comm import ICommunicator from mplang.core.expr.ast import Expr from mplang.core.expr.evaluator import IEvaluator, create_evaluator from mplang.core.mask import Mask @@ -96,19 +98,25 @@ class SessionState: class Session: """Represents the per-rank execution context. - Immutable config: name, rank, cluster_spec. + Immutable config: name, rank, cluster_spec, communicator. Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party. Mutable: state (runtime object, symbols, computations, seeded flag). + + Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints). """ - def __init__(self, name: str, rank: int, cluster_spec: ClusterSpec): + def __init__( + self, + name: str, + rank: int, + cluster_spec: ClusterSpec, + communicator: ICommunicator, + ): self.name = name self.rank = rank self.cluster_spec = cluster_spec self.state = SessionState() - self.communicator = HttpCommunicator( - session_name=name, rank=rank, endpoints=self.endpoints - ) + self.communicator = communicator # --- Derived topology --- @cached_property @@ -119,18 +127,9 @@ def node(self) -> Node: def runtime_info(self) -> RuntimeInfo: return self.node.runtime_info - @cached_property + @property def endpoints(self) -> list[str]: - eps: list[str] = [] - for n in sorted( - self.cluster_spec.nodes.values(), - key=lambda x: x.rank, # type: ignore[attr-defined] - ): - ep = n.endpoint - if not ep.startswith(("http://", "https://")): - ep = f"http://{ep}" - eps.append(ep) - return eps + return self.cluster_spec.endpoints @cached_property def spu_device(self): # type: ignore @@ -191,10 +190,11 @@ def ensure_spu_env(self) -> None: if self.is_spu_party: # Build SPU address list across all endpoints for ranks in mask spu_addrs: list[str] = [] - for r, addr in enumerate(self.communicator.endpoints): + for r, addr in enumerate(self.cluster_spec.endpoints): if r in self.spu_mask: - if "//" not in addr: - addr = f"//{addr}" + # TODO(oeqqwq): addr may contain other schema like grpc:// + if not addr.startswith(("http://", "https://")): + addr = f"http://{addr}" parsed = urlparse(addr) assert isinstance(parsed.port, int) new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}" @@ -281,12 +281,17 @@ def execute( ) self.add_symbol(Symbol(name=name, mptype={}, data=val)) - # --- Convenience constructor --- - @classmethod - def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session: - from mplang.core.cluster import ClusterSpec # local import to avoid cycles - spec = ClusterSpec.from_dict(spec_dict) - if len(spec.get_devices_by_kind("SPU")) == 0: - raise RuntimeError("No SPU device found in cluster_spec") - return cls(name=name, rank=rank, cluster_spec=spec) +# --- Convenience constructor use HttpCommunicator--- +def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session: + if len(spec.get_devices_by_kind("SPU")) == 0: + raise RuntimeError("No SPU device found in cluster_spec") + + # Create HttpCommunicator for the session + communicator = HttpCommunicator( + session_name=name, + rank=rank, + endpoints=spec.endpoints, + ) + + return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)