Skip to content

Comment prediction added (CommentCode2Seq) #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
628c5e5
Added useful classes
malodetz Jul 14, 2022
1ff1994
Config added
malodetz Jul 14, 2022
f8ffa00
Added comment label processing
malodetz Jul 15, 2022
5f39e68
Added wrapper
malodetz Jul 15, 2022
8620414
Update requirements.txt
malodetz Jul 15, 2022
7624d44
Fixing train to use gpu
malodetz Jul 15, 2022
fb5fb26
New model with correct decoder
malodetz Jul 15, 2022
c4a203b
Fix black
malodetz Jul 15, 2022
6fdee03
Added custom chrf metric
malodetz Jul 15, 2022
db70270
Minor updates
malodetz Jul 16, 2022
b68cc10
Fix random
malodetz Jul 16, 2022
24c8484
Fixing chrf and f1
malodetz Jul 20, 2022
33db6b5
Preliminary new tokenizer
malodetz Jul 28, 2022
dc15cdb
Complete new tokenizer
malodetz Jul 28, 2022
26dd42f
Some fixes
malodetz Jul 28, 2022
f6fa424
New vocab size
malodetz Jul 28, 2022
e8b677d
Add tokenizer to config
malodetz Aug 7, 2022
b1b2e27
Move chrf metric
malodetz Aug 8, 2022
e70f96d
Better tokenization
malodetz Aug 8, 2022
5e93981
Implement comment transformer decoder
malodetz Aug 10, 2022
033560c
Greedy decoding for val/test
malodetz Aug 10, 2022
c74319f
Small fix
malodetz Aug 10, 2022
b1009cd
Some fixes
malodetz Aug 12, 2022
6aff1ed
Logits cut
malodetz Aug 12, 2022
c23e20a
Fixed greedy generation
malodetz Aug 12, 2022
37f3db2
Some train changes to fix
malodetz Aug 12, 2022
a19e752
Fix train
malodetz Aug 25, 2022
e43b664
Сonfig for transformer decoder
malodetz Aug 30, 2022
c51f14c
Multiple decoders
malodetz Aug 30, 2022
d09dc30
Another config
malodetz Aug 30, 2022
9bfa25c
Add early stop for greedy generation
malodetz Aug 30, 2022
b0b7b81
Early generation stop
malodetz Aug 30, 2022
54cb711
Added predictions
malodetz Sep 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions code2seq/comment_code2seq_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from argparse import ArgumentParser
from typing import cast

import torch
from commode_utils.common import print_config
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything

from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule
from code2seq.model.comment_code2seq import CommentCode2Seq
from code2seq.utils.common import filter_warnings
from code2seq.utils.test import test
from code2seq.utils.train import train


def configure_arg_parser() -> ArgumentParser:
arg_parser = ArgumentParser()
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test", "predict"])
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
arg_parser.add_argument(
"-p", "--pretrained", help="Path to pretrained model", type=str, required=False, default=None
)
arg_parser.add_argument(
"-o", "--output", help="Output file for predictions", type=str, required=False, default=None
)
return arg_parser


def train_code2seq(config: DictConfig):
filter_warnings()

if config.print_config:
print_config(config, fields=["model", "data", "train", "optimizer"])

# Load data module
data_module = CommentPathContextDataModule(config.data_folder, config.data)

# Load model
code2seq = CommentCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)

train(code2seq, data_module, config)


def test_code2seq(model_path: str, config: DictConfig):
filter_warnings()

# Load data module
data_module = CommentPathContextDataModule(config.data_folder, config.data)

# Load model
code2seq = CommentCode2Seq.load_from_checkpoint(model_path, map_location=torch.device("cpu"))

test(code2seq, data_module, config.seed)


def save_predictions(model_path: str, config: DictConfig, output_path: str):
filter_warnings()

data_module = CommentPathContextDataModule(config.data_folder, config.data)
tokenizer = data_module.vocabulary.tokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
code2seq = CommentCode2Seq.load_from_checkpoint(model_path)
code2seq.to(device)
code2seq.eval()

with open(output_path, "w") as f:
for batch in data_module.test_dataloader():
data_module.transfer_batch_to_device(batch, device, 0)
logits, _ = code2seq.logits_from_batch(batch, None)

predictions = logits[:-1].argmax(-1)
targets = batch.labels[1:]

batch_size = targets.shape[1]
for batch_idx in range(batch_size):
target_seq = [token.item() for token in targets[:, batch_idx]]
predicted_seq = [token.item() for token in predictions[:, batch_idx]]

target_str = tokenizer.decode(target_seq, skip_special_tokens=True)
predicted_str = tokenizer.decode(predicted_seq, skip_special_tokens=True)

if target_str == "":
continue

print(target_str.replace(" ", "|"), predicted_str.replace(" ", "|"), file=f)


if __name__ == "__main__":
__arg_parser = configure_arg_parser()
__args = __arg_parser.parse_args()

