diff --git a/scripts/configs/hpl/discrete/gym.yaml b/scripts/configs/hpl/discrete/gym.yaml index a9fcf1d..80f77d8 100644 --- a/scripts/configs/hpl/discrete/gym.yaml +++ b/scripts/configs/hpl/discrete/gym.yaml @@ -16,7 +16,9 @@ algorithm: reg_coef: 0.0 discrete: true discrete_group: 8 - stoc_encoding: false # for hopper-medium-expert, try true + gumbel_softmax: true + temperature: 1.0 + stoc_encoding: false # for hopper-medium-replay, try true rm_label: true checkpoint: null @@ -29,7 +31,7 @@ wandb: entity: null project: null -env: hopper-medium-expert-v2 +env: hopper-medium-replay-v2 env_kwargs: env_wrapper: env_wrapper_kwargs: @@ -79,17 +81,17 @@ network: rm_dataset: - class: D4RLOfflineDataset - env: hopper-medium-expert-v2 + env: hopper-medium-replay-v2 batch_size: 64 # [64, 128] mode: trajectory segment_length: 100 padding_mode: none - class: IPLComparisonOfflineDataset - env: hopper-medium-expert-v2 + env: hopper-medium-replay-v2 batch_size: 8 mode: human - class: D4RLOfflineDataset - env: hopper-medium-expert-v2 + env: hopper-medium-replay-v2 batch_size: 512 mode: transition rm_dataloader: @@ -98,7 +100,7 @@ rm_dataloader: rl_dataset: - class: D4RLOfflineDataset - env: hopper-medium-expert-v2 + env: hopper-medium-replay-v2 batch_size: 512 mode: transition rl_dataloader: @@ -118,7 +120,7 @@ rm_eval: function: eval_reward_model eval_dataset_kwargs: class: IPLComparisonOfflineDataset - env: hopper-medium-expert-v2 + env: hopper-medium-replay-v2 batch_size: 32 mode: human eval: false diff --git a/scripts/configs/hpl/discrete/metaworld.yaml b/scripts/configs/hpl/discrete/metaworld.yaml index 3c60ed9..61a678b 100644 --- a/scripts/configs/hpl/discrete/metaworld.yaml +++ b/scripts/configs/hpl/discrete/metaworld.yaml @@ -16,6 +16,8 @@ algorithm: reg_coef: 0.0001 discrete: true discrete_group: 8 + gumbel_softmax: true + temperature: 1.0 stoc_encoding: true rm_label: true diff --git a/wiserl/algorithm/hpl/hpl.py b/wiserl/algorithm/hpl/hpl.py index 557efd7..b28828f 100644 --- a/wiserl/algorithm/hpl/hpl.py +++ b/wiserl/algorithm/hpl/hpl.py @@ -16,6 +16,7 @@ from wiserl.module.net.mlp import MLP from wiserl.utils.functional import expectile_regression from wiserl.utils.misc import make_target, sync_target +from wiserl.utils.distributions import OneHotCategoricalSTGumbelSoftmax class Decoder(nn.Module): @@ -70,6 +71,8 @@ def __init__( stoc_encoding: bool = True, discrete: bool = True, discrete_group: int = 8, + gumbel_softmax: bool = False, + temperature: float = 1.0, **kwargs ): self.expectile = expectile @@ -89,6 +92,8 @@ def __init__( self.stoc_encoding = stoc_encoding self.discrete = discrete self.discrete_group = discrete_group + self.gumbel_softmax = gumbel_softmax + self.temperature = temperature self.rm_label = rm_label super().__init__(*args, **kwargs) # define the attention mask for future prediction @@ -200,10 +205,16 @@ def select_reward(self, batch, deterministic=False): def get_z_distribution(self, logits): if self.discrete: logits = logits.reshape(*logits.shape[:-1], self.discrete_group, -1) - return torch.distributions.Independent( - torch.distributions.OneHotCategoricalStraightThrough(logits=logits), - reinterpreted_batch_ndims=1 - ) + if self.gumbel_softmax: + return torch.distributions.Independent( + OneHotCategoricalSTGumbelSoftmax(self.temperature, logits=logits), + reinterpreted_batch_ndims=1 + ) + else: + return torch.distributions.Independent( + torch.distributions.OneHotCategoricalStraightThrough(logits=logits), + reinterpreted_batch_ndims=1 + ) else: mean, logstd = logits.chunk(2, dim=-1) return torch.distributions.Independent( diff --git a/wiserl/utils/distributions.py b/wiserl/utils/distributions.py index 141ecad..dea4c41 100644 --- a/wiserl/utils/distributions.py +++ b/wiserl/utils/distributions.py @@ -3,8 +3,8 @@ import numpy as np import torch -from torch.distributions import Normal - +from torch.distributions import Normal, OneHotCategorical +from torch.distributions.utils import clamp_probs class TanhNormal(Normal): def __init__(self, @@ -40,3 +40,32 @@ def entropy(self): @property def tanh_mean(self): return torch.tanh(self.mean) + + +class OneHotCategoricalSTGumbelSoftmax(OneHotCategorical): + r""" + Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- + through gumbel-softmax estimator from [1]. + + [1] Categorical Reparametrization with Gumbel-Softmax + (Jang et al, 2017) + """ + has_rsample = True + + def __init__(self, temperature, probs=None, logits=None, validate_args=None): + super().__init__(probs, logits, validate_args=validate_args) + self.temperature = temperature + + def gumbel_softmax_sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + uniforms = clamp_probs( + torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) + ) + gumbels = -((-(uniforms.log())).log()) + y = self.logits + gumbels + return torch.nn.functional.softmax(y / self.temperature, dim=-1) + + def rsample(self, sample_shape=torch.Size()): + samples = self.sample(sample_shape) + gumbel_softmax_samples = self.gumbel_softmax_sample(sample_shape) + return samples + (gumbel_softmax_samples - gumbel_softmax_samples.detach())