@@ -69,9 +69,9 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
69
69
metadata: metadata associated with compressed tensor.
70
70
"""
71
71
if require_lossless :
72
- compressed_nparray , metadata = self .lossless_pipeline .forward (data , ** kwargs )
72
+ data , metadata = self .lossless_pipeline .forward (data , ** kwargs )
73
73
else :
74
- compressed_nparray , metadata = self .compression_pipeline .forward (data , ** kwargs )
74
+ data , metadata = self .compression_pipeline .forward (data , ** kwargs )
75
75
# Define the compressed tensorkey that should be
76
76
# returned ('trained.delta'->'trained.delta.lossy_compressed')
77
77
tensor_name , origin , round_number , report , tags = tensor_key
@@ -80,7 +80,7 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
80
80
else :
81
81
new_tags = change_tags (tags , add_field = "lossy_compressed" )
82
82
compressed_tensor_key = TensorKey (tensor_name , origin , round_number , report , new_tags )
83
- return compressed_tensor_key , compressed_nparray , metadata
83
+ return compressed_tensor_key , data , metadata
84
84
85
85
def decompress (
86
86
self ,
@@ -121,13 +121,9 @@ def decompress(
121
121
assert "compressed" in tags , "Cannot losslessly decompress lossy tensor"
122
122
123
123
if require_lossless or "compressed" in tags :
124
- decompressed_nparray = self .lossless_pipeline .backward (
125
- data , transformer_metadata , ** kwargs
126
- )
124
+ data = self .lossless_pipeline .backward (data , transformer_metadata , ** kwargs )
127
125
else :
128
- decompressed_nparray = self .compression_pipeline .backward (
129
- data , transformer_metadata , ** kwargs
130
- )
126
+ data = self .compression_pipeline .backward (data , transformer_metadata , ** kwargs )
131
127
# Define the decompressed tensorkey that should be returned
132
128
if "lossy_compressed" in tags :
133
129
new_tags = change_tags (
@@ -144,7 +140,7 @@ def decompress(
144
140
else :
145
141
raise NotImplementedError ("Decompression is only supported on compressed data" )
146
142
147
- return decompressed_tensor_key , decompressed_nparray
143
+ return decompressed_tensor_key , data
148
144
149
145
@staticmethod
150
146
def generate_delta (tensor_key , nparray , base_model_nparray ):
0 commit comments