Skip to content
Open
6 changes: 4 additions & 2 deletions examples/download_slim_pajama.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env python3
import os
import argparse
import requests
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

import requests
Comment on lines -2 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary changes, no?



def download_file(url, target_path):
"""Attempt to download a file from 'url' to 'target_path' up to 3 tries."""
tries = 3
Expand Down
22 changes: 20 additions & 2 deletions mixtera/core/datacollection/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
from .dataset_type import DatasetType # noqa: F401
from .jsonl_dataset import JSONLDataset # noqa: F401
from .parquet_dataset import ParquetDataset # noqa: F401
from .web_dataset import WebDataset
from .web_dataset import ( # noqa: F401
CC12MDataset,
COYO700MDataset,
DomainNetDataset,
LAION400MDataset,
MSCOCODataset,
WebDataset,
)

__all__ = ["Dataset", "DatasetType", "JSONLDataset", "ParquetDataset", "WebDataset"]
__all__ = [
"Dataset",
"DatasetType",
"JSONLDataset",
"ParquetDataset",
"WebDataset",
"CC12MDataset",
"MSCOCODataset",
"LAION400MDataset",
"COYO700MDataset",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, this must change

"DomainNetDataset",
]
30 changes: 29 additions & 1 deletion mixtera/core/datacollection/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Dataset(ABC):
type: DatasetType = DatasetType.GENERIC_DATASET

@staticmethod
def from_type_id(type_id: int) -> "Type[Dataset]":
def from_type_id(type_id: int) -> "Type[Dataset]": # pylint: disable=too-many-return-statements
"""
This method instantiates a dataset from an integer type ID (e.g., stored in a DB).

Expand All @@ -32,6 +32,34 @@ def from_type_id(type_id: int) -> "Type[Dataset]":
from mixtera.core.datacollection.datasets import WebDataset # pylint: disable=import-outside-toplevel

return WebDataset
if dataset_type == DatasetType.CC12M_DATASET:
from mixtera.core.datacollection.datasets import CC12MDataset # pylint: disable=import-outside-toplevel

return CC12MDataset
if dataset_type == DatasetType.MSCOCO_DATASET:
from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel
MSCOCODataset,
)

return MSCOCODataset
if dataset_type == DatasetType.LAION400M_DATASET:
from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel
LAION400MDataset,
)

return LAION400MDataset
if dataset_type == DatasetType.COYO700M_DATASET:
from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel
COYO700MDataset,
)

return COYO700MDataset
if dataset_type == DatasetType.DOMAINNET_DATASET:
from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel
DomainNetDataset,
)

return DomainNetDataset
if dataset_type == DatasetType.PARQUET_DATASET:
from mixtera.core.datacollection.datasets import ( # pylint: disable=import-outside-toplevel
ParquetDataset,
Expand Down
5 changes: 5 additions & 0 deletions mixtera/core/datacollection/datasets/dataset_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ class DatasetType(IntEnum):
CROISSANT_DATASET = auto()
WEB_DATASET = auto()
PARQUET_DATASET = auto()
CC12M_DATASET = auto()
MSCOCO_DATASET = auto()
LAION400M_DATASET = auto()
COYO700M_DATASET = auto()
DOMAINNET_DATASET = auto()
65 changes: 61 additions & 4 deletions mixtera/core/datacollection/datasets/web_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Callable, Iterable, Optional
from typing import Callable, ClassVar, Iterable, Optional

from loguru import logger
from mixtera.core.datacollection.datasets import Dataset, DatasetType
Expand All @@ -11,6 +11,7 @@

class WebDataset(Dataset):
type: DatasetType = DatasetType.WEB_DATASET
dataset_name: ClassVar[str] = "WebDataset"

@staticmethod
def iterate_files(loc: str) -> Iterable[str]:
Expand All @@ -23,17 +24,27 @@ def iterate_files(loc: str) -> Iterable[str]:

@staticmethod
def inform_metadata_parser(loc: Path, metadata_parser: MetadataParser) -> None:
samples = IndexedTarSamples(str(loc), decode_images=False)
"""Parse metadata from a WebDataset tar file."""
cls = WebDataset # Use the current class (works for subclasses too)
dataset_name = getattr(cls, "dataset_name", cls.__name__)

samples = IndexedTarSamples(str(loc))

logger.info(f"Starting to iterate over samples ({cls.__name__}) in folder: {loc}")
for idx, sample in enumerate(samples):
metadata_parser.parse(line_number=idx, payload=sample)
metadata_parser.parse(
line_number=idx,
payload=sample,
dataset_name=dataset_name if dataset_name != "WebDataset" else None,
)

samples.close()

@staticmethod
def read_ranges_from_files(
ranges_per_file: dict[str, list[tuple[int, int]]],
parsing_func: Callable[[str | dict], str], # Will not necessarily take a string?
# Will not necessarily take a string?
parsing_func: Callable[[str | dict], str],
server_connection: Optional[ServerConnection],
) -> Iterable[str | dict]:
for file, range_list in ranges_per_file.items():
Expand Down Expand Up @@ -61,3 +72,49 @@ def _read_ranges_from_file( # pylint: disable=contextmanager-generator-missing-
yield from (parsing_func(samples[line]) for line in range(r_start, r_end))

last_line_read = r_end


class CC12MDataset(WebDataset):
type: DatasetType = DatasetType.CC12M_DATASET
dataset_name: ClassVar[str] = "CC12M"


class MSCOCODataset(WebDataset):
type: DatasetType = DatasetType.MSCOCO_DATASET
dataset_name: ClassVar[str] = "MSCOCO"


class LAION400MDataset(WebDataset):
type: DatasetType = DatasetType.LAION400M_DATASET
dataset_name: ClassVar[str] = "LAION400M"


class COYO700MDataset(WebDataset):
type: DatasetType = DatasetType.COYO700M_DATASET
dataset_name: ClassVar[str] = "COYO700M"


class DomainNetDataset(WebDataset):
type: DatasetType = DatasetType.DOMAINNET_DATASET
dataset_name: ClassVar[str] = "DomainNet"

@staticmethod
def inform_metadata_parser(loc: Path, metadata_parser: MetadataParser) -> None:
dataset_name = DomainNetDataset.dataset_name

samples = IndexedTarSamples(str(loc))

logger.info(f"Starting to iterate over samples (DomainNet) in folder: {loc}")
for idx, sample in enumerate(samples):
class_name = sample["cls"]
domain = sample["domain"]

metadata_parser.parse(
line_number=idx,
payload=sample,
dataset_name=dataset_name,
class_name=class_name,
domain=domain,
)

samples.close()
9 changes: 8 additions & 1 deletion mixtera/core/datacollection/index/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
"""

