Skip to content

Commit 7284ca0

Browse files
authored
Gpu image pipeline (#169)
* Add gpu image creation pipeline Signed-off-by: Kunjan <[email protected]> * Move jax[cuda] installation after requirements.txt to avoid clobbering of it by jax installation in requirements.txt Signed-off-by: Kunjan <[email protected]> --------- Signed-off-by: Kunjan <[email protected]>
1 parent e7d1f13 commit 7284ca0

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

.github/workflows/UploadDockerImages.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ jobs:
3535
- name: build maxdiffusion jax nightly image
3636
run: |
3737
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
38+
39+
build-gpu-image:
40+
runs-on: ["self-hosted", "e2", "cpu"]
41+
steps:
42+
- uses: actions/checkout@v3
43+
- name: Cleanup old docker images
44+
run: docker system prune --all --force
45+
- name: build maxdiffusion jax stable stack gpu image
46+
run: |
47+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
48+
- name: build maxdiffusion jax nightly image
49+
run: |
50+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu

.github/workflows/build_and_upload_images.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ for ARGUMENT in "$@"; do
3434
echo "$KEY"="$VALUE"
3535
done
3636

37+
export DEVICE="${DEVICE:-tpu}"
38+
3739
if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
3840
echo "You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
3941
exit 1
4042
fi
4143

4244
gcloud auth configure-docker us-docker.pkg.dev --quiet
43-
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
45+
bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE=$DEVICE
4446
image_date=$(date +%Y-%m-%d)
4547

4648
# Upload only dependencies image

maxdiffusion_gpu_dependencies.Dockerfile

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
2222
# Set environment variables for Google Cloud SDK
2323
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
2424

25-
# Upgrade libcusprase to work with Jax
26-
RUN apt-get update && apt-get install -y libcusparse-12-3
25+
2726

2827
ARG MODE
2928
ENV ENV_MODE=$MODE
@@ -46,5 +45,4 @@ RUN ls .
4645
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
4746
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
4847

49-
5048
WORKDIR /deps

setup.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
5555
exit 1
5656
fi
5757

58+
# Install dependencies from requirements.txt first
59+
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
60+
5861
# Install JAX and JAXlib based on the specified mode
5962
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
6063
# Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
7881
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
7982
fi
8083
export NVTE_FRAMEWORK=jax
81-
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84+
pip3 install transformer_engine[jax]==2.1.0
8285
fi
8386

8487
elif [[ $MODE == "nightly" ]]; then
@@ -106,8 +109,5 @@ else
106109
exit 1
107110
fi
108111

109-
# Install dependencies from requirements.txt
110-
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
111-
112112
# Install maxdiffusion
113113
pip3 install -U . || echo "Failed to install maxdiffusion" >&2

0 commit comments

Comments
 (0)