Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ jobs:
-save_data /tmp/onmt_feat \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
-n_sample -1 \
&& rm -rf /tmp/sample
- name: Test field/transform dump
Expand Down Expand Up @@ -259,21 +258,34 @@ jobs:
-config data/features_data.yaml \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-hidden_size 2 -batch_size 10 \
-num_workers 0 -bucket_size 1024 \
-word_vec_size 5 -hidden_size 10 \
-report_every 5 -train_steps 10 \
-save_model /tmp/onmt.model \
-save_checkpoint_steps 10
- name: Testing training with features and dynamic scoring
run: |
python onmt/bin/train.py \
-config data/features_data.yaml \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-hidden_size 2 -batch_size 10 \
-word_vec_size 5 -hidden_size 10 \
-num_workers 0 -bucket_size 1024 \
-report_every 5 -train_steps 10 \
-train_metrics "BLEU" "TER" \
-valid_metrics "BLEU" "TER" \
-save_model /tmp/onmt.model \
-save_checkpoint_steps 10
- name: Testing translation with features
run: |
python translate.py \
-model /tmp/onmt.model_step_10.pt \
-src data/data_features/src-test.txt \
-src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \
-verbose
-src data/data_features/src-test-with-feats.txt \
-n_src_feats 1 -verbose
- name: Test RNN translation
run: |
head data/src-test.txt > /tmp/src-test.txt
Expand Down
1 change: 1 addition & 0 deletions data/data_features/src-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-test.feat0

This file was deleted.

3 changes: 3 additions & 0 deletions data/data_features/src-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
3 changes: 0 additions & 3 deletions data/data_features/src-train.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/src-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-val.feat0

This file was deleted.

16 changes: 12 additions & 4 deletions data/features_data.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@

# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
src_feats:
feat0: data/data_features/src-train.feat0
transforms: [filterfeats, inferfeats]
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val.txt
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
src_feats_defaults: "0"
75 changes: 28 additions & 47 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,39 +620,34 @@ Training options to perform vocabulary update are:

## How can I use source word features?

Extra information can be added to the words in the source sentences by defining word features.
Additional word-level information can be incorporated into the model by defining word features in the source sentence.

Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files:
Word features must be appended to the actual textual data by using the special character │ as a feature separator. For instance:

data.src
```
however, according to the logs, she is hard-working.
however│C ■,│N according│L to│L the│L logs│L ■,│N she│L is│L hard-working│L ■.│N
```

feat.txt
Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform if tokenization has been applied. For instace:

```
A C C C C A A B
SRC: however,│C according│L to│L the│L logs,│L she│L is│L hard-working.│L
TOKENIZED SRC: however ■, according to the logs ■, she is hard-working ■.
RESULT: however│C ■,│C according│L to│L the│L logs│L ■,│L she│L is│L hard│L ■-■│L working│L ■.│L
```

Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform if tokenization has been applied.
**Options**
- `-n_src_feats`: the expected number of source features per token.
- `-src_feats_defaults` (optional): provides default values for features. This can be really useful when mixing task specific data (with features) with general data which has not been annotated.

No previous tokenization:
```
SRC: this is a test.
FEATS: A A A B
TOKENIZED SRC: this is a test ■.
RESULT: A A A B <null>
```
For the Transformer architecture make sure the following options are appropriately set:

Previously tokenized:
```
SRC: this is a test ■.
FEATS: A A A B A
RESULT: A A A B A
```
- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
- `feat_merge`: how to handle features vecs
- `feat_vec_size` or maybe `feat_vec_exponent`

**Notes**
- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality.
- `FeatInferTransform` transform is required in order to ensure the functionality.
- Not possible to do shared embeddings (at least with `feat_merge: concat` method)

Sample config file:
Expand All @@ -662,50 +657,36 @@ data:
dummy:
path_src: data/train/data.src
path_tgt: data/train/data.tgt
src_feats:
feat_0: data/train/data.src.feat_0
feat_1: data/train/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong]
transforms: [onmt_tokenize, inferfeats, filtertoolong]
weight: 1
valid:
path_src: data/valid/data.src
path_tgt: data/valid/data.tgt
src_feats:
feat_0: data/valid/data.src.feat_0
feat_1: data/valid/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats]
transforms: [onmt_tokenize, inferfeats]

# Transform options
reversible_tokenization: "joiner"
prior_tokenization: true

# Vocab opts
src_vocab: exp/data.vocab.src
tgt_vocab: exp/data.vocab.tgt
src_feats_vocab:
feat_0: exp/data.vocab.feat_0
feat_1: exp/data.vocab.feat_1

