Skip to content
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
12fad57
added function related to training and for GNN, needed to define GNN …
Aug 19, 2025
be734dd
Added the gnn part, must be fine-tuned hyper-params, no test
Aug 19, 2025
61c6824
Removed the barriers in the creation of the DAG
Aug 19, 2025
75875ff
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 19, 2025
081651c
coded tested and fixed, need to add a cross validation module
Aug 20, 2025
b82dc01
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
antotu Aug 20, 2025
5ebd202
fixed the problem of the predict_device_for_figure_of_merits
Aug 20, 2025
857cd6f
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 20, 2025
6081f6b
Hellinger test done: success
Aug 20, 2025
7c54da6
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
Aug 20, 2025
bb4da24
GNN predictor fixed with optuna and tested
Aug 21, 2025
10bb52c
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 21, 2025
06be0d6
GNN predictor fixed with optuna and tested
Aug 21, 2025
ce990e3
Modified the tolm for running on the MacOS
Aug 21, 2025
96ca75b
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
a64a082
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
f8c99b5
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 21, 2025
e4e2742
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
5784ff7
Problems modified TPESampler and not TYPESampler
Aug 21, 2025
7e17379
Test modified with number of epochs as parameter
Aug 21, 2025
082de05
Eliminated trained model
Aug 21, 2025
5ed00a9
Changed the test estimated hellinger for windows
Aug 21, 2025
e59a941
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 21, 2025
3a9f16c
Changed the test estimated hellinger for windows
Aug 21, 2025
c43ee01
Merge branch 'gnn-branch' of https://github.com/antotu/predictor-gnn …
Aug 21, 2025
92eda99
Changed the test estimated hellinger for windows
Aug 21, 2025
dc1aa55
Problem with windows solved eliminating warning
Aug 21, 2025
6809ccb
Files modified according suggestion
Aug 22, 2025
8c77598
Fixed the comments related to test hellinger distance and utils
antotu Aug 25, 2025
dc0a824
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 25, 2025
2419952
Fixed modification also with pre-commit
antotu Aug 25, 2025
5335241
Fixed modification also with pre-commit
antotu Aug 25, 2025
96096a0
Refactor the test ml predictor considering to join function related M…
antotu Aug 25, 2025
4613012
Modified part of helper in order to solve problems code
antotu Aug 26, 2025
1c728e2
Pre-commit has substituted Wille in Will
antotu Aug 26, 2025
c31cb46
Update tests/device_selection/test_predictor_ml.py
antotu Aug 27, 2025
13cf0f4
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
713343f
first round fixes
antotu Aug 27, 2025
17c6575
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
169b00e
pre-commit fixes
antotu Aug 27, 2025
2248081
pre-commit fixes
antotu Aug 27, 2025
8f90b12
Update src/mqt/predictor/ml/predictor.py
antotu Aug 27, 2025
74ec34b
Update src/mqt/predictor/ml/predictor.py
antotu Aug 27, 2025
f99e17b
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
57b1a29
Partial modification
antotu Aug 27, 2025
96232f0
Merge branch 'gnn-branch' of github.com:antotu/predictor-gnn into gnn…
antotu Aug 27, 2025
61965d8
🎨 pre-commit fixes
pre-commit-ci[bot] Aug 27, 2025
93f5414
fixed comments repo
antotu Aug 27, 2025
95f5359
Merge branch 'gnn-branch' of github.com:antotu/predictor-gnn into gnn…
antotu Aug 27, 2025
4fb7112
Modified the gates accepted
antotu Aug 28, 2025
5ea1720
Modified list
antotu Aug 28, 2025
312e6ea
Fixed bug Swap and Cswap gates
antotu Sep 8, 2025
156b7e6
Edit for saving memory GPU
antotu Sep 12, 2025
77c9f5c
Added patience as variable
antotu Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ dependencies = [
"numpy>=1.24; python_version >= '3.11'",
"numpy>=1.22",
"numpy>=1.22,<2; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict numpy v2 for macOS x86 since it is not supported anymore since torch v2.3.0
"optuna>=4.5.0",
"torch-geometric>=2.6.1",
"torch>=2.7.1,<2.8.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore.
"typing-extensions>=4.1", # for `assert_never`
]
Expand Down Expand Up @@ -164,9 +166,15 @@ implicit_reexport = true
# recent versions of `gym` are typed, but stable-baselines3 pins a very old version of gym.
# qiskit is not yet marked as typed, but is typed mostly.
# the other libraries do not have type stubs.
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*", "networkx.*", "stable_baselines3.*"]
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*", "networkx.*", "stable_baselines3.*", "torch", "torch.*", "torch_geometric", "torch_geometric.*", "optuna.*"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["mqt.predictor.ml.*"]
disallow_subclassing_any = false



Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

