Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/ml_flashpoint/core/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,63 @@
_LOGGER = get_logger(__name__)


# Allowlist of (module, name) pairs for safe metadata deserialization.
# Replaces insecure pickle.load() to prevent arbitrary code execution when loading
# .metadata files from untrusted sources (e.g., peer nodes, shared storage).
# See CWE-502: Deserialization of Untrusted Data.
_SAFE_METADATA_UNPICKLE_ALLOWLIST: frozenset[Tuple[str, str]] = frozenset(
[
# torch.distributed.checkpoint.metadata classes
("torch.distributed.checkpoint.metadata", "Metadata"),
("torch.distributed.checkpoint.metadata", "MetadataIndex"),
("torch.distributed.checkpoint.metadata", "TensorStorageMetadata"),
("torch.distributed.checkpoint.metadata", "BytesStorageMetadata"),
("torch.distributed.checkpoint.metadata", "StorageMeta"),
("torch.distributed.checkpoint.metadata", "TensorProperties"),
("torch.distributed.checkpoint.metadata", "ChunkStorageMetadata"),
("torch.distributed.checkpoint.metadata", "_MEM_FORMAT_ENCODING"),
# torch.serialization
("torch.serialization", "_get_layout"),
# torch types
("torch", "Size"),
("torch", "float32"),
("torch", "float16"),
("torch", "float64"),
("torch", "bfloat16"),
("torch", "int8"),
("torch", "uint8"),
("torch", "int16"),
("torch", "int32"),
("torch", "int64"),
("torch", "bool"),
("torch", "complex64"),
("torch", "complex128"),
("torch", "strided"),
("torch", "sparse_coo"),
("torch", "sparse_csr"),
("torch", "sparse_bsr"),
("torch", "sparse_csc"),
("torch", "sparse_bsc"),
("torch", "jagged"),
]
)


class _RestrictedMetadataUnpickler(pickle.Unpickler):
"""Unpickler that only allows deserializing PyTorch checkpoint Metadata classes.

Prevents arbitrary code execution from malicious pickle payloads (CWE-502).
"""

def find_class(self, module: str, name: str) -> type:
key = (module, name)
if key not in _SAFE_METADATA_UNPICKLE_ALLOWLIST:
raise pickle.UnpicklingError(
f"Unsafe deserialization blocked: ({module!r}, {name!r}) is not in the allowlist. "
"Metadata files must only contain PyTorch checkpoint metadata structures."
)
return super().find_class(module, name)

class MLFlashpointCheckpointLoader(abc.ABC):
"""
This is the main interface for loading checkpoints, providing functionality for the different
Expand Down Expand Up @@ -169,7 +226,7 @@ def read_metadata(
metadata_path = Path(checkpoint_id.data) / object_name
try:
with open(metadata_path, "rb") as f:
return pickle.load(f)
return _RestrictedMetadataUnpickler(f).load()
except Exception:
_LOGGER.exception("Error reading metadata from '%s'", metadata_path)
raise
Expand Down
25 changes: 25 additions & 0 deletions tests/core/test_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ def test_read_metadata_invalid_format(self, checkpoint_directory):
object_name="invalid_metadata.pt",
)

def test_read_metadata_rejects_malicious_pickle(self, checkpoint_directory):
"""Security test: malicious pickle payloads must not execute arbitrary code (CWE-502).

An attacker who controls a peer node or shared checkpoint storage could craft a
malicious .metadata file. The restricted unpickler must block deserialization
of unsafe classes (e.g., exec, os.system) while allowing valid Metadata.
"""
metadata_path = Path(checkpoint_directory) / ".metadata"

class MaliciousPayload:
def __reduce__(self):
return (exec, ("open('pwned.txt', 'w').write('pwned')",))

with open(metadata_path, "wb") as f:
pickle.dump(MaliciousPayload(), f)

with pytest.raises(pickle.UnpicklingError, match="Unsafe deserialization blocked"):
self.loader.read_metadata(
CheckpointContainerId(checkpoint_directory),
object_name=".metadata",
)

# Ensure no code execution occurred
assert not Path("pwned.txt").exists(), "Malicious pickle must not execute arbitrary code"


class TestReadTensor:
@pytest.fixture
Expand Down