Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vissl/config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ config:
# ----------------------------------------------------------------------------------- #
TENSORBOARD_SETUP:
# whether to use tensorboard for the visualization
VISUALIZATION_SAMPLE_PERIOD: -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment what this is.
nit: move comment below USE_TENSORBOARD

USE_TENSORBOARD: False
# log directory for tensorboard events
LOG_DIR: "."
Expand Down
5 changes: 1 addition & 4 deletions vissl/engines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,13 @@ def train_main(
print_cfg(cfg)
logging.info("System config:\n{}".format(collect_env_info()))

# get the hooks - these hooks are executed per replica
hooks = hook_generator(cfg)

# build the SSL trainer. The trainer first prepares a "task" object which
# acts as a container for various things needed in a training: datasets,
# dataloader, optimizers, losses, hooks, etc. "Task" will also have information
# about phases (train, test) both. The trainer then sets up distributed
# training.
trainer = SelfSupervisionTrainer(
cfg, dist_run_id, checkpoint_path, checkpoint_folder, hooks
cfg, dist_run_id, checkpoint_path, checkpoint_folder, hook_generator,
)
trainer.train()
logging.info("All Done!")
Expand Down
4 changes: 2 additions & 2 deletions vissl/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SSLClassyHookFunctions(Enum):
on_end = auto()


def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
def default_hook_generator(cfg: AttrDict, event_storage) -> List[ClassyHook]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Type hint for event_storage.

"""
The utility function that prepares all the hoooks that will be used in training
based on user selection. Some basic hooks are used by default.
Expand Down Expand Up @@ -114,7 +114,7 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
"If using conda and you prefer conda install of tensorboard: "
"`conda install -c conda-forge tensorboard`"
)
tb_hook = get_tensorboard_hook(cfg)
tb_hook = get_tensorboard_hook(cfg, event_storage)
hooks.extend([tb_hook])
if cfg.MODEL.GRAD_CLIP.USE_GRAD_CLIP:
hooks.extend(
Expand Down
46 changes: 16 additions & 30 deletions vissl/hooks/log_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
"""

import atexit
import datetime
import json
import logging
import time
from typing import Optional
Expand All @@ -19,6 +17,7 @@
from fvcore.common.file_io import PathManager
from vissl.utils.checkpoint import CheckpointWriter, is_checkpoint_phase
from vissl.utils.env import get_machine_local_and_dist_rank
from vissl.utils.events import VisslEventStorage
from vissl.utils.io import save_file
from vissl.utils.logger import log_gpu_stats
from vissl.utils.perf_stats import PerfStats
Expand Down Expand Up @@ -165,6 +164,7 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
train_phase_idx = task.train_phase_idx
log_freq = task.config["LOG_FREQUENCY"]
iteration = task.iteration
evt_stg: VisslEventStorage = task.event_storage

if torch.cuda.is_available():
peak_mem_used = int(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
Expand All @@ -184,49 +184,35 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
avg_time = sum(batch_times) / len(batch_times)

eta_secs = avg_time * (task.max_iteration - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_secs)))
if isinstance(task.optimizer.options_view.lr, set):
lr_val = list(task.optimizer.options_view.lr)
else:
lr_val = round(task.optimizer.options_view.lr, 5)
batch_time = int(1000.0 * avg_time)
rank = get_rank()
log_data = {
"Rank": rank,
"ep": train_phase_idx,
"iter": iteration,
"lr": lr_val,
"loss": loss_val,
"btime(ms)": batch_time,
"eta": eta_string,
"peak_mem(M)": peak_mem_used,
}
evt_stg.put_scalars(
rank=rank,
epoch=train_phase_idx,
iteration=iteration,
lr=lr_val,
loss=loss_val,
batch_time=batch_time,
eta=eta_secs,
peak_mem_used=peak_mem_used,
)
if self.btime_freq and len(batch_times) >= self.btime_freq:
rolling_avg_time = (
sum(batch_times[-self.btime_freq :]) / self.btime_freq
)
rolling_eta_secs = int(
rolling_avg_time * (task.max_iteration - iteration)
)
rolling_eta_str = str(
datetime.timedelta(seconds=int(rolling_eta_secs))
)
rolling_btime = int(1000.0 * rolling_avg_time)
log_data[f"btime({self.btime_freq}iters)(ms)"] = rolling_btime
log_data["rolling_eta"] = rolling_eta_str

