File tree Expand file tree Collapse file tree 4 files changed +21
-8
lines changed Expand file tree Collapse file tree 4 files changed +21
-8
lines changed Original file line number Diff line number Diff line change 35
35
- name : build maxdiffusion jax nightly image
36
36
run : |
37
37
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly
38
+
39
+ build-gpu-image :
40
+ runs-on : ["self-hosted", "e2", "cpu"]
41
+ steps :
42
+ - uses : actions/checkout@v3
43
+ - name : Cleanup old docker images
44
+ run : docker system prune --all --force
45
+ - name : build maxdiffusion jax stable stack gpu image
46
+ run : |
47
+ bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu
48
+ - name : build maxdiffusion jax nightly image
49
+ run : |
50
+ bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu
Original file line number Diff line number Diff line change @@ -34,13 +34,15 @@ for ARGUMENT in "$@"; do
34
34
echo " $KEY " =" $VALUE "
35
35
done
36
36
37
+ export DEVICE=" ${DEVICE:- tpu} "
38
+
37
39
if [[ ! -v CLOUD_IMAGE_NAME ]] || [[ ! -v PROJECT ]] || [[ ! -v MODE ]] ; then
38
40
echo " You must set CLOUD_IMAGE_NAME, PROJECT and MODE"
39
41
exit 1
40
42
fi
41
43
42
44
gcloud auth configure-docker us-docker.pkg.dev --quiet
43
- bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE
45
+ bash docker_build_dependency_image.sh LOCAL_IMAGE_NAME=$LOCAL_IMAGE_NAME MODE=$MODE DEVICE= $DEVICE
44
46
image_date=$( date +%Y-%m-%d)
45
47
46
48
# Upload only dependencies image
Original file line number Diff line number Diff line change @@ -22,8 +22,7 @@ RUN apt-get update && apt-get install -y google-cloud-sdk
22
22
# Set environment variables for Google Cloud SDK
23
23
ENV PATH="/usr/local/google-cloud-sdk/bin:${PATH}"
24
24
25
- # Upgrade libcusprase to work with Jax
26
- RUN apt-get update && apt-get install -y libcusparse-12-3
25
+
27
26
28
27
ARG MODE
29
28
ENV ENV_MODE=$MODE
@@ -46,5 +45,4 @@ RUN ls .
46
45
RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION DEVICE=${ENV_DEVICE}"
47
46
RUN --mount=type=cache,target=/root/.cache/pip bash setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} DEVICE=${ENV_DEVICE}
48
47
49
-
50
48
WORKDIR /deps
Original file line number Diff line number Diff line change @@ -55,6 +55,9 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
55
55
exit 1
56
56
fi
57
57
58
+ # Install dependencies from requirements.txt first
59
+ pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
60
+
58
61
# Install JAX and JAXlib based on the specified mode
59
62
if [[ " $MODE " == " stable" || ! -v MODE ]]; then
60
63
# Stable mode
@@ -78,7 +81,7 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
78
81
pip3 install " jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
79
82
fi
80
83
export NVTE_FRAMEWORK=jax
81
- pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
84
+ pip3 install transformer_engine[jax]==2.1.0
82
85
fi
83
86
84
87
elif [[ $MODE == " nightly" ]]; then
106
109
exit 1
107
110
fi
108
111
109
- # Install dependencies from requirements.txt
110
- pip3 install -U -r requirements.txt || echo " Failed to install dependencies in the requirements" >&2
111
-
112
112
# Install maxdiffusion
113
113
pip3 install -U . || echo " Failed to install maxdiffusion" >&2
You can’t perform that action at this time.
0 commit comments