|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import unittest
|
| 11 | +from typing import Optional |
| 12 | +from unittest.mock import patch |
11 | 13 |
|
12 | 14 | import fbgemm_gpu
|
13 | 15 |
|
14 | 16 | import hypothesis.strategies as st
|
15 | 17 |
|
16 | 18 | import torch
|
| 19 | +from fbgemm_gpu.config import FeatureGateName |
17 | 20 | from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
|
18 | 21 | ComputeDevice,
|
19 | 22 | EmbeddingLocation,
|
|
38 | 41 | from hypothesis import given, settings
|
39 | 42 |
|
40 | 43 | from .. import common # noqa E402
|
| 44 | +from ..common import running_in_oss |
41 | 45 |
|
42 | 46 | try:
|
43 | 47 | # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
|
@@ -147,6 +151,104 @@ def test_report_stats(
|
147 | 151 | == tbeconfig.indices_params.offset_dtype
|
148 | 152 | ), "Extracted config does not match the original TBEDataConfig"
|
149 | 153 |
|
| 154 | + # pyre-ignore[56] |
| 155 | + @given( |
| 156 | + T=st.integers(1, 10), |
| 157 | + E=st.integers(100, 10000), |
| 158 | + D=st.sampled_from([32, 64, 128, 256]), |
| 159 | + L=st.integers(1, 10), |
| 160 | + B=st.integers(20, 100), |
| 161 | + ) |
| 162 | + @settings(max_examples=1, deadline=None) |
| 163 | + @unittest.skipIf(*running_in_oss) |
| 164 | + def test_report_fb_files( |
| 165 | + self, |
| 166 | + T: int, |
| 167 | + E: int, |
| 168 | + D: int, |
| 169 | + L: int, |
| 170 | + B: int, |
| 171 | + ) -> None: |
| 172 | + """ |
| 173 | + Test writing extrcted TBEDataConfig to FB FileStore |
| 174 | + """ |
| 175 | + from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore |
| 176 | + |
| 177 | + # Initialize the reporter |
| 178 | + bucket = "tlparse_reports" |
| 179 | + path_prefix = "tree/unit_tests/" |
| 180 | + |
| 181 | + # Generate a TBEDataConfig |
| 182 | + tbeconfig = TBEDataConfig( |
| 183 | + T=T, |
| 184 | + E=E, |
| 185 | + D=D, |
| 186 | + mixed_dim=False, |
| 187 | + weighted=False, |
| 188 | + batch_params=BatchParams(B=B), |
| 189 | + indices_params=IndicesParams( |
| 190 | + heavy_hitters=torch.tensor([]), |
| 191 | + zipf_q=0.1, |
| 192 | + zipf_s=0.1, |
| 193 | + index_dtype=torch.int64, |
| 194 | + offset_dtype=torch.int64, |
| 195 | + ), |
| 196 | + pooling_params=PoolingParams(L=L), |
| 197 | + use_cpu=not torch.cuda.is_available(), |
| 198 | + ) |
| 199 | + |
| 200 | + embedding_location = ( |
| 201 | + EmbeddingLocation.DEVICE |
| 202 | + if torch.cuda.is_available() |
| 203 | + else EmbeddingLocation.HOST |
| 204 | + ) |
| 205 | + |
| 206 | + # Generate the embedding dimension list |
| 207 | + _, Ds = generate_embedding_dims(tbeconfig) |
| 208 | + |
| 209 | + with patch( |
| 210 | + "torch.ops.fbgemm.check_feature_gate_key" |
| 211 | + ) as mock_check_feature_gate_key: |
| 212 | + # Mock the return value for TBE_REPORT_INPUT_PARAMS |
| 213 | + def side_effect(feature_name: str) -> Optional[bool]: |
| 214 | + if feature_name == FeatureGateName.TBE_REPORT_INPUT_PARAMS.name: |
| 215 | + return True |
| 216 | + |
| 217 | + mock_check_feature_gate_key.side_effect = side_effect |
| 218 | + |
| 219 | + # Generate the embedding operation |
| 220 | + embedding_op = SplitTableBatchedEmbeddingBagsCodegen( |
| 221 | + [ |
| 222 | + ( |
| 223 | + tbeconfig.E, |
| 224 | + D, |
| 225 | + embedding_location, |
| 226 | + ( |
| 227 | + ComputeDevice.CUDA |
| 228 | + if torch.cuda.is_available() |
| 229 | + else ComputeDevice.CPU |
| 230 | + ), |
| 231 | + ) |
| 232 | + for D in Ds |
| 233 | + ], |
| 234 | + ) |
| 235 | + |
| 236 | + embedding_op = embedding_op.to(get_device()) |
| 237 | + |
| 238 | + # Generate indices and offsets |
| 239 | + request = generate_requests(tbeconfig, 1)[0] |
| 240 | + |
| 241 | + # Execute the embedding operation with reporting flag enable |
| 242 | + embedding_op.forward(request.indices, request.offsets) |
| 243 | + |
| 244 | + # Check if the file was written to Manifold |
| 245 | + store = FileStore(bucket) |
| 246 | + path = f"{path_prefix}tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter_cpu.item()}.json" |
| 247 | + assert store.exists(path), f"{path} not exists" |
| 248 | + |
| 249 | + # Clenaup, delete the file |
| 250 | + store.remove(path) |
| 251 | + |
150 | 252 |
|
151 | 253 | if __name__ == "__main__":
|
152 | 254 | unittest.main()
|
0 commit comments