# to maintain the backwards compatibility with the log.txt
# logs, we convert the json to the previous format.
# the stdout.json can be used to use the json format of logs.
stdout_data = ""
for key, value in log_data.items():
stdout_data = (
f"{stdout_data}[{key}: {value}] "
if key == "ep"
else f"{stdout_data}{key}: {value}; "
evt_stg.put_scalars(
rolling_btime=rolling_btime, rolling_eta=rolling_eta_secs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer keeping all times in human readable string. We may want to add a VisslEventStorage#put_string or change #put_scalars to #put_values and make it type agnostic.

)
logging.info(stdout_data.strip())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I believe we lose the logging to stdout here, which imo we should definitely keep -- super helpful for debugging.

We could create another EventStorageWriter called StdOutWriter and add to event_storage_writers.

self.json_stdout_logger.write(json.dumps(log_data) + "\n")
for writer in task.event_storage_writers:
writer.write()


class LogLossMetricsCheckpointHook(ClassyHook):
Expand Down
47 changes: 20 additions & 27 deletions vissl/hooks/tensorboard_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ActivationStatisticsMonitor,
ActivationStatisticsObserver,
)
from vissl.utils.events import VisslEventStorage


try:
Expand Down Expand Up @@ -62,6 +63,7 @@ class SSLTensorboardHook(ClassyHook):
def __init__(
self,
tb_writer: SummaryWriter,
event_storage: VisslEventStorage,
log_params: bool = False,
log_params_every_n_iterations: int = -1,
log_params_gradients: bool = False,
Expand All @@ -86,6 +88,7 @@ def __init__(
)
logging.info("Setting up SSL Tensorboard Hook...")
self.tb_writer = tb_writer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need tb_writer anymore in this class? Can we refactor ActivationStatisticsMonitor to use event_storage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming: none of these methods write to tensorboard anymore and that it's only handled in TensorBoardWriter#Write.

self.event_storage = event_storage
self.log_params = log_params
self.log_params_every_n_iterations = log_params_every_n_iterations
self.log_params_gradients = log_params_gradients
Expand Down Expand Up @@ -137,9 +140,7 @@ def on_forward(self, task: "tasks.ClassyTask") -> None:
and task.iteration % self.log_params_every_n_iterations == 0
):
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}", parameter, global_step=task.iteration
)
self.event_storage.put_histogram(f"Parameters/{name}", parameter)

def on_phase_start(self, task: "tasks.ClassyTask") -> None:
"""
Expand All @@ -153,9 +154,7 @@ def on_phase_start(self, task: "tasks.ClassyTask") -> None:
# log the parameters just once, before training starts
if is_primary() and task.train and task.train_phase_idx == 0:
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}", parameter, global_step=-1
)
self.event_storage.put_histogram(f"Parameters/{name}", parameter)

def on_phase_end(self, task: "tasks.ClassyTask") -> None:
"""
Expand All @@ -172,7 +171,7 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None:
for top_n, accuracies in meter.value.items():
for i, acc in accuracies.items():
tag_name = f"{phase_type}/Accuracy_" f" {top_n}_Output_{i}"
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag=tag_name,
scalar_value=round(acc, 5),
global_step=task.train_phase_idx,
Expand All @@ -184,20 +183,16 @@ def on_phase_end(self, task: "tasks.ClassyTask") -> None:
# Log the weights and bias at the end of the epoch
if self.log_params:
for name, parameter in task.base_model.named_parameters():
self.tb_writer.add_histogram(
f"Parameters/{name}",
parameter,
global_step=task.train_phase_idx,
self.event_storage.put_histogram(
f"Parameters/{name}", parameter,
)
# Log the parameter gradients at the end of the epoch
if self.log_params_gradients:
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
self.tb_writer.add_histogram(
f"Gradients/{name}",
parameter.grad,
global_step=task.train_phase_idx,
self.event_storage.put_histogram(
f"Gradients/{name}", parameter.grad,
)
except ValueError:
logging.info(
Expand Down Expand Up @@ -235,10 +230,8 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
for name, parameter in task.base_model.named_parameters():
if parameter.grad is not None:
try:
self.tb_writer.add_histogram(
f"Gradients/{name}",
parameter.grad,
global_step=task.iteration,
self.event_storage.put_histogram(
f"Gradients/{name}", parameter.grad,
)
except ValueError:
logging.info(
Expand All @@ -251,13 +244,13 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
iteration <= 100 and iteration % 5 == 0
):
logging.info(f"Logging metrics. Iteration {iteration}")
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Training/Loss",
scalar_value=round(task.last_batch.loss.data.cpu().item(), 5),
global_step=iteration,
)

self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Training/Learning_rate",
scalar_value=round(task.optimizer.options_view.lr, 5),
global_step=iteration,
Expand All @@ -270,7 +263,7 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
batch_times = [0]

batch_time_avg_s = sum(batch_times) / max(len(batch_times), 1)
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Speed/Batch_processing_time_ms",
scalar_value=int(1000.0 * batch_time_avg_s),
global_step=iteration,
Expand All @@ -285,7 +278,7 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
if batch_time_avg_s > 0
else 0.0
)
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Speed/img_per_sec_per_gpu",
scalar_value=pic_per_batch_per_gpu_per_sec,
global_step=iteration,
Expand All @@ -294,7 +287,7 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
# ETA
avg_time = sum(batch_times) / len(batch_times)
eta_secs = avg_time * (task.max_iteration - iteration)
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Speed/ETA_hours",
scalar_value=eta_secs / 3600.0,
global_step=iteration,
Expand All @@ -303,21 +296,21 @@ def on_update(self, task: "tasks.ClassyTask") -> None:
# GPU Memory
if torch.cuda.is_available():
# Memory actually being used
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Memory/Peak_GPU_Memory_allocated_MiB",
scalar_value=torch.cuda.max_memory_allocated() / BYTE_TO_MiB,
global_step=iteration,
)

# Memory reserved by PyTorch's memory allocator
self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Memory/Peak_GPU_Memory_reserved_MiB",
scalar_value=torch.cuda.max_memory_reserved()
/ BYTE_TO_MiB, # byte to MiB
global_step=iteration,
)

