| 
8 | 8 | # pyre-strict  | 
9 | 9 | 
 
  | 
10 | 10 | import unittest  | 
 | 11 | +from typing import Optional  | 
 | 12 | +from unittest.mock import patch  | 
11 | 13 | 
 
  | 
12 | 14 | import fbgemm_gpu  | 
13 | 15 | 
 
  | 
14 | 16 | import hypothesis.strategies as st  | 
15 | 17 | 
 
  | 
16 | 18 | import torch  | 
 | 19 | +from fbgemm_gpu.config import FeatureGateName  | 
17 | 20 | from fbgemm_gpu.split_table_batched_embeddings_ops_common import (  | 
18 | 21 |     ComputeDevice,  | 
19 | 22 |     EmbeddingLocation,  | 
 | 
38 | 41 | from hypothesis import given, settings  | 
39 | 42 | 
 
  | 
40 | 43 | from .. import common  # noqa E402  | 
 | 44 | +from ..common import running_in_oss  | 
41 | 45 | 
 
  | 
42 | 46 | try:  | 
43 | 47 |     # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.  | 
@@ -147,6 +151,104 @@ def test_report_stats(  | 
147 | 151 |             == tbeconfig.indices_params.offset_dtype  | 
148 | 152 |         ), "Extracted config does not match the original TBEDataConfig"  | 
149 | 153 | 
 
  | 
 | 154 | +    # pyre-ignore[56]  | 
 | 155 | +    @given(  | 
 | 156 | +        T=st.integers(1, 10),  | 
 | 157 | +        E=st.integers(100, 10000),  | 
 | 158 | +        D=st.sampled_from([32, 64, 128, 256]),  | 
 | 159 | +        L=st.integers(1, 10),  | 
 | 160 | +        B=st.integers(20, 100),  | 
 | 161 | +    )  | 
 | 162 | +    @settings(max_examples=1, deadline=None)  | 
 | 163 | +    @unittest.skipIf(*running_in_oss)  | 
 | 164 | +    def test_report_fb_files(  | 
 | 165 | +        self,  | 
 | 166 | +        T: int,  | 
 | 167 | +        E: int,  | 
 | 168 | +        D: int,  | 
 | 169 | +        L: int,  | 
 | 170 | +        B: int,  | 
 | 171 | +    ) -> None:  | 
 | 172 | +        """  | 
 | 173 | +        Test writing extrcted TBEDataConfig to FB FileStore  | 
 | 174 | +        """  | 
 | 175 | +        from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore  | 
 | 176 | + | 
 | 177 | +        # Initialize the reporter  | 
 | 178 | +        bucket = "tlparse_reports"  | 
 | 179 | +        path_prefix = "tree/unit_tests/"  | 
 | 180 | + | 
 | 181 | +        # Generate a TBEDataConfig  | 
 | 182 | +        tbeconfig = TBEDataConfig(  | 
 | 183 | +            T=T,  | 
 | 184 | +            E=E,  | 
 | 185 | +            D=D,  | 
 | 186 | +            mixed_dim=False,  | 
 | 187 | +            weighted=False,  | 
 | 188 | +            batch_params=BatchParams(B=B),  | 
 | 189 | +            indices_params=IndicesParams(  | 
 | 190 | +                heavy_hitters=torch.tensor([]),  | 
 | 191 | +                zipf_q=0.1,  | 
 | 192 | +                zipf_s=0.1,  | 
 | 193 | +                index_dtype=torch.int64,  | 
 | 194 | +                offset_dtype=torch.int64,  | 
 | 195 | +            ),  | 
 | 196 | +            pooling_params=PoolingParams(L=L),  | 
 | 197 | +            use_cpu=not torch.cuda.is_available(),  | 
 | 198 | +        )  | 
 | 199 | + | 
 | 200 | +        embedding_location = (  | 
 | 201 | +            EmbeddingLocation.DEVICE  | 
 | 202 | +            if torch.cuda.is_available()  | 
 | 203 | +            else EmbeddingLocation.HOST  | 
 | 204 | +        )  | 
 | 205 | + | 
 | 206 | +        # Generate the embedding dimension list  | 
 | 207 | +        _, Ds = generate_embedding_dims(tbeconfig)  | 
 | 208 | + | 
 | 209 | +        with patch(  | 
 | 210 | +            "torch.ops.fbgemm.check_feature_gate_key"  | 
 | 211 | +        ) as mock_check_feature_gate_key:  | 
 | 212 | +            # Mock the return value for TBE_REPORT_INPUT_PARAMS  | 
 | 213 | +            def side_effect(feature_name: str) -> Optional[bool]:  | 
 | 214 | +                if feature_name == FeatureGateName.TBE_REPORT_INPUT_PARAMS.name:  | 
 | 215 | +                    return True  | 
 | 216 | + | 
 | 217 | +            mock_check_feature_gate_key.side_effect = side_effect  | 
 | 218 | + | 
 | 219 | +            # Generate the embedding operation  | 
 | 220 | +            embedding_op = SplitTableBatchedEmbeddingBagsCodegen(  | 
 | 221 | +                [  | 
 | 222 | +                    (  | 
 | 223 | +                        tbeconfig.E,  | 
 | 224 | +                        D,  | 
 | 225 | +                        embedding_location,  | 
 | 226 | +                        (  | 
 | 227 | +                            ComputeDevice.CUDA  | 
 | 228 | +                            if torch.cuda.is_available()  | 
 | 229 | +                            else ComputeDevice.CPU  | 
 | 230 | +                        ),  | 
 | 231 | +                    )  | 
 | 232 | +                    for D in Ds  | 
 | 233 | +                ],  | 
 | 234 | +            )  | 
 | 235 | + | 
 | 236 | +            embedding_op = embedding_op.to(get_device())  | 
 | 237 | + | 
 | 238 | +            # Generate indices and offsets  | 
 | 239 | +            request = generate_requests(tbeconfig, 1)[0]  | 
 | 240 | + | 
 | 241 | +            # Execute the embedding operation with reporting flag enable  | 
 | 242 | +            embedding_op.forward(request.indices, request.offsets)  | 
 | 243 | + | 
 | 244 | +            # Check if the file was written to Manifold  | 
 | 245 | +            store = FileStore(bucket)  | 
 | 246 | +            path = f"{path_prefix}tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter_cpu.item()}.json"  | 
 | 247 | +            assert store.exists(path), f"{path} not exists"  | 
 | 248 | + | 
 | 249 | +            # Clenaup, delete the file  | 
 | 250 | +            store.remove(path)  | 
 | 251 | + | 
150 | 252 | 
 
  | 
151 | 253 | if __name__ == "__main__":  | 
152 | 254 |     unittest.main()  | 
0 commit comments