Skip to content

Commit 6af7f96

Browse files
Merge branch 'main' into tolga/WebDatasetUpdates
2 parents 8582e67 + 63b6759 commit 6af7f96

40 files changed

+1472
-302
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ var/
3434
*.egg
3535
.eggs/
3636
*.egg-info
37+
build_*/
3738

3839
# postgresql
3940
postgres-data/

README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<div align="center">
2+
<h1>Mixtera</h1>
3+
4+
---
5+
[![GitHub Workflow Status](https://github.com/eth-easl/mixtera/actions/workflows/workflow.yaml/badge.svg)](https://github.com/eth-easl/mixtera/actions/workflows/workflow.yaml)
6+
[![License](https://img.shields.io/github/license/eth-easl/mixtera)](https://img.shields.io/github/license/eth-easl/mixtera)
7+
8+
Mixtera is an open-source data-centric training data plane built for modern LLM/VLM training. It enables ML engineers to declaratively filter, mix, and distribute large-scale training datasets on the fly, while supporting dynamic adjustment based on model feedback. Learn more in our [paper](https://mboether.com/assets/pdf/bother2024mixtera.pdf).
9+
10+
</div>
11+
12+
## ⚡️ Quickstart
13+
14+
Mixtera can run as a server, or, for single-GPU training, in-process. In both cases, you will need to install the necessary dependencies and install Mixtera in your environment, for example as follows:
15+
16+
```bash
17+
# In case you don't have micromamba yet
18+
# macos:
19+
brew install micromamba
20+
# alternatively:
21+
"${SHELL}" <(curl -L micro.mamba.pm/install.sh)
22+
23+
# Start here if you have micromamba already
24+
micromamba env create -f ./environment.yml
25+
micromamba activate mixtera
26+
pip install -e .
27+
pip install -r dev-requirements.txt
28+
```
29+
30+
The Mixtera server can then be started using the `mixtera-server` command.
31+
32+
## 🔁 What is Mixtera used for?
33+
Modern large language and vision models rely on training datasets with fine-grained properties such as language, source, topic, or license. Traditionally, ML engineers have managed these datasets manually using ad hoc scripts and directory structures, which is time-consuming, tedious, and prone to errors. Mixtera addresses these issues by providing a lightweight, declarative data plane that lets you seamlessly filter and dynamically mix data on the fly without the overhead of redundant data processing.
34+
35+
Whether you need to enforce fixed data ratios—say, 70% JavaScript code and 30% Python, or want to adjust proportions during training using feedback-driven algorithms like [ADO](https://arxiv.org/abs/2410.11820), Mixtera offers a flexible interface for both static and dynamic mixing. Beyond efficiency, Mixtera ensures that distributed training jobs receive identical, reproducible data inputs across all nodes, crucial for consistency and accurate model results.
36+
37+
Mixtera is a centralized sample management layer, building upon DuckDB. It abstracts away the complexities of file-system-based data management. It supports data samples stored in various formats (e.g., jsonl, parquet, webdataset), letting users focus on model research rather than data wrangling.
38+
39+
## 🚀 Usage
40+
41+
Using Mixtera typically consists of (1) registering your data and (2) running queries/trainings on top of it. We maintain several [examples](https://github.com/eth-easl/mixtera/blob/main/examples/) of how to use Mixtera. A good first read is the [local-only example](https://github.com/eth-easl/mixtera/blob/main/examples/client_local_example.py). That script walks you through the basics of registering data in Mixtera and running a query on that. Afterwards, the [server example](https://github.com/eth-easl/mixtera/blob/main/examples/client_server_example.py) shows you how to run a server with the `mixtera-server` command, and how to register data and query it via client-server interaction.
42+
43+
We provide a [full guide](examples/torchtitan.md) on how to run a training with Mixtera and torchtitan, in particular on how to run the server, register the dataset, and then start training jobs, for both bare-metal and slurm (e.g., SwissAI/CSCS/Alps/Clariden) deployments.
44+
45+
## ✨ Mixtera’s System Overview
46+
47+
<div align="center">
48+
<img src="img/system.png" height=300 alt="Mixtera system design"/>
49+
</div>
50+
51+
Mixtera follows a server-client model. During training, the server runs on a node and each training node runs client instances. The query is executed at the server in two phases. First, Mixtera applies static filters from the query (e.g., English-only) to obtain all samples we could train on. This gives us a [QueryResult](https://github.com/eth-easl/mixtera/blob/main/mixtera/core/query/query_result.py). Second, during training, the server distributes [chunks](https://github.com/eth-easl/mixtera/blob/main/mixtera/core/query/result_chunk.py) of that query result to the client(s). A chunk is a collection of pointers to samples in files. These pointers tell the receiving client which samples in the file to load (e.g., sample 10 in file `wikipedia.jsonl.zst`).
52+
53+
## ✉️ About
54+
55+
Mixtera is being developed at the [Efficient Architectures and Systems Lab (EASL)](https://anakli.inf.ethz.ch/#Group) at the [ETH Zurich Systems Group](https://systems.ethz.ch/).
56+
Please reach out to `mboether [at] inf [­dot] ethz [dot] ch` or open an issue on Github if you have any questions or inquiry related to Mixtera and its usage.

cmake/dependencies.cmake

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,6 @@ FetchContent_Declare(
4545
FetchContent_MakeAvailable(indicators)
4646
target_compile_options(indicators INTERFACE -Wno-zero-as-null-pointer-constant -Wno-sign-compare)
4747

48-
################### abseil ####################
49-
50-
message(STATUS "Making abseil available.")
51-
52-
FetchContent_Declare(
53-
absl
54-
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
55-
GIT_TAG 20240722.0
56-
)
57-
FetchContent_MakeAvailable(absl)
58-
59-
# Required for GCC
60-
target_compile_options(absl_flat_hash_map INTERFACE -Wno-pedantic)
61-
target_compile_options(absl_base INTERFACE -Wno-pedantic)
62-
6348

6449
################### Arrow ####################
6550

@@ -104,3 +89,19 @@ else()
10489
endif()
10590

10691
target_compile_options(Arrow::arrow_shared INTERFACE -Wno-redundant-move)
92+
93+
################### abseil ####################
94+
95+
# Abseil needs to be loaded after arrow, otherwise we run into issues on the alps/clariden cluster.
96+
message(STATUS "Making abseil available.")
97+
98+
FetchContent_Declare(
99+
absl
100+
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
101+
GIT_TAG 20240722.0
102+
)
103+
FetchContent_MakeAvailable(absl)
104+
105+
# Required for GCC
106+
target_compile_options(absl_flat_hash_map INTERFACE -Wno-pedantic)
107+
target_compile_options(absl_base INTERFACE -Wno-pedantic)

examples/clariden/Dockerfile

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
FROM nvcr.io/nvidia/pytorch:25.01-py3
2+
3+
RUN apt-get update && apt-get upgrade -y && apt-get install ca-certificates lsb-release wget python3-pip neovim autoconf build-essential gdb software-properties-common curl unzip cmake gzip protobuf-compiler libtool zstd liblz4-dev lz4 -y
4+
5+
RUN wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
6+
RUN apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
7+
RUN apt update
8+
RUN apt install -y -V libparquet-glib-dev libparquet-dev libarrow-dataset-glib-dev libarrow-dataset-dev libarrow-glib-dev libarrow-dev
9+
10+
RUN pip install pip==24.*
11+
12+
# If you encounter pyarrow issues, ensure the version here matches the version downloaded above!!
13+
RUN pip install tqdm loguru psutil numpy==1.26.4 dill datasets transformers pyarrow==19.* xxhash xopen scipy tenacity
14+
RUN pip install duckdb polars==1.15 pillow pybind11 pytest flake8 mypy pylint autopep8 isort black tensorboard tiktoken blobfile tabulate wandb torchdata>=0.8.0 tomli>=1.1.0 dacite pyyaml packaging safetensors sentencepiece jupyter seaborn webdataset lz4 git+https://github.com/tmbdev/[email protected] mosaicml-streaming grain
15+
RUN pip install lm_eval typer # for evaluation
16+
17+
# Test torch nightly
18+
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
19+
20+
RUN git clone --recurse-submodules -b v1.64.3 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
21+
cd grpc && mkdir -p cmake/build && cd cmake/build && \
22+
cmake -DgRPC_PROTOBUF_PROVIDER=module -DABSL_ENABLE_INSTALL=On -DgRPC_BUILD_CSHARP_EXT=Off -DABSL_BUILD_TESTING=Off -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release ../.. && \
23+
make -j64 && make install && cd ../../
24+
25+
RUN bash -c "cp /usr/local/lib/libutf8* /usr/lib"
26+
27+
## For nanotron
28+
RUN pip uninstall -y ninja && pip install ninja
29+
RUN MAX_JOBS=12 numactl --membind=0-3 pip install flash-attn --no-build-isolation

examples/client_local_example.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,31 @@ class TestMetadataParser(MetadataParser):
4747
def get_properties(cls) -> list[MetadataProperty]:
4848
return [
4949
MetadataProperty(
50-
name="language", dtype="ENUM", multiple=False, nullable=False, enum_options={"JavaScript", "HTML"}
50+
name="language",
51+
dtype="ENUM",
52+
multiple=False,
53+
nullable=False,
54+
enum_options={"JavaScript", "HTML"},
5155
),
5256
MetadataProperty(
53-
name="license", dtype="STRING", multiple=False, nullable=False, enum_options={"CC", "MIT"}
57+
name="license",
58+
dtype="STRING",
59+
multiple=False,
60+
nullable=False,
61+
enum_options={"CC", "MIT"},
5462
), # Could be ENUM but we are using string to test
5563
MetadataProperty(
56-
name="doublelanguage", dtype="ENUM", multiple=True, nullable=False, enum_options={"JavaScript", "HTML"}
64+
name="doublelanguage",
65+
dtype="ENUM",
66+
multiple=True,
67+
nullable=False,
68+
enum_options={"JavaScript", "HTML"},
5769
),
5870
]
5971

60-
def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]) -> None:
72+
def parse(
73+
self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any]]
74+
) -> None:
6175
metadata = payload["meta"]
6276
self.add_metadata(
6377
sample_id=line_number,
@@ -69,49 +83,69 @@ def parse(self, line_number: int, payload: Any, **kwargs: Optional[dict[Any, Any
6983

7084
def parsing_func(sample):
7185
import json
86+
7287
return json.loads(sample)["text"]
7388

89+
7490
def setup_local_client(directory: Path):
7591
# Writing JSONL data to the directory, which simulates the dataset.
7692
write_jsonl(directory / "testd.jsonl")
77-
93+
7894
# Instantiating a client from a local directory to interact with the datasets locally.
7995
client = MixteraClient.from_directory(directory)
80-
96+
8197
# Register the metadata parser.
8298
client.register_metadata_parser("TEST_PARSER", TestMetadataParser)
83-
99+
84100
# Registering the dataset with the client.
85-
client.register_dataset(
86-
"local_integrationtest_dataset", directory / "testd.jsonl", JSONLDataset, parsing_func, "TEST_PARSER"
87-
)
88-
101+
if not client.register_dataset(
102+
"local_integrationtest_dataset",
103+
directory / "testd.jsonl",
104+
JSONLDataset,
105+
parsing_func,
106+
"TEST_PARSER",
107+
):
108+
raise RuntimeError("Error while registering dataset!")
109+
89110
return client
90111

112+
91113
def run_query(client: MixteraClient, chunk_size: int):
92-
job_id = str(round(time.time() * 1000)) # Get some job ID based on current timestamp
93-
query = Query.for_job(job_id).select(("language", "==", "JavaScript")) # In our example, we want to query all samples tagged JavaScript
114+
job_id = str(
115+
round(time.time() * 1000)
116+
) # Get some job ID based on current timestamp
117+
query = Query.for_job(job_id).select(
118+
("language", "==", "JavaScript")
119+
) # In our example, we want to query all samples tagged JavaScript
94120

95121
mixture = ArbitraryMixture(chunk_size=chunk_size)
96122
qea = QueryExecutionArgs(mixture=mixture)
97123
client.execute_query(query, qea)
124+
client.wait_for_execution(job_id)
98125

99126
rsa = ResultStreamingArgs(job_id=job_id)
100127
result_samples = list(client.stream_results(rsa))
101-
128+
102129
# Checking the number of results and their validity.
103-
assert len(result_samples) == 500, f"Got {len(result_samples)} samples instead of the expected 500!"
104-
for _, sample in result_samples: # The first argument is the index in the current chunk, needed for state recovery
130+
assert (
131+
len(result_samples) == 500
132+
), f"Got {len(result_samples)} samples instead of the expected 500!"
133+
for (
134+
_,
135+
_,
136+
sample,
137+
) in result_samples: # The first argument is the index in the current chunk, needed for state recovery. The second argument is the domain id.
105138
assert int(sample) % 2 == 0, f"Sample {sample} should not appear for JavaScript"
106139

140+
107141
def main():
108142
with tempfile.TemporaryDirectory() as temp_dir:
109143
# Setup the local client with a temporary directory.
110144
# This also populates the database with a dummy dataset, where 50% of data is tagged HTML and 50% is tagged JavaScript.
111145
client = setup_local_client(Path(temp_dir))
112-
chunk_size = 42 # Size of the result chunks of the query
146+
chunk_size = 42 # Size of the result chunks of the query
113147
run_query(client, chunk_size)
114-
148+
115149

116150
if __name__ == "__main__":
117-
main()
151+
main()

0 commit comments

Comments
 (0)