From 93af3b3d99810abc6cfa44fa3c9aa6993aca33c8 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 26 May 2025 22:43:13 -0600 Subject: [PATCH] [*] Typo fixes --- benchmarks/mlperf/backend.py | 2 +- benchmarks/tests/test_benchmark_serving.py | 2 +- ...y-prometheus-metrics-in-jetstream-server.md | 2 +- .../entrypoint/mini_offline_benchmarking.py | 12 ++++++------ experimental/jax/inference/nn/linear.py | 2 +- .../jax/inference/parallel/operations.py | 2 +- experimental/jax/inference/parallel/util.py | 2 +- .../jax/inference/runtime/batch_scheduler.py | 18 +++++++++--------- experimental/jax/inference/runtime/engine.py | 2 +- .../jetstream-maxtext-stable-stack/build.sh | 2 +- jetstream/core/README.md | 2 +- jetstream/core/lora/adapter_tensorstore.py | 10 +++++----- jetstream/core/orchestrator.py | 8 ++++---- jetstream/core/server_lib.py | 4 ++-- jetstream/engine/engine_api.py | 8 ++++---- jetstream/engine/mock_engine.py | 4 ++-- jetstream/engine/mock_utils.py | 2 +- jetstream/engine/token_utils.py | 6 +++--- jetstream/engine/tokenizer_api.py | 2 +- jetstream/engine/warmup_utils.py | 2 +- .../llama3/llama3_tokenizer.py | 2 +- jetstream/tests/engine/test_token_utils.py | 10 +++++----- jetstream/tools/multi_lora_decode_requester.py | 2 +- 23 files changed, 54 insertions(+), 54 deletions(-) diff --git a/benchmarks/mlperf/backend.py b/benchmarks/mlperf/backend.py index f574e4c2..66f74d60 100644 --- a/benchmarks/mlperf/backend.py +++ b/benchmarks/mlperf/backend.py @@ -327,7 +327,7 @@ def flush_queries(self): self.accuracy_log.write(json.dumps(pred_outputs)) self.accuracy_log.flush() self.accuracy_log.close() - log.info("Dumpped prediction outputs to accuracy log... ") + log.info("Dumped prediction outputs to accuracy log... ") def __del__(self): print("Finished destroying SUT.") diff --git a/benchmarks/tests/test_benchmark_serving.py b/benchmarks/tests/test_benchmark_serving.py index e208e269..6800765e 100644 --- a/benchmarks/tests/test_benchmark_serving.py +++ b/benchmarks/tests/test_benchmark_serving.py @@ -44,7 +44,7 @@ async def test_benchmark(self): disable_tqdm = True async def mocked_decode_response(): - """Mocks decode reponse as an async generator.""" + """Mocks decode response as an async generator.""" responses = [ jetstream_pb2.DecodeResponse( stream_content=jetstream_pb2.DecodeResponse.StreamContent( diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md index ad091158..c1659f45 100644 --- a/docs/observability-prometheus-metrics-in-jetstream-server.md +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -1,6 +1,6 @@ # Observability in JetStream Server -In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to gaurd the metrics observability feature. +In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to guard the metrics observability feature. ## Enable Prometheus server to observe Jetstream metrics diff --git a/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py b/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py index 7ec4b37b..3a738751 100644 --- a/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py +++ b/experimental/jax/inference/entrypoint/mini_offline_benchmarking.py @@ -64,12 +64,12 @@ def benchmark(): num_input_tokens = sum(map(lambda r: len(r.input_tokens), res_list)) num_output_tokens = sum(map(lambda r: len(r.generated_tokens), res_list)) - print("Benchmarking result: ") - print(" Total requests:", len(dataset)) - print(" Total input tokens:", num_input_tokens) - print(" Total output tokens:", num_output_tokens) - print(f" Input token thruput: {num_input_tokens/duration: .2f} tokens/sec") - print(f" Output token thruput: {num_output_tokens/duration: .2f} tokens/sec") + print("Benchmarking result:") + print(" Total requests: ", len(dataset)) + print(" Total input tokens: ", num_input_tokens) + print(" Total output tokens: ", num_output_tokens) + print(f" Input token throughput: {num_input_tokens/duration: .2f} tokens/sec") + print(f" Output token throughput: {num_output_tokens/duration: .2f} tokens/sec") if __name__ == "__main__": diff --git a/experimental/jax/inference/nn/linear.py b/experimental/jax/inference/nn/linear.py index b0e87756..eabb0a1e 100644 --- a/experimental/jax/inference/nn/linear.py +++ b/experimental/jax/inference/nn/linear.py @@ -7,7 +7,7 @@ https://www.apache.org/licenses/LICENSE-2.0 -Unless reuired by applicable law or agreed to in writing, software +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and diff --git a/experimental/jax/inference/parallel/operations.py b/experimental/jax/inference/parallel/operations.py index f87c34ec..b966fcba 100644 --- a/experimental/jax/inference/parallel/operations.py +++ b/experimental/jax/inference/parallel/operations.py @@ -22,7 +22,7 @@ def reduce_scatter(operand, scatter_dimension, axis_names): - """reduce-scatter sum operation via ppermute.""" + """reduce-scatter sum operation via permute.""" idx = get_partition_index(axis_names=axis_names) num_partitions = get_num_partitions(axis_names=axis_names) chunk_size = operand.shape[scatter_dimension] // num_partitions diff --git a/experimental/jax/inference/parallel/util.py b/experimental/jax/inference/parallel/util.py index 59c300b7..f2b28cf5 100644 --- a/experimental/jax/inference/parallel/util.py +++ b/experimental/jax/inference/parallel/util.py @@ -27,6 +27,6 @@ def pspec(a): elif isinstance(a, int) or isinstance(a, float): return P() else: - raise ValueError(f"unknown parition spec for {a}") + raise ValueError(f"unknown partition spec for {a}") return jax.tree_util.tree_map(pspec, sharded_pytree) diff --git a/experimental/jax/inference/runtime/batch_scheduler.py b/experimental/jax/inference/runtime/batch_scheduler.py index 6ff858fd..7a53c144 100644 --- a/experimental/jax/inference/runtime/batch_scheduler.py +++ b/experimental/jax/inference/runtime/batch_scheduler.py @@ -110,10 +110,10 @@ def schedule( cur_prompt_chunk_len = ( total_len - next_prefill_req.chunk_idx * next_prefill_req.chunk_size ) - alloced_pages = self.kv_manager.alloc_prefill_hbm_pages( + allocated_pages = self.kv_manager.alloc_prefill_hbm_pages( cur_prompt_chunk_len ) - if len(alloced_pages) == 0: + if len(allocated_pages) == 0: # TODO: introduce priority for the request and better # eviction algorithm. raise NotImplementedError("Eviction is not supported yet") @@ -121,9 +121,9 @@ def schedule( start_idx = ( next_prefill_req.chunk_idx * next_prefill_req.chunk_size ) // self.kv_manager.page_size - for i, page in enumerate(alloced_pages): + for i, page in enumerate(allocated_pages): next_prefill_req.page_indices[start_idx + i] = page - prefill_pages_update = PrefillPagesUpdate(alloced_pages) + prefill_pages_update = PrefillPagesUpdate(allocated_pages) # Schedule new generate reqs and allocate memory for all reqs. with generate_state.map_mutex: @@ -150,12 +150,12 @@ def schedule( next_generate_reqs.append(gr) # Check and alloc memory for generate. - alloced_pages = self.kv_manager.alloc_hbm_pages( + allocated_pages = self.kv_manager.alloc_hbm_pages( len(generate_state.active_slot_req_map) ) if ( len(generate_state.active_slot_req_map) != 0 - and len(alloced_pages) == 0 + and len(allocated_pages) == 0 ): raise NotImplementedError( "Eviction isn't supported yet, please set a lower value for batch_size" @@ -169,17 +169,17 @@ def schedule( if idx >= len(req.page_indices): continue - req.page_indices[idx] = alloced_pages[page_to_use] + req.page_indices[idx] = allocated_pages[page_to_use] generate_state_page_updates.append( GenerateStatePageUpdate( slot=slot, page_idx=idx, - mapped_idx=alloced_pages[page_to_use], + mapped_idx=allocated_pages[page_to_use], ) ) page_to_use += 1 - self.kv_manager.free_hbm_pages(alloced_pages[page_to_use:]) + self.kv_manager.free_hbm_pages(allocated_pages[page_to_use:]) if len(generate_state.active_slot_req_map) == 0: schedule_generate = False diff --git a/experimental/jax/inference/runtime/engine.py b/experimental/jax/inference/runtime/engine.py index c4a7c16b..e455aa52 100644 --- a/experimental/jax/inference/runtime/engine.py +++ b/experimental/jax/inference/runtime/engine.py @@ -223,7 +223,7 @@ def __init__( ) print(" preprocess,", end="") self._preprocess_queue: queue.Queue[Request] = queue.Queue() - # TODO: Seperate the running loop with the static inference model. + # TODO: Separate the running loop with the static inference model. self._preprocess_thread = threading.Thread( name="preprocess", target=self._preprocess ) diff --git a/experimental/jetstream-maxtext-stable-stack/build.sh b/experimental/jetstream-maxtext-stable-stack/build.sh index c0d9d683..115f47e6 100755 --- a/experimental/jetstream-maxtext-stable-stack/build.sh +++ b/experimental/jetstream-maxtext-stable-stack/build.sh @@ -37,4 +37,4 @@ docker build --no-cache \ -t ${LOCAL_IMAGE_TAG} \ -f ./Dockerfile . -echo "********* Sucessfully built Stable Stack Image with tag $LOCAL_IMAGE_TAG *********" +echo "********* Successfully built Stable Stack Image with tag $LOCAL_IMAGE_TAG *********" diff --git a/jetstream/core/README.md b/jetstream/core/README.md index 50532c70..042f4dc1 100644 --- a/jetstream/core/README.md +++ b/jetstream/core/README.md @@ -1,3 +1,3 @@ # JetStream core Subpackage - Server and Library that support continuous batching serving. -Interleaved mode: Provide continuous batching to optimize inference. Uses JAX directy on single-host TPU. \ No newline at end of file +Interleaved mode: Provide continuous batching to optimize inference. Uses JAX directly on single-host TPU. \ No newline at end of file diff --git a/jetstream/core/lora/adapter_tensorstore.py b/jetstream/core/lora/adapter_tensorstore.py index e54c7261..516c73e9 100644 --- a/jetstream/core/lora/adapter_tensorstore.py +++ b/jetstream/core/lora/adapter_tensorstore.py @@ -231,9 +231,9 @@ def _initialize_decoding_adapters_cache(self, adapter_weights): """ Create a new PyTree with zero tensors at the paths corresponding to non-None leaves in the input PyTree. The zero tensors have an added - dimension of size `self.totol_slots`. + dimension of size `self.total_slots`. Args: - adatper_weights: The input PyTree, whose structure will be mirrored. + adapter_weights: The input PyTree, whose structure will be mirrored. Returns: A new PyTree with zero Tensors or None values, mirroring the structure of the input PyTree. @@ -437,7 +437,7 @@ async def load_adapter( # --- Handle LOADING state --- if metadata.status == AdapterStatus.LOADING: - # Wait untill loading is done. + # Wait until loading is done. logging.info( "Adapter %s is already loading by another task, waiting...", adapter_id, @@ -655,7 +655,7 @@ async def get_lora_weights( async def unload_adapter(self, adapter_id: str): """Unloads a LoRA adapter's weights and removes it from the TensorStore.""" if adapter_id not in self.adapter_registry: - raise ValueError(f"Adatper with ID '{adapter_id}' not found.") + raise ValueError(f"Adapter with ID '{adapter_id}' not found.") event_to_wait_on: Optional[asyncio.Event] = None async with self.lock: @@ -677,7 +677,7 @@ async def unload_adapter(self, adapter_id: str): self._unsafe_unload_adapter(adapter_id) def list_adapters(self) -> Dict[str, AdapterMetadata]: - """Lists all registered adatpers and their metadata.""" + """Lists all registered adapters and their metadata.""" return self.adapter_registry def _evict(self, from_hbm: bool = True) -> bool: diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 7b774e7a..0d574c35 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -60,7 +60,7 @@ on queues that don't have an ongoing activity (i.e. everything but the generation queue) because we don't control to go back to those queues until necessary. Blocking means that the GIL doesn't switch back to that thread, -wheras continual queue get operations 'chop' control and mean that we do not +whereas continual queue get operations 'chop' control and mean that we do not achieve good throughput. This is okay on the prefill/transfer/detokenization threads because we don't need to do anything other than react to the presence of items on these queues, wheras the generation thread needs to also run a @@ -811,9 +811,9 @@ def _prefill_thread(self, idx: int): # Here we are applying the LoRA adapter params to the base params and # them. In the interleaved mode, the prefill and generate shares the - # same params. But as long as prefill and decode happens sequentially, - # there is no issues. Issue will arrise if prefill and decode is running - # in parallel and sharing the same params. Issue arrise because prefill + # same params. But as long as prefill and decode happen sequentially, + # there are no issues. Issue will arise if prefill and decode are running + # in parallel and sharing the same params. Issues arise because prefill # uses pre-merged weights and generate uses only base weights. final_prefill_params = prefill_params if adapter_id and adapter_tensorstore is not None: diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 28e7c16a..7b6c4e2d 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -147,7 +147,7 @@ def create_driver( config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. - metrics_collector: The JetStream Promethus metric collector. + metrics_collector: The JetStream Prometheus metric collector. enable_model_warmup: The flag to enable model server warmup. multi_sampling: The flag to enable multi-sampling. prefix_caching_config: Config to prefix caching. Disable if None. @@ -291,7 +291,7 @@ def run( threads: Number of RPC handlers worker threads. This should be at least equal to the decoding batch size to fully saturate the decoding queue. jax_padding: The flag to enable JAX padding during tokenization. - metrics_server_config: The config to enable Promethus metric server. + metrics_server_config: The config to enable Prometheus metric server. enable_jax_profiler: The flag to enable JAX profiler server. jax_profiler_port: The port JAX profiler server (default to 9999). enable_model_warmup: The flag to enable model server warmup. diff --git a/jetstream/engine/engine_api.py b/jetstream/engine/engine_api.py index 9279e216..5ba9e051 100644 --- a/jetstream/engine/engine_api.py +++ b/jetstream/engine/engine_api.py @@ -31,7 +31,7 @@ # The model parameters - their partitioning will be unique for different prefill -# and decode topoologies. +# and decode topologies. Params = Any # The result of a prefill operation, often a batch size 1 KVCache. Prefix = Any @@ -39,9 +39,9 @@ DecodeState = Any # Accelerator representation of tokens. DeviceTokens = Any -# Cpus asscociated with the mesh. +# Cpus associated with the mesh. CpuDevices = Any -# Tokenkizer used by the engine +# Tokenizer used by the engine Tokenizer = Any # PRNG key used for prefilling PRNGKeyType = Any @@ -264,7 +264,7 @@ def free_resource( ) -> Any: """Free cache and other decode resource for the slot. - This function is needed for advanced attetnion kenel like PageAttetion. + This function is needed for advanced attention kernel like PageAttention. After finishing one request, the engine need to free all used page block resource and reuse for coming requests. """ diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 7466659f..24b6009f 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -332,7 +332,7 @@ def generate( # TODO: Do we need a left aligned one to test spec sampling? # Don't need the + 1 you normally would, because we don't provide a # token from prefill in the dummy. - # This iota and masking is to allow for a cicular cache. + # This iota and masking is to allow for a circular cache. length_mask = ( -(l_iota - generate_cache_index) % self.cache_length ) <= generate_lengths[:, None] @@ -540,7 +540,7 @@ def colocated_cpus(self) -> None: @property def use_chunked_prefill(self) -> bool: - """Wether to use chunked prefill.""" + """Whether to use chunked prefill.""" return self._use_chunked_prefill @property diff --git a/jetstream/engine/mock_utils.py b/jetstream/engine/mock_utils.py index a48a360f..3293d0e2 100644 --- a/jetstream/engine/mock_utils.py +++ b/jetstream/engine/mock_utils.py @@ -56,7 +56,7 @@ class TestVocab(Vocabulary): tokenizer: TestTokenizer = TestTokenizer() def _encode(self, s: str) -> Sequence[int]: - """Converts a string into a integer sequenc.""" + """Converts a string into a integer sequence.""" # 'We use array methods, not python iterables so we don't # implement this method in the mock vocab. raise NotImplementedError diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index f0ff3bc9..ded5c3bd 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -409,7 +409,7 @@ def encode( return tokens, true_length def decode(self, token_ids: list[int], **kwargs) -> str: - """Processess input token ids to generate a string. + """Processes input token ids to generate a string. Args: token_ids: List of token ids. **kwargs: Additional keyword arguments. @@ -483,7 +483,7 @@ def encode( return tokens, true_length def decode(self, token_ids: list[int]) -> str: - """Processess input token ids to generate a string. + """Processes input token ids to generate a string. Args: token_ids: List of token ids. Returns: @@ -566,7 +566,7 @@ def encode( return tokens, true_length def decode(self, token_ids: list[int]) -> str: - """Processess input token ids to generate a string. + """Processes input token ids to generate a string. Args: token_ids: List of token ids. Returns: diff --git a/jetstream/engine/tokenizer_api.py b/jetstream/engine/tokenizer_api.py index a1461e60..cdc693a8 100644 --- a/jetstream/engine/tokenizer_api.py +++ b/jetstream/engine/tokenizer_api.py @@ -43,7 +43,7 @@ def encode( @abc.abstractmethod def decode(self, token_ids: list[int], **kwargs) -> str: - """Processess input token ids to generate a string. + """Processes input token ids to generate a string. Args: token_ids: List of token ids. **kwargs: Additional keyword arguments. diff --git a/jetstream/engine/warmup_utils.py b/jetstream/engine/warmup_utils.py index 5ed3d0eb..940397f5 100644 --- a/jetstream/engine/warmup_utils.py +++ b/jetstream/engine/warmup_utils.py @@ -141,7 +141,7 @@ def initialize_insert_generate_jit_cache( generate_params: Any, generate_idx: int, ): - """Initialiszes jit cache for insert and generate. + """Initializes jit cache for insert and generate. Args: generate_engine: A generate engine to be compiled for. diff --git a/jetstream/external_tokenizers/llama3/llama3_tokenizer.py b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py index 230debe5..ddafe8aa 100644 --- a/jetstream/external_tokenizers/llama3/llama3_tokenizer.py +++ b/jetstream/external_tokenizers/llama3/llama3_tokenizer.py @@ -125,7 +125,7 @@ def encode( By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising + to special tokens to be encoded as natural text (instead of raising an error). - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index fe2de570..0a384437 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -27,7 +27,7 @@ class SPTokenizer: - """Tokenier used in original llama2 git""" + """Tokenizer used in original llama2 git""" def __init__(self, tokenizer_path: str): self.tokenizer = SentencePieceProcessor() @@ -40,7 +40,7 @@ def decode(self, t: List[int]) -> str: class JetStreamTokenizer: - """Tokenier used in JetStream before mix_token""" + """Tokenizer used in JetStream before mix_token""" def __init__(self, tokenizer_path: str): metadata = tokenizer_pb2.TokenizerParameters(path=tokenizer_path) @@ -91,13 +91,13 @@ def setup_hftoken(self): def test_decode_vs_piece(self): self.setup_sentencepiece() tokens = [304, 13, 2266, 526, 777, 9590, 2020, 29901] - expeted_sp_output = [] + expected_sp_output = [] jt_output = [] for t in tokens: - expeted_sp_output.append(self.sp_tokenizer.decode([t])) + expected_sp_output.append(self.sp_tokenizer.decode([t])) jt_output.append(self.jt_tokenizer.decode(t)) - self.assertNotEqual(jt_output, expeted_sp_output) + self.assertNotEqual(jt_output, expected_sp_output) def test_sp_vs_seqio(self): self.setup_sentencepiece() diff --git a/jetstream/tools/multi_lora_decode_requester.py b/jetstream/tools/multi_lora_decode_requester.py index a651f9f8..9dc08032 100644 --- a/jetstream/tools/multi_lora_decode_requester.py +++ b/jetstream/tools/multi_lora_decode_requester.py @@ -69,7 +69,7 @@ def get_tokenizer( model_id: str, tokenizer_name: str, ) -> Any: - """Return a tokenizer or a tokenizer placholder.""" + """Return a tokenizer or a tokenizer placeholder.""" if tokenizer_name == "test": print("Using test tokenizer") return "test"