Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions lib/tokenizers/decode_stream.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
defmodule Tokenizers.DecodeStream do
@moduledoc """
Implements streaming decoding functionality for tokenizers.
"""

@enforce_keys [:resource]
defstruct [:resource]

@type t :: %__MODULE__{
resource: reference()
}

@doc """
Creates a new decode stream.

The `skip_special_tokens` option determines whether special tokens should be skipped during decoding.
By default, it is set to `false`.
"""
@spec new(boolean()) :: t()
def new(skip_special_tokens \\ false) do
Tokenizers.Native.decoder_stream_new(skip_special_tokens)
end

@doc """
Steps through the decode stream with the given tokenizer and token ID.

Returns `{:ok, string}` if there's a decoded string, or `{:ok, nil}` if there's nothing more to decode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would likely do :done for when there is nothing more to decode.

Copy link
Contributor Author

@Virviil Virviil Apr 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a problem here. While it looks like it's stream process, it's actually not. It's not exhausted. One can call
step(12345678) and then call step(0) - it's valid from original library perspective. Or one can call step(n) multiple times (this is even not idempotent).

https://docs.rs/tokenizers/0.21.1/tokenizers/tokenizer/struct.DecodeStream.html

use tokenizers::Tokenizer;
let tokenizer = Tokenizer::from_file("data/roberta.json").unwrap();

let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(713).unwrap(), Some("This".to_string()));
assert_eq!(decode_stream.step(16).unwrap(), Some(" is".to_string()));
assert_eq!(decode_stream.step(41).unwrap(), Some(" an".to_string()));
assert_eq!(
    decode_stream.step(1246).unwrap(),
    Some(" example".to_string())
);

From this perspective, :done looks inappropriate for having semantics of "last action", while it's not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe :out_of_range?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well

Enum.at/2 returns nil for out of range.
Erlang's libs (like array) raise argumentError.

I've never seen :out_of_range anywhere across ecosystem, so may be it will not work good.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enum.at does not wrap it in a tuple though. Enum.fetch, which returns {:ok, _} or :error, is a better example but there are only two states. My understanding is that this code has three states:

  1. In range
  2. Out of range
  3. Error when decoding

So you need three states accordingly. I don't mind what we call it but I'd say the three states should be distinct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to :out_of_range

Returns `{:error, reason}` if an error occurs during decoding.
"""
def step(%__MODULE__{} = decode_stream, tokenizer, id) when is_integer(id) do
case Tokenizers.Native.decoder_stream_step(decode_stream, tokenizer, id) do
{:ok, result} -> {:ok, result}
{:error, reason} -> {:error, reason}
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case Tokenizers.Native.decoder_stream_step(decode_stream, tokenizer, id) do
{:ok, result} -> {:ok, result}
{:error, reason} -> {:error, reason}
end
Tokenizers.Native.decoder_stream_step(decode_stream, tokenizer, id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it actually efficient going step after step by giving specific indexes? Or does the upstream code works better by calling something .next? Is this something we should benchmark before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR I translated API as is. Some abstraction can be definitely written on top of it using Stream, but it seems to be a part of separate effort.

Decoding large chunks this way looks better for BEAM because control is given back from NIF. There is a possibility to write decoding without dirty, but i'm not sure somebody will invest into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Thank you!

end

@doc """
Returns information about the decode stream state.
"""
def info(%__MODULE__{} = decode_stream) do
Tokenizers.Native.decoder_stream_info(decode_stream)
end

defimpl Inspect do
def inspect(decode_stream, _opts) do
info = Tokenizers.DecodeStream.info(decode_stream)
"#Tokenizers.DecodeStream<#{inspect(info)}>"
end
end

defimpl String.Chars do
def to_string(decode_stream) do
info = Tokenizers.DecodeStream.info(decode_stream)
"#Tokenizers.DecodeStream<#{inspect(info)}>"
end
end
end
7 changes: 7 additions & 0 deletions lib/tokenizers/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ defmodule Tokenizers.Native do
def decoders_ctc(_options), do: err()
def decoders_sequence(_decoders), do: err()

# DecoderStream
def decoder_stream_step(_decoder_stream, _tokenizer, _id), do: err()
#
def decoder_stream_info(_decoder_stream), do: err()
#
def decoder_stream_new(_skip_special_tokens), do: err()

# Encoding
def encoding_get_length(_encoding), do: err()
def encoding_get_n_sequences(_encoding), do: err()
Expand Down
Loading
Loading