1
1
"""Utility functions for manipulating JAX embedding tables and inputs."""
2
2
3
3
import collections
4
- import dataclasses
5
- import typing
6
4
from typing import Any , Mapping , NamedTuple , Sequence , TypeAlias , TypeVar
7
5
8
6
import jax
@@ -35,12 +33,6 @@ class ShardedCooMatrix(NamedTuple):
35
33
values : ArrayLike
36
34
37
35
38
- class InputStatsPerTable (NamedTuple ):
39
- max_ids_per_partition : int
40
- max_unique_ids_per_partition : int
41
- required_buffer_size_per_device : int
42
-
43
-
44
36
def _round_up_to_multiple (value : int , multiple : int ) -> int :
45
37
return ((value + multiple - 1 ) // multiple ) * multiple
46
38
@@ -303,122 +295,6 @@ def unshard_and_unstack_tables(
303
295
return output
304
296
305
297
306
- def get_table_specs (feature_specs : Nested [FeatureSpec ]) -> dict [str , TableSpec ]:
307
- table_spec_map : dict [str , TableSpec ] = {}
308
- flat_feature_specs , _ = jax .tree .flatten (feature_specs )
309
- for feature_spec in flat_feature_specs :
310
- table_spec = feature_spec .table_spec
311
- table_spec_map [table_spec .name ] = table_spec
312
- return table_spec_map
313
-
314
-
315
- def get_table_stacks (
316
- table_specs : Nested [TableSpec ],
317
- ) -> dict [str , list [TableSpec ]]:
318
- """Extracts lists of tables that are stacked together.
319
-
320
- Args:
321
- table_specs: Nested collection of table specifications.
322
-
323
- Returns:
324
- A mapping of stacked table names to lists of table specifications for
325
- each stack.
326
- """
327
- stacked_table_specs : dict [str , list [TableSpec ]] = collections .defaultdict (
328
- list
329
- )
330
- flat_table_specs , _ = jax .tree .flatten (table_specs )
331
- for table_spec in flat_table_specs :
332
- table_spec = typing .cast (TableSpec , table_spec )
333
- stacked_table_spec = table_spec .stacked_table_spec
334
- if stacked_table_spec is not None :
335
- stacked_table_specs [stacked_table_spec .stack_name ].append (
336
- table_spec
337
- )
338
- else :
339
- stacked_table_specs [table_spec .name ].append (table_spec )
340
-
341
- return stacked_table_specs
342
-
343
-
344
- def get_stacked_table_stats (
345
- feature_specs : Nested [FeatureSpec ],
346
- ) -> dict [str , InputStatsPerTable ]:
347
- """Extracts the stacked-table input statistics from the feature specs.
348
-
349
- Args:
350
- feature_specs: Feature specs from which to extracts the statistics.
351
-
352
- Returns:
353
- A mapping of stacked table names to input statistics per table.
354
- """
355
- stacked_table_specs : dict [str , StackedTableSpec ] = {}
356
- for feature_spec in jax .tree .flatten (feature_specs )[0 ]:
357
- feature_spec = typing .cast (FeatureSpec , feature_spec )
358
- stacked_table_spec = typing .cast (
359
- StackedTableSpec , feature_spec .table_spec .stacked_table_spec
360
- )
361
- stacked_table_specs [stacked_table_spec .stack_name ] = stacked_table_spec
362
-
363
- stats : dict [str , InputStatsPerTable ] = {}
364
- for stacked_table_spec in stacked_table_specs .values ():
365
- buffer_size = stacked_table_spec .suggested_coo_buffer_size_per_device
366
- buffer_size = buffer_size or 0
367
- stats [stacked_table_spec .stack_name ] = InputStatsPerTable (
368
- max_ids_per_partition = stacked_table_spec .max_ids_per_partition ,
369
- max_unique_ids_per_partition = stacked_table_spec .max_unique_ids_per_partition ,
370
- required_buffer_size_per_device = buffer_size ,
371
- )
372
-
373
- return stats
374
-
375
-
376
- def update_stacked_table_stats (
377
- feature_specs : Nested [FeatureSpec ],
378
- stats : Mapping [str , InputStatsPerTable ],
379
- ) -> None :
380
- """Updates stacked-table input properties in the supplied feature specs.
381
-
382
- Args:
383
- feature_specs: Feature specs to update in-place.
384
- stats: Per-stacked-table input statistics.
385
- """
386
- # Collect table specs and stacked table specs.
387
- table_specs : dict [str , TableSpec ] = {}
388
- for feature_spec in jax .tree .flatten (feature_specs )[0 ]:
389
- feature_spec = typing .cast (FeatureSpec , feature_spec )
390
- table_specs [feature_spec .table_spec .name ] = feature_spec .table_spec
391
-
392
- stacked_table_specs : dict [str , StackedTableSpec ] = {}
393
- for table_spec in table_specs .values ():
394
- stacked_table_spec = typing .cast (
395
- StackedTableSpec , table_spec .stacked_table_spec
396
- )
397
- stacked_table_specs [stacked_table_spec .stack_name ] = stacked_table_spec
398
-
399
- # Replace fields in the stacked_table_specs.
400
- stack_names = stacked_table_specs .keys ()
401
- for stack_name in stack_names :
402
- stack_stats = stats [stack_name ]
403
- stacked_table_spec = stacked_table_specs [stack_name ]
404
- buffer_size = stack_stats .required_buffer_size_per_device or None
405
- stacked_table_specs [stack_name ] = dataclasses .replace (
406
- stacked_table_spec ,
407
- max_ids_per_partition = stack_stats .max_ids_per_partition ,
408
- max_unique_ids_per_partition = stack_stats .max_unique_ids_per_partition ,
409
- suggested_coo_buffer_size_per_device = buffer_size ,
410
- )
411
-
412
- # Insert new stacked tables into tables.
413
- for table_spec in table_specs .values ():
414
- stacked_table_spec = typing .cast (
415
- StackedTableSpec , table_spec .stacked_table_spec
416
- )
417
- table_spec .stacked_table_spec = stacked_table_specs [
418
- stacked_table_spec .stack_name
419
- ]
420
-
421
-
422
298
def convert_to_numpy (
423
299
ragged_or_dense : np .ndarray [Any , Any ] | Sequence [Sequence [Any ]] | Any ,
424
300
dtype : Any ,
@@ -483,7 +359,7 @@ def ones_like(
483
359
484
360
Args:
485
361
ragged_or_dense: The ragged or dense input whose shape and data-type
486
- define these same attributes of the returned array.
362
+ define these same attributes of the returned array.
487
363
dtype: The data-type of the returned array.
488
364
489
365
Returns:
@@ -567,7 +443,7 @@ def stack_and_shard_samples(
567
443
global_device_count : int ,
568
444
num_sc_per_device : int ,
569
445
static_buffer_size : int | Mapping [str , int ] | None = None ,
570
- ) -> tuple [dict [str , ShardedCooMatrix ], dict [ str , InputStatsPerTable ] ]:
446
+ ) -> tuple [dict [str , ShardedCooMatrix ], embedding . SparseDenseMatmulInputStats ]:
571
447
"""Prepares input samples for use in embedding lookups.
572
448
573
449
Args:
@@ -612,7 +488,6 @@ def collect_tokens_and_weights(
612
488
)
613
489
614
490
out : dict [str , ShardedCooMatrix ] = {}
615
- out_stats : dict [str , InputStatsPerTable ] = {}
616
491
tables_names = preprocessed_inputs .lhs_row_pointers .keys ()
617
492
for table_name in tables_names :
618
493
shard_ends = preprocessed_inputs .lhs_row_pointers [table_name ]
@@ -626,17 +501,5 @@ def collect_tokens_and_weights(
626
501
row_ids = preprocessed_inputs .lhs_sample_ids [table_name ],
627
502
values = preprocessed_inputs .lhs_gains [table_name ],
628
503
)
629
- out_stats [table_name ] = InputStatsPerTable (
630
- max_ids_per_partition = np .max (
631
- stats .max_ids_per_partition [table_name ]
632
- ),
633
- max_unique_ids_per_partition = np .max (
634
- stats .max_unique_ids_per_partition [table_name ]
635
- ),
636
- required_buffer_size_per_device = np .max (
637
- stats .required_buffer_size_per_sc [table_name ]
638
- )
639
- * num_sc_per_device ,
640
- )
641
504
642
- return out , out_stats
505
+ return out , stats
0 commit comments