Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 100 additions & 34 deletions flair/embeddings/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
from flair.file_utils import cached_path
from flair.nn import LockedDropout, WordDropout

# NEW: Import diskcache instead of sqlitedict
try:
from diskcache import Cache
except ImportError:
log.warning("-" * 100)
log.warning("ATTENTION! The library 'diskcache' is not installed!")
log.warning(
'Flair embeddings cache needs diskcache. '
'Please install with "pip install diskcache"'
)
log.warning("-" * 100)
Cache = None # Set Cache to None if not installed

log = logging.getLogger("flair")


Expand Down Expand Up @@ -162,6 +175,9 @@ def __setstate__(self, state):
class CharLMEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""

# Make chars_per_chunk accessible if needed by LM (might be needed)
chars_per_chunk: int = 512

@deprecated(version="0.4", reason="Use 'FlairEmbeddings' instead.")
def __init__(
self,
Expand Down Expand Up @@ -353,29 +369,49 @@ def __init__(

from flair.models import LanguageModel

self.instance_parameters = self.get_instance_parameters(locals=locals()) # Keep instance params

self.lm = LanguageModel.load_language_model(model)
self.detach = detach

self.is_forward_lm: bool = self.lm.is_forward_lm

# initialize cache if use_cache set
# --- MODIFIED: Initialize diskcache ---
self.cache = None
if use_cache:
cache_path = (
Path(f"{self.name}-tmp-cache.sqllite")
if not cache_directory
else cache_directory / f"{self.name}-tmp-cache.sqllite"
)
from sqlitedict import SqliteDict

self.cache = SqliteDict(str(cache_path), autocommit=True)
if use_cache and Cache is not None:
# Use a directory based on the model name
# Create safe dir name
model_name_slug = re.sub(r"[^a-zA-Z0-9_-]", "-", Path(model).stem)
if cache_directory is None:
# Default cache location within flair's cache root
cache_path = flair.cache_root / "embeddings_cache" / model_name_slug
else:
# Use user-provided directory + model-specific sub-directory
cache_path = Path(cache_directory) / model_name_slug

log.info(f"Cache storage path is {cache_path}")
# Make sure directory exists
cache_path.mkdir(parents=True, exist_ok=True)
# Initialize diskcache.Cache
self.cache = Cache(str(cache_path)) # Size limit etc. can be added
elif use_cache and Cache is None:
log.warning("diskcache not installed, caching disabled.")
use_cache = False # Explicitly disable if library is missing

# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence(["hello"])
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())
# Ensure correct grad context
with torch.no_grad() if self.detach else torch.enable_grad():
embedded_dummy = self.embed(dummy_sentence)
# Check if embedding was successful before accessing length
if len(embedded_dummy) > 0 and len(embedded_dummy[0]) > 0 and embedded_dummy[0].get_token(1) is not None:
self.__embedding_length: int = len(embedded_dummy[0].get_token(1).get_embedding())
else:
# Fallback or error if dummy embedding failed
log.warning("Could not determine embedding length from dummy sentence.")
# Try getting length from LM state size (less reliable for final embedding)
self.__embedding_length: int = self.lm.hidden_size

# set to eval mode
self.eval()

def train(self, mode=True):
Expand All @@ -395,23 +431,40 @@ def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
# if cache is used, try setting embeddings from cache first
if "cache" in self.__dict__ and self.cache is not None:
# try populating embeddings from cache
all_embeddings_retrieved_from_cache: bool = True
for sentence in sentences:
key = sentence.to_tokenized_string()
embeddings = self.cache.get(key)

if not embeddings:
all_embeddings_retrieved_from_cache = False
break
else:
for token, embedding in zip(sentence, embeddings):
token.set_embedding(self.name, torch.FloatTensor(embedding))
# Ensure LM is on the correct device (important if device changed after init)
self.lm.to(flair.device)

if all_embeddings_retrieved_from_cache:
return sentences
# --- MODIFIED: Use diskcache API ---
# Check if caching is enabled and library is available
if self.cache is not None:
all_embeddings_retrieved_from_cache: bool = True
try:
for sentence in sentences:
key = sentence.to_tokenized_string()
# Use cache.get() which returns None if key not found
cached_embeddings = self.cache.get(key)

if cached_embeddings is None: # Check for None explicitly
all_embeddings_retrieved_from_cache = False
break # No need to check further sentences
else:
# Ensure we have the correct number of embeddings
if len(cached_embeddings) == len(sentence.tokens):
for token, embedding_list in zip(sentence, cached_embeddings):
# Convert list back to Tensor
token.set_embedding(self.name, torch.FloatTensor(embedding_list).to(flair.device))
else:
# Data corruption or mismatch, treat as cache miss
log.warning(f"Cache data mismatch for sentence: '{key}'. Recomputing.")
all_embeddings_retrieved_from_cache = False
break

if all_embeddings_retrieved_from_cache:
log.debug("Retrieved embeddings from cache.")
return sentences
except Exception as e:
log.error(f"Error accessing diskcache: {e}. Disabling cache for this batch.")
all_embeddings_retrieved_from_cache = False # Force recompute

# if this is not possible, use LM to generate embedding. First, get text sentences
text_sentences = [sentence.to_tokenized_string() for sentence in sentences]
Expand Down Expand Up @@ -446,11 +499,24 @@ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:

token.set_embedding(self.name, embedding)

if "cache" in self.__dict__ and self.cache is not None:
for sentence in sentences:
self.cache[sentence.to_tokenized_string()] = [
token._embeddings[self.name].tolist() for token in sentence
]
# --- MODIFIED: Store in diskcache ---
# Check if caching is enabled and library is available
if self.cache is not None:
try:
for sentence in sentences:
# Check if all tokens received embeddings before caching
if all(self.name in token._embeddings for token in sentence.tokens):
key = sentence.to_tokenized_string()
# Store list of lists of floats (move tensor to CPU before tolist)
value = [token.get_embedding().cpu().tolist() for token in sentence]
self.cache.set(key, value) # Use cache.set()
else:
log.warning(f"Skipping caching for sentence with missing embeddings: '{sentence.to_tokenized_string()[:50]}...'")
log.debug("Stored embeddings in cache.")
except Exception as e:
log.error(f"Error writing to diskcache: {e}.")
# Optionally: self.cache.close(), self.cache = None to disable further cache attempts
# --- End Store Modification ---

return sentences

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pytorch_revgrad>=0.2.0
regex>=2022.1.18
scikit-learn>=1.0.2
segtok>=1.5.11
sqlitedict>=2.0.0
diskcache>=5.6.0
tabulate>=0.8.10
torch>=1.13.1
tqdm>=4.63.0
Expand Down