Skip to content
Open
1 change: 1 addition & 0 deletions .github/workflows/nv-a6000-fastgen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
- 'mii/legacy/**'
- 'tests/legacy/**'
- '.github/workflows/nv-v100-legacy.yml'
- '.github/workflows/nv-a6000-sd.yml'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
58 changes: 58 additions & 0 deletions .github/workflows/nv-a6000-sd.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: nv-a6000-sd

on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- 'mii/legacy/**'
- 'tests/legacy/**'
- '.github/workflows/nv-a6000-sd.yml'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
unit-tests:
runs-on: [self-hosted, nvidia, a6000]
container:
image: nvcr.io/nvidia/pytorch:24.03-py3
ports:
- 80
options: --gpus all --shm-size "8G"

steps:
- uses: actions/checkout@v4

- name: Check container state
run: |
ldd --version
nvcc --version
nvidia-smi
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install transformers
run: |
git clone --depth=1 https://github.com/huggingface/transformers
cd transformers
git rev-parse --short HEAD
python -m pip install .
- name: Install deepspeed
run: |
git clone --depth=1 https://github.com/microsoft/DeepSpeed
cd DeepSpeed
python -m pip install .
ds_report
- name: Install MII
run: |
pip install .[dev]
- name: Python environment
run: |
python -m pip list
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests/legacy
python -m pytest --color=yes --durations=0 --verbose -rF -m "stable_diffusion" ./
25 changes: 22 additions & 3 deletions mii/legacy/method_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import io

