Skip to content
11 changes: 11 additions & 0 deletions mplang/core/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cached_property
from typing import Any


Expand Down Expand Up @@ -169,6 +170,16 @@ def to_dict(self) -> dict[str, Any]:
},
}

@cached_property
def endpoints(self) -> list[str]:
eps: list[str] = []
for n in sorted(
self.nodes.values(),
key=lambda x: x.rank, # type: ignore[attr-defined]
):
eps.append(n.endpoint)
return eps

@classmethod
def from_dict(cls, config: dict[str, Any]) -> ClusterSpec:
"""Parses a raw config dictionary and returns a validated ClusterSpec."""
Expand Down
4 changes: 4 additions & 0 deletions mplang/core/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def send(self, to: int, key: str, data: Any) -> None:
def recv(self, frm: int, key: str) -> Any:
"""Receive data from peer with the given key"""

@abstractmethod
def onSent(self, frm: int, key: str, data: Any) -> None:
"""Called when a key is sent to self"""


class ICollective(ABC):
"""Interface for collective communication"""
Expand Down
7 changes: 3 additions & 4 deletions mplang/runtime/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pydantic import BaseModel

from mplang.core import IrReader, TableType, TensorType
from mplang.core.cluster import ClusterSpec
from mplang.kernels.base import KernelContext
from mplang.kernels.value import Value, decode_value, encode_value
from mplang.protos.v1alpha1 import mpir_pb2
Expand All @@ -40,6 +41,7 @@
Computation,
Session,
Symbol,
create_session_from_spec,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -260,15 +262,12 @@ def list_session_computations(session_name: str) -> ComputationListResponse:
def create_session(session_name: str, request: CreateSessionRequest) -> SessionResponse:
validate_name(session_name, "session")
# Delegate cluster spec parsing & session construction to resource layer
from mplang.core.cluster import ClusterSpec # local import to avoid cycles

if session_name in _sessions:
sess = _sessions[session_name]
else:
spec = ClusterSpec.from_dict(request.cluster_spec)
if len(spec.get_devices_by_kind("SPU")) == 0:
raise InvalidRequestError("No SPU device found in cluster_spec for session")
sess = Session(name=session_name, rank=request.rank, cluster_spec=spec)
sess = create_session_from_spec(name=session_name, rank=request.rank, spec=spec)
_sessions[session_name] = sess
return SessionResponse(name=sess.name)

Expand Down
59 changes: 32 additions & 27 deletions mplang/runtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

import spu.libspu as libspu

from mplang.core.cluster import ClusterSpec
from mplang.core.comm import ICommunicator
from mplang.core.expr.ast import Expr
from mplang.core.expr.evaluator import IEvaluator, create_evaluator
from mplang.core.mask import Mask
Expand Down Expand Up @@ -96,19 +98,25 @@ class SessionState:
class Session:
"""Represents the per-rank execution context.

Immutable config: name, rank, cluster_spec.
Immutable config: name, rank, cluster_spec, communicator.
Derived: node, runtime_info, endpoints, spu_device, spu_mask, protocol/field, is_spu_party.
Mutable: state (runtime object, symbols, computations, seeded flag).

Note: communicator is assumed to be initialized with cluster spec info (e.g. endpoints).
"""

def __init__(self, name: str, rank: int, cluster_spec: ClusterSpec):
def __init__(
self,
name: str,
rank: int,
cluster_spec: ClusterSpec,
communicator: ICommunicator,
):
self.name = name
self.rank = rank
self.cluster_spec = cluster_spec
self.state = SessionState()
self.communicator = HttpCommunicator(
session_name=name, rank=rank, endpoints=self.endpoints
)
self.communicator = communicator

# --- Derived topology ---
@cached_property
Expand All @@ -119,18 +127,9 @@ def node(self) -> Node:
def runtime_info(self) -> RuntimeInfo:
return self.node.runtime_info

@cached_property
@property
def endpoints(self) -> list[str]:
eps: list[str] = []
for n in sorted(
self.cluster_spec.nodes.values(),
key=lambda x: x.rank, # type: ignore[attr-defined]
):
ep = n.endpoint
if not ep.startswith(("http://", "https://")):
ep = f"http://{ep}"
eps.append(ep)
return eps
return self.cluster_spec.endpoints

@cached_property
def spu_device(self): # type: ignore
Expand Down Expand Up @@ -191,10 +190,11 @@ def ensure_spu_env(self) -> None:
if self.is_spu_party:
# Build SPU address list across all endpoints for ranks in mask
spu_addrs: list[str] = []
for r, addr in enumerate(self.communicator.endpoints):
for r, addr in enumerate(self.cluster_spec.endpoints):
if r in self.spu_mask:
if "//" not in addr:
addr = f"//{addr}"
# TODO(oeqqwq): addr may contain other schema like grpc://
if not addr.startswith(("http://", "https://")):
addr = f"http://{addr}"
parsed = urlparse(addr)
assert isinstance(parsed.port, int)
new_addr = f"{parsed.hostname}:{parsed.port + SPU_PORT_OFFSET}"
Expand Down Expand Up @@ -281,12 +281,17 @@ def execute(
)
self.add_symbol(Symbol(name=name, mptype={}, data=val))

# --- Convenience constructor ---
@classmethod
def from_cluster_spec_dict(cls, name: str, rank: int, spec_dict: dict) -> Session:
from mplang.core.cluster import ClusterSpec # local import to avoid cycles

spec = ClusterSpec.from_dict(spec_dict)
if len(spec.get_devices_by_kind("SPU")) == 0:
raise RuntimeError("No SPU device found in cluster_spec")
return cls(name=name, rank=rank, cluster_spec=spec)
# --- Convenience constructor use HttpCommunicator---
def create_session_from_spec(name: str, rank: int, spec: ClusterSpec) -> Session:
if len(spec.get_devices_by_kind("SPU")) == 0:
raise RuntimeError("No SPU device found in cluster_spec")

# Create HttpCommunicator for the session
communicator = HttpCommunicator(
session_name=name,
rank=rank,
endpoints=spec.endpoints,
)

return Session(name=name, rank=rank, cluster_spec=spec, communicator=communicator)
Loading