Skip to content

Conversation

@BilyHurington
Copy link
Contributor

@BilyHurington BilyHurington commented Aug 18, 2025

Description

1. Motivation and Context

This pull request introduces the Data Value Embedding (DVEmb) attributor, a method for calculating trajectory-specific data influence. Unlike existing methods that often overlook the temporal dynamics of model training, DVEmb captures how the influence of a data point evolves over time by creating epoch-specific embeddings, which allows for a more accurate analysis of data value.

2. Summary of the change

github issue

3. What tests have been added/updated for the change?

  • N/A: No test will be added (please justify)
  • Unit test: Typically, this should be included if you implemented a new function/fixed a bug.
  • Application test: If you wrote an example for the toolkit, this test should be added.
  • Document test: If you added an external API, then you should check if the document is correctly generated.

@tingwl0122 tingwl0122 linked an issue Aug 18, 2025 that may be closed by this pull request
@jiaqima
Copy link
Contributor

jiaqima commented Sep 21, 2025

Please address the failed tests.

@TheaperDeng
Copy link
Collaborator

Please address the failed tests.

We are still working on a performance test to reproduce the a small experiment's result presented on the paper.

@jiaqima
Copy link
Contributor

jiaqima commented Oct 19, 2025

Please add some tests.

@BilyHurington
Copy link
Contributor Author

Please add some tests.

I've added some tests. This is still a work in progress as we're currently adding the ghost product part.

Copy link
Collaborator

@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have not fully completed the review.

def __init__(
self,
model: nn.Module,
loss_func: Callable,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to use AttributionTask here for model, loss_func's integration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

data_tensors: Tuple[Tensor, Tensor],
) -> Tensor:
inputs = data_tensors[0].unsqueeze(0)
targets = data_tensors[1].unsqueeze(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use AttributionTask (https://github.com/TRAIS-Lab/dattri/blob/main/dattri/task.py#L40) This will be a user defined function and we don't really need to assume user's data is a tuple of 2 tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

factorization_type: Type of gradient factorization to use. Options are
"none"(default),
"kronecker"(same with paper),
or "elementwise"(more memory-efficient).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is not directly related to memory since our behavior repect the proj_dim parameter as the main parameter to control the memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

loss_func: Callable,
device: str = "cpu",
proj_dim: Optional[int] = None,
factorization_type: str = "none",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also have a layer_names parameter supported here so that we can allow users to define the layers they want the grad decomposition happens. Remember, not only nn.Linear is a linear layer :)

layer_names: Optional[
Union[str, List[str]]
] = None, # Maybe support layer class as input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

project_input = self._generate_projector(
input_dim,
self.projection_dim,
)
Copy link
Collaborator

@TheaperDeng TheaperDeng Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the projector once #206 is merged.

)
projected_grads = per_sample_grads @ self.projector

scaling_factor = 1.0 / math.sqrt(self.projection_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a normalization in line 134 right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. Now dvemb's score is essentially equal to groundtruth's without using factorization.

loss_func: Callable,
device: str = "cpu",
task: AttributionTask,
criterion: nn.Module,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you have already integrated the criterion into the loss_func of AttributionTask, it may be better to use get_loss_func to get it instead of redefining it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's hard to completely remove criterion. While the task does hold all the components, we need to access the functional loss_func and the stateful criterion for two different paths.
The functional task.get_grad_loss_func() is used for the full gradient path when factorization_type="none". However, when factorization is enabled, we must use _calculate_gradient_factors, which relies on backward hooks to capture the intermediate gradient factors. These hooks are only triggered by loss.backward(), requiring us to use the nn.Module directly after a standard model.forward(). Using the functional task.get_loss_func() in this case would bypass the hook system entirely.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reply! You could use loss = task.get_loss_func(params, data) and then call loss.backward() to trigger the backward hook. (In some versions of PyTorch, torch.func.functional_call may bypass the forward hook, so you might want to be aware of that.)

By the way, you might consider optimizing calculate_gradient_factors — instead of performing an additional forward/backward pass, you could collect the factors during the parameter update step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second that we should avoid the criterion argument. Factorization is also used in the LoGraAttributor, which did not use a separate criterion. Maybe take a look there about how to avoid this argument?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason is that we are using vmap and torch.func to calculate the gradient for non-factorization path and the autograd for factorization path of dvemb(we support both of them). And the loss function (target function) has a different formatting between these two. E.g.,

def f(params, data_target_pair):
image, label = data_target_pair
loss = nn.CrossEntropyLoss()
yhat = torch.func.functional_call(model_details["model"], params, image)
return loss(yhat, label.long())

def f(model, batch, device):
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
return nn.functional.cross_entropy(outputs, targets)

One we we can do is to use task and remove this criterion for now, and detect if a false format is given by the user and provide clear instructions in document and error messages.

)