from abc import ABC, abstractmethod
from mii.legacy.constants import TaskType
from mii.legacy.grpc_related.proto import legacymodelresponse_pb2 as modelresponse_pb2
Expand Down Expand Up @@ -274,11 +276,28 @@ def pack_request_to_proto(self, request_dict, **query_kwargs):
negative_prompt = request_dict.get("negative_prompt", [""] * len(prompt))
negative_prompt = negative_prompt if isinstance(negative_prompt,
list) else [negative_prompt]
image = request_dict["image"] if isinstance(request_dict["image"],
list) else [request_dict["image"]]
mask_image = request_dict["mask_image"] if isinstance(
image_list = request_dict["image"] if isinstance(
request_dict["image"],
list) else [request_dict["image"]]
mask_image_list = request_dict["mask_image"] if isinstance(
request_dict["mask_image"],
list) else [request_dict["mask_image"]]
image = []
for img in image_list:
if isinstance(img, bytes):
image.append(img)
else:
imgByteArr = io.BytesIO()
img.save(imgByteArr, format=img.format)
image.append(imgByteArr.getvalue())
mask_image = []
for img in mask_image_list:
if isinstance(img, bytes):
mask_image.append(img)
else:
imgByteArr = io.BytesIO()
img.save(imgByteArr, format=img.format)
mask_image.append(imgByteArr.getvalue())

return modelresponse_pb2.InpaintingRequest(
prompt=prompt,
Expand Down
10 changes: 9 additions & 1 deletion mii/legacy/models/providers/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
# DeepSpeed Team
import os
import torch
from huggingface_hub import HfApi

from .utils import attempt_load
from mii.config import ModelConfig


def _get_model_revs(model_name):
api = HfApi()
branches = api.list_repo_refs(model_name).branches
return [b.name for b in branches]


def diffusers_provider(model_config: ModelConfig):
from diffusers import DiffusionPipeline

Expand All @@ -17,7 +24,8 @@ def diffusers_provider(model_config: ModelConfig):
kwargs = model_config.pipeline_kwargs
if model_config.dtype == torch.half:
kwargs["torch_dtype"] = torch.float16
kwargs["revision"] = "fp16"
if "fp16" in _get_model_revs(model_config.model):
kwargs["revision"] = "fp16"

pipeline = attempt_load(DiffusionPipeline.from_pretrained,
model_config.model,
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
clang-format==18.1.3
diffusers
einops
pre-commit>=2.20.0
pytest
Expand Down
42 changes: 38 additions & 4 deletions tests/legacy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import os
import mii.legacy as mii
from types import SimpleNamespace
from packaging import version as pkg_version
import torch


@pytest.fixture(scope="function", params=["fp16"])
Expand Down Expand Up @@ -84,13 +86,20 @@ def ds_config(request):
return request.param


@pytest.fixture(scope="function")
def replace_with_kernel_inject(model_name):
if "clip-vit" in model_name:
@pytest.fixture(scope="function", params=[None])
def replace_with_kernel_inject(request, model_name):
if request.param is not None:
return request.param
if model_name == "openai/clip-vit-base-patch32":
return False
return True


@pytest.fixture(scope="function", params=[False])
def enable_cuda_graph(request):
return request.param


@pytest.fixture(scope="function")
def model_config(
task_name: str,
Expand All @@ -104,6 +113,7 @@ def model_config(
enable_zero: bool,
ds_config: dict,
replace_with_kernel_inject: bool,
enable_cuda_graph: bool,
):
config = SimpleNamespace(
skip_model_check=True, # TODO: remove this once conversation task check is fixed
Expand All @@ -120,6 +130,7 @@ def model_config(
enable_zero=enable_zero,
ds_config=ds_config,
replace_with_kernel_inject=replace_with_kernel_inject,
enable_cuda_graph=enable_cuda_graph,
)
return config.__dict__

Expand All @@ -145,8 +156,31 @@ def expected_failure(request):
return request.param


@pytest.fixture(scope="function", params=[None])
def min_compute_capability(request):
return request.param


@pytest.fixture(scope="function")
def meets_compute_capability_reqs(min_compute_capability):
if min_compute_capability is None:
return
min_compute_ver = pkg_version.parse(str(min_compute_capability))
device_compute_ver = pkg_version.parse(".".join(
map(str,
torch.cuda.get_device_capability())))
if device_compute_ver < min_compute_ver:
pytest.skip(
f"Skipping test because device compute capability ({device_compute_ver}) is less than the minimum required ({min_compute_ver})."
)


@pytest.fixture(scope="function")
def deployment(deployment_name, mii_config, model_config, expected_failure):
def deployment(deployment_name,
mii_config,
model_config,
expected_failure,
meets_compute_capability_reqs):
if expected_failure is not None:
with pytest.raises(expected_failure) as excinfo:
mii.deploy(
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[pytest]
markers =
deepspeed:Run test for deepspeed CI
stable_diffusion:Run Stable Diffusion tests
102 changes: 101 additions & 1 deletion tests/legacy/test_local_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,86 @@
import pytest
import mii.legacy as mii

import requests
from PIL import Image


@pytest.mark.parametrize(
"task_name, model_name, query",
[(
"conversational",
"microsoft/DialoGPT-small",
{
"text": "DeepSpeed is the greatest",
"conversation_id": 3,
"past_user_inputs": [],
"generated_responses": [],
},
),
(
"fill-mask",
"bert-base-uncased",
{
"query": "Hello I'm a [MASK] model."
},
),
(
"question-answering",
"deepset/roberta-large-squad2",
{
"question": "What is the greatest?",
"context": "DeepSpeed is the greatest",
},
),
(
"text-generation",
"bigscience/bloom-560m",
{
"query": ["DeepSpeed is the greatest",
"Seattle is"]
},
),
(
"token-classification",
"Jean-Baptiste/roberta-large-ner-english",
{
"query": "My name is jean-baptiste and I live in montreal."
},
),
(
"text-classification",
"roberta-large-mnli",
{
"query": "DeepSpeed is the greatest"
},
),
(
"zero-shot-image-classification",
"openai/clip-vit-base-patch32",
{
"image":
"https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
"candidate_labels": ["animals",
"humans",
"landscape"]
},
),
("text-to-image-inpainting",
"stabilityai/stable-diffusion-2-inpainting",
{
"prompt":
"a black cat with glowing eyes",
"image":
Image.open(
requests.get(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
stream=True).raw),
"mask_image":
Image.open(
requests.get(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png",
stream=True).raw),
})],
[
(
"fill-mask",
Expand Down Expand Up @@ -73,7 +150,7 @@ def test_single_GPU(deployment, query):


@pytest.mark.parametrize(
"task_name, model_name, query",
"task_name, model_name, query, tensor_parallel",
[
(
"text-generation",
Expand All @@ -82,6 +159,7 @@ def test_single_GPU(deployment, query):
"query": ["DeepSpeed is the greatest",
"Seattle is"]
},
2,
),
],
)
Expand Down Expand Up @@ -111,3 +189,25 @@ def test_session(deployment, query):
result = generator.query(query)
generator.destroy_session(session_name)
assert result


@pytest.mark.stable_diffusion
@pytest.mark.parametrize(
"task_name, model_name, query",
[
(
"text-to-image",
"openskyml/midjourney-mini",
{
"prompt": "a dog on a rocket",
"negative_prompt": "planet earth",
},
),
],
)
@pytest.mark.parametrize("enable_cuda_graph", [True])
@pytest.mark.parametrize("min_compute_capability", [8])
def test_SD_kernel_inject(deployment, query):
generator = mii.mii_query_handle(deployment)
result = generator.query(query)
assert result
7 changes: 0 additions & 7 deletions tests/legacy/test_non_persistent_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
"context": "DeepSpeed is the greatest",
},
),
(
"text-generation",
"distilgpt2",
{
"query": ["DeepSpeed is the greatest"]
},
),
(
"text-generation",
"bigscience/bloom-560m",
Expand Down
Loading