# Features options
n_src_feats: 2
src_feats_defaults: "0│1"
feat_merge: "sum"
```

During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature names with their data file.
To allow source features in the server add the following parameters in the server's config file:

```
{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}
```

**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example:

```bash
python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}"
"features": {
"n_src_feats": 2,
"src_feats_defaults": "0│1",
"reversible_tokenization": "joiner"
}
```

When using the Transformer architecture make sure the following options are appropriately set:

- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
- `feat_merge`: how to handle features vecs
- `feat_vec_size` and maybe `feat_vec_exponent`

## How can I set up a translation server ?
A REST server was implemented to serve OpenNMT-py models. A discussion is opened on the OpenNMT forum: [discussion link](https://forum.opennmt.net/t/simple-opennmt-py-rest-server/1392).

Expand Down
54 changes: 24 additions & 30 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
from onmt.inputters.text_utils import process
from onmt.inputters.text_utils import process, append_features_to_text
from onmt.transforms import make_transforms, get_transforms_cls
from onmt.constants import CorpusName, CorpusTask
from collections import Counter, defaultdict
from collections import Counter
import multiprocessing as mp


Expand Down Expand Up @@ -40,21 +40,11 @@ def write_files_from_queues(sample_path, queues):
break


# Just for debugging purposes
# It appends features to subwords when dumping to file
def append_features_to_example(example, features):
ex_toks = example.split(' ')
feat_toks = features.split(' ')
toks = [f"{subword}│{feat}" for subword, feat in
zip(ex_toks, feat_toks)]
return " ".join(toks)


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -70,19 +60,22 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
continue
src_line, tgt_line = (maybe_example['src']['src'],
maybe_example['tgt']['tgt'])
src_line_pretty = src_line
for feat_name, feat_line in maybe_example["src"].items():
if feat_name not in ["src", "src_original"]:
sub_counter_src_feats[feat_name].update(
feat_line.split(' '))
if opts.dump_samples:
src_line_pretty = append_features_to_example(
src_line_pretty, feat_line)
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))

if 'feats' in maybe_example['src']:
src_feats_lines = maybe_example['src']['feats']
for i in range(opts.n_src_feats):
sub_counter_src_feats[i].update(
src_feats_lines[i].split(' '))
else:
src_feats_lines = []

if opts.dump_samples:
src_pretty_line = append_features_to_text(
src_line, src_feats_lines)
build_sub_vocab.queues[c_name][offset].put(
(i, src_line_pretty, tgt_line))
(i, src_pretty_line, tgt_line))
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
Expand Down Expand Up @@ -113,7 +106,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -134,7 +127,8 @@ def build_vocab(opts, transforms, n_sample=3):
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
for i in range(opts.n_src_feats):
counter_src_feats[i].update(sub_counter_src_feats[i])
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
Expand Down Expand Up @@ -166,10 +160,10 @@ def build_vocab_main(opts):
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")
logger.info(f"Counters src: {len(src_counter)}")
logger.info(f"Counters tgt: {len(tgt_counter)}")
for i, feat_counter in enumerate(src_feats_counter):
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -186,8 +180,8 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])
for i, c in enumerate(src_feats_counter):
save_counter(c, f"{opts.src_vocab}_feat{i}")


def _get_parser():
Expand Down
32 changes: 14 additions & 18 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ def build_vocab(opt, specials):
""" Build vocabs dict to be stored in the checkpoint
based on vocab files having each line [token, count]
Args:
opt: src_vocab, tgt_vocab, src_feats_vocab
opt: src_vocab, tgt_vocab, n_src_feats
Return:
vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
'src_feats' : {'feat0': pyonmttok.Vocab,
'feat1': pyonmttok.Vocab, ...},
'src_feats' : [pyonmttok.Vocab, ...]},
'data_task': seq2seq or lm
}
"""
Expand Down Expand Up @@ -85,10 +84,10 @@ def _pad_vocab_to_multiple(vocab, multiple):
opt.vocab_size_multiple)
vocabs['tgt'] = tgt_vocab

if opt.src_feats_vocab:
src_feats = {}
for feat_name, filepath in opt.src_feats_vocab.items():
src_f_vocab = _read_vocab_file(filepath, 1)
if opt.n_src_feats > 0:
src_feats_vocabs = []
for i in range(opt.n_src_feats):
src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1)
src_f_vocab = pyonmttok.build_vocab_from_tokens(
src_f_vocab,
maximum_size=0,
Expand All @@ -101,8 +100,8 @@ def _pad_vocab_to_multiple(vocab, multiple):
if opt.vocab_size_multiple > 1:
src_f_vocab = _pad_vocab_to_multiple(src_f_vocab,
opt.vocab_size_multiple)
src_feats[feat_name] = src_f_vocab
vocabs['src_feats'] = src_feats
src_feats_vocabs.append(src_f_vocab)
vocabs["src_feats"] = src_feats_vocabs

vocabs['data_task'] = opt.data_task

Expand Down Expand Up @@ -146,10 +145,8 @@ def vocabs_to_dict(vocabs):
vocabs_dict['src'] = vocabs['src'].ids_to_tokens
vocabs_dict['tgt'] = vocabs['tgt'].ids_to_tokens
if 'src_feats' in vocabs.keys():
vocabs_dict['src_feats'] = {}
for feat in vocabs['src_feats'].keys():
vocabs_dict['src_feats'][feat] = \
vocabs['src_feats'][feat].ids_to_tokens
vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['src_feats']]
vocabs_dict['data_task'] = vocabs['data_task']
return vocabs_dict

Expand All @@ -167,9 +164,8 @@ def dict_to_vocabs(vocabs_dict):
else:
vocabs['tgt'] = pyonmttok.build_vocab_from_tokens(vocabs_dict['tgt'])
if 'src_feats' in vocabs_dict.keys():
vocabs['src_feats'] = {}
for feat in vocabs_dict['src_feats'].keys():
vocabs['src_feats'][feat] = \
pyonmttok.build_vocab_from_tokens(
vocabs_dict['src_feats'][feat])
vocabs['src_feats'] = []
for feat_vocab in vocabs_dict['src_feats']:
vocabs['src_feats'].append(
pyonmttok.build_vocab_from_tokens(feat_vocab))
return vocabs
Loading