from .metadata_parser import MetadataParser, MetadataProperty # noqa: F401
from .parser_collection import MetadataParserFactory, RedPajamaMetadataParser # noqa: F401
from .parser_collection import ( # noqa: F401
DomainNetMetadataParser,
GenericMetadataParser,
MetadataParserFactory,
RedPajamaMetadataParser,
)

__all__ = [
"DomainNetMetadataParser",
"GenericMetadataParser",
"MetadataParser",
"MetadataProperty",
"RedPajamaMetadataParser",
Expand Down
67 changes: 66 additions & 1 deletion mixtera/core/datacollection/index/parser/parser_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ class FineWebMetadataParser(MetadataParser):
def get_properties(cls) -> list[MetadataProperty]:
return [
MetadataProperty(name="dump", dtype="STRING", multiple=False, nullable=False),
MetadataProperty(name="language", dtype="ENUM", enum_options=["en"], multiple=False, nullable=False),
MetadataProperty(
name="language",
dtype="ENUM",
enum_options=["en"],
multiple=False,
nullable=False,
),
]

def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None:
Expand Down Expand Up @@ -168,6 +174,63 @@ def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any
self.add_metadata(sample_id=line_number, pile_set_name=pile_set_name)


class GenericMetadataParser(MetadataParser):
"""
Metadata parser with only the source dataset name as a property.
"""

@classmethod
def get_properties(cls) -> list[MetadataProperty]:
return [
MetadataProperty(
name="dataset",
dtype="STRING",
multiple=False,
nullable=False,
)
]

def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None:
dataset_name = kwargs.get("dataset_name")
self.add_metadata(sample_id=line_number, dataset=dataset_name)


class DomainNetMetadataParser(MetadataParser):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to clean up the datasets vs metadata parsers here. Not sure why in addition to the datasets you added you also added this parser. Let's find a cleaner abstraction for this. You could have like a "LlavaMetadataParser" that somehow e.g. from the path infers which dataset the sample comes from. Or you have one MetadataParser per dataset. You can actually register multiple datasets within the same data collection so that should also not be an issue.

"""
Metadata parser class for the DomainNet dataset.
"""

@classmethod
def get_properties(cls) -> list[MetadataProperty]:
return [
MetadataProperty(
name="domain",
dtype="STRING",
multiple=False,
nullable=False,
),
MetadataProperty(
name="class_name",
dtype="STRING",
multiple=False,
nullable=False,
),
MetadataProperty(
name="dataset",
dtype="STRING",
multiple=False,
nullable=False,
),
]

def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None:
dataset = kwargs.get("dataset_name")
domain = kwargs.get("domain")
class_name = kwargs.get("class_name")

self.add_metadata(sample_id=line_number, domain=domain, class_name=class_name, dataset=dataset)


class MetadataParserFactory:
"""Handles the creation of metadata parsers."""

Expand All @@ -180,6 +243,8 @@ def __init__(self) -> None:
"FINEWEB": FineWebMetadataParser,
"MSCOCO": MsCocoParser,
"PILE": PileaMetadataParser,
"GENERIC": GenericMetadataParser,
"DOMAINNET": DomainNetMetadataParser,
}

