From 2d791bff6f1875aae6ed5355af2681174efc18ef Mon Sep 17 00:00:00 2001 From: Pysith Vanuptikul Date: Wed, 26 Mar 2025 18:29:41 +0000 Subject: [PATCH 1/2] Add support for w&b Currently behavior: passing `wandb_project` parameter to Trainer will enable wandb. The first run will prompt the user for an api_key: ``` wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. wandb: (1) Create a W&B account wandb: (2) Use an existing W&B account wandb: (3) Don't visualize my results ``` Data sent to wandb: https://screenshot.googleplex.com/3wTVkYw23LKHYyX --- .gitignore | 1 + examples/singlehost/quick_start.py | 26 ++++---- kithara/callbacks/__init__.py | 3 +- kithara/callbacks/wandb.py | 69 ++++++++++++++++++++ kithara/trainer/trainer.py | 100 ++++++++++++++++++++--------- pyproject.toml | 3 +- 6 files changed, 156 insertions(+), 46 deletions(-) create mode 100644 kithara/callbacks/wandb.py diff --git a/.gitignore b/.gitignore index c8f0836..956675a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ docs/bin/ docs/lib/ docs/lib64 docs/pyvenv.cfg +wandb/ diff --git a/examples/singlehost/quick_start.py b/examples/singlehost/quick_start.py index 71f43af..7acf704 100644 --- a/examples/singlehost/quick_start.py +++ b/examples/singlehost/quick_start.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Quick Start Example @@ -53,6 +53,7 @@ "per_device_batch_size": 1, "max_eval_samples": 50, "learning_rate": 2e-4, + "wandb_project": "wandb-test", } @@ -113,11 +114,12 @@ def run_workload(): eval_steps_interval=config["eval_steps_interval"], max_eval_samples=config["max_eval_samples"], log_steps_interval=config["log_steps_interval"], + wandb_project=config["wandb_project"], ) # Start training trainer.train() - + print("Finished training. Prompting model...") # Test after tuning diff --git a/kithara/callbacks/__init__.py b/kithara/callbacks/__init__.py index b9c2f66..b320503 100644 --- a/kithara/callbacks/__init__.py +++ b/kithara/callbacks/__init__.py @@ -1,2 +1,3 @@ from kithara.callbacks.profiler import * -from kithara.callbacks.checkpointer import * \ No newline at end of file +from kithara.callbacks.checkpointer import * +from kithara.callbacks.wandb import * \ No newline at end of file diff --git a/kithara/callbacks/wandb.py b/kithara/callbacks/wandb.py new file mode 100644 index 0000000..632d707 --- /dev/null +++ b/kithara/callbacks/wandb.py @@ -0,0 +1,69 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from ctypes import cdll +import subprocess +import shutil +from keras.src.callbacks.callback import Callback +import jax +import os +import wandb + + +class Wandb(Callback): + """Callbacks to send data to Weights and Biases. + + Args: + wandb_project (str): Weights and Biases API project name. + learning_rate (float, optional): Training learning rate. Defaults to None. + epochs (int, optional): Training epochs. Defaults to None. + """ + + def __init__( + self, + wandb_project, + learning_rate=None, + epochs=None, + ): + super().__init__() + wandb.login() + config = {} + if learning_rate: + config["learning_rate"] = learning_rate + if epochs: + config["epochs"] = epochs + wandb.init( + project=wandb_project, + config=config, + ) + + def on_train_begin(self, logs=None): + return + + def on_train_end(self, logs=None): + wandb.finish() + + def on_train_batch_begin(self, batch, logs=None): + return + + def on_train_batch_end(self, batch, logs=None): + if logs != None: + entry = {} + if "loss" in logs.keys(): + entry["loss"] = logs["loss"] + if "acc" in logs.keys(): + entry["acc"] = logs["acc"] + wandb.log(entry) diff --git a/kithara/trainer/trainer.py b/kithara/trainer/trainer.py index bed96a4..bf1e050 100644 --- a/kithara/trainer/trainer.py +++ b/kithara/trainer/trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os @@ -32,12 +32,13 @@ from kithara.optimizers import convert_to_kithara_optimizer from kithara.model import Model from kithara.dataset import Dataloader -from kithara.callbacks import Profiler, Checkpointer +from kithara.callbacks import Profiler, Checkpointer, Wandb from kithara.distributed.sharding._data_sharding import DataSharding from keras.src.backend.common import global_state from typing import Any, Union, List, Tuple import jax.tree_util as jtu import numpy as np +import wandb class Trainer: @@ -67,6 +68,7 @@ class Trainer: tensorboard_dir (str, optional): The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage (GCS) path. Defaults to None. profiler (kithara.Profiler, optional): A profiler instance for monitoring performance metrics. Defaults to None. + wandb_project (str, optional): Name of Weights and Biases project. Defaults to None. When set to None, Weights and Biases wont be enabled. Methods: loss_fn: Returns a JAX-compatible callable that computes the loss value from logits and labels. @@ -80,7 +82,11 @@ class Trainer: def __init__( self, model: Model, - optimizer: keras.Optimizer | optax.GradientTransformation | optax.GradientTransformationExtraArgs, + optimizer: ( + keras.Optimizer + | optax.GradientTransformation + | optax.GradientTransformationExtraArgs + ), train_dataloader: Dataloader, eval_dataloader: Dataloader = None, steps=None, @@ -92,6 +98,7 @@ def __init__( tensorboard_dir=None, profiler: Profiler = None, checkpointer: Checkpointer = None, + wandb_project=None, ): if steps is None and epochs is None: epochs = 1 @@ -121,6 +128,7 @@ def __init__( self.global_batch_size = train_dataloader.global_batch_size self.profiler = profiler self.checkpointer = checkpointer + self.wandb_project = wandb_project self._validate_setup() # Initialize optimizer and callbacks @@ -128,7 +136,8 @@ def __init__( self.optimizer.build(self.model.trainable_variables) else: self.optimizer = convert_to_kithara_optimizer( - optimizer, self.model.trainable_variables) + optimizer, self.model.trainable_variables + ) self.callbacks = self._create_callbacks() if self.tensorboard_dir: @@ -213,7 +222,8 @@ def _train_step(self, state: Tuple[List[jax.Array]], data: dict): trainable_variables, non_trainable_variables, x, y ) trainable_variables, optimizer_variables = self.optimizer.stateless_apply( - optimizer_variables, grads, trainable_variables) + optimizer_variables, grads, trainable_variables + ) return ( loss, ( @@ -302,9 +312,7 @@ def train(self): "samples_per_second": round(samples_per_second, 2), "train_steps_per_second": round(1 / step_time, 2), "samples_seen": self.global_batch_size * self.step_count, - "learning_rate": (round(float(self.optimizer.learning_rate.value),7) - if self.optimizer.learning_rate is not None else None), - + "learning_rate": self._learning_rate(), } # Log progress @@ -483,6 +491,15 @@ def evaluate(self, state=None): return eval_loss + def _learning_rate(self): + return ( + ( + round(float(self.optimizer.learning_rate.value), 7) + if self.optimizer.learning_rate is not None + else None + ), + ) + def _make_train_step(self): return jax.jit(self._train_step, donate_argnums=(0,)) @@ -555,7 +572,8 @@ def _update_model_with_state(self, state): _ = jax.tree.map( lambda variable, value: variable.assign( - jax.lax.with_sharding_constraint(value, variable._layout)), + jax.lax.with_sharding_constraint(value, variable._layout) + ), self.optimizer.variables, optimizer_variables, ) @@ -593,7 +611,9 @@ def _print_run_summary(self): def _create_callbacks(self): callbacks = [] - if self.tensorboard_dir and isinstance(self.optimizer, keras.optimizers.Optimizer): + if self.tensorboard_dir and isinstance( + self.optimizer, keras.optimizers.Optimizer + ): callbacks.append( keras.callbacks.TensorBoard( log_dir=self.tensorboard_dir, @@ -601,6 +621,13 @@ def _create_callbacks(self): write_steps_per_second=True, ) ) + if self.wandb_project: + self.wanb = Wandb( + self.wandb_project, + learning_rate=self._learning_rate(), + epochs=self.epochs, + ) + callbacks.append(self.wanb) if self.profiler: callbacks.append(self.profiler) if self.checkpointer: @@ -649,15 +676,20 @@ def _validate_sharding_correctness(self, data, state): ) _ = jax.tree.map( - lambda variable, value: print( - f"Step {self.step_count}: optimizer variable is not sharded", - f"{get_size_in_mb(value)}mb", - variable.path, - value.shape, - value.sharding, - ) if is_not_sharded_and_is_large(value) else None, + lambda variable, value: ( + print( + f"Step {self.step_count}: optimizer variable is not sharded", + f"{get_size_in_mb(value)}mb", + variable.path, + value.shape, + value.sharding, + ) + if is_not_sharded_and_is_large(value) + else None + ), self.optimizer.variables, - state[2]) + state[2], + ) except Exception as e: print(f"Error during sharding correctness validation: {e}") @@ -676,8 +708,10 @@ def _validate_memory_usage(self): total_size += get_size_in_mb(v.value) total_size += jax.tree.reduce( - lambda agg, leaf: jax.numpy.add(agg, get_size_in_mb(leaf.value)), self.optimizer.variables, - initializer=0) + lambda agg, leaf: jax.numpy.add(agg, get_size_in_mb(leaf.value)), + self.optimizer.variables, + initializer=0, + ) live_arrays = jax.live_arrays() live_arrays_size = 0 @@ -699,9 +733,11 @@ def _validate_memory_usage(self): memory_info = jax.local_devices()[0].memory_stats() memory_per_device_mb = memory_info["bytes_limit"] / (1024**2) total_memory = memory_per_device_mb * jax.device_count() - print(f"Total memory available is {total_memory:.3f} MB, if you run into " + print( + f"Total memory available is {total_memory:.3f} MB, if you run into " "errors, check if your memory usage is close to the limit, and either " - "reduce your per-device batch size or sequence length.") + "reduce your per-device batch size or sequence length." + ) except Exception as e: # memory_info is not available on some TPUs pass diff --git a/pyproject.toml b/pyproject.toml index 4471940..f2c729b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ dependencies = [ "editdistance", "pyglove", "tensorflow_datasets", - "tfds-nightly" #"tfds-nightly==4.9.8.dev202503240044" + "tfds-nightly", #"tfds-nightly==4.9.8.dev202503240044" + "wandb" ] [project.optional-dependencies] From ebc72d4694d6d3d8fc8a6e861324656b8e69913e Mon Sep 17 00:00:00 2001 From: Pysith Vanuptikul Date: Thu, 27 Mar 2025 18:38:01 +0000 Subject: [PATCH 2/2] Update wandb interface to accept settings. Update observability documentation. Update wandb to log step_stats. --- docs/source/observability.rst | 30 +++++++++++++++++++++++++++++- examples/singlehost/quick_start.py | 5 +++-- kithara/callbacks/wandb.py | 13 ++++--------- kithara/trainer/trainer.py | 10 +++++----- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/docs/source/observability.rst b/docs/source/observability.rst index 48eaeb9..7ce061e 100644 --- a/docs/source/observability.rst +++ b/docs/source/observability.rst @@ -3,7 +3,7 @@ Observability ============= -Kithara supports Tensorboard and (soon) Weights and Biases for observability. +Kithara supports Tensorboard and Weights and Biases for observability. Tensorboard ----------- @@ -13,3 +13,31 @@ To use Tensorboard, simply specify the ``tensorboard_dir`` arg in the ``Trainer` To track training and evaluation performance, launch the tensorboard server with:: tensorboard --logdir=your_tensorboard_dir + +Weights and Biases +----------- + +To use Weights and Biases, import the class Settings from wandb and pass it to ``wandb_settings`` arg in the ``Trainer`` class. For example: + +``` +from wandb import Settings + +Trainer(...,wandb_settings=Settings(project="Project name")) +``` + +When running the Trainer, you will be prompt to provide an API key: + +``` +wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information. +wandb: (1) Create a W&B account +wandb: (2) Use an existing W&B account +wandb: (3) Don't visualize my results +``` + +Alternatively you can export the key as an environment variable before running the trainer: + +``` +WANDB_API_KEY=$YOUR_API_KEY +``` + +After providing the key, you will be able to see your results at https://wandb.ai// \ No newline at end of file diff --git a/examples/singlehost/quick_start.py b/examples/singlehost/quick_start.py index 7acf704..3b92ab3 100644 --- a/examples/singlehost/quick_start.py +++ b/examples/singlehost/quick_start.py @@ -40,6 +40,7 @@ Trainer, SFTDataset, ) +from wandb import Settings config = { "model_handle": "hf://google/gemma-2-2b", @@ -53,7 +54,7 @@ "per_device_batch_size": 1, "max_eval_samples": 50, "learning_rate": 2e-4, - "wandb_project": "wandb-test", + "wandb_settings": Settings(project="project"), } @@ -114,7 +115,7 @@ def run_workload(): eval_steps_interval=config["eval_steps_interval"], max_eval_samples=config["max_eval_samples"], log_steps_interval=config["log_steps_interval"], - wandb_project=config["wandb_project"], + wandb_settings=config["wandb_settings"], ) # Start training diff --git a/kithara/callbacks/wandb.py b/kithara/callbacks/wandb.py index 632d707..3a5089e 100644 --- a/kithara/callbacks/wandb.py +++ b/kithara/callbacks/wandb.py @@ -27,14 +27,14 @@ class Wandb(Callback): """Callbacks to send data to Weights and Biases. Args: - wandb_project (str): Weights and Biases API project name. + settings (wandb.Settings): Settings to init Weights and Biases with. learning_rate (float, optional): Training learning rate. Defaults to None. epochs (int, optional): Training epochs. Defaults to None. """ def __init__( self, - wandb_project, + settings, learning_rate=None, epochs=None, ): @@ -46,7 +46,7 @@ def __init__( if epochs: config["epochs"] = epochs wandb.init( - project=wandb_project, + settings=settings, config=config, ) @@ -61,9 +61,4 @@ def on_train_batch_begin(self, batch, logs=None): def on_train_batch_end(self, batch, logs=None): if logs != None: - entry = {} - if "loss" in logs.keys(): - entry["loss"] = logs["loss"] - if "acc" in logs.keys(): - entry["acc"] = logs["acc"] - wandb.log(entry) + wandb.log(logs) diff --git a/kithara/trainer/trainer.py b/kithara/trainer/trainer.py index bf1e050..29ca038 100644 --- a/kithara/trainer/trainer.py +++ b/kithara/trainer/trainer.py @@ -68,7 +68,7 @@ class Trainer: tensorboard_dir (str, optional): The directory path for TensorBoard logs. Can be either a local directory or a Google Cloud Storage (GCS) path. Defaults to None. profiler (kithara.Profiler, optional): A profiler instance for monitoring performance metrics. Defaults to None. - wandb_project (str, optional): Name of Weights and Biases project. Defaults to None. When set to None, Weights and Biases wont be enabled. + wandb_settings (wandb.Settings, optional): Configuration for Weights and Biases. Defaults to None. When set to None, Weights and Biases wont be enabled. Methods: loss_fn: Returns a JAX-compatible callable that computes the loss value from logits and labels. @@ -98,7 +98,7 @@ def __init__( tensorboard_dir=None, profiler: Profiler = None, checkpointer: Checkpointer = None, - wandb_project=None, + wandb_settings=None, ): if steps is None and epochs is None: epochs = 1 @@ -128,7 +128,7 @@ def __init__( self.global_batch_size = train_dataloader.global_batch_size self.profiler = profiler self.checkpointer = checkpointer - self.wandb_project = wandb_project + self.wandb_settings = wandb_settings self._validate_setup() # Initialize optimizer and callbacks @@ -621,9 +621,9 @@ def _create_callbacks(self): write_steps_per_second=True, ) ) - if self.wandb_project: + if self.wandb_settings: self.wanb = Wandb( - self.wandb_project, + settings=self.wandb_settings, learning_rate=self._learning_rate(), epochs=self.epochs, )