-
-
Notifications
You must be signed in to change notification settings - Fork 20
Add GNN-Based Predictor with DAG Preprocessing #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…into gnn-branch
…into gnn-branch
…into gnn-branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @antotu , thanks for your continued efforts!
I still didn't manage to get fully through, so here is another preliminary batch of feedback.
src/mqt/predictor/ml/gnn.py
Outdated
# 2) Global pooling | ||
return global_mean_pool(x, batch) | ||
|
||
# 3) MLP head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# 3) MLP head |
pyproject.toml
Outdated
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/mqt/predictor/ml/predictor.py
Outdated
warnings.filterwarnings( | ||
"ignore", | ||
message=r"An issue occurred while importing 'torch-scatter'.*", | ||
category=UserWarning, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch.*:UserWarning
should already be ignored through the filterwarnings
in pyproject.toml
. Please only add the additionally required ones there. The same goes for the other files, too. Thanks!
src/mqt/predictor/ml/predictor.py
Outdated
number_epochs: The number of epochs to train the GNN model. Defaults to 100. | ||
number_trials: The number of trials to run for hyperparameter optimization for the GNN. Defaults to 50. | ||
verbose: Whether to print verbose output during training GNN. Defaults to False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number_epochs: The number of epochs to train the GNN model. Defaults to 100. | |
number_trials: The number of trials to run for hyperparameter optimization for the GNN. Defaults to 50. | |
verbose: Whether to print verbose output during training GNN. Defaults to False. | |
**gnn_kwargs: Forwarded to `Predictor.train_gnn_model` when `gnn=True` | |
(e.g., `number_epochs=100`, `number_trials=50`, `verbose=False`). |
src/mqt/predictor/ml/predictor.py
Outdated
number_epochs: int = 100, | ||
number_trials: int = 50, | ||
verbose: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can use a gnn_kwargs
dictionary here to avoid cluttering the arguments with only GNN-specific things. It could also be useful in the future to add more hyperparameters if needed.
Co-authored-by: Patrick Hopf <[email protected]> Signed-off-by: Antonio Tudisco <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just another batch of feedback. Thank you for integrating the requested changes so fast!
src/mqt/predictor/ml/helper.py
Outdated
|
||
|
||
def create_dag(qc: QuantumCircuit) -> tuple[torch.Tensor, torch.Tensor, int]: | ||
"""Creates and returns the associate DAG of the quantum circuit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Creates and returns the associate DAG of the quantum circuit. | |
"""Creates and returns the feature-annotated DAG of the quantum circuit. |
src/mqt/predictor/ml/helper.py
Outdated
|
||
from __future__ import annotations | ||
|
||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numpy is already imported and provides the same functionality. If I remember correctly, in a similar way, PyTorch provides these basic things too (perhaps we can further reduce imports here).
src/mqt/predictor/ml/helper.py
Outdated
return_arrays: bool = False, | ||
verbose: bool = False, | ||
) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]: | ||
"""Evaluate the models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the description a bit more detailed? Just so we know why this is necessary and that it is only required for the GNN models.
src/mqt/predictor/ml/helper.py
Outdated
restore_best: bool = True, | ||
scheduler: torch.optim.lr_scheduler._LRScheduler | None = None, | ||
) -> None: | ||
"""Trains the model with optional early stopping on validation loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Trains the model with optional early stopping on validation loss. | |
"""Trains a GNN model with optional early stopping on validation loss. |
src/mqt/predictor/ml/predictor.py
Outdated
qc = QuantumCircuit.from_qasm_file(path_uncompiled_circuit / file) | ||
feature_vec = create_feature_vector(qc) | ||
training_sample = (feature_vec, target_label) | ||
if not self.gnn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not self.gnn: | |
if self.gnn: | |
x, edge_index, number_of_gates = create_dag(qc) | |
y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float) | |
training_sample = (x, y, edge_index, number_of_gates, target_label) | |
else: | |
feature_vec = create_feature_vector(qc) | |
training_sample = (feature_vec, target_label) | |
circuit_name = str(file).split(".")[0] | |
return training_sample, circuit_name, scores_list |
src/mqt/predictor/ml/predictor.py
Outdated
feature_vec = create_feature_vector(qc) | ||
training_sample = (feature_vec, target_label) | ||
circuit_name = str(file).split(".")[0] | ||
return training_sample, circuit_name, scores_list | ||
x, edge_index, number_of_gates = create_dag(qc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feature_vec = create_feature_vector(qc) | |
training_sample = (feature_vec, target_label) | |
circuit_name = str(file).split(".")[0] | |
return training_sample, circuit_name, scores_list | |
x, edge_index, number_of_gates = create_dag(qc) |
src/mqt/predictor/ml/predictor.py
Outdated
self.devices_description = [dev.description for dev in self.devices] | ||
y = self.devices_description.index(target_label) | ||
print(target_label) | ||
return Data( | ||
x=x, | ||
y=torch.tensor([y], dtype=torch.float), | ||
circuit_name=circuit_name, | ||
edge_index=edge_index, | ||
target_label=target_label, # torch.tensor([target_label], dtype=torch.float), | ||
scores_list=scores_list, | ||
num_nodes=number_of_gates, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.devices_description = [dev.description for dev in self.devices] | |
y = self.devices_description.index(target_label) | |
print(target_label) | |
return Data( | |
x=x, | |
y=torch.tensor([y], dtype=torch.float), | |
circuit_name=circuit_name, | |
edge_index=edge_index, | |
target_label=target_label, # torch.tensor([target_label], dtype=torch.float), | |
scores_list=scores_list, | |
num_nodes=number_of_gates, | |
) |
src/mqt/predictor/ml/predictor.py
Outdated
|
||
return mdl.best_estimator_ | ||
|
||
def _get_prepared_training_graphs(self) -> TrainingData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the changes above, we can drop this graph-specific method and instead use the _get_prepared_training_data
with a slight modification when loading the graph-specific training data (if self.gnn: ...
).
Co-authored-by: Patrick Hopf <[email protected]> Signed-off-by: Antonio Tudisco <[email protected]>
Co-authored-by: Patrick Hopf <[email protected]> Signed-off-by: Antonio Tudisco <[email protected]>
train_loss = running_loss / max(1, total) | ||
if scheduler is not None: | ||
scheduler.step() | ||
val_loss = float("inf") |
Check warning
Code scanning / CodeQL
Variable defined multiple times Warning
Description
This PR introduces a Graph Neural Network (GNN) as an alternative to the Random Forest model for predicting the best device to run a quantum circuit.
To support this, the preprocessing pipeline was redesigned: instead of manually extracting features from the circuit, the model now directly takes as input the Directed Acyclic Graph (DAG) representation of the quantum circuit.
🚀 Major Changes
Graph Neural Network Integration
🎯 Motivation
🔧 Fixes and Enhancements
📦 Dependency Updates
optuna>=4.5.0
torch-geometric>=2.6.1
Checklist: