Skip to content

Commit 1ac8e38

Browse files
added dask.dataframe.Series support to CountVectorizer & TfidfVectorizer
1 parent ebedfa8 commit 1ac8e38

File tree

2 files changed

+86
-39
lines changed

2 files changed

+86
-39
lines changed

dask_ml/feature_extraction/text.py

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import dask.bag as db
99
import dask.dataframe as dd
1010
import distributed
11+
import pandas as pd
1112
import numpy as np
1213
import scipy.sparse
1314
import sklearn.base
@@ -16,7 +17,6 @@
1617
from dask.delayed import Delayed
1718
from distributed import get_client, wait
1819
from sklearn.utils.validation import check_is_fitted
19-
from builtins import getattr
2020

2121
FLOAT_DTYPES = (np.float64, np.float32, np.float16)
2222

@@ -120,18 +120,6 @@ def _hasher(self):
120120
return sklearn.feature_extraction.text.FeatureHasher
121121

122122

123-
def _n_samples(X):
124-
"""Count the number of samples dask array X."""
125-
def chunk_n_samples(chunk, axis, keepdims):
126-
return np.array([chunk.shape[0]])
127-
128-
return da.reduction(X,
129-
chunk=chunk_n_samples,
130-
aggregate=np.sum,
131-
concatenate=False,
132-
dtype=X.dtype).compute()
133-
134-
135123
def _document_frequency(X, dtype):
136124
"""Count the number of non-zero values for each feature in dask array X."""
137125
def chunk_doc_freq(chunk, axis, keepdims):
@@ -172,7 +160,9 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
172160
Examples
173161
--------
174162
The Dask-ML implementation currently requires that ``raw_documents``
175-
is a :class:`dask.bag.Bag` of documents (lists of strings).
163+
is either a :class:`dask.bag.Bag` of documents (lists of strings) or
164+
a :class:`dask.dataframe.Series` of documents (Series of strings)
165+
with partitions of type :class:`pandas.Series`.
176166
177167
>>> from dask_ml.feature_extraction.text import CountVectorizer
178168
>>> import dask.bag as db
@@ -184,10 +174,25 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
184174
... 'And this is the third one.',
185175
... 'Is this the first document?',
186176
... ]
187-
>>> corpus = db.from_sequence(corpus, npartitions=2)
177+
>>> corpus_bag = db.from_sequence(corpus, npartitions=2)
188178
>>> vectorizer = CountVectorizer()
189-
>>> X = vectorizer.fit_transform(corpus)
190-
dask.array<concatenate, shape=(nan, 9), dtype=int64, chunksize=(nan, 9), ...
179+
>>> X = vectorizer.fit_transform(corpus_bag)
180+
dask.array<concatenate, shape=(4, 9), dtype=int64, chunksize=(2, 9), ...
181+
chunktype=scipy.csr_matrix>
182+
>>> X.compute().toarray()
183+
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
184+
[0, 2, 0, 1, 0, 1, 1, 0, 1],
185+
[1, 0, 0, 1, 1, 0, 1, 1, 1],
186+
[0, 1, 1, 1, 0, 0, 1, 0, 1]])
187+
>>> vectorizer.get_feature_names()
188+
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
189+
190+
>>> import dask.dataframe as dd
191+
>>> import pandas as pd
192+
>>> corpus_dds = dd.from_pandas(pd.Series(corpus), npartitions=2)
193+
>>> vectorizer = CountVectorizer()
194+
>>> X = vectorizer.fit_transform(corpus_dds)
195+
dask.array<concatenate, shape=(4, 9), dtype=int64, chunksize=(2, 9), ...
191196
chunktype=scipy.csr_matrix>
192197
>>> X.compute().toarray()
193198
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
@@ -199,13 +204,17 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
199204
"""
200205

201206
def fit_transform(self, raw_documents, y=None):
207+
# Note that in general 'self' could refer to an instance of either this
208+
# class or a subclass of this class. Hence it is possible that
209+
# self.get_params() could get unexpected parameters of an instance of a
210+
# subclass. Such parameters need to be excluded here:
202211
subclass_instance_params = self.get_params()
203212
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
204213
params = {key: subclass_instance_params[key]
205214
for key in subclass_instance_params
206215
if key not in excluded_keys}
207-
vocabulary = params.pop("vocabulary")
208216

217+
vocabulary = params.pop("vocabulary")
209218
vocabulary_for_transform = vocabulary
210219

211220
if self.vocabulary is not None:
@@ -217,26 +226,33 @@ def fit_transform(self, raw_documents, y=None):
217226
fixed_vocabulary = False
218227
# Case 2: learn vocabulary from the data.
219228
vocabularies = raw_documents.map_partitions(_build_vocabulary, params)
220-
vocabulary = vocabulary_for_transform = _merge_vocabulary(
221-
*vocabularies.to_delayed()
222-
)
229+
vocabulary = vocabulary_for_transform = (
230+
_merge_vocabulary( *vocabularies.to_delayed() ))
223231
vocabulary_for_transform = vocabulary_for_transform.persist()
224232
vocabulary_ = vocabulary.compute()
225233
n_features = len(vocabulary_)
226234

227-
result = raw_documents.map_partitions(
228-
_count_vectorizer_transform, vocabulary_for_transform, params
229-
)
230-
231235
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
232-
result = build_array(result, n_features, meta)
236+
if isinstance(raw_documents, dd.Series):
237+
result = raw_documents.map_partitions(
238+
_count_vectorizer_transform, vocabulary_for_transform,
239+
params, meta=meta)
240+
else:
241+
result = raw_documents.map_partitions(
242+
_count_vectorizer_transform, vocabulary_for_transform, params)
243+
result = build_array(result, n_features, meta)
244+
result.compute_chunk_sizes()
233245

234246
self.vocabulary_ = vocabulary_
235247
self.fixed_vocabulary_ = fixed_vocabulary
236248

237249
return result
238250

239251
def transform(self, raw_documents):
252+
# Note that in general 'self' could refer to an instance of either this
253+
# class or a subclass of this class. Hence it is possible that
254+
# self.get_params() could get unexpected parameters of an instance of a
255+
# subclass. Such parameters need to be excluded here:
240256
subclass_instance_params = self.get_params()
241257
excluded_keys = getattr(self, '_non_CountVectorizer_params', [])
242258
params = {key: subclass_instance_params[key]
@@ -262,12 +278,17 @@ def transform(self, raw_documents):
262278
vocabulary_for_transform = vocabulary
263279

264280
n_features = vocabulary_length(vocabulary_for_transform)
265-
transformed = raw_documents.map_partitions(
266-
_count_vectorizer_transform, vocabulary_for_transform, params
267-
)
268281
meta = scipy.sparse.eye(0, format="csr", dtype=self.dtype)
269-
return build_array(transformed, n_features, meta)
270-
282+
if isinstance(raw_documents, dd.Series):
283+
result = raw_documents.map_partitions(
284+
_count_vectorizer_transform, vocabulary_for_transform,
285+
params, meta=meta)
286+
else:
287+
transformed = raw_documents.map_partitions(
288+
_count_vectorizer_transform, vocabulary_for_transform, params)
289+
result = build_array(transformed, n_features, meta)
290+
result.compute_chunk_sizes()
291+
return result
271292

272293
class TfidfTransformer(sklearn.feature_extraction.text.TfidfTransformer):
273294
"""Transform a count matrix to a normalized tf or tf-idf representation
@@ -316,7 +337,7 @@ def fit(self, X, y=None):
316337
dtype = X.dtype if X.dtype in FLOAT_DTYPES else np.float64
317338

