Skip to content
69 changes: 69 additions & 0 deletions tests/generator/molgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from rdkit import RDLogger
from torch_molecule.generator.molgan import (
MolGAN,
RewardOracle,
)

RDLogger.DisableLog("rdApp.*")

def test_molgan():
# Sample SMILES list
smiles_list = [
"CCO", "CCN", "CCC", "COC",
"CCCl", "CCF", "CBr", "CN(C)C", "CC(=O)O", "c1ccccc1",
'CNC[C@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@@H]1C',
'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C',
'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F',
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
]
model_decoder = ["C", "N", "O", "F", "Cl", "Br"]

# 1. Initialize MolGAN
print("\n=== Testing MolGAN Initialization ===")
GANConfig = {
"num_nodes": 9,
"num_layers": 4,
"num_atom_types": 5,
"num_bond_types": 4,
"latent_dim": 56,
"hidden_dims_gen": [128, 128],
"hidden_dims_disc": [128, 128],
"tau": 1.0,
"use_reward": True,
}
model = MolGAN(**GANConfig, device="cpu")
print("MolGAN initialized successfully")

# 2. Fit with QED reward
print("\n=== Testing MolGAN Training with QED Reward ===")
reward = RewardOracle(kind="qed")
model.fit(X=smiles_list, reward=reward, epochs=5, batch_size=16)
print("MolGAN trained successfully")

# 3. Generation
print("\n=== Testing MolGAN Generation ===")
gen_smiles = model.generate(n_samples=10)
print(f"Generated {len(gen_smiles)} SMILES")
print("Example generated molecules:", gen_smiles[:3])

# 4. Save and Reload
print("\n=== Testing MolGAN Save & Load ===")
save_dir = "molgan-test"
model.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")

model2 = MolGAN.from_pretrained(save_dir)
print("Model loaded successfully")

gen_smiles2 = model2.generate(n_samples=5)
print("Generated after loading:", gen_smiles2[:3])

# 5. Cleanup
import shutil
if os.path.exists(save_dir):
shutil.rmtree(save_dir)
print(f"Cleaned up {save_dir}")

if __name__ == "__main__":
test_molgan()
Empty file.
109 changes: 109 additions & 0 deletions torch_molecule/generator/molgan/molgan_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import List, Optional, Callable
from rdkit import Chem
import torch
from torch.utils.data import Dataset
from .molgan_utils import qed_reward_fn

class MolGANDataset(Dataset):
"""
A PyTorch Dataset for MolGAN, with all RDKit and graph tensor processing
precomputed in __init__ for fast, pure-tensor __getitem__ access.
Optionally caches property values for each molecule.
"""
def __init__(
self,
data: List[str],
atom_types: List[str],
bond_types: List[str],
max_num_atoms: int = 50,
cache_properties: bool = False,
property_fn: Optional[Callable] = None,
return_mol: bool = False,
device: Optional[torch.device] = None
):
self.data = data
self.atom_types = atom_types
self.bond_types = bond_types
self.max_num_atoms = max_num_atoms
self.atom_type_to_idx = {atom: idx for idx, atom in enumerate(atom_types)}
self.bond_type_to_idx = {bond: idx for idx, bond in enumerate(bond_types)}
self.return_mol = return_mol
self.device = torch.device(device) if device is not None else None

self.node_features = []
self.adjacency_matrices = []
self.mols = []
self.cached_properties = [] if cache_properties and property_fn else None

self.property_fn = property_fn if property_fn is not None else qed_reward_fn

for idx, smiles in enumerate(self.data):
mol = Chem.MolFromSmiles(smiles)
self.mols.append(mol)
# Default: if invalid, fill with zeros and (optionally) property 0
nf = torch.zeros((self.max_num_atoms, len(self.atom_types)), dtype=torch.float)
adj = torch.zeros((self.max_num_atoms, self.max_num_atoms, len(self.bond_types)), dtype=torch.float)
prop_val = 0.0 if cache_properties else None

if mol is not None:
num_atoms = mol.GetNumAtoms()
if num_atoms > self.max_num_atoms:
raise ValueError(f"Molecule at index {idx} exceeds max_num_atoms: {num_atoms} > {self.max_num_atoms}")

for i, atom in enumerate(mol.GetAtoms()):
atom_type = atom.GetSymbol()
if atom_type in self.atom_type_to_idx:
nf[i, self.atom_type_to_idx[atom_type]] = 1.0

