Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions top2vec/Top2Vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class Top2Vec:

verbose: bool (Optional, default True)
Whether to print status data during training.

batch_size_embed: int (Optional, default 500)
The batch size for embedding of documents.
"""

def __init__(self,
Expand All @@ -186,7 +189,8 @@ def __init__(self,
use_embedding_model_tokenizer=False,
umap_args=None,
hdbscan_args=None,
verbose=True
verbose=True,
batch_size_embed=500
):

if verbose:
Expand Down Expand Up @@ -340,10 +344,10 @@ def return_doc(doc):

# embed documents
if use_embedding_model_tokenizer:
self.document_vectors = self._embed_documents(documents)
self.document_vectors = self._embed_documents(documents, batch_size_embed)
else:
train_corpus = [' '.join(tokens) for tokens in tokenized_corpus]
self.document_vectors = self._embed_documents(train_corpus)
self.document_vectors = self._embed_documents(train_corpus, batch_size_embed)

else:
raise ValueError(f"{embedding_model} is an invalid embedding model.")
Expand Down Expand Up @@ -518,13 +522,12 @@ def _l2_normalize(vectors):
else:
return normalize(vectors.reshape(1, -1))[0]

def _embed_documents(self, train_corpus):
def _embed_documents(self, train_corpus, batch_size):

self._check_import_status()
self._check_model_status()

# embed documents
batch_size = 500
document_vectors = []

current = 0
Expand Down Expand Up @@ -1131,7 +1134,7 @@ def get_documents_topics(self, doc_ids, reduced=False):

return doc_topics, doc_dist, topic_words, topic_word_scores

def add_documents(self, documents, doc_ids=None, tokenizer=None, use_embedding_model_tokenizer=False):
def add_documents(self, documents, doc_ids=None, tokenizer=None, use_embedding_model_tokenizer=False, batch_size_embed=500):
"""
Update the model with new documents.

Expand Down Expand Up @@ -1159,6 +1162,9 @@ def add_documents(self, documents, doc_ids=None, tokenizer=None, use_embedding_m
use_embedding_model_tokenizer: bool (Optional, default False)
If using an embedding model other than doc2vec, use the model's
tokenizer for document embedding.

batch_size_embed: int (Optional, default 500)
The batch size for embedding of documents.
"""
# if tokenizer is not passed use default
if tokenizer is None:
Expand Down Expand Up @@ -1205,7 +1211,7 @@ def add_documents(self, documents, doc_ids=None, tokenizer=None, use_embedding_m
else:
docs_processed = [tokenizer(doc) for doc in documents]
docs_training = [' '.join(doc) for doc in docs_processed]
document_vectors = self._embed_documents(docs_training)
document_vectors = self._embed_documents(docs_training, batch_size_embed)
self._set_document_vectors(np.vstack([self._get_document_vectors(), document_vectors]))

# update index
Expand Down