From f3138f2473d8c061dae0909bc53b8febafba3106 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 15:42:55 +0100 Subject: [PATCH 01/12] Add Electra-wrapper as separate class Make training more stable --- chebai/models/electra.py | 180 +++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 102 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 267a0063..e3be02c2 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -94,8 +94,7 @@ def filter_dict(d, filter_key): } -class Electra(ChebaiBaseNet): - NAME = "Electra" +class ElectraBasedModel(ChebaiBaseNet): def _process_batch(self, batch, batch_idx): model_kwargs = dict() @@ -122,10 +121,6 @@ def _process_batch(self, batch, batch_idx): idents=batch.additional_fields["idents"], ) - @property - def as_pretrained(self): - return self.electra.electra - def __init__( self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs ): @@ -141,13 +136,7 @@ def __init__( self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) in_d = self.config.hidden_size - self.output = nn.Sequential( - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, in_d), - nn.GELU(), - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, self.config.num_labels), - ) + if pretrained_checkpoint: with open(pretrained_checkpoint, "rb") as fin: model_dict = torch.load(fin, map_location=self.device) @@ -196,12 +185,40 @@ def forward(self, data, **kwargs): electra = self.electra(inputs_embeds=inp, **kwargs) d = electra.last_hidden_state[:, 0, :] return dict( - logits=self.output(d), + output=d, attentions=electra.attentions, target_mask=data.get("target_mask"), ) +class Electra(ElectraBasedModel): + NAME = "Electra" + + def __init__( + self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs + ): + # Remove this property in order to prevent it from being stored as a + # hyper parameter + + super().__init__(**kwargs) + + in_d = self.config.hidden_size + self.output = nn.Sequential( + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, in_d), + nn.GELU(), + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, self.config.num_labels), + ) + + def forward(self, data, **kwargs): + d = super().forward(data, **kwargs) + return dict( + logits=self.output(d["output"]), + attentions=d["attentions"], + target_mask=d["target_mask"] + ) + class ElectraLegacy(ChebaiBaseNet): NAME = "ElectraLeg" @@ -236,65 +253,28 @@ def forward(self, data): return dict(logits=self.output(d), attentions=electra.attentions) +def gbmf(x, l, r, b = 6): + a = (r-l)+1e-3 + c = l+(r-l)/2 + return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) -class ChebiBoxWithMemberships(ChebaiBaseNet): - NAME = "ChebiBoxWithMemberships" - def _process_batch(self, batch, batch_idx): - model_kwargs = dict() - loss_kwargs = batch.additional_fields["loss_kwargs"] - if "lens" in batch.additional_fields["model_kwargs"]: - model_kwargs["attention_mask"] = pad_sequence( - [ - torch.ones(l + 1, device=self.device) - for l in batch.additional_fields["model_kwargs"]["lens"] - ], - batch_first=True, - ) - cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN - ) - return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), - labels=batch.y, - model_kwargs=model_kwargs, - loss_kwargs=loss_kwargs, - idents=batch.additional_fields["idents"], - ) +def normal(sigma, mu, x): + v = (x-mu)/sigma + return torch.exp(-0.5 * v*v) - def __init__( - self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs - ): - super().__init__(**kwargs) - if config is None: - config = dict() - if not "num_labels" in config and self.out_dim is not None: - config["num_labels"] = self.out_dim - self.config = ElectraConfig(**config, output_attentions=True) - self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) - - if pretrained_checkpoint: - with open(pretrained_checkpoint, "rb") as fin: - model_dict = torch.load(fin, map_location=self.device) - if load_prefix: - state_dict = filter_dict(model_dict["state_dict"], load_prefix) - else: - state_dict = model_dict["state_dict"] - self.electra = ElectraModel.from_pretrained( - None, state_dict=state_dict, config=self.config - ) - else: - self.electra = ElectraModel(config=self.config) +class ChebiBoxWithMemberships(ElectraBasedModel): + NAME = "ChebiBoxWithMemberships" - self.config = ElectraConfig(**config, output_attentions=True) + def __init__(self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs): + super().__init__(**kwargs) self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 ) + self.membership_method = membership_method + self.dimension_aggregation = dimension_aggregation self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), @@ -302,56 +282,52 @@ def __init__( nn.Linear(self.hidden_dim, self.out_dim) ) - def _process_for_loss(self, model_output, labels, loss_kwargs): - kwargs_copy = dict(loss_kwargs) - mask = kwargs_copy.pop("target_mask", None) - if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask - else: - d = model_output["logits"] - if labels is not None: - labels = labels.float() - return d, labels, kwargs_copy + def _prod_agg(self, memberships, dim=-1): + return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) - def _get_prediction_and_labels(self, data, labels, model_output): - mask = model_output.get("target_mask") - if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask - else: - d = model_output["logits"] - loss_kwargs = data.get("loss_kwargs", dict()) - if "non_null_labels" in loss_kwargs: - n = loss_kwargs["non_null_labels"] - d = d[n] - return torch.sigmoid(d), labels.int() + def _min_agg(self, memberships, dim=-1): + return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) + def _lukaziewisz_agg(self, memberships, dim=-1): + return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) + + def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): + return gbmf(points, left_corners, right_corners) + + def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs): + widths = 0.1 * (right_corners - left_corners) + max_distance_per_dim = nn.functional.relu(left_corners - points + widths**0.5) + nn.functional.relu(points - right_corners + widths**0.5) + return normal(widths**0.5, 0, max_distance_per_dim) def forward(self, data, **kwargs): - self.batch_size = data["features"].shape[0] - inp = self.electra.embeddings.forward(data["features"]) - inp = self.word_dropout(inp) - electra = self.electra(inputs_embeds=inp, **kwargs) - d = electra.last_hidden_state[:, 0, :] - points = self.embeddings_to_points(d) - self.points = points + d = super().forward(data, **kwargs) + points = self.embeddings_to_points(d["output"]).unsqueeze(1) - b = self.boxes.expand(self.batch_size, -1, -1, -1) + b = self.boxes.unsqueeze(0) l = torch.min(b, dim=-1)[0] r = torch.max(b, dim=-1)[0] - p = points.expand(self.config.num_labels, -1, -1).transpose(1, 0) - center = torch.mean(torch.stack([l, r]), dim=0) - width = 0.6 * (r - l) - slope = torch.sqrt(torch.abs(r - l)) - - membership = 1 / (1 + ((torch.abs(p - center) / width) ** (2 * slope))) - m = torch.mean(membership, dim=-1) + if self.membership_method == "normal": + m = self._forward_normal_membership(points, l, r) + elif self.membership_method == "gbmf": + m = self._forward_gbmf_membership(points, l, r) + else: + raise Exception("Unknown membership function:", self.membership_method) + + if self.dimension_aggregation == "prod": + m = self._prod_agg(m) + elif self.dimension_aggregation == "lukaziewisz": + m = self._lukaziewisz_agg(m) + elif self.dimension_aggregation == "min": + m = self._prod_min(m) + else: + raise Exception("Unknown aggregation function:", self.dimension_aggregation) return dict( boxes=b, embedded_points=points, logits=m, - attentions=electra.attentions, - target_mask=data.get("target_mask"), + attentions=d["attentions"], + target_mask=d["target_mask"], ) From 8d93d4360035fc66d36e304f83fcc6857ef003de Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 16:23:07 +0100 Subject: [PATCH 02/12] Refactor using black --- chebai/models/electra.py | 41 ++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index e3be02c2..7fb9d305 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -95,7 +95,6 @@ def filter_dict(d, filter_key): class ElectraBasedModel(ChebaiBaseNet): - def _process_batch(self, batch, batch_idx): model_kwargs = dict() loss_kwargs = batch.additional_fields["loss_kwargs"] @@ -216,9 +215,10 @@ def forward(self, data, **kwargs): return dict( logits=self.output(d["output"]), attentions=d["attentions"], - target_mask=d["target_mask"] + target_mask=d["target_mask"], ) + class ElectraLegacy(ChebaiBaseNet): NAME = "ElectraLeg" @@ -253,49 +253,62 @@ def forward(self, data): return dict(logits=self.output(d), attentions=electra.attentions) -def gbmf(x, l, r, b = 6): - a = (r-l)+1e-3 - c = l+(r-l)/2 +def gbmf(x, l, r, b=6): + a = (r - l) + 1e-3 + c = l + (r - l) / 2 return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) def normal(sigma, mu, x): - v = (x-mu)/sigma - return torch.exp(-0.5 * v*v) + v = (x - mu) / sigma + return torch.exp(-0.5 * v * v) class ChebiBoxWithMemberships(ElectraBasedModel): NAME = "ChebiBoxWithMemberships" - def __init__(self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs): + def __init__( + self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs + ): super().__init__(**kwargs) self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions - self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 ) + self.boxes = nn.Parameter( + torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 + ) self.membership_method = membership_method self.dimension_aggregation = dimension_aggregation self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), - nn.Linear(self.hidden_dim, self.out_dim) + nn.Linear(self.hidden_dim, self.out_dim), ) def _prod_agg(self, memberships, dim=-1): - return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) + return torch.relu( + torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) + ) def _min_agg(self, memberships, dim=-1): - return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) + return torch.relu( + torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) + ) + def _lukaziewisz_agg(self, memberships, dim=-1): - return torch.relu(torch.sum(memberships, dim=dim)-(memberships.shape[dim]-1)) + return torch.relu( + torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) + ) def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): return gbmf(points, left_corners, right_corners) def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs): widths = 0.1 * (right_corners - left_corners) - max_distance_per_dim = nn.functional.relu(left_corners - points + widths**0.5) + nn.functional.relu(points - right_corners + widths**0.5) + max_distance_per_dim = nn.functional.relu( + left_corners - points + widths**0.5 + ) + nn.functional.relu(points - right_corners + widths**0.5) return normal(widths**0.5, 0, max_distance_per_dim) def forward(self, data, **kwargs): From dcff0d1d9818709971350ad84a076da7d59b611a Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 16:26:40 +0100 Subject: [PATCH 03/12] Fix linting error --- chebai/models/electra.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7fb9d305..c99317e6 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -350,7 +350,6 @@ def __init__(self, **kwargs): self.criteria = nn.BCELoss() def forward(self, outputs, targets, **kwargs): - criterion = self.criteria bce_loss = criterion(outputs, targets) From d9c396d9fbed445d3610b2ff4695d9745c5fcf5b Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 16:35:37 +0100 Subject: [PATCH 04/12] Fix handling of non-logit outputs --- chebai/models/electra.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index c99317e6..18120fff 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -153,9 +153,9 @@ def _process_for_loss(self, model_output, labels, loss_kwargs): kwargs_copy = dict(loss_kwargs) mask = kwargs_copy.pop("target_mask", None) if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask + d = model_output["output"] * mask - 100 * ~mask else: - d = model_output["logits"] + d = model_output["output"] if labels is not None: labels = labels.float() return d, labels, kwargs_copy @@ -163,14 +163,14 @@ def _process_for_loss(self, model_output, labels, loss_kwargs): def _get_prediction_and_labels(self, data, labels, model_output): mask = model_output.get("target_mask") if mask is not None: - d = model_output["logits"] * mask - 100 * ~mask + d = model_output["output"] * mask - 100 * ~mask else: - d = model_output["logits"] + d = model_output["output"] loss_kwargs = data.get("loss_kwargs", dict()) if "non_null_labels" in loss_kwargs: n = loss_kwargs["non_null_labels"] d = d[n] - return torch.sigmoid(d), labels.int() + return d, labels.int() def forward(self, data, **kwargs): self.batch_size = data["features"].shape[0] @@ -210,10 +210,14 @@ def __init__( nn.Linear(in_d, self.config.num_labels), ) + def _get_prediction_and_labels(self, data, labels, model_output): + preds, lbls = super()._get_prediction_and_labels(data, labels, model_output) + return torch.sigmoid(preds), lbls + def forward(self, data, **kwargs): d = super().forward(data, **kwargs) return dict( - logits=self.output(d["output"]), + output=self.output(d["output"]), attentions=d["attentions"], target_mask=d["target_mask"], ) @@ -250,7 +254,7 @@ def __init__(self, **kwargs): def forward(self, data): electra = self.electra(data) d = torch.sum(electra.last_hidden_state, dim=1) - return dict(logits=self.output(d), attentions=electra.attentions) + return dict(output=self.output(d), attentions=electra.attentions) def gbmf(x, l, r, b=6): @@ -338,7 +342,7 @@ def forward(self, data, **kwargs): return dict( boxes=b, embedded_points=points, - logits=m, + output=m, attentions=d["attentions"], target_mask=d["target_mask"], ) From ccaf7346430e3e44785876bd09fb31faa1e58837 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 20:03:35 +0100 Subject: [PATCH 05/12] Fix some bugs in aggregation method --- chebai/models/electra.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 18120fff..7ab158ef 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -16,6 +16,7 @@ from chebai.models.base import ChebaiBaseNet from chebai.preprocessing.reader import CLS_TOKEN, MASK_TOKEN_INDEX import pytorch_lightning as pl +import math logging.getLogger("pysmiles").setLevel(logging.CRITICAL) @@ -279,7 +280,7 @@ def __init__( self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions self.boxes = nn.Parameter( - torch.rand((self.config.num_labels, self.out_dim, 2)) * 3 + 3 - torch.rand((self.config.num_labels, self.out_dim, 2)) * 6 ) self.membership_method = membership_method self.dimension_aggregation = dimension_aggregation @@ -291,19 +292,16 @@ def __init__( ) def _prod_agg(self, memberships, dim=-1): - return torch.relu( - torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) - ) + return torch.prod(memberships, dim=dim) def _min_agg(self, memberships, dim=-1): - return torch.relu( - torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) - ) + return torch.min(memberships, dim=dim)[0] - def _lukaziewisz_agg(self, memberships, dim=-1): - return torch.relu( - torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1) - ) + def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): + """ + This is a version of the Łukaziewish-T-norm using a modified softplus instead of max + """ + return 1/scale * torch.log(1+torch.exp(math.log(math.exp(scale)-1)*(torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1)))) def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): return gbmf(points, left_corners, right_corners) @@ -324,25 +322,25 @@ def forward(self, data, **kwargs): r = torch.max(b, dim=-1)[0] if self.membership_method == "normal": - m = self._forward_normal_membership(points, l, r) + memberships_per_dim = self._forward_normal_membership(points, l, r) elif self.membership_method == "gbmf": - m = self._forward_gbmf_membership(points, l, r) + memberships_per_dim = self._forward_gbmf_membership(points, l, r) else: raise Exception("Unknown membership function:", self.membership_method) if self.dimension_aggregation == "prod": - m = self._prod_agg(m) + aggregated_memberships = self._prod_agg(memberships_per_dim) elif self.dimension_aggregation == "lukaziewisz": - m = self._lukaziewisz_agg(m) + aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim) elif self.dimension_aggregation == "min": - m = self._prod_min(m) + aggregated_memberships = self._min_agg(memberships_per_dim) else: raise Exception("Unknown aggregation function:", self.dimension_aggregation) return dict( boxes=b, embedded_points=points, - output=m, + output=aggregated_memberships, attentions=d["attentions"], target_mask=d["target_mask"], ) From 4cacf8faa8e18842385f718ff3354c828bdff81f Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 18 Jan 2024 20:27:59 +0100 Subject: [PATCH 06/12] Reformat using black --- chebai/models/electra.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 7ab158ef..ed63581e 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -301,7 +301,17 @@ def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): """ This is a version of the Łukaziewish-T-norm using a modified softplus instead of max """ - return 1/scale * torch.log(1+torch.exp(math.log(math.exp(scale)-1)*(torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1)))) + return ( + 1 + / scale + * torch.log( + 1 + + torch.exp( + math.log(math.exp(scale) - 1) + * (torch.sum(memberships, dim=dim) - (memberships.shape[dim] - 1)) + ) + ) + ) def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): return gbmf(points, left_corners, right_corners) From c93c6bdc44f1c05eb32668cdbbb043309289c9fe Mon Sep 17 00:00:00 2001 From: Adel Memariani Date: Mon, 22 Jan 2024 15:58:49 +0100 Subject: [PATCH 07/12] add gbmf with adjusted slope --- chebai/models/base.py | 2 +- chebai/models/electra.py | 34 +++++++++++++++++++++++------ chebai/preprocessing/collect_all.py | 2 +- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 3a149832..cd6ab347 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -4,7 +4,7 @@ from lightning.pytorch.core.module import LightningModule import torch - +import pickle from chebai.preprocessing.structures import XYData logging.getLogger("pysmiles").setLevel(logging.CRITICAL) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index ed63581e..41e3dc82 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -263,6 +263,16 @@ def gbmf(x, l, r, b=6): c = l + (r - l) / 2 return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) +def gbmf_adjusted_slope(x, l, r): + l = l.to(torch.float32) + r = r.to(torch.float32) + segment = torch.abs(r - l) + is_close = torch.allclose(l, r, rtol=1e-2, atol=1e-3) + a = torch.abs(r - l) - 1e-3 if is_close else 0.7 * torch.abs(r - l) + c = l + (r - l) / 2 + b = 2 / (1 + torch.exp(-(segment - 5) / 5)) + membership = 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) + return membership def normal(sigma, mu, x): v = (x - mu) / sigma @@ -273,17 +283,17 @@ class ChebiBoxWithMemberships(ElectraBasedModel): NAME = "ChebiBoxWithMemberships" def __init__( - self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs + self, **kwargs ): super().__init__(**kwargs) + + self.membership_method = self.config.membership_method + self.dimension_aggregation = self.config.dimension_aggregation + self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions - self.boxes = nn.Parameter( - 3 - torch.rand((self.config.num_labels, self.out_dim, 2)) * 6 - ) - self.membership_method = membership_method - self.dimension_aggregation = dimension_aggregation + self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2))) self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), @@ -297,6 +307,12 @@ def _prod_agg(self, memberships, dim=-1): def _min_agg(self, memberships, dim=-1): return torch.min(memberships, dim=dim)[0] + def _mean_agg(self, memberships, dim=-1): + return torch.mean(memberships, dim=dim) + + def _sum_agg(self, memberships, dim=-1): + return torch.sum(memberships, dim=dim) + def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): """ This is a version of the Łukaziewish-T-norm using a modified softplus instead of max @@ -314,7 +330,7 @@ def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): ) def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): - return gbmf(points, left_corners, right_corners) + return gbmf_adjusted_slope(points, left_corners, right_corners) def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs): widths = 0.1 * (right_corners - left_corners) @@ -344,6 +360,10 @@ def forward(self, data, **kwargs): aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim) elif self.dimension_aggregation == "min": aggregated_memberships = self._min_agg(memberships_per_dim) + elif self.dimension_aggregation == "mean": + aggregated_memberships = self._mean_agg(memberships_per_dim) + elif self.dimension_aggregation == "sum": + aggregated_memberships = self._sum_agg(memberships_per_dim) else: raise Exception("Unknown aggregation function:", self.dimension_aggregation) diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py index f82ce71c..53b01381 100644 --- a/chebai/preprocessing/collect_all.py +++ b/chebai/preprocessing/collect_all.py @@ -191,11 +191,11 @@ def train(train_loader, validation_loader): checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(tb_logger.log_dir, "checkpoints"), filename="{epoch}-{step}-{val_loss:.7f}", - save_top_k=5, save_last=True, verbose=True, monitor="val_loss", mode="min", + every_n_epochs=1 ) trainer = pl.Trainer( logger=tb_logger, From 9f396026271b81bde13ae0608102e7ea16f596a5 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Wed, 24 Jan 2024 14:36:59 +0100 Subject: [PATCH 08/12] Implement alternate box model --- chebai/loss/boxes.py | 36 ++++++++++++++++++++++++++++++ chebai/models/electra.py | 48 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 chebai/loss/boxes.py diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py new file mode 100644 index 00000000..671267ca --- /dev/null +++ b/chebai/loss/boxes.py @@ -0,0 +1,36 @@ +import torch + +class BoxLoss(torch.nn.Module): + def __init__( + self, base_loss: torch.nn.Module = None + ): + super().__init__() + self.base_loss = base_loss + + def forward(self, input, target, **kwargs): + b = input["boxes"] + points = input["embedded_points"] + target = target.float().unsqueeze(-1) + l, lind = torch.min(b, dim=-1) + r, rind = torch.max(b, dim=-1) + + widths = r - l + + l += 0.1*widths + r -= 0.1 * widths + inside = ((l < points) * (points < r)).float() + closer_to_l_than_to_r = (torch.abs(l - points) < torch.abs(r - points)).float() + fn_per_dim = ((1 - inside) * target) + fp_per_dim = (inside * (1 - target)) + diff = torch.abs(fp_per_dim - fn_per_dim) + return self.base_loss(diff * closer_to_l_than_to_r * points, diff * closer_to_l_than_to_r * l) + self.base_loss( + diff * (1 - closer_to_l_than_to_r) * points, diff * (1 - closer_to_l_than_to_r) * r) + + def _calculate_implication_loss(self, l, r): + capped_difference = torch.relu(l - r) + return torch.mean( + torch.sum( + (torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1 + ), + dim=0, + ) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 41e3dc82..767639dd 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -96,6 +96,8 @@ def filter_dict(d, filter_key): class ElectraBasedModel(ChebaiBaseNet): + NAME = "ElectraBase" + def _process_batch(self, batch, batch_idx): model_kwargs = dict() loss_kwargs = batch.additional_fields["loss_kwargs"] @@ -127,7 +129,11 @@ def __init__( # Remove this property in order to prevent it from being stored as a # hyper parameter + super().__init__(**kwargs) + + + if config is None: config = dict() if not "num_labels" in config and self.out_dim is not None: @@ -150,6 +156,9 @@ def __init__( else: self.electra = ElectraModel(config=self.config) + if self.out_dim is None: + self.out_dim = self.electra.config.hidden_size + def _process_for_loss(self, model_output, labels, loss_kwargs): kwargs_copy = dict(loss_kwargs) mask = kwargs_copy.pop("target_mask", None) @@ -190,7 +199,6 @@ def forward(self, data, **kwargs): target_mask=data.get("target_mask"), ) - class Electra(ElectraBasedModel): NAME = "Electra" @@ -389,6 +397,44 @@ def forward(self, outputs, targets, **kwargs): return total_loss +class CrispBoxClassifier(ElectraBasedModel): + NAME = "CripsBox" + + def __init__(self, box_dimensions=3, **kwargs): + super().__init__(**kwargs) + + self.point_embedding = nn.Linear(self.config.hidden_size, box_dimensions) + + self.num_boxes = kwargs["out_dim"] + b = torch.randn((self.num_boxes, box_dimensions, 2)) + self.boxes = nn.Parameter(b, requires_grad=True) + + def forward(self, x, **kwargs): + d = super().forward(x, **kwargs) + points = self.point_embedding(d["output"]).unsqueeze(1) + b = self.boxes.unsqueeze(0) + l, lind = torch.min(b, dim=-1) + r, rind = torch.max(b, dim=-1) + inside = torch.all((l < points) * (points < r), dim=-1).float() + return dict( + boxes=b, + embedded_points=points, + output=inside, + attentions=d["attentions"], + target_mask=d["target_mask"], + ) + + def _process_for_loss(self, model_output, labels, loss_kwargs): + kwargs_copy = dict(loss_kwargs) + mask = kwargs_copy.pop("target_mask", None) + if mask is not None: + d = model_output["output"] * mask - 100 * ~mask + else: + d = model_output["output"] + if labels is not None: + labels = labels.float() + model_output["output"] = d + return model_output, labels, kwargs_copy class ConeElectra(ChebaiBaseNet): NAME = "ConeElectra" From f5e650bae843aa24097cbdd47713819ff00bc735 Mon Sep 17 00:00:00 2001 From: Adel Memariani Date: Wed, 24 Jan 2024 17:04:12 +0100 Subject: [PATCH 09/12] use torch.log1p in lukaziewisz aggregations --- chebai/models/electra.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 41e3dc82..0a792087 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -257,7 +257,6 @@ def forward(self, data): d = torch.sum(electra.last_hidden_state, dim=1) return dict(output=self.output(d), attentions=electra.attentions) - def gbmf(x, l, r, b=6): a = (r - l) + 1e-3 c = l + (r - l) / 2 @@ -268,7 +267,7 @@ def gbmf_adjusted_slope(x, l, r): r = r.to(torch.float32) segment = torch.abs(r - l) is_close = torch.allclose(l, r, rtol=1e-2, atol=1e-3) - a = torch.abs(r - l) - 1e-3 if is_close else 0.7 * torch.abs(r - l) + a = torch.abs(r - l) - 1e-3 if is_close else 1 + (0.7 * torch.abs(r - l)) c = l + (r - l) / 2 b = 2 / (1 + torch.exp(-(segment - 5) / 5)) membership = 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) @@ -293,8 +292,19 @@ def __init__( self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions + + # Random boxes self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2))) + # Boxes with (relatively) same sizes + """ + base_box = torch.rand((2, self.out_dim)) + similar_boxes = base_box.repeat((self.config.num_labels, 1, 1)) + small_differences = 0.05 * torch.randn_like(similar_boxes) + similar_boxes += small_differences + self.boxes = nn.Parameter(similar_boxes) + """ + self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), nn.ReLU(), @@ -304,14 +314,14 @@ def __init__( def _prod_agg(self, memberships, dim=-1): return torch.prod(memberships, dim=dim) - def _min_agg(self, memberships, dim=-1): - return torch.min(memberships, dim=dim)[0] - - def _mean_agg(self, memberships, dim=-1): - return torch.mean(memberships, dim=dim) - - def _sum_agg(self, memberships, dim=-1): - return torch.sum(memberships, dim=dim) + def _soft_lukaziewicz_agg_2(self, memberships, scale=10): + # TORCH.LOG1P is more accurate than torch.log() for small values of input + # https://pytorch.org/docs/stable/generated/torch.log1p.html + return torch.log1p( + torch.exp( + scale * (torch.sum(memberships, dim=-1) - (memberships.shape[-1] - 1)) + ) + ) / scale def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): """ @@ -360,10 +370,9 @@ def forward(self, data, **kwargs): aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim) elif self.dimension_aggregation == "min": aggregated_memberships = self._min_agg(memberships_per_dim) - elif self.dimension_aggregation == "mean": - aggregated_memberships = self._mean_agg(memberships_per_dim) - elif self.dimension_aggregation == "sum": - aggregated_memberships = self._sum_agg(memberships_per_dim) + elif self.dimension_aggregation == "lukaziewisz2": + aggregated_memberships = self._soft_lukaziewicz_agg_2(memberships_per_dim) + else: raise Exception("Unknown aggregation function:", self.dimension_aggregation) From 23462356de8cafb4d1e9c33586b971fed91dd82d Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 25 Jan 2024 13:19:57 +0100 Subject: [PATCH 10/12] Update box loss --- chebai/loss/boxes.py | 44 +++++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py index 671267ca..85def4e4 100644 --- a/chebai/loss/boxes.py +++ b/chebai/loss/boxes.py @@ -14,17 +14,39 @@ def forward(self, input, target, **kwargs): l, lind = torch.min(b, dim=-1) r, rind = torch.max(b, dim=-1) - widths = r - l - - l += 0.1*widths - r -= 0.1 * widths - inside = ((l < points) * (points < r)).float() - closer_to_l_than_to_r = (torch.abs(l - points) < torch.abs(r - points)).float() - fn_per_dim = ((1 - inside) * target) - fp_per_dim = (inside * (1 - target)) - diff = torch.abs(fp_per_dim - fn_per_dim) - return self.base_loss(diff * closer_to_l_than_to_r * points, diff * closer_to_l_than_to_r * l) + self.base_loss( - diff * (1 - closer_to_l_than_to_r) * points, diff * (1 - closer_to_l_than_to_r) * r) + width = (r - l) / 2 + r_fp = r + 0.1 * width + r_fn = r - 0.1 * width + + l_fp = l - 0.1 * width + l_fn = l + 0.1 * width + + inside = ((l < points) * (points < r)) + inside_fp = (l_fp < points) * (points < r_fp) + inside_fn = (l_fn < points) * (points < r_fn) + + fn_per_dim = ~inside_fn * target + fp_per_dim = inside_fp * (1 - target) + + false_per_dim = fn_per_dim + fp_per_dim + number_of_false_dims = torch.sum(false_per_dim, dim=-1, keepdim=True) + + dl = torch.abs(l - points) + dr = torch.abs(r - points) + + closer_to_l_than_r = dl < dr + + r_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * ~closer_to_l_than_r) + l_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * closer_to_l_than_r) + + r_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * ~closer_to_l_than_r) + l_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * closer_to_l_than_r) + + r_loss = torch.mean(torch.sum(torch.abs(r_scale_fp * (r_fp - points)), dim=-1) + torch.sum( + torch.abs(r_scale_fn * (r_fn - points)), dim=-1)) + l_loss = torch.mean(torch.sum(torch.abs(l_scale_fp * (l_fp - points)), dim=-1) + torch.sum( + torch.abs(l_scale_fn * (l_fn - points)), dim=-1)) + return l_loss + r_loss def _calculate_implication_loss(self, l, r): capped_difference = torch.relu(l - r) From 2d7ba5d5167ee7b95da1f8fbc0516d05cc183f20 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Thu, 25 Jan 2024 21:00:10 +0100 Subject: [PATCH 11/12] Add comments to box loss --- chebai/loss/boxes.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py index 85def4e4..9950701b 100644 --- a/chebai/loss/boxes.py +++ b/chebai/loss/boxes.py @@ -11,37 +11,46 @@ def forward(self, input, target, **kwargs): b = input["boxes"] points = input["embedded_points"] target = target.float().unsqueeze(-1) - l, lind = torch.min(b, dim=-1) - r, rind = torch.max(b, dim=-1) + left_borders, lind = torch.min(b, dim=-1) + right_borders, rind = torch.max(b, dim=-1) + width = (right_borders - left_borders) / 2 - width = (r - l) / 2 - r_fp = r + 0.1 * width - r_fn = r - 0.1 * width - - l_fp = l - 0.1 * width - l_fn = l + 0.1 * width - - inside = ((l < points) * (points < r)) + # We want some safety margins around boxes. (False) positives should be drawn + # further into the box, whilst (false) negatives should be pushed further outside. + # Therefore, we use different borders for (false) positives and negatives. + r_fp = right_borders + 0.1 * width + r_fn = right_borders - 0.1 * width + l_fp = left_borders - 0.1 * width + l_fn = left_borders + 0.1 * width inside_fp = (l_fp < points) * (points < r_fp) inside_fn = (l_fn < points) * (points < r_fn) + # False positive and negatives, w.r.t. the adapted box borders fn_per_dim = ~inside_fn * target fp_per_dim = inside_fp * (1 - target) + # We also want to penalise wrong memberships in multiple dimensions. This + # is important, because a false positive in a single dimension is not wrong, + # if at least one dimension is true negative. false_per_dim = fn_per_dim + fp_per_dim number_of_false_dims = torch.sum(false_per_dim, dim=-1, keepdim=True) - dl = torch.abs(l - points) - dr = torch.abs(r - points) + # We calculate the gradient for left and right border simultaneously, but we only need the one + # closest to the point. Therefore, we create a filter for that. + dl = torch.abs(left_borders - points) + dr = torch.abs(right_borders - points) closer_to_l_than_r = dl < dr + # The scaling factor encodes the conjunction of whether the respective dimension is false and whether the respective + # border is the closest to the point. r_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * ~closer_to_l_than_r) l_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * closer_to_l_than_r) - r_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * ~closer_to_l_than_r) l_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * closer_to_l_than_r) + # The loss for a border is then the mean of the scaled vector between the points for which the model would + # produce a wrong prediction and the closest border of the box r_loss = torch.mean(torch.sum(torch.abs(r_scale_fp * (r_fp - points)), dim=-1) + torch.sum( torch.abs(r_scale_fn * (r_fn - points)), dim=-1)) l_loss = torch.mean(torch.sum(torch.abs(l_scale_fp * (l_fp - points)), dim=-1) + torch.sum( From d6a79341335af63239728d8a9086b7784fab0306 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Fri, 26 Jan 2024 01:39:19 +0100 Subject: [PATCH 12/12] Fix logic error in box loss --- chebai/loss/boxes.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py index 9950701b..6971a29e 100644 --- a/chebai/loss/boxes.py +++ b/chebai/loss/boxes.py @@ -29,11 +29,11 @@ def forward(self, input, target, **kwargs): fn_per_dim = ~inside_fn * target fp_per_dim = inside_fp * (1 - target) - # We also want to penalise wrong memberships in multiple dimensions. This + # We also want to penalise wrong memberships in different dimensions. This # is important, because a false positive in a single dimension is not wrong, # if at least one dimension is true negative. false_per_dim = fn_per_dim + fp_per_dim - number_of_false_dims = torch.sum(false_per_dim, dim=-1, keepdim=True) + all_dimensions_wrong = torch.min(false_per_dim, dim=-1, keepdim=True)[0] # We calculate the gradient for left and right border simultaneously, but we only need the one @@ -44,17 +44,27 @@ def forward(self, input, target, **kwargs): # The scaling factor encodes the conjunction of whether the respective dimension is false and whether the respective # border is the closest to the point. - r_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * ~closer_to_l_than_r) - l_scale_fp = number_of_false_dims * torch.rand_like(fp_per_dim) * (fp_per_dim * closer_to_l_than_r) - r_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * ~closer_to_l_than_r) - l_scale_fn = number_of_false_dims * torch.rand_like(fn_per_dim) * (fn_per_dim * closer_to_l_than_r) + + r_scale_fp = all_dimensions_wrong * (fp_per_dim * ~closer_to_l_than_r) + l_scale_fp = all_dimensions_wrong * (fp_per_dim * closer_to_l_than_r) + + r_scale_fn = (fn_per_dim * ~closer_to_l_than_r) + l_scale_fn = (fn_per_dim * closer_to_l_than_r) + + d_r_fp = r_scale_fp * torch.abs(r_fp - points) + d_l_fp = l_scale_fp * torch.abs(l_fp - points) + d_r_fn = r_scale_fn * torch.abs(r_fn - points) + d_l_fn = l_scale_fn * torch.abs(l_fn - points) + + w_r_fp = torch.nn.functional.softmin(d_r_fp, dim=-1) + w_r_fn = torch.nn.functional.softmin(d_r_fn, dim=-1) + w_l_fp = torch.nn.functional.softmin(d_l_fp, dim=-1) + w_l_fn = torch.nn.functional.softmin(d_l_fn, dim=-1) # The loss for a border is then the mean of the scaled vector between the points for which the model would # produce a wrong prediction and the closest border of the box - r_loss = torch.mean(torch.sum(torch.abs(r_scale_fp * (r_fp - points)), dim=-1) + torch.sum( - torch.abs(r_scale_fn * (r_fn - points)), dim=-1)) - l_loss = torch.mean(torch.sum(torch.abs(l_scale_fp * (l_fp - points)), dim=-1) + torch.sum( - torch.abs(l_scale_fn * (l_fn - points)), dim=-1)) + r_loss = torch.mean(torch.sum(w_r_fp * d_r_fp, dim=(1, 2)) + torch.sum(w_r_fn * d_r_fn, dim=(1, 2))) + l_loss = torch.mean(torch.sum(w_l_fp * d_l_fp, dim=(1, 2)) + torch.sum(w_l_fn * d_l_fn, dim=(1, 2))) return l_loss + r_loss def _calculate_implication_loss(self, l, r):