diff --git a/recml/core/ops/hstu_ops.py b/recml/core/ops/hstu_ops.py index 3a8df11..59fd7bd 100644 --- a/recml/core/ops/hstu_ops.py +++ b/recml/core/ops/hstu_ops.py @@ -125,9 +125,9 @@ def _apply_mask( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] snm = jnp.where(should_not_mask, 1, 0) masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) @@ -156,7 +156,7 @@ def _apply_mask( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -170,7 +170,7 @@ def _apply_mask( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -181,9 +181,9 @@ def _apply_mask( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) if masks: @@ -228,7 +228,7 @@ def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) q = q_ref[...] - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] qk = jax.lax.dot_general( q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32 ) @@ -256,7 +256,7 @@ def body(kv_compute_index, _): ) sv_dims = NN_DIM_NUMBERS - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] to_float32 = lambda x: x.astype(jnp.float32) v = to_float32(v) diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..eabce4a 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): if abstract_batch is not None: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -117,7 +117,7 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +130,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +138,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): return step_fn(batch, state) return _wrapped_step @@ -217,7 +217,7 @@ def __init__( def mesh_context_manager( self, ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh + return jax.set_mesh def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: diff --git a/recml/examples/DLRM_HSTU/action_encoder.py b/recml/examples/DLRM_HSTU/action_encoder.py new file mode 100644 index 0000000..1caed5b --- /dev/null +++ b/recml/examples/DLRM_HSTU/action_encoder.py @@ -0,0 +1,123 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of the ActionEncoder module.""" + +from typing import Dict, List, Optional, Tuple + +import flax.linen as nn +from flax.linen import initializers +import jax +import jax.numpy as jnp + + +class ActionEncoder(nn.Module): + """Encodes categorical actions and continuous watch times into a fixed-size embedding. + + assumes dense tensors of shape (batch_size, sequence_length) for all inputs. + """ + + action_embedding_dim: int + action_feature_name: str + action_weights: List[int] + watchtime_feature_name: str = "" + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None + + def setup(self): + """Initializes parameters and constants for the module.""" + wt_thresholds_and_weights = ( + self.watchtime_to_action_thresholds_and_weights or [] + ) + + self.combined_action_weights = jnp.array( + list(self.action_weights) + [w for _, w in wt_thresholds_and_weights] + ) + + self.num_action_types: int = ( + len(self.action_weights) + len(wt_thresholds_and_weights) + ) + + self.action_embedding_table = self.param( + "action_embedding_table", + initializers.normal(stddev=0.1), + (self.num_action_types, self.action_embedding_dim), + ) + + self.target_action_embedding_table = self.param( + "target_action_embedding_table", + initializers.normal(stddev=0.1), + (1, self.output_embedding_dim), + ) + + @property + def output_embedding_dim(self) -> int: + """The dimension of the final output embedding.""" + num_watchtime_actions = ( + len(self.watchtime_to_action_thresholds_and_weights) + if self.watchtime_to_action_thresholds_and_weights + else 0 + ) + num_action_types = len(self.action_weights) + num_watchtime_actions + return self.action_embedding_dim * num_action_types + + def __call__( + self, + seq_payloads: Dict[str, jax.Array], + is_target_mask: jax.Array, + ) -> jax.Array: + """Processes a batch of sequences to generate action embeddings. + + Args: + seq_payloads: A dictionary of feature names to dense tensors of shape + `(batch_size, sequence_length)`. + is_target_mask: A boolean tensor of shape `(batch_size, + sequence_length)` where `True` indicates a target item. + + Returns: + A dense tensor of action embeddings of shape + `(batch_size, sequence_length, output_embedding_dim)`. + """ + + seq_actions = seq_payloads[self.action_feature_name] + + wt_thresholds_and_weights = ( + self.watchtime_to_action_thresholds_and_weights or [] + ) + if wt_thresholds_and_weights: + watchtimes = seq_payloads[self.watchtime_feature_name] + for threshold, weight in wt_thresholds_and_weights: + watch_action = (watchtimes >= threshold).astype(jnp.int64) * weight + seq_actions = jnp.bitwise_or(seq_actions, watch_action) + + exploded_actions = ( + jnp.bitwise_and(seq_actions[..., None], self.combined_action_weights) + > 0 + ) + + history_embeddings = ( + exploded_actions[..., None] * self.action_embedding_table + ).reshape(*seq_actions.shape, -1) + + target_embeddings = jnp.broadcast_to( + self.target_action_embedding_table, history_embeddings.shape + ) + + final_embeddings = jnp.where( + is_target_mask[..., None], + target_embeddings, + history_embeddings, + ) + + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/action_encoder_test.py b/recml/examples/DLRM_HSTU/action_encoder_test.py new file mode 100644 index 0000000..6706bd2 --- /dev/null +++ b/recml/examples/DLRM_HSTU/action_encoder_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt +# from third_party.py.pybase import googletest +from absl.testing import absltest +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder + + +class ActionEncoderJaxTest(absltest.TestCase): + def test_forward_and_backward(self) -> None: + """Tests the ActionEncoder's forward pass logic and differentiability.""" + + batch_size = 2 + max_seq_len = 6 + action_embedding_dim = 32 + action_weights = [1, 2, 4, 8, 16] + watchtime_to_action_thresholds_and_weights = [ + (30, 32), (60, 64), (100, 128), + ] + num_action_types = len(action_weights) + len( + watchtime_to_action_thresholds_and_weights + ) + output_dim = action_embedding_dim * num_action_types + combined_action_weights = action_weights + [ + w for _, w in watchtime_to_action_thresholds_and_weights + ] + + enabled_actions = [ + [0], # Seq 1, Item 1 + [0, 1], # Seq 1, Item 2 + [1, 3, 4], # Seq 1, Item 3 + [1, 2, 3, 4], # Seq 1, Item 4 + [1, 2], # Seq 2, Item 1 + [2], # Seq 2, Item 2 + ] + watchtimes_flat = [40, 20, 110, 31, 26, 55] + + # Add actions based on watchtime thresholds + for i, wt in enumerate(watchtimes_flat): + for j, (threshold, _) in enumerate( + watchtime_to_action_thresholds_and_weights + ): + if wt > threshold: + enabled_actions[i].append(j + len(action_weights)) + + actions_flat = [ + sum([combined_action_weights[t] for t in x]) for x in enabled_actions + ] + + padded_actions = np.zeros((batch_size, max_seq_len), dtype=np.int64) + padded_watchtimes = np.zeros((batch_size, max_seq_len), dtype=np.int64) + + padded_actions[0, :4] = actions_flat[0:4] + padded_actions[1, :2] = actions_flat[4:6] + padded_watchtimes[0, :4] = watchtimes_flat[0:4] + padded_watchtimes[1, :2] = watchtimes_flat[4:6] + + is_target_mask = np.zeros((batch_size, max_seq_len), dtype=bool) + is_target_mask[0, 4:6] = True + is_target_mask[1, 2] = True + + padding_mask = np.zeros((batch_size, max_seq_len), dtype=bool) + padding_mask[0, :6] = True + padding_mask[1, :3] = True + + seq_payloads = { + "watchtimes": jnp.array(padded_watchtimes), + "actions": jnp.array(padded_actions), + } + + encoder = ActionEncoder( + watchtime_feature_name="watchtimes", + action_feature_name="actions", + action_weights=action_weights, + watchtime_to_action_thresholds_and_weights=( + watchtime_to_action_thresholds_and_weights + ), + action_embedding_dim=action_embedding_dim, + ) + + key = jax.random.PRNGKey(0) + variables = encoder.init(key, seq_payloads, is_target_mask) + params = variables["params"] + + action_embeddings = encoder.apply( + variables, seq_payloads, is_target_mask + ) + + self.assertEqual( + action_embeddings.shape, (batch_size, max_seq_len, output_dim) + ) + + action_table = params["action_embedding_table"] + target_table_flat = params["target_action_embedding_table"] + target_table = target_table_flat.reshape(num_action_types, -1) + + history_item_idx = 0 + for b in range(batch_size): + for s in range(max_seq_len): + if not padding_mask[b, s]: + npt.assert_allclose(action_embeddings[b, s], 0, atol=1e-6) + continue + + embedding = action_embeddings[b, s].reshape(num_action_types, -1) + + if is_target_mask[b, s]: + npt.assert_allclose(embedding, target_table, atol=1e-6) + else: + current_enabled = enabled_actions[history_item_idx] + for atype in range(num_action_types): + if atype in current_enabled: + npt.assert_allclose( + embedding[atype], action_table[atype], atol=1e-6 + ) + else: + npt.assert_allclose(embedding[atype], + jnp.zeros_like(embedding[atype]), + atol=1e-6) + history_item_idx += 1 + + def loss_fn(p): + return encoder.apply({"params": p}, seq_payloads, is_target_mask).sum() + + grads = jax.grad(loss_fn)(params) + self.assertIsNotNone(grads) + self.assertFalse(np.all(np.isclose(grads["action_embedding_table"], 0))) + self.assertFalse(np.all( + np.isclose(grads["target_action_embedding_table"], 0) + )) + + +if __name__ == "__main__": + absltest.main() diff --git a/recml/examples/DLRM_HSTU/content_encoder.py b/recml/examples/DLRM_HSTU/content_encoder.py new file mode 100644 index 0000000..c66a2bb --- /dev/null +++ b/recml/examples/DLRM_HSTU/content_encoder.py @@ -0,0 +1,159 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of ContentEncoder for dense tensors.""" + +from typing import Dict, List, Optional + +import flax.linen as nn +from jax import numpy as jnp + + +class ContentEncoder(nn.Module): + """JAX/Flax implementation of ContentEncoder for dense tensors. + + This module concatenates input embeddings with additional features. It + handles two types of features: + 1. `additional_content_features`: Features available for the entire + sequence. + 2. `target_enrich_features`: Features available only for target items, with + a learned dummy embedding used as a placeholder for history items. + """ + input_embedding_dim: int + additional_content_features: Optional[Dict[str, int]] = None + target_enrich_features: Optional[Dict[str, int]] = None + + def setup(self) -> None: + self._additional_content_features_internal: Dict[str, int] = ( + self.additional_content_features + if self.additional_content_features is not None + else {} + ) + self._target_enrich_features_internal: Dict[str, int] = ( + self.target_enrich_features + if self.target_enrich_features is not None + else {} + ) + + self._target_enrich_dummy_embeddings = { + name: self.param( + f"target_enrich_dummy_param_{name}", + nn.initializers.normal(stddev=0.1), + (1, dim), # Shape is (1, feature_dim) for broadcasting + ) + for name, dim in self._target_enrich_features_internal.items() + } + + @property + def output_embedding_dim(self) -> int: + """The total dimension of the output embeddings after concatenation.""" + additional_dim = sum( + self.additional_content_features.values() + if self.additional_content_features + else [] + ) + enrich_dim = sum( + self.target_enrich_features.values() + if self.target_enrich_features + else [] + ) + return self.input_embedding_dim + additional_dim + enrich_dim + + @nn.compact + def __call__( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Forward pass for the ContentEncoder. + + Args: + max_uih_len: The length of the user interaction history (non-target part) + in the padded sequence. + seq_embeddings: The base embeddings for the sequence with shape + (batch_size, seq_len, input_embedding_dim). + seq_payloads: A dictionary mapping feature names to their tensors. - For + `additional_content_features`, shape is (batch_size, seq_len, + feature_dim). - For `target_enrich_features`, shape is (batch_size, + max_targets, feature_dim). + + Returns: + The concatenated content embeddings. + Shape: (batch_size, seq_len, output_embedding_dim). + """ + content_embeddings_list: List[jnp.ndarray] = [seq_embeddings] + + if self._additional_content_features_internal: + for x in self._additional_content_features_internal.keys(): + content_embeddings_list.append( + seq_payloads[x].astype(seq_embeddings.dtype) + ) + + if self._target_enrich_dummy_embeddings: + batch_size = seq_embeddings.shape[0] + + for name, param in self._target_enrich_dummy_embeddings.items(): + # If a feature is used for both additional content and target + # enrichment, the payload will contain the full sequence. We need to + # slice the target part. + if name in self._additional_content_features_internal: + full_sequence_feature = seq_payloads[name] + enrich_embeddings_target = full_sequence_feature[ + :, max_uih_len:, : + ].astype(seq_embeddings.dtype) + else: + # Otherwise, the payload contains only the target features. + enrich_embeddings_target = seq_payloads[name].astype( + seq_embeddings.dtype + ) + + enrich_embeddings_uih = jnp.broadcast_to( + param, (batch_size, max_uih_len, param.shape[-1]) + ).astype(seq_embeddings.dtype) + + # Pad targets if necessary to match sequence length + num_targets = enrich_embeddings_target.shape[1] + num_history = max_uih_len + if num_history + num_targets < seq_embeddings.shape[1]: + padding_needed = seq_embeddings.shape[1] - ( + num_history + num_targets + ) + padding = jnp.zeros( + ( + batch_size, + padding_needed, + enrich_embeddings_target.shape[-1], + ), + dtype=enrich_embeddings_target.dtype, + ) + enrich_embeddings_target = jnp.concatenate( + [enrich_embeddings_target, padding], axis=1 + ) + + enrich_embeddings = jnp.concatenate( + [enrich_embeddings_uih, enrich_embeddings_target], axis=1 + ) + content_embeddings_list.append(enrich_embeddings) + + if ( + not self._additional_content_features_internal + and not self._target_enrich_features_internal + ): + return seq_embeddings + else: + content_embeddings = jnp.concatenate( + content_embeddings_list, + axis=-1, + ) + return content_embeddings diff --git a/recml/examples/DLRM_HSTU/content_encoder_test.py b/recml/examples/DLRM_HSTU/content_encoder_test.py new file mode 100644 index 0000000..0c13e3e --- /dev/null +++ b/recml/examples/DLRM_HSTU/content_encoder_test.py @@ -0,0 +1,106 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +import jax +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder + + +class ContentEncoderTest(absltest.TestCase): + """Tests for JAX ContentEncoder.""" + + def test_forward_and_backward_pass(self) -> None: + """Verifies that the model's forward and backward passes execute without error.""" + batch_size = 2 + seq_len = 6 + num_targets = 2 + max_uih_len = seq_len - num_targets + input_embedding_dim = 32 + additional_embedding_dim = 64 + enrich_embedding_dim = 16 + + encoder = ContentEncoder( + input_embedding_dim=input_embedding_dim, + additional_content_features={ + "a0": additional_embedding_dim, + "a1": additional_embedding_dim, + }, + target_enrich_features={ + "t0": enrich_embedding_dim, + "t1": enrich_embedding_dim, + }, + ) + + key = jax.random.PRNGKey(42) + key, data_key, init_key = jax.random.split(key, 3) + + seq_embeddings = jax.random.normal( + data_key, (batch_size, seq_len, input_embedding_dim) + ) + seq_payloads = { + "a0": jax.random.normal( + data_key, (batch_size, seq_len, additional_embedding_dim) + ), + "a1": jax.random.normal( + data_key, (batch_size, seq_len, additional_embedding_dim) + ), + "t0": jax.random.normal( + data_key, (batch_size, num_targets, enrich_embedding_dim) + ), + "t1": jax.random.normal( + data_key, (batch_size, num_targets, enrich_embedding_dim) + ), + } + + params = encoder.init( + init_key, + max_uih_len, + seq_embeddings, + seq_payloads, + )["params"] + + content_embeddings = encoder.apply( + {"params": params}, + max_uih_len, + seq_embeddings, + seq_payloads, + ) + + expected_dim = ( + input_embedding_dim + + sum(encoder.additional_content_features.values()) + + sum(encoder.target_enrich_features.values()) + ) + self.assertEqual( + content_embeddings.shape, (batch_size, seq_len, expected_dim) + ) + + def loss_fn(p): + output = encoder.apply( + {"params": p}, + max_uih_len, + seq_embeddings, + seq_payloads, + ) + return jnp.sum(output) + + grads = jax.grad(loss_fn)(params) + + self.assertIsNotNone(grads) + self.assertIn("target_enrich_dummy_param_t0", grads) + self.assertIn("target_enrich_dummy_param_t1", grads) + + +if __name__ == "__main__": + absltest.main() diff --git a/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py b/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py new file mode 100644 index 0000000..fad1bd3 --- /dev/null +++ b/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py @@ -0,0 +1,164 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of ContextualInterleavePreprocessor.""" + +from typing import Callable, Dict, Tuple + +from flax import linen as nn +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder +from recml.examples.DLRM_HSTU.contextualize_mlps import ContextualizedMLP +from recml.examples.DLRM_HSTU.preprocessors import get_contextual_input_embeddings +from recml.examples.DLRM_HSTU.preprocessors import InputPreprocessor + + +class ContextualInterleavePreprocessor(InputPreprocessor): + """A JAX/Flax implementation of the ContextualInterleavePreprocessor. + + This preprocessor orchestrates content encoding, action encoding, and + contextualization using parameterized MLPs, working on dense, padded tensors. + """ + + input_embedding_dim: int + output_embedding_dim: int + contextual_feature_to_max_length: Dict[str, int] + contextual_feature_to_min_uih_length: Dict[str, int] + content_encoder: ContentEncoder + content_contextualize_mlp_fn: Callable[[], ContextualizedMLP] + action_encoder: ActionEncoder + action_contextualize_mlp_fn: Callable[[], ContextualizedMLP] + pmlp_contextual_dropout_ratio: float = 0.0 + enable_interleaving: bool = False + + def setup(self): + self._max_contextual_seq_len = sum( + self.contextual_feature_to_max_length.values() + ) + + self._content_embedding_mlp = self.content_contextualize_mlp_fn() + self._action_embedding_mlp = self.action_contextualize_mlp_fn() + + if self._max_contextual_seq_len > 0: + self._batched_contextual_linear_weights = self.param( + "batched_contextual_linear_weights", + nn.initializers.xavier_uniform(), + ( + self._max_contextual_seq_len, + self.input_embedding_dim, + self.output_embedding_dim, + ), + ) + self._batched_contextual_linear_bias = self.param( + "batched_contextual_linear_bias", + nn.initializers.zeros, + (self._max_contextual_seq_len, self.output_embedding_dim), + ) + self._pmlp_dropout = nn.Dropout(rate=self.pmlp_contextual_dropout_ratio) + + def __call__( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + *, + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray] + ]: + batch_size, max_seq_len, _ = seq_embeddings.shape + + pmlp_contextual_embeddings = None + contextual_embeddings = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_mask=seq_mask, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self.contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + + pmlp_contextual_embeddings = self._pmlp_dropout( + contextual_input_embeddings, deterministic=deterministic + ) + + contextual_embeddings = jnp.einsum( + "bci,cio->bco", + contextual_input_embeddings.reshape( + batch_size, self._max_contextual_seq_len, self.input_embedding_dim + ), + self._batched_contextual_linear_weights, + ) + jnp.expand_dims(self._batched_contextual_linear_bias, axis=0) + + # Content Embeddings + content_embeddings = self.content_encoder( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + content_embeddings = self._content_embedding_mlp( + seq_embeddings=content_embeddings, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # Action Embeddings + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + indices = jnp.arange(max_seq_len) + start_target_idx = jnp.expand_dims(seq_lengths - num_targets, axis=1) + is_target_mask = (indices >= start_target_idx) & seq_mask + + action_embeddings = self.action_encoder( + seq_payloads=seq_payloads, + is_target_mask=is_target_mask, + ) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # Combine + output_seq_embeddings = content_embeddings + action_embeddings + output_seq_embeddings *= jnp.expand_dims(seq_mask, axis=-1) + output_mask = seq_mask + output_timestamps = seq_timestamps + + # Prepend contextual embeddings + if self._max_contextual_seq_len > 0: + output_seq_embeddings = jnp.concatenate( + [contextual_embeddings, output_seq_embeddings], axis=1 + ) + contextual_mask = jnp.ones( + (batch_size, self._max_contextual_seq_len), dtype=jnp.bool_ + ) + output_mask = jnp.concatenate([contextual_mask, seq_mask], axis=1) + + contextual_timestamps = jnp.zeros( + (batch_size, self._max_contextual_seq_len), + dtype=seq_timestamps.dtype, + ) + output_timestamps = jnp.concatenate( + [contextual_timestamps, seq_timestamps], axis=1 + ) + + return ( + output_seq_embeddings, + output_mask, + output_timestamps, + num_targets, + seq_payloads, + ) diff --git a/recml/examples/DLRM_HSTU/contextualize_mlps.py b/recml/examples/DLRM_HSTU/contextualize_mlps.py new file mode 100644 index 0000000..d694824 --- /dev/null +++ b/recml/examples/DLRM_HSTU/contextualize_mlps.py @@ -0,0 +1,179 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains Flax modules for contextualized MLPs used in DLRM-HSTU.""" + +from typing import Optional + +from flax import linen as nn +import jax.numpy as jnp + + +class SwishLayerNorm(nn.Module): + """Custom module for Swish(LayerNorm(x)) which is x * sigmoid(LayerNorm(x)). + + This mimics the SwishLayerNorm class in the PyTorch implementation. + """ + + epsilon: float = 1e-5 + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Computes Swish(LayerNorm(x)). + + Args: + x: Input tensor. + + Returns: + The output tensor. + """ + normed_x = nn.LayerNorm(epsilon=self.epsilon, name="layernorm")(x) + return x * nn.sigmoid(normed_x) + + +class ContextualizedMLP(nn.Module): + """Abstract base class for contextualized MLPs. + + JAX/Flax doesn't strictly require this, but it is included for structural + parity with the PyTorch version. + + This module assumes dense inputs, where ragged tensors have been padded. + """ + + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Forward pass for contextualized MLPs. + + Args: + seq_embeddings: Dense tensor of shape (B, N, D_in). + contextual_embeddings: Dense tensor of shape (B, D_ctx). + + Returns: + Output tensor. + """ + raise NotImplementedError() + + +class SimpleContextualizedMLP(ContextualizedMLP): + """A simple MLP applied to sequential embeddings, ignoring contextual ones. + + This module is analogous to the PyTorch version and works on dense tensors. + """ + + sequential_input_dim: int + sequential_output_dim: int + hidden_dim: int + + @nn.compact + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Applies a simple MLP to the sequence embeddings. + + Args: + seq_embeddings: Dense tensor of shape (B, N, sequential_input_dim). + contextual_embeddings: Ignored. + + Returns: + Output tensor of shape (B, N, sequential_output_dim). + """ + x = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="mlp_0", + )(seq_embeddings) + x = SwishLayerNorm(name="mlp_1")(x) + + x = nn.Dense( + features=self.sequential_output_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="mlp_2", + )(x) + x = nn.LayerNorm(name="mlp_3")(x) + return x + + +class ParameterizedContextualizedMLP(ContextualizedMLP): + """An MLP whose weights are parameterized by contextual embeddings. + + This module is analogous to the PyTorch version and works on dense tensors. + """ + + contextual_embedding_dim: int + sequential_input_dim: int + sequential_output_dim: int + hidden_dim: int + + @nn.compact + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Applies a parameterized MLP to the sequence embeddings. + + Args: + seq_embeddings: Dense tensor of shape (B, N, sequential_input_dim). + contextual_embeddings: Dense tensor of shape + (B, contextual_embedding_dim). + + Returns: + Output tensor of shape (B, N, sequential_output_dim). + """ + if contextual_embeddings is None: + raise ValueError( + "contextual_embeddings cannot be None for " + "ParameterizedContextualizedMLP" + ) + + shared_input = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="dense_features_compress" + )(contextual_embeddings) + + attn_raw_weights_flat = nn.Dense( + features=self.sequential_input_dim * self.sequential_output_dim, + name="attn_raw_weights_0" + )(shared_input) + + batch_size = contextual_embeddings.shape[0] + attn_weights_unnorm = attn_raw_weights_flat.reshape( + batch_size, self.sequential_input_dim, self.sequential_output_dim + ) + + attn_weights = nn.LayerNorm( + feature_axes=(-2, -1), + name="attn_weights_norm" + )(attn_weights_unnorm) + + res_x = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="res_weights_0" + )(shared_input) + res_x = SwishLayerNorm(name="res_weights_1")(res_x) + bias = nn.Dense( + features=self.sequential_output_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="res_weights_2" + )(res_x) + + bmm_out = jnp.matmul(seq_embeddings, attn_weights) + bias_broadcast = jnp.expand_dims(bias, axis=1) + return bmm_out + bias_broadcast diff --git a/recml/examples/DLRM_HSTU/dlrm_hstu.py b/recml/examples/DLRM_HSTU/dlrm_hstu.py new file mode 100644 index 0000000..b750309 --- /dev/null +++ b/recml/examples/DLRM_HSTU/dlrm_hstu.py @@ -0,0 +1,544 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of DLRM-HSTU.""" + +from dataclasses import dataclass +from dataclasses import field +from functools import partial +import logging +from typing import Any, Dict, List, Optional, Tuple + +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder +from recml.examples.DLRM_HSTU.contextual_interleave_preprocessor import ContextualInterleavePreprocessor +from recml.examples.DLRM_HSTU.contextualize_mlps import ContextualizedMLP +from recml.examples.DLRM_HSTU.contextualize_mlps import ParameterizedContextualizedMLP +from recml.examples.DLRM_HSTU.contextualize_mlps import SimpleContextualizedMLP +from recml.examples.DLRM_HSTU.hstu_transducer import HSTUTransducer +from recml.examples.DLRM_HSTU.multitask_module import DefaultMultitaskModule +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig +from recml.examples.DLRM_HSTU.positional_encoder import HSTUPositionalEncoder +from recml.examples.DLRM_HSTU.postprocessors import L2NormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import LayerNormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import TimestampLayerNormPostprocessor +from recml.examples.DLRM_HSTU.preprocessors import SwishLayerNorm +from recml.examples.DLRM_HSTU.stu import STULayer +from recml.examples.DLRM_HSTU.stu import STULayerConfig +from recml.examples.DLRM_HSTU.stu import STUStack + + +logger = logging.getLogger(__name__) + +Dtype = Any +Array = jnp.ndarray + + +@dataclass +class EmbeddingConfig: + """Simplified embedding config for JAX.""" + + name: str + num_embeddings: int + embedding_dim: int + + +@dataclass +class DlrmHSTUConfig: + """Configuration for DLRM-HSTU model.""" + + max_seq_len: int = 2056 + max_num_candidates: int = 10 + max_num_candidates_inference: int = 5 + hstu_num_heads: int = 1 + hstu_attn_linear_dim: int = 256 + hstu_attn_qk_dim: int = 128 + hstu_attn_num_layers: int = 12 + hstu_embedding_table_dim: int = 192 + hstu_preprocessor_hidden_dim: int = 256 + hstu_transducer_embedding_dim: int = 256 # changed from 0 + hstu_group_norm: bool = False + hstu_input_dropout_ratio: float = 0.2 + hstu_linear_dropout_rate: float = 0.2 + hstu_max_attn_len: int = 0 + contextual_feature_to_max_length: Dict[str, int] = field(default_factory=dict) + contextual_feature_to_min_uih_length: Dict[str, int] = field( + default_factory=dict + ) + additional_content_features: Optional[Dict[str, int]] = None + target_enrich_features: Optional[Dict[str, int]] = None + pmlp_contextual_dropout_ratio: float = 0.0 + candidates_weight_feature_name: str = "" + candidates_watchtime_feature_name: str = "" + candidates_querytime_feature_name: str = "" + watchtime_feature_name: str = "" + causal_multitask_weights: float = 0.2 + multitask_configs: List[TaskConfig] = field(default_factory=list) + user_embedding_feature_names: List[str] = field(default_factory=list) + item_embedding_feature_names: List[str] = field(default_factory=list) + uih_post_id_feature_name: str = "" + uih_action_time_feature_name: str = "" + uih_weight_feature_name: str = "" + hstu_uih_feature_names: List[str] = field(default_factory=list) + hstu_candidate_feature_names: List[str] = field(default_factory=list) + merge_uih_candidate_feature_mapping: List[Tuple[str, str]] = field( + default_factory=list + ) + action_weights: Optional[List[int]] = None + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None + enable_postprocessor: bool = True + use_layer_norm_postprocessor: bool = False + + +def _get_supervision_labels_and_weights( + supervision_bitmasks: jnp.ndarray, + watchtime_sequence: jnp.ndarray, + task_configs: List[TaskConfig], +) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: + """Computes supervision labels and weights for multitask learning.""" + supervision_labels: Dict[str, jnp.ndarray] = {} + supervision_weights: Dict[str, jnp.ndarray] = {} + for task in task_configs: + if task.task_type == MultitaskTaskType.REGRESSION: + supervision_labels[task.task_name] = watchtime_sequence.astype( + jnp.float32 + ) + elif task.task_type == MultitaskTaskType.BINARY_CLASSIFICATION: + supervision_labels[task.task_name] = ( + jnp.bitwise_and(supervision_bitmasks, task.task_weight) > 0 + ).astype(jnp.float32) + else: + raise RuntimeError("Unsupported MultitaskTaskType") + return supervision_labels, supervision_weights + + +class EmbeddingCollection(nn.Module): + """A module to hold and query multiple embedding tables.""" + + embedding_configs: Dict[str, EmbeddingConfig] + + def setup(self): + self.embeddings = { + name: nn.Embed( + num_embeddings=cfg.num_embeddings, + features=cfg.embedding_dim, + name=name, + ) + for name, cfg in self.embedding_configs.items() + } + + def __call__( + self, features: Dict[str, jnp.ndarray] + ) -> Dict[str, jnp.ndarray]: + """Looks up embeddings for features given as dense ID tensors.""" + return { + name: self.embeddings[name](ids) + for name, ids in features.items() + if name in self.embeddings + } + + +class PredictionMLP(nn.Module): + """MLP for multitask prediction head.""" + + hidden_dim: int + num_tasks: int + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + x = nn.Dense( + features=self.hidden_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = SwishLayerNorm(dtype=self.dtype)(x) + x = nn.Dense( + features=self.num_tasks, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + return x + + +class ItemMLP(nn.Module): + """MLP for processing item embeddings.""" + + hidden_dim: int + output_dim: int + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + x = nn.Dense( + features=self.hidden_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = SwishLayerNorm(dtype=self.dtype)(x) + x = nn.Dense( + features=self.output_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = nn.LayerNorm(dtype=self.dtype)(x) + return x + + +class DlrmHSTU(nn.Module): + """JAX/Flax implementation of DLRM with HSTU user encoder. + + Operates on dense tensors. + """ + + hstu_configs: DlrmHSTUConfig + embedding_tables: Dict[str, EmbeddingConfig] + dtype: Dtype = jnp.float32 + + def setup(self): + self._embedding_collection = EmbeddingCollection(self.embedding_tables) + self._multitask_configs: List[TaskConfig] = ( + self.hstu_configs.multitask_configs + ) + + self._multitask_module = DefaultMultitaskModule( + task_configs=self._multitask_configs, + embedding_dim=self.hstu_configs.hstu_transducer_embedding_dim, + prediction_fn=lambda in_dim, num_tasks: PredictionMLP( + hidden_dim=512, num_tasks=num_tasks, dtype=self.dtype + ), + causal_multitask_weights=self.hstu_configs.causal_multitask_weights, + ) + + hstu_config = self.hstu_configs + + content_encoder = ContentEncoder( + input_embedding_dim=hstu_config.hstu_embedding_table_dim, + additional_content_features=hstu_config.additional_content_features, + target_enrich_features=hstu_config.target_enrich_features, + ) + + action_encoder = ActionEncoder( + action_embedding_dim=hstu_config.hstu_transducer_embedding_dim, + action_feature_name=hstu_config.uih_weight_feature_name, + action_weights=hstu_config.action_weights, + watchtime_feature_name=hstu_config.watchtime_feature_name, + watchtime_to_action_thresholds_and_weights=hstu_config.watchtime_to_action_thresholds_and_weights, + ) + + contextual_embedding_dim = sum( + hstu_config.contextual_feature_to_max_length.values() + ) * hstu_config.hstu_embedding_table_dim + + def mlp_fn( + sequential_input_dim: int, + ) -> ContextualizedMLP: + if contextual_embedding_dim > 0: + return ParameterizedContextualizedMLP( + contextual_embedding_dim=contextual_embedding_dim, + sequential_input_dim=sequential_input_dim, + sequential_output_dim=hstu_config.hstu_transducer_embedding_dim, + hidden_dim=hstu_config.hstu_preprocessor_hidden_dim, + ) + else: + return SimpleContextualizedMLP( + sequential_input_dim=sequential_input_dim, + sequential_output_dim=hstu_config.hstu_transducer_embedding_dim, + hidden_dim=hstu_config.hstu_preprocessor_hidden_dim, + ) + + preprocessor = ContextualInterleavePreprocessor( + input_embedding_dim=hstu_config.hstu_embedding_table_dim, + output_embedding_dim=hstu_config.hstu_transducer_embedding_dim, + contextual_feature_to_max_length=hstu_config.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=hstu_config.contextual_feature_to_min_uih_length, + content_encoder=content_encoder, + content_contextualize_mlp_fn=partial( + mlp_fn, sequential_input_dim=content_encoder.output_embedding_dim + ), + action_encoder=action_encoder, + action_contextualize_mlp_fn=partial( + mlp_fn, sequential_input_dim=action_encoder.output_embedding_dim + ), + pmlp_contextual_dropout_ratio=hstu_config.pmlp_contextual_dropout_ratio, + ) + + contextual_seq_len = sum( + hstu_config.contextual_feature_to_max_length.values() + ) + positional_encoder = HSTUPositionalEncoder( + num_position_buckets=8192, + num_time_buckets=2048, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + contextual_seq_len=contextual_seq_len, + ) + + if hstu_config.enable_postprocessor: + if hstu_config.use_layer_norm_postprocessor: + postproc_cls = partial( + LayerNormPostprocessor, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + eps=1e-5, + dtype=self.dtype, + ) + else: + postproc_cls = partial( + TimestampLayerNormPostprocessor, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + time_duration_features=[(60 * 60, 24), (24 * 60 * 60, 7)], + eps=1e-5, + dtype=self.dtype, + ) + else: + postproc_cls = L2NormPostprocessor + + stu_layers = [] + for _ in range(hstu_config.hstu_attn_num_layers): + stu_layer_config = STULayerConfig( + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + num_heads=hstu_config.hstu_num_heads, + hidden_dim=hstu_config.hstu_attn_linear_dim, + attention_dim=hstu_config.hstu_attn_qk_dim, + output_dropout_ratio=hstu_config.hstu_linear_dropout_rate, + causal=True, + target_aware=True, + use_group_norm=hstu_config.hstu_group_norm, + contextual_seq_len=contextual_seq_len, + max_attn_len=hstu_config.hstu_max_attn_len, + ) + stu_layers.append(STULayer(config=stu_layer_config)) + stu_module = STUStack( + stu_layers=stu_layers + ) + + self._hstu_transducer = HSTUTransducer( + stu_module=stu_module, + input_preprocessor=preprocessor, + output_postprocessor_cls=postproc_cls, + input_dropout_ratio=hstu_config.hstu_input_dropout_ratio, + positional_encoder=positional_encoder, + return_full_embeddings=False, + listwise=False, + ) + + self._item_embedding_mlp = ItemMLP( + hidden_dim=512, + output_dim=hstu_config.hstu_transducer_embedding_dim, + dtype=self.dtype, + ) + + def _concat_features( + self, uih_tensor: Array, cand_tensor: Array + ) -> Array: + """Concatenates dense UIH and candidate tensors along sequence dim.""" + return jnp.concatenate([uih_tensor, cand_tensor], axis=1) + + def _construct_payload( + self, + uih_features: Dict[str, Array], + cand_features: Dict[str, Array], + uih_embeddings: Dict[str, Array], + cand_embeddings: Dict[str, Array], + ) -> Dict[str, Array]: + """Constructs payload dictionary for HSTUTransducer.""" + payload = {} + for name in self.hstu_configs.contextual_feature_to_max_length: + if name in uih_embeddings: + payload[name] = uih_embeddings[name] + elif name in cand_embeddings: + payload[name] = cand_embeddings[name] + elif name in uih_features and name in self.embedding_tables: + payload[name] = self._embedding_collection({name: uih_features[name]})[ + name + ] + elif name in cand_features and name in self.embedding_tables: + payload[name] = self._embedding_collection({name: cand_features[name]})[ + name + ] + elif name in uih_features: # non-embedding contextual feature + payload[name] = uih_features[name] + elif name in cand_features: + payload[name] = cand_features[name] + + for ( + uih_name, + cand_name, + ) in self.hstu_configs.merge_uih_candidate_feature_mapping: + # Handle non-embedding features that need to be merged. + if uih_name in uih_features and cand_name in cand_features: + if uih_name not in self.embedding_tables and uih_name not in payload: + payload[uih_name] = self._concat_features( + uih_features[uih_name], cand_features[cand_name] + ) + # Handle embedding features that need to be in the payload. + if uih_name in uih_embeddings and cand_name in cand_embeddings: + if uih_name not in payload: + payload[uih_name] = self._concat_features( + uih_embeddings[uih_name], cand_embeddings[cand_name] + ) + + # Handle features that only exist for candidates (for target enrichment) + if self.hstu_configs.target_enrich_features: + for feat_name in self.hstu_configs.target_enrich_features: + if feat_name in cand_embeddings and feat_name not in payload: + payload[feat_name] = cand_embeddings[feat_name] + return payload + + def __call__( + self, + uih_features: Dict[str, Array], + candidate_features: Dict[str, Array], + uih_lengths: Array, + num_candidates: Array, + *, + deterministic: bool, + ) -> Tuple[ + Array, + Array, + Dict[str, Array], + Optional[Array], + Optional[Array], + Optional[Array], + ]: + """Forward pass for DLRM-HSTU. + + Args: + uih_features: Dict of dense UIH feature tensors (B, max_uih_len, ...). + candidate_features: Dict of dense candidate feature tensors + (B, max_cand, ...). + uih_lengths: Length of UIH sequences (B,). + num_candidates: Number of candidates per example (B,). + deterministic: If true, disable dropout. + + Returns: + Tuple of (user_embeddings, item_embeddings, aux_losses, + preds, labels, weights). + """ + max_uih_len = uih_features[ + self.hstu_configs.uih_post_id_feature_name + ].shape[1] + cand_key = self.hstu_configs.item_embedding_feature_names[0] + max_candidates = candidate_features[cand_key].shape[1] + + uih_embeddings = self._embedding_collection(uih_features) + cand_embeddings = self._embedding_collection(candidate_features) + + merged_embeddings: Dict[str, Array] = {} + for ( + uih_name, + cand_name, + ) in self.hstu_configs.merge_uih_candidate_feature_mapping: + if uih_name in uih_embeddings and cand_name in cand_embeddings: + merged_embeddings[uih_name] = self._concat_features( + uih_embeddings[uih_name], cand_embeddings[cand_name] + ) + if self.hstu_configs.uih_post_id_feature_name not in merged_embeddings: + raise ValueError( + "Post ID feature " + f"{self.hstu_configs.uih_post_id_feature_name} not found in " + "merged embeddings." + ) + cand_item_embeddings_for_mlp = jnp.concatenate( + [ + cand_embeddings[k] + for k in self.hstu_configs.item_embedding_feature_names + ], + axis=-1, + ) + item_embeddings_candidates = self._item_embedding_mlp( + cand_item_embeddings_for_mlp + ) + + payload = self._construct_payload( + uih_features, candidate_features, uih_embeddings, cand_embeddings + ) + hstu_seq_lengths = uih_lengths + num_candidates + hstu_seq_embeddings = merged_embeddings[ + self.hstu_configs.uih_post_id_feature_name + ] + candidate_querytime_feature_name = ( + self.hstu_configs.candidates_querytime_feature_name + ) + hstu_seq_timestamps = self._concat_features( + uih_features[self.hstu_configs.uih_action_time_feature_name], + candidate_features[candidate_querytime_feature_name], + ) + + user_embeddings_candidates, _ = self._hstu_transducer( + max_uih_len=max_uih_len, + max_targets=max_candidates, + total_uih_len=0, # Not used in dense tensor implementation + total_targets=0, # Not used in dense tensor implementation + seq_lengths=hstu_seq_lengths, + seq_embeddings=hstu_seq_embeddings, + seq_timestamps=hstu_seq_timestamps, + num_targets=num_candidates, + seq_payloads=payload, + deterministic=deterministic, + ) + + supervision_bitmasks = candidate_features[ + self.hstu_configs.candidates_weight_feature_name + ] + watchtime_sequence = candidate_features[ + self.hstu_configs.candidates_watchtime_feature_name + ] + supervision_labels, supervision_weights = ( + _get_supervision_labels_and_weights( + supervision_bitmasks, + watchtime_sequence, + self._multitask_configs, + ) + ) + + # The HSTU transducer returns embeddings for the full sequence, with + # non-candidate parts masked. We need to slice out the candidate parts + # to match the shape of the item embeddings. + user_embeddings_candidates = user_embeddings_candidates[ + :, -max_candidates:, : + ] + + mt_target_preds, mt_target_labels, mt_target_weights, mt_losses = ( + self._multitask_module( + encoded_user_embeddings=user_embeddings_candidates, + item_embeddings=item_embeddings_candidates, + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + deterministic=deterministic, + ) + ) + + aux_losses: Dict[str, Array] = {} + if not deterministic and mt_losses is not None: + for i, task in enumerate(self._multitask_configs): + aux_losses[task.task_name] = mt_losses[i] + + return ( + user_embeddings_candidates, + item_embeddings_candidates, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) diff --git a/recml/examples/DLRM_HSTU/dlrm_hstu_test.py b/recml/examples/DLRM_HSTU/dlrm_hstu_test.py new file mode 100644 index 0000000..df8ff73 --- /dev/null +++ b/recml/examples/DLRM_HSTU/dlrm_hstu_test.py @@ -0,0 +1,245 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import tree_util +import jax.numpy as jnp +import optax +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTU +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTUConfig +from recml.examples.DLRM_HSTU.dlrm_hstu import EmbeddingConfig +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig + + +class DlrmHstuTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.batch_size = 16 + self.max_uih_len = 4 + self.max_candidates = 2 + self.embed_dim = 32 + self.hstu_dim = 128 + self.post_id_vocab = 20 + self.cat_feat_vocab = 10 + + self.config = DlrmHSTUConfig( + max_seq_len=self.max_uih_len + self.max_candidates, + hstu_embedding_table_dim=self.embed_dim, + hstu_transducer_embedding_dim=self.hstu_dim, + hstu_preprocessor_hidden_dim=8, + hstu_attn_num_layers=1, + hstu_attn_linear_dim=8, + hstu_attn_qk_dim=8, + item_embedding_feature_names=['post_id', 'cat_feat'], + user_embedding_feature_names=['post_id'], + uih_post_id_feature_name='post_id', + uih_action_time_feature_name='action_time', + candidates_querytime_feature_name='query_time', + candidates_weight_feature_name='weight', + candidates_watchtime_feature_name='watch_time', + uih_weight_feature_name='action_type', + action_weights=[1, 2, 4], + merge_uih_candidate_feature_mapping=[ + ('post_id', 'post_id'), + ('cat_feat', 'cat_feat'), + ('action_type', 'action_type'), + ], + multitask_configs=[ + TaskConfig('CTR', 1, MultitaskTaskType.BINARY_CLASSIFICATION), + TaskConfig('WT', 2, MultitaskTaskType.REGRESSION), + ], + contextual_feature_to_max_length={}, + additional_content_features={'cat_feat': self.embed_dim}, + target_enrich_features={'cat_feat': self.embed_dim}, + ) + self.tables = { + 'post_id': EmbeddingConfig( + 'post_id', self.post_id_vocab, self.embed_dim + ), + 'cat_feat': EmbeddingConfig( + 'cat_feat', self.cat_feat_vocab, self.embed_dim + ), + } + + def _get_mock_data(self, key): + k1, k2, k3, k4, k5, k6, k7, k8, k9 = jax.random.split(key, 9) + uih_features = { + 'post_id': jax.random.randint( + k1, (self.batch_size, self.max_uih_len), 0, self.post_id_vocab + ), + 'cat_feat': jax.random.randint( + k2, (self.batch_size, self.max_uih_len), 0, self.cat_feat_vocab + ), + 'action_time': jax.random.randint( + k3, (self.batch_size, self.max_uih_len), 0, 1000 + ), + 'action_type': jax.random.randint( + k7, (self.batch_size, self.max_uih_len), 0, 8 + ), + } + candidate_features = { + 'post_id': jax.random.randint( + k1, (self.batch_size, self.max_candidates), 0, self.post_id_vocab + ), + 'cat_feat': jax.random.randint( + k2, (self.batch_size, self.max_candidates), 0, self.cat_feat_vocab + ), + 'query_time': jax.random.randint( + k4, (self.batch_size, self.max_candidates), 1000, 2000 + ), + 'weight': jax.random.randint( + k5, (self.batch_size, self.max_candidates), 0, 2 + ), # for CTR bitmask + 'watch_time': jax.random.randint( + k6, (self.batch_size, self.max_candidates), 0, 100 + ), # for WT regression + 'action_type': jax.random.randint( + k7, (self.batch_size, self.max_candidates), 0, 8 + ), + } + uih_lengths = jax.random.randint( + k8, (self.batch_size,), 1, self.max_uih_len + 1 + ).astype(jnp.int32) + num_candidates = jax.random.randint( + k9, (self.batch_size,), 1, self.max_candidates + 1 + ).astype(jnp.int32) + return uih_features, candidate_features, uih_lengths, num_candidates + + @parameterized.named_parameters( + ('train', False), + ('eval', True), + ) + def test_forward_pass(self, deterministic): + key = jax.random.PRNGKey(0) + prng_keys = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, embedding_tables=self.tables) + uih_features, candidate_features, uih_lengths, num_candidates = ( + self._get_mock_data(key) + ) + + variables = model.init( + {'params': prng_keys[0], 'dropout': prng_keys[1]}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + ) + + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + variables, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + rngs={'dropout': prng_keys[2]} if not deterministic else None, + ) + + num_tasks = len(self.config.multitask_configs) + expected_user_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(user_emb.shape, expected_user_emb_shape) + expected_item_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(item_emb.shape, expected_item_emb_shape) + self.assertEqual( + preds.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + if not deterministic: + self.assertNotEmpty(aux_losses) + self.assertEqual( + labels.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + self.assertEqual( + weights.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + else: + self.assertEmpty(aux_losses) + self.assertIsNone(labels) + self.assertIsNone(weights) + + def test_backward_pass_and_training(self): + key = jax.random.PRNGKey(1) + init_key, data_key, train_key = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, embedding_tables=self.tables) + uih_features, candidate_features, uih_lengths, num_candidates = ( + self._get_mock_data(data_key) + ) + + variables = model.init( + {'params': init_key, 'dropout': train_key}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=False, + ) + params = variables['params'] + + logging.info( + 'Model parameter shapes: %s', + tree_util.tree_map(lambda x: x.shape, params), + ) + logging.info('Model parameters: %s', params) + + optimizer = optax.adam(learning_rate=1e-3) + opt_state = optimizer.init(params) + + def loss_fn(params, dropout_key): + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + {'params': params}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=False, + rngs={'dropout': dropout_key}, + ) + return ( + user_emb.sum() + + item_emb.sum() + + preds.sum() + + sum(val.sum() for val in aux_losses.values()) + ) + + @jax.jit + def train_step(params, opt_state, dropout_key): + loss, grads = jax.value_and_grad(loss_fn)(params, dropout_key) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss + + logging.info('Starting training loop...') + for i in range(10): + step_key, train_key = jax.random.split(train_key) + params, opt_state, loss = train_step(params, opt_state, step_key) + logging.info('Step %d, Loss: %f', i, loss) + + self.assertIsNotNone(params) + + +if __name__ == '__main__': + absltest.main() diff --git a/recml/examples/DLRM_HSTU/hstu_transducer.py b/recml/examples/DLRM_HSTU/hstu_transducer.py new file mode 100644 index 0000000..63f2d51 --- /dev/null +++ b/recml/examples/DLRM_HSTU/hstu_transducer.py @@ -0,0 +1,234 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of HSTUTransducer for dense tensors.""" + +import logging +from typing import Dict, Optional, Tuple, Type + +import flax.linen as nn +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.positional_encoder import HSTUPositionalEncoder +from recml.examples.DLRM_HSTU.postprocessors import L2NormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import OutputPostprocessor +from recml.examples.DLRM_HSTU.preprocessors import InputPreprocessor +from recml.examples.DLRM_HSTU.stu import STUStack + + +logger = logging.getLogger(__name__) + + +class HSTUTransducer(nn.Module): + """JAX/Flax implementation of the HSTU Transducer module, using dense tensors. + + This implementation mirrors structure but replaces jagged tensor operations + with dense tensor operations using masking. + """ + + stu_module: STUStack + input_preprocessor: InputPreprocessor + output_postprocessor_cls: Type[OutputPostprocessor] = L2NormPostprocessor + input_dropout_ratio: float = 0.0 + positional_encoder: Optional[HSTUPositionalEncoder] = None + return_full_embeddings: bool = False + listwise: bool = False + + def setup(self): + self._output_postprocessor: OutputPostprocessor = ( + self.output_postprocessor_cls() + ) + self._input_dropout = nn.Dropout(rate=self.input_dropout_ratio) + + def _preprocess( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + is_training: bool, + ) -> Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + Dict[str, jnp.ndarray], + ]: + """Preprocesses the input sequence embeddings.""" + ( + output_seq_embeddings, + output_seq_mask, + output_seq_timestamps, + output_num_targets, + output_seq_payloads, + ) = self.input_preprocessor( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_mask=seq_mask, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + deterministic=not is_training, + ) + + output_seq_lengths = jnp.sum(output_seq_mask, axis=1, dtype=jnp.int32) + + if self.positional_encoder is not None: + output_seq_embeddings = self.positional_encoder( + max_seq_len=output_seq_embeddings.shape[1], + seq_lengths=output_seq_lengths, + seq_timestamps=output_seq_timestamps, + seq_embeddings=output_seq_embeddings, + num_targets=( + None if self.listwise and is_training else output_num_targets + ), + ) + + output_seq_embeddings = self._input_dropout( + output_seq_embeddings, deterministic=not is_training + ) + + return ( + output_seq_embeddings, + output_seq_mask, + output_seq_timestamps, + output_num_targets, + output_seq_payloads, + ) + + def _hstu_compute( + self, + seq_embeddings: jnp.ndarray, + num_targets: jnp.ndarray, + is_training: bool, + ) -> jnp.ndarray: + """Computes the HSTU embeddings.""" + seq_embeddings = self.stu_module( + x=seq_embeddings, + num_targets=None if self.listwise and is_training else num_targets, + deterministic=not is_training, + ) + return seq_embeddings + + def _postprocess( + self, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + ) -> Tuple[Optional[jnp.ndarray], jnp.ndarray]: + """Postprocesses the output sequence embeddings.""" + if self.return_full_embeddings: + seq_embeddings = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + seq_payloads=seq_payloads, + ) + + batch_size, max_seq_len, embedding_dim = seq_embeddings.shape + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + indices = jnp.arange(max_seq_len) + start_target_idx = seq_lengths - num_targets + candidate_mask = (indices >= start_target_idx[:, jnp.newaxis]) & ( + indices < seq_lengths[:, jnp.newaxis] + ) + + candidate_embeddings_masked = ( + seq_embeddings * candidate_mask[..., jnp.newaxis] + ) + candidate_timestamps_masked = seq_timestamps * candidate_mask + + if self.input_preprocessor.interleave_targets(): + raise NotImplementedError( + "Interleaved targets not supported in dense post-processing yet." + ) + + if not self.return_full_embeddings: + candidate_embeddings = self._output_postprocessor( + seq_embeddings=candidate_embeddings_masked, + seq_timestamps=candidate_timestamps_masked, + seq_payloads=seq_payloads, + ) + candidate_embeddings = ( + candidate_embeddings * candidate_mask[..., jnp.newaxis] + ) + else: + candidate_embeddings = candidate_embeddings_masked + + return ( + seq_embeddings if self.return_full_embeddings else None, + candidate_embeddings, + ) + + def __call__( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: jnp.ndarray, + seq_embeddings: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + *, + deterministic: Optional[bool] = None, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + ]: + """Forward pass for HSTUTransducer.""" + is_training = ( + not deterministic if deterministic is None else not deterministic + ) + + batch_size, max_len, _ = seq_embeddings.shape + seq_mask = jnp.arange(max_len)[None, :] < seq_lengths[:, None] + + ( + processed_seq_embeddings, + processed_seq_mask, + processed_seq_timestamps, + processed_num_targets, + processed_seq_payloads, + ) = self._preprocess( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_mask=seq_mask, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + is_training=is_training, + ) + + encoded_embeddings = self._hstu_compute( + seq_embeddings=processed_seq_embeddings, + num_targets=processed_num_targets, + is_training=is_training, + ) + + encoded_embeddings = ( + encoded_embeddings * processed_seq_mask[..., jnp.newaxis] + ) + + full_embeddings, candidate_embeddings = self._postprocess( + seq_embeddings=encoded_embeddings, + seq_mask=processed_seq_mask, + seq_timestamps=processed_seq_timestamps, + num_targets=processed_num_targets, + seq_payloads=processed_seq_payloads, + ) + + return candidate_embeddings, full_embeddings diff --git a/recml/examples/DLRM_HSTU/movielens_dataloader.py b/recml/examples/DLRM_HSTU/movielens_dataloader.py new file mode 100644 index 0000000..414da4d --- /dev/null +++ b/recml/examples/DLRM_HSTU/movielens_dataloader.py @@ -0,0 +1,178 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataloader for MovieLens dataset using jax_recommenders.""" + +import jax.numpy as jnp +import pandas as pd + +USER_ID = 'user_id' +ITEM_ID = 'item_id' +TIMESTAMP = 'timestamp' +USER_RATING = 'user_rating' + + +class MovieLensDataLoader: + """Dataloader for MovieLens dataset.""" + + def __init__( + self, + batch_size, + max_uih_len, + max_candidates, + raw_df: pd.DataFrame, + ): + self.batch_size = batch_size + self.max_uih_len = max_uih_len + self.max_candidates = max_candidates + + if raw_df is None: + raise ValueError('raw_df must be provided') + self.raw_df = raw_df + self._create_vocabs() + self.processed_data = self._preprocess_data() + + def _create_vocabs(self): + """Creates vocabularies from the raw dataframe.""" + self.user_vocab = sorted(self.raw_df[USER_ID].unique()) + self.movie_vocab = sorted(self.raw_df[ITEM_ID].unique()) + # Movielens ratings are 0.5 to 5.0. We can map them to 0-9 + self.rating_vocab = sorted(self.raw_df[USER_RATING].unique()) + + self.user_map = {name: i for i, name in enumerate(self.user_vocab)} + self.movie_map = {name: i for i, name in enumerate(self.movie_vocab)} + self.rating_map = {name: i for i, name in enumerate(self.rating_vocab)} + + self.user_vocab_size = len(self.user_vocab) + self.movie_vocab_size = len(self.movie_vocab) + self.rating_vocab_size = len(self.rating_vocab) + self.genre_vocab_size = 1 # Genre not directly used in this simple version + + def _pad_seq(self, seq, max_len, pad_value=0): + """Pads a sequence to max_len.""" + if len(seq) > max_len: + return seq[:max_len] + return seq + [pad_value] * (max_len - len(seq)) + + def _preprocess_data(self): + """Preprocesses the raw data into batches of UIH and candidates.""" + df = self.raw_df.copy() + df[USER_ID] = df[USER_ID].map(self.user_map) + df[ITEM_ID] = df[ITEM_ID].map(self.movie_map) + df[USER_RATING] = df[USER_RATING].map(self.rating_map) + + df = df.sort_values(by=[USER_ID, TIMESTAMP]) + grouped = df.groupby(USER_ID) + + batched_data = [] + current_batch = [] + + for user_id, user_df in grouped: + history = user_df[:-self.max_candidates] + candidates = user_df[-self.max_candidates:] + + if len(history) < 1 or len(candidates) < 1: + continue + + uih_len = min(len(history), self.max_uih_len) + num_cands = len(candidates) + + uih_features = { + 'user_id': self._pad_seq( + [user_id] * uih_len, self.max_uih_len, pad_value=-1 + ), + 'movie_id': self._pad_seq( + history[ITEM_ID].tolist(), self.max_uih_len + ), + 'rating': self._pad_seq( + history[USER_RATING].tolist(), self.max_uih_len + ), + 'action_time': self._pad_seq( + history[TIMESTAMP].tolist(), self.max_uih_len + ), + 'uih_weight': self._pad_seq([1] * uih_len, self.max_uih_len, 0), + 'uih_watch_time': self._pad_seq( + history[USER_RATING].tolist(), self.max_uih_len, 0 + ), + } + + candidate_features = { + 'movie_id': self._pad_seq( + candidates[ITEM_ID].tolist(), self.max_candidates + ), + 'query_time': self._pad_seq( + candidates[TIMESTAMP].tolist(), self.max_candidates + ), + # candidates_weight is used as a mask for valid candidates in the loss + # calculation. + 'candidates_weight': self._pad_seq( + [1] * num_cands, self.max_candidates, 0 + ), + # candidates_watch_time carries the true rating values for the + # candidate items, which are used as labels for the regression task + # in the MultitaskModule. + 'candidates_watch_time': self._pad_seq( + candidates[USER_RATING].tolist(), self.max_candidates, 0 + ), + } + + current_batch.append({ + 'uih_features': uih_features, + 'candidate_features': candidate_features, + 'uih_lengths': uih_len, + 'num_candidates': num_cands, + }) + + if len(current_batch) == self.batch_size: + batched_data.append(self._collate_batch(current_batch)) + current_batch = [] + + # Add the last partial batch if any + if current_batch: + # To keep things simple for the test, we'll drop the last partial batch + # pass # batched_data.append(self._collate_batch(current_batch)) + pass + return batched_data + + def _collate_batch(self, batch): + """Collates a list of samples into a single batch of numpy arrays.""" + collated = {} + if not batch: + return collated + + keys = batch[0].keys() + + for key in keys: + example_value = batch[0][key] + if isinstance(example_value, dict): + collated[key] = {} + sub_keys = example_value.keys() + for sub_key in sub_keys: + collated[key][sub_key] = jnp.array( + [sample[key][sub_key] for sample in batch] + ) + elif isinstance(example_value, int): + collated[key] = jnp.array([sample[key] for sample in batch]) + else: + # Handle other potential types if necessary + pass + return collated + + def get_batch(self, idx): + """Returns a single batch by index.""" + if idx >= len(self.processed_data): + raise IndexError("Batch index out of range") + return self.processed_data[idx] + + def __len__(self): + return len(self.processed_data) diff --git a/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py b/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py new file mode 100644 index 0000000..2da3a39 --- /dev/null +++ b/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py @@ -0,0 +1,233 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DLRM with MovieLens dataset.""" + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np +import optax +import pandas as pd +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTU +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTUConfig +from recml.examples.DLRM_HSTU.dlrm_hstu import EmbeddingConfig +from recml.examples.DLRM_HSTU.movielens_dataloader import MovieLensDataLoader +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig + + +USER_ID = 'user_id' +ITEM_ID = 'item_id' +TIMESTAMP = 'timestamp' +USER_RATING = 'user_rating' + + +def create_dummy_movielens_df(num_users, num_items, num_events): + user_ids = np.random.randint(0, num_users, num_events) + item_ids = np.random.randint(0, num_items, num_events) + ratings = np.random.uniform(0.5, 5.0, num_events).round(1) + timestamps = np.arange(num_events) * 1000 # Increasing timestamps + df = pd.DataFrame({ + USER_ID: [f'user_{u}' for u in user_ids], + ITEM_ID: [f'item_{i}' for i in item_ids], + USER_RATING: ratings, + TIMESTAMP: timestamps, + }) + return df + + +class DlrmMovielensTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.batch_size = 16 + self.max_uih_len = 16 + self.max_candidates = 4 + self.embed_dim = 32 + self.hstu_dim = 64 + + dummy_df = create_dummy_movielens_df( + num_users=50, num_items=100, num_events=500 + ) + + self.dataloader = MovieLensDataLoader( + self.batch_size, + self.max_uih_len, + self.max_candidates, + raw_df=dummy_df, + ) + + self.user_vocab = self.dataloader.user_vocab_size + self.movie_vocab = self.dataloader.movie_vocab_size + self.rating_vocab = self.dataloader.rating_vocab_size + self.genre_vocab = self.dataloader.genre_vocab_size + + self.config = DlrmHSTUConfig( + max_seq_len=self.max_uih_len + self.max_candidates, + hstu_embedding_table_dim=self.embed_dim, + hstu_transducer_embedding_dim=self.hstu_dim, + hstu_preprocessor_hidden_dim=16, + hstu_attn_num_layers=2, + hstu_attn_linear_dim=16, + hstu_attn_qk_dim=16, + item_embedding_feature_names=['movie_id'], + user_embedding_feature_names=['user_id'], + uih_post_id_feature_name='movie_id', + uih_action_time_feature_name='action_time', + candidates_querytime_feature_name='query_time', + candidates_weight_feature_name='candidates_weight', + candidates_watchtime_feature_name='candidates_watch_time', + uih_weight_feature_name='uih_weight', + action_weights=[1], + merge_uih_candidate_feature_mapping=[ + ('movie_id', 'movie_id'), + ('rating', 'rating'), + ('action_time', 'query_time'), + ('user_id', 'user_id'), + ('uih_weight', 'candidates_weight'), + ('uih_watch_time', 'candidates_watch_time'), + ], + multitask_configs=[ + TaskConfig('RatingPrediction', 1, MultitaskTaskType.REGRESSION), + ], + contextual_feature_to_max_length={}, + additional_content_features={}, + target_enrich_features={}, + ) + self.tables = { + 'user_id': EmbeddingConfig('user_id', self.user_vocab, self.embed_dim), + 'movie_id': EmbeddingConfig( + 'movie_id', self.movie_vocab, self.embed_dim + ), + 'rating': EmbeddingConfig('rating', self.rating_vocab, self.embed_dim), + } + + @parameterized.named_parameters( + ('train', False), + ('eval', True), + ) + def test_forward_pass(self, deterministic): + key = jax.random.PRNGKey(0) + prng_keys = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, embedding_tables=self.tables) + + if not self.dataloader: + self.skipTest( + 'No batches were created, potentially too few users or max_candidates' + ' too high for debug data.' + ) + + batch = self.dataloader.get_batch(0) + uih_features = batch['uih_features'] + candidate_features = batch['candidate_features'] + uih_lengths = batch['uih_lengths'] + num_candidates = batch['num_candidates'] + + variables = model.init( + {'params': prng_keys[0], 'dropout': prng_keys[1]}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + ) + + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + variables, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + rngs={'dropout': prng_keys[2]} if not deterministic else None, + ) + + num_tasks = len(self.config.multitask_configs) + self.assertEqual( + user_emb.shape, (self.batch_size, self.max_candidates, self.hstu_dim) + ) + self.assertEqual( + item_emb.shape, (self.batch_size, self.max_candidates, self.hstu_dim) + ) + self.assertEqual( + preds.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + if not deterministic: + self.assertNotEmpty(aux_losses) + self.assertEqual( + labels.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + self.assertEqual( + weights.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + def test_backward_pass(self): + key = jax.random.PRNGKey(1) + init_key, data_key, train_key = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, embedding_tables=self.tables) + + if not self.dataloader: + self.skipTest( + 'No batches were created, potentially too few users or max_candidates' + ' too high for debug data.' + ) + + batch = self.dataloader.get_batch(0) + uih_features = batch['uih_features'] + candidate_features = batch['candidate_features'] + uih_lengths = batch['uih_lengths'] + num_candidates = batch['num_candidates'] + + variables = model.init( + {'params': init_key, 'dropout': train_key}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=False, + ) + params = variables['params'] + + optimizer = optax.adam(learning_rate=1e-3) + opt_state = optimizer.init(params) + + def loss_fn(params, dropout_key): + _, _, aux_losses, _, _, _ = model.apply( + {'params': params}, + uih_features, + candidate_features, + uih_lengths, + num_candidates, + deterministic=False, + rngs={'dropout': dropout_key}, + ) + return sum(val.sum() for val in aux_losses.values()) + + @jax.jit + def train_step(params, opt_state, dropout_key): + loss, grads = jax.value_and_grad(loss_fn)(params, dropout_key) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss + + step_key, train_key = jax.random.split(train_key) + params, opt_state, loss = train_step(params, opt_state, step_key) + logging.info('MovieLens Test Loss: %f', loss) + self.assertIsNotNone(params) + + +if __name__ == '__main__': + absltest.main() diff --git a/recml/examples/DLRM_HSTU/multitask_module.py b/recml/examples/DLRM_HSTU/multitask_module.py new file mode 100644 index 0000000..510aeb6 --- /dev/null +++ b/recml/examples/DLRM_HSTU/multitask_module.py @@ -0,0 +1,273 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains modules and functions for handling multitask predictions and losses.""" + +import abc +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple + +from flax import linen as nn +import jax.numpy as jnp +import numpy as np +import optax + + +# These data classes are pure Python and can be used directly. +class MultitaskTaskType(IntEnum): + BINARY_CLASSIFICATION = 0 + REGRESSION = 1 + + +@dataclass +class TaskConfig: + task_name: str + task_weight: int + task_type: MultitaskTaskType + + +class MultitaskModule(nn.Module): + """Abstract base class for multitask modules in Flax.""" + + def __call__( + self, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ]: + """Computes multi-task predictions. + + Args: + encoded_user_embeddings: (B, N, D) float array. + item_embeddings: (B, N, D) float array. + supervision_labels: Dictionary of (B, N) float or int arrays. + supervision_weights: Dictionary of (B, N) float or int arrays. + deterministic: If True, losses are not computed (inference mode). + + Returns: + A tuple of (predictions, labels, weights, losses). + Predictions are of shape (num_tasks, B, N). + """ + raise NotImplementedError + + +def _compute_pred_and_logits( + prediction_module: nn.Module, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + task_offsets: List[int], + has_multiple_task_types: bool, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes predictions and raw logits from user and item embeddings.""" + # Logits are computed by applying the prediction module to the + # element-wise product. + # Input shape: (B, N, D), Output shape: (B, N, num_tasks) + mt_logits_unnposed = prediction_module( + encoded_user_embeddings * item_embeddings + ) + # Transpose to (num_tasks, B, N) to match PyTorch logic. + mt_logits = jnp.transpose(mt_logits_unnposed, (2, 0, 1)) + + mt_preds_list: List[jnp.ndarray] = [] + for task_type in MultitaskTaskType: + start_offset, end_offset = ( + task_offsets[task_type], + task_offsets[task_type + 1], + ) + if end_offset > start_offset: + task_logits = mt_logits[start_offset:end_offset, ...] + if task_type == MultitaskTaskType.REGRESSION: + # For regression, predictions are the raw logits. + mt_preds_list.append(task_logits) + else: + # For classification, predictions are the sigmoid of the logits. + mt_preds_list.append(nn.sigmoid(task_logits)) + + mt_preds = ( + jnp.concatenate(mt_preds_list, axis=0) + if has_multiple_task_types + else mt_preds_list[0] + ) + + return mt_preds, mt_logits + + +def _compute_labels_and_weights( + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + task_configs: List[TaskConfig], +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Aggregates label and weight tensors from input dictionaries.""" + # Get a sample tensor to determine shape and dtype for the default weight. + first_label = next(iter(supervision_labels.values())) + default_supervision_weight = jnp.ones_like(first_label) + + mt_labels_list: List[jnp.ndarray] = [] + mt_weights_list: List[jnp.ndarray] = [] + for task in task_configs: + mt_labels_list.append(supervision_labels[task.task_name]) + mt_weights_list.append( + supervision_weights.get(task.task_name, default_supervision_weight) + ) + + # Stack along a new 'task' dimension. + mt_labels = jnp.stack(mt_labels_list, axis=0) + mt_weights = jnp.stack(mt_weights_list, axis=0) + + return mt_labels, mt_weights + + +def _compute_loss( + task_offsets: List[int], + causal_multitask_weights: float, + mt_logits: jnp.ndarray, + mt_labels: jnp.ndarray, + mt_weights: jnp.ndarray, + has_multiple_task_types: bool, +) -> jnp.ndarray: + """Computes the final loss across all tasks.""" + mt_losses_list: List[jnp.ndarray] = [] + for task_type in MultitaskTaskType: + start_offset, end_offset = ( + task_offsets[task_type], + task_offsets[task_type + 1], + ) + if end_offset > start_offset: + task_logits = mt_logits[start_offset:end_offset, ...] + task_labels = mt_labels[start_offset:end_offset, ...] + task_weights = mt_weights[start_offset:end_offset, ...] + + if task_type == MultitaskTaskType.REGRESSION: + # Equivalent to mse_loss with reduction='none'. + task_losses = (task_logits - task_labels) ** 2 + else: + # Equivalent to binary_cross_entropy_with_logits with reduction='none'. + task_losses = optax.sigmoid_binary_cross_entropy( + task_logits, task_labels + ) + + # Apply task-specific weights. + mt_losses_list.append(task_losses * task_weights) + + mt_losses = ( + jnp.concatenate(mt_losses_list, axis=0) + if has_multiple_task_types + else mt_losses_list[0] + ) + + # Normalize loss per task by the sum of weights for that task. + # Sum over the item dimension (axis=-1). + sum_losses = mt_losses.sum(axis=-1) + sum_weights = mt_weights.sum(axis=-1) + + # Clamp sum_weights to avoid division by zero for empty examples. + normalized_losses = sum_losses / jnp.maximum(sum_weights, 1.0) + + # Apply a global weight for this entire multitask head. + return normalized_losses * causal_multitask_weights + + +class DefaultMultitaskModule(MultitaskModule): + """ + JAX/Flax implementation of the default multitask module. + + Attributes: + task_configs: A list of TaskConfig objects, which must be pre-sorted + by task_type. + embedding_dim: The dimensionality of the input embeddings. + prediction_fn: A function that returns a Flax module for predictions, + e.g., a simple MLP. It takes embedding_dim and num_tasks as input. + causal_multitask_weights: A global weight for the final computed loss. + """ + task_configs: List[TaskConfig] + embedding_dim: int + prediction_fn: Callable[[int, int], nn.Module] + causal_multitask_weights: float + + def setup(self): + if not self.task_configs: + raise ValueError("task_configs must be non-empty.") + + # Check if tasks are sorted by type, as required by the original logic. + is_sorted = all( + self.task_configs[i].task_type <= self.task_configs[i + 1].task_type + for i in range(len(self.task_configs) - 1) + ) + if not is_sorted: + raise ValueError("task_configs must be sorted by task_type.") + + # Calculate offsets for slicing tensors based on task type. + task_offsets_list = [0] * (len(MultitaskTaskType) + 1) + for task in self.task_configs: + task_offsets_list[task.task_type + 1] += 1 + + self._has_multiple_task_types: bool = ( + task_offsets_list.count(0) < len(MultitaskTaskType) + ) + self._task_offsets: List[int] = np.cumsum(task_offsets_list).tolist() + + # Instantiate the prediction module. + self._prediction_module = self.prediction_fn( + self.embedding_dim, len(self.task_configs) + ) + + def __call__( + self, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ]: + + mt_preds, mt_logits = _compute_pred_and_logits( + prediction_module=self._prediction_module, + encoded_user_embeddings=encoded_user_embeddings, + item_embeddings=item_embeddings, + task_offsets=self._task_offsets, + has_multiple_task_types=self._has_multiple_task_types, + ) + + mt_labels: Optional[jnp.ndarray] = None + mt_weights: Optional[jnp.ndarray] = None + mt_losses: Optional[jnp.ndarray] = None + + if not deterministic: + mt_labels, mt_weights = _compute_labels_and_weights( + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + task_configs=self.task_configs, + ) + mt_losses = _compute_loss( + task_offsets=self._task_offsets, + causal_multitask_weights=self.causal_multitask_weights, + mt_logits=mt_logits, + mt_labels=mt_labels, + mt_weights=mt_weights, + has_multiple_task_types=self._has_multiple_task_types, + ) + + return mt_preds, mt_labels, mt_weights, mt_losses diff --git a/recml/examples/DLRM_HSTU/positional_encoder.py b/recml/examples/DLRM_HSTU/positional_encoder.py new file mode 100644 index 0000000..34c8f99 --- /dev/null +++ b/recml/examples/DLRM_HSTU/positional_encoder.py @@ -0,0 +1,242 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of positional and timestamp encoding for sequences.""" + +from math import sqrt +from typing import Optional + +from flax import linen as nn +import jax +import jax.numpy as jnp + + +def _get_col_indices( + max_seq_len: int, + max_contextual_seq_len: int, + max_pos_ind: int, + seq_lengths: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + interleave_targets: bool, +) -> jnp.ndarray: + """Calculates the positional indices for each element in the sequence. + + JAX translation of `_get_col_indices` from pt_position.py. + + Args: + max_seq_len: The maximum sequence length. + max_contextual_seq_len: The maximum length of the contextual prefix. + max_pos_ind: The maximum positional index. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + interleave_targets: A boolean indicating whether to interleave targets. + + Returns: + A 2D tensor of shape (batch_size, max_seq_len) containing the positional + indices for each element in the sequence. + """ + batch_size = seq_lengths.shape[0] + col_indices = jnp.tile( + jnp.arange(max_seq_len, dtype=jnp.int32), (batch_size, 1) + ) + + if num_targets is not None: + if interleave_targets: + high_inds = seq_lengths - num_targets * 2 + else: + high_inds = seq_lengths - num_targets + + col_indices = jnp.minimum(col_indices, high_inds[:, jnp.newaxis]) + col_indices = high_inds[:, jnp.newaxis] - col_indices + else: + col_indices = seq_lengths[:, jnp.newaxis] - col_indices + + col_indices = col_indices + max_contextual_seq_len + col_indices = jnp.clip(col_indices, a_min=0, a_max=max_pos_ind - 1) + + if max_contextual_seq_len > 0: + contextual_indices = jnp.arange(max_contextual_seq_len, dtype=jnp.int32)[ + jnp.newaxis, : + ] + col_indices = col_indices.at[:, :max_contextual_seq_len].set( + contextual_indices + ) + + return col_indices + + +def add_timestamp_positional_embeddings( + seq_embeddings: jnp.ndarray, + pos_embeddings: jnp.ndarray, + ts_embeddings: jnp.ndarray, + timestamps: jnp.ndarray, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + interleave_targets: bool, + time_bucket_fn: str, +) -> jnp.ndarray: + """Adds timestamp and positional embeddings to sequence embeddings. + + JAX translation of `pytorch_add_timestamp_positional_embeddings`. Assumes + inputs are padded dense tensors. + + Args: + seq_embeddings: A 3D padded tensor of shape (batch_size, max_seq_len, + embedding_dim) containing the input item embeddings. + pos_embeddings: The learned positional embedding weights. + ts_embeddings: The learned timestamp embedding weights. + timestamps: A 2D padded tensor of shape (batch_size, max_seq_len) + containing timestamps for each item. + max_seq_len: The maximum sequence length for padding. + max_contextual_seq_len: The maximum length of the contextual prefix. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + interleave_targets: A boolean indicating whether to interleave targets. + time_bucket_fn: The function to use for time bucketing ("log" or "sqrt"). + + Returns: + A 3D tensor of the same shape as `seq_embeddings` with positional + and time embeddings added. + """ + # Position encoding + max_pos_ind = pos_embeddings.shape[0] + pos_inds = _get_col_indices( + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + ) + position_embeddings = pos_embeddings[pos_inds] + + # Timestamp encoding + batch_size = seq_lengths.shape[0] + num_time_buckets = ts_embeddings.shape[0] - 1 + time_bucket_increments = 60.0 + time_bucket_divisor = 1.0 + time_delta = 0 + + # Get the last valid timestamp from each padded sequence for query_time + query_indices = jnp.maximum(0, seq_lengths - 1) + query_time = timestamps[jnp.arange(batch_size), query_indices][:, jnp.newaxis] + + ts = query_time - timestamps + ts = ts + time_delta + ts = jnp.maximum(ts, 1e-6) / time_bucket_increments + + if time_bucket_fn == "log": + ts = jnp.log(ts) + elif time_bucket_fn == "sqrt": + ts = jnp.sqrt(ts) + else: + raise ValueError(f"Unsupported time_bucket_fn: {time_bucket_fn}") + + ts = (ts / time_bucket_divisor).clip(min=0).astype(jnp.int32) + ts = jnp.clip(ts, a_min=0, a_max=num_time_buckets) + + time_embeddings = ts_embeddings[ts] + + # Combine embeddings + added_embeddings = (position_embeddings + time_embeddings).astype( + seq_embeddings.dtype + ) + + # The original op implies addition to only the valid (non-padded) parts. + # In a dense representation, this is equivalent to masking the added + # embeddings. + mask = ( + jnp.arange(max_seq_len, dtype=jnp.int32)[jnp.newaxis, :] + < seq_lengths[:, jnp.newaxis] + ) + masked_added_embeddings = added_embeddings * mask[..., jnp.newaxis] + + return seq_embeddings + masked_added_embeddings + + +class HSTUPositionalEncoder(nn.Module): + """JAX implementation of HSTUPositionalEncoder. + + This module computes and adds positional and timestamp-based embeddings + to a sequence of input embeddings. + + Attributes: + num_position_buckets: The total number of position buckets. + num_time_buckets: The total number of time buckets. + embedding_dim: The dimensionality of the embeddings. + contextual_seq_len: The length of the contextual prefix in sequences. + """ + + num_position_buckets: int + num_time_buckets: int + embedding_dim: int + contextual_seq_len: int + + @nn.compact + def __call__( + self, + max_seq_len: int, + seq_lengths: jnp.ndarray, + seq_timestamps: jnp.ndarray, + seq_embeddings: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Adds positional and timestamp embeddings to the input sequence embeddings. + + Args: + max_seq_len: The maximum sequence length for padding. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + seq_timestamps: A 2D padded tensor of shape (batch_size, max_seq_len) + containing timestamps for each item. + seq_embeddings: A 3D padded tensor of shape (batch_size, max_seq_len, + embedding_dim) containing the input item embeddings. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + + Returns: + A 3D tensor of the same shape as `seq_embeddings` with positional + and time embeddings added. + """ + position_embeddings_weight = self.param( + "_position_embeddings_weight", + nn.initializers.uniform(scale=sqrt(1.0 / self.num_position_buckets)), + (self.num_position_buckets, self.embedding_dim), + ) + timestamp_embeddings_weight = self.param( + "_timestamp_embeddings_weight", + nn.initializers.uniform(scale=sqrt(1.0 / self.num_time_buckets)), + (self.num_time_buckets + 1, self.embedding_dim), + ) + + scaled_seq_embeddings = seq_embeddings * sqrt(self.embedding_dim) + + final_embeddings = add_timestamp_positional_embeddings( + seq_embeddings=scaled_seq_embeddings, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=seq_timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=self.contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=False, + time_bucket_fn="sqrt", + ) + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/postprocessors.py b/recml/examples/DLRM_HSTU/postprocessors.py new file mode 100644 index 0000000..59696e0 --- /dev/null +++ b/recml/examples/DLRM_HSTU/postprocessors.py @@ -0,0 +1,171 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Postprocessors for user embeddings after HSTU layers.""" + +import math +from typing import Any, Dict, List, Tuple + +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax.numpy as jnp + + +Array = jnp.ndarray +Dtype = Any + + +class OutputPostprocessor(nn.Module): + """An abstract class for post-processing user embeddings after HSTU layers.""" + + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + """Processes the final sequence embeddings. + + Args: + seq_embeddings: (B, N, D) or (L, D) final embeddings from the model. + seq_timestamps: (B, N) or (L,) corresponding timestamps. + seq_payloads: A dictionary of other features. + + Returns: + The post-processed sequence embeddings. + """ + raise NotImplementedError + + +class L2NormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with L2 normalization.""" + epsilon: float = 1e-6 + + @nn.compact + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + norm = jnp.linalg.norm(seq_embeddings, ord=2, axis=-1, keepdims=True) + # Prevent division by zero + safe_norm = jnp.maximum(norm, self.epsilon) + return seq_embeddings / safe_norm + + +class LayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with LayerNorm.""" + embedding_dim: int + eps: float = 1e-5 + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + ln = nn.LayerNorm(epsilon=self.eps, dtype=self.dtype) + return ln(seq_embeddings) + + +class TimestampLayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with a timestamp-based MLP and LayerNorm.""" + embedding_dim: int + time_duration_features: List[Tuple[int, int]] + eps: float = 1e-5 + dtype: Dtype = jnp.float32 + + def setup(self): + self._layer_norm = nn.LayerNorm(epsilon=self.eps, dtype=self.dtype) + + num_time_features = len(self.time_duration_features) + combiner_input_dim = self.embedding_dim + 2 * num_time_features + + self._time_feature_combiner = nn.Dense( + features=self.embedding_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + ) + + # Store time feature constants directly. No need for buffers in Flax. + self._period_units = jnp.array( + [f[0] for f in self.time_duration_features], dtype=self.dtype + ) + self._units_per_period = jnp.array( + [f[1] for f in self.time_duration_features], dtype=self.dtype + ) + + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + """Processes sequence embeddings with timestamp features and LayerNorm. + + Creates circular time features, concatenates them to the embeddings, + processes through an MLP, and applies LayerNorm. + + Args: + seq_embeddings: (B, N, D) or (L, D) final embeddings from the model. + seq_timestamps: (B, N) or (L,) corresponding timestamps. + seq_payloads: A dictionary of other features. + + Returns: + The post-processed sequence embeddings. + """ + + # 1. Create circular time features from timestamps. + # Ensure timestamps have a feature dimension for broadcasting. + if seq_timestamps.ndim != seq_embeddings.ndim: + timestamps = jnp.expand_dims(seq_timestamps, axis=-1) + else: + timestamps = seq_timestamps + + # Ensure correct broadcast shape for time constants. + # Original shape: (num_features,) -> (1, ..., 1, num_features) + broadcast_shape = (1,) * (timestamps.ndim - 1) + (-1,) + period_units = self._period_units.reshape(broadcast_shape) + units_per_period = self._units_per_period.reshape(broadcast_shape) + + # Calculate the phase angle for the circular representation. + units_since_epoch = jnp.floor(timestamps / period_units) + remainder = jnp.remainder(units_since_epoch, units_per_period) + angle = (remainder / units_per_period) * 2 * math.pi + + # Create sin/cos features. Cast to float32 for precision if needed. + original_dtype = angle.dtype + if original_dtype != jnp.float32: + angle = angle.astype(jnp.float32) + + cos_features = jnp.cos(angle) + sin_features = jnp.sin(angle) + + time_features = jnp.stack([cos_features, sin_features], axis=-1) + + # New shape will have a final dimension of num_time_features * 2 + final_shape = seq_embeddings.shape[:-1] + (-1,) + time_features = time_features.reshape(final_shape).astype(original_dtype) + # 2. Concatenate with sequence embeddings. + combined_embeddings = jnp.concatenate( + [seq_embeddings, time_features], axis=-1 + ) + # 3. Process through the MLP and LayerNorm. + user_embeddings = self._time_feature_combiner(combined_embeddings) + final_embeddings = self._layer_norm(user_embeddings) + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/preprocessors.py b/recml/examples/DLRM_HSTU/preprocessors.py new file mode 100644 index 0000000..17e5f15 --- /dev/null +++ b/recml/examples/DLRM_HSTU/preprocessors.py @@ -0,0 +1,131 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python3 + +"""Input preprocessors for HSTU models.""" + +from typing import Any, Dict, Tuple + +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder + + +Array = jnp.ndarray +Dtype = Any + + +class SwishLayerNorm(nn.Module): + """JAX/Flax implementation of SwishLayerNorm. + + Corresponds to generative_recommenders/ops/layer_norm.py -> SwishLayerNorm + The PyTorch implementation is: x * sigmoid(layer_norm(x)) + """ + + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + """Applies swish layer normalization to the input.""" + ln = nn.LayerNorm( + epsilon=self.epsilon, + use_bias=True, + use_scale=True, + dtype=self.dtype, + ) + normed_x = ln(x) + return x * nn.sigmoid(normed_x) + + +class InputPreprocessor(nn.Module): + """An abstract class for pre-processing sequence embeddings before HSTU layers.""" + + def __call__( + self, + max_uih_len: int, + seq_embeddings: Array, + seq_mask: Array, + seq_timestamps: Array, + num_targets: Array, + seq_payloads: Dict[str, Array], + *, + deterministic: bool, + ) -> Tuple[Array, Array, Array, Array, Dict[str, Array]]: + """Processes input sequences and their features. + + Args: + max_uih_len: Maximum length of the user item history. + seq_embeddings: (B, N, D) Padded sequence embeddings. + seq_mask: (B, N) Boolean mask for seq_embeddings. + seq_timestamps: (B, N) Padded timestamps. + num_targets: (B,) Number of targets for each sequence. + seq_payloads: Dict of other features, also as padded tensors with + masks. + deterministic: Controls dropout behavior. + + Returns: + A tuple containing the processed ( + output_embeddings, + output_mask, + output_timestamps, + output_num_targets, + output_payloads + ). + """ + raise NotImplementedError + + def interleave_targets(self) -> bool: + return False + + +def get_contextual_input_embeddings( + seq_mask: Array, + seq_payloads: Dict[str, Array], + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + dtype: Dtype, +) -> Array: + """Constructs the input for contextual embeddings from dense tensors. + + Args: + seq_mask: Boolean mask for the sequence. + seq_payloads: Dictionary of all feature tensors. + contextual_feature_to_max_length: Maps feature names to their max length. + contextual_feature_to_min_uih_length: Maps features to a min uih length + for them to be active. + dtype: Data type for the output. + + Returns: + A dense tensor of shape (batch_size, sum_of_dims). + """ + padded_values = [] + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + + for key, max_len in contextual_feature_to_max_length.items(): + # Assuming the payload is already a dense tensor of shape (B, L, D) + v = seq_payloads[key].astype(dtype) + + min_uih_length = contextual_feature_to_min_uih_length.get(key, 0) + if min_uih_length > 0: + # Create a mask to zero out embeddings for sequences that are too short + mask = (seq_lengths >= min_uih_length).reshape(-1, 1, 1) + v *= mask + + # Flatten the feature dimension + padded_values.append(v.reshape(v.shape[0], -1)) + + return jnp.concatenate(padded_values, axis=1) diff --git a/recml/examples/DLRM_HSTU/stu.py b/recml/examples/DLRM_HSTU/stu.py new file mode 100644 index 0000000..8935c78 --- /dev/null +++ b/recml/examples/DLRM_HSTU/stu.py @@ -0,0 +1,282 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Self-Targeting Unit (STU) module. + +This module implements the STU layer and a stack of STU +layers. The STU layer is designed to capture long-range dependencies in +sequential data by incorporating self-attention mechanisms with a gating +mechanism. +""" +import dataclasses +from typing import Optional, Sequence + +from flax import linen as nn +import jax.numpy as jnp + +dataclass = dataclasses.dataclass + + +@dataclass +class STULayerConfig: + """Configuration for the STU layer. + + Attributes: + embedding_dim: Input embedding dimension. + num_heads: Number of attention heads. + hidden_dim: Hidden dimension of the STU layer. + attention_dim: Dimension of the attention projections. + output_dropout_ratio: Dropout ratio for the output. + causal: Whether to use causal attention. + target_aware: Whether to use target-aware attention. + max_attn_len: Maximum attention length. + attn_alpha: Scaling factor for attention scores. + use_group_norm: Whether to use group normalization. + recompute_normed_x: Whether to recompute normalized input. + recompute_uvqk: Whether to recompute u, v, q, k projections. + recompute_y: Whether to recompute the output. + sort_by_length: Whether to sort sequences by length. + contextual_seq_len: Contextual sequence length. + min_full_attn_seq_len: Minimum sequence length to apply full attention. + norm_epsilon: Epsilon value for normalization. + deterministic: Whether to apply dropout in deterministic mode. + """ + embedding_dim: int + num_heads: int + hidden_dim: int + attention_dim: int + output_dropout_ratio: float = 0.0 + causal: bool = True + target_aware: bool = True + max_attn_len: Optional[int] = None + attn_alpha: Optional[float] = None + use_group_norm: bool = False + recompute_normed_x: bool = True + recompute_uvqk: bool = True + recompute_y: bool = True + sort_by_length: bool = True + contextual_seq_len: int = 0 + norm_epsilon: float = 1e-6 + min_full_attn_seq_len: int = 0 + deterministic: bool = True + + +class STULayer(nn.Module): + """Self-Targeting Unit layer. + + Attributes: + config: STULayerConfig, configuration of the STU layer. + """ + + config: STULayerConfig + + def setup(self): + self.num_heads: int = self.config.num_heads + self.embedding_dim: int = self.config.embedding_dim + self.hidden_dim: int = self.config.hidden_dim + self.attention_dim: int = self.config.attention_dim + self.output_dropout_ratio: float = self.config.output_dropout_ratio + self.target_aware: bool = self.config.target_aware + self.causal: bool = self.config.causal + self.max_attn_len: int = self.config.max_attn_len or 0 + self.attn_alpha: float = self.config.attn_alpha or 1.0 / ( + self.attention_dim**0.5 + ) + self.use_group_norm: bool = self.config.use_group_norm + self.norm_epsilon: float = self.config.norm_epsilon + self.contextual_seq_len: int = self.config.contextual_seq_len + self.min_full_attn_seq_len: int = self.config.min_full_attn_seq_len + + self.uvqk_weight = self.param( + '_uvqk_weight', + nn.initializers.xavier_normal(), + ( + self.embedding_dim, + (self.hidden_dim * 2 + self.attention_dim * 2) * self.num_heads, + ), + ) + self.uvqk_beta = self.param( + '_uvqk_beta', + nn.initializers.zeros, + (self.hidden_dim * 2 + self.attention_dim * 2) * self.num_heads, + ) + + self.output_weight = self.param( + '_output_weight', + nn.initializers.xavier_uniform(), + (self.hidden_dim * self.num_heads * 3, self.embedding_dim), + ) + + self.dropout_layer = nn.Dropout(rate=self.output_dropout_ratio) + self.silu_layer = nn.silu + self.group_norm_layer = nn.GroupNorm( + num_groups=self.num_heads, + use_scale=True, + use_bias=True, + epsilon=self.norm_epsilon, + ) + self.input_norm_layer = nn.LayerNorm( + use_scale=True, use_bias=True, epsilon=self.norm_epsilon + ) + self.output_norm_layer = nn.LayerNorm( + use_scale=True, use_bias=True, epsilon=self.norm_epsilon + ) + + def _get_valid_attn_mask(self, x, num_targets: Optional[jnp.ndarray]): + batch_size, seq_len, _ = x.shape + seq_lengths = jnp.full((batch_size,), seq_len, dtype=jnp.int32) + ids = jnp.arange(seq_len)[None, :] + max_ids = seq_lengths[:, None, None] + + if self.contextual_seq_len > 0: + ids = ids - self.contextual_seq_len + 1 + ids = jnp.maximum(ids, 0) + max_ids = max_ids - self.contextual_seq_len + 1 + + if num_targets is not None: + max_ids = (max_ids - num_targets[:, None, None]).squeeze(axis=-1) + ids = jnp.minimum(ids, max_ids) + row_ids = ids[:, :, None] + col_ids = ids[:, None, :] + else: + row_ids_base = jnp.arange(seq_len)[None, :, None] + col_ids_base = jnp.arange(seq_len)[None, None, :] + row_ids = jnp.broadcast_to(row_ids_base, (1, seq_len, seq_len)) + col_ids = jnp.broadcast_to(col_ids_base, (1, seq_len, seq_len)) + + row_col_dist = row_ids - col_ids + valid_attn_mask = jnp.eye(seq_len, dtype=jnp.bool_)[None, :, :] + if not self.causal: + row_col_dist = jnp.abs(row_col_dist) + valid_attn_mask = jnp.logical_or(valid_attn_mask, row_col_dist > 0) + if self.max_attn_len > 0: + if self.min_full_attn_seq_len > 0: + valid_attn_mask = jnp.logical_and( + valid_attn_mask, + jnp.logical_or( + row_col_dist <= self.max_attn_len, + row_ids >= max_ids - self.min_full_attn_seq_len, + ), + ) + else: + valid_attn_mask = jnp.logical_and( + valid_attn_mask, row_col_dist <= self.max_attn_len + ) + if self.contextual_seq_len > 0: + valid_attn_mask = jnp.logical_or( + valid_attn_mask, jnp.logical_and(row_ids == 0, col_ids < max_ids) + ) + + return valid_attn_mask + + def hstu_compute_output(self, attn, u, x, deterministic: bool): + """Computes the output of the STU layer with corrected logic.""" + if self.use_group_norm: + norm_input = attn.reshape( + x.shape[0], x.shape[1], self.num_heads, self.hidden_dim + ) + normed_attn = self.group_norm_layer(norm_input).reshape( + x.shape[0], x.shape[1], -1 + ) + else: + normed_attn = self.output_norm_layer(attn) + + gated_attn = u * normed_attn + proj_input = jnp.concatenate([u, attn, gated_attn], axis=-1) + projected_output = proj_input @ self.output_weight + dropped_out = self.dropout_layer( + projected_output, deterministic=deterministic + ) + return x + dropped_out + + def hstu_preprocess_and_attention( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + deterministic: bool, + ): + """Replicated STU preprocess and attention.""" + normed_x = self.input_norm_layer(x) + uvqk = normed_x @ self.uvqk_weight + self.uvqk_beta + u_proj, v_proj, q_proj, k_proj = jnp.split( + uvqk, + [ + self.hidden_dim * self.num_heads, + self.hidden_dim * self.num_heads * 2, + self.hidden_dim * self.num_heads * 2 + + self.attention_dim * self.num_heads, + ], + axis=-1, + ) + + u = self.silu_layer(u_proj) + batch_size, seq_len, _ = x.shape + q = q_proj.reshape( + batch_size, seq_len, self.num_heads, self.attention_dim + ).transpose(0, 2, 1, 3) + k = k_proj.reshape( + batch_size, seq_len, self.num_heads, self.attention_dim + ).transpose(0, 2, 1, 3) + v = v_proj.reshape( + batch_size, seq_len, self.num_heads, self.hidden_dim + ).transpose(0, 2, 1, 3) + + qk_attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) * self.attn_alpha + qk_attn = self.silu_layer(qk_attn) / seq_len + valid_attn_mask = self._get_valid_attn_mask(x, num_targets) + qk_attn = qk_attn * valid_attn_mask[:, None, :, :].astype(jnp.float32) + qk_attn = self.dropout_layer(qk_attn, deterministic=deterministic) + attn_dense = jnp.einsum('bhqv,bhvd->bhqd', qk_attn, v) + attn_output = attn_dense.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len, -1 + ) + return u, attn_output, k_proj, v_proj + + def __call__( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray] = None, + deterministic: bool = True, + ): + """Computes the STU layer.""" + actual_num_targets = num_targets if self.target_aware else None + u, attn_output, _, _ = self.hstu_preprocess_and_attention( + x, actual_num_targets, deterministic=deterministic + ) + final_output = self.hstu_compute_output( + attn=attn_output, u=u, x=x, deterministic=deterministic + ) + return final_output + + +class STUStack(nn.Module): + """STU stack. + + This module creates a stack of STU layers. + + Attributes: + stu_layers: A sequence of STU layers. + """ + + stu_layers: Sequence[STULayer] + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray] = None, + deterministic: bool = True, + ): + for stu in self.stu_layers: + x = stu(x, num_targets, deterministic) + return x diff --git a/recml/examples/DLRM_HSTU/stu_test.py b/recml/examples/DLRM_HSTU/stu_test.py new file mode 100644 index 0000000..a0f28aa --- /dev/null +++ b/recml/examples/DLRM_HSTU/stu_test.py @@ -0,0 +1,283 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl import logging +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from recml.examples.DLRM_HSTU.stu import STULayer +from recml.examples.DLRM_HSTU.stu import STULayerConfig +from recml.examples.DLRM_HSTU.stu import STUStack + + +def get_test_configs(): + """Generates a list of test configurations.""" + test_params = [] + test_params.append(( + "basic_config", + { + "num_layers": 2, + "num_heads": 2, + "batch_size": 4, + "max_len": 32, + "embedding_dim": 16, + "attention_dim": 8, + "hidden_dim": 24, + "use_group_norm": False, + "target_aware": True, + }, + )) + test_params.append(( + "group_norm", + { + "num_layers": 1, + "num_heads": 4, + "batch_size": 2, + "max_len": 16, + "embedding_dim": 32, + "attention_dim": 16, + "hidden_dim": 20, + "use_group_norm": True, + "target_aware": True, + }, + )) + test_params.append(( + "not_target_aware", + { + "num_layers": 1, + "num_heads": 1, + "batch_size": 8, + "max_len": 64, + "embedding_dim": 8, + "attention_dim": 4, + "hidden_dim": 12, + "use_group_norm": False, + "target_aware": False, + }, + )) + test_params.append(( + "sliding_window_attention", + { + "num_layers": 1, + "num_heads": 2, + "batch_size": 2, + "max_len": 20, + "embedding_dim": 16, + "attention_dim": 8, + "hidden_dim": 16, + "use_group_norm": False, + "target_aware": True, + "max_attn_len": 5, + }, + )) + return test_params + + +class StuJaxTest(parameterized.TestCase): + """Unit tests for the JAX STU implementation.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + logging.info("Available devices: %s", jax.devices()) + # Assert that TPUs are available + assert any( + d.platform == "tpu" for d in jax.devices() + ), "No TPU devices found." + + def setUp(self): + """Set up a base key for all tests.""" + super().setUp() + self.key = jax.random.PRNGKey(42) + self.devices = jax.devices() + self.num_devices = len(self.devices) + self.mesh = Mesh(np.array(self.devices), ("data",)) + logging.info("Using device mesh: %s", self.mesh) + + self.batch_sharding = NamedSharding(self.mesh, PartitionSpec("data")) + self.replicated_sharding = NamedSharding(self.mesh, PartitionSpec()) + + @parameterized.named_parameters(get_test_configs()) + def test_output_shape_and_gradients(self, config_dict): + """Tests STUStack for output shape and valid gradients. + + This test verifies that the STUStack runs, produces the correct output + shape, and that gradients can be computed without errors (e.g., NaNs). + + Args: + config_dict: A dictionary containing the configuration parameters for the + STUStack. + """ + self.assertEqual(jax.devices()[0].platform, "tpu") + + config = STULayerConfig( + embedding_dim=config_dict["embedding_dim"], + num_heads=config_dict["num_heads"], + hidden_dim=config_dict["hidden_dim"], + attention_dim=config_dict["attention_dim"], + target_aware=config_dict["target_aware"], + use_group_norm=config_dict["use_group_norm"], + max_attn_len=config_dict.get("max_attn_len", 0), + ) + + stu_layers = [ + STULayer(config=config, name=f"stu_layer_{i}") + for i in range(config_dict["num_layers"]) + ] + model = STUStack(stu_layers=stu_layers) + + batch_size, max_len = config_dict["batch_size"], config_dict["max_len"] + + if batch_size % self.num_devices != 0: + batch_size = (batch_size // self.num_devices + 1) * self.num_devices + logging.warning("Adjusted batch size to %d for sharding", batch_size) + + init_key, data_key, dropout_key = jax.random.split(self.key, 3) + + dummy_x = jax.random.normal( + data_key, (batch_size, max_len, config.embedding_dim) + ) + dummy_num_targets = jax.random.randint( + data_key, (batch_size,), minval=1, maxval=5 + ) + + dummy_x = jax.device_put(dummy_x, self.batch_sharding) + dummy_num_targets = jax.device_put(dummy_num_targets, self.batch_sharding) + + params = model.init( + {"params": init_key, "dropout": dropout_key}, + x=dummy_x, + num_targets=dummy_num_targets, + )["params"] + params = jax.device_put(params, self.replicated_sharding) + + @jax.jit + def loss_fn(p, x, num_targets, rng_key): + y = model.apply( + {"params": p}, + x, + num_targets=num_targets, + rngs={"dropout": rng_key}, + ) + return jnp.sum(y**2) + + # Jitted apply function + apply_fn = jax.jit( + lambda p, x, num_targets: model.apply( + {"params": p}, x, num_targets=num_targets + ), + out_shardings=self.batch_sharding, + ) + + output = apply_fn(params, dummy_x, dummy_num_targets) + self.assertEqual(output.shape, dummy_x.shape) + self.assertEqual(output.sharding, self.batch_sharding) + + grads = jax.grad(loss_fn)(params, dummy_x, dummy_num_targets, dropout_key) + + grad_leaves, _ = jax.tree_util.tree_flatten(grads) + self.assertNotEmpty(grad_leaves) + for g in grad_leaves: + self.assertFalse(jnp.any(jnp.isnan(g)), "Found NaNs in gradients") + self.assertFalse(jnp.all(g == 0), "Found all-zero gradients") + self.assertEqual(g.sharding, self.replicated_sharding) + + def test_target_invariance(self): + """Tests invariance of output with target section swaps. + + This test checks if swapping items within the target section of sequences + results in an equivalently swapped output. + """ + self.assertEqual(jax.devices()[0].platform, "tpu") + + batch_size, max_len, embedding_dim = 4, 32, 16 + # Adjust batch size to be divisible by the number of devices + if batch_size % self.num_devices != 0: + batch_size = (batch_size // self.num_devices + 1) * self.num_devices + logging.warning("Adjusted batch size to %d for sharding", batch_size) + + config = STULayerConfig( + embedding_dim=embedding_dim, + num_heads=2, + hidden_dim=24, + attention_dim=8, + target_aware=True, + causal=True, + ) + model = STUStack(stu_layers=[STULayer(config, name="stu_layer_0")]) + + init_key, data_key = jax.random.split(self.key) + x = jax.random.normal(data_key, (batch_size, max_len, embedding_dim)) + num_targets = jax.random.randint( + data_key, (batch_size,), minval=2, maxval=10 + ) + + # Shard inputs + x = jax.device_put(x, self.batch_sharding) + num_targets = jax.device_put(num_targets, self.batch_sharding) + + swap_from_offset = jnp.zeros((batch_size,), dtype=jnp.int32) + swap_to_offset = jnp.ones((batch_size,), dtype=jnp.int32) + swap_from_offset = jax.device_put(swap_from_offset, self.batch_sharding) + swap_to_offset = jax.device_put(swap_to_offset, self.batch_sharding) + + swap_from_idx = max_len - 1 - swap_from_offset + swap_to_idx = max_len - 1 - swap_to_offset + + params = model.init( + {"params": init_key, "dropout": data_key}, + x=x, + num_targets=num_targets, + )["params"] + params = jax.device_put(params, self.replicated_sharding) + + apply_fn = jax.jit( + lambda p, x, num_targets: model.apply( + {"params": p}, + x, + num_targets=num_targets, + ), + out_shardings=self.batch_sharding, + ) + + output_original = apply_fn(params, x, num_targets) + self.assertEqual(output_original.sharding, self.batch_sharding) + + def swap_rows(arr, idx1, idx2): + val1 = arr[idx1] + val2 = arr[idx2] + return arr.at[idx1].set(val2).at[idx2].set(val1) + + swapped_x = jax.vmap(swap_rows)(x, swap_from_idx, swap_to_idx) + self.assertEqual(swapped_x.sharding, self.batch_sharding) + output_swapped_input = apply_fn(params, swapped_x, num_targets) + self.assertEqual(output_swapped_input.sharding, self.batch_sharding) + + output_swapped_restored = jax.vmap(swap_rows)( + output_swapped_input, swap_from_idx, swap_to_idx + ) + self.assertEqual(output_swapped_restored.sharding, self.batch_sharding) + + np.testing.assert_allclose( + output_original, output_swapped_restored, rtol=1e-5, atol=1e-5 + ) + + +if __name__ == "__main__": + absltest.main() + diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..3849425 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -334,7 +334,7 @@ def _to_np(x: Any) -> np.ndarray: weights[key] = np.reshape(weights[key], (-1, 1)) self._batch_number += 1 - csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( features=features, features_weights=weights, feature_specs=self.sparsecore_config.feature_specs, @@ -345,6 +345,7 @@ def _to_np(x: Any) -> np.ndarray: allow_id_dropping=self.sparsecore_config.allow_id_dropping, batch_number=self._batch_number, ) + csr_inputs = preprocessed_inputs.sparse_dense_matmul_input processed_inputs = { k: v for k, v in inputs.items() if k not in sparse_features @@ -362,7 +363,7 @@ class SparsecoreEmbed(nn.Module): Attributes: sparsecore_config: A sparsecore config specifying how to create the tables. mesh: The mesh to use for the embedding layer. If not provided, the global - mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an + mesh set by `jax.set_mesh` will be used. If neither is set, an error will be raised. """ @@ -375,7 +376,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: abstract_mesh = jax.sharding.get_abstract_mesh() if not abstract_mesh.shape_tuple: raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + 'No abstract mesh shape was set with `jax.set_mesh`. Make' ' sure to set the mesh when calling the sparsecore module.' ) return abstract_mesh