-
-
Notifications
You must be signed in to change notification settings - Fork 21
Add GNN-Based Predictor with DAG Preprocessing #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 35 commits
12fad57
be734dd
61c6824
75875ff
081651c
b82dc01
5ebd202
857cd6f
6081f6b
7c54da6
bb4da24
10bb52c
06be0d6
ce990e3
96ca75b
a64a082
f8c99b5
e4e2742
5784ff7
7e17379
082de05
5ed00a9
e59a941
3a9f16c
c43ee01
92eda99
dc1aa55
6809ccb
8c77598
dc0a824
2419952
5335241
96096a0
4613012
1c728e2
c31cb46
13cf0f4
713343f
17c6575
169b00e
2248081
8f90b12
74ec34b
f99e17b
57b1a29
96232f0
61965d8
93f5414
95f5359
4fb7112
5ea1720
312e6ea
156b7e6
77c9f5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should not be tracked and can be removed. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||
# Copyright (c) 2025 Munich Quantum Software Company GmbH | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: MIT | ||
# | ||
# Licensed under the MIT License | ||
|
||
# file generated by setuptools-scm | ||
# don't change, don't track in version control | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"__commit_id__", | ||
"__version__", | ||
"__version_tuple__", | ||
"commit_id", | ||
"version", | ||
"version_tuple", | ||
] | ||
|
||
TYPE_CHECKING = False | ||
if TYPE_CHECKING: | ||
VERSION_TUPLE = tuple[int | str, ...] | ||
Check warningCode scanning / CodeQL Unreachable code Warning
This statement is unreachable.
|
||
COMMIT_ID = str | None | ||
else: | ||
VERSION_TUPLE = object | ||
COMMIT_ID = object | ||
|
||
version: str | ||
__version__: str | ||
__version_tuple__: VERSION_TUPLE | ||
version_tuple: VERSION_TUPLE | ||
commit_id: COMMIT_ID | ||
__commit_id__: COMMIT_ID | ||
|
||
__version__ = version = "2.3.1.dev6+g1d835bd4c" | ||
__version_tuple__ = version_tuple = (2, 3, 1, "dev6", "g1d835bd4c") | ||
|
||
__commit_id__ = commit_id = None |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to change this file. GNN should never be accessed by a user directly. |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,174 @@ | ||||
# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM | ||||
# Copyright (c) 2025 Munich Quantum Software Company GmbH | ||||
# All rights reserved. | ||||
# | ||||
# SPDX-License-Identifier: MIT | ||||
# | ||||
# Licensed under the MIT License | ||||
|
||||
"""This module contains the GNN module for graph neural networks.""" | ||||
|
||||
from __future__ import annotations | ||||
|
||||
import warnings | ||||
from typing import TYPE_CHECKING, Any | ||||
|
||||
import torch | ||||
import torch.nn as nn | ||||
import torch.nn.functional as functional | ||||
from torch_geometric.nn import SAGEConv, global_mean_pool | ||||
|
||||
if TYPE_CHECKING: | ||||
from collections.abc import ( | ||||
Callable, # on 3.10+ prefer collections.abc | ||||
) | ||||
|
||||
from torch_geometric.data import Data | ||||
|
||||
|
||||
class GraphConvolutionSage(nn.Module): | ||||
"""Graph convolutional layer using SAGEConv.""" | ||||
|
||||
def __init__( | ||||
self, | ||||
in_feats: int, | ||||
hidden_dim: int, | ||||
num_resnet_layers: int, | ||||
*, | ||||
conv_activation: Callable[..., torch.Tensor] = functional.leaky_relu, | ||||
conv_act_kwargs: dict[str, Any] | None = None, | ||||
) -> None: | ||||
"""A flexible SageConv graph classification model. | ||||
Args: | ||||
in_feats: dimensionality of node features | ||||
hidden_dim: output size of SageConv | ||||
num_resnet_layers: how many SageConv layers (with residuals) to stack after the SageConvs | ||||
mlp_units: list of units for each layer of the final MLP | ||||
conv_activation: activation fn after each graph layer | ||||
conv_act_kwargs: extra kwargs for conv_activation | ||||
final_activation: activation applied to the final scalar output | ||||
""" | ||||
super().__init__() | ||||
self.conv_activation = conv_activation | ||||
self.conv_act_kwargs = conv_act_kwargs or {} | ||||
|
||||
# --- GRAPH ENCODER --- | ||||
self.convs = nn.ModuleList() | ||||
# 1) Convolution not in residual configuration | ||||
# Possible to generalize the code | ||||
self.convs.append(SAGEConv(in_feats, hidden_dim)) | ||||
self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||||
|
||||
for _ in range(num_resnet_layers): | ||||
self.convs.append(SAGEConv(hidden_dim, hidden_dim)) | ||||
|
||||
def forward(self, data: Data) -> torch.Tensor: | ||||
"""Forward function that allows to elaborate the input graph.""" | ||||
x, edge_index, batch = data.x, data.edge_index, data.batch | ||||
# 1) Graph stack with residuals | ||||
for i, conv in enumerate(self.convs): | ||||
x_new = conv(x, edge_index) | ||||
x_new = self.conv_activation(x_new, **self.conv_act_kwargs) | ||||
# the number 2 is set because two convolution without residual configuration are applied | ||||
# and then all the others are in residual configuration | ||||
x = x_new if i < 2 else x + x_new | ||||
|
||||
# 2) Global pooling | ||||
return global_mean_pool(x, batch) | ||||
|
||||
# 3) MLP head | ||||
|
# 3) MLP head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.