Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,6 @@ wandb/
.jpg
.tmp_listen*
.tmp_speak*
saved_datasets/
saved_datasets/
env.json
mbodied/agents/motion/rt_pali/checkpoints/*
4 changes: 4 additions & 0 deletions mbodied/agents/motion/rt_pali/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: 2024-present Tilak1114 <[email protected]>
#
# SPDX-License-Identifier: MIT
__version__ = "0.0.1"
3 changes: 3 additions & 0 deletions mbodied/agents/motion/rt_pali/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present Tilak1114 <[email protected]>
#
# SPDX-License-Identifier: MIT
Empty file.
94 changes: 94 additions & 0 deletions mbodied/agents/motion/rt_pali/action_tokenizer/action_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json

import torch


class ActionTokenizer:
def __init__(self):
"""Initializes the ActionTokenizer.

Loads the ra_to_token_map from a JSON file and sets the total number of bins for discretization.
"""
self.bins = 256
with open('mbodied/agents/motion/rt_pali/action_tokenizer/ra_to_token_map.json', 'r') as f:
self.ra_to_token_map = json.load(f)

def discretize_values(self, pose_data: dict) -> dict:
"""Discretizes continuous pose data into discrete bins.

Args:
pose_data (dict): Dictionary containing pose data with keys like 'grasp', 'terminated',
and other positional and orientation values that need to be discretized.

Returns:
dict: Dictionary containing the discretized pose data.
"""
discrete_data = {}

for key, scaled_value in pose_data.items():
if key == 'grasp' or key == 'terminated':
# Ensure grasp is binary: 0 or 1
bin_index = (self.bins-1) if scaled_value > 0.5 else 0
discretized = f"ra_{bin_index}"
discrete_data[key] = discretized
else:
# Quantize the scaled value to a bin index
quantized_tensor = torch.quantize_per_tensor(
torch.tensor([scaled_value], dtype=torch.float32),
scale=1/(self.bins - 1),
zero_point=0,
dtype=torch.quint8
)
bin_index = int(quantized_tensor.int_repr().item())
discretized = f"ra_{bin_index}"
discrete_data[key] = discretized

return discrete_data

def reverse_discretize_values(self, discrete_data: dict) -> dict:
"""Converts discrete values back to continuous values.

Args:
discrete_data (dict): Dictionary containing discretized pose data.

Returns:
dict: Dictionary containing the continuous pose data.
"""
inverse_data = {}

for key, token in discrete_data.items():
bin_index = int(token.rsplit('_', 1)[1])
scaled_value = bin_index / (self.bins - 1)

inverse_data[key] = scaled_value

return inverse_data

def tokenize(self, pose_data: dict) -> str:
"""Converts pose data into a string of action tokens.

Args:
pose_data (dict): Dictionary containing pose data to be tokenized.

Returns:
str: A space-separated string of action tokens.
"""
discretized_data = self.discretize_values(pose_data)
action_tokens = list(discretized_data.values())
return " ".join([self.ra_to_token_map[ra_tkn] for ra_tkn in action_tokens])

def detokenize(self, tokens: list) -> dict:
"""Converts a list of action tokens back into continuous pose data.

Args:
tokens (list): List of action tokens.

Returns:
dict: Dictionary containing the continuous pose data.
"""
token_to_ra_map = {v: k for k, v in self.ra_to_token_map.items()}
action_tokens = [token_to_ra_map[tkn] for tkn in tokens]
discretized_data = dict(zip(["terminated", "x", "y", "z", "roll", "pitch", "yaw", "grasp"],
action_tokens)
)
return self.reverse_discretize_values(discretized_data)
258 changes: 258 additions & 0 deletions mbodied/agents/motion/rt_pali/action_tokenizer/ra_to_token_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{
"ra_0": "<unused0>",
"ra_1": "<unused1>",
"ra_2": "<unused2>",
"ra_3": "<unused3>",
"ra_4": "<unused4>",
"ra_5": "<unused5>",
"ra_6": "<unused6>",
"ra_7": "<unused7>",
"ra_8": "<unused8>",
"ra_9": "<unused9>",
"ra_10": "<unused10>",
"ra_11": "<unused11>",
"ra_12": "<unused12>",
"ra_13": "<unused13>",
"ra_14": "<unused14>",
"ra_15": "<unused15>",
"ra_16": "<unused16>",
"ra_17": "<unused17>",
"ra_18": "<unused18>",
"ra_19": "<unused19>",
"ra_20": "<unused20>",
"ra_21": "<unused21>",
"ra_22": "<unused22>",
"ra_23": "<unused23>",
"ra_24": "<unused24>",
"ra_25": "<unused25>",
"ra_26": "<unused26>",
"ra_27": "<unused27>",
"ra_28": "<unused28>",
"ra_29": "<unused29>",
"ra_30": "<unused30>",
"ra_31": "<unused31>",
"ra_32": "<unused32>",
"ra_33": "<unused33>",
"ra_34": "<unused34>",
"ra_35": "<unused35>",
"ra_36": "<unused36>",
"ra_37": "<unused37>",
"ra_38": "<unused38>",
"ra_39": "<unused39>",
"ra_40": "<unused40>",
"ra_41": "<unused41>",
"ra_42": "<unused42>",
"ra_43": "<unused43>",
"ra_44": "<unused44>",
"ra_45": "<unused45>",
"ra_46": "<unused46>",
"ra_47": "<unused47>",
"ra_48": "<unused48>",
"ra_49": "<unused49>",
"ra_50": "<unused50>",
"ra_51": "<unused51>",
"ra_52": "<unused52>",
"ra_53": "<unused53>",
"ra_54": "<unused54>",
"ra_55": "<unused55>",
"ra_56": "<unused56>",
"ra_57": "<unused57>",
"ra_58": "<unused58>",
"ra_59": "<unused59>",
"ra_60": "<unused60>",
"ra_61": "<unused61>",
"ra_62": "<unused62>",
"ra_63": "<unused63>",
"ra_64": "<unused64>",
"ra_65": "<unused65>",
"ra_66": "<unused66>",
"ra_67": "<unused67>",
"ra_68": "<unused68>",
"ra_69": "<unused69>",
"ra_70": "<unused70>",
"ra_71": "<unused71>",
"ra_72": "<unused72>",
"ra_73": "<unused73>",
"ra_74": "<unused74>",
"ra_75": "<unused75>",
"ra_76": "<unused76>",
"ra_77": "<unused77>",
"ra_78": "<unused78>",
"ra_79": "<unused79>",
"ra_80": "<unused80>",
"ra_81": "<unused81>",
"ra_82": "<unused82>",
"ra_83": "<unused83>",
"ra_84": "<unused84>",
"ra_85": "<unused85>",
"ra_86": "<unused86>",
"ra_87": "<unused87>",
"ra_88": "<unused88>",
"ra_89": "<unused89>",
"ra_90": "<unused90>",
"ra_91": "<unused91>",
"ra_92": "<unused92>",
"ra_93": "<unused93>",
"ra_94": "<unused94>",
"ra_95": "<unused95>",
"ra_96": "<unused96>",
"ra_97": "<unused97>",
"ra_98": "<unused98>",
"ra_99": "<start_of_turn>",
"ra_100": "<end_of_turn>",
"ra_101": "<table>",
"ra_102": "<caption>",
"ra_103": "<thead>",
"ra_104": "<tbody>",
"ra_105": "<tfoot>",
"ra_106": "<tr>",
"ra_107": "<th>",
"ra_108": "<td>",
"ra_109": "</table>",
"ra_110": "</caption>",
"ra_111": "</thead>",
"ra_112": "</tbody>",
"ra_113": "</tfoot>",
"ra_114": "</tr>",
"ra_115": "</th>",
"ra_116": "</td>",
"ra_117": "<h1>",
"ra_118": "<h2>",
"ra_119": "<h3>",
"ra_120": "<h4>",
"ra_121": "<h5>",
"ra_122": "<h6>",
"ra_123": "<blockquote>",
"ra_124": "</h1>",
"ra_125": "</h2>",
"ra_126": "</h3>",
"ra_127": "</h4>",
"ra_128": "</h5>",
"ra_129": "</h6>",
"ra_130": "</blockquote>",
"ra_131": "<strong>",
"ra_132": "<em>",
"ra_133": "<b>",
"ra_134": "<i>",
"ra_135": "<u>",
"ra_136": "<s>",
"ra_137": "<sub>",
"ra_138": "<sup>",
"ra_139": "<code>",
"ra_140": "</strong>",
"ra_141": "</em>",
"ra_142": "</b>",
"ra_143": "</i>",
"ra_144": "</u>",
"ra_145": "</s>",
"ra_146": "</sub>",
"ra_147": "</sup>",
"ra_148": "</code>",
"ra_149": "<0x00>",
"ra_150": "<0x01>",
"ra_151": "<0x02>",
"ra_152": "<0x03>",
"ra_153": "<0x04>",
"ra_154": "<0x05>",
"ra_155": "<0x06>",
"ra_156": "<0x07>",
"ra_157": "<0x08>",
"ra_158": "<0x0A>",
"ra_159": "<0x0B>",
"ra_160": "<0x0C>",
"ra_161": "<0x0D>",
"ra_162": "<0x0E>",
"ra_163": "<0x0F>",
"ra_164": "<0x10>",
"ra_165": "<0x11>",
"ra_166": "<0x12>",
"ra_167": "<0x13>",
"ra_168": "<0x14>",
"ra_169": "<0x15>",
"ra_170": "<0x16>",
"ra_171": "<0x17>",
"ra_172": "<0x18>",
"ra_173": "<0x19>",
"ra_174": "<0x1A>",
"ra_175": "<0x1B>",
"ra_176": "<0x1C>",
"ra_177": "<0x1D>",
"ra_178": "<0x1E>",
"ra_179": "<0x1F>",
"ra_180": "<0x20>",
"ra_181": "<0x21>",
"ra_182": "<0x22>",
"ra_183": "<0x23>",
"ra_184": "<0x24>",
"ra_185": "<0x25>",
"ra_186": "<0x26>",
"ra_187": "<0x27>",
"ra_188": "<0x28>",
"ra_189": "<0x29>",
"ra_190": "<0x2A>",
"ra_191": "<0x2B>",
"ra_192": "<0x2C>",
"ra_193": "<0x2D>",
"ra_194": "<0x2E>",
"ra_195": "<0x2F>",
"ra_196": "<0x30>",
"ra_197": "<0x31>",
"ra_198": "<0x32>",
"ra_199": "<0x33>",
"ra_200": "<0x34>",
"ra_201": "<0x35>",
"ra_202": "<0x36>",
"ra_203": "<0x37>",
"ra_204": "<0x38>",
"ra_205": "<0x39>",
"ra_206": "<0x3A>",
"ra_207": "<0x3B>",
"ra_208": "<0x3C>",
"ra_209": "<0x3D>",
"ra_210": "<0x3E>",
"ra_211": "<0x3F>",
"ra_212": "<0x40>",
"ra_213": "<0x41>",
"ra_214": "<0x42>",
"ra_215": "<0x43>",
"ra_216": "<0x44>",
"ra_217": "<0x45>",
"ra_218": "<0x46>",
"ra_219": "<0x47>",
"ra_220": "<0x48>",
"ra_221": "<0x49>",
"ra_222": "<0x4A>",
"ra_223": "<0x4B>",
"ra_224": "<0x4C>",
"ra_225": "<0x4D>",
"ra_226": "<0x4E>",
"ra_227": "<0x4F>",
"ra_228": "<0x50>",
"ra_229": "<0x51>",
"ra_230": "<0x52>",
"ra_231": "<0x53>",
"ra_232": "<0x54>",
"ra_233": "<0x55>",
"ra_234": "<0x56>",
"ra_235": "<0x57>",
"ra_236": "<0x58>",
"ra_237": "<0x59>",
"ra_238": "<0x5A>",
"ra_239": "<0x5B>",
"ra_240": "<0x5C>",
"ra_241": "<0x5D>",
"ra_242": "<0x5E>",
"ra_243": "<0x5F>",
"ra_244": "<0x60>",
"ra_245": "<0x61>",
"ra_246": "<0x62>",
"ra_247": "<0x63>",
"ra_248": "<0x64>",
"ra_249": "<0x65>",
"ra_250": "<0x66>",
"ra_251": "<0x67>",
"ra_252": "<0x68>",
"ra_253": "<0x69>",
"ra_254": "<0x6A>",
"ra_255": "<0x6B>"
}
Loading