Skip to content

Commit f9fe941

Browse files
Misc: variable reuse
Signed-off-by: Shah, Karan <[email protected]>
1 parent 15d6226 commit f9fe941

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

openfl/pipelines/tensor_codec.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
6969
metadata: metadata associated with compressed tensor.
7070
"""
7171
if require_lossless:
72-
compressed_nparray, metadata = self.lossless_pipeline.forward(data, **kwargs)
72+
data, metadata = self.lossless_pipeline.forward(data, **kwargs)
7373
else:
74-
compressed_nparray, metadata = self.compression_pipeline.forward(data, **kwargs)
74+
data, metadata = self.compression_pipeline.forward(data, **kwargs)
7575
# Define the compressed tensorkey that should be
7676
# returned ('trained.delta'->'trained.delta.lossy_compressed')
7777
tensor_name, origin, round_number, report, tags = tensor_key
@@ -80,7 +80,7 @@ def compress(self, tensor_key, data, require_lossless=False, **kwargs):
8080
else:
8181
new_tags = change_tags(tags, add_field="lossy_compressed")
8282
compressed_tensor_key = TensorKey(tensor_name, origin, round_number, report, new_tags)
83-
return compressed_tensor_key, compressed_nparray, metadata
83+
return compressed_tensor_key, data, metadata
8484

8585
def decompress(
8686
self,
@@ -121,13 +121,9 @@ def decompress(
121121
assert "compressed" in tags, "Cannot losslessly decompress lossy tensor"
122122

123123
if require_lossless or "compressed" in tags:
124-
decompressed_nparray = self.lossless_pipeline.backward(
125-
data, transformer_metadata, **kwargs
126-
)
124+
data = self.lossless_pipeline.backward(data, transformer_metadata, **kwargs)
127125
else:
128-
decompressed_nparray = self.compression_pipeline.backward(
129-
data, transformer_metadata, **kwargs
130-
)
126+
data = self.compression_pipeline.backward(data, transformer_metadata, **kwargs)
131127
# Define the decompressed tensorkey that should be returned
132128
if "lossy_compressed" in tags:
133129
new_tags = change_tags(
@@ -144,7 +140,7 @@ def decompress(
144140
else:
145141
raise NotImplementedError("Decompression is only supported on compressed data")
146142

147-
return decompressed_tensor_key, decompressed_nparray
143+
return decompressed_tensor_key, data
148144

149145
@staticmethod
150146
def generate_delta(tensor_key, nparray, base_model_nparray):

0 commit comments

Comments
 (0)