diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py new file mode 100644 index 00000000..6971a29e --- /dev/null +++ b/chebai/loss/boxes.py @@ -0,0 +1,77 @@ +import torch + +class BoxLoss(torch.nn.Module): + def __init__( + self, base_loss: torch.nn.Module = None + ): + super().__init__() + self.base_loss = base_loss + + def forward(self, input, target, **kwargs): + b = input["boxes"] + points = input["embedded_points"] + target = target.float().unsqueeze(-1) + left_borders, lind = torch.min(b, dim=-1) + right_borders, rind = torch.max(b, dim=-1) + width = (right_borders - left_borders) / 2 + + # We want some safety margins around boxes. (False) positives should be drawn + # further into the box, whilst (false) negatives should be pushed further outside. + # Therefore, we use different borders for (false) positives and negatives. + r_fp = right_borders + 0.1 * width + r_fn = right_borders - 0.1 * width + l_fp = left_borders - 0.1 * width + l_fn = left_borders + 0.1 * width + inside_fp = (l_fp < points) * (points < r_fp) + inside_fn = (l_fn < points) * (points < r_fn) + + # False positive and negatives, w.r.t. the adapted box borders + fn_per_dim = ~inside_fn * target + fp_per_dim = inside_fp * (1 - target) + + # We also want to penalise wrong memberships in different dimensions. This + # is important, because a false positive in a single dimension is not wrong, + # if at least one dimension is true negative. + false_per_dim = fn_per_dim + fp_per_dim + all_dimensions_wrong = torch.min(false_per_dim, dim=-1, keepdim=True)[0] + + + # We calculate the gradient for left and right border simultaneously, but we only need the one + # closest to the point. Therefore, we create a filter for that. + dl = torch.abs(left_borders - points) + dr = torch.abs(right_borders - points) + closer_to_l_than_r = dl < dr + + # The scaling factor encodes the conjunction of whether the respective dimension is false and whether the respective + # border is the closest to the point. + + r_scale_fp = all_dimensions_wrong * (fp_per_dim * ~closer_to_l_than_r) + l_scale_fp = all_dimensions_wrong * (fp_per_dim * closer_to_l_than_r) + + r_scale_fn = (fn_per_dim * ~closer_to_l_than_r) + l_scale_fn = (fn_per_dim * closer_to_l_than_r) + + d_r_fp = r_scale_fp * torch.abs(r_fp - points) + d_l_fp = l_scale_fp * torch.abs(l_fp - points) + d_r_fn = r_scale_fn * torch.abs(r_fn - points) + d_l_fn = l_scale_fn * torch.abs(l_fn - points) + + w_r_fp = torch.nn.functional.softmin(d_r_fp, dim=-1) + w_r_fn = torch.nn.functional.softmin(d_r_fn, dim=-1) + w_l_fp = torch.nn.functional.softmin(d_l_fp, dim=-1) + w_l_fn = torch.nn.functional.softmin(d_l_fn, dim=-1) + + # The loss for a border is then the mean of the scaled vector between the points for which the model would + # produce a wrong prediction and the closest border of the box + r_loss = torch.mean(torch.sum(w_r_fp * d_r_fp, dim=(1, 2)) + torch.sum(w_r_fn * d_r_fn, dim=(1, 2))) + l_loss = torch.mean(torch.sum(w_l_fp * d_l_fp, dim=(1, 2)) + torch.sum(w_l_fn * d_l_fn, dim=(1, 2))) + return l_loss + r_loss + + def _calculate_implication_loss(self, l, r): + capped_difference = torch.relu(l - r) + return torch.mean( + torch.sum( + (torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1 + ), + dim=0, + ) diff --git a/chebai/models/base.py b/chebai/models/base.py index 3a149832..cd6ab347 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -4,7 +4,7 @@ from lightning.pytorch.core.module import LightningModule import torch - +import pickle from chebai.preprocessing.structures import XYData logging.getLogger("pysmiles").setLevel(logging.CRITICAL) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 267a0063..003703e9 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -16,6 +16,7 @@ from chebai.models.base import ChebaiBaseNet from chebai.preprocessing.reader import CLS_TOKEN, MASK_TOKEN_INDEX import pytorch_lightning as pl +import math logging.getLogger("pysmiles").setLevel(logging.CRITICAL) @@ -94,8 +95,8 @@ def filter_dict(d, filter_key): } -class Electra(ChebaiBaseNet): - NAME = "Electra" +class ElectraBasedModel(ChebaiBaseNet): + NAME = "ElectraBase" def _process_batch(self, batch, batch_idx): model_kwargs = dict() @@ -122,17 +123,17 @@ def _process_batch(self, batch, batch_idx): idents=batch.additional_fields["idents"], ) - @property - def as_pretrained(self): - return self.electra.electra - def __init__( self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs ): # Remove this property in order to prevent it from being stored as a # hyper parameter + super().__init__(**kwargs) + + + if config is None: config = dict() if not "num_labels" in config and self.out_dim is not None: @@ -141,13 +142,7 @@ def __init__( self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) in_d = self.config.hidden_size - self.output = nn.Sequential( - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, in_d), - nn.GELU(), - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, self.config.num_labels), - ) + if pretrained_checkpoint: with open(pretrained_checkpoint, "rb") as fin: model_dict = torch.load(fin, map_location=self.device) @@ -161,13 +156,16 @@ def __init__( else: self.electra = ElectraModel(config=self.config) + if self.out_dim is None: + self.out_dim = self.electra.config.hidden_size + def _process_for_loss(self, model_output, labels, loss_kwargs): kwargs_copy = dict(loss_kwargs) mask = kwargs_copy.pop("target_mask", None) if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask + d = model_output["output"] * mask - 100 * ~mask else: - d = model_output["logits"] + d = model_output["output"] if labels is not None: labels = labels.float() return d, labels, kwargs_copy @@ -175,14 +173,14 @@ def _process_for_loss(self, model_output, labels, loss_kwargs): def _get_prediction_and_labels(self, data, labels, model_output): mask = model_output.get("target_mask") if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask + d = model_output["output"] * mask - 100 * ~mask else: - d = model_output["logits"] + d = model_output["output"] loss_kwargs = data.get("loss_kwargs", dict()) if "non_null_labels" in loss_kwargs: n = loss_kwargs["non_null_labels"] d = d[n] - return torch.sigmoid(d), labels.int() + return d, labels.int() def forward(self, data, **kwargs): self.batch_size = data["features"].shape[0] @@ -196,11 +194,43 @@ def forward(self, data, **kwargs): electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] return dict( - logits=self.output(d), + output=d, attentions=electra.attentions, target_mask=data.get("target_mask"), ) +class Electra(ElectraBasedModel): + NAME = "Electra" + + def __init__( + self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs + ): + # Remove this property in order to prevent it from being stored as a + # hyper parameter + + super().__init__(**kwargs) + + in_d = self.config.hidden_size + self.output = nn.Sequential( + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, in_d), + nn.GELU(), + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, self.config.num_labels), + ) + + def _get_prediction_and_labels(self, data, labels, model_output): + preds, lbls = super()._get_prediction_and_labels(data, labels, model_output) + return torch.sigmoid(preds), lbls + + def forward(self, data, **kwargs): + d = super().forward(data, **kwargs) + return dict( + output=self.output(d["output"]), + attentions=d["attentions"], + target_mask=d["target_mask"], + ) + class ElectraLegacy(ChebaiBaseNet): NAME = "ElectraLeg" @@ -233,125 +263,133 @@ def __init__(self, **kwargs): def forward(self, data): electra = self.electra(data) d = torch.sum(electra.last_hidden_state, dim=1) - return dict(logits=self.output(d), attentions=electra.attentions) - - - -class ChebiBoxWithMemberships(ChebaiBaseNet): + return dict(output=self.output(d), attentions=electra.attentions) + +def gbmf(x, l, r, b=6): + a = (r - l) + 1e-3 + c = l + (r - l) / 2 + return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) + +def gbmf_adjusted_slope(x, l, r): + l = l.to(torch.float32) + r = r.to(torch.float32) + segment = torch.abs(r - l) + is_close = torch.allclose(l, r, rtol=1e-2, atol=1e-3) + a = torch.abs(r - l) - 1e-3 if is_close else 1 + (0.7 * torch.abs(r - l)) + c = l + (r - l) / 2 + b = 2 / (1 + torch.exp(-(segment - 5) / 5)) + membership = 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) + return membership + +def normal(sigma, mu, x): + v = (x - mu) / sigma + return torch.exp(-0.5 * v * v) + + +class ChebiBoxWithMemberships(ElectraBasedModel): NAME = "ChebiBoxWithMemberships" - def _process_batch(self, batch, batch_idx): - model_kwargs = dict() - loss_kwargs = batch.additional_fields["loss_kwargs"] - if "lens" in batch.additional_fields["model_kwargs"]: - model_kwargs["attention_mask"] = pad_sequence( - [ - torch.ones(l + 1, device=self.device) - for l in batch.additional_fields["model_kwargs"]["lens"] - ], - batch_first=True, - ) - cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN - ) - return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), - labels=batch.y, - model_kwargs=model_kwargs, - loss_kwargs=loss_kwargs, - idents=batch.additional_fields["idents"], - ) - def __init__( - self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs + self, **kwargs ): super().__init__(**kwargs) - if config is None: - config = dict() - if not "num_labels" in config and self.out_dim is not None: - config["num_labels"] = self.out_dim - self.config = ElectraConfig(**config, output_attentions=True) - self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) - - if pretrained_checkpoint: - with open(pretrained_checkpoint, "rb") as fin: - model_dict = torch.load(fin, map_location=self.device) - if load_prefix: - state_dict = filter_dict(model_dict["state_dict"], load_prefix) - else: - state_dict = model_dict["state_dict"] - self.electra = ElectraModel.from_pretrained( - None, state_dict=state_dict, config=self.config - ) - else: - self.electra = ElectraModel(config=self.config) + self.membership_method = self.config.membership_method + self.dimension_aggregation = self.config.dimension_aggregation - self.config = ElectraConfig(**config, output_attentions=True) self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions - self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 ) + + # Random boxes + self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2))) + + # Boxes with (relatively) same sizes + """ + base_box = torch.rand((2, self.out_dim)) + similar_boxes = base_box.repeat((self.config.num_labels, 1, 1)) + small_differences = 0.05 * torch.randn_like(similar_boxes) + similar_boxes += small_differences + self.boxes = nn.Parameter(similar_boxes) + """ self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), - nn.Linear(self.hidden_dim, self.out_dim) + nn.Linear(self.hidden_dim, self.out_dim), ) - def _process_for_loss(self, model_output, labels, loss_kwargs): - kwargs_copy = dict(loss_kwargs) - mask = kwargs_copy.pop("target_mask", None) - if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask - else: - d = model_output["logits"] - if labels is not None: - labels = labels.float() - return d, labels, kwargs_copy + def _prod_agg(self, memberships, dim=-1): + return torch.prod(memberships, dim=dim) - def _get_prediction_and_labels(self, data, labels, model_output): - mask = model_output.get("target_mask") - if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask - else: - d = model_output["logits"] - loss_kwargs = data.get("loss_kwargs", dict()) - if "non_null_labels" in loss_kwargs: - n = loss_kwargs["non_null_labels"] - d = d[n] - return torch.sigmoid(d), labels.int() + def _soft_lukaziewicz_agg_2(self, memberships, scale=10): + # TORCH.LOG1P is more accurate than torch.log() for small values of input + # https://pytorch.org/docs/stable/generated/torch.log1p.html + return torch.log1p( + torch.exp( + scale * (torch.sum(memberships, dim=-1) - (memberships.shape[-1] - 1)) + ) + ) / scale + + def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): + """ + This is a version of the Ɓukaziewish-T-norm using a modified softplus instead of max + """ + return ( + 1 + / scale + * torch.log( + 1 + + torch.exp( + math.log(math.exp(scale) - 1) + * (torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1)) + ) + ) + ) + + def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): + return gbmf_adjusted_slope(points, left_corners, right_corners) + + def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs): + widths = 0.1 * (right_corners - left_corners) + max_distance_per_dim = nn.functional.relu( + left_corners - points + widths**0.5 + ) + nn.functional.relu(points - right_corners + widths**0.5) + return normal(widths**0.5, 0, max_distance_per_dim) def forward(self, data, **kwargs): - self.batch_size = data["features"].shape[0] - inp = self.electra.embeddings.forward(data["features"]) - inp = self.word_dropout(inp) - electra = self.electra(inputs_embeds=inp, **kwargs) - d = electra.last_hidden_state[:, 0, :] - points = self.embeddings_to_points(d) - self.points = points + d = super().forward(data, **kwargs) + points = self.embeddings_to_points(d["output"]).unsqueeze(1) - b = self.boxes.expand(self.batch_size, -1, -1, -1) + b = self.boxes.unsqueeze(0) l = torch.min(b, dim=-1)[0] r = torch.max(b, dim=-1)[0] - p = points.expand(self.config.num_labels, -1, -1).transpose(1, 0) - center = torch.mean(torch.stack([l, r]), dim=0) - width = 0.6 * (r - l) - slope = torch.sqrt(torch.abs(r - l)) + if self.membership_method == "normal": + memberships_per_dim = self._forward_normal_membership(points, l, r) + elif self.membership_method == "gbmf": + memberships_per_dim = self._forward_gbmf_membership(points, l, r) + else: + raise Exception("Unknown membership function:", self.membership_method) - membership = 1 / (1 + ((torch.abs(p - center) / width) ** (2 * slope))) - m = torch.mean(membership, dim=-1) + if self.dimension_aggregation == "prod": + aggregated_memberships = self._prod_agg(memberships_per_dim) + elif self.dimension_aggregation == "lukaziewisz": + aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim) + elif self.dimension_aggregation == "min": + aggregated_memberships = self._min_agg(memberships_per_dim) + elif self.dimension_aggregation == "lukaziewisz2": + aggregated_memberships = self._soft_lukaziewicz_agg_2(memberships_per_dim) + + else: + raise Exception("Unknown aggregation function:", self.dimension_aggregation) return dict( boxes=b, embedded_points=points, - logits=m, - attentions=electra.attentions, - target_mask=data.get("target_mask"), + output=aggregated_memberships, + attentions=d["attentions"], + target_mask=d["target_mask"], ) @@ -361,7 +399,6 @@ def __init__(self, **kwargs): self.criteria = nn.BCELoss() def forward(self, outputs, targets, **kwargs): - criterion = self.criteria bce_loss = criterion(outputs, targets) @@ -369,6 +406,44 @@ def forward(self, outputs, targets, **kwargs): return total_loss +class CrispBoxClassifier(ElectraBasedModel): + NAME = "CripsBox" + + def __init__(self, box_dimensions=3, **kwargs): + super().__init__(**kwargs) + + self.point_embedding = nn.Linear(self.config.hidden_size, box_dimensions) + + self.num_boxes = kwargs["out_dim"] + b = torch.randn((self.num_boxes, box_dimensions, 2)) + self.boxes = nn.Parameter(b, requires_grad=True) + + def forward(self, x, **kwargs): + d = super().forward(x, **kwargs) + points = self.point_embedding(d["output"]).unsqueeze(1) + b = self.boxes.unsqueeze(0) + l, lind = torch.min(b, dim=-1) + r, rind = torch.max(b, dim=-1) + inside = torch.all((l < points) * (points < r), dim=-1).float() + return dict( + boxes=b, + embedded_points=points, + output=inside, + attentions=d["attentions"], + target_mask=d["target_mask"], + ) + + def _process_for_loss(self, model_output, labels, loss_kwargs): + kwargs_copy = dict(loss_kwargs) + mask = kwargs_copy.pop("target_mask", None) + if mask is not None: + d = model_output["output"] * mask - 100 * ~mask + else: + d = model_output["output"] + if labels is not None: + labels = labels.float() + model_output["output"] = d + return model_output, labels, kwargs_copy class ConeElectra(ChebaiBaseNet): NAME = "ConeElectra" diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py index f82ce71c..53b01381 100644 --- a/chebai/preprocessing/collect_all.py +++ b/chebai/preprocessing/collect_all.py @@ -191,11 +191,11 @@ def train(train_loader, validation_loader): checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(tb_logger.log_dir, "checkpoints"), filename="{epoch}-{step}-{val_loss:.7f}", - save_top_k=5, save_last=True, verbose=True, monitor="val_loss", mode="min", + every_n_epochs=1 ) trainer = pl.Trainer( logger=tb_logger,