Skip to content

Commit 237f38f

Browse files
chandrasekhard2recml authors
authored andcommitted
This CL introduces an end-to-end Flax implementation of the DLRM-HSTU model, along with individual unit tests for the various components.
- Converted all necessary PyTorch modules to Flax/Linen - Full training loop test for end to end model - Added unit tests for individual components (e.g. `ActionEncoder`, `ContentEncoder`, `STU` modules) - Ensured model runs correctly on TPUs. - Verified shape-correctness of all modules and parameters. Reverts changelist 793734230 PiperOrigin-RevId: 789375174
1 parent 847628b commit 237f38f

20 files changed

+3711
-16
lines changed

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/partitioning.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array:
107107
def partition_init(
108108
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
109109
) -> CreateStateFn:
110-
with jax.sharding.use_mesh(self.mesh):
110+
with jax.set_mesh(self.mesh):
111111
if abstract_batch is not None:
112112
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113113
specs = nn.get_partition_spec(abstract_state)
@@ -117,7 +117,7 @@ def partition_init(
117117
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
118118

119119
def _wrapped_init(batch: PyTree) -> State:
120-
with jax.sharding.use_mesh(self.mesh):
120+
with jax.set_mesh(self.mesh):
121121
state = init_fn(batch)
122122
state = _maybe_unbox_state(state)
123123
return state
@@ -130,15 +130,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130130
jit_kws["out_shardings"] = (self.state_sharding, None)
131131
jit_kws["donate_argnums"] = (1,)
132132

133-
with jax.sharding.use_mesh(self.mesh):
133+
with jax.set_mesh(self.mesh):
134134
step_fn = jax.jit(
135135
fn,
136136
in_shardings=(self.data_sharding, self.state_sharding),
137137
**jit_kws,
138138
)
139139

140140
def _wrapped_step(batch: PyTree, state: State) -> Any:
141-
with jax.sharding.use_mesh(self.mesh):
141+
with jax.set_mesh(self.mesh):
142142
return step_fn(batch, state)
143143

144144
return _wrapped_step
@@ -217,7 +217,7 @@ def __init__(
217217
def mesh_context_manager(
218218
self,
219219
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
220-
return jax.sharding.use_mesh
220+
return jax.set_mesh
221221

222222
def shard_inputs(self, inputs: PyTree) -> PyTree:
223223
def _shard(x: np.ndarray) -> jax.Array:
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""JAX implementation of the ActionEncoder module."""
15+
16+
from typing import Dict, List, Optional, Tuple
17+
18+
import flax.linen as nn
19+
from flax.linen import initializers
20+
import jax
21+
import jax.numpy as jnp
22+
23+
24+
class ActionEncoder(nn.Module):
25+
"""Encodes categorical actions and continuous watch times into a fixed-size embedding.
26+
27+
assumes dense tensors of shape (batch_size, sequence_length) for all inputs.
28+
"""
29+
30+
action_embedding_dim: int
31+
action_feature_name: str
32+
action_weights: List[int]
33+
watchtime_feature_name: str = ""
34+
watchtime_to_action_thresholds_and_weights: Optional[
35+
List[Tuple[int, int]]
36+
] = None
37+
38+
def setup(self):
39+
"""Initializes parameters and constants for the module."""
40+
wt_thresholds_and_weights = (
41+
self.watchtime_to_action_thresholds_and_weights or []
42+
)
43+
44+
self.combined_action_weights = jnp.array(
45+
list(self.action_weights) + [w for _, w in wt_thresholds_and_weights]
46+
)
47+
48+
self.num_action_types: int = (
49+
len(self.action_weights) + len(wt_thresholds_and_weights)
50+
)
51+
52+
self.action_embedding_table = self.param(
53+
"action_embedding_table",
54+
initializers.normal(stddev=0.1),
55+
(self.num_action_types, self.action_embedding_dim),
56+
)
57+
58+
self.target_action_embedding_table = self.param(
59+
"target_action_embedding_table",
60+
initializers.normal(stddev=0.1),
61+
(1, self.output_embedding_dim),
62+
)
63+
64+
@property
65+
def output_embedding_dim(self) -> int:
66+
"""The dimension of the final output embedding."""
67+
num_watchtime_actions = (
68+
len(self.watchtime_to_action_thresholds_and_weights)
69+
if self.watchtime_to_action_thresholds_and_weights
70+
else 0
71+
)
72+
num_action_types = len(self.action_weights) + num_watchtime_actions
73+
return self.action_embedding_dim * num_action_types
74+
75+
def __call__(
76+
self,
77+
seq_payloads: Dict[str, jax.Array],
78+
is_target_mask: jax.Array,
79+
) -> jax.Array:
80+
"""Processes a batch of sequences to generate action embeddings.
81+
82+
Args:
83+
seq_payloads: A dictionary of feature names to dense tensors of shape
84+
`(batch_size, sequence_length)`.
85+
is_target_mask: A boolean tensor of shape `(batch_size,
86+
sequence_length)` where `True` indicates a target item.
87+
88+
Returns:
89+
A dense tensor of action embeddings of shape
90+
`(batch_size, sequence_length, output_embedding_dim)`.
91+
"""
92+
93+
seq_actions = seq_payloads[self.action_feature_name]
94+
95+
wt_thresholds_and_weights = (
96+
self.watchtime_to_action_thresholds_and_weights or []
97+
)
98+
if wt_thresholds_and_weights:
99+
watchtimes = seq_payloads[self.watchtime_feature_name]
100+
for threshold, weight in wt_thresholds_and_weights:
101+
watch_action = (watchtimes >= threshold).astype(jnp.int64) * weight
102+
seq_actions = jnp.bitwise_or(seq_actions, watch_action)
103+
104+
exploded_actions = (
105+
jnp.bitwise_and(seq_actions[..., None], self.combined_action_weights)
106+
> 0
107+
)
108+
109+
history_embeddings = (
110+
exploded_actions[..., None] * self.action_embedding_table
111+
).reshape(*seq_actions.shape, -1)
112+
113+
target_embeddings = jnp.broadcast_to(
114+
self.target_action_embedding_table, history_embeddings.shape
115+
)
116+
117+
final_embeddings = jnp.where(
118+
is_target_mask[..., None],
119+
target_embeddings,
120+
history_embeddings,
121+
)
122+
123+
return final_embeddings
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import jax
15+
import jax.numpy as jnp
16+
import numpy as np
17+
import numpy.testing as npt
18+
# from third_party.py.pybase import googletest
19+
from absl.testing import absltest
20+
from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder
21+
22+
23+
class ActionEncoderJaxTest(absltest.TestCase):
24+
def test_forward_and_backward(self) -> None:
25+
"""Tests the ActionEncoder's forward pass logic and differentiability."""
26+
27+
batch_size = 2
28+
max_seq_len = 6
29+
action_embedding_dim = 32
30+
action_weights = [1, 2, 4, 8, 16]
31+
watchtime_to_action_thresholds_and_weights = [
32+
(30, 32), (60, 64), (100, 128),
33+
]
34+
num_action_types = len(action_weights) + len(
35+
watchtime_to_action_thresholds_and_weights
36+
)
37+
output_dim = action_embedding_dim * num_action_types
38+
combined_action_weights = action_weights + [
39+
w for _, w in watchtime_to_action_thresholds_and_weights
40+
]
41+
42+
enabled_actions = [
43+
[0], # Seq 1, Item 1
44+
[0, 1], # Seq 1, Item 2
45+
[1, 3, 4], # Seq 1, Item 3
46+
[1, 2, 3, 4], # Seq 1, Item 4
47+
[1, 2], # Seq 2, Item 1
48+
[2], # Seq 2, Item 2
49+
]
50+
watchtimes_flat = [40, 20, 110, 31, 26, 55]
51+
52+
# Add actions based on watchtime thresholds
53+
for i, wt in enumerate(watchtimes_flat):
54+
for j, (threshold, _) in enumerate(
55+
watchtime_to_action_thresholds_and_weights
56+
):
57+
if wt > threshold:
58+
enabled_actions[i].append(j + len(action_weights))
59+
60+
actions_flat = [
61+
sum([combined_action_weights[t] for t in x]) for x in enabled_actions
62+
]
63+
64+
padded_actions = np.zeros((batch_size, max_seq_len), dtype=np.int64)
65+
padded_watchtimes = np.zeros((batch_size, max_seq_len), dtype=np.int64)
66+
67+
padded_actions[0, :4] = actions_flat[0:4]
68+
padded_actions[1, :2] = actions_flat[4:6]
69+
padded_watchtimes[0, :4] = watchtimes_flat[0:4]
70+
padded_watchtimes[1, :2] = watchtimes_flat[4:6]
71+
72+
is_target_mask = np.zeros((batch_size, max_seq_len), dtype=bool)
73+
is_target_mask[0, 4:6] = True
74+
is_target_mask[1, 2] = True
75+
76+
padding_mask = np.zeros((batch_size, max_seq_len), dtype=bool)
77+
padding_mask[0, :6] = True
78+
padding_mask[1, :3] = True
79+
80+
seq_payloads = {
81+
"watchtimes": jnp.array(padded_watchtimes),
82+
"actions": jnp.array(padded_actions),
83+
}
84+
85+
encoder = ActionEncoder(
86+
watchtime_feature_name="watchtimes",
87+
action_feature_name="actions",
88+
action_weights=action_weights,
89+
watchtime_to_action_thresholds_and_weights=(
90+
watchtime_to_action_thresholds_and_weights
91+
),
92+
action_embedding_dim=action_embedding_dim,
93+
)
94+
95+
key = jax.random.PRNGKey(0)
96+
variables = encoder.init(key, seq_payloads, is_target_mask)
97+
params = variables["params"]
98+
99+
action_embeddings = encoder.apply(
100+
variables, seq_payloads, is_target_mask
101+
)
102+
103+
self.assertEqual(
104+
action_embeddings.shape, (batch_size, max_seq_len, output_dim)
105+
)
106+
107+
action_table = params["action_embedding_table"]
108+
target_table_flat = params["target_action_embedding_table"]
109+
target_table = target_table_flat.reshape(num_action_types, -1)
110+
111+
history_item_idx = 0
112+
for b in range(batch_size):
113+
for s in range(max_seq_len):
114+
if not padding_mask[b, s]:
115+
npt.assert_allclose(action_embeddings[b, s], 0, atol=1e-6)
116+
continue
117+
118+
embedding = action_embeddings[b, s].reshape(num_action_types, -1)
119+
120+
if is_target_mask[b, s]:
121+
npt.assert_allclose(embedding, target_table, atol=1e-6)
122+
else:
123+
current_enabled = enabled_actions[history_item_idx]
124+
for atype in range(num_action_types):
125+
if atype in current_enabled:
126+
npt.assert_allclose(
127+
embedding[atype], action_table[atype], atol=1e-6
128+
)
129+
else:
130+
npt.assert_allclose(embedding[atype],
131+
jnp.zeros_like(embedding[atype]),
132+
atol=1e-6)
133+
history_item_idx += 1
134+
135+
def loss_fn(p):
136+
return encoder.apply({"params": p}, seq_payloads, is_target_mask).sum()
137+
138+
grads = jax.grad(loss_fn)(params)
139+
self.assertIsNotNone(grads)
140+
self.assertFalse(np.all(np.isclose(grads["action_embedding_table"], 0)))
141+
self.assertFalse(np.all(
142+
np.isclose(grads["target_action_embedding_table"], 0)
143+
))
144+
145+
146+
if __name__ == "__main__":
147+
absltest.main()

0 commit comments

Comments
 (0)