Skip to content

Commit 05496d2

Browse files
author
Jesse Andrews
authored
Merge pull request #65 from anotherjesse/nsfw_checker
enable safety_checker for SD 2.1
2 parents 88550b8 + ec8e6b5 commit 05496d2

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

cog.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ build:
33
cuda: "11.6"
44
python_version: "3.10"
55
python_packages:
6-
- "diffusers==0.10.0"
6+
- "diffusers==0.11.1"
77
- "torch==1.13.0"
88
- "ftfy==6.1.1"
99
- "scipy==1.9.3"
1010
- "transformers==4.25.1"
11-
- "accelerate==0.14.0"
11+
- "accelerate==0.15.0"
1212

1313
predict: "predict.py:Predictor"

predict.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,28 @@
1212
EulerAncestralDiscreteScheduler,
1313
DPMSolverMultistepScheduler,
1414
)
15+
from diffusers.pipelines.stable_diffusion.safety_checker import (
16+
StableDiffusionSafetyChecker,
17+
)
18+
1519

1620
MODEL_ID = "stabilityai/stable-diffusion-2-1"
1721
MODEL_CACHE = "diffusers-cache"
22+
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
1823

1924

2025
class Predictor(BasePredictor):
2126
def setup(self):
2227
"""Load the model into memory to make running multiple predictions efficient"""
2328
print("Loading pipeline...")
29+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
30+
SAFETY_MODEL_ID,
31+
cache_dir=MODEL_CACHE,
32+
local_files_only=True,
33+
)
2434
self.pipe = StableDiffusionPipeline.from_pretrained(
2535
MODEL_ID,
36+
safety_checker=safety_checker,
2637
cache_dir=MODEL_CACHE,
2738
local_files_only=True,
2839
).to("cuda")
@@ -107,9 +118,15 @@ def predict(
107118

108119
output_paths = []
109120
for i, sample in enumerate(output.images):
110-
output_path = f"/tmp/out-{i}.png"
111-
sample.save(output_path)
112-
output_paths.append(Path(output_path))
121+
if output.nsfw_content_detected and not output.nsfw_content_detected[i]:
122+
output_path = f"/tmp/out-{i}.png"
123+
sample.save(output_path)
124+
output_paths.append(Path(output_path))
125+
126+
if len(output_paths) == 0:
127+
raise Exception(
128+
f"NSFW content detected. Try running it again, or try a different prompt."
129+
)
113130

114131
return output_paths
115132

script/download-weights

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,25 @@
33
import os
44
import shutil
55
from diffusers import StableDiffusionPipeline
6+
from diffusers.pipelines.stable_diffusion.safety_checker import (
7+
StableDiffusionSafetyChecker,
8+
)
69

710

8-
model_id = "stabilityai/stable-diffusion-2-1"
11+
MODEL_ID = "stabilityai/stable-diffusion-2-1"
912
MODEL_CACHE = "diffusers-cache"
13+
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
14+
1015
if os.path.exists(MODEL_CACHE):
1116
shutil.rmtree(MODEL_CACHE)
1217
os.makedirs(MODEL_CACHE, exist_ok=True)
1318

19+
saftey_checker = StableDiffusionSafetyChecker.from_pretrained(
20+
SAFETY_MODEL_ID,
21+
cache_dir=MODEL_CACHE,
22+
)
23+
1424
pipe = StableDiffusionPipeline.from_pretrained(
15-
model_id,
25+
MODEL_ID,
1626
cache_dir=MODEL_CACHE,
1727
)

0 commit comments

Comments
 (0)