Skip to content

Commit a6ef573

Browse files
committed
fix(cosine_sim): division by zero
1 parent fee19e9 commit a6ef573

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

neuralnetlib/preprocessing.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from time import time_ns
6-
from enum import Enum, auto
6+
from enum import Enum
77
from collections import defaultdict
88
from collections.abc import Generator
99

@@ -243,6 +243,8 @@ def cosine_similarity(vector1: np.ndarray, vector2: np.ndarray) -> float:
243243
dot_product = np.dot(vector1, vector2)
244244
norm_vector1 = np.linalg.norm(vector1)
245245
norm_vector2 = np.linalg.norm(vector2)
246+
if norm_vector1 == 0 or norm_vector2 == 0:
247+
return 0.0
246248
similarity = dot_product / (norm_vector1 * norm_vector2)
247249
return similarity
248250

@@ -749,15 +751,10 @@ def get_vocabulary(self) -> dict:
749751
return dict(sorted(self.vocabulary_.items(), key=lambda x: x[1]))
750752

751753

752-
class TokenType(Enum):
753-
CHAR = auto()
754-
WORD = auto()
755-
756-
757754
class NGram:
758755
def __init__(self,
759756
n: int = 3,
760-
token_type: TokenType = TokenType.CHAR,
757+
token_type: str = "char",
761758
start_token: str = '$',
762759
end_token: str = '^',
763760
separator: str = ' '):
@@ -771,12 +768,12 @@ def __init__(self,
771768
self.transitions = defaultdict(list)
772769

773770
def _tokenize(self, text: str) -> list[str]:
774-
if self.token_type == TokenType.CHAR:
771+
if self.token_type == "char":
775772
return list(text)
776773
return text.split(self.separator)
777774

778775
def _join_tokens(self, tokens: list[str]) -> str:
779-
if self.token_type == TokenType.CHAR:
776+
if self.token_type == "char":
780777
return ''.join(tokens)
781778
return self.separator.join(tokens)
782779

@@ -810,7 +807,7 @@ def fit(self, sequences: list[str]) -> "NGram":
810807
return self
811808

812809
def _get_random_start(self) -> list[str]:
813-
if self.token_type == TokenType.CHAR:
810+
if self.token_type == "char":
814811
return [self.start_token] * (self.n - 1)
815812

816813
start_contexts = [
@@ -840,7 +837,7 @@ def generate_sequence(self, min_length: int = 5, max_length: int = None, variabi
840837
context = tuple(current[-(self.n - 1):])
841838

842839
if context not in self.ngrams:
843-
if (self.token_type == TokenType.WORD and
840+
if (self.token_type == "word" and
844841
current[-1] in self.transitions):
845842
next_token = random.choice(self.transitions[current[-1]])
846843
current.append(next_token)
@@ -855,15 +852,15 @@ def generate_sequence(self, min_length: int = 5, max_length: int = None, variabi
855852
if len(sequence) >= min_length:
856853
if max_length is None or len(sequence) <= max_length:
857854
result = self._join_tokens(sequence)
858-
if self.token_type == TokenType.WORD:
855+
if self.token_type == "word":
859856
result = result.capitalize()
860857
return result
861858
break
862859

863860
if max_length and len(current) - (self.n - 1) > max_length:
864861
break
865862

866-
if (self.token_type == TokenType.WORD and
863+
if (self.token_type == "word" and
867864
random.random() < variability and
868865
current[-1] in self.transitions):
869866
next_token = random.choice(self.transitions[current[-1]])
@@ -894,9 +891,6 @@ def get_contexts(self) -> dict:
894891
return dict(self.ngrams)
895892

896893

897-
import numpy as np
898-
from time import time_ns
899-
900894
class ImageDataGenerator:
901895
def __init__(
902896
self,

0 commit comments

Comments
 (0)