Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def sparsecore_build(
)

# Collect all stacked tables.
table_specs = embedding_utils.get_table_specs(feature_specs)
table_specs = embedding.get_table_specs(feature_specs)
table_stacks = embedding_utils.get_table_stacks(table_specs)

# Create variables for all stacked tables and slot variables.
Expand Down Expand Up @@ -515,9 +515,7 @@ def _sparsecore_symbolic_preprocess(
del inputs, weights, training

# Each stacked-table gets a ShardedCooMatrix.
table_specs = embedding_utils.get_table_specs(
self._config.feature_specs
)
table_specs = embedding.get_table_specs(self._config.feature_specs)
table_stacks = embedding_utils.get_table_stacks(table_specs)
stacked_table_specs = {
stack_name: stack[0].stacked_table_spec
Expand Down Expand Up @@ -600,40 +598,43 @@ def _sparsecore_preprocess(
if training:
# Synchronize input statistics across all devices and update the
# underlying stacked tables specs in the feature specs.
prev_stats = embedding_utils.get_stacked_table_stats(
self._config.feature_specs
)

# Take the maximum with existing stats.
stats = keras.tree.map_structure(max, prev_stats, stats)
# Aggregate stats across all processes/devices via pmax.
num_local_cpu_devices = jax.local_device_count("cpu")

# Flatten the stats so we can more efficiently transfer them
# between hosts. We use jax.tree because we will later need to
# unflatten.
flat_stats, stats_treedef = jax.tree.flatten(stats)
def pmax_aggregate(x: Any) -> Any:
if not hasattr(x, "ndim"):
x = np.array(x)
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
return jax.pmap(
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
axis_name="all_cpus",
backend="cpu",
)(tiled_x)[0]

# In the case of multiple local CPU devices per host, we need to
# replicate the stats to placate JAX collectives.
num_local_cpu_devices = jax.local_device_count("cpu")
tiled_stats = np.tile(
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
)
full_stats = jax.tree.map(pmax_aggregate, stats)

# Aggregate variables across all processes/devices.
max_across_cpus = jax.pmap(
lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
x, "all_cpus"
),
axis_name="all_cpus",
backend="cpu",
# Check if stats changed enough to warrant action.
stacked_table_specs = embedding.get_stacked_table_specs(
self._config.feature_specs
)
changed = any(
np.max(full_stats.max_ids_per_partition[stack_name])
> spec.max_ids_per_partition
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
> spec.max_unique_ids_per_partition
or (
np.max(full_stats.required_buffer_size_per_sc[stack_name])
* num_sc_per_device
)
> (spec.suggested_coo_buffer_size_per_device or 0)
for stack_name, spec in stacked_table_specs.items()
)
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
stats = jax.tree.unflatten(stats_treedef, flat_stats)

# Update configuration and repeat preprocessing if stats changed.
if stats != prev_stats:
embedding_utils.update_stacked_table_stats(
self._config.feature_specs, stats
if changed:
embedding.update_preprocessing_parameters(
self._config.feature_specs, full_stats, num_sc_per_device
)

# Re-execute preprocessing with consistent input statistics.
Expand Down Expand Up @@ -718,7 +719,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:

config = self._config
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
table_specs = embedding_utils.get_table_specs(config.feature_specs)
table_specs = embedding.get_table_specs(config.feature_specs)
sharded_tables = embedding_utils.stack_and_shard_tables(
table_specs,
tables,
Expand Down Expand Up @@ -750,7 +751,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:

config = self._config
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
table_specs = embedding_utils.get_table_specs(config.feature_specs)
table_specs = embedding.get_table_specs(config.feature_specs)

# Extract only the table variables, not the gradient slot variables.
table_variables = {
Expand Down
113 changes: 3 additions & 110 deletions keras_rs/src/layers/embedding/jax/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Utility functions for manipulating JAX embedding tables and inputs."""

import collections
import dataclasses
import typing
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar

Expand Down Expand Up @@ -35,12 +34,6 @@ class ShardedCooMatrix(NamedTuple):
values: ArrayLike


class InputStatsPerTable(NamedTuple):
max_ids_per_partition: int
max_unique_ids_per_partition: int
required_buffer_size_per_device: int


def _round_up_to_multiple(value: int, multiple: int) -> int:
return ((value + multiple - 1) // multiple) * multiple

Expand Down Expand Up @@ -303,15 +296,6 @@ def unshard_and_unstack_tables(
return output


def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
table_spec_map: dict[str, TableSpec] = {}
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
for feature_spec in flat_feature_specs:
table_spec = feature_spec.table_spec
table_spec_map[table_spec.name] = table_spec
return table_spec_map


def get_table_stacks(
table_specs: Nested[TableSpec],
) -> dict[str, list[TableSpec]]:
Expand Down Expand Up @@ -341,84 +325,6 @@ def get_table_stacks(
return stacked_table_specs


def get_stacked_table_stats(
feature_specs: Nested[FeatureSpec],
) -> dict[str, InputStatsPerTable]:
"""Extracts the stacked-table input statistics from the feature specs.

Args:
feature_specs: Feature specs from which to extracts the statistics.

Returns:
A mapping of stacked table names to input statistics per table.
"""
stacked_table_specs: dict[str, StackedTableSpec] = {}
for feature_spec in jax.tree.flatten(feature_specs)[0]:
feature_spec = typing.cast(FeatureSpec, feature_spec)
stacked_table_spec = typing.cast(
StackedTableSpec, feature_spec.table_spec.stacked_table_spec
)
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec

stats: dict[str, InputStatsPerTable] = {}
for stacked_table_spec in stacked_table_specs.values():
buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
buffer_size = buffer_size or 0
stats[stacked_table_spec.stack_name] = InputStatsPerTable(
max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
required_buffer_size_per_device=buffer_size,
)

return stats


def update_stacked_table_stats(
feature_specs: Nested[FeatureSpec],
stats: Mapping[str, InputStatsPerTable],
) -> None:
"""Updates stacked-table input properties in the supplied feature specs.

Args:
feature_specs: Feature specs to update in-place.
stats: Per-stacked-table input statistics.
"""
# Collect table specs and stacked table specs.
table_specs: dict[str, TableSpec] = {}
for feature_spec in jax.tree.flatten(feature_specs)[0]:
feature_spec = typing.cast(FeatureSpec, feature_spec)
table_specs[feature_spec.table_spec.name] = feature_spec.table_spec

stacked_table_specs: dict[str, StackedTableSpec] = {}
for table_spec in table_specs.values():
stacked_table_spec = typing.cast(
StackedTableSpec, table_spec.stacked_table_spec
)
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec

# Replace fields in the stacked_table_specs.
stack_names = stacked_table_specs.keys()
for stack_name in stack_names:
stack_stats = stats[stack_name]
stacked_table_spec = stacked_table_specs[stack_name]
buffer_size = stack_stats.required_buffer_size_per_device or None
stacked_table_specs[stack_name] = dataclasses.replace(
stacked_table_spec,
max_ids_per_partition=stack_stats.max_ids_per_partition,
max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
suggested_coo_buffer_size_per_device=buffer_size,
)

# Insert new stacked tables into tables.
for table_spec in table_specs.values():
stacked_table_spec = typing.cast(
StackedTableSpec, table_spec.stacked_table_spec
)
table_spec.stacked_table_spec = stacked_table_specs[
stacked_table_spec.stack_name
]


def convert_to_numpy(
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
dtype: Any,
Expand Down Expand Up @@ -483,7 +389,7 @@ def ones_like(

Args:
ragged_or_dense: The ragged or dense input whose shape and data-type
define these same attributes of the returned array.
define these same attributes of the returned array.
dtype: The data-type of the returned array.

Returns:
Expand Down Expand Up @@ -567,7 +473,7 @@ def stack_and_shard_samples(
global_device_count: int,
num_sc_per_device: int,
static_buffer_size: int | Mapping[str, int] | None = None,
) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]:
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
"""Prepares input samples for use in embedding lookups.

Args:
Expand Down Expand Up @@ -612,7 +518,6 @@ def collect_tokens_and_weights(
)

out: dict[str, ShardedCooMatrix] = {}
out_stats: dict[str, InputStatsPerTable] = {}
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
for table_name in tables_names:
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
Expand All @@ -626,17 +531,5 @@ def collect_tokens_and_weights(
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
values=preprocessed_inputs.lhs_gains[table_name],
)
out_stats[table_name] = InputStatsPerTable(
max_ids_per_partition=np.max(
stats.max_ids_per_partition[table_name]
),
max_unique_ids_per_partition=np.max(
stats.max_unique_ids_per_partition[table_name]
),
required_buffer_size_per_device=np.max(
stats.required_buffer_size_per_sc[table_name]
)
* num_sc_per_device,
)

return out, out_stats
return out, stats
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ known-first-party = ["keras_rs"]
[tool.mypy]
python_version = "3.10"
strict = "True"
exclude = ["_test\\.py$", "^examples/"]
exclude = ["_test\\.py$", "^examples/", "venv/"]
untyped_calls_exclude = ["ml_dtypes"]
disable_error_code = ["import-untyped", "unused-ignore"]
disallow_subclassing_any = "False"
Expand Down