From 824d25bf517405abdf9124b17ac1ae78d2c95845 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Thu, 1 Jun 2023 14:34:54 +0800 Subject: [PATCH 1/2] from_dataset ignore null instance --- fastNLP/core/vocabulary.py | 4 ++++ tests/core/test_vocabulary.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 tests/core/test_vocabulary.py diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 3a6ab650..a5dac145 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -396,6 +396,10 @@ def from_dataset(self, *datasets, field_name:Union[str,List[str]], no_create_ent def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] + # 如果 field 为空或者 None, 那么直接跳过即可。 + if field is None or len(field) == 0: + logger.warning(f"instance: {ins} has null field. Skip now!") + continue if isinstance(field, str) or not _is_iterable(field): self.add_word(field, no_create_entry=no_create_entry) else: diff --git a/tests/core/test_vocabulary.py b/tests/core/test_vocabulary.py new file mode 100644 index 00000000..de601252 --- /dev/null +++ b/tests/core/test_vocabulary.py @@ -0,0 +1,15 @@ +import pytest +from collections import Counter + +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary +from fastNLP import logger + + +class TestVocabulary: + + def test_from_dataset(self): + ds = DataSet({"x": [[1, 2], [3, 4]], "y": ["apple", ""]}) + vocab = Vocabulary() + vocab.from_dataset(ds, field_name="y") + assert vocab.word_count == Counter({'apple': 1}) From 4237af3f409594ce791c4500f2115b0297e32da2 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Mon, 5 Jun 2023 11:00:24 +0800 Subject: [PATCH 2/2] fix bug --- fastNLP/core/vocabulary.py | 2 +- tests/core/test_vocabulary.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index a5dac145..5d4f4725 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -397,7 +397,7 @@ def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] # 如果 field 为空或者 None, 那么直接跳过即可。 - if field is None or len(field) == 0: + if field is None or (hasattr(field, "__len__") and len(field) == 0): logger.warning(f"instance: {ins} has null field. Skip now!") continue if isinstance(field, str) or not _is_iterable(field): diff --git a/tests/core/test_vocabulary.py b/tests/core/test_vocabulary.py index de601252..7787840a 100644 --- a/tests/core/test_vocabulary.py +++ b/tests/core/test_vocabulary.py @@ -13,3 +13,9 @@ def test_from_dataset(self): vocab = Vocabulary() vocab.from_dataset(ds, field_name="y") assert vocab.word_count == Counter({'apple': 1}) + + def test_from_dataset1(self): + ds = DataSet({"x": [[1, 2], [3, 4], [5]], "y": [1, None, 2]}) + vocab = Vocabulary() + vocab.from_dataset(ds, field_name="y") + assert vocab.word_count == Counter({1: 1, 2: 1})