Skip to content

Commit 31fa994

Browse files
Refactor DistributedEmbedding and embedding_utils to streamline table spec handling and statistics aggregation
1 parent cee2286 commit 31fa994

File tree

3 files changed

+40
-176
lines changed

3 files changed

+40
-176
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ def sparsecore_build(
441441
)
442442

443443
# Collect all stacked tables.
444-
table_specs = embedding_utils.get_table_specs(feature_specs)
445-
table_stacks = embedding_utils.get_table_stacks(table_specs)
444+
table_specs = embedding.get_table_specs(feature_specs)
445+
table_stacks = embedding.get_table_stacks(table_specs)
446446

447447
# Create variables for all stacked tables and slot variables.
448448
with sparsecore_distribution.scope():
@@ -515,10 +515,8 @@ def _sparsecore_symbolic_preprocess(
515515
del inputs, weights, training
516516

517517
# Each stacked-table gets a ShardedCooMatrix.
518-
table_specs = embedding_utils.get_table_specs(
519-
self._config.feature_specs
520-
)
521-
table_stacks = embedding_utils.get_table_stacks(table_specs)
518+
table_specs = embedding.get_table_specs(self._config.feature_specs)
519+
table_stacks = embedding.get_table_stacks(table_specs)
522520
stacked_table_specs = {
523521
stack_name: stack[0].stacked_table_spec
524522
for stack_name, stack in table_stacks.items()
@@ -600,40 +598,43 @@ def _sparsecore_preprocess(
600598
if training:
601599
# Synchronize input statistics across all devices and update the
602600
# underlying stacked tables specs in the feature specs.
603-
prev_stats = embedding_utils.get_stacked_table_stats(
604-
self._config.feature_specs
605-
)
606601

607-
# Take the maximum with existing stats.
608-
stats = keras.tree.map_structure(max, prev_stats, stats)
602+
# Aggregate stats across all processes/devices via pmax.
603+
num_local_cpu_devices = jax.local_device_count("cpu")
609604

610-
# Flatten the stats so we can more efficiently transfer them
611-
# between hosts. We use jax.tree because we will later need to
612-
# unflatten.
613-
flat_stats, stats_treedef = jax.tree.flatten(stats)
605+
def pmax_aggregate(x: Any) -> Any:
606+
if not hasattr(x, "ndim"):
607+
x = np.array(x)
608+
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609+
return jax.pmap(
610+
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611+
axis_name="all_cpus",
612+
backend="cpu",
613+
)(tiled_x)[0]
614614

615-
# In the case of multiple local CPU devices per host, we need to
616-
# replicate the stats to placate JAX collectives.
617-
num_local_cpu_devices = jax.local_device_count("cpu")
618-
tiled_stats = np.tile(
619-
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
620-
)
615+
full_stats = jax.tree.map(pmax_aggregate, stats)
621616

622-
# Aggregate variables across all processes/devices.
623-
max_across_cpus = jax.pmap(
624-
lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
625-
x, "all_cpus"
626-
),
627-
axis_name="all_cpus",
628-
backend="cpu",
617+
# Check if stats changed enough to warrant action.
618+
stacked_table_specs = embedding.get_stacked_table_specs(
619+
self._config.feature_specs
620+
)
621+
changed = any(
622+
np.max(full_stats.max_ids_per_partition[stack_name])
623+
> spec.max_ids_per_partition
624+
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
625+
> spec.max_unique_ids_per_partition
626+
or (
627+
np.max(full_stats.required_buffer_size_per_sc[stack_name])
628+
* num_sc_per_device
629+
)
630+
> (spec.suggested_coo_buffer_size_per_device or 0)
631+
for stack_name, spec in stacked_table_specs.items()
629632
)
630-
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
631-
stats = jax.tree.unflatten(stats_treedef, flat_stats)
632633

633634
# Update configuration and repeat preprocessing if stats changed.
634-
if stats != prev_stats:
635-
embedding_utils.update_stacked_table_stats(
636-
self._config.feature_specs, stats
635+
if changed:
636+
embedding.update_preprocessing_parameters(
637+
self._config.feature_specs, full_stats, num_sc_per_device
637638
)
638639

639640
# Re-execute preprocessing with consistent input statistics.
@@ -718,7 +719,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
718719

719720
config = self._config
720721
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
721-
table_specs = embedding_utils.get_table_specs(config.feature_specs)
722+
table_specs = embedding.get_table_specs(config.feature_specs)
722723
sharded_tables = embedding_utils.stack_and_shard_tables(
723724
table_specs,
724725
tables,
@@ -750,7 +751,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
750751

751752
config = self._config
752753
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
753-
table_specs = embedding_utils.get_table_specs(config.feature_specs)
754+
table_specs = embedding.get_table_specs(config.feature_specs)
754755

755756
# Extract only the table variables, not the gradient slot variables.
756757
table_variables = {

keras_rs/src/layers/embedding/jax/embedding_utils.py

Lines changed: 3 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Utility functions for manipulating JAX embedding tables and inputs."""
22

33
import collections
4-
import dataclasses
5-
import typing
64
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
75

86
import jax
@@ -35,12 +33,6 @@ class ShardedCooMatrix(NamedTuple):
3533
values: ArrayLike
3634

3735

38-
class InputStatsPerTable(NamedTuple):
39-
max_ids_per_partition: int
40-
max_unique_ids_per_partition: int
41-
required_buffer_size_per_device: int
42-
43-
4436
def _round_up_to_multiple(value: int, multiple: int) -> int:
4537
return ((value + multiple - 1) // multiple) * multiple
4638

@@ -303,122 +295,6 @@ def unshard_and_unstack_tables(
303295
return output
304296

305297

306-
def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
307-
table_spec_map: dict[str, TableSpec] = {}
308-
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
309-
for feature_spec in flat_feature_specs:
310-
table_spec = feature_spec.table_spec
311-
table_spec_map[table_spec.name] = table_spec
312-
return table_spec_map
313-
314-
315-
def get_table_stacks(
316-
table_specs: Nested[TableSpec],
317-
) -> dict[str, list[TableSpec]]:
318-
"""Extracts lists of tables that are stacked together.
319-
320-
Args:
321-
table_specs: Nested collection of table specifications.
322-
323-
Returns:
324-
A mapping of stacked table names to lists of table specifications for
325-
each stack.
326-
"""
327-
stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
328-
list
329-
)
330-
flat_table_specs, _ = jax.tree.flatten(table_specs)
331-
for table_spec in flat_table_specs:
332-
table_spec = typing.cast(TableSpec, table_spec)
333-
stacked_table_spec = table_spec.stacked_table_spec
334-
if stacked_table_spec is not None:
335-
stacked_table_specs[stacked_table_spec.stack_name].append(
336-
table_spec
337-
)
338-
else:
339-
stacked_table_specs[table_spec.name].append(table_spec)
340-
341-
return stacked_table_specs
342-
343-
344-
def get_stacked_table_stats(
345-
feature_specs: Nested[FeatureSpec],
346-
) -> dict[str, InputStatsPerTable]:
347-
"""Extracts the stacked-table input statistics from the feature specs.
348-
349-
Args:
350-
feature_specs: Feature specs from which to extracts the statistics.
351-
352-
Returns:
353-
A mapping of stacked table names to input statistics per table.
354-
"""
355-
stacked_table_specs: dict[str, StackedTableSpec] = {}
356-
for feature_spec in jax.tree.flatten(feature_specs)[0]:
357-
feature_spec = typing.cast(FeatureSpec, feature_spec)
358-
stacked_table_spec = typing.cast(
359-
StackedTableSpec, feature_spec.table_spec.stacked_table_spec
360-
)
361-
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
362-
363-
stats: dict[str, InputStatsPerTable] = {}
364-
for stacked_table_spec in stacked_table_specs.values():
365-
buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
366-
buffer_size = buffer_size or 0
367-
stats[stacked_table_spec.stack_name] = InputStatsPerTable(
368-
max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
369-
max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
370-
required_buffer_size_per_device=buffer_size,
371-
)
372-
373-
return stats
374-
375-
376-
def update_stacked_table_stats(
377-
feature_specs: Nested[FeatureSpec],
378-
stats: Mapping[str, InputStatsPerTable],
379-
) -> None:
380-
"""Updates stacked-table input properties in the supplied feature specs.
381-
382-
Args:
383-
feature_specs: Feature specs to update in-place.
384-
stats: Per-stacked-table input statistics.
385-
"""
386-
# Collect table specs and stacked table specs.
387-
table_specs: dict[str, TableSpec] = {}
388-
for feature_spec in jax.tree.flatten(feature_specs)[0]:
389-
feature_spec = typing.cast(FeatureSpec, feature_spec)
390-
table_specs[feature_spec.table_spec.name] = feature_spec.table_spec
391-
392-
stacked_table_specs: dict[str, StackedTableSpec] = {}
393-
for table_spec in table_specs.values():
394-
stacked_table_spec = typing.cast(
395-
StackedTableSpec, table_spec.stacked_table_spec
396-
)
397-
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
398-
399-
# Replace fields in the stacked_table_specs.
400-
stack_names = stacked_table_specs.keys()
401-
for stack_name in stack_names:
402-
stack_stats = stats[stack_name]
403-
stacked_table_spec = stacked_table_specs[stack_name]
404-
buffer_size = stack_stats.required_buffer_size_per_device or None
405-
stacked_table_specs[stack_name] = dataclasses.replace(
406-
stacked_table_spec,
407-
max_ids_per_partition=stack_stats.max_ids_per_partition,
408-
max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
409-
suggested_coo_buffer_size_per_device=buffer_size,
410-
)
411-
412-
# Insert new stacked tables into tables.
413-
for table_spec in table_specs.values():
414-
stacked_table_spec = typing.cast(
415-
StackedTableSpec, table_spec.stacked_table_spec
416-
)
417-
table_spec.stacked_table_spec = stacked_table_specs[
418-
stacked_table_spec.stack_name
419-
]
420-
421-
422298
def convert_to_numpy(
423299
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
424300
dtype: Any,
@@ -483,7 +359,7 @@ def ones_like(
483359
484360
Args:
485361
ragged_or_dense: The ragged or dense input whose shape and data-type
486-
define these same attributes of the returned array.
362+
define these same attributes of the returned array.
487363
dtype: The data-type of the returned array.
488364
489365
Returns:
@@ -567,7 +443,7 @@ def stack_and_shard_samples(
567443
global_device_count: int,
568444
num_sc_per_device: int,
569445
static_buffer_size: int | Mapping[str, int] | None = None,
570-
) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]:
446+
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
571447
"""Prepares input samples for use in embedding lookups.
572448
573449
Args:
@@ -612,7 +488,6 @@ def collect_tokens_and_weights(
612488
)
613489

614490
out: dict[str, ShardedCooMatrix] = {}
615-
out_stats: dict[str, InputStatsPerTable] = {}
616491
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
617492
for table_name in tables_names:
618493
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
@@ -626,17 +501,5 @@ def collect_tokens_and_weights(
626501
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
627502
values=preprocessed_inputs.lhs_gains[table_name],
628503
)
629-
out_stats[table_name] = InputStatsPerTable(
630-
max_ids_per_partition=np.max(
631-
stats.max_ids_per_partition[table_name]
632-
),
633-
max_unique_ids_per_partition=np.max(
634-
stats.max_unique_ids_per_partition[table_name]
635-
),
636-
required_buffer_size_per_device=np.max(
637-
stats.required_buffer_size_per_sc[table_name]
638-
)
639-
* num_sc_per_device,
640-
)
641504

642-
return out, out_stats
505+
return out, stats

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ known-first-party = ["keras_rs"]
6464
[tool.mypy]
6565
python_version = "3.10"
6666
strict = "True"
67-
exclude = ["_test\\.py$", "^examples/"]
67+
exclude = ["_test\\.py$", "^examples/", "venv/"]
6868
untyped_calls_exclude = ["ml_dtypes"]
6969
disable_error_code = ["import-untyped", "unused-ignore"]
7070
disallow_subclassing_any = "False"

0 commit comments

Comments
 (0)