Skip to content

Commit 74e3e50

Browse files
Shuangping Liufacebook-github-bot
authored andcommitted
Add validation logic for keys & weights of KJT
Summary: Introduces validations for the correctness of `keys` and `weights` within a `KeyedJaggedTensor` (KJT), ensuring that KJTs are properly formed before they are used in downstream computations. Added validation rules: - `keys` should NOT contain duplications - For non-VBE scenario, `lengths` size should be divisible by the number of keys. - If `keys` is empty, the KJT should be an empty one. - `weights` must match the size of `values`. VBE scenario will be addressed in follow-up changes. Reviewed By: TroyGarden Differential Revision: D77247081 fbshipit-source-id: bf0a9915a5a499e5504c912c5d720d8e80928026
1 parent cf49291 commit 74e3e50

File tree

2 files changed

+108
-7
lines changed

2 files changed

+108
-7
lines changed

torchrec/sparse/jagged_tensor_validator.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@ def validate_keyed_jagged_tensor(
1717
"""
1818
Validates the inputs that construct a KeyedJaggedTensor.
1919
20-
This function ensures that:
21-
- At least one of lengths or offsets is provided
22-
- If both are provided, they are consistent with each other
23-
- The dimensions of these tensors align with the values tensor
24-
2520
Any invalid input will result in a ValueError being thrown.
2621
"""
27-
# TODO: Add validation checks on keys, values, weights
2822
_validate_lengths_and_offsets(kjt)
23+
_validate_keys(kjt)
24+
_validate_weights(kjt)
2925

3026

3127
def _validate_lengths_and_offsets(kjt: KeyedJaggedTensor) -> None:
28+
"""
29+
Validates the lengths and offsets of a KJT.
30+
31+
- At least one of lengths or offsets is provided
32+
- If both are provided, they are consistent with each other
33+
- The dimensions of these tensors align with the values tensor
34+
"""
3235
lengths = kjt.lengths_or_none()
3336
offsets = kjt.offsets_or_none()
3437
if lengths is None and offsets is None:
@@ -76,3 +79,45 @@ def _validate_offsets(offsets: torch.Tensor, values: torch.Tensor) -> None:
7679
raise ValueError(
7780
f"The last element of offsets must equal to the number of values, but got {offsets[-1]} and {values.numel()}"
7881
)
82+
83+
84+
def _validate_keys(kjt: KeyedJaggedTensor) -> None:
85+
"""
86+
Validates KJT keys, assuming the lengths/offsets input are valid.
87+
88+
- keys must be unique
89+
- For non-VBE cases, the size of lengths is divisible by the number of keys
90+
"""
91+
keys = kjt.keys()
92+
93+
if len(set(keys)) != len(keys):
94+
raise ValueError("keys must be unique")
95+
96+
lengths = kjt.lengths_or_none()
97+
offsets = kjt.offsets_or_none()
98+
if lengths is not None:
99+
lengths_size = lengths.numel()
100+
else:
101+
assert offsets is not None
102+
lengths_size = offsets.numel() - 1
103+
104+
if len(keys) == 0 and lengths_size > 0:
105+
raise ValueError("keys is empty but lengths or offsets is not")
106+
elif len(keys) > 0:
107+
# TODO: Validate KJT for VBE cases
108+
if not kjt.variable_stride_per_key():
109+
if lengths_size % len(keys) != 0:
110+
raise ValueError(
111+
f"lengths size must be divisible by keys size, but got {lengths_size} and {len(keys)}"
112+
)
113+
114+
115+
def _validate_weights(kjt: KeyedJaggedTensor) -> None:
116+
"""
117+
Validates if the KJT weights has the same size as values.
118+
"""
119+
weights = kjt.weights_or_none()
120+
if weights is not None and weights.numel() != kjt.values().numel():
121+
raise ValueError(
122+
f"weights size must equal to values size, but got {weights.numel()} and {kjt.values().numel()}"
123+
)

torchrec/sparse/tests/test_jagged_tensor_validator.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,71 @@ class TestJaggedTensorValidator(unittest.TestCase):
122122
),
123123
]
124124

125-
@parameterized.expand(INVALID_LENGTHS_OFFSETS_CASES)
125+
INVALID_KEYS_CASES = [
126+
param(
127+
expected_error_msg="keys must be unique",
128+
keys=["f1", "f1"],
129+
values=torch.tensor([1, 2, 3, 4, 5]),
130+
lengths=torch.tensor([1, 2, 0, 2]),
131+
offsets=torch.tensor([0, 1, 3, 3, 5]),
132+
),
133+
param(
134+
expected_error_msg="keys is empty but lengths or offsets is not",
135+
keys=[],
136+
values=torch.tensor([1, 2, 3, 4, 5]),
137+
lengths=torch.tensor([1, 2, 0, 2]),
138+
offsets=torch.tensor([0, 1, 3, 3, 5]),
139+
),
140+
param(
141+
expected_error_msg="lengths size must be divisible by keys size",
142+
keys=["f1", "f2", "f3"],
143+
values=torch.tensor([1, 2, 3, 4, 5]),
144+
lengths=torch.tensor([1, 2, 0, 2]),
145+
offsets=torch.tensor([0, 1, 3, 3, 5]),
146+
),
147+
]
148+
149+
INVALID_WEIGHTS_CASES = [
150+
param(
151+
expected_error_msg="weights size must equal to values size",
152+
keys=["f1", "f2"],
153+
values=torch.tensor([1, 2, 3, 4, 5]),
154+
lengths=torch.tensor([1, 2, 0, 2]),
155+
offsets=torch.tensor([0, 1, 3, 3, 5]),
156+
weights=torch.tensor([0.1, 0.2, 0.3, 0.4]),
157+
),
158+
param(
159+
expected_error_msg="weights size must equal to values size",
160+
keys=["f1", "f2"],
161+
values=torch.tensor([1, 2, 3, 4, 5]),
162+
lengths=torch.tensor([1, 2, 0, 2]),
163+
offsets=torch.tensor([0, 1, 3, 3, 5]),
164+
weights=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
165+
),
166+
]
167+
168+
@parameterized.expand(
169+
[
170+
*INVALID_LENGTHS_OFFSETS_CASES,
171+
*INVALID_KEYS_CASES,
172+
*INVALID_WEIGHTS_CASES,
173+
]
174+
)
126175
def test_invalid_keyed_jagged_tensor(
127176
self,
128177
expected_error_msg: str,
129178
keys: List[str],
130179
values: torch.Tensor,
131180
lengths: Optional[torch.Tensor],
132181
offsets: Optional[torch.Tensor],
182+
weights: Optional[torch.Tensor] = None,
133183
) -> None:
134184
kjt = KeyedJaggedTensor(
135185
keys=keys,
136186
values=values,
137187
lengths=lengths,
138188
offsets=offsets,
189+
weights=weights,
139190
)
140191

141192
with self.assertRaises(ValueError) as err:
@@ -181,3 +232,8 @@ def test_valid_kjt_from_offsets(
181232
)
182233

183234
validate_keyed_jagged_tensor(kjt)
235+
236+
def test_valid_empty_kjt(self) -> None:
237+
kjt = KeyedJaggedTensor.empty()
238+
239+
validate_keyed_jagged_tensor(kjt)

0 commit comments

Comments
 (0)