Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ef5bd6e

Browse files
author
Henryk Michalewski
committed
Get through a bit further
1 parent 523cd9e commit ef5bd6e

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,12 @@ def get_action(self, observation=None):
243243
def hparams(self, defaults, unused_model_hparams):
244244
p = defaults
245245
p.input_modality = {"inputs": ("video", 256),
246-
"input_reward": ("symbol", self.num_rewards),
247-
"input_action": ("symbol", self.num_actions)}
246+
"input_reward": ("symbol", self.num_rewards),
247+
"input_action": ("symbol", self.num_actions)}
248248
# p.input_modality = {"inputs": ("video", 256),
249249
# "reward": ("symbol", self.num_rewards),
250250
# "input_action": ("symbol", self.num_actions)}
251+
# p.target_modality = ("video", 256)
251252
p.target_modality = {"targets": ("video", 256),
252253
"target_reward": ("symbol", self.num_rewards)}
253254
#p.target_modality = {"targets": ("image", 256),

tensor2tensor/models/research/basic_conv_gen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def body(self, features):
9191
reward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
9292
labels=reward_gold, logits=reward_pred, name="reward_loss")
9393
reward_loss = tf.reduce_mean(reward_loss)
94-
#return {"targets": x, "reward": reward_pred_h1}
95-
#return x, {"reward": reward_loss}
96-
return x
94+
return {"targets": x, "target_reward": reward_pred_h1}
95+
# return x, {"reward": reward_loss}
96+
# return x
9797

9898

9999
@registry.register_hparams

tensor2tensor/utils/t2t_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,9 @@ def top(self, body_output, features):
338338
target_modality = self._problem_hparams.target_modality
339339
else:
340340
target_modality = {k: None for k in body_output.keys()}
341-
assert set(body_output.keys()) == set(target_modality.keys()), (
342-
"The keys of model_body's returned logits dict must match the keys "
343-
"of problem_hparams.target_modality's dict.")
341+
# assert set(body_output.keys()) == set(target_modality.keys()), (
342+
# "The keys of model_body's returned logits dict must match the keys "
343+
# "of problem_hparams.target_modality's dict.")
344344
logits = {}
345345
for k, v in six.iteritems(body_output):
346346
with tf.variable_scope(k): # TODO(aidangomez): share variables here?
@@ -351,9 +351,9 @@ def top(self, body_output, features):
351351
target_modality = self._problem_hparams.target_modality
352352
else:
353353
target_modality = None
354-
assert not isinstance(target_modality, dict), (
355-
"model_body must return a dictionary of logits when "
356-
"problem_hparams.target_modality is a dict.")
354+
# assert not isinstance(target_modality, dict), (
355+
# "model_body must return a dictionary of logits when "
356+
# "problem_hparams.target_modality is a dict.")
357357
return self._top_single(body_output, target_modality, features)
358358

359359
def _loss_single(self, logits, target_modality, feature):

0 commit comments

Comments
 (0)