diff --git a/tests/generator/molgan.py b/tests/generator/molgan.py new file mode 100644 index 0000000..09e8265 --- /dev/null +++ b/tests/generator/molgan.py @@ -0,0 +1,69 @@ +import os +from rdkit import RDLogger +from torch_molecule.generator.molgan import ( + MolGAN, + RewardOracle, +) + +RDLogger.DisableLog("rdApp.*") + +def test_molgan(): + # Sample SMILES list + smiles_list = [ + "CCO", "CCN", "CCC", "COC", + "CCCl", "CCF", "CBr", "CN(C)C", "CC(=O)O", "c1ccccc1", + 'CNC[C@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@@H]1C', + 'CNC[C@@H]1OCc2cnnn2CCCC(=O)N([C@H](C)CO)C[C@H]1C', + 'C[C@H]1CN([C@@H](C)CO)C(=O)CCCn2cc(nn2)CO[C@@H]1CN(C)C(=O)CCC(F)(F)F', + 'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F' + ] + model_decoder = ["C", "N", "O", "F", "Cl", "Br"] + + # 1. Initialize MolGAN + print("\n=== Testing MolGAN Initialization ===") + GANConfig = { + "num_nodes": 9, + "num_layers": 4, + "num_atom_types": 5, + "num_bond_types": 4, + "latent_dim": 56, + "hidden_dims_gen": [128, 128], + "hidden_dims_disc": [128, 128], + "tau": 1.0, + "use_reward": True, + } + model = MolGAN(**GANConfig, device="cpu") + print("MolGAN initialized successfully") + + # 2. Fit with QED reward + print("\n=== Testing MolGAN Training with QED Reward ===") + reward = RewardOracle(kind="qed") + model.fit(X=smiles_list, reward=reward, epochs=5, batch_size=16) + print("MolGAN trained successfully") + + # 3. Generation + print("\n=== Testing MolGAN Generation ===") + gen_smiles = model.generate(n_samples=10) + print(f"Generated {len(gen_smiles)} SMILES") + print("Example generated molecules:", gen_smiles[:3]) + + # 4. Save and Reload + print("\n=== Testing MolGAN Save & Load ===") + save_dir = "molgan-test" + model.save_pretrained(save_dir) + print(f"Model saved to {save_dir}") + + model2 = MolGAN.from_pretrained(save_dir) + print("Model loaded successfully") + + gen_smiles2 = model2.generate(n_samples=5) + print("Generated after loading:", gen_smiles2[:3]) + + # 5. Cleanup + import shutil + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + print(f"Cleaned up {save_dir}") + +if __name__ == "__main__": + test_molgan() diff --git a/torch_molecule/generator/molgan/__init__.py b/torch_molecule/generator/molgan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch_molecule/generator/molgan/molgan_dataset.py b/torch_molecule/generator/molgan/molgan_dataset.py new file mode 100644 index 0000000..bcea90a --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_dataset.py @@ -0,0 +1,109 @@ +from typing import List, Optional, Callable +from rdkit import Chem +import torch +from torch.utils.data import Dataset +from .molgan_utils import qed_reward_fn + +class MolGANDataset(Dataset): + """ + A PyTorch Dataset for MolGAN, with all RDKit and graph tensor processing + precomputed in __init__ for fast, pure-tensor __getitem__ access. + Optionally caches property values for each molecule. + """ + def __init__( + self, + data: List[str], + atom_types: List[str], + bond_types: List[str], + max_num_atoms: int = 50, + cache_properties: bool = False, + property_fn: Optional[Callable] = None, + return_mol: bool = False, + device: Optional[torch.device] = None + ): + self.data = data + self.atom_types = atom_types + self.bond_types = bond_types + self.max_num_atoms = max_num_atoms + self.atom_type_to_idx = {atom: idx for idx, atom in enumerate(atom_types)} + self.bond_type_to_idx = {bond: idx for idx, bond in enumerate(bond_types)} + self.return_mol = return_mol + self.device = torch.device(device) if device is not None else None + + self.node_features = [] + self.adjacency_matrices = [] + self.mols = [] + self.cached_properties = [] if cache_properties and property_fn else None + + self.property_fn = property_fn if property_fn is not None else qed_reward_fn + + for idx, smiles in enumerate(self.data): + mol = Chem.MolFromSmiles(smiles) + self.mols.append(mol) + # Default: if invalid, fill with zeros and (optionally) property 0 + nf = torch.zeros((self.max_num_atoms, len(self.atom_types)), dtype=torch.float) + adj = torch.zeros((self.max_num_atoms, self.max_num_atoms, len(self.bond_types)), dtype=torch.float) + prop_val = 0.0 if cache_properties else None + + if mol is not None: + num_atoms = mol.GetNumAtoms() + if num_atoms > self.max_num_atoms: + raise ValueError(f"Molecule at index {idx} exceeds max_num_atoms: {num_atoms} > {self.max_num_atoms}") + + for i, atom in enumerate(mol.GetAtoms()): + atom_type = atom.GetSymbol() + if atom_type in self.atom_type_to_idx: + nf[i, self.atom_type_to_idx[atom_type]] = 1.0 + + for bond in mol.GetBonds(): + begin_idx = bond.GetBeginAtomIdx() + end_idx = bond.GetEndAtomIdx() + bond_type = str(bond.GetBondType()) + if bond_type in self.bond_type_to_idx: + bidx = self.bond_type_to_idx[bond_type] + adj[begin_idx, end_idx, bidx] = 1.0 + adj[end_idx, begin_idx, bidx] = 1.0 + + if cache_properties and self.property_fn: + try: + prop_val = self.property_fn(mol) + except Exception: + prop_val = 0.0 + + # Move tensors to device immediately if a device is set + if self.device is not None: + nf = nf.to(self.device) + adj = adj.to(self.device) + + self.node_features.append(nf) + self.adjacency_matrices.append(adj) + if cache_properties and property_fn and self.cached_properties is not None: + self.cached_properties.append(prop_val) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + parts = [ + self.node_features[idx], + self.adjacency_matrices[idx] + ] + # add optional property + if self.cached_properties is not None: + parts.append(self.cached_properties[idx]) + # add optional Mol object (can always access it if you want) + if self.return_mol: + parts.append(self.mols[idx]) + + # Default: (node_features, adjacency_matrix) + # With property: (node_features, adjacency_matrix, property) + # With property and mol: (node_features, adjacency_matrix, property, mol) + # With only mol: (node_features, adjacency_matrix, mol) + return tuple(parts) + + + + + + + diff --git a/torch_molecule/generator/molgan/molgan_gen_disc.py b/torch_molecule/generator/molgan/molgan_gen_disc.py new file mode 100644 index 0000000..e71b971 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_gen_disc.py @@ -0,0 +1,201 @@ +import torch +from dataclasses import dataclass +from typing import Tuple +from .molgan_r_gcn import RelationalGCNLayer # Local import to avoid circular dependency +import torch.nn.functional as F + +@dataclass +class MolGANGeneratorConfig: + def __init__( + self, + z_dim: int = 32, + g_conv_dim: int = 64, + d_conv_dim: int = 64, + g_num_layers: int = 3, + d_num_layers: int = 3, + num_atom_types: int = 5, + num_bond_types: int = 4, + max_num_atoms: int = 9, + dropout: float = 0.0, + tau: float = 1.0, + use_batchnorm: bool = True, + device: str = 'cuda' if torch.cuda.is_available() else 'cpu', + ): + self.z_dim = z_dim + self.g_conv_dim = g_conv_dim + self.d_conv_dim = d_conv_dim + self.g_num_layers = g_num_layers + self.d_num_layers = d_num_layers + self.num_atom_types = num_atom_types + self.num_bond_types = num_bond_types + self.max_num_atoms = max_num_atoms + self.dropout = dropout + self.use_batchnorm = use_batchnorm + self.tau = tau # Gumbel-Softmax temperature + self.device = device + + +# MolGAN Generotor +class MolGANGenerator(torch.nn.Module): + def __init__(self, config: MolGANGeneratorConfig): + super(MolGANGenerator, self).__init__() + self.z_dim = config.z_dim + self.g_conv_dim = config.g_conv_dim + self.g_num_layers = config.g_num_layers + self.num_atom_types = config.num_atom_types + self.num_bond_types = config.num_bond_types + self.max_num_atoms = config.max_num_atoms + self.dropout = config.dropout + self.use_batchnorm = config.use_batchnorm + self.tau = config.tau + self.device = config.device + self.to(self.device) + + layers = [] + input_dim = self.z_dim + for i in range(self.g_num_layers): + output_dim = self.g_conv_dim * (2 ** i) + layers.append(torch.nn.Linear(input_dim, output_dim)) + if self.use_batchnorm: + layers.append(torch.nn.BatchNorm1d(output_dim)) + layers.append(torch.nn.ReLU()) + if self.dropout > 0: + layers.append(torch.nn.Dropout(self.dropout)) + input_dim = output_dim + + self.fc_layers = torch.nn.Sequential(*layers) + self.atom_fc = torch.nn.Linear(input_dim, self.max_num_atoms * self.num_atom_types) + self.bond_fc = torch.nn.Linear(input_dim, self.num_bond_types * self.max_num_atoms * self.max_num_atoms) + + def forward(self, z: torch.Tensor, sample_mode='softmax') -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = z.size(0) + h = self.fc_layers(z) + atom_logits = self.atom_fc(h).view(batch_size, self.max_num_atoms, self.num_atom_types) + # Output bond logits with [batch, num_bond_types, max_num_atoms, max_num_atoms] order + bond_logits = self.bond_fc(h).view(batch_size, self.num_bond_types, self.max_num_atoms, self.max_num_atoms) + + # Nodes + if sample_mode == 'softmax': + node = torch.softmax(atom_logits, dim=-1) + elif sample_mode == 'soft_gumbel': + node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=False, dim=-1) + elif sample_mode == 'hard_gumbel': + node = F.gumbel_softmax(atom_logits, tau=self.tau, hard=True, dim=-1) + elif sample_mode == 'argmax': + node = atom_logits.argmax(dim=-1) + else: + raise ValueError(f"Unknown sample_mode: {sample_mode}") + + # Adjacency + if sample_mode == 'softmax': + adj = torch.softmax(bond_logits, dim=1) + elif sample_mode == 'soft_gumbel': + adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=False, dim=1) + elif sample_mode == 'hard_gumbel': + adj = F.gumbel_softmax(bond_logits, tau=self.tau, hard=True, dim=1) + else: + raise ValueError(f"Unknown sample_mode: {sample_mode}") + + return node, adj + + + + + +# MolGAN Discriminator +@dataclass +class MolGANDiscriminatorConfig: + def __init__( + self, + in_dim: int = 5, # Number of atom types (node feature dim). Typically set automatically. + hidden_dim: int = 64, # Hidden feature/channel size for GCN layers. + num_layers: int = 3, # Number of R-GCN layers (depth). + num_relations: int = 4, # Number of bond types (relation types per edge). + max_num_atoms: int = 9, # Max node count in padded tensor. + dropout: float = 0.0, # Dropout between layers. + use_batchnorm: bool = True, # BatchNorm or similar normalization. + readout: str = 'sum', # Readout type (sum/mean/max for pooling nodes to graph-level vector) + device: str = 'cuda' if torch.cuda.is_available() else 'cpu', + ): + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.num_relations = num_relations + self.max_num_atoms = max_num_atoms + self.dropout = dropout + self.use_batchnorm = use_batchnorm + self.readout = readout + self.device = device + + +class MolGANDiscriminator(torch.nn.Module): + def __init__(self, config: MolGANDiscriminatorConfig): + super(MolGANDiscriminator, self).__init__() + + self.in_dim = config.in_dim + self.hidden_dim = config.hidden_dim + self.num_layers = config.num_layers + self.num_relations = config.num_relations + self.max_num_atoms = config.max_num_atoms + self.dropout = config.dropout + self.use_batchnorm = config.use_batchnorm + self.readout = config.readout + + layers = [] + input_dim = self.in_dim + for i in range(self.num_layers): + output_dim = self.hidden_dim * (2 ** i) + layers.append(RelationalGCNLayer(input_dim, output_dim, self.num_relations)) + if self.use_batchnorm: + layers.append(torch.nn.BatchNorm1d(self.max_num_atoms)) + layers.append(torch.nn.LeakyReLU(0.2)) + if self.dropout > 0: + layers.append(torch.nn.Dropout(self.dropout)) + input_dim = output_dim + + self.gcn_layers = torch.nn.ModuleList(layers) + self.fc = torch.nn.Linear(input_dim, 1) + self.device = config.device + self.to(self.device) + + def forward( + self, + atom_feats: torch.Tensor, + adj: torch.Tensor, + mask: torch.Tensor + ) -> torch.Tensor: + # atom_feats: [batch, max_num_atoms, num_atom_types] + # adj: [batch, num_bond_types, max_num_atoms, max_num_atoms] + # mask: [batch, max_num_atoms] (float, 1=real, 0=pad) + h = atom_feats + for layer in self.gcn_layers: + if isinstance(layer, RelationalGCNLayer): + h = layer(h, adj) + else: + # If using BatchNorm1d, input should be [batch, features, nodes] + if isinstance(layer, torch.nn.BatchNorm1d): + # Permute for batchnorm: [batch, nodes, features] → [batch, features, nodes] + h = layer(h.permute(0, 2, 1)).permute(0, 2, 1) + else: + h = layer(h) + + # MASKED GRAPH READOUT + # mask: [batch, max_num_atoms] float + mask = mask.unsqueeze(-1) # [batch, max_num_atoms, 1] + h_masked = h * mask # zeros padded nodes + + if self.readout == 'sum': + g = h_masked.sum(dim=1) # [batch, hidden_dim] + elif self.readout == 'mean': + # Prevent divide-by-zero with (mask.sum(dim=1, keepdim=True)+1e-8) + g = h_masked.sum(dim=1) / (mask.sum(dim=1) + 1e-8) + elif self.readout == 'max': + # Set padded to large neg, then max + h_masked_pad = h.clone() + h_masked_pad[mask.squeeze(-1) == 0] = float('-inf') + g, _ = h_masked_pad.max(dim=1) + else: + raise ValueError(f"Unknown readout type: {self.readout}") + + out = self.fc(g) # [batch, 1] + return out.squeeze(-1) # [batch] diff --git a/torch_molecule/generator/molgan/molgan_generator.py b/torch_molecule/generator/molgan/molgan_generator.py new file mode 100644 index 0000000..4a02164 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_generator.py @@ -0,0 +1,206 @@ +import torch +from typing import Optional, Union, List, Callable +from .molgan_model import MolGANModel +from .molgan_gen_disc import MolGANGeneratorConfig, MolGANDiscriminatorConfig +from torch_molecule.base.generator import BaseMolecularGenerator +from .molgan_dataset import MolGANDataset +from torch_molecule.utils import graph_to_smiles, graph_from_smiles + + + + +class MolGANGenerativeModel(BaseMolecularGenerator): + + """ + This generator implements the MolGAN model for molecular graph generation. + + The model uses a GAN like architecture with a generator and discriminator, + combined with a reward network to optimize for desired molecular properties. + The generator produces molecular graphs represented as adjacency matrices, with the discriminator + and reward network evaluating their validity and quality. The reward network can be trained to optimize + for specific chemical properties, such as drug-likeness or synthetic accessibility. + + + References: + ---------- + - De Cao, N., & Kipf, T. (2018). MolGAN: An implicit generative model for small molecular graphs. + arXiv preprint arXiv:1805.11973. Link: https://arxiv.org/pdf/1805.11973 + + Parameters: + ---------- + MolGANGeneratorConfig : MolGANGeneratorConfig, optional + Configuration for the generator network. If None, default values are used. + + MolGANDiscriminatorConfig : MolGANDiscriminatorConfig, Optional + Configuration for the discriminator and reward network. If None, default values are used. + + Lambda_rl : float, Optional + Weight for the reinforcement learning reward in the generator loss. Default is 0.25. + + device : Optional[Union[torch.device, str]], optional + Device to run the model on. If None, defaults to CPU or GPU if available. + + model_name : str, Optional + Name of the model. Default is "MolGANGenerativeModel". + + """ + + def __init__( + self, + generator_config: Optional[MolGANGeneratorConfig] = None, + discriminator_config: Optional[MolGANDiscriminatorConfig] = None, + lambda_rl: float = 0.25, + device: Optional[Union[torch.device, str]] = None, + model_name: str = "MolGANGenerativeModel", + ): + super().__init__(device=device, model_name=model_name) + + # Initialize MolGAN model + self.model = MolGANModel( + generator_config=generator_config, + discriminator_config=discriminator_config, + reward_network_config=discriminator_config, + ).to(self.device) + + self.lambda_rl = lambda_rl + self.gen_optimizer = None + self.disc_optimizer = None + self.reward_optimizer = None + self.gen_scheduler = None + self.disc_scheduler = None + self.reward_scheduler = None + self.use_reward = False + + self.epoch = 0 + self.step = 0 + + def training_config( + self, + lambda_rl: float = 0.25, + reward_function: Optional[Callable] = None, + gen_optimizer: Optional[torch.optim.Optimizer] = None, + disc_optimizer: Optional[torch.optim.Optimizer] = None, + reward_optimizer: Optional[torch.optim.Optimizer] = None, + gen_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + disc_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + reward_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + use_reward: bool = True, + epochs: int = 300, + batch_size: int = 32, + atom_types: List[str] = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H'], + bond_types: List[str] = ['SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC'], + max_num_atoms: int = 50, + ): + """ + Configure training parameters for MolGAN. + + Parameters: + ---------- + lambda_rl : float, Optional + Weight for the reinforcement learning reward in the generator loss. Default is 0.25. + + reward_function : Optional[Callable], optional + + gen_optimizer : torch.optim.Optimizer + Optimizer for the generator network. + + disc_optimizer : torch.optim.Optimizer + Optimizer for the discriminator network. + + reward_optimizer : Optional[torch.optim.Optimizer], optional + Optimizer for the reward network. If None, the discriminator optimizer is used. + + gen_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the generator optimizer. + + disc_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the discriminator optimizer. + + reward_scheduler : Optional[torch.optim.lr_scheduler._LRScheduler], optional + Learning rate scheduler for the reward optimizer. + + use_reward : bool, optional + Whether to use the reward network during training. Default is True. + + epochs : int + Number of training epochs. Default is 300. + + atom_types : List[str] + List of atom types to consider in the molecular graphs. Default includes common organic atoms. + + bond_types : List[str] + List of bond types to consider in the molecular graphs. Default includes common bond types. + + max_num_atoms : int + Maximum number of atoms in the generated molecular graphs. Default is 50. + """ + + if gen_optimizer is None: gen_optimizer = torch.optim.Adam(self.model.gen.parameters(), lr=0.0001, betas=(0.5, 0.999)) + if disc_optimizer is None: disc_optimizer = torch.optim.Adam(self.model.disc.parameters(), lr=0.0001, betas=(0.5, 0.999)) + if reward_optimizer is None: reward_optimizer = disc_optimizer + + self.model.config_training( + gen_optimizer=gen_optimizer, + disc_optimizer=disc_optimizer, + lambda_rl=lambda_rl, + reward_optimizer=reward_optimizer, + gen_scheduler=gen_scheduler, + disc_scheduler=disc_scheduler, + reward_scheduler=reward_scheduler, + ) + self.lambda_rl = lambda_rl + self.reward_function = reward_function + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + self.reward_optimizer = ( + reward_optimizer if reward_optimizer is not None else disc_optimizer + ) + self.gen_scheduler = gen_scheduler + self.disc_scheduler = disc_scheduler + self.reward_scheduler = reward_scheduler + self.use_reward = use_reward + self.epochs = epochs + self.atom_types = atom_types + self.bond_types = bond_types + self.max_num_atoms = max_num_atoms + self.batch_size = batch_size + + + def fit( self, X:List[str], y=None ) -> "BaseMolecularGenerator": + """ + Fit the MolGAN model to the training data. + + Parameters: + ---------- + X : List[str] + List of SMILES strings representing the training molecules. + + y : Optional[np.ndarray], optional + Optional array of target values for supervised training. Default is None. (Not used in MolGAN) + """ + + if self.gen_optimizer is None or self.disc_optimizer is None: + # raise ValueError("Please configure training optimizers using `training_config()` before fitting.") + # Set default optimizers if not configured + self.training_config( + gen_optimizer=torch.optim.Adam(self.model.gen.parameters(), lr=0.0001, betas=(0.5, 0.999)), + disc_optimizer=torch.optim.Adam(self.model.disc.parameters(), lr=0.0001, betas=(0.5, 0.999)), + lambda_rl=0.25, + use_reward=True, + ) + + # Create a dataloader from the SMILES strings + dataset = MolGANDataset(data=X, atom_types=self.atom_types, bond_types=self.bond_types, max_num_atoms=self.max_num_atoms, return_mol=False, device=self.device) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + self.model.train() + for _ in range(self.epochs): + self.model.train_epoch( + dataloader, + reward_fn= None if not self.use_reward else self.reward_function + ) + + return self + + + diff --git a/torch_molecule/generator/molgan/molgan_model.py b/torch_molecule/generator/molgan/molgan_model.py new file mode 100644 index 0000000..07a3b68 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_model.py @@ -0,0 +1,170 @@ +from typing import Optional +import torch +import torch.nn as nn +from .molgan_gen_disc import * + + +class MolGANModel(nn.Module): + def __init__( + self, + generator_config: Optional[MolGANGeneratorConfig] = None, + discriminator_config: Optional[MolGANDiscriminatorConfig] = None, + reward_network_config: Optional[MolGANDiscriminatorConfig] = None, + ): + super(MolGANModel, self).__init__() + + # Initialize generator and discriminator + self.gen_config = generator_config if generator_config is not None else MolGANGeneratorConfig() + self.gen: MolGANGenerator = MolGANGenerator(self.gen_config) + + self.disc_config = discriminator_config if discriminator_config is not None else MolGANDiscriminatorConfig() + self.disc = MolGANDiscriminator(self.disc_config) + + # By default, the reward network is the same as the discriminator + self.reward_net = ( + MolGANDiscriminator(reward_network_config) + if reward_network_config is not None + else MolGANDiscriminator(self.disc_config) + ) + + + def generate(self, batch_size: int, sample_mode: Optional[str] = None): + """Generate a batch of molecules.""" + z = torch.randn( + batch_size, + self.gen_config.z_dim, + device = torch.device(self.gen.device) + ) + if sample_mode is None: + if self.training: + return self.gen(z, sample_mode='softmax') + else: + return self.gen(z, sample_mode='argmax') + else: + return self.gen(z, sample_mode=sample_mode) + + + def discriminate( + self, + atom_type_matrix: torch.Tensor, + bond_type_tensor: torch.Tensor, + molecule_mask: Optional[torch.Tensor], + ): + """Discriminate a batch of molecules.""" + return self.disc(atom_type_matrix, bond_type_tensor, molecule_mask) + + def reward( + self, + atom_type_matrix: torch.Tensor, + bond_type_tensor: torch.Tensor, + molecule_mask: Optional[torch.Tensor], + ): + """Compute reward for a batch of molecules.""" + return self.reward_net(atom_type_matrix, bond_type_tensor, molecule_mask) + + def config_training( + self, + gen_optimizer: torch.optim.Optimizer, + disc_optimizer: torch.optim.Optimizer, + lambda_rl: float = 0.25, + reward_optimizer: Optional[torch.optim.Optimizer] = None, + gen_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + disc_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + reward_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """Configure optimizers and schedulers for training.""" + self.gen_optimizer = gen_optimizer + self.disc_optimizer = disc_optimizer + self.reward_optimizer = reward_optimizer + self.lambda_rl = lambda_rl + + self.gen_scheduler = gen_scheduler + self.disc_scheduler = disc_scheduler + self.reward_scheduler = reward_scheduler + + def training_step(self, batch, reward_fn, pretrain=False): + node_features, adjacency_matrix = batch + batch_size = node_features.size(0) + adjacency_matrix = adjacency_matrix.permute(0, 3, 1, 2) + mask = (node_features.sum(-1) != 0).float() + + z = torch.randn(batch_size, self.gen_config.z_dim, device=node_features.device) + fake_atom_logits, fake_bond_logits = self.gen(z, sample_mode='softmax') + fake_atom = torch.softmax(fake_atom_logits, -1) + fake_bond = torch.softmax(fake_bond_logits, 1) + fake_mask = (fake_atom.argmax(-1) != 0).float() + + # === Discriminator update === + self.disc_optimizer.zero_grad() + real_scores = self.disc(node_features, adjacency_matrix, mask) + fake_scores = self.disc(fake_atom, fake_bond, fake_mask) + wgan_loss = -(real_scores.mean() - fake_scores.mean()) + wgan_loss.backward() + self.disc_optimizer.step() + if self.disc_scheduler: self.disc_scheduler.step() + + # === Reward net update === + if self.reward_optimizer is not None: + self.reward_optimizer.zero_grad() + reward_targets = reward_fn(node_features, adjacency_matrix, mask) + pred_rewards = self.reward_net(node_features, adjacency_matrix, mask) + r_loss = torch.nn.functional.mse_loss(pred_rewards, reward_targets) + r_loss.backward() + self.reward_optimizer.step() + if self.reward_scheduler: self.reward_scheduler.step() + else: + r_loss = torch.tensor(0.0, device=node_features.device) + + # === Generator update === + self.gen_optimizer.zero_grad() + fake_atom_logits, fake_bond_logits = self.gen(z, sample_mode='softmax') + fake_mask = (fake_atom.argmax(-1) != 0).float() + fake_scores = self.disc(fake_atom, fake_bond, fake_mask) + g_wgan_loss = -fake_scores.mean() + if not pretrain and hasattr(self, 'lambda_rl') and self.lambda_rl > 0: + with torch.no_grad(): + rewards = self.reward_net(fake_atom, fake_bond, fake_mask) + rl_loss = -rewards.mean() + else: + rl_loss = torch.tensor(0.0, device=node_features.device) + total_loss = g_wgan_loss + getattr(self, 'lambda_rl', 0.0) * rl_loss + total_loss.backward() + self.gen_optimizer.step() + if self.gen_scheduler: self.gen_scheduler.step() + + return { + 'd_loss': wgan_loss.item(), + 'g_loss': g_wgan_loss.item(), + 'rl_loss': rl_loss.item(), + 'r_loss': r_loss.item() if self.reward_optimizer is not None else None + } + + + def train_epoch(self, dataloader, reward_fn, pretrain=False, log_interval=100): + self.gen.train() + self.disc.train() + if self.reward_net: self.reward_net.train() + for i, batch in enumerate(dataloader): + result = self.training_step(batch, reward_fn, pretrain) + if i % log_interval == 0: + print({k: round(v, 5) for k, v in result.items()}) + + def evaluate(self, dataloader, reward_fn): + self.gen.eval() + self.disc.eval() + if self.reward_net: self.reward_net.eval() + eval_metrics = {'d_loss': 0.0, 'g_loss': 0.0, 'rl_loss': 0.0, 'r_loss': 0.0} + count = 0 + with torch.no_grad(): + for batch in dataloader: + result = self.training_step(batch, reward_fn, pretrain=False) + for k in eval_metrics.keys(): + if result[k] is not None: + eval_metrics[k] += result[k] + count += 1 + for k in eval_metrics.keys(): + eval_metrics[k] /= count + return {k: round(v, 5) for k, v in eval_metrics.items()} + + + diff --git a/torch_molecule/generator/molgan/molgan_r_gcn.py b/torch_molecule/generator/molgan/molgan_r_gcn.py new file mode 100644 index 0000000..c0e8124 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_r_gcn.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + +class RelationalGCNLayer(nn.Module): + """ + Relational Graph Convolutional Layer for fully connected dense graphs. + Input: + - node_feats: [batch, num_nodes, in_dim] + - adj: [batch, num_relations, num_nodes, num_nodes] + Output: + - node_feats: [batch, num_nodes, out_dim] + """ + def __init__(self, in_dim, out_dim, num_relations, use_bias=True): + super(RelationalGCNLayer, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.num_relations = num_relations + + # One weight matrix per relation/bond type + self.rel_weights = nn.Parameter(torch.Tensor(num_relations, in_dim, out_dim)) + if use_bias: + self.bias = nn.Parameter(torch.Tensor(out_dim)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.rel_weights) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, node_feats, adj): + # node_feats: [batch, num_nodes, in_dim] + # adj: [batch, num_relations, num_nodes, num_nodes] + batch_size, num_nodes, _ = node_feats.shape + + out = torch.zeros(batch_size, num_nodes, self.out_dim, device=node_feats.device) + + for rel in range(self.num_relations): + # Multiply node features by relation weight + # [batch, num_nodes, in_dim] @ [in_dim, out_dim] -> [batch, num_nodes, out_dim] + h_rel = torch.matmul(node_feats, self.rel_weights[rel]) + # Propagate messages using adjacency for this relation: + # [batch, num_nodes, out_dim] ← [batch, num_nodes, num_nodes] @ [batch, num_nodes, out_dim] + # Here adj[:, rel, :, :] gives [batch, num_nodes, num_nodes] + out += torch.bmm(adj[:, rel], h_rel) + + if self.bias is not None: + out += self.bias + + return out # You can add activation after this (ReLU, LeakyReLU, etc.) + diff --git a/torch_molecule/generator/molgan/molgan_utils.py b/torch_molecule/generator/molgan/molgan_utils.py new file mode 100644 index 0000000..1f68ce3 --- /dev/null +++ b/torch_molecule/generator/molgan/molgan_utils.py @@ -0,0 +1,14 @@ +from rdkit.Chem import QED + +# This is used as the default reward function for MolGAN +def qed_reward_fn(mol): + """ + Computes the QED score of a single RDKit Mol object. + Returns 0.0 for invalid molecules or errors. + """ + if mol is not None: + try: + return QED.qed(mol) + except Exception: + return 0.0 + return 0.0