Purpose and Goal
The primary purpose of this script is to build, train, evaluate, and explain a hybrid deep learning model for a classification task. The model combines Convolutional Neural Networks (CNNs), Long Short-Term Memory networks (LSTMs), and an Attention mechanism.
The key goals are:
- Classification: Train a model to classify input data (presumably sequences or feature vectors) into one of
OUTPUT_SIZEcategories. - Hybrid Architecture: Leverage the strengths of different neural network types:
- CNNs: To automatically extract spatial or local features from the input sequence (treating the features as a 1D sequence).
- LSTMs: To model temporal dependencies or sequential patterns in the features extracted by the CNN (using a bidirectional LSTM captures patterns in both forward and backward directions).
- Attention: To allow the model to focus on the most relevant parts of the sequence processed by the LSTM when making the final classification decision.
- Explainability (XAI): Go beyond just prediction accuracy and understand why the model makes certain predictions. This is achieved using:
- SHAP (SHapley Additive exPlanations): To quantify the contribution of each input feature to the model's output for specific predictions and globally.
- Grad-CAM (Gradient-weighted Class Activation Mapping): To visualize which parts of the input sequence were most important for the CNN layers when classifying a specific sample.
- Attention Weights Visualization: To visualize the internal focus of the Attention layer within the sequence processed by the LSTM.
- Reproducibility & Reusability: Save the trained model (weights and hyperparameters) so it can be loaded and used later without retraining.
- Performance Monitoring: Track and visualize training and testing loss/accuracy over epochs to monitor the learning process.
Step-by-Step Implementation Process
Here's a breakdown of the script's sections with code snippets and explanations:
-
Imports:
- Standard libraries like
torch,torch.nn,torch.optim,DataLoader,matplotlib,pandas,numpy,seaborn,randomare imported for deep learning, data handling, plotting, and utility functions. shapis imported for explainability analysis.sklearn.metricsis used for evaluation metrics like confusion matrix and classification report.
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Subset import torch.nn.functional as F import matplotlib.pyplot as plt import pandas as pd import numpy as np import shap from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns import random import os # Added for checking file existence in loading part # Assuming data_loader.py contains CustomDataset from data_loader import CustomDataset # Adjust import as needed
- Standard libraries like
-
Custom Dataset Assumption:
- The script assumes you have a
CustomDatasetclass (likely defined indata_loader.py) that inherits fromtorch.utils.data.Dataset. - This class is responsible for loading data (e.g., from a CSV), extracting features and labels, and returning them as tensors.
- Crucially, for classification with
nn.CrossEntropyLoss, the__getitem__method should return(data_tensor, label_tensor), wherelabel_tensoris a single integer (scalar tensor of typetorch.long) representing the class index (0, 1, 2,...). - The example provided shows how such a dataset might look, including optional attributes
feature_namesandclass_nameswhich are very useful for interpreting the explainability results.
- The script assumes you have a
-
Configuration:
- Sets up all the important parameters and constants for the experiment.
INPUT_SIZE: Number of features in each input sample.CNN_CHANNELS,LSTM_HIDDEN_SIZE,LSTM_NUM_LAYERS,NUM_HEADS_ATTENTION: Hyperparameters defining the model architecture.OUTPUT_SIZE: Number of distinct classes the model should predict.LEARNING_RATE,BATCH_SIZE,EPOCHS: Training parameters.DEVICE: Automatically selects GPU (cuda) if available, otherwise CPU.TRAIN_CSV_PATH,TEST_CSV_PATH: Paths to the data files.SAVED_MODEL_PATH,RESULTS_CSV_PATH,PLOT_SAVE_PATH: Paths for saving outputs.NUM_SHAP_BACKGROUND_SAMPLES,NUM_SAMPLES_TO_EXPLAIN: Parameters for the explainability analysis.- It attempts to load
FEATURE_NAMESandCLASS_NAMESfrom the dataset, falling back to generic names if the attributes don't exist. This is good practice for labeling plots and outputs later.
# --- Configuration --- INPUT_SIZE = 4 CNN_CHANNELS = 16 LSTM_HIDDEN_SIZE = 32 LSTM_NUM_LAYERS = 2 OUTPUT_SIZE = 4 NUM_HEADS_ATTENTION = 4 LEARNING_RATE = 0.005 BATCH_SIZE = 32 EPOCHS = 50 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TRAIN_CSV_PATH = "train_set.csv" TEST_CSV_PATH = "test_set.csv" SAVED_MODEL_PATH = 'saved_model_hybrid_explained.pth' RESULTS_CSV_PATH = "train_results_hybrid_explained.csv" PLOT_SAVE_PATH = "train_plots_hybrid_explained.eps" NUM_SHAP_BACKGROUND_SAMPLES = 100 NUM_SAMPLES_TO_EXPLAIN = 5 # Define feature and class names try: temp_dataset = CustomDataset(TRAIN_CSV_PATH) FEATURE_NAMES = temp_dataset.feature_names CLASS_NAMES = temp_dataset.class_names # ... (rest of try block) except AttributeError: # ... (fallback names)
-
Model Definition (
HybridCNNLSTMAttention):- Defines the neural network architecture inheriting from
nn.Module. __init__:- Initializes CNN layers (
nn.Conv1d,nn.ReLU,nn.BatchNorm1d,nn.MaxPool1d). BatchNorm is added for better stability and faster training. MaxPool reduces the sequence length. - Dynamically calculates the sequence length output by the CNN (
cnn_output_seq_len) based onINPUT_SIZEand pooling layers. This is important for defining the subsequent LSTM layer correctly. - Initializes a bidirectional LSTM layer (
nn.LSTM).batch_first=Truemakes tensor handling easier.bidirectional=Truedoubles the effective hidden size. - Initializes a
nn.MultiheadAttentionlayer. It takes the LSTM outputs as query, key, and value (self-attention). - Initializes the final fully connected layer (
nn.Linear) to map the processed sequence representation to theOUTPUT_SIZElogits. - Initializes
feature_mapsandgradientstoNoneand definesactivations_hook. These are specifically for capturing intermediate results needed for Grad-CAM.
- Initializes CNN layers (
forward(self, x):- Defines the data flow through the layers.
- Input
xshape:[batch_size, input_size]. x.unsqueeze(1): Reshapes input forConv1dto[batch_size, 1, input_size].- Passes through CNN layers.
- Grad-CAM Hook: Registers the
activations_hookto the output of the last CNN layer's activation (x_cnn). This hook will store the gradients flowing back into this layer duringbackward(). The feature map itself is also stored. This is only done during evaluation (self.training is False) if gradients are needed. x_cnn.permute(0, 2, 1): Reshapes the CNN output to[batch_size, seq_len_after_cnn, cnn_channels*2]suitable for the LSTM layer (batch_first=True).- Passes through the bidirectional LSTM.
- Passes the LSTM output sequence through the
MultiheadAttentionlayer. - Context Aggregation: Aggregates the attention output sequence (
attn_output) into a single vector (context_vector) per sample using mean pooling (torch.mean(attn_output, dim=1)). This vector summarizes the sequence information weighted by attention. Other options like using the last hidden state or max pooling are commented out. - Passes the
context_vectorthrough the finalnn.Linearlayer to get classification logits. - Returns the final
outputlogits and theattn_weights(useful for visualization).
class HybridCNNLSTMAttention(nn.Module): def __init__(self, input_size, cnn_channels, lstm_hidden_size, lstm_num_layers, output_size, num_heads): super(HybridCNNLSTMAttention, self).__init__() # ... (CNN layers definition) ... self.cnn_output_seq_len = input_size // 4 lstm_input_features = cnn_channels * 2 # ... (LSTM layer definition) ... lstm_output_size = lstm_hidden_size * 2 # ... (Attention layer definition) ... self.fc = nn.Linear(lstm_output_size, output_size) # ... (Grad-CAM hooks init) ... def activations_hook(self, grad): self.gradients = grad def forward(self, x): # ... (CNN forward pass) ... if self.training is False and x_cnn.requires_grad: h = x_cnn.register_hook(self.activations_hook) self.feature_maps = x_cnn # ... (LSTM forward pass) ... # ... (Attention forward pass) ... context_vector = torch.mean(attn_output, dim=1) # Aggregate attention output output = self.fc(context_vector) return output, attn_weights
- Defines the neural network architecture inheriting from
-
Training Setup:
- Instantiates the
HybridCNNLSTMAttentionmodel. - Moves the model to the selected
DEVICE. - Defines the loss function:
nn.CrossEntropyLosscombines LogSoftmax and Negative Log-Likelihood loss, suitable for multi-class classification. - Defines the optimizer:
optim.Adamis a popular adaptive learning rate optimizer. - Defines a learning rate scheduler:
optim.lr_scheduler.StepLRreduces the learning rate by a factor (gamma=0.1) everystep_sizeepochs (here, every third of the total epochs). This helps in fine-tuning the model in later stages.
model = HybridCNNLSTMAttention(...) model.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=EPOCHS // 3, gamma=0.1)
- Instantiates the
-
Data Loading:
- Creates instances of the
CustomDatasetfor training and testing data using the specified CSV paths. - Creates
DataLoaderinstances, which handle batching, shuffling (for training data), and parallel data loading.
train_dataset = CustomDataset(TRAIN_CSV_PATH) test_dataset = CustomDataset(TEST_CSV_PATH) train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
- Creates instances of the
-
Training and Testing Loop:
- Initializes lists to store loss and accuracy values for plotting.
- Iterates through the specified number of
EPOCHS. - Training Phase:
- Sets model to training mode (
model.train()). This enables dropout, batch normalization updates, etc. - Iterates through batches from
train_data_loader. - Moves input
inputsandlabelsto theDEVICE. - Clears previous gradients (
optimizer.zero_grad()). - Performs a forward pass:
outputs, _ = model(inputs). The attention weights are ignored for loss calculation. - Calculates the loss using
criterion(outputs, labels). - Performs backpropagation (
loss.backward()) to compute gradients. - Updates model weights (
optimizer.step()). - Accumulates loss and calculates the number of correct predictions for the epoch.
- Calculates and stores average training loss and accuracy for the epoch.
- Sets model to training mode (
- Testing Phase:
- Sets model to evaluation mode (
model.eval()). This disables dropout and uses running stats for batch normalization. - Disables gradient calculation (
with torch.no_grad():) for efficiency and to prevent accidental updates. - Iterates through batches from
test_data_loader. - Performs a forward pass.
- Calculates the loss.
- Calculates the number of correct predictions.
- Stores all predictions (
all_preds) and true labels (all_labels) from the test set. These are needed later for the confusion matrix and classification report. - Calculates and stores average test loss and accuracy for the epoch.
- Sets model to evaluation mode (
- Updates the learning rate using
scheduler.step(). - Prints the progress (loss, accuracy, learning rate) periodically.
train_loss_values = [] # ... (other lists init) ... for epoch in range(EPOCHS): # --- Training Phase --- model.train() # ... (training loop over batches) ... train_loss_values.append(epoch_train_loss) train_accuracy_values.append(train_accuracy) # --- Testing Phase --- model.eval() all_preds = [] all_labels = [] with torch.no_grad(): # ... (testing loop over batches) ... all_preds.extend(predicted_test.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # ... (store test metrics) ... scheduler.step() # ... (print progress) ...
-
Save Model:
- Creates a
hyperparametersdictionary containing the key architectural parameters used to build the model. Also includes feature/class names. - Creates a
checkpointdictionary containing:- The
hyperparameters. - The learned model weights (
model.state_dict()). - The last epoch number.
- Optionally, the state of the optimizer and scheduler (useful if you want to resume training later).
- The
- Saves this
checkpointdictionary to the specified.pthfile usingtorch.save(). This approach is better than saving the entire model object, as it's more portable across different code versions.
print(f"\nSaving model checkpoint to {SAVED_MODEL_PATH}...") hyperparameters = { ... } # Populate with config values checkpoint = { 'hyperparameters': hyperparameters, 'model_state_dict': model.state_dict(), # ... (optional optimizer/scheduler states) ... } torch.save(checkpoint, SAVED_MODEL_PATH)
- Creates a
-
Save Training Results:
- Creates a pandas DataFrame storing the epoch number, train/test loss, and train/test accuracy for each epoch.
- Saves this DataFrame to a CSV file.
train_info_df = pd.DataFrame({ 'epoch': range(1, EPOCHS + 1), # ... (other columns) ... }) train_info_df.to_csv(RESULTS_CSV_PATH, index=False)
-
Plotting:
- Uses
matplotlib.pyplotto create two subplots:- Loss vs. Epochs (Training and Testing)
- Accuracy vs. Epochs (Training and Testing)
- Saves the combined plot to an EPS file (vector format) and displays it.
plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) # Loss plot # ... (plotting commands) ... plt.subplot(1, 2, 2) # Accuracy plot # ... (plotting commands) ... plt.tight_layout() plt.savefig(PLOT_SAVE_PATH, format='eps') plt.show()
- Uses
-
Explainability Analysis: This is a key part demonstrating XAI techniques.
- Ensures the model is in evaluation mode (
model.eval()). - a. Confusion Matrix & Classification Report:
- Uses the
all_labelsandall_predscollected during testing. - Calculates the confusion matrix using
sklearn.metrics.confusion_matrix. - Visualizes the matrix using
seaborn.heatmap. - Prints the
sklearn.metrics.classification_report, which provides precision, recall, F1-score, and support for each class.
- Uses the
- b. SHAP Analysis:
- Background Data: Selects a random subset of the training data to serve as a background distribution for SHAP's
DeepExplainer. This background helps establish a baseline expectation. - Wrapper Function: Defines
model_predictor. This is crucial because the model'sforwardmethod returns(output, attn_weights), but SHAP explainers typically expect a function that takes a tensor and returns only the model's output tensor (logits or probabilities). - Explainer: Creates a
shap.DeepExplainerinstance, passing it themodel_predictorand thebackground_data. - Calculate SHAP Values: Selects a subset of test samples and calculates their SHAP values using
explainer.shap_values(test_samples). This can be computationally intensive. The resultshap_valuesis typically a list (one element per output class) of arrays, where each array has shape[num_samples, num_features]. - Summary Plots:
- Generates a bar plot (
plot_type="bar") showing the mean absolute SHAP value for each feature across all samples and classes (global importance). - Generates dot summary plots (
plot_type="dot", the default) for each class individually, showing the distribution of SHAP values for each feature and how high/low feature values impact the prediction for that specific class.
- Generates a bar plot (
- Force Plots:
- For a few individual test samples:
- Gets the model's prediction and true label.
- Uses
shap.force_plotto visualize how each feature contributed to pushing the model's output away from the base value (average prediction over the background dataset, stored inexplainer.expected_value) towards the actual prediction for the predicted class. - Optionally, plots the force plot for the true class if it was different from the predicted one.
- For a few individual test samples:
- Background Data: Selects a random subset of the training data to serve as a background distribution for SHAP's
- c. Grad-CAM:
get_grad_camfunction:- Takes the model, the target layer's output (feature maps captured by the hook), the target class index, and the final model output.
- Performs backpropagation specifically for the
target_class_indexscore (output[:, target_class_index].backward(retain_graph=True)).retain_graph=Truemight be needed if backpropagating multiple times. - Retrieves the gradients stored by the hook (
model.gradients). - Calculates the importance weights (alpha) by averaging gradients across the sequence dimension (
torch.mean(gradients, dim=[2])). - Computes the weighted sum of feature map channels using these weights.
- Applies ReLU (
F.relu) to keep only positive contributions. - Normalizes the resulting heatmap for visualization.
- Visualization Loop:
- Selects a few test samples.
- For each sample:
- Ensures the input tensor requires gradients (
input_tensor.requires_grad = True). - Performs a forward pass to get the prediction and trigger the activation/gradient hooks (storing
model.feature_mapsand enablingmodel.gradientsupon backward). - Calls
get_grad_camto compute the heatmap for the predicted class. - Upsamples the heatmap (which has length
cnn_output_seq_len) to the originalINPUT_SIZEusingF.interpolate(linear interpolation) because the heatmap corresponds to the CNN's output sequence, but we want to visualize it relative to the input features. - Plots the original input features and overlays the upsampled heatmap using
plt.imshowwith transparency (alpha). - Cleans up gradients/hooks (good practice).
- Ensures the input tensor requires gradients (
- d. Attention Weights Visualization:
- Selects a few test samples.
- Performs a forward pass with no gradients (
torch.no_grad()) to get theattn_weightsreturned by the model. - Extracts and potentially processes the attention weights (e.g., squeezing batch dimension, averaging over heads if it's MultiheadAttention). The shape depends on the specific Attention layer implementation. The example assumes
[batch, seq_len, seq_len]after processing/squeezing. - Uses
seaborn.heatmapto visualize the attention matrix, showing which parts of the sequence (Query Time Steps) attended to which other parts (Key Time Steps) at the output of the LSTM.
- Ensures the model is in evaluation mode (
-
Load Model Example:
- Demonstrates how to reload the saved model.
- Checks if the
SAVED_MODEL_PATHexists. - Loads the
checkpointdictionary usingtorch.load.map_location=DEVICEensures the model is loaded onto the correct device, regardless of where it was saved. - Extracts the
hyperparametersfrom the checkpoint. - Re-instantiates the model using these loaded hyperparameters. This is crucial for ensuring the architecture matches the saved weights.
- Loads the learned weights using
loaded_model.load_state_dict(checkpoint['model_state_dict']). - Moves the loaded model to the
DEVICE. - Sets the model to evaluation mode (
loaded_model.eval()) immediately after loading, as it's typically used for inference. - Includes an example of using the loaded model to make a prediction on a sample from the test set and prints the true vs. predicted class.
print("\n--- Loading Model Example ---") if os.path.exists(SAVED_MODEL_PATH): checkpoint = torch.load(SAVED_MODEL_PATH, map_location=DEVICE) loaded_hyperparameters = checkpoint['hyperparameters'] loaded_model = HybridCNNLSTMAttention( input_size=loaded_hyperparameters['input_size'], # ... (rest of params from loaded_hyperparameters) ... ) loaded_model.load_state_dict(checkpoint['model_state_dict']) loaded_model.to(DEVICE) loaded_model.eval() print("Model loaded successfully.") # ... (Test prediction with loaded model) ... else: print(f"Error: Saved model file not found at {SAVED_MODEL_PATH}. Cannot load.")
This comprehensive script provides a solid template for building, training, and thoroughly analyzing a hybrid deep learning model, placing a strong emphasis on understanding the model's behavior through various explainability techniques.
Okay, let's delve deeper into the three explainability techniques used in the script: SHAP, Grad-CAM, and Attention Weights Visualization.
-
Purpose: SHAP aims to explain the output of any machine learning model by assigning an importance value (the "SHAP value") to each input feature for a particular prediction. It tells you how much each feature contributed, positively or negatively, to pushing the model's output away from a baseline (average) prediction towards the final prediction for that specific instance. It's based on Shapley values from cooperative game theory, providing a theoretically sound way to distribute the prediction "payout" among the features.
-
Concept (for DeepExplainer): Calculating exact Shapley values is computationally infeasible for complex models.
shap.DeepExplainer(used here as it's suitable for neural networks) provides an efficient approximation. It combines ideas from DeepLIFT (another attribution method using backpropagation) and Shapley values. It requires a background dataset to represent the expected distribution of features, which helps estimate the baseline prediction (the "expected value" orexplainer.expected_value). The SHAP values are calculated by comparing the model's output when a feature is present versus when it's "absent" (approximated using the background data), considering all possible feature orderings or subsets (approximated efficiently). -
Step-by-Step Implementation (from the script):
-
Select Background Data:
- Code:
background_indices = random.sample(range(len(train_dataset)), min(NUM_SHAP_BACKGROUND_SAMPLES, len(train_dataset))) background_data = torch.stack([train_dataset[i][0] for i in background_indices]).to(DEVICE)
- Explanation: A subset of the training data is chosen randomly. This data acts as the reference or baseline distribution.
DeepExplaineruses this to understand the "average" behavior of the model and to simulate the effect of features being "absent" or unobserved when calculating contributions. The number of samples (NUM_SHAP_BACKGROUND_SAMPLES) is a trade-off between representativeness and computational cost. The data is stacked into a single tensor and moved to the correct device.
- Code:
-
Define Model Predictor Wrapper:
- Code:
def model_predictor(data_tensor): output, _ = model(data_tensor) return output
- Explanation: The
HybridCNNLSTMAttentionmodel'sforwardmethod returns(output, attn_weights). However, the SHAP explainer needs a function that takes the input tensor and returns only the tensor representing the model's predictions (logits in this case). This wrapper function serves exactly that purpose, discarding the attention weights.
- Code:
-
Instantiate Explainer:
- Code:
explainer = shap.DeepExplainer(model_predictor, background_data)
- Explanation: Creates the
DeepExplainerobject, providing it with the wrapped model function (model_predictor) and the prepared background data tensor.
- Code:
-
Select Test Samples to Explain:
- Code:
test_indices = random.sample(range(len(test_dataset)), min(NUM_SAMPLES_TO_EXPLAIN * 5, len(test_dataset))) # Get more samples for summary test_samples = torch.stack([test_dataset[i][0] for i in test_indices]).to(DEVICE) # ... (get corresponding labels if needed) ...
- Explanation: A subset of the test data is selected for explanation. More samples might be selected than strictly needed for individual force plots (
NUM_SAMPLES_TO_EXPLAIN) to provide a better basis for the summary plots.
- Code:
-
Calculate SHAP Values:
- Code:
shap_values = explainer.shap_values(test_samples)
- Explanation: This is the core computation. The explainer calculates SHAP values for each feature, for each selected test sample, and for each output class. The output
shap_valuesis typically a list, whereshap_values[i]is a NumPy array containing the SHAP values for the i-th class. Each array has the shape(num_test_samples, num_features).
- Code:
-
Generate Summary Plots:
- Code (Bar Plot - Global Importance):
shap.summary_plot(shap_values, test_samples.cpu().numpy(), feature_names=FEATURE_NAMES, class_names=CLASS_NAMES, plot_type="bar", max_display=10) # ... (title, save, show) ...
- Explanation (Bar Plot): Aggregates importance across all classes and samples by taking the mean absolute SHAP value for each feature. The resulting bar chart shows the top
max_displayfeatures ranked by their overall impact on the model's predictions. - Code (Dot Plot - Per-Class Importance):
for i, class_name in enumerate(CLASS_NAMES): plt.figure() shap.summary_plot(shap_values[i], test_samples.cpu().numpy(), feature_names=FEATURE_NAMES, show=False, max_display=10) # ... (title, save, close) ...
- Explanation (Dot Plot): This plot is generated per class. For each feature (y-axis), it plots a point for each sample. The point's position on the x-axis represents the SHAP value (impact on that class's prediction), and its color represents the original value of that feature (high/low). This reveals not just which features are important for a class, but also how their values influence the prediction (e.g., high values of Feature_3 tend to increase the prediction score for Class_1).
- Code (Bar Plot - Global Importance):
-
Generate Force Plots (Individual Explanations):
- Code:
# ... (get base_values = explainer.expected_value) ... # ... (loop through selected indices_to_plot) ... if base_values is not None and len(base_values) == OUTPUT_SIZE: shap.force_plot(base_values[predicted_label], # Base value for the predicted class shap_values[predicted_label][i,:], # SHAP values for this sample & predicted class test_samples_np[i,:], # Feature values for this sample feature_names=FEATURE_NAMES, matplotlib=True, show=False) # ... (title, save, close) ... # ... (optional: plot for true_label if different) ...
- Explanation: For a single prediction:
base_value: The average prediction score for that class across the background dataset (explainer.expected_value[predicted_label]).shap_values[...]: The SHAP values calculated for this specific sample (i) and the class of interest (predicted_label).test_samples_np[i,:]: The actual feature values for this sample.- The plot visualizes features as forces pushing the output from the
base_valuetowards the final prediction score. Red features push the score higher (increase probability of that class), blue features push it lower. The size of the feature's block indicates the magnitude of its impact.matplotlib=Trueallows saving the plot. Plotting for both the predicted and true (if different) classes helps understand misclassifications.
- Code:
-
-
Purpose: Grad-CAM is designed to produce a coarse localization map highlighting the important regions in the input space or feature maps of a CNN that the model used to predict a specific target class. It answers: "Which parts of the sequence (as seen by the target CNN layer) were most important for deciding on class Y?". It's specific to Convolutional Neural Networks.
-
Concept: It leverages the spatial information preserved in the feature maps of convolutional layers.
- It calculates the gradient of the score for the target class with respect to the feature maps of a chosen (usually the last) convolutional layer. These gradients indicate how much a change in each feature map channel affects the class score.
- It computes the average of these gradients across the spatial dimensions (global average pooling) for each channel. This gives a weight (importance score,
alpha_k) for each feature map channel. - It computes a weighted combination of the forward activation feature maps, using the channel importance weights (
alpha_k) derived from the gradients. - It applies a ReLU function to this combination. This focuses on features that have a positive influence on the class of interest. The result is a heatmap indicating regions the CNN focused on for that class prediction.
-
Step-by-Step Implementation (from the script):
-
Model Hooks Setup:
- Code (in
__init__):self.feature_maps = None self.gradients = None def activations_hook(self, grad): self.gradients = grad
- Code (in
forward):# Hook applied to the output of the last CNN activation if self.training is False and x_cnn.requires_grad: h = x_cnn.register_hook(self.activations_hook) # Register hook to capture gradient self.feature_maps = x_cnn # Store the activation map itself
- Explanation: The
activations_hookfunction is defined to simply store the gradient that flows back into the layer it's attached to. In theforwardpass (only during evaluation when gradients might be needed for explanation), this hook is registered to the tensorx_cnn(output of the last CNN layer block). The activation tensorx_cnnitself is also stored inself.feature_maps. Whenloss.backward()(or specificallyoutput[:, target_class_index].backward()) is called later, PyTorch calculates gradients, and the hook automatically saves the gradient flowing intox_cnnintoself.gradients.
- Code (in
-
get_grad_camFunction:- Code:
def get_grad_cam(model, target_layer_output, target_class_index, output): if model.gradients is None or target_layer_output is None: return None # Check if hooks worked # Backpropagate the specific class score: output[:, target_class_index].backward(retain_graph=True) gradients = model.gradients # Shape: [batch, channels, seq_len] activations = target_layer_output # Shape: [batch, channels, seq_len] # Pool gradients (alpha_k calculation): pooled_gradients = torch.mean(gradients, dim=[2]) # Shape: [batch, channels] # Weight the channels (importance weighting): pooled_gradients = pooled_gradients.unsqueeze(-1) # Shape: [batch, channels, 1] for broadcasting heatmap = torch.sum(activations * pooled_gradients, dim=1) # Shape: [batch, seq_len] heatmap = F.relu(heatmap) # Apply ReLU # Normalize: # ... (normalization code) ... return heatmap.squeeze().cpu().numpy()
- Explanation:
- Takes the model, the captured feature maps (
target_layer_outputwhich ismodel.feature_maps), the index of the class to explain, and the model's final output tensor. output[:, target_class_index].backward(retain_graph=True): This is the key step to get gradients relevant to the target class. It backpropagates only from the score of the desired class.retain_graph=Truemight be needed if you perform multiple backward passes (e.g., explaining multiple classes for the same input).- Retrieves
gradientsandactivationsstored by the hooks. torch.mean(gradients, dim=[2]): Calculates the importance weight (alpha_k) for each channel by averaging the gradient values across the sequence dimension (dimension 2).activations * pooled_gradients: Weights the activation maps channel-wise using the importance weights (broadcasting handles the dimensions).torch.sum(..., dim=1): Sums the weighted channels to create the raw heatmap (collapsing the channel dimension).F.relu(heatmap): Keeps only positive contributions.- Normalization: Scales the heatmap values to be between 0 and 1 for easier visualization.
- Returns the heatmap as a NumPy array.
- Takes the model, the captured feature maps (
- Code:
-
Visualization Loop:
- Code:
for i, sample_idx in enumerate(grad_cam_indices): input_tensor = grad_cam_samples[i].unsqueeze(0) input_tensor.requires_grad = True # CRITICAL for gradient calculation model.zero_grad() # Clear any stale gradients # Forward pass to get output AND trigger hooks output, _ = model(input_tensor) predicted_class = torch.argmax(output, dim=1).item() # Calculate Grad-CAM (includes backward pass inside) heatmap = get_grad_cam(model, model.feature_maps, predicted_class, output) if heatmap is not None: # Upsample heatmap to original input size heatmap_tensor = torch.tensor(heatmap).unsqueeze(0).unsqueeze(0) upsampled_heatmap = F.interpolate(heatmap_tensor, size=INPUT_SIZE, mode='linear', align_corners=False) upsampled_heatmap = upsampled_heatmap.squeeze().numpy() # Plotting plt.figure(...) original_data = input_tensor.squeeze().detach().cpu().numpy() plt.plot(original_data, ...) plt.imshow(upsampled_heatmap[np.newaxis, :], cmap='viridis', aspect='auto', alpha=0.5, extent=[0, INPUT_SIZE -1, np.min(original_data)-0.1, np.max(original_data)+0.1]) # ... (colorbar, title, save, show) ... # Clean up model.gradients = None model.feature_maps = None input_tensor.requires_grad = False
- Explanation:
- Iterates through selected test samples.
- Sets
requires_grad=Trueon the input tensor because we need to compute gradients from the output back towards the intermediate CNN layer via the input. model.zero_grad()ensures gradients from previous iterations don't interfere.- The
model(input_tensor)call performs the forward pass, which gets the prediction and triggers the hooks to storemodel.feature_maps. get_grad_cam(...)is called. Inside this function, thebackward()call triggers the gradient calculation and the storage ofmodel.gradientsvia the hook.- Upsampling: The calculated
heatmaphas a length corresponding to the sequence length after the CNN pooling layers (cnn_output_seq_len). To visualize it relative to the original input features (lengthINPUT_SIZE),F.interpolateis used to resize the heatmap linearly. - Plotting: The original feature sequence is plotted. The
upsampled_heatmapis plotted over it usingplt.imshow.aspect='auto'stretches the heatmap horizontally.alpha=0.5makes it semi-transparent.extentis crucial:[0, INPUT_SIZE - 1, y_min, y_max]tellsimshowto map the heatmap columns to the range 0 toINPUT_SIZE - 1on the x-axis (aligning with the feature plot) and adjusts the y-limits for visual clarity. - Cleanup: Resetting hooks and
requires_gradprevents potential issues in subsequent iterations or model uses.
- Code:
-
-
Purpose: To visualize the internal mechanism of the Attention layer itself. In this model, it's a
MultiheadAttentionlayer operating on the sequence output by the LSTM. Visualizing the weights shows which elements (time steps or positions) in the LSTM output sequence the attention mechanism focused on relative to other elements when computing the weighted representation (attn_output) that gets aggregated into thecontext_vector. For self-attention, it answers "How much did sequence element i attend to sequence element j?". -
Concept: Attention mechanisms compute scores indicating the relevance between elements. For self-attention (
query,key, andvalueall come from the same sequence), the layer calculates how much each position ("query") should pay attention to every other position including itself ("key"). These scores are typically normalized (e.g., via Softmax) to become weights that sum to 1 for each query position. A higher weight means higher attention/focus. Visualizing these weights as a matrix (heatmap) reveals these focus patterns. -
Step-by-Step Implementation (from the script):
-
Get Attention Weights:
- Code:
with torch.no_grad(): # No gradients needed for this output, attn_weights = model(input_tensor) # attn_weights shape likely [batch, seq_len_after_cnn, seq_len_after_cnn] # or [batch, num_heads, seq_len_after_cnn, seq_len_after_cnn]
- Explanation: The model's
forwardpass is executed withintorch.no_grad()as only the forward output (specificallyattn_weights) is needed. Theattn_weightstensor is directly returned by themodel. Note: The exact shape depends on thenn.MultiheadAttentionimplementation details and how it's used. The script assumes the weights available for visualization are shaped[batch, seq_len, seq_len]. Ifnum_heads > 1, the script implicitly either uses an attention layer that averages heads internally or it might be visualizing only one head or an average (though the averaging code is commented out in the model definition but might be intended). Let's proceed assumingattn_weightshas the shape[batch, seq_len, seq_len].
- Code:
-
Process Weights:
- Code:
if attn_weights is not None: # Squeeze batch dim, move to CPU, convert to NumPy attn_map = attn_weights.squeeze(0).cpu().numpy() # Shape: [seq_len, seq_len]
- Explanation: Checks if weights were returned.
squeeze(0)removes the batch dimension (assuming batch size 1 for individual explanation)..cpu().numpy()moves the tensor off the GPU (if used) and converts it to a NumPy array suitable for plotting libraries like Seaborn/Matplotlib. The resultingattn_mapis a 2D array whereattn_map[i, j]represents the attention paid by query stepito key stepj.
- Code:
-
Visualize Heatmap:
- Code:
plt.figure(figsize=(7, 6)) sns.heatmap(attn_map, cmap="viridis", cbar=True) plt.title(f'Attention Map for Sample {sample_idx} (Predicted: {CLASS_NAMES[predicted_class]})') plt.xlabel('Key Time Steps (CNN Output Sequence)') plt.ylabel('Query Time Steps (CNN Output Sequence)') plt.savefig(f"attention_map_sample_{sample_idx}.png") plt.show()
- Explanation:
seaborn.heatmapis used to plot the 2Dattn_map.- The y-axis represents the "query" positions (the position generating the attention context).
- The x-axis represents the "key" positions (the positions being attended to).
- Both axes correspond to the sequence length after the CNN layers (i.e., the input sequence to the attention layer).
- Brighter colors indicate higher attention weights. A bright diagonal might indicate strong self-attention. Off-diagonal bright spots show attention paid to other positions in the sequence. Different patterns (e.g., focus on early/late steps, specific relative positions) can reveal how the model integrates information across the sequence.
- Code:
-
In summary:
- SHAP: Explains prediction based on input features. Answers "Which input features mattered most?".
- Grad-CAM: Explains prediction based on CNN spatial focus. Answers "Where did the CNN look?".
- Attention Viz: Explains prediction based on Attention layer's internal weighting. Answers "Which parts of the sequence representation did the Attention layer focus on?".