Skip to content

feat: add gumbel_softmax #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions scripts/configs/hpl/discrete/gym.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ algorithm:
reg_coef: 0.0
discrete: true
discrete_group: 8
stoc_encoding: false # for hopper-medium-expert, try true
gumbel_softmax: true
temperature: 1.0
stoc_encoding: false # for hopper-medium-replay, try true
rm_label: true

checkpoint: null
Expand All @@ -29,7 +31,7 @@ wandb:
entity: null
project: null

env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
env_kwargs:
env_wrapper:
env_wrapper_kwargs:
Expand Down Expand Up @@ -79,17 +81,17 @@ network:

rm_dataset:
- class: D4RLOfflineDataset
env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
batch_size: 64 # [64, 128]
mode: trajectory
segment_length: 100
padding_mode: none
- class: IPLComparisonOfflineDataset
env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
batch_size: 8
mode: human
- class: D4RLOfflineDataset
env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
batch_size: 512
mode: transition
rm_dataloader:
Expand All @@ -98,7 +100,7 @@ rm_dataloader:

rl_dataset:
- class: D4RLOfflineDataset
env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
batch_size: 512
mode: transition
rl_dataloader:
Expand All @@ -118,7 +120,7 @@ rm_eval:
function: eval_reward_model
eval_dataset_kwargs:
class: IPLComparisonOfflineDataset
env: hopper-medium-expert-v2
env: hopper-medium-replay-v2
batch_size: 32
mode: human
eval: false
Expand Down
2 changes: 2 additions & 0 deletions scripts/configs/hpl/discrete/metaworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ algorithm:
reg_coef: 0.0001
discrete: true
discrete_group: 8
gumbel_softmax: true
temperature: 1.0
stoc_encoding: true
rm_label: true

Expand Down
19 changes: 15 additions & 4 deletions wiserl/algorithm/hpl/hpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from wiserl.module.net.mlp import MLP
from wiserl.utils.functional import expectile_regression
from wiserl.utils.misc import make_target, sync_target
from wiserl.utils.distributions import OneHotCategoricalSTGumbelSoftmax


class Decoder(nn.Module):
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
stoc_encoding: bool = True,
discrete: bool = True,
discrete_group: int = 8,
gumbel_softmax: bool = False,
temperature: float = 1.0,
**kwargs
):
self.expectile = expectile
Expand All @@ -89,6 +92,8 @@ def __init__(
self.stoc_encoding = stoc_encoding
self.discrete = discrete
self.discrete_group = discrete_group
self.gumbel_softmax = gumbel_softmax
self.temperature = temperature
self.rm_label = rm_label
super().__init__(*args, **kwargs)
# define the attention mask for future prediction
Expand Down Expand Up @@ -200,10 +205,16 @@ def select_reward(self, batch, deterministic=False):
def get_z_distribution(self, logits):
if self.discrete:
logits = logits.reshape(*logits.shape[:-1], self.discrete_group, -1)
return torch.distributions.Independent(
torch.distributions.OneHotCategoricalStraightThrough(logits=logits),
reinterpreted_batch_ndims=1
)
if self.gumbel_softmax:
return torch.distributions.Independent(
OneHotCategoricalSTGumbelSoftmax(self.temperature, logits=logits),
reinterpreted_batch_ndims=1
)
else:
return torch.distributions.Independent(
torch.distributions.OneHotCategoricalStraightThrough(logits=logits),
reinterpreted_batch_ndims=1
)
else:
mean, logstd = logits.chunk(2, dim=-1)
return torch.distributions.Independent(
Expand Down
33 changes: 31 additions & 2 deletions wiserl/utils/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import numpy as np
import torch
from torch.distributions import Normal

from torch.distributions import Normal, OneHotCategorical
from torch.distributions.utils import clamp_probs

class TanhNormal(Normal):
def __init__(self,
Expand Down Expand Up @@ -40,3 +40,32 @@ def entropy(self):
@property
def tanh_mean(self):
return torch.tanh(self.mean)


class OneHotCategoricalSTGumbelSoftmax(OneHotCategorical):
r"""
Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
through gumbel-softmax estimator from [1].

[1] Categorical Reparametrization with Gumbel-Softmax
(Jang et al, 2017)
"""
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
super().__init__(probs, logits, validate_args=validate_args)
self.temperature = temperature

def gumbel_softmax_sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(
torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
)
gumbels = -((-(uniforms.log())).log())
y = self.logits + gumbels
return torch.nn.functional.softmax(y / self.temperature, dim=-1)

def rsample(self, sample_shape=torch.Size()):
samples = self.sample(sample_shape)
gumbel_softmax_samples = self.gumbel_softmax_sample(sample_shape)
return samples + (gumbel_softmax_samples - gumbel_softmax_samples.detach())