[tool.ruff]
line-length = 120
extend-include = ["*.ipynb"]
Expand Down Expand Up @@ -245,6 +253,7 @@ wille = "wille"
anc = "anc"
aer = "aer"
fom = "fom"
TPE = "TPE"

[tool.repo-review]
ignore = ["GH200"]
40 changes: 40 additions & 0 deletions src/mqt/predictor/_version.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be tracked and can be removed.

Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM
# Copyright (c) 2025 Munich Quantum Software Company GmbH
# All rights reserved.
#
# SPDX-License-Identifier: MIT
#
# Licensed under the MIT License

# file generated by setuptools-scm
# don't change, don't track in version control
from __future__ import annotations

__all__ = [
"__commit_id__",
"__version__",
"__version_tuple__",
"commit_id",
"version",
"version_tuple",
]

TYPE_CHECKING = False
if TYPE_CHECKING:
VERSION_TUPLE = tuple[int | str, ...]

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
COMMIT_ID = str | None
else:
VERSION_TUPLE = object
COMMIT_ID = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID

__version__ = version = "2.3.1.dev6+g1d835bd4c"
__version_tuple__ = version_tuple = (2, 3, 1, "dev6", "g1d835bd4c")

__commit_id__ = commit_id = None
10 changes: 5 additions & 5 deletions src/mqt/predictor/hellinger/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ def calc_device_specific_features(
return np.array(list(feature_dict.values()))


def get_hellinger_model_path(device: Target) -> Path:
def get_hellinger_model_path(device: Target, gnn: bool = False) -> Path:
"""Returns the path to the trained model folder resulting from the machine learning training."""
training_data_path = Path(str(resources.files("mqt.predictor"))) / "ml" / "training_data"
training_data_path = Path(str(resources.files("mqt.predictor"))) / "ml" / "training_data" / "trained_model"
model_path = (
training_data_path
/ "trained_model"
/ ("trained_hellinger_distance_regressor_" + device.description + ".joblib")
(training_data_path / ("trained_hellinger_distance_regressor_gnn_" + device.description + ".pth"))
if gnn
else (training_data_path / ("trained_hellinger_distance_regressor_" + device.description + ".joblib"))
)
return Path(model_path)
7 changes: 6 additions & 1 deletion src/mqt/predictor/ml/__init__.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change this file. GNN should never be accessed by a user directly.

Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,9 @@
from mqt.predictor.ml import helper
from mqt.predictor.ml.predictor import Predictor, predict_device_for_figure_of_merit, setup_device_predictor

__all__ = ["Predictor", "helper", "predict_device_for_figure_of_merit", "setup_device_predictor"]
__all__ = [
"Predictor",
"helper",
"predict_device_for_figure_of_merit",
"setup_device_predictor",
]
174 changes: 174 additions & 0 deletions src/mqt/predictor/ml/gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM
# Copyright (c) 2025 Munich Quantum Software Company GmbH
# All rights reserved.
#
# SPDX-License-Identifier: MIT
#
# Licensed under the MIT License

"""This module contains the GNN module for graph neural networks."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch_geometric.nn import SAGEConv, global_mean_pool

if TYPE_CHECKING:
from collections.abc import (
Callable, # on 3.10+ prefer collections.abc
)

from torch_geometric.data import Data


class GraphConvolutionSage(nn.Module):
"""Graph convolutional layer using SAGEConv."""

def __init__(
self,
in_feats: int,
hidden_dim: int,
num_resnet_layers: int,
*,
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
conv_act_kwargs: dict[str, Any] | None = None,
) -> None:
"""A flexible SageConv graph classification model.
Args:
in_feats: dimensionality of node features
hidden_dim: output size of SageConv
num_resnet_layers: how many SageConv layers (with residuals) to stack after the SageConvs
mlp_units: list of units for each layer of the final MLP
conv_activation: activation fn after each graph layer
conv_act_kwargs: extra kwargs for conv_activation
final_activation: activation applied to the final scalar output
"""
super().__init__()
self.conv_activation = conv_activation
self.conv_act_kwargs = conv_act_kwargs or {}

# --- GRAPH ENCODER ---
self.convs = nn.ModuleList()
# 1) Convolution not in residual configuration
# Possible to generalize the code
self.convs.append(SAGEConv(in_feats, hidden_dim))
self.convs.append(SAGEConv(hidden_dim, hidden_dim))

for _ in range(num_resnet_layers):
self.convs.append(SAGEConv(hidden_dim, hidden_dim))

def forward(self, data: Data) -> torch.Tensor:
"""Forward function that allows to elaborate the input graph."""
x, edge_index, batch = data.x, data.edge_index, data.batch
# 1) Graph stack with residuals
for i, conv in enumerate(self.convs):
x_new = conv(x, edge_index)
x_new = self.conv_activation(x_new, **self.conv_act_kwargs)
# the number 2 is set because two convolution without residual configuration are applied
# and then all the others are in residual configuration
x = x_new if i < 2 else x + x_new

# 2) Global pooling
return global_mean_pool(x, batch)

# 3) MLP head
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 3) MLP head



class GNN(nn.Module):
"""Architecture composed by a Graph Convolutional part with Sage Convolution module and followed by a MLP."""

def __init__(
self,
in_feats: int,
hidden_dim: int,
num_resnet_layers: int,
mlp_units: list[int],
*,
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
conv_act_kwargs: dict[str, Any] | None = None,
mlp_activation: Callable[..., torch.Tensor] = functional.leaky_relu,
mlp_act_kwargs: dict[str, Any] | None = None,
classes: list[str] | None = None,
output_dim: int = 1,
) -> None:
"""Init class for the GNN.
Arguments:
in_feats: dimension of input features of the node
hidden_dim: dimension of hidden output channels of the Convolutional part
num_resnet_layers: number of residual layers
mlp_units: list of units for each layer of the final MLP
conv_activation: activation fn after each graph layer
conv_act_kwargs: extra kwargs for conv_activation.
mlp_activation: activation fn after each MLP layer
mlp_act_kwargs: extra kwargs for mlp_activation.
output_dim: dimension of the output, default is 1 for regression tasks
classes: list of class names for classification tasks
"""
# ─────────────────────────────────────────────────────────────────────────
# Suppress torch-geometric "plugin" import warnings (torch-scatter, etc.)
warnings.filterwarnings(
"ignore",
message=r"An issue occurred while importing 'torch-scatter'.*",
category=UserWarning,
module=r"torch_geometric.typing",
)
warnings.filterwarnings(
"ignore",
message=r"An issue occurred while importing 'torch-spline-conv'.*",
category=UserWarning,
module=r"torch_geometric.typing",
)
warnings.filterwarnings(
"ignore",
message=r"An issue occurred while importing 'torch-sparse'.*",
category=UserWarning,
module=r"torch_geometric.typing",
)
warnings.filterwarnings(
"ignore",
message=r"An issue occurred while importing 'torch-geometric'.*",
category=UserWarning,
)

warnings.filterwarnings(
"ignore",
message=r".*'type_params' parameter of 'typing\._eval_type'.*",
category=DeprecationWarning,
)

super().__init__()
# Convolutional part
self.graph_conv = GraphConvolutionSage(
in_feats, hidden_dim, num_resnet_layers, conv_activation=conv_activation, conv_act_kwargs=conv_act_kwargs
)

# MLP architecture
self.mlp_activation = mlp_activation
self.mlp_act_kwargs = mlp_act_kwargs or {}
self.classes = classes
self.fcs = nn.ModuleList()
last_dim = hidden_dim
for out_dim in mlp_units:
self.fcs.append(nn.Linear(last_dim, out_dim))
last_dim = out_dim
self.out = nn.Linear(last_dim, output_dim)

def forward(self, data: Data) -> torch.Tensor:
"""Forward function that allows to elaborate the input graph.
Arguments:
data: The input graph data.
"""
# apply the convolution
x = self.graph_conv(data)
# Apply the MLP
for fc in self.fcs:
x = self.mlp_activation(fc(x), **self.mlp_act_kwargs)
return self.out(x)
Loading
Loading