Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/run_streaming_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Description: phi_streaming_llm_v1.py This script demonstrates how to enable streaming LLM for a multimodal model.
import warnings
warnings.filterwarnings("ignore")

import torch
import argparse
import os
from streaming_llm.utils import load, download_url, load_jsonl
from enum import Enum
from transformers import AutoTokenizer

from streaming_llm.enable_streaming_llm import enable_streaming_llm

class InputMode(Enum):
LANGUAGE = 0
VISION = 1
SPEECH = 2
VISION_SPEECH = 3

@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
input_mode = torch.tensor([InputMode.LANGUAGE.value], dtype=torch.long, device=model.device)
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
input_mode=input_mode,
num_logits_to_keep=1 # Only compute last token's logits
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
pos = 0
for _ in range(max_gen_len - 1):
outputs = model(
input_ids=pred_token_idx,
past_key_values=past_key_values,
use_cache=True,
input_mode=input_mode,
num_logits_to_keep=1
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
generated_text = (
tokenizer.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)

now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now

if pred_token_idx == tokenizer.eos_token_id:
break
print(" ".join(generated_text[pos:]), flush=True)
return past_key_values


@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)

past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)

def main(args):
model_name_or_path = args.model_name_or_path
model, tokenizer = load(model_name_or_path)
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
print(f"Loading data from {test_filepath} ...")

if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)

list_data = load_jsonl(test_filepath)
prompts = []
for sample in list_data:
prompts += sample["turns"]

if args.enable_streaming:

print("---------------------------")
print("Enabling streaming LLM ...")
print("---------------------------")
kv_cache = enable_streaming_llm(
model, start_size=args.start_size, recent_size=args.recent_size
)
else:
kv_cache = None

streaming_inference(
model,
tokenizer,
prompts,
kv_cache,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="microsoft/Phi-4-multimodal-instruct"
)
parser.add_argument("--data_root", type=str, default="data/")
parser.add_argument("--enable_streaming",default=True, action="store_true")
parser.add_argument("--start_size", type=int, default=200)
parser.add_argument("--recent_size", type=int, default=800)
args = parser.parse_args()

main(args)
Empty file added phi4mm/__init__.py
Empty file.
235 changes: 235 additions & 0 deletions phi4mm/configuration_phi4mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Phi-4-MM model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class Phi4MMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi4MMModel`]. It is used to instantiate a Phi-4-MM
model according to the specified arguments, defining the model architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 200064):
Vocabulary size of the Phi-4-MM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Phi4MMModel`].
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after computing the attention scores.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model was trained with. This is used to determine the size of the
original RoPE embeddings when using long scaling.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value used for the RMSNorm.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
bos_token_id (`int`, *optional*, defaults to 199999):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 199999):
The id of the "end-of-sequence" token.
pad_token_id (`int`, *optional*, defaults to 199999):
The id of the padding token.
sliding_window (`int`, *optional*):
Sliding window attention window size. If `None`, no sliding window is applied.

Example:

```python
>>> from transformers import Phi4MMModel, Phi4MMConfig

>>> # Initializing a Phi-4-MM style configuration
>>> configuration = Phi4MMConfig.from_pretrained("TBA")

>>> # Initializing a model from the configuration
>>> model = Phi4MMModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "phi4mm"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=200064,
hidden_size=3072,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="silu",
max_position_embeddings=4096,
original_max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=1,
bos_token_id=199999,
eos_token_id=199999,
pad_token_id=199999,
sliding_window=None,
embd_layer: str = "default",
img_processor=None,
audio_processor=None,
vision_lora=None,
speech_lora=None,
**kwargs,
):
self.embd_layer = embd_layer
self.img_processor = img_processor
self.audio_processor = audio_processor
self.vision_lora = vision_lora
self.speech_lora = speech_lora

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window

super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return

rope_scaling_type = self.rope_scaling.get("type", None)

# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
if not len(rope_scaling_short_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
)
Loading