Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
pip install -r benchmarks/requirements.in
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/
pytype --jobs auto --exclude "jetstream/engine/implementations/*" --disable=import-error,module-attr jetstream/ benchmarks/
- name: Analysing the code with pylint
run: |
pylint jetstream/ benchmarks/
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ logs/
tmp/
venv/
.vscode/

# engine imple submodules
jetstream/engine/implementations/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "jetstream/engine/implementations/maxtext"]
path = jetstream/engine/implementations/maxtext
url = https://github.com/google/maxtext.git
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pip install -r requirements.txt
Use the following commands to run a server locally:
```
# Start a server
python -m jetstream.core.implementations.mock.server
python -m jetstream.entrypoints.mock.server

# Test local mock server
python -m jetstream.tools.requester
Expand Down
1 change: 1 addition & 0 deletions jetstream/engine/implementations/maxtext
Submodule maxtext added at 34412d
66 changes: 66 additions & 0 deletions jetstream/entrypoints/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Config for JetStream Server (including engine init)."""

import functools
import os
from typing import Sequence, Type

import jax
from jetstream.core import config_lib
from jetstream.engine.implementations.maxtext.MaxText import maxengine_config, pyconfig
from jetstream_pt import config


def get_server_config(
config_str: str, argv: Sequence[str]
) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]:
match config_str:
case "MaxtextInterleavedServer":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(
functools.partial(
maxengine_config.create_maxengine, config=pyconfig.config
),
),
)
case "PyTorchInterleavedServer":
os.environ["XLA_FLAGS"] = (
"--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
)
engine = config.create_engine_from_config_flags()
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(lambda a: engine,),
)
case "InterleavedCPUTestServer":
server_config = config_lib.InterleavedCPUTestServer
case "CPUTestServer":
server_config = config_lib.CPUTestServer
case _:
raise NotImplementedError
return server_config
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl import app
from absl import flags

from jetstream.core.implementations.mock import config as mock_config
from jetstream.entrypoints.mock import config as mock_config
from jetstream.core import server_lib


Expand Down
62 changes: 62 additions & 0 deletions jetstream/entrypoints/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs a JetStream Server."""

from typing import Sequence

from absl import app
from absl import flags

from jetstream.entrypoints import config
from jetstream.core import config_lib, server_lib


flags.DEFINE_integer("port", 9000, "port to listen on")
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
flags.DEFINE_string(
"config",
"InterleavedCPUTestServer",
"available servers",
)
flags.DEFINE_integer("prometheus_port", 0, "")


def main(argv: Sequence[str]):
devices = server_lib.get_devices()
print(f"devices: {devices}")
server_config = config.get_server_config(flags.FLAGS.config, argv)
print(f"server_config: {server_config}")
del argv

metrics_server_config: config_lib.MetricsServerConfig | None = None
if flags.FLAGS.prometheus_port != 0:
metrics_server_config = config_lib.MetricsServerConfig(
port=flags.FLAGS.prometheus_port
)
# We separate credential from run so that we can unit test it with local
# credentials.
# TODO: Add grpc credentials for OSS.
jetstream_server = server_lib.run(
threads=flags.FLAGS.threads,
port=flags.FLAGS.port,
config=server_config,
devices=devices,
metrics_server_config=metrics_server_config,
)
jetstream_server.wait_for_termination()


if __name__ == "__main__":
app.run(main)
27 changes: 27 additions & 0 deletions requirements-standalone.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# jetstream library
absl-py
coverage
flax
grpcio
jax
jaxlib
numpy
portpicker
prometheus-client
pytest
seqio
tiktoken
blobfile
parameterized
shortuuid
# jetstream benchmarks
nltk
evaluate
rouge-score
tqdm
# jetstream profiling
tensorboard-plugin-profile
# engines
# maxtext @ git+https://github.com/google/[email protected]#egg=maxtext
# maxtext @ {root:uri}/jetstream/engine/implementations/maxtext
jetstream_pt @ git+https://github.com/google/[email protected]#egg=jetstream_pt