-
-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
questionUser queriesUser queries
Description
In the LoRA implementation, the way the API is setup is such that we have to keep the abstract values consistent for Arrays:
quax/quax/examples/lora/_core.py
Lines 89 to 90 in b72049d
| def aval(self): | |
| return jax.core.ShapedArray(self.w.shape, self.w.dtype) |
However, when counting parameters later on this becomes annoying and one has to resort to ugly solutions because the PyTree obscures the actual internal parameters of the LoraArray, and thus we arrive at a huge overestimate.
How do you think this can be resolved?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries