-
Notifications
You must be signed in to change notification settings - Fork 5
[Algo] Added sac codebase #5
base: dev
Are you sure you want to change the base?
Conversation
| _has_functorch = False | ||
|
|
||
|
|
||
| class SACLoss(LossModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does that differ from the TorchRL SAC exactly? If there's an extra feature I'd prefer to add it to torchrl directly, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I used the local sac_loss is because torchRL sac.py requires you to pass three networks: actor, qvalue, and value. In SAC, you don't have a value function as far as I remember.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented that following regorously what the paper presented, but if it works better with one net only we can put that as an option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is SAC-v1, I think more commonly used is SAC-v2 (https://arxiv.org/abs/1812.05905). Checkout section 4.2 in the paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in my opinion, it is worth adding the v2 implementation of SACLoss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. My point is mainly that rather than coding up a new SAC, we should simply add the v2 to the SAC loss. As it is now, we're sort of saying "TorchRL has everything you need... but they got SAC wrong so here's a patch"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you have a look at pytorch/rl#864?
| _has_tv = False | ||
|
|
||
|
|
||
| class _RRLNet(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really see why we need a new env for this. We could create R3M with download=False, and load the state dict from torchvision no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure if the architecture of R3M is different from ResNet torchvision module. Plus I think this is a cleaner way to do it? but we can switch to loading weights if you think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not 100% sure if the architecture of R3M is different from ResNet torchvision module
What would be different? The only thing that pretrained=True does is load a state_dict, the architecture is 100% the same
Have a look at my PR on torchrl.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool. I have never tested R3M backbone against ResNet backbone but they might be exactly same. Thanks! I will take a look and update the code



No description provided.