self.tb_writer.add_scalar(
self.event_storage.put_scalars(
tag="Memory/Current_GPU_Memory_reserved_MiB",
scalar_value=torch.cuda.memory_reserved()
/ BYTE_TO_MiB, # byte to MiB
Expand Down
1 change: 1 addition & 0 deletions vissl/models/trunks/regnet_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from classy_vision.models.regnet import RegNetParams
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP, auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap

from vissl.config import AttrDict
from vissl.models.model_helpers import (
Flatten,
Expand Down
7 changes: 7 additions & 0 deletions vissl/trainer/train_steps/standard_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ def standard_train_step(task):
sample = next(task.data_iterator)

sample = construct_sample_for_model(sample, task)
vis_period = task.config.HOOKS.TENSORBOARD_SETUP.VISUALIZATION_SAMPLE_PERIOD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we comment this code block?

if vis_period > 0 and task.iteration % vis_period == 0:
storage = task.event_storage
storage.put_images()
name = f"Model input sample: iteration: {task.iteration_num}"
for idx, vis_img in enumerate(sample["input"]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: If batch size is large, this will upload a lot of images -- do we want to offer the option to only sample a small portion of the inputs?

storage.put_image(name + f" ({idx})", vis_img)

# Only need gradients during training
grad_context = torch.enable_grad() if task.train else torch.no_grad()
Expand Down
24 changes: 24 additions & 0 deletions vissl/trainer/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from classy_vision.tasks.classification_task import AmpType, BroadcastBuffersMode
from fvcore.common.file_io import PathManager
from torch.cuda.amp import GradScaler as TorchGradScaler

from vissl.config import AttrDict
from vissl.data import (
build_dataset,
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self, config: AttrDict):
self.phase_idx = -1
# id of the current training phase training is at. Starts from 0
self.train_phase_idx = -1 # set by trainer
self._event_storage = None
# metrics stored during the training.
self.metrics = {} # set by the trainer
self.start_time = -1 # set by trainer
Expand Down Expand Up @@ -116,6 +118,28 @@ def __init__(self, config: AttrDict):
# communication as much as possible
self.set_ddp_bucket_cap_mb()
self.use_gpu = self.device.type == "cuda"
self.event_storage_writers = []

def initiate_vissl_event_storage(self):
from vissl.utils.events import create_event_storage, get_event_storage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: I prefer inputs up top usually unless necessary.


create_event_storage()
self._event_storage = get_event_storage()

def build_event_storage_writers(self):
from vissl.utils.events import JsonWriter, TensorboardWriter

flush_secs = self.config.HOOKS.TENSORBOARD_SETUP.FLUSH_EVERY_N_MIN * 60
checkpoint_dir = self.config.CHECKPOINT.DIR

self.event_storage_writers = [
JsonWriter(f"{self.checkpoint_folder}/stdout.json"),
TensorboardWriter(checkpoint_dir, flush_secs),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my reading of this, TensorBoardWriter will always write, regardless of USE_TENSORBOARD. Am I missing something?

]

@property
def event_storage(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use self.event_storage internally and externally.

return self._event_storage

def set_device(self):
"""
Expand Down
Loading