Fix: CQL ood regularization #7
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I'm pretty sure the CQL ood loss is not implemented correctly.
a) The softmax is not backpropagated through, so no penalty to OOD actions is actually applied
b) The softmax reduces over the ensemble dimension, rather than over the batchAdditionally, I'm not entirely sure if next_pi_q should really be evaluated in the next observations. While it technically makes more sense, CORL evaluates the actions in batch.obs, and they seemed to have been able to replicate the CQL performance (albeit with some difficulty - tinkoff-ai/CORL#14).
I think taking the next_actions for estimating the logsumexp for (s,a) is just something that helped during training, its a little odd sure, but i think we should keep the batch.obs there.
To illustrate, here's a simple 1D bandit benchmark with the estimated q-values before/after the fixes.
Thank you for the great library!