Skip to content

Commit 02d1bbc

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add TBE data configuration reporter to TBE forward (v3) (#4672)
Summary: X-link: facebookresearch/FBGEMM#1703 Pull Request resolved: #4672 X-link: facebookresearch/FBGEMM#1516 Pull Request resolved: #4455 Re-land attempt of D75462895 # Add TBE data configuration reporter to TBE forward call. The reporter reports TBE data configuration at the `SplitTableBatchedEmbeddingBagsCodegen` ***forward*** call. The output is a `TBEDataConfig` object, which is written to a JSON file(s). The configuration of its environment variables and an example of its usage is described below. ## Just Knobs for enablement - fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS is added for enablement of the reporter (https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures) - Default is set to `False`, enable this flag to enable reporter. - To enable it locally use: ``` jk canary set fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS --on --ttl 600 ``` ## Environment Variables --------------------- The Reporter relies on several environment variables to control its behavior. Below is a description of each variable: - **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL**: - **Description**: Determines the interval at which reports are generated. This is specified in terms of the number of iterations. - **Example Value**: `1` (report every iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_START**: - ***Description**: Specifies the start of the iteration range to capture reports. Default 0. - ***Example Value**: `0` (start reporting from the first iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_END**: - ***Description**: Specifies the end of the iteration range to capture reports. Use `-1` to report until the last iteration. Default -1. - ***Example Value**: `-1` (report until the last iteration) - **FBGEMM_REPORT_INPUT_PARAMS_BUCKET**: * **Description**: Specifies the name of the Manifold bucket where the report data will be saved. * **Example Value**: `tlparse_reports` - **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX**: - **Description**: Defines the path prefix where the report files will be stored. Path will be created if not exist. - **Example Value**: `tree/tests/` ## Use Cases - FileStore - General - Auto-create output directories if not exist. - fb-internal: - Only export to manifold. - Assert error, if the flag is set but failed to initialize manifold connection. (missing backend or manifold bucket is not exist) - OSS - Will use local FileStore to store the output ## Example Usage ------------- Below is an example command demonstrating how to use the FBGEMM Reporter with specific environment variable settings: ``` FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2 FBGEMM_REPORT_INPUT_PARAMS_ITER_START=3 FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/ buck2 run mode/opt //deeplearning/fbgemm/fbgemm_gpu/bench:split_table_batched_embeddings -- device --iters 2 ``` **Explanation** The above setting will report `iter 3` and `iter 5` * **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2**: The reporter will generate a report every 2 iterations. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_START=0**: The reporter will start generating reports from the first iteration. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_END=-1 (Default)**: The reporter will continue to generate reports until the last iteration interval. * **FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports**: The reports will be saved in the `tlparse_reports` bucket. * **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/**: The reports will be stored with the path prefix `tree/tests/`. For Manifold make sure all folders within the path exist. **Note on Benchmark example** Note that with the `--iters 2` option, the benchmark will execute 6 forward calls (2 iterations plus 1 warmup) for the forward benchmark and another 3 calls (2 iterations plus 1 warmup) for the backward benchmark. Iteration starts from 0. --- --- ## Other includes changes in this Diff: - Updates build dependency of tbe_data_config* files - Remove `shutil` and `numpy.random` lib as it cause uncompatiblity error. - Add non-OSS test, writing extracted config data json file to Manifold Differential Revision: D79758603
1 parent 405d7e6 commit 02d1bbc

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,11 @@ def __init__( # noqa C901
14411441
self._debug_print_input_stats_factory()
14421442
)
14431443

1444+
# Get a reporter function pointer
1445+
self._report_input_params: Callable[..., None] = (
1446+
self.__report_input_params_factory()
1447+
)
1448+
14441449
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
14451450
# Register writeback hook for Exact_SGD optimizer
14461451
self.log(
@@ -1953,6 +1958,19 @@ def forward( # noqa: C901
19531958
# Print input stats if enable (for debugging purpose only)
19541959
self._debug_print_input_stats(indices, offsets, per_sample_weights)
19551960

1961+
# Extract and Write input stats if enable
1962+
if self._report_input_params is not None:
1963+
self._report_input_params(
1964+
feature_rows=self.rows_per_table,
1965+
feature_dims=self.feature_dims,
1966+
iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
1967+
indices=indices,
1968+
offsets=offsets,
1969+
op_id=self.uuid,
1970+
per_sample_weights=per_sample_weights,
1971+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1972+
)
1973+
19561974
if not is_torchdynamo_compiling():
19571975
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
19581976

@@ -3829,6 +3847,30 @@ def _debug_print_input_stats_factory_null(
38293847
return _debug_print_input_stats_factory_impl
38303848
return _debug_print_input_stats_factory_null
38313849

3850+
@torch.jit.ignore
3851+
def __report_input_params_factory(
3852+
self,
3853+
) -> Optional[Callable[..., None]]:
3854+
"""
3855+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
3856+
3857+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
3858+
- Reports input parameters (TBEDataConfig).
3859+
- Writes the output as a JSON file.
3860+
3861+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
3862+
"""
3863+
try:
3864+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
3865+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
3866+
3867+
reporter = TBEBenchmarkParamsReporter.create()
3868+
return reporter.report_stats
3869+
except Exception:
3870+
return None
3871+
3872+
return None
3873+
38323874

38333875
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
38343876
"""

fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
# pyre-strict
99

1010
import unittest
11+
from typing import Optional
12+
from unittest.mock import patch
1113

1214
import fbgemm_gpu
1315

1416
import hypothesis.strategies as st
1517

1618
import torch
19+
from fbgemm_gpu.config import FeatureGateName
1720
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
1821
ComputeDevice,
1922
EmbeddingLocation,
@@ -38,6 +41,7 @@
3841
from hypothesis import given, settings
3942

4043
from .. import common # noqa E402
44+
from ..common import running_in_oss
4145

4246
try:
4347
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
@@ -147,6 +151,104 @@ def test_report_stats(
147151
== tbeconfig.indices_params.offset_dtype
148152
), "Extracted config does not match the original TBEDataConfig"
149153

154+
# pyre-ignore[56]
155+
@given(
156+
T=st.integers(1, 10),
157+
E=st.integers(100, 10000),
158+
D=st.sampled_from([32, 64, 128, 256]),
159+
L=st.integers(1, 10),
160+
B=st.integers(20, 100),
161+
)
162+
@settings(max_examples=1, deadline=None)
163+
@unittest.skipIf(*running_in_oss)
164+
def test_report_fb_files(
165+
self,
166+
T: int,
167+
E: int,
168+
D: int,
169+
L: int,
170+
B: int,
171+
) -> None:
172+
"""
173+
Test writing extrcted TBEDataConfig to FB FileStore
174+
"""
175+
from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore
176+
177+
# Initialize the reporter
178+
bucket = "tlparse_reports"
179+
path_prefix = "tree/unit_tests/"
180+
181+
# Generate a TBEDataConfig
182+
tbeconfig = TBEDataConfig(
183+
T=T,
184+
E=E,
185+
D=D,
186+
mixed_dim=False,
187+
weighted=False,
188+
batch_params=BatchParams(B=B),
189+
indices_params=IndicesParams(
190+
heavy_hitters=torch.tensor([]),
191+
zipf_q=0.1,
192+
zipf_s=0.1,
193+
index_dtype=torch.int64,
194+
offset_dtype=torch.int64,
195+
),
196+
pooling_params=PoolingParams(L=L),
197+
use_cpu=not torch.cuda.is_available(),
198+
)
199+
200+
embedding_location = (
201+
EmbeddingLocation.DEVICE
202+
if torch.cuda.is_available()
203+
else EmbeddingLocation.HOST
204+
)
205+
206+
# Generate the embedding dimension list
207+
_, Ds = generate_embedding_dims(tbeconfig)
208+
209+
with patch(
210+
"torch.ops.fbgemm.check_feature_gate_key"
211+
) as mock_check_feature_gate_key:
212+
# Mock the return value for TBE_REPORT_INPUT_PARAMS
213+
def side_effect(feature_name: str) -> Optional[bool]:
214+
if feature_name == FeatureGateName.TBE_REPORT_INPUT_PARAMS.name:
215+
return True
216+
217+
mock_check_feature_gate_key.side_effect = side_effect
218+
219+
# Generate the embedding operation
220+
embedding_op = SplitTableBatchedEmbeddingBagsCodegen(
221+
[
222+
(
223+
tbeconfig.E,
224+
D,
225+
embedding_location,
226+
(
227+
ComputeDevice.CUDA
228+
if torch.cuda.is_available()
229+
else ComputeDevice.CPU
230+
),
231+
)
232+
for D in Ds
233+
],
234+
)
235+
236+
embedding_op = embedding_op.to(get_device())
237+
238+
# Generate indices and offsets
239+
request = generate_requests(tbeconfig, 1)[0]
240+
241+
# Execute the embedding operation with reporting flag enable
242+
embedding_op.forward(request.indices, request.offsets)
243+
244+
# Check if the file was written to Manifold
245+
store = FileStore(bucket)
246+
path = f"{path_prefix}tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter_cpu.item()}.json"
247+
assert store.exists(path), f"{path} not exists"
248+
249+
# Clenaup, delete the file
250+
store.remove(path)
251+
150252

151253
if __name__ == "__main__":
152254
unittest.main()

0 commit comments

Comments
 (0)