diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 81878f3a1..2a5a83d86 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -623,6 +623,27 @@ def _update( labels, torch.Tensor ) + # Metrics such as TensorWeightedAvgMetric will have tensors that we also need to stack. + # Stack in task order: (n_tasks, batch_size) + if "required_inputs" in kwargs: + target_tensors: list[torch.Tensor] = [] + for task in self._tasks: + if ( + task.tensor_name + and task.tensor_name in kwargs["required_inputs"] + ): + target_tensors.append( + kwargs["required_inputs"][task.tensor_name] + ) + + if target_tensors: + stacked_tensor = torch.stack(target_tensors) + + # Reshape the stacked_tensor to size([len(self._tasks), self._batch_size]) + stacked_tensor = stacked_tensor.view(len(self._tasks), -1) + assert isinstance(stacked_tensor, torch.Tensor) + kwargs["required_inputs"]["target_tensor"] = stacked_tensor + predictions = ( # Reshape the predictions to size([len(self._tasks), self._batch_size]) predictions.view(len(self._tasks), -1) diff --git a/torchrec/metrics/tensor_weighted_avg.py b/torchrec/metrics/tensor_weighted_avg.py index 580432351..ccace53cf 100644 --- a/torchrec/metrics/tensor_weighted_avg.py +++ b/torchrec/metrics/tensor_weighted_avg.py @@ -30,23 +30,31 @@ class TensorWeightedAvgMetricComputation(RecMetricComputation): It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor passed in as a required input instead of the predictions tensor. + + FUSED_TASKS_COMPUTATION: + This class requires all target tensors from tasks to be stacked together in RecMetrics._update(). + During TensorWeightedAvgMetricComputation.update(), the target tensor is sliced into the correct """ def __init__( self, *args: Any, - tensor_name: Optional[str] = None, - weighted: bool = True, - description: Optional[str] = None, + tasks: List[RecTaskInfo], **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - if tensor_name is None: - raise RecMetricException( - f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}" - ) - self.tensor_name: str = tensor_name - self.weighted: bool = weighted + self.tasks: List[RecTaskInfo] = tasks + + for task in self.tasks: + if task.tensor_name is None: + raise RecMetricException( + "TensorWeightedAvgMetricComputation expects all tasks to have tensor_name, but got None." + ) + + self.weighted_mask: torch.Tensor = torch.tensor( + [task.weighted for task in self.tasks], device=self.device + ).unsqueeze(dim=-1) + self._add_state( "weighted_sum", torch.zeros(self._n_tasks, dtype=torch.double), @@ -61,7 +69,6 @@ def __init__( dist_reduce_fx="sum", persistent=True, ) - self._description = description def update( self, @@ -71,26 +78,60 @@ def update( weights: Optional[torch.Tensor], **kwargs: Dict[str, Any], ) -> None: - if ( - "required_inputs" not in kwargs - or self.tensor_name not in kwargs["required_inputs"] - ): + + target_tensor: torch.Tensor + + if "required_inputs" not in kwargs: raise RecMetricException( - f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs" + "TensorWeightedAvgMetricComputation expects 'required_inputs' to exist." ) + else: + if len(self.tasks) > 1: + # In FUSED mode, RecMetric._update() always creates "target_tensor" for the stacked tensor. + # Note that RecMetric._update() only stacks if the tensor_name exists in kwargs["required_inputs"]. + target_tensor = cast( + torch.Tensor, + kwargs["required_inputs"]["target_tensor"], + ) + elif len(self.tasks) == 1: + # UNFUSED_TASKS_COMPUTATION + tensor_name = self.tasks[0].tensor_name + if tensor_name not in kwargs["required_inputs"]: + raise RecMetricException( + f"TensorWeightedAvgMetricComputation expects required_inputs to contain target tensor {self.tasks[0].tensor_name}" + ) + else: + target_tensor = cast( + torch.Tensor, + kwargs["required_inputs"][tensor_name], + ) + else: + raise RecMetricException( + "TensorWeightedAvgMetricComputation expects at least one task." + ) + num_samples = labels.shape[0] - target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name]) weights = cast(torch.Tensor, weights) + + # Vectorized computation using masks + weighted_values = torch.where( + self.weighted_mask, target_tensor * weights, target_tensor + ) + + weighted_counts = torch.where( + self.weighted_mask, weights, torch.ones_like(weights) + ) + + # Sum across batch dimension + weighted_sum = weighted_values.sum(dim=-1) # Shape: (n_tasks,) + weighted_num_samples = weighted_counts.sum(dim=-1) # Shape: (n_tasks,) + + # Update states states = { - "weighted_sum": ( - target_tensor * weights if self.weighted else target_tensor - ).sum(dim=-1), - "weighted_num_samples": ( - weights.sum(dim=-1) - if self.weighted - else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device) - ), + "weighted_sum": weighted_sum, + "weighted_num_samples": weighted_num_samples, } + for state_name, state_value in states.items(): state = getattr(self, state_name) state += state_value @@ -105,7 +146,6 @@ def _compute(self) -> List[MetricComputationReport]: cast(torch.Tensor, self.weighted_sum), cast(torch.Tensor, self.weighted_num_samples), ), - description=self._description, ), MetricComputationReport( name=MetricName.WEIGHTED_AVG, @@ -114,7 +154,6 @@ def _compute(self) -> List[MetricComputationReport]: self.get_window_state("weighted_sum"), self.get_window_state("weighted_num_samples"), ), - description=self._description, ), ] @@ -126,23 +165,40 @@ class TensorWeightedAvgMetric(RecMetric): def _get_task_kwargs( self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] ) -> Dict[str, Any]: - if not isinstance(task_config, RecTaskInfo): - raise RecMetricException( - f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings." - ) + all_tasks = ( + [task_config] if isinstance(task_config, RecTaskInfo) else task_config + ) return { - "tensor_name": task_config.tensor_name, - "weighted": task_config.weighted, + "tasks": all_tasks, } def _get_task_required_inputs( self, task_config: Union[RecTaskInfo, List[RecTaskInfo]] ) -> Set[str]: - if not isinstance(task_config, RecTaskInfo): - raise RecMetricException( - f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings." - ) - required_inputs = set() - if task_config.tensor_name is not None: - required_inputs.add(task_config.tensor_name) - return required_inputs + """ + Returns the required inputs for the task. + + FUSED_TASKS_COMPUTATION: + - Given two tasks with the same tensor_name, assume the same tensor reference + - For a given tensor_name, assume all tasks have the same weighted flag + """ + all_tasks = ( + [task_config] if isinstance(task_config, RecTaskInfo) else task_config + ) + + required_inputs: dict[str, bool] = {} + for task in all_tasks: + if task.tensor_name is not None: + if ( + task.tensor_name in required_inputs + and task.weighted is not required_inputs[task.tensor_name] + ): + existing_weighted_flag = required_inputs[task.tensor_name] + raise RecMetricException( + f"This target tensor was already registered as weighted={existing_weighted_flag}. " + f"This target tensor cannot be re-registered with weighted={task.weighted}" + ) + else: + required_inputs[str(task.tensor_name)] = task.weighted + + return set(required_inputs.keys()) diff --git a/torchrec/metrics/tests/test_tensor_weighted_avg.py b/torchrec/metrics/tests/test_tensor_weighted_avg.py index 79041693d..74973be70 100644 --- a/torchrec/metrics/tests/test_tensor_weighted_avg.py +++ b/torchrec/metrics/tests/test_tensor_weighted_avg.py @@ -79,26 +79,22 @@ def test_tensor_weighted_avg_unfused(self) -> None: entry_point=metric_test_helper, ) - def test_tensor_weighted_avg_fused_fails(self) -> None: - """Test that TensorWeightedAvgMetric fails with FUSED_TASKS_COMPUTATION as expected.""" - # This test verifies the current limitation - FUSED mode should fail - with self.assertRaisesRegex( - RecMetricException, "expects task_config to be RecTaskInfo not" - ): - rec_metric_value_test_launcher( - target_clazz=TensorWeightedAvgMetric, - target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, - test_clazz=TestTensorWeightedAvgMetric, - metric_name=METRIC_NAMESPACE, - task_names=["t1", "t2", "t3"], - fused_update_limit=0, - compute_on_all_ranks=False, - should_validate_update=False, - world_size=WORLD_SIZE, - entry_point=metric_test_helper, - ) + def test_tensor_weighted_avg_fused(self) -> None: + """Test TensorWeightedAvgMetric with FUSED_TASKS_COMPUTATION.""" + rec_metric_value_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) - def test_tensor_weighted_avg_single_task(self) -> None: + def test_tensor_weighted_avg_single_task_unfused(self) -> None: rec_metric_value_test_launcher( target_clazz=TensorWeightedAvgMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, @@ -112,17 +108,48 @@ def test_tensor_weighted_avg_single_task(self) -> None: entry_point=metric_test_helper, ) + def test_tensor_weighted_avg_single_task_fused(self) -> None: + """Test TensorWeightedAvgMetric with single task in FUSED_TASKS_COMPUTATION mode.""" + rec_metric_value_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["single_task"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + class TensorWeightedAvgGPUSyncTest(unittest.TestCase): """GPU synchronization tests for TensorWeightedAvgMetric.""" - def test_sync_tensor_weighted_avg(self) -> None: + def test_sync_tensor_weighted_avg_unfused(self) -> None: rec_metric_gpu_sync_test_launcher( target_clazz=TensorWeightedAvgMetric, target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, test_clazz=TestTensorWeightedAvgMetric, metric_name=METRIC_NAMESPACE, - task_names=["t1"], + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + def test_sync_tensor_weighted_avg_fused(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=TensorWeightedAvgMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestTensorWeightedAvgMetric, + metric_name=METRIC_NAMESPACE, + task_names=["t1", "t2", "t3"], fused_update_limit=0, compute_on_all_ranks=False, should_validate_update=False, @@ -161,8 +188,8 @@ def test_tensor_weighted_avg_basic_functionality(self) -> None: self.assertEqual(len(metric._metrics_computations), 1) computation = metric._metrics_computations[0] - self.assertEqual(computation.tensor_name, "test_tensor") - self.assertTrue(computation.weighted) + self.assertEqual(computation.tasks[0].tensor_name, "test_tensor") + self.assertTrue(computation.tasks[0].weighted) def test_tensor_weighted_avg_unweighted_task(self) -> None: @@ -188,35 +215,65 @@ def test_tensor_weighted_avg_unweighted_task(self) -> None: ) computation = metric._metrics_computations[0] - self.assertEqual(computation.tensor_name, "test_tensor") - self.assertFalse(computation.weighted) + self.assertEqual(computation.tasks[0].tensor_name, "test_tensor") + self.assertFalse(computation.tasks[0].weighted) - def test_tensor_weighted_avg_missing_tensor_name_throws_exception(self) -> None: - - # Create task with None tensor_name + def test_tensor_weighted_avg_unfused_required_inputs_validation(self) -> None: tasks = [ RecTaskInfo( name="test_task", label_name="test_label", prediction_name="test_pred", weight_name="test_weight", - tensor_name=None, + tensor_name="test_tensor", weighted=True, ) ] + metric = TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=2, + tasks=tasks, + compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + window_size=100, + ) + + # Test that required inputs are correctly identified + required_inputs = metric.get_required_inputs() + self.assertIn("test_tensor", required_inputs) + + # Test update with missing required inputs should fail + with self.assertRaisesRegex(RecMetricException, "required_inputs"): + metric.update( + predictions={"test_task": torch.tensor([0.1, 0.2])}, + labels={"test_task": torch.tensor([1.0, 0.0])}, + weights={"test_task": torch.tensor([1.0, 2.0])}, + ) + + def test_tensor_weighted_avg_unfused_missing_tensor_name_init_error(self) -> None: with self.assertRaisesRegex(RecMetricException, "tensor_name"): TensorWeightedAvgMetric( world_size=1, my_rank=0, - batch_size=4, - tasks=tasks, + batch_size=2, + tasks=[ + RecTaskInfo( + name="test_task", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name=None, + weighted=True, + ) + ], compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, ) - def test_tensor_weighted_avg_required_inputs_validation(self) -> None: - tasks = [ + def test_tensor_weighted_avg_unfused_missing_tensor_name_update_error(self) -> None: + """Test error when tensor_name is missing in required_inputs for UNFUSED mode.""" + single_task = [ RecTaskInfo( name="test_task", label_name="test_label", @@ -227,25 +284,51 @@ def test_tensor_weighted_avg_required_inputs_validation(self) -> None: ) ] - metric = TensorWeightedAvgMetric( + single_metric = TensorWeightedAvgMetric( world_size=1, my_rank=0, batch_size=2, - tasks=tasks, + tasks=single_task, compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, window_size=100, ) - # Test that required inputs are correctly identified - required_inputs = metric.get_required_inputs() - self.assertIn("test_tensor", required_inputs) - - # Test update with missing required inputs should fail - with self.assertRaisesRegex(RecMetricException, "required_inputs"): - metric.update( + with self.assertRaisesRegex(RecMetricException, "test_tensor"): + single_metric.update( predictions={"test_task": torch.tensor([0.1, 0.2])}, labels={"test_task": torch.tensor([1.0, 0.0])}, weights={"test_task": torch.tensor([1.0, 2.0])}, + required_inputs={"wrong_tensor_name": torch.tensor([1.0, 2.0])}, + ) + + def test_tensor_weighted_avg_fused_conflicting_weighted_flags_error(self) -> None: + with self.assertRaisesRegex( + RecMetricException, "already registered as weighted" + ): + TensorWeightedAvgMetric( + world_size=1, + my_rank=0, + batch_size=2, + tasks=[ + RecTaskInfo( + name="task1", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name="shared_tensor", + weighted=True, + ), + RecTaskInfo( + name="task2", + label_name="test_label", + prediction_name="test_pred", + weight_name="test_weight", + tensor_name="shared_tensor", + weighted=False, + ), + ], + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + window_size=100, ) @@ -389,7 +472,7 @@ def _test_tensor_weighted_avg_helper( msg=f"Actual: {cur_actual_tensor_weighted_avg}, Expected: {cur_expected_tensor_weighted_avg}", ) - def test_tensor_weighted_avg_computation_correctness(self) -> None: + def test_tensor_weighted_avg_correctness(self) -> None: """Test tensor weighted average computation correctness with known values.""" test_data = generate_tensor_model_outputs_cases() for inputs in test_data: