Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ docs/bin/
docs/lib/
docs/lib64
docs/pyvenv.cfg
wandb/

30 changes: 29 additions & 1 deletion docs/source/observability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Observability
=============

Kithara supports Tensorboard and (soon) Weights and Biases for observability.
Kithara supports Tensorboard and Weights and Biases for observability.

Tensorboard
-----------
Expand All @@ -13,3 +13,31 @@ To use Tensorboard, simply specify the ``tensorboard_dir`` arg in the ``Trainer`
To track training and evaluation performance, launch the tensorboard server with::

tensorboard --logdir=your_tensorboard_dir

Weights and Biases
-----------

To use Weights and Biases, import the class Settings from wandb and pass it to ``wandb_settings`` arg in the ``Trainer`` class. For example:

```
from wandb import Settings

Trainer(...,wandb_settings=Settings(project="Project name"))
```

When running the Trainer, you will be prompt to provide an API key:

```
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
```

Alternatively you can export the key as an environment variable before running the trainer:

```
WANDB_API_KEY=$YOUR_API_KEY
```

After providing the key, you will be able to see your results at https://wandb.ai/<user>/<project>
27 changes: 15 additions & 12 deletions examples/singlehost/quick_start.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
Copyright 2025 Google LLC
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0
https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Quick Start Example

Expand Down Expand Up @@ -40,6 +40,7 @@
Trainer,
SFTDataset,
)
from wandb import Settings

config = {
"model_handle": "hf://google/gemma-2-2b",
Expand All @@ -53,6 +54,7 @@
"per_device_batch_size": 1,
"max_eval_samples": 50,
"learning_rate": 2e-4,
"wandb_settings": Settings(project="project"),
}


Expand Down Expand Up @@ -113,11 +115,12 @@ def run_workload():
eval_steps_interval=config["eval_steps_interval"],
max_eval_samples=config["max_eval_samples"],
log_steps_interval=config["log_steps_interval"],
wandb_settings=config["wandb_settings"],
)

# Start training
trainer.train()

print("Finished training. Prompting model...")

# Test after tuning
Expand Down
3 changes: 2 additions & 1 deletion kithara/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from kithara.callbacks.profiler import *
from kithara.callbacks.checkpointer import *
from kithara.callbacks.checkpointer import *
from kithara.callbacks.wandb import *
64 changes: 64 additions & 0 deletions kithara/callbacks/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from ctypes import cdll
import subprocess
import shutil
from keras.src.callbacks.callback import Callback
import jax
import os
import wandb


class Wandb(Callback):
"""Callbacks to send data to Weights and Biases.

Args:
settings (wandb.Settings): Settings to init Weights and Biases with.
learning_rate (float, optional): Training learning rate. Defaults to None.
epochs (int, optional): Training epochs. Defaults to None.
"""

def __init__(
self,
settings,
learning_rate=None,
epochs=None,
):
super().__init__()
wandb.login()
config = {}
if learning_rate:
config["learning_rate"] = learning_rate
if epochs:
config["epochs"] = epochs
wandb.init(
settings=settings,
config=config,
)

def on_train_begin(self, logs=None):
return

def on_train_end(self, logs=None):
wandb.finish()

def on_train_batch_begin(self, batch, logs=None):
return

def on_train_batch_end(self, batch, logs=None):
if logs != None:
wandb.log(logs)
100 changes: 68 additions & 32 deletions kithara/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
Copyright 2025 Google LLC
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0
https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os

Expand All @@ -32,12 +32,13 @@
from kithara.optimizers import convert_to_kithara_optimizer
from kithara.model import Model
from kithara.dataset import Dataloader
from kithara.callbacks import Profiler, Checkpointer
from kithara.callbacks import Profiler, Checkpointer, Wandb
from kithara.distributed.sharding._data_sharding import DataSharding
from keras.src.backend.common import global_state
from typing import Any, Union, List, Tuple
import jax.tree_util as jtu
import numpy as np
import wandb


class Trainer:
Expand Down Expand Up @@ -67,6 +68,7 @@ class Trainer:
tensorboard_dir (str, optional): The directory path for TensorBoard logs. Can be either a
local directory or a Google Cloud Storage (GCS) path. Defaults to None.
profiler (kithara.Profiler, optional): A profiler instance for monitoring performance metrics. Defaults to None.
wandb_settings (wandb.Settings, optional): Configuration for Weights and Biases. Defaults to None. When set to None, Weights and Biases wont be enabled.

Methods:
loss_fn: Returns a JAX-compatible callable that computes the loss value from logits and labels.
Expand All @@ -80,7 +82,11 @@ class Trainer:
def __init__(
self,
model: Model,
optimizer: keras.Optimizer | optax.GradientTransformation | optax.GradientTransformationExtraArgs,
optimizer: (
keras.Optimizer
| optax.GradientTransformation
| optax.GradientTransformationExtraArgs
),
train_dataloader: Dataloader,
eval_dataloader: Dataloader = None,
steps=None,
Expand All @@ -92,6 +98,7 @@ def __init__(
tensorboard_dir=None,
profiler: Profiler = None,
checkpointer: Checkpointer = None,
wandb_settings=None,
):
if steps is None and epochs is None:
epochs = 1
Expand Down Expand Up @@ -121,14 +128,16 @@ def __init__(
self.global_batch_size = train_dataloader.global_batch_size
self.profiler = profiler
self.checkpointer = checkpointer
self.wandb_settings = wandb_settings
self._validate_setup()

# Initialize optimizer and callbacks
if isinstance(optimizer, keras.optimizers.Optimizer):
self.optimizer.build(self.model.trainable_variables)
else:
self.optimizer = convert_to_kithara_optimizer(
optimizer, self.model.trainable_variables)
optimizer, self.model.trainable_variables
)

self.callbacks = self._create_callbacks()
if self.tensorboard_dir:
Expand Down Expand Up @@ -213,7 +222,8 @@ def _train_step(self, state: Tuple[List[jax.Array]], data: dict):
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables)
optimizer_variables, grads, trainable_variables
)
return (
loss,
(
Expand Down Expand Up @@ -302,9 +312,7 @@ def train(self):
"samples_per_second": round(samples_per_second, 2),
"train_steps_per_second": round(1 / step_time, 2),
"samples_seen": self.global_batch_size * self.step_count,
"learning_rate": (round(float(self.optimizer.learning_rate.value),7)
if self.optimizer.learning_rate is not None else None),

"learning_rate": self._learning_rate(),
}

# Log progress
Expand Down Expand Up @@ -483,6 +491,15 @@ def evaluate(self, state=None):

return eval_loss

def _learning_rate(self):
return (
(
round(float(self.optimizer.learning_rate.value), 7)
if self.optimizer.learning_rate is not None
else None
),
)

def _make_train_step(self):
return jax.jit(self._train_step, donate_argnums=(0,))

Expand Down Expand Up @@ -555,7 +572,8 @@ def _update_model_with_state(self, state):

_ = jax.tree.map(
lambda variable, value: variable.assign(
jax.lax.with_sharding_constraint(value, variable._layout)),
jax.lax.with_sharding_constraint(value, variable._layout)
),
self.optimizer.variables,
optimizer_variables,
)
Expand Down Expand Up @@ -593,14 +611,23 @@ def _print_run_summary(self):

def _create_callbacks(self):
callbacks = []
if self.tensorboard_dir and isinstance(self.optimizer, keras.optimizers.Optimizer):
if self.tensorboard_dir and isinstance(
self.optimizer, keras.optimizers.Optimizer
):
callbacks.append(
keras.callbacks.TensorBoard(
log_dir=self.tensorboard_dir,
update_freq="batch",
write_steps_per_second=True,
)
)
if self.wandb_settings:
self.wanb = Wandb(
settings=self.wandb_settings,
learning_rate=self._learning_rate(),
epochs=self.epochs,
)
callbacks.append(self.wanb)
if self.profiler:
callbacks.append(self.profiler)
if self.checkpointer:
Expand Down Expand Up @@ -649,15 +676,20 @@ def _validate_sharding_correctness(self, data, state):
)

_ = jax.tree.map(
lambda variable, value: print(
f"Step {self.step_count}: optimizer variable is not sharded",
f"{get_size_in_mb(value)}mb",
variable.path,
value.shape,
value.sharding,
) if is_not_sharded_and_is_large(value) else None,
lambda variable, value: (
print(
f"Step {self.step_count}: optimizer variable is not sharded",
f"{get_size_in_mb(value)}mb",
variable.path,
value.shape,
value.sharding,
)
if is_not_sharded_and_is_large(value)
else None
),
self.optimizer.variables,
state[2])
state[2],
)

except Exception as e:
print(f"Error during sharding correctness validation: {e}")
Expand All @@ -676,8 +708,10 @@ def _validate_memory_usage(self):
total_size += get_size_in_mb(v.value)

total_size += jax.tree.reduce(
lambda agg, leaf: jax.numpy.add(agg, get_size_in_mb(leaf.value)), self.optimizer.variables,
initializer=0)
lambda agg, leaf: jax.numpy.add(agg, get_size_in_mb(leaf.value)),
self.optimizer.variables,
initializer=0,
)

live_arrays = jax.live_arrays()
live_arrays_size = 0
Expand All @@ -699,9 +733,11 @@ def _validate_memory_usage(self):
memory_info = jax.local_devices()[0].memory_stats()
memory_per_device_mb = memory_info["bytes_limit"] / (1024**2)
total_memory = memory_per_device_mb * jax.device_count()
print(f"Total memory available is {total_memory:.3f} MB, if you run into "
print(
f"Total memory available is {total_memory:.3f} MB, if you run into "
"errors, check if your memory usage is close to the limit, and either "
"reduce your per-device batch size or sequence length.")
"reduce your per-device batch size or sequence length."
)
except Exception as e:
# memory_info is not available on some TPUs
pass
Expand Down
Loading