diff --git a/examples/download_slim_pajama.py b/examples/download_slim_pajama.py index 1f229a15..54de663a 100644 --- a/examples/download_slim_pajama.py +++ b/examples/download_slim_pajama.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 -import os import argparse -import requests +import os from concurrent.futures import ThreadPoolExecutor, as_completed +import requests + + def download_file(url, target_path): """Attempt to download a file from 'url' to 'target_path' up to 3 tries.""" tries = 3 diff --git a/mixtera/core/datacollection/datasets/__init__.py b/mixtera/core/datacollection/datasets/__init__.py index df8554df..1590ef75 100644 --- a/mixtera/core/datacollection/datasets/__init__.py +++ b/mixtera/core/datacollection/datasets/__init__.py @@ -6,6 +6,24 @@ from .dataset_type import DatasetType # noqa: F401 from .jsonl_dataset import JSONLDataset # noqa: F401 from .parquet_dataset import ParquetDataset # noqa: F401 -from .web_dataset import WebDataset +from .web_dataset import ( # noqa: F401 + CC12MDataset, + COYO700MDataset, + DomainNetDataset, + LAION400MDataset, + MSCOCODataset, + WebDataset, +) -__all__ = ["Dataset", "DatasetType", "JSONLDataset", "ParquetDataset", "WebDataset"] +__all__ = [ + "Dataset", + "DatasetType", + "JSONLDataset", + "ParquetDataset", + "WebDataset", + "CC12MDataset", + "MSCOCODataset", + "LAION400MDataset", + "COYO700MDataset", + "DomainNetDataset", +] diff --git a/mixtera/core/datacollection/datasets/dataset.py b/mixtera/core/datacollection/datasets/dataset.py index ed7cb508..100dc68f 100644 --- a/mixtera/core/datacollection/datasets/dataset.py +++ b/mixtera/core/datacollection/datasets/dataset.py @@ -11,7 +11,7 @@ class Dataset(ABC): type: DatasetType = DatasetType.GENERIC_DATASET @staticmethod - def from_type_id(type_id: int) -> "Type[Dataset]": + def from_type_id(type_id: int) -> "Type[Dataset]": # pylint: disable=too-many-return-statements """ This method instantiates a dataset from an integer type ID (e.g., stored in a DB). @@ -32,6 +32,34 @@ def from_type_id(type_id: int) -> "Type[Dataset]": from mixtera.core.datacollection.datasets import WebDataset # pylint: disable=import-outside-toplevel return WebDataset + if dataset_type == DatasetType.CC12M_DATASET: + from mixtera.core.datacollection.datasets import CC12MDataset # pylint: disable=import-outside-toplevel + + return CC12MDataset + if dataset_type == DatasetType.MSCOCO_DATASET: + from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel + MSCOCODataset, + ) + + return MSCOCODataset + if dataset_type == DatasetType.LAION400M_DATASET: + from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel + LAION400MDataset, + ) + + return LAION400MDataset + if dataset_type == DatasetType.COYO700M_DATASET: + from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel + COYO700MDataset, + ) + + return COYO700MDataset + if dataset_type == DatasetType.DOMAINNET_DATASET: + from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel + DomainNetDataset, + ) + + return DomainNetDataset if dataset_type == DatasetType.PARQUET_DATASET: from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel ParquetDataset, diff --git a/mixtera/core/datacollection/datasets/dataset_type.py b/mixtera/core/datacollection/datasets/dataset_type.py index 182a1c76..7d11edaf 100644 --- a/mixtera/core/datacollection/datasets/dataset_type.py +++ b/mixtera/core/datacollection/datasets/dataset_type.py @@ -7,3 +7,8 @@ class DatasetType(IntEnum): CROISSANT_DATASET = auto() WEB_DATASET = auto() PARQUET_DATASET = auto() + CC12M_DATASET = auto() + MSCOCO_DATASET = auto() + LAION400M_DATASET = auto() + COYO700M_DATASET = auto() + DOMAINNET_DATASET = auto() diff --git a/mixtera/core/datacollection/datasets/web_dataset.py b/mixtera/core/datacollection/datasets/web_dataset.py index e076e874..5a7882a4 100644 --- a/mixtera/core/datacollection/datasets/web_dataset.py +++ b/mixtera/core/datacollection/datasets/web_dataset.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Callable, Iterable, Optional +from typing import Callable, ClassVar, Iterable, Optional from loguru import logger from mixtera.core.datacollection.datasets import Dataset, DatasetType @@ -11,6 +11,7 @@ class WebDataset(Dataset): type: DatasetType = DatasetType.WEB_DATASET + dataset_name: ClassVar[str] = "WebDataset" @staticmethod def iterate_files(loc: str) -> Iterable[str]: @@ -23,17 +24,27 @@ def iterate_files(loc: str) -> Iterable[str]: @staticmethod def inform_metadata_parser(loc: Path, metadata_parser: MetadataParser) -> None: - samples = IndexedTarSamples(str(loc), decode_images=False) + """Parse metadata from a WebDataset tar file.""" + cls = WebDataset # Use the current class (works for subclasses too) + dataset_name = getattr(cls, "dataset_name", cls.__name__) + samples = IndexedTarSamples(str(loc)) + + logger.info(f"Starting to iterate over samples ({cls.__name__}) in folder: {loc}") for idx, sample in enumerate(samples): - metadata_parser.parse(line_number=idx, payload=sample) + metadata_parser.parse( + line_number=idx, + payload=sample, + dataset_name=dataset_name if dataset_name != "WebDataset" else None, + ) samples.close() @staticmethod def read_ranges_from_files( ranges_per_file: dict[str, list[tuple[int, int]]], - parsing_func: Callable[[str | dict], str], # Will not necessarily take a string? + # Will not necessarily take a string? + parsing_func: Callable[[str | dict], str], server_connection: Optional[ServerConnection], ) -> Iterable[str | dict]: for file, range_list in ranges_per_file.items(): @@ -61,3 +72,49 @@ def _read_ranges_from_file( # pylint: disable=contextmanager-generator-missing- yield from (parsing_func(samples[line]) for line in range(r_start, r_end)) last_line_read = r_end + + +class CC12MDataset(WebDataset): + type: DatasetType = DatasetType.CC12M_DATASET + dataset_name: ClassVar[str] = "CC12M" + + +class MSCOCODataset(WebDataset): + type: DatasetType = DatasetType.MSCOCO_DATASET + dataset_name: ClassVar[str] = "MSCOCO" + + +class LAION400MDataset(WebDataset): + type: DatasetType = DatasetType.LAION400M_DATASET + dataset_name: ClassVar[str] = "LAION400M" + + +class COYO700MDataset(WebDataset): + type: DatasetType = DatasetType.COYO700M_DATASET + dataset_name: ClassVar[str] = "COYO700M" + + +class DomainNetDataset(WebDataset): + type: DatasetType = DatasetType.DOMAINNET_DATASET + dataset_name: ClassVar[str] = "DomainNet" + + @staticmethod + def inform_metadata_parser(loc: Path, metadata_parser: MetadataParser) -> None: + dataset_name = DomainNetDataset.dataset_name + + samples = IndexedTarSamples(str(loc)) + + logger.info(f"Starting to iterate over samples (DomainNet) in folder: {loc}") + for idx, sample in enumerate(samples): + class_name = sample["cls"] + domain = sample["domain"] + + metadata_parser.parse( + line_number=idx, + payload=sample, + dataset_name=dataset_name, + class_name=class_name, + domain=domain, + ) + + samples.close() diff --git a/mixtera/core/datacollection/index/parser/__init__.py b/mixtera/core/datacollection/index/parser/__init__.py index 9460a40b..52cc6ee3 100644 --- a/mixtera/core/datacollection/index/parser/__init__.py +++ b/mixtera/core/datacollection/index/parser/__init__.py @@ -3,9 +3,16 @@ """ from .metadata_parser import MetadataParser, MetadataProperty # noqa: F401 -from .parser_collection import MetadataParserFactory, RedPajamaMetadataParser # noqa: F401 +from .parser_collection import ( # noqa: F401 + DomainNetMetadataParser, + GenericMetadataParser, + MetadataParserFactory, + RedPajamaMetadataParser, +) __all__ = [ + "DomainNetMetadataParser", + "GenericMetadataParser", "MetadataParser", "MetadataProperty", "RedPajamaMetadataParser", diff --git a/mixtera/core/datacollection/index/parser/parser_collection.py b/mixtera/core/datacollection/index/parser/parser_collection.py index 83bfa9ae..c918ba5a 100644 --- a/mixtera/core/datacollection/index/parser/parser_collection.py +++ b/mixtera/core/datacollection/index/parser/parser_collection.py @@ -87,7 +87,13 @@ class FineWebMetadataParser(MetadataParser): def get_properties(cls) -> list[MetadataProperty]: return [ MetadataProperty(name="dump", dtype="STRING", multiple=False, nullable=False), - MetadataProperty(name="language", dtype="ENUM", enum_options=["en"], multiple=False, nullable=False), + MetadataProperty( + name="language", + dtype="ENUM", + enum_options=["en"], + multiple=False, + nullable=False, + ), ] def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: @@ -168,6 +174,63 @@ def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any self.add_metadata(sample_id=line_number, pile_set_name=pile_set_name) +class GenericMetadataParser(MetadataParser): + """ + Metadata parser with only the source dataset name as a property. + """ + + @classmethod + def get_properties(cls) -> list[MetadataProperty]: + return [ + MetadataProperty( + name="dataset", + dtype="STRING", + multiple=False, + nullable=False, + ) + ] + + def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: + dataset_name = kwargs.get("dataset_name") + self.add_metadata(sample_id=line_number, dataset=dataset_name) + + +class DomainNetMetadataParser(MetadataParser): + """ + Metadata parser class for the DomainNet dataset. + """ + + @classmethod + def get_properties(cls) -> list[MetadataProperty]: + return [ + MetadataProperty( + name="domain", + dtype="STRING", + multiple=False, + nullable=False, + ), + MetadataProperty( + name="class_name", + dtype="STRING", + multiple=False, + nullable=False, + ), + MetadataProperty( + name="dataset", + dtype="STRING", + multiple=False, + nullable=False, + ), + ] + + def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: + dataset = kwargs.get("dataset_name") + domain = kwargs.get("domain") + class_name = kwargs.get("class_name") + + self.add_metadata(sample_id=line_number, domain=domain, class_name=class_name, dataset=dataset) + + class MetadataParserFactory: """Handles the creation of metadata parsers.""" @@ -180,6 +243,8 @@ def __init__(self) -> None: "FINEWEB": FineWebMetadataParser, "MSCOCO": MsCocoParser, "PILE": PileaMetadataParser, + "GENERIC": GenericMetadataParser, + "DOMAINNET": DomainNetMetadataParser, } def add_parser(self, parser_name: str, parser: type[MetadataParser], overwrite: bool = False) -> bool: diff --git a/mixtera/multimodal/__init__.py b/mixtera/multimodal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mixtera/multimodal/webdataset/__init__.py b/mixtera/multimodal/webdataset/__init__.py new file mode 100644 index 00000000..ae3b365a --- /dev/null +++ b/mixtera/multimodal/webdataset/__init__.py @@ -0,0 +1,3 @@ +from .pipeline import MixteraDataPipeline + +__all__ = ["MixteraDataPipeline"] diff --git a/mixtera/multimodal/webdataset/pipeline.py b/mixtera/multimodal/webdataset/pipeline.py new file mode 100644 index 00000000..e91ca724 --- /dev/null +++ b/mixtera/multimodal/webdataset/pipeline.py @@ -0,0 +1,35 @@ +from typing import Any, Iterable + +import webdataset as wds +from mixtera.core.client.mixtera_client import MixteraClient, QueryExecutionArgs, ResultStreamingArgs +from mixtera.core.query.query import Query +from mixtera.torch import MixteraTorchDataset + + +class MixteraDataPipeline(wds.DataPipeline): + """ + Supports building arbitrary webdataset pipelines with Mixtera's `MixteraTorchDataset` as the data source. + """ + + def __init__( + self, + client: MixteraClient, + query: Query, + query_execution_args: QueryExecutionArgs, + result_streaming_args: ResultStreamingArgs, + pipeline: Iterable[Any], + ): + super().__init__(*pipeline) + self.client = client + self.query = query + self.query_execution_args = query_execution_args + self.result_streaming_args = result_streaming_args + + torch_dataset = MixteraTorchDataset( + client=client, + query=query, + query_execution_args=query_execution_args, + result_streaming_args=result_streaming_args, + ) + + self.pipeline.insert(0, torch_dataset) diff --git a/mixtera/tests/core/datacollection/datasets/test_web_dataset.py b/mixtera/tests/core/datacollection/datasets/test_web_dataset.py index 659fec8e..1299b565 100644 --- a/mixtera/tests/core/datacollection/datasets/test_web_dataset.py +++ b/mixtera/tests/core/datacollection/datasets/test_web_dataset.py @@ -41,19 +41,19 @@ def test_read_ranges_from_tar_e2e(self): } expected = [ - {"__key__": "000001", ".cls": 1}, - {"__key__": "000002", ".cls": 0}, - {"__key__": "000003", ".cls": 1}, - {"__key__": "000004", ".cls": 0}, - {"__key__": "000005", ".cls": 1}, - {"__key__": "000007", ".cls": 1}, - {"__key__": "000008", ".cls": 0}, + {"__key__": "000001", "cls": b"1"}, + {"__key__": "000002", "cls": b"0"}, + {"__key__": "000003", "cls": b"1"}, + {"__key__": "000004", "cls": b"0"}, + {"__key__": "000005", "cls": b"1"}, + {"__key__": "000007", "cls": b"1"}, + {"__key__": "000008", "cls": b"0"}, ] result = list(WebDataset.read_ranges_from_files(ranges_per_file, lambda x: x, None)) - assert all(".png" in sample for sample in result) + assert all("png" in sample for sample in result) - result = [{k: v for k, v in sample.items() if k in [".cls", "__key__"]} for sample in result] + result = [{k: v for k, v in sample.items() if k in ["cls", "__key__"]} for sample in result] self.assertEqual(result, expected) diff --git a/mixtera/utils/webdataset_utils.py b/mixtera/utils/webdataset_utils.py index 6c81af04..bb75e17a 100644 --- a/mixtera/utils/webdataset_utils.py +++ b/mixtera/utils/webdataset_utils.py @@ -1,13 +1,12 @@ import gzip import io -from functools import partial from typing import Any, Iterator from wids.wids import group_by_key, splitname from wids.wids_mmtar import MMIndexedTar -def decode(sample: dict[str, Any], decode_image: bool = True) -> dict[str, Any]: +def decode_sample(sample: dict[str, Any]) -> dict[str, Any]: """ A utility function to decode the samples from the tar file for many common extensions. """ @@ -32,7 +31,7 @@ def decode(sample: dict[str, Any], decode_image: bool = True) -> dict[str, Any]: elif extension in ["cls", "cls2"]: value = stream.read() sample[key] = int(value.decode("utf-8")) - elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"] and decode_image: + elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]: import torchvision.transforms.functional as F # pylint: disable=import-outside-toplevel from PIL import Image # pylint: disable=import-outside-toplevel @@ -54,8 +53,18 @@ def decode(sample: dict[str, Any], decode_image: bool = True) -> dict[str, Any]: return sample +class MMIndexedTarRawBytes(MMIndexedTar): + """ + A subclass of `MMIndexedTar` that returns the raw bytes instead of an IOBytes object. + """ + + def get_file(self, i: int) -> tuple[str, bytes]: + filename, data = self.get_at_index(i) + return filename, data + + class IndexedTarSamples: - def __init__(self, path: str, decode_images: bool = True): + def __init__(self, path: str, decode: bool = False): """ A class for efficient reading of tar files for web datasets. @@ -64,9 +73,9 @@ def __init__(self, path: str, decode_images: bool = True): and with decoding integrated. """ self.path = path - self.decoder = partial(decode, decode_image=decode_images) - self.stream = open(self.path, "rb") # pylint: disable=consider-using-with - self.reader = MMIndexedTar(self.stream) + self.decoder = decode_sample + self.decode = decode + self.reader = MMIndexedTarRawBytes(path) all_files = self.reader.names() self.samples = group_by_key(all_files) @@ -80,8 +89,6 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore def close(self) -> None: if self.reader is not None: self.reader.close() - if self.stream is not None and not self.stream.closed: - self.stream.close() def __len__(self) -> int: return len(self.samples) @@ -96,10 +103,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]: k, ext = splitname(fname) key = key or k assert key == k, "Inconsistent keys in the same sample" - sample[ext] = data - sample["__key__"] = key - return self.decoder(sample) - raise ValueError("Co") + sample[ext[1:]] = data + sample["__key__"] = key # type: ignore + return self.decoder(sample) if self.decode else sample + raise ValueError("Error reading sample") def __iter__(self) -> Iterator[dict[str, Any]]: for idx in range(len(self)): diff --git a/setup.py b/setup.py index d43d5c20..de6b1d12 100644 --- a/setup.py +++ b/setup.py @@ -4,14 +4,14 @@ import io import os import pathlib +import shutil +import socket import subprocess import sysconfig -import socket -import shutil from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext from setuptools.command.build import build +from setuptools.command.build_ext import build_ext # Package meta-data. NAME = "mixtera"