for bond in mol.GetBonds():
begin_idx = bond.GetBeginAtomIdx()
end_idx = bond.GetEndAtomIdx()
bond_type = str(bond.GetBondType())
if bond_type in self.bond_type_to_idx:
bidx = self.bond_type_to_idx[bond_type]
adj[begin_idx, end_idx, bidx] = 1.0
adj[end_idx, begin_idx, bidx] = 1.0

if cache_properties and self.property_fn:
try:
prop_val = self.property_fn(mol)
except Exception:
prop_val = 0.0

# Move tensors to device immediately if a device is set
if self.device is not None:
nf = nf.to(self.device)
adj = adj.to(self.device)

self.node_features.append(nf)
self.adjacency_matrices.append(adj)
if cache_properties and property_fn and self.cached_properties is not None:
self.cached_properties.append(prop_val)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
parts = [
self.node_features[idx],
self.adjacency_matrices[idx]
]
# add optional property
if self.cached_properties is not None:
parts.append(self.cached_properties[idx])
# add optional Mol object (can always access it if you want)
if self.return_mol:
parts.append(self.mols[idx])

# Default: (node_features, adjacency_matrix)
# With property: (node_features, adjacency_matrix, property)
# With property and mol: (node_features, adjacency_matrix, property, mol)
# With only mol: (node_features, adjacency_matrix, mol)
return tuple(parts)







201 changes: 201 additions & 0 deletions torch_molecule/generator/molgan/molgan_gen_disc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import torch
from dataclasses import dataclass
from typing import Tuple
from .molgan_r_gcn import RelationalGCNLayer # Local import to avoid circular dependency
import torch.nn.functional as F

@dataclass
class MolGANGeneratorConfig:
def __init__(
self,
z_dim: int = 32,
g_conv_dim: int = 64,
d_conv_dim: int = 64,
g_num_layers: int = 3,
d_num_layers: int = 3,
num_atom_types: int = 5,
num_bond_types: int = 4,
max_num_atoms: int = 9,
dropout: float = 0.0,
tau: float = 1.0,
use_batchnorm: bool = True,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
):
self.z_dim = z_dim
self.g_conv_dim = g_conv_dim
self.d_conv_dim = d_conv_dim
self.g_num_layers = g_num_layers
self.d_num_layers = d_num_layers
self.num_atom_types = num_atom_types
self.num_bond_types = num_bond_types
self.max_num_atoms = max_num_atoms
self.dropout = dropout
self.use_batchnorm = use_batchnorm
self.tau = tau # Gumbel-Softmax temperature
self.device = device


# MolGAN Generotor
class MolGANGenerator(torch.nn.Module):
def __init__(self, config: MolGANGeneratorConfig):
super(MolGANGenerator, self).__init__()
self.z_dim = config.z_dim
self.g_conv_dim = config.g_conv_dim
self.g_num_layers = config.g_num_layers
self.num_atom_types = config.num_atom_types
self.num_bond_types = config.num_bond_types
self.max_num_atoms = config.max_num_atoms
self.dropout = config.dropout
self.use_batchnorm = config.use_batchnorm
self.tau = config.tau
self.device = config.device
self.to(self.device)

layers = []
input_dim = self.z_dim
for i in range(self.g_num_layers):
output_dim = self.g_conv_dim * (2 ** i)
layers.append(torch.nn.Linear(input_dim, output_dim))
if self.use_batchnorm:
layers.append(torch.nn.BatchNorm1d(output_dim))
layers.append(torch.nn.ReLU())
if self.dropout > 0:
layers.append(torch.nn.Dropout(self.dropout))
input_dim = output_dim

self.fc_layers = torch.nn.Sequential(*layers)
self.atom_fc = torch.nn.Linear(input_dim, self.max_num_atoms * self.num_atom_types)
self.bond_fc = torch.nn.Linear(input_dim, self.num_bond_types * self.max_num_atoms * self.max_num_atoms)

def forward(self, z: torch.Tensor, sample_mode='softmax') -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = z.size(0)
h = self.fc_layers(z)
atom_logits = self.atom_fc(h).view(batch_size, self.max_num_atoms, self.num_atom_types)
# Output bond logits with [batch, num_bond_types, max_num_atoms, max_num_atoms] order
bond_logits = self.bond_fc(h).view(batch_size, self.num_bond_types, self.max_num_atoms, self.max_num_atoms)

