-
Notifications
You must be signed in to change notification settings - Fork 6
Small updates for compatibility with WebDataset pipelines #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
55d8d88
ed8e08a
763b182
91cd44a
56a261f
55b1191
c690c27
80696ac
ceede7c
ff6316e
8582e67
cc55aa4
e749bdb
3873ff8
30e80c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,24 @@ | |
| from .dataset_type import DatasetType # noqa: F401 | ||
| from .jsonl_dataset import JSONLDataset # noqa: F401 | ||
| from .parquet_dataset import ParquetDataset # noqa: F401 | ||
| from .web_dataset import WebDataset | ||
| from .web_dataset import ( # noqa: F401 | ||
| CC12MDataset, | ||
| COYO700MDataset, | ||
| DomainNetDataset, | ||
| LAION400MDataset, | ||
| MSCOCODataset, | ||
| WebDataset, | ||
| ) | ||
|
|
||
| __all__ = ["Dataset", "DatasetType", "JSONLDataset", "ParquetDataset", "WebDataset"] | ||
| __all__ = [ | ||
| "Dataset", | ||
| "DatasetType", | ||
| "JSONLDataset", | ||
| "ParquetDataset", | ||
| "WebDataset", | ||
| "CC12MDataset", | ||
| "MSCOCODataset", | ||
| "LAION400MDataset", | ||
| "COYO700MDataset", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, this must change |
||
| "DomainNetDataset", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,7 +87,13 @@ class FineWebMetadataParser(MetadataParser): | |
| def get_properties(cls) -> list[MetadataProperty]: | ||
| return [ | ||
| MetadataProperty(name="dump", dtype="STRING", multiple=False, nullable=False), | ||
| MetadataProperty(name="language", dtype="ENUM", enum_options=["en"], multiple=False, nullable=False), | ||
| MetadataProperty( | ||
| name="language", | ||
| dtype="ENUM", | ||
| enum_options=["en"], | ||
| multiple=False, | ||
| nullable=False, | ||
| ), | ||
| ] | ||
|
|
||
| def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: | ||
|
|
@@ -168,6 +174,63 @@ def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any | |
| self.add_metadata(sample_id=line_number, pile_set_name=pile_set_name) | ||
|
|
||
|
|
||
| class GenericMetadataParser(MetadataParser): | ||
| """ | ||
| Metadata parser with only the source dataset name as a property. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def get_properties(cls) -> list[MetadataProperty]: | ||
| return [ | ||
| MetadataProperty( | ||
| name="dataset", | ||
| dtype="STRING", | ||
| multiple=False, | ||
| nullable=False, | ||
| ) | ||
| ] | ||
|
|
||
| def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: | ||
| dataset_name = kwargs.get("dataset_name") | ||
| self.add_metadata(sample_id=line_number, dataset=dataset_name) | ||
|
|
||
|
|
||
| class DomainNetMetadataParser(MetadataParser): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to clean up the datasets vs metadata parsers here. Not sure why in addition to the datasets you added you also added this parser. Let's find a cleaner abstraction for this. You could have like a "LlavaMetadataParser" that somehow e.g. from the path infers which dataset the sample comes from. Or you have one MetadataParser per dataset. You can actually register multiple datasets within the same data collection so that should also not be an issue. |
||
| """ | ||
| Metadata parser class for the DomainNet dataset. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def get_properties(cls) -> list[MetadataProperty]: | ||
| return [ | ||
| MetadataProperty( | ||
| name="domain", | ||
| dtype="STRING", | ||
| multiple=False, | ||
| nullable=False, | ||
| ), | ||
| MetadataProperty( | ||
| name="class_name", | ||
| dtype="STRING", | ||
| multiple=False, | ||
| nullable=False, | ||
| ), | ||
| MetadataProperty( | ||
| name="dataset", | ||
| dtype="STRING", | ||
| multiple=False, | ||
| nullable=False, | ||
| ), | ||
| ] | ||
|
|
||
| def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None: | ||
| dataset = kwargs.get("dataset_name") | ||
| domain = kwargs.get("domain") | ||
| class_name = kwargs.get("class_name") | ||
|
|
||
| self.add_metadata(sample_id=line_number, domain=domain, class_name=class_name, dataset=dataset) | ||
|
|
||
|
|
||
| class MetadataParserFactory: | ||
| """Handles the creation of metadata parsers.""" | ||
|
|
||
|
|
@@ -180,6 +243,8 @@ def __init__(self) -> None: | |
| "FINEWEB": FineWebMetadataParser, | ||
| "MSCOCO": MsCocoParser, | ||
| "PILE": PileaMetadataParser, | ||
| "GENERIC": GenericMetadataParser, | ||
| "DOMAINNET": DomainNetMetadataParser, | ||
| } | ||
|
|
||
| def add_parser(self, parser_name: str, parser: type[MetadataParser], overwrite: bool = False) -> bool: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .pipeline import MixteraDataPipeline | ||
|
|
||
| __all__ = ["MixteraDataPipeline"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from typing import Any, Iterable | ||
|
|
||
| import webdataset as wds | ||
| from mixtera.core.client.mixtera_client import MixteraClient, QueryExecutionArgs, ResultStreamingArgs | ||
| from mixtera.core.query.query import Query | ||
| from mixtera.torch import MixteraTorchDataset | ||
|
|
||
|
|
||
| class MixteraDataPipeline(wds.DataPipeline): | ||
| """ | ||
| Supports building arbitrary webdataset pipelines with Mixtera's `MixteraTorchDataset` as the data source. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| client: MixteraClient, | ||
| query: Query, | ||
| query_execution_args: QueryExecutionArgs, | ||
| result_streaming_args: ResultStreamingArgs, | ||
| pipeline: Iterable[Any], | ||
| ): | ||
| super().__init__(*pipeline) | ||
| self.client = client | ||
| self.query = query | ||
| self.query_execution_args = query_execution_args | ||
| self.result_streaming_args = result_streaming_args | ||
|
|
||
| torch_dataset = MixteraTorchDataset( | ||
| client=client, | ||
| query=query, | ||
|
Comment on lines
+9
to
+30
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what this is necessary for. The MixteraTorchDataset is a user facing abstraction |
||
| query_execution_args=query_execution_args, | ||
| result_streaming_args=result_streaming_args, | ||
| ) | ||
|
|
||
| self.pipeline.insert(0, torch_dataset) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary changes, no?