Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added pickled_task_with_embs.pkl
Binary file not shown.
39 changes: 35 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pufferlib.emulation
import pufferlib.frameworks.cleanrl
import pufferlib.registry.nmmo
from nmmo.task.task_api import make_team_tasks
import torch

import clean_pufferl
Expand Down Expand Up @@ -61,7 +62,7 @@
help="reset on death (default: False)")
parser.add_argument(
"--env.num_maps", dest="num_maps", type=int, default=128,
help="number of maps to use for training (default: 1)")
help="number of maps to use for training (default: 128)")
parser.add_argument(
"--env.maps_path", dest="maps_path", type=str, default="maps/train/",
help="path to maps to use for training (default: None)")
Expand Down Expand Up @@ -100,7 +101,7 @@
help="number of cores to use for training (default: num_envs)")
parser.add_argument(
"--rollout.num_envs", dest="num_envs", type=int, default=4,
help="number of environments to use for training (default: 1)")
help="number of environments to use for training (default: 4)")
parser.add_argument(
"--rollout.num_buffers", dest="num_buffers", type=int, default=4,
help="number of buffers to use for training (default: 4)")
Expand Down Expand Up @@ -142,7 +143,7 @@
parser.add_argument(
"--ppo.bptt_horizon", dest="bptt_horizon", type=int, default=8,
help="train on bptt_horizon steps of a rollout at a time. "
"use this to reduce GPU memory (default: 16)")
"use this to reduce GPU memory (default: 8)")

parser.add_argument(
"--ppo.training_batch_size",
Expand Down Expand Up @@ -198,7 +199,37 @@
)

def make_env():
return nmmo.Env(config)
import pickle as pkl
import numpy as np
import random

import os
print('cwd', os.getcwd())
with open('./pickled_task_with_embs.pkl', 'rb') as f:
task_spec = pkl.load(f)

# tasks = [d[1] for d in task_spec]
num_tasks = len(task_spec)
teams = team_helper.teams
single_task = task_spec[0]

# make_task_fn = lambda: tasks
# task_spec_sampled =np.random.choice(task_spec, size=len(teams), replace=False)
task_spec_sampled = random.sample(task_spec, len(teams))
tasks = make_team_tasks(teams, task_spec_sampled)
make_task_fn = lambda: tasks

# env = nmmo.Env(config)
class NMMOTaskWrapper(nmmo.Env):
def __init__(self, config):
super().__init__(config)

def reset(self, *args, **kwargs):
return super().reset(*args, make_task_fn=make_task_fn, **kwargs)

env = NMMOTaskWrapper(config)

return env
# if args.model_type in ["realikun", "realikun-simplified"]:
# env = NMMOTeamEnv(
# config, team_helper, rewards_config, moves_only=args.moves_only)
Expand Down