def _run_dvemb_simulation(self, attributor: DVEmbAttributor):
"""A generic simulation runner for any configured DVEmbAttributor."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be fine to collect gradients for the training dataset just to verify that the algorithm runs correctly without updating the model parameters in the unit test, but it would be good to add a comment clarifying that this isn’t the correct way to collect gradients.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've added the comment clarifying that.

@TheaperDeng TheaperDeng changed the title [WIP] DVEmbAttributor DVEmbAttributor Nov 16, 2025

from dattri.benchmark.load import load_benchmark

"""This file define the MLP model."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring about this file should be put at the top. Please elaborate a bit more about this file. "This file defines the MLP model" is inaccurate about this file.

loss_func: Callable,
device: str = "cpu",
task: AttributionTask,
criterion: nn.Module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second that we should avoid the criterion argument. Factorization is also used in the LoGraAttributor, which did not use a separate criterion. Maybe take a look there about how to avoid this argument?

@@ -0,0 +1,199 @@
"""Example code to compute Leave-One-Out (LOO) scores on MLP trained on MNIST dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this docstrng.

Copy link
Collaborator

@TheaperDeng TheaperDeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the API structure and added descriptive in-line comments.

DVEmbAttributor(
task=self.task_eager,
factorization_type="none",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a unit test to verify the validation of invalid loss function formats.

If None, uses all Linear layers.
You can check the names using model.named_modules().
Hooks will be registered on these layers to collect gradients.
Only available when factorization_type is not "none".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Will only be used when ..."

factorization_type: Type of gradient factorization to use. Options are
"none" (default),
"kronecker" (same as in the paper),
or "elementwise" (better performance while
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe elaborate a little bit about what "elementwise" does?

ValueError: If embeddings for the specified `epoch` are not found.
"""
if not self.embeddings:
msg = "Embeddings not computed. Call compute_embeddings first."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call compute_embeddings in self.cache() for consistency with other attributors? And in this error message we can ask the user to call cache()

Comment on lines 214 to 231
def fwd_hook(idx: int) -> Callable:
def _hook(
layer: nn.Module,
inputs: Tuple[Tensor, ...],
_output: Tensor,
) -> None:
a = inputs[0].detach()

if a.dim() > 2: # noqa: PLR2004
a = a.reshape(a.size(0), -1)
caches[idx]["A"] = a
caches[idx]["has_bias"] = layer.bias is not None

return _hook
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Mixed Precision, this hook forces the retention of fp32 input activations. Since Autograd saves a separate bf16 copy for the backward pass, this hook prevents the fp32 tensor from being freed, significantly increasing memory usage. To fix this, we need to use saved_tensors_hooks instead, which is non-trivial (see my recent revision to GhostSuite). For now, I suggest to make a note here and in the README.

@TheaperDeng TheaperDeng merged commit f7d41d1 into TRAIS-Lab:main Jan 27, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implementation of Data Value Embedding (DVEmb)

5 participants