-
Notifications
You must be signed in to change notification settings - Fork 190
Glulx game logger #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
66d1776
f8de962
336354b
2522015
862cba1
ad262f8
a4ca143
eb4e0c5
5c498d5
2cd2d92
1c99564
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT license. | ||
|
||
from typing import Tuple, List, Optional, Iterable, Union, Sized, Any, Mapping | ||
|
||
from textworld.core import Environment, GameState, Wrapper | ||
from textworld.envs.glulx.git_glulx_ml import GitGlulxMLEnvironment, GlulxGameState | ||
|
||
|
||
class GlulxLogger(Wrapper): | ||
def __init__(self, env: GitGlulxMLEnvironment) -> None: | ||
""" | ||
Wrap around a TextWorld GitGlulxML environment to provide logging capabilities. | ||
|
||
Parameters | ||
---------- | ||
:param env: | ||
The GitGlulxML environment to wrap. | ||
""" | ||
super().__init__(env) | ||
self.activate_state_tracking() | ||
|
||
self.serialized_game = env.game.serialize() | ||
self._gamefile = env.gamefile | ||
|
||
|
||
def step(self, command: str) -> Tuple[GlulxGameState, float, bool]: | ||
""" | ||
Take a step in the environment, save needed information. | ||
:param command: | ||
input string for taking an action | ||
:return: | ||
GlulxGameState, score and done. | ||
""" | ||
self._logs.append(self._current) | ||
self._current = {'optional': []} | ||
|
||
self._current['command'] = command | ||
|
||
game_state, score, done = super().step(command) | ||
self._current['feedback'] = game_state.feedback | ||
self._current['score'] = score | ||
self._current['done'] = done | ||
self._current['action'] = game_state.action.serialize() | ||
self._current['state'] = game_state.state.serialize() | ||
|
||
return game_state, score, done | ||
|
||
def reset(self) -> GameState: | ||
""" | ||
Reset the environment. | ||
Also clears logs. | ||
|
||
""" | ||
self._logs = [] | ||
|
||
game_state = super().reset() | ||
self._current = {'optional': []} | ||
self._current['done'] = False | ||
self._current['state'] = game_state.state.serialize() | ||
|
||
return game_state | ||
|
||
def add_commands(self, commands: List[str], scores: Optional[Union[Iterable[float], Sized]]=None) -> None: | ||
|
||
""" | ||
Add custom commands to the logger. Optionally add scores for each command. | ||
:param commands: | ||
|
||
A list of commands. | ||
:param scores: | ||
scores for each command. Must be same size as commands if provided. | ||
:return: | ||
""" | ||
command_mapping = commands | ||
if scores is not None: | ||
assert len(scores) == len(commands) | ||
command_mapping = {a: p for a, p in zip(commands, scores)} | ||
|
||
self._current['command_distribution'] = command_mapping | ||
|
||
|
||
def add(self, info: Any) -> None: | ||
""" | ||
Add any additional information you want to log. | ||
:param info: | ||
Additional information to log for the current game state. | ||
""" | ||
self._current['optional'].append(info) | ||
|
||
@property | ||
def current(self) -> Mapping: | ||
return self._current | ||
|
||
@property | ||
def logs(self) -> List[Mapping]: | ||
""" | ||
Get all logs | ||
:return: List of all logs | ||
""" | ||
logs = self._logs[:] | ||
|
||
logs.append(self._current) | ||
return logs | ||
|
||
@property | ||
def gamefile(self): | ||
|
||
return self._gamefile | ||
|
||
def __getitem__(self, index: int) -> Mapping: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need this anymore. |
||
""" | ||
Get a certain log at a given index. | ||
:param index: | ||
index of log to get. | ||
:return: | ||
log at index. | ||
""" | ||
assert index <= len(self._logs) | ||
|
||
if index < len(self._logs) - 1: | ||
return self._logs[index] | ||
return self._current | ||
|
||
|
||
def __str__(self) -> str: | ||
return str(self.logs) | ||
|
||
def serialize(self) -> List[Mapping]: | ||
""" | ||
Get serialized mappings of logs. | ||
:return: List of serialized mappings. | ||
""" | ||
return self.logs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT license. | ||
|
||
import textworld | ||
import numpy as np | ||
|
||
from textworld.envs.wrappers import GlulxLogger | ||
from textworld.utils import make_temp_directory | ||
from textworld.generator import compile_game | ||
from textworld import g_rng | ||
|
||
|
||
def test_glulx_logger(): | ||
num_nodes = 3 | ||
num_items = 10 | ||
g_rng.set_seed(1234) | ||
grammar_flags = {"theme": "house", "include_adj": True} | ||
game = textworld.generator.make_game(world_size=num_nodes, nb_objects=num_items, quest_length=3, grammar_flags=grammar_flags) | ||
|
||
game_name = "test_glulx_logger" | ||
with make_temp_directory(prefix=game_name) as tmpdir: | ||
game_file = compile_game(game, game_name, games_folder=tmpdir) | ||
|
||
env = textworld.start(game_file) | ||
env = GlulxLogger(env) | ||
env.activate_state_tracking() | ||
game_state = env.reset() | ||
|
||
# test reset | ||
assert 'state' in env.current | ||
|
||
# test step | ||
options = game_state.admissible_commands | ||
game_state, score, done = env.step(options[0]) | ||
assert len(env.logs) > 1 | ||
assert 'action' in env.current | ||
assert 'state' in env.current | ||
assert 'feedback' in env.current | ||
|
||
# test add_commands | ||
option_scores = np.array([0.1] * len(options)) | ||
env.add_commands(options, option_scores) | ||
assert len(env.current['command_distribution'].values()) == len(options) | ||
|
||
# test add | ||
additional_info = {'scores': option_scores} | ||
env.add(additional_info) | ||
assert len(env.current['optional']) > 0 | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert unneeded change.