diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 243e82d4..a923e300 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -6,6 +6,7 @@ from __future__ import annotations import os +import warnings from dataclasses import asdict from torch.utils.tensorboard import SummaryWriter @@ -25,17 +26,19 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): run_name = os.path.split(log_dir)[-1] try: - project = cfg["wandb_project"] + project = cfg["wandb_kwargs"]["project"] except KeyError: raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") - try: - entity = os.environ["WANDB_USERNAME"] + entity = cfg["wandb_kwargs"]["entity"] except KeyError: entity = None + warnings.warn("No entity specified for wandb logging. Defaulting to None.") + if "name" not in cfg["wandb_kwargs"]: + cfg["wandb_kwargs"]["name"] = run_name + warnings.warn(f"No name specified for wandb logging. Defaulting to {run_name}.") - # Initialize wandb - wandb.init(project=project, entity=entity, name=run_name) + wandb.init(**cfg["wandb_kwargs"]) # Add log directory to wandb wandb.config.update({"log_dir": log_dir})