Skip to content

Commit 029b2df

Browse files
authored
Fix server-client serialization nit (#365)
* fix serialization nit Signed-off-by: youliangt <youliangt@nvidia.com> * fix style Signed-off-by: youliangt <youliangt@nvidia.com> --------- Signed-off-by: youliangt <youliangt@nvidia.com>
1 parent 7f53666 commit 029b2df

File tree

2 files changed

+32
-54
lines changed

2 files changed

+32
-54
lines changed

gr00t/eval/service.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,44 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import io
17+
import json
1618
from dataclasses import dataclass
1719
from typing import Any, Callable, Dict
1820

21+
import msgpack
22+
import numpy as np
1923
import zmq
2024

21-
import gr00t.utils.serialization as serialization
25+
from gr00t.data.dataset import ModalityConfig
2226

2327

24-
class TorchSerializer:
28+
class MsgSerializer:
2529
@staticmethod
2630
def to_bytes(data: dict) -> bytes:
27-
return serialization.dumps(data)
31+
return msgpack.packb(data, default=MsgSerializer.encode_custom_classes)
2832

2933
@staticmethod
3034
def from_bytes(data: bytes) -> dict:
31-
return serialization.loads(data)
35+
return msgpack.unpackb(data, object_hook=MsgSerializer.decode_custom_classes)
36+
37+
@staticmethod
38+
def decode_custom_classes(obj):
39+
if "__ModalityConfig_class__" in obj:
40+
obj = ModalityConfig(**json.loads(obj["as_json"]))
41+
if "__ndarray_class__" in obj:
42+
obj = np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False)
43+
return obj
44+
45+
@staticmethod
46+
def encode_custom_classes(obj):
47+
if isinstance(obj, ModalityConfig):
48+
return {"__ModalityConfig_class__": True, "as_json": obj.model_dump_json()}
49+
if isinstance(obj, np.ndarray):
50+
output = io.BytesIO()
51+
np.save(output, obj, allow_pickle=False)
52+
return {"__ndarray_class__": True, "as_npy": output.getvalue()}
53+
return obj
3254

3355

3456
@dataclass
@@ -92,12 +114,12 @@ def run(self):
92114
while self.running:
93115
try:
94116
message = self.socket.recv()
95-
request = TorchSerializer.from_bytes(message)
117+
request = MsgSerializer.from_bytes(message)
96118

97119
# Validate token before processing request
98120
if not self._validate_token(request):
99121
self.socket.send(
100-
TorchSerializer.to_bytes({"error": "Unauthorized: Invalid API token"})
122+
MsgSerializer.to_bytes({"error": "Unauthorized: Invalid API token"})
101123
)
102124
continue
103125

@@ -112,13 +134,13 @@ def run(self):
112134
if handler.requires_input
113135
else handler.handler()
114136
)
115-
self.socket.send(TorchSerializer.to_bytes(result))
137+
self.socket.send(MsgSerializer.to_bytes(result))
116138
except Exception as e:
117139
print(f"Error in server: {e}")
118140
import traceback
119141

120142
print(traceback.format_exc())
121-
self.socket.send(TorchSerializer.to_bytes({"error": str(e)}))
143+
self.socket.send(MsgSerializer.to_bytes({"error": str(e)}))
122144

123145

124146
class BaseInferenceClient:
@@ -172,9 +194,9 @@ def call_endpoint(
172194
if self.api_token:
173195
request["api_token"] = self.api_token
174196

175-
self.socket.send(TorchSerializer.to_bytes(request))
197+
self.socket.send(MsgSerializer.to_bytes(request))
176198
message = self.socket.recv()
177-
response = TorchSerializer.from_bytes(message)
199+
response = MsgSerializer.from_bytes(message)
178200

179201
if "error" in response:
180202
raise RuntimeError(f"Server error: {response['error']}")

gr00t/utils/serialization.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

0 commit comments

Comments
 (0)