Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,27 @@ def _update(
labels, torch.Tensor
)

# Metrics such as TensorWeightedAvgMetric will have tensors that we also need to stack.
# Stack in task order: (n_tasks, batch_size)
if "required_inputs" in kwargs:
target_tensors: list[torch.Tensor] = []
for task in self._tasks:
if (
task.tensor_name
and task.tensor_name in kwargs["required_inputs"]
):
target_tensors.append(
kwargs["required_inputs"][task.tensor_name]
)

if target_tensors:
stacked_tensor = torch.stack(target_tensors)

# Reshape the stacked_tensor to size([len(self._tasks), self._batch_size])
stacked_tensor = stacked_tensor.view(len(self._tasks), -1)
assert isinstance(stacked_tensor, torch.Tensor)
kwargs["required_inputs"]["target_tensor"] = stacked_tensor

predictions = (
# Reshape the predictions to size([len(self._tasks), self._batch_size])
predictions.view(len(self._tasks), -1)
Expand Down
136 changes: 96 additions & 40 deletions torchrec/metrics/tensor_weighted_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,31 @@ class TensorWeightedAvgMetricComputation(RecMetricComputation):

It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
passed in as a required input instead of the predictions tensor.

FUSED_TASKS_COMPUTATION:
This class requires all target tensors from tasks to be stacked together in RecMetrics._update().
During TensorWeightedAvgMetricComputation.update(), the target tensor is sliced into the correct
"""

def __init__(
self,
*args: Any,
tensor_name: Optional[str] = None,
weighted: bool = True,
description: Optional[str] = None,
tasks: List[RecTaskInfo],
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if tensor_name is None:
raise RecMetricException(
f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}"
)
self.tensor_name: str = tensor_name
self.weighted: bool = weighted
self.tasks: List[RecTaskInfo] = tasks

for task in self.tasks:
if task.tensor_name is None:
raise RecMetricException(
"TensorWeightedAvgMetricComputation expects all tasks to have tensor_name, but got None."
)

self.weighted_mask: torch.Tensor = torch.tensor(
[task.weighted for task in self.tasks], device=self.device
).unsqueeze(dim=-1)

self._add_state(
"weighted_sum",
torch.zeros(self._n_tasks, dtype=torch.double),
Expand All @@ -61,7 +69,6 @@ def __init__(
dist_reduce_fx="sum",
persistent=True,
)
self._description = description

def update(
self,
Expand All @@ -71,26 +78,60 @@ def update(
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if (
"required_inputs" not in kwargs
or self.tensor_name not in kwargs["required_inputs"]
):

target_tensor: torch.Tensor

if "required_inputs" not in kwargs:
raise RecMetricException(
f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs"
"TensorWeightedAvgMetricComputation expects 'required_inputs' to exist."
)
else:
if len(self.tasks) > 1:
# In FUSED mode, RecMetric._update() always creates "target_tensor" for the stacked tensor.
# Note that RecMetric._update() only stacks if the tensor_name exists in kwargs["required_inputs"].
target_tensor = cast(
torch.Tensor,
kwargs["required_inputs"]["target_tensor"],
)
elif len(self.tasks) == 1:
# UNFUSED_TASKS_COMPUTATION
tensor_name = self.tasks[0].tensor_name
if tensor_name not in kwargs["required_inputs"]:
raise RecMetricException(
f"TensorWeightedAvgMetricComputation expects required_inputs to contain target tensor {self.tasks[0].tensor_name}"
)
else:
target_tensor = cast(
torch.Tensor,
kwargs["required_inputs"][tensor_name],
)
else:
raise RecMetricException(
"TensorWeightedAvgMetricComputation expects at least one task."
)

num_samples = labels.shape[0]
target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name])
weights = cast(torch.Tensor, weights)

# Vectorized computation using masks
weighted_values = torch.where(
self.weighted_mask, target_tensor * weights, target_tensor
)

weighted_counts = torch.where(
self.weighted_mask, weights, torch.ones_like(weights)
)

# Sum across batch dimension
weighted_sum = weighted_values.sum(dim=-1) # Shape: (n_tasks,)
weighted_num_samples = weighted_counts.sum(dim=-1) # Shape: (n_tasks,)

# Update states
states = {
"weighted_sum": (
target_tensor * weights if self.weighted else target_tensor
).sum(dim=-1),
"weighted_num_samples": (
weights.sum(dim=-1)
if self.weighted
else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device)
),
"weighted_sum": weighted_sum,
"weighted_num_samples": weighted_num_samples,
}

for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
Expand All @@ -105,7 +146,6 @@ def _compute(self) -> List[MetricComputationReport]:
cast(torch.Tensor, self.weighted_sum),
cast(torch.Tensor, self.weighted_num_samples),
),
description=self._description,
),
MetricComputationReport(
name=MetricName.WEIGHTED_AVG,
Expand All @@ -114,7 +154,6 @@ def _compute(self) -> List[MetricComputationReport]:
self.get_window_state("weighted_sum"),
self.get_window_state("weighted_num_samples"),
),
description=self._description,
),
]

Expand All @@ -126,23 +165,40 @@ class TensorWeightedAvgMetric(RecMetric):
def _get_task_kwargs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Dict[str, Any]:
if not isinstance(task_config, RecTaskInfo):
raise RecMetricException(
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
)
all_tasks = (
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
)
return {
"tensor_name": task_config.tensor_name,
"weighted": task_config.weighted,
"tasks": all_tasks,
}

def _get_task_required_inputs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Set[str]:
if not isinstance(task_config, RecTaskInfo):
raise RecMetricException(
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
)
required_inputs = set()
if task_config.tensor_name is not None:
required_inputs.add(task_config.tensor_name)
return required_inputs
"""
Returns the required inputs for the task.

FUSED_TASKS_COMPUTATION:
- Given two tasks with the same tensor_name, assume the same tensor reference
- For a given tensor_name, assume all tasks have the same weighted flag
"""
all_tasks = (
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
)

required_inputs: dict[str, bool] = {}
for task in all_tasks:
if task.tensor_name is not None:
if (
task.tensor_name in required_inputs
and task.weighted is not required_inputs[task.tensor_name]
):
existing_weighted_flag = required_inputs[task.tensor_name]
raise RecMetricException(
f"This target tensor was already registered as weighted={existing_weighted_flag}. "
f"This target tensor cannot be re-registered with weighted={task.weighted}"
)
else:
required_inputs[str(task.tensor_name)] = task.weighted

return set(required_inputs.keys())
Loading
Loading