def add_parser(self, parser_name: str, parser: type[MetadataParser], overwrite: bool = False) -> bool:
Expand Down
Empty file added mixtera/multimodal/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions mixtera/multimodal/webdataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pipeline import MixteraDataPipeline

__all__ = ["MixteraDataPipeline"]
35 changes: 35 additions & 0 deletions mixtera/multimodal/webdataset/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Iterable

import webdataset as wds
from mixtera.core.client.mixtera_client import MixteraClient, QueryExecutionArgs, ResultStreamingArgs
from mixtera.core.query.query import Query
from mixtera.torch import MixteraTorchDataset


class MixteraDataPipeline(wds.DataPipeline):
"""
Supports building arbitrary webdataset pipelines with Mixtera's `MixteraTorchDataset` as the data source.
"""

def __init__(
self,
client: MixteraClient,
query: Query,
query_execution_args: QueryExecutionArgs,
result_streaming_args: ResultStreamingArgs,
pipeline: Iterable[Any],
):
super().__init__(*pipeline)
self.client = client
self.query = query
self.query_execution_args = query_execution_args
self.result_streaming_args = result_streaming_args

torch_dataset = MixteraTorchDataset(
client=client,
query=query,
Comment on lines +9 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this is necessary for. The MixteraTorchDataset is a user facing abstraction

query_execution_args=query_execution_args,
result_streaming_args=result_streaming_args,
)

self.pipeline.insert(0, torch_dataset)
18 changes: 9 additions & 9 deletions mixtera/tests/core/datacollection/datasets/test_web_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ def test_read_ranges_from_tar_e2e(self):
}

expected = [
{"__key__": "000001", ".cls": 1},
{"__key__": "000002", ".cls": 0},
{"__key__": "000003", ".cls": 1},
{"__key__": "000004", ".cls": 0},
{"__key__": "000005", ".cls": 1},
{"__key__": "000007", ".cls": 1},
{"__key__": "000008", ".cls": 0},
{"__key__": "000001", "cls": b"1"},
{"__key__": "000002", "cls": b"0"},
{"__key__": "000003", "cls": b"1"},
{"__key__": "000004", "cls": b"0"},
{"__key__": "000005", "cls": b"1"},
{"__key__": "000007", "cls": b"1"},
{"__key__": "000008", "cls": b"0"},
]

result = list(WebDataset.read_ranges_from_files(ranges_per_file, lambda x: x, None))
assert all(".png" in sample for sample in result)
assert all("png" in sample for sample in result)

result = [{k: v for k, v in sample.items() if k in [".cls", "__key__"]} for sample in result]
result = [{k: v for k, v in sample.items() if k in ["cls", "__key__"]} for sample in result]
self.assertEqual(result, expected)


Expand Down
Loading
Loading