Pass the current epoch to the aggregate function of Weighting Interface #617
Closed
GiovanniCanali
started this conversation in
Ideas
Replies: 1 comment
-
|
We solved this in #620 by adding a link to the solver. Adding a link to the solver gives for free the possibility to access the trainer ( |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, the
aggregatemethod of any class inheriting fromWeightingInterfaceaccepts onlylosses, a dictionary storing the loss per condition.In principle, this should be enough. However, some weighting schemes require updating the weights, which can be computationally expensive -- for instance, self-adaptive weighting as described in "Simulating Three-dimensional Turbulence with Physics-informed Neural Networks" (see
PirateNetwork). In such cases, it may be preferable to perform these computations only everykepochs using a simple modulus operation.Unfortunately, this is not currently possible: there is no straightforward way for the
WeightingInterfaceto be aware of the current epoch. While using trivial internal counters partially addresses the problem, they fail when batching is involved, as they would count each epoch multiple times.Therefore, I propose adding the current epoch as an argument to the
aggregatemethod ofWeightingInterface, allowing it to be passed directly from the solver.Beta Was this translation helpful? Give feedback.
All reactions