diff --git a/scratchpad/managers/toppings_manager.py b/scratchpad/managers/toppings_manager.py index b5ec094..fe330cc 100644 --- a/scratchpad/managers/toppings_manager.py +++ b/scratchpad/managers/toppings_manager.py @@ -313,6 +313,7 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch): weight_indices, lora_buffer=None, delta_buffer=self.weights_buffer["lm_head"][:len_deltas], + num_lora = len_loras ) elif "embed_tokens" in module_name: module.set_topping_info( @@ -320,6 +321,7 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch): weight_indices, lora_buffer=None, delta_buffer=self.weights_buffer["embed_tokens"][:len_deltas], + num_lora = len_loras ) elif "qkv_proj" in module_name: module.set_topping_info( diff --git a/scratchpad/nn/toppings/topping_layer.py b/scratchpad/nn/toppings/topping_layer.py index a49d2b2..2496ff9 100644 --- a/scratchpad/nn/toppings/topping_layer.py +++ b/scratchpad/nn/toppings/topping_layer.py @@ -44,9 +44,10 @@ def __init__(self, base_layer: VocabParallelEmbedding, config: Dict) -> None: super().__init__(base_layer, config) self.weight = base_layer.weight - def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=None): + def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=None, num_lora = 0): self.weight_indices = weight_indices self.delta_weights = delta_buffer + self.num_lora = num_lora def forward(self, input_: torch.Tensor): if self.delta_weights == None or self.delta_weights.shape[0] == 0: @@ -58,6 +59,10 @@ def forward(self, input_: torch.Tensor): dtype=self.delta_weights[0].dtype, ) unique_indices = torch.unique(self.weight_indices) + # recover delta indices + mask_delta = (unique_indices >= self.num_lora) & (unique_indices != -1) + unique_indices = unique_indices[mask_delta] - self.num_lora + for id in unique_indices: idx_mask = self.weight_indices == id @@ -298,9 +303,10 @@ class LogitsProcessorWithTopping(BaseLayerWithTopping): def __init__(self, base_layer, config): super().__init__(base_layer, config) - def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=None): + def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=None, num_lora = 0): self.weight_indices = weight_indices self.delta_buffer = delta_buffer + self.num_lora = num_lora def _get_logits( self, @@ -317,6 +323,11 @@ def _get_logits( logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) assert isinstance(logits_metadata, LogitsMetadata) + unique_indices = torch.unique(self.weight_indices) + # recover delta indices + mask_delta = (unique_indices >= self.num_lora) & (unique_indices != -1) + unique_indices = unique_indices[mask_delta] - self.num_lora + # Get the last hidden states and last logits for the next token prediction if logits_metadata.forward_mode.is_decode(): last_index = None @@ -327,7 +338,6 @@ def _get_logits( dtype=last_hidden.dtype, device=last_hidden.device, ) - unique_indices = torch.unique(self.weight_indices) for id in unique_indices: idx_mask = self.weight_indices == id inp = last_hidden[idx_mask] @@ -346,7 +356,6 @@ def _get_logits( dtype=last_hidden.dtype, device=last_hidden.device, ) - unique_indices = torch.unique(self.weight_indices) assert len(unique_indices) == 1, f"Prefill stage only supports one index" w_idx = unique_indices[0] if w_idx == -1: diff --git a/tools/benchmark/bench_perf_toppings.py b/tools/benchmark/bench_perf_toppings.py new file mode 100644 index 0000000..74a4a30 --- /dev/null +++ b/tools/benchmark/bench_perf_toppings.py @@ -0,0 +1,216 @@ +import time +import argparse +from argparse import Namespace +import asyncio +import requests +from typing import List, Optional, Callable, Dict +from tqdm.asyncio import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from tools.benchmark.common import ( + construct_dataset, + get_request, + async_request_openai_completions, + async_request_sp_sysinfo, + RequestFuncOutput, + RequestFuncInput, + calculate_metrics, +) +from tools.benchmark.report import print_benchmark, write_benchmark + + +def check_goodput_args(args: Namespace): + # Check and parse goodput arguments + gootput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + gootput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in gootput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. " + ) + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative." + ) + return gootput_config_dict + + +def parse_goodput(slo_pairs): + gootput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + gootput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + 'Specify service level objectives for goodput as "KEY:VALUE" ' + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds." + ) from err + return gootput_config_dict + + +async def run_benchmark( + args: Namespace, + input_requests: List[RequestFuncInput], + request_func: Callable, + tokenizer: PreTrainedTokenizerBase, + goodput_config_dict: Dict[str, float], + max_concurrency: Optional[int] = None, +): + system_info = await async_request_sp_sysinfo(args.endpoint) + pbar = tqdm(total=len(input_requests)) + tasks: List[asyncio.Task] = [] + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, pbar=pbar) + + benchmark_start_time = time.perf_counter() + async for request in get_request(input_requests, args.request_rate): + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], + goodput_config_dict=goodput_config_dict, + ) + print_benchmark(metrics) + if args.output: + output_file = write_benchmark( + metrics, + args.output, + system_info, + args, + outputs, + ) + print(f"Results written to {output_file}") + return metrics + + +def benchmark(args): + print(args) + if args.model == "": + args.model = args.tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + bench_requests = construct_dataset( + args.endpoint, args.dataset, tokenizer, args.num_prompts + ) + for i, req in enumerate(bench_requests): + req.model = args.model[i%len(args.model)] + request_func = async_request_openai_completions + gootput_config_dict = check_goodput_args(args) + # check if server is ready + server_ready = False + if args.wait_until_ready: + while not server_ready: + try: + requests.get(args.endpoint) + server_ready = True + except Exception as e: + print("Server is not ready. Please start the server first.") + time.sleep(5) + asyncio.run( + run_benchmark( + args, + bench_requests, + request_func, + tokenizer, + goodput_config_dict=gootput_config_dict, + ) + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--endpoint", type=str, default="http://localhost:8080/") + parser.add_argument("--tokenizer", type=str, default="meta-llama/Llama-3.2-3B") + parser.add_argument("--model", type=str, default=[ "eltorio/Llama-3.2-3B-appreciation-1", + "deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs-1"]) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=100, + help="Number of prompts to process.", + ) + parser.add_argument( + "--dataset", + type=str, + default="xiaozheyao/MegaChat:sharegpt", + help="Dataset to use for prompts.", + ) + parser.add_argument( + "--output", + type=str, + default=".local/benchmark_output", + help="Output file to save the benchmark results.", + ) + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help='Specify service level objectives for goodput as "KEY:VALUE" ' + "pairs, where the key is a metric name, and the value is in " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' + "separated by spaces. Allowed request level metric names are " + '"ttft", "tpot", "e2el". For more context on the definition of ' + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + parser.add_argument( + "--wait-until-ready", + action="store_true", + help="Wait until the server is ready before starting the benchmark.", + default=True, + ) + args = parser.parse_args() + benchmark(args) diff --git a/tools/utils/plot_e2e_benchmark.py b/tools/utils/plot_e2e_benchmark.py new file mode 100644 index 0000000..c6fa880 --- /dev/null +++ b/tools/utils/plot_e2e_benchmark.py @@ -0,0 +1,131 @@ +import json + +import matplotlib.pyplot as plt +import os +from typing import List +import numpy as np + +colors = [ + (0, 0, 0), + (0.7, 0, 0), + (0.9, 0.3, 0), + (1, 0.8, 0), + (0.6, 0.8, 0.2), + (0.2, 0.6, 0.4), + (0, 0.4, 0.6), +] + +# Config: +INTERMEDIATE_PLOTS = False + +y_label = ("latency (s)", "throughput (tokens/sec)") # "MFU" or "latency" +x_label = "selected adapters" + +# name of the adapters +nm = ["1 LoRA", "2 LoRA", "1 Delta", "2 Delta", "1 LoRA 1 Delta"] +filenames = ['benchmark_1_lora_new.jsonl', 'benchmark_2_lora_new.jsonl', 'benchmark_1_delta_new.jsonl' , 'benchmark_2_delta_new.jsonl', 'benchmark_1_delta_1_lora_new.jsonl'] + +label_map = {"throughput (tokens/sec)": "output_throughput", "latency (s)": "mean_e2el_ms"} + +title = (f"End to end latency", + f"Output token throughput") + +out_dir = "./" + +data_names = ["total_token_throughput", "output_throughput", "mean_e2el_ms", "std_e2el_ms"] +# contains the results of the benchmark, the keys being the matrix sizes +data_metrics: dict[str, list[list]] = {} + +for i, filename in enumerate(filenames): + with open(filename, "r") as f: + line = f.readline() + data = json.loads(line.strip()) + data_metrics[nm[i]] = {} + for data_name in data_names: + if "ms" in data_name: + data_metrics[nm[i]][data_name] = data["metrics"][data_name] / 1000 + else: + data_metrics[nm[i]][data_name] = data["metrics"][data_name] + +# create plot for each metric +plot_titles_and_labels = [ + ( + title[0], + y_label[0], + x_label, + ), + ( + title [1], + y_label[1], + x_label, + ) +] + + +def create_plot( + title: str, + x_label: str, + y_label: str, + names: list[str], + metrics: dict[str, dict[str, list]], +): + global colors + + # plot configurations + plt.figure(figsize=(9, 5), dpi=300) + plt.grid(True, which="major", axis="y", color="white", linestyle="-", linewidth=1.5, zorder=0) + plt.yticks(fontsize=12) + plt.gca().set_facecolor("#dbdbdb") + + select = label_map[y_label] + + # plot data + for j, name in enumerate(names): + color = (0.7, 0.7, 0.7) + if j == 2 or j == 3: + color = colors[j-1] + y_data = metrics[name][select] + + # Plot bars at measurement points + plt.bar(name, y_data, color=color, label=name, zorder=2) + if (select == "mean_e2el_ms"): + # Compute confidence intervals + std = metrics[name]["std_e2el_ms"] + print(std) + print(y_data) + # Plot error bars at measurement points + plt.errorbar( + name, + y_data, + yerr=std, + fmt="o", + color=(0.4,0.4,0.4), + ecolor=(0.4,0.4,0.4), + elinewidth=1.5, + capsize=4, + ) + + + # add legend + # plt.legend(names, fontsize=12) + + # remove border + plt.gca().spines["left"].set_visible(False) + plt.gca().spines["top"].set_visible(False) + plt.gca().spines["right"].set_visible(False) + + # set title and labels + plt.title(title, fontsize = 14, loc="left", pad=24, fontweight="bold") + plt.xlabel(x_label, fontsize=14) + plt.ylabel(y_label, rotation=0, labelpad=40, loc="bottom", fontsize=14) + plt.gca().yaxis.set_label_coords(0, 1) + plt.gca().get_yaxis().get_offset_text().set_x(-0.04) + + plt.savefig( + f"{out_dir}/{select}.svg", format="svg", bbox_inches="tight" + ) + plt.show() + #plt.close() + +for title, y_label, x_label in plot_titles_and_labels: + create_plot(title, x_label, y_label, nm, data_metrics)