33import numpy as np
44
55from time import time_ns
6- from enum import Enum , auto
6+ from enum import Enum
77from collections import defaultdict
88from collections .abc import Generator
99
@@ -243,6 +243,8 @@ def cosine_similarity(vector1: np.ndarray, vector2: np.ndarray) -> float:
243243 dot_product = np .dot (vector1 , vector2 )
244244 norm_vector1 = np .linalg .norm (vector1 )
245245 norm_vector2 = np .linalg .norm (vector2 )
246+ if norm_vector1 == 0 or norm_vector2 == 0 :
247+ return 0.0
246248 similarity = dot_product / (norm_vector1 * norm_vector2 )
247249 return similarity
248250
@@ -749,15 +751,10 @@ def get_vocabulary(self) -> dict:
749751 return dict (sorted (self .vocabulary_ .items (), key = lambda x : x [1 ]))
750752
751753
752- class TokenType (Enum ):
753- CHAR = auto ()
754- WORD = auto ()
755-
756-
757754class NGram :
758755 def __init__ (self ,
759756 n : int = 3 ,
760- token_type : TokenType = TokenType . CHAR ,
757+ token_type : str = "char" ,
761758 start_token : str = '$' ,
762759 end_token : str = '^' ,
763760 separator : str = ' ' ):
@@ -771,12 +768,12 @@ def __init__(self,
771768 self .transitions = defaultdict (list )
772769
773770 def _tokenize (self , text : str ) -> list [str ]:
774- if self .token_type == TokenType . CHAR :
771+ if self .token_type == "char" :
775772 return list (text )
776773 return text .split (self .separator )
777774
778775 def _join_tokens (self , tokens : list [str ]) -> str :
779- if self .token_type == TokenType . CHAR :
776+ if self .token_type == "char" :
780777 return '' .join (tokens )
781778 return self .separator .join (tokens )
782779
@@ -810,7 +807,7 @@ def fit(self, sequences: list[str]) -> "NGram":
810807 return self
811808
812809 def _get_random_start (self ) -> list [str ]:
813- if self .token_type == TokenType . CHAR :
810+ if self .token_type == "char" :
814811 return [self .start_token ] * (self .n - 1 )
815812
816813 start_contexts = [
@@ -840,7 +837,7 @@ def generate_sequence(self, min_length: int = 5, max_length: int = None, variabi
840837 context = tuple (current [- (self .n - 1 ):])
841838
842839 if context not in self .ngrams :
843- if (self .token_type == TokenType . WORD and
840+ if (self .token_type == "word" and
844841 current [- 1 ] in self .transitions ):
845842 next_token = random .choice (self .transitions [current [- 1 ]])
846843 current .append (next_token )
@@ -855,15 +852,15 @@ def generate_sequence(self, min_length: int = 5, max_length: int = None, variabi
855852 if len (sequence ) >= min_length :
856853 if max_length is None or len (sequence ) <= max_length :
857854 result = self ._join_tokens (sequence )
858- if self .token_type == TokenType . WORD :
855+ if self .token_type == "word" :
859856 result = result .capitalize ()
860857 return result
861858 break
862859
863860 if max_length and len (current ) - (self .n - 1 ) > max_length :
864861 break
865862
866- if (self .token_type == TokenType . WORD and
863+ if (self .token_type == "word" and
867864 random .random () < variability and
868865 current [- 1 ] in self .transitions ):
869866 next_token = random .choice (self .transitions [current [- 1 ]])
@@ -894,9 +891,6 @@ def get_contexts(self) -> dict:
894891 return dict (self .ngrams )
895892
896893
897- import numpy as np
898- from time import time_ns
899-
900894class ImageDataGenerator :
901895 def __init__ (
902896 self ,
0 commit comments