|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +""" |
| 5 | +This module defines a custom callback PrintAndSaveStats to monitor and save various statistics during the training of a |
| 6 | +MLP model. The callback logs information such as epoch timings, accuracy, loss, and metrics like precision and recall. |
| 7 | +It also computes aggregates like total training time and best accuracy achieved. Additionally, it writes these |
| 8 | +statistics to a file and logs them for TensorBoard visualization. The get_callbacks function generates a list of |
| 9 | +callbacks including Early Stopping, model checkpointing, the custom PrintAndSaveStats, and TensorBoard logging, |
| 10 | +tailored for a specific model with given parameters. |
| 11 | +""" |
| 12 | + |
| 13 | +from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard |
| 14 | +import tensorflow as tf |
| 15 | +import datetime |
| 16 | +import time |
| 17 | +from parameters import get_tensorboard_path |
| 18 | + |
| 19 | + |
| 20 | +class PrintAndSaveStats(tf.keras.callbacks.Callback): |
| 21 | + |
| 22 | + def __init__(self, model_name): |
| 23 | + self.epoch_time_start = None |
| 24 | + self.model_name = model_name |
| 25 | + self.total_time = 0 |
| 26 | + self.last_epoch = 1 |
| 27 | + self.best_acc = 0 |
| 28 | + self.best_epoch = 1 |
| 29 | + self.first_acc = 0 |
| 30 | + self.last_acc = 0 |
| 31 | + self.last_loss = 0 |
| 32 | + self.last_f1_micro = 0 |
| 33 | + self.last_f1_macro = 0 |
| 34 | + self.last_precision = 0 |
| 35 | + self.last_recall = 0 |
| 36 | + |
| 37 | + def on_epoch_begin(self, batch, logs={}): |
| 38 | + self.epoch_time_start = time.time() |
| 39 | + |
| 40 | + def on_epoch_end(self, epoch, logs): |
| 41 | + epoch += 1 |
| 42 | + if epoch == 1: |
| 43 | + self.first_acc = logs["val_accuracy"] |
| 44 | + print('Epoch {} finished at {}'.format(epoch, datetime.datetime.now().time())) |
| 45 | + print(f"Printing log object:\n{logs}") |
| 46 | + elapsed_time = int((time.time() - self.epoch_time_start)) |
| 47 | + print(f"Elaspsed time: {elapsed_time}") |
| 48 | + if logs["loss"] != 0: |
| 49 | + print("val/train loss: {:.2f}".format(logs["val_loss"] / logs["loss"])) |
| 50 | + if logs["accuracy"] != 0: |
| 51 | + print("val/train acc: {:.2f}".format(logs["val_accuracy"] / logs["accuracy"])) |
| 52 | + file1 = open(get_history_path(self.model_name), "a") # append mode |
| 53 | + SEPARATOR = ";" |
| 54 | + file1.write(str(epoch) + SEPARATOR + str(datetime.datetime.now().time()) + SEPARATOR + |
| 55 | + str(elapsed_time) + SEPARATOR + str(logs["accuracy"]) + SEPARATOR + |
| 56 | + str(logs["val_accuracy"]) + SEPARATOR + str(logs["loss"]) + SEPARATOR + str(logs["val_loss"]) |
| 57 | + + "\n") |
| 58 | + file1.close() |
| 59 | + self.compute_aggregates(elapsed_time, logs["val_accuracy"], epoch) |
| 60 | + |
| 61 | + self.last_acc = logs["val_accuracy"] |
| 62 | + self.last_loss = logs["val_loss"] |
| 63 | + # self.last_f1_micro = logs["val_f1_micro"] |
| 64 | + # self.last_f1_macro = logs["val_f1_macro"] |
| 65 | + self.last_precision = logs["val_precision"] |
| 66 | + self.last_recall = logs["val_recall"] |
| 67 | + with tf.summary.create_file_writer(get_tensorboard_path()).as_default(): |
| 68 | + tf.summary.scalar("val_accuracy", logs["val_accuracy"], step=epoch) |
| 69 | + tf.summary.scalar("val_loss", logs["val_loss"], step=epoch) |
| 70 | + tf.summary.scalar("train_accuracy", logs["accuracy"], step=epoch) |
| 71 | + tf.summary.scalar("train_loss", logs["loss"], step=epoch) |
| 72 | + tf.summary.scalar("time", elapsed_time, step=epoch) |
| 73 | + tf.summary.scalar("precision", logs["val_precision"], step=epoch) |
| 74 | + tf.summary.scalar("recall", logs["val_recall"], step=epoch) |
| 75 | + # tf.summary.scalar("f1_macro", logs["val_f1_macro"], step=epoch) |
| 76 | + # tf.summary.scalar("f1_micro", logs["val_f1_micro"], step=epoch) |
| 77 | + |
| 78 | + def compute_aggregates(self, elapsed_time: int, val_acc, epoch: int): |
| 79 | + self.total_time += elapsed_time |
| 80 | + self.last_epoch = epoch |
| 81 | + if val_acc > self.best_acc: |
| 82 | + self.best_acc = val_acc |
| 83 | + self.best_epoch = epoch |
| 84 | + |
| 85 | + def get_stats(self): |
| 86 | + return [int(self.total_time / self.last_epoch), self.first_acc, self.best_acc, self.best_epoch, self.last_epoch] |
| 87 | + |
| 88 | + |
| 89 | +def get_history_path(model_name: str): |
| 90 | + return model_name + "_history.csv" |
| 91 | + |
| 92 | + |
| 93 | +def get_best_model_path(model_name: str): |
| 94 | + return model_name + "_checkpoint.h5" |
| 95 | + |
| 96 | + |
| 97 | +def get_callbacks(model_name: str, early_patience: int) -> list: |
| 98 | + early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=early_patience, |
| 99 | + restore_best_weights=True, verbose=1) |
| 100 | + save_best_model = ModelCheckpoint(get_best_model_path(model_name), save_best_only=True, monitor="val_loss", verbose=1) |
| 101 | + save_model_stats = PrintAndSaveStats(model_name) |
| 102 | + tensorboard = TensorBoard(get_tensorboard_path()) |
| 103 | + return [save_best_model, save_model_stats, early_stopping, tensorboard] |
0 commit comments