__config = cast(DictConfig, OmegaConf.load(__args.config))
seed_everything(__config.seed)
if __args.mode == "train":
train_code2seq(__config)
else:
assert __args.pretrained is not None
if __args.mode == "test":
test_code2seq(__args.pretrained, __config)
else:
save_predictions(__args.pretrained, __config, __args.output)
68 changes: 68 additions & 0 deletions code2seq/data/comment_path_context_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pickle
from collections import Counter
from os.path import join, exists, dirname
from typing import Dict, Counter as TCounter, Type

from commode_utils.vocabulary import BaseVocabulary
from tqdm.auto import tqdm

from commode_utils.filesystem import count_lines_in_file
from omegaconf import DictConfig
from transformers import RobertaTokenizerFast

from code2seq.data.comment_path_context_dataset import CommentPathContextDataset
from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.data.vocabulary import CommentVocabulary


def _build_from_scratch(config: DictConfig, train_data: str, vocabulary_cls: Type[BaseVocabulary]):
total_samples = count_lines_in_file(train_data)
counters: Dict[str, TCounter[str]] = {
key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE]
}
with open(train_data, "r") as f_in:
for raw_sample in tqdm(f_in, total=total_samples):
vocabulary_cls.process_raw_sample(raw_sample, counters)

training_corpus = []
for string, amount in counters[vocabulary_cls.LABEL].items():
training_corpus.extend([string] * amount)
old_tokenizer = RobertaTokenizerFast.from_pretrained(config.base_tokenizer)
if config.train_new_tokenizer:
tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, config.max_tokenizer_vocab)
else:
tokenizer = old_tokenizer

for feature, counter in counters.items():
print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}")

dataset_dir = dirname(train_data)
vocabulary_file = join(dataset_dir, vocabulary_cls.vocab_filename)
with open(vocabulary_file, "wb") as f_out:
pickle.dump(counters, f_out)
pickle.dump(tokenizer, f_out)


class CommentPathContextDataModule(PathContextDataModule):
_vocabulary: CommentVocabulary

def __init__(self, data_dir: str, config: DictConfig):
super().__init__(data_dir, config)

def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset:
if self._vocabulary is None:
raise RuntimeError(f"Setup vocabulary before creating data loaders")
return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)

def setup_vocabulary(self) -> CommentVocabulary:
if not exists(join(self._data_dir, CommentVocabulary.vocab_filename)):
print("Can't find vocabulary, collect it from train holdout")
_build_from_scratch(self._config, join(self._data_dir, f"{self._train}.c2s"), CommentVocabulary)
vocabulary_path = join(self._data_dir, CommentVocabulary.vocab_filename)
return CommentVocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count)

@property
def vocabulary(self) -> CommentVocabulary:
if self._vocabulary is None:
raise RuntimeError(f"Setup data module for initializing vocabulary")
return self._vocabulary
22 changes: 22 additions & 0 deletions code2seq/data/comment_path_context_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Dict, List, Optional

from code2seq.data.vocabulary import CommentVocabulary
from omegaconf import DictConfig

from code2seq.data.path_context_dataset import PathContextDataset


class CommentPathContextDataset(PathContextDataset):
def __init__(self, data_file: str, config: DictConfig, vocabulary: CommentVocabulary, random_context: bool):
super().__init__(data_file, config, vocabulary, random_context)
self._vocab: CommentVocabulary = vocabulary

def tokenize_label(self, raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
tokenizer = self._vocab.tokenizer
tokenized_snippet = tokenizer(
raw_label.replace(PathContextDataset._separator, " "),
add_special_tokens=True,
padding="max_length" if max_parts else "do_not_pad",
max_length=max_parts,
)
return tokenized_snippet["input_ids"]
15 changes: 15 additions & 0 deletions code2seq/data/vocabulary.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pickle
from argparse import ArgumentParser
from collections import Counter
from os.path import dirname, join
from pickle import load, dump
from typing import Dict, Counter as CounterType, Optional, List

from commode_utils.vocabulary import BaseVocabulary, build_from_scratch
from transformers import PreTrainedTokenizerFast


class Vocabulary(BaseVocabulary):
Expand Down Expand Up @@ -71,6 +73,19 @@ def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]):
TypedVocabulary._process_raw_sample(raw_sample, counters, context_seq)


class CommentVocabulary(Vocabulary):
def __init__(
self,
vocabulary_file: str,
labels_count: Optional[int] = None,
tokens_count: Optional[int] = None,
):
super().__init__(vocabulary_file, labels_count, tokens_count)
with open(vocabulary_file, "rb") as f_in:
pickle.load(f_in)
self.tokenizer: PreTrainedTokenizerFast = pickle.load(f_in)


