diff --git a/pytorch_tokenizers/__init__.py b/pytorch_tokenizers/__init__.py index 441117a..49d6a4f 100644 --- a/pytorch_tokenizers/__init__.py +++ b/pytorch_tokenizers/__init__.py @@ -6,22 +6,60 @@ # @lint-ignore-every LICENSELINT +from enum import Enum from typing import Optional from .hf_tokenizer import HuggingFaceTokenizer from .llama2c import Llama2cTokenizer +from .sentencepiece import SentencePieceTokenizer from .tiktoken import TiktokenTokenizer -__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"] +class TokenizerType(Enum): + LLAMA2C = "llama2c" + SENTENCEPIECE = "sentencepiece" + TIKTOKEN = "tiktoken" + HUGGINGFACE = "huggingface" + + @classmethod + def from_str(cls, value: str) -> "TokenizerType": + """Create TokenizerType from string value (case-insensitive)""" + value_lower = value.lower() + for tokenizer_type in cls: + if tokenizer_type.value == value_lower: + return tokenizer_type + raise ValueError(f"Invalid tokenizer type: {value}. Valid options: {[t.value for t in cls]}") -def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): + +__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer", "SentencePieceTokenizer", "TokenizerType"] + + +def get_tokenizer( + tokenizer_path: str, + tokenizer_config_path: Optional[str] = None, + tokenizer_type: Optional[TokenizerType] = None +): + if tokenizer_type is not None: + if tokenizer_type == TokenizerType.HUGGINGFACE: + return HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) + elif tokenizer_type == TokenizerType.LLAMA2C: + return Llama2cTokenizer(model_path=str(tokenizer_path)) + elif tokenizer_type == TokenizerType.SENTENCEPIECE: + return SentencePieceTokenizer(model_path=str(tokenizer_path)) + elif tokenizer_type == TokenizerType.TIKTOKEN: + return TiktokenTokenizer(model_path=str(tokenizer_path)) + + # Default fallback to auto-detection if tokenizer_path.endswith(".json"): tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) else: try: tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path)) except Exception: - print("Using Tiktokenizer") - tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path)) + try: + print("Using SentencePiece tokenizer") + tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path)) + except Exception: + print("Using Tiktokenizer") + tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path)) return tokenizer diff --git a/pytorch_tokenizers/sentencepiece.py b/pytorch_tokenizers/sentencepiece.py new file mode 100644 index 0000000..97a450d --- /dev/null +++ b/pytorch_tokenizers/sentencepiece.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# @lint-ignore-every LICENSELINT + +import logging +import os +from typing import List + +from sentencepiece import SentencePieceProcessor + + +class SentencePieceTokenizer: + def __init__(self, model_path: str): + assert os.path.isfile( + model_path + ), f"Need a valid tokenizer model path but got {model_path}" + self.sp_model = SentencePieceProcessor(model_file=model_path) + self.model_path = model_path + + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + logging.info( + f"SentencePiece - #words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) + + def decode_token(self, t: int) -> str: + return self.sp_model.id_to_piece(t)