# Nodes
if sample_mode == 'softmax':
node = torch.softmax(atom_logits, dim=-1)
elif sample_mode == 'soft_gumbel':
node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=False, dim=-1)
elif sample_mode == 'hard_gumbel':
node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=True, dim=-1)
elif sample_mode == 'argmax':
node = atom_logits.argmax(dim=-1)
else:
raise ValueError(f"Unknown sample_mode: {sample_mode}")

# Adjacency
if sample_mode == 'softmax':
adj = torch.softmax(bond_logits, dim=1)
elif sample_mode == 'soft_gumbel':
adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=False, dim=1)
elif sample_mode == 'hard_gumbel':
adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=True, dim=1)
else:
raise ValueError(f"Unknown sample_mode: {sample_mode}")

return node, adj





# MolGAN Discriminator
@dataclass
class MolGANDiscriminatorConfig:
def __init__(
self,
in_dim: int = 5, # Number of atom types (node feature dim). Typically set automatically.
hidden_dim: int = 64, # Hidden feature/channel size for GCN layers.
num_layers: int = 3, # Number of R-GCN layers (depth).
num_relations: int = 4, # Number of bond types (relation types per edge).
max_num_atoms: int = 9, # Max node count in padded tensor.
dropout: float = 0.0, # Dropout between layers.
use_batchnorm: bool = True, # BatchNorm or similar normalization.
readout: str = 'sum', # Readout type (sum/mean/max for pooling nodes to graph-level vector)
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
):
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_relations = num_relations
self.max_num_atoms = max_num_atoms
self.dropout = dropout
self.use_batchnorm = use_batchnorm
self.readout = readout
self.device = device


class MolGANDiscriminator(torch.nn.Module):
def __init__(self, config: MolGANDiscriminatorConfig):
super(MolGANDiscriminator, self).__init__()

self.in_dim = config.in_dim
self.hidden_dim = config.hidden_dim
self.num_layers = config.num_layers
self.num_relations = config.num_relations
self.max_num_atoms = config.max_num_atoms
self.dropout = config.dropout
self.use_batchnorm = config.use_batchnorm
self.readout = config.readout

layers = []
input_dim = self.in_dim
for i in range(self.num_layers):
output_dim = self.hidden_dim * (2 ** i)
layers.append(RelationalGCNLayer(input_dim, output_dim, self.num_relations))
if self.use_batchnorm:
layers.append(torch.nn.BatchNorm1d(self.max_num_atoms))
layers.append(torch.nn.LeakyReLU(0.2))
if self.dropout > 0:
layers.append(torch.nn.Dropout(self.dropout))
input_dim = output_dim

self.gcn_layers = torch.nn.ModuleList(layers)
self.fc = torch.nn.Linear(input_dim, 1)
self.device = config.device
self.to(self.device)

def forward(
self,
atom_feats: torch.Tensor,
adj: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
# atom_feats: [batch, max_num_atoms, num_atom_types]
# adj: [batch, num_bond_types, max_num_atoms, max_num_atoms]
# mask: [batch, max_num_atoms] (float, 1=real, 0=pad)
h = atom_feats
for layer in self.gcn_layers:
if isinstance(layer, RelationalGCNLayer):
h = layer(h, adj)
else:
# If using BatchNorm1d, input should be [batch, features, nodes]
if isinstance(layer, torch.nn.BatchNorm1d):
# Permute for batchnorm: [batch, nodes, features] → [batch, features, nodes]
h = layer(h.permute(0, 2, 1)).permute(0, 2, 1)
else:
h = layer(h)

# MASKED GRAPH READOUT
# mask: [batch, max_num_atoms] float
mask = mask.unsqueeze(-1) # [batch, max_num_atoms, 1]
h_masked = h * mask # zeros padded nodes

if self.readout == 'sum':
g = h_masked.sum(dim=1) # [batch, hidden_dim]
elif self.readout == 'mean':
# Prevent divide-by-zero with (mask.sum(dim=1, keepdim=True)+1e-8)
g = h_masked.sum(dim=1) / (mask.sum(dim=1) + 1e-8)
elif self.readout == 'max':
# Set padded to large neg, then max
h_masked_pad = h.clone()
h_masked_pad[mask.squeeze(-1) == 0] = float('-inf')
g, _ = h_masked_pad.max(dim=1)
else:
raise ValueError(f"Unknown readout type: {self.readout}")

out = self.fc(g) # [batch, 1]
return out.squeeze(-1) # [batch]
Loading