318339
if self.use_idf:
319-
n_samples, n_features = _n_samples(X), X.shape[1]
340+
n_samples, n_features = X.shape
320341
df = _document_frequency(X, dtype)
321342
# df = df.astype(dtype, **_astype_copy_false(df))
322343

@@ -409,7 +430,9 @@ class TfidfVectorizer(CountVectorizer):
409430
Examples
410431
--------
411432
The Dask-ML implementation currently requires that ``raw_documents``
412-
is a :class:`dask.bag.Bag` of documents (lists of strings).
433+
is either a :class:`dask.bag.Bag` of documents (lists of strings) or
434+
a :class:`dask.dataframe.Series` of documents (Series of strings)
435+
with partitions of type :class:`pandas.Series`.
413436
414437
>>> from dask_ml.feature_extraction.text import TfidfVectorizer
415438
>>> import dask.bag as db
@@ -421,10 +444,29 @@ class TfidfVectorizer(CountVectorizer):
421444
... 'And this is the third one.',
422445
... 'Is this the first document?',
423446
... ]
424-
>>> corpus = db.from_sequence(corpus, npartitions=2)
447+
>>> corpus_bag = db.from_sequence(corpus, npartitions=2)
448+
>>> vectorizer = TfidfVectorizer()
449+
>>> X = vectorizer.fit_transform(corpus_bag)
450+
dask.array<concatenate, shape=(4, 9), dtype=float64, chunksize=(2, 9), ...
451+
chunktype=scipy.csr_matrix>
452+
>>> X.compute().toarray()
453+
array([[0. , 0.46979139, 0.58028582, 0.38408524, 0. ,
454+
0. , 0.38408524, 0. , 0.38408524],
455+
[0. , 0.6876236 , 0. , 0.28108867, 0. ,
456+
0.53864762, 0.28108867, 0. , 0.28108867],
457+
[0.51184851, 0. , 0. , 0.26710379, 0.51184851,
458+
0. , 0.26710379, 0.51184851, 0.26710379],
459+
[0. , 0.46979139, 0.58028582, 0.38408524, 0. ,
460+
0. , 0.38408524, 0. , 0.38408524]])
461+
>>> vectorizer.get_feature_names()
462+
['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
463+
464+
>>> import dask.dataframe as dd
465+
>>> import pandas as pd
466+
>>> corpus_dds = dd.from_pandas(pd.Series(corpus), npartitions=2)
425467
>>> vectorizer = TfidfVectorizer()
426-
>>> X = vectorizer.fit_transform(corpus)
427-
dask.array<concatenate, shape=(nan, 9), dtype=float64, chunksize=(nan, 9), ...
468+
>>> X = vectorizer.fit_transform(corpus_dds)
469+
dask.array<concatenate, shape=(4, 9), dtype=float64, chunksize=(2, 9), ...
428470
chunktype=scipy.csr_matrix>
429471
>>> X.compute().toarray()
430472
array([[0. , 0.46979139, 0.58028582, 0.38408524, 0. ,

tests/feature_extraction/test_text.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,13 @@ def test_count_vectorizer_remote_vocabulary():
186186

187187

188188
@pytest.mark.parametrize("distributed", [True, False])
189+
@pytest.mark.parametrize("collection_type", ["Bag", "Series"])
189190
@pytest.mark.parametrize("norm", ["l1", "l2"])
190191
@pytest.mark.parametrize("use_idf", [True, False])
191192
@pytest.mark.parametrize("smooth_idf", [True, False])
192193
@pytest.mark.parametrize("sublinear_tf", [True, False])
193194
def test_tfidf_vectorizer(distributed,
195+
collection_type,
194196
norm,
195197
use_idf,
196198
smooth_idf,
@@ -200,7 +202,10 @@ def test_tfidf_vectorizer(distributed,
200202
use_idf=use_idf,
201203
smooth_idf=smooth_idf,
202204
sublinear_tf=sublinear_tf))
203-
b = db.from_sequence(JUNK_FOOD_DOCS, npartitions=2)
205+
if collection_type == "Bag":
206+
docs = db.from_sequence(JUNK_FOOD_DOCS, npartitions=2)
207+
elif collection_type == "Series":
208+
docs = dd.from_pandas(pd.Series(JUNK_FOOD_DOCS), npartitions=2)
204209
r1 = m1.fit_transform(JUNK_FOOD_DOCS)
205210

206211
m2 = (dask_ml.feature_extraction.text
@@ -214,7 +219,7 @@ def test_tfidf_vectorizer(distributed,
214219
else:
215220
client = dummy_context()
216221

217-
r2 = m2.fit_transform(b)
222+
r2 = m2.fit_transform(docs)
218223

219224
with client:
220225
exclude = {"vocabulary_actor_", "stop_words_"}
@@ -228,7 +233,7 @@ def test_tfidf_vectorizer(distributed,
228233
np.testing.assert_array_almost_equal(r1.toarray(),
229234
r2.compute().toarray())
230235

231-
r3 = m2.transform(b)
236+
r3 = m2.transform(docs)
232237
assert isinstance(r3, da.Array)
233238
assert isinstance(r3._meta, scipy.sparse.csr_matrix)
234239
np.testing.assert_array_almost_equal(r1.toarray(),

0 commit comments

Comments
 (0)