def convert_from_vanilla(vocabulary_path: str):
counters: Dict[str, CounterType[str]] = {}
with open(vocabulary_path, "rb") as dict_file:
Expand Down
28 changes: 15 additions & 13 deletions code2seq/model/code2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ def __init__(
if vocabulary.SOS not in vocabulary.label_to_id:
raise ValueError(f"Can't find SOS token in label to id vocabulary")

self.__pad_idx = vocabulary.label_to_id[vocabulary.PAD]
self._pad_idx = vocabulary.label_to_id[vocabulary.PAD]
eos_idx = vocabulary.label_to_id[vocabulary.EOS]
ignore_idx = [vocabulary.label_to_id[vocabulary.SOS], vocabulary.label_to_id[vocabulary.UNK]]
metrics: Dict[str, Metric] = {
f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}
id2label = {v: k for k, v in vocabulary.label_to_id.items()}
metrics.update(
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]}
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self._pad_idx, eos_idx]) for holdout in ["val", "test"]}
)
self.__metrics = MetricCollection(metrics)
self._metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx)
decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self._pad_idx)
self._decoder = Decoder(
decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing
)

self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean")
self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean")

@property
def vocabulary(self) -> Vocabulary:
Expand Down Expand Up @@ -107,16 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
target_sequence = batch.labels if step == "train" else None
# [seq length; batch size; vocab size]
logits, _ = self.logits_from_batch(batch, target_sequence)
result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])}
logits = logits[1:]
batch.labels = batch.labels[1:]
result = {f"{step}/loss": self._loss(logits, batch.labels)}

with torch.no_grad():
prediction = logits.argmax(-1)
metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels)
result.update(
{f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall}
)
if step != "train":
result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels)
result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels)

return result

Expand All @@ -140,17 +142,17 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
with torch.no_grad():
losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs]
mean_loss = torch.stack(losses).mean()
metric = self.__metrics[f"{step}_f1"].compute()
metric = self._metrics[f"{step}_f1"].compute()
log = {
f"{step}/loss": mean_loss,
f"{step}/f1": metric.f1_score,
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
self.__metrics[f"{step}_f1"].reset()
self._metrics[f"{step}_f1"].reset()
if step != "train":
log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute()
self.__metrics[f"{step}_chrf"].reset()
log[f"{step}/chrf"] = self._metrics[f"{step}_chrf"].compute()
self._metrics[f"{step}_chrf"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
Expand Down
90 changes: 90 additions & 0 deletions code2seq/model/comment_code2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Dict

import torch
from commode_utils.losses import SequenceCrossEntropyLoss
from commode_utils.metrics import SequentialF1Score, ClassificationMetrics
from commode_utils.modules import LSTMDecoderStep, Decoder
from omegaconf import DictConfig
from torchmetrics import MetricCollection, Metric

from code2seq.data.path_context import BatchedLabeledPathContext
from code2seq.data.vocabulary import CommentVocabulary
from code2seq.model import Code2Seq
from code2seq.model.modules.transformer_comment_decoder import TransformerCommentDecoder
from code2seq.model.modules.metrics import CommentChrF


class CommentCode2Seq(Code2Seq):
def __init__(
self,
model_config: DictConfig,
optimizer_config: DictConfig,
vocabulary: CommentVocabulary,
teacher_forcing: float = 0.0,
):
super(Code2Seq, self).__init__()

self.save_hyperparameters()
self._optim_config = optimizer_config
self._vocabulary = vocabulary

tokenizer = vocabulary.tokenizer

self._pad_idx = tokenizer.pad_token_id
self._eos_idx = tokenizer.eos_token_id
self._sos_idx = tokenizer.bos_token_id
ignore_idx = [self._sos_idx, tokenizer.unk_token_id]
metrics: Dict[str, Metric] = {
f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=self._eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}

# TODO add concatenation and rouge-L metric
metrics.update({f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]})
self._metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
self._decoder = self.get_decoder(model_config, tokenizer.vocab_size, teacher_forcing)

self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="seq-mean")

def get_decoder(self, model_config: DictConfig, vocab_size: int, teacher_forcing: float) -> torch.nn.Module:
if model_config.decoder_type == "LSTM":
decoder_step = LSTMDecoderStep(model_config, vocab_size, self._pad_idx)
return Decoder(decoder_step, vocab_size, self._sos_idx, teacher_forcing)
elif model_config.decoder_type == "Transformer":
return TransformerCommentDecoder(
model_config,
vocab_size=vocab_size,
pad_token=self._pad_idx,
sos_token=self._sos_idx,
eos_token=self._eos_idx,
teacher_forcing=teacher_forcing,
)
else:
raise ValueError

def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
target_sequence = batch.labels if step != "test" else None
# [seq length; batch size; vocab size]
logits, _ = self.logits_from_batch(batch, target_sequence)
logits = logits[:-1]
batch.labels = batch.labels[1:]
result = {f"{step}/loss": self._loss(logits, batch.labels)}

with torch.no_grad():
if step != "train":
prediction = logits.argmax(-1)
metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels)
result.update(
{
f"{step}/f1": metric.f1_score,
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
)
result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels)
else:
result.update({f"{step}/f1": 0, f"{step}/precision": 0, f"{step}/recall": 0})

return result
Loading