1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import io
17+ import json
1618from dataclasses import dataclass
1719from typing import Any , Callable , Dict
1820
21+ import msgpack
22+ import numpy as np
1923import zmq
2024
21- import gr00t .utils . serialization as serialization
25+ from gr00t .data . dataset import ModalityConfig
2226
2327
24- class TorchSerializer :
28+ class MsgSerializer :
2529 @staticmethod
2630 def to_bytes (data : dict ) -> bytes :
27- return serialization . dumps (data )
31+ return msgpack . packb (data , default = MsgSerializer . encode_custom_classes )
2832
2933 @staticmethod
3034 def from_bytes (data : bytes ) -> dict :
31- return serialization .loads (data )
35+ return msgpack .unpackb (data , object_hook = MsgSerializer .decode_custom_classes )
36+
37+ @staticmethod
38+ def decode_custom_classes (obj ):
39+ if "__ModalityConfig_class__" in obj :
40+ obj = ModalityConfig (** json .loads (obj ["as_json" ]))
41+ if "__ndarray_class__" in obj :
42+ obj = np .load (io .BytesIO (obj ["as_npy" ]), allow_pickle = False )
43+ return obj
44+
45+ @staticmethod
46+ def encode_custom_classes (obj ):
47+ if isinstance (obj , ModalityConfig ):
48+ return {"__ModalityConfig_class__" : True , "as_json" : obj .model_dump_json ()}
49+ if isinstance (obj , np .ndarray ):
50+ output = io .BytesIO ()
51+ np .save (output , obj , allow_pickle = False )
52+ return {"__ndarray_class__" : True , "as_npy" : output .getvalue ()}
53+ return obj
3254
3355
3456@dataclass
@@ -92,12 +114,12 @@ def run(self):
92114 while self .running :
93115 try :
94116 message = self .socket .recv ()
95- request = TorchSerializer .from_bytes (message )
117+ request = MsgSerializer .from_bytes (message )
96118
97119 # Validate token before processing request
98120 if not self ._validate_token (request ):
99121 self .socket .send (
100- TorchSerializer .to_bytes ({"error" : "Unauthorized: Invalid API token" })
122+ MsgSerializer .to_bytes ({"error" : "Unauthorized: Invalid API token" })
101123 )
102124 continue
103125
@@ -112,13 +134,13 @@ def run(self):
112134 if handler .requires_input
113135 else handler .handler ()
114136 )
115- self .socket .send (TorchSerializer .to_bytes (result ))
137+ self .socket .send (MsgSerializer .to_bytes (result ))
116138 except Exception as e :
117139 print (f"Error in server: { e } " )
118140 import traceback
119141
120142 print (traceback .format_exc ())
121- self .socket .send (TorchSerializer .to_bytes ({"error" : str (e )}))
143+ self .socket .send (MsgSerializer .to_bytes ({"error" : str (e )}))
122144
123145
124146class BaseInferenceClient :
@@ -172,9 +194,9 @@ def call_endpoint(
172194 if self .api_token :
173195 request ["api_token" ] = self .api_token
174196
175- self .socket .send (TorchSerializer .to_bytes (request ))
197+ self .socket .send (MsgSerializer .to_bytes (request ))
176198 message = self .socket .recv ()
177- response = TorchSerializer .from_bytes (message )
199+ response = MsgSerializer .from_bytes (message )
178200
179201 if "error" in response :
180202 raise RuntimeError (f"Server error: { response ['error' ]} " )
0 commit comments