|
12 | 12 | EulerAncestralDiscreteScheduler, |
13 | 13 | DPMSolverMultistepScheduler, |
14 | 14 | ) |
| 15 | +from diffusers.pipelines.stable_diffusion.safety_checker import ( |
| 16 | + StableDiffusionSafetyChecker, |
| 17 | +) |
| 18 | + |
15 | 19 |
|
16 | 20 | MODEL_ID = "stabilityai/stable-diffusion-2-1" |
17 | 21 | MODEL_CACHE = "diffusers-cache" |
| 22 | +SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker" |
18 | 23 |
|
19 | 24 |
|
20 | 25 | class Predictor(BasePredictor): |
21 | 26 | def setup(self): |
22 | 27 | """Load the model into memory to make running multiple predictions efficient""" |
23 | 28 | print("Loading pipeline...") |
| 29 | + safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
| 30 | + SAFETY_MODEL_ID, |
| 31 | + cache_dir=MODEL_CACHE, |
| 32 | + local_files_only=True, |
| 33 | + ) |
24 | 34 | self.pipe = StableDiffusionPipeline.from_pretrained( |
25 | 35 | MODEL_ID, |
| 36 | + safety_checker=safety_checker, |
26 | 37 | cache_dir=MODEL_CACHE, |
27 | 38 | local_files_only=True, |
28 | 39 | ).to("cuda") |
@@ -107,9 +118,15 @@ def predict( |
107 | 118 |
|
108 | 119 | output_paths = [] |
109 | 120 | for i, sample in enumerate(output.images): |
110 | | - output_path = f"/tmp/out-{i}.png" |
111 | | - sample.save(output_path) |
112 | | - output_paths.append(Path(output_path)) |
| 121 | + if output.nsfw_content_detected and not output.nsfw_content_detected[i]: |
| 122 | + output_path = f"/tmp/out-{i}.png" |
| 123 | + sample.save(output_path) |
| 124 | + output_paths.append(Path(output_path)) |
| 125 | + |
| 126 | + if len(output_paths) == 0: |
| 127 | + raise Exception( |
| 128 | + f"NSFW content detected. Try running it again, or try a different prompt." |
| 129 | + ) |
113 | 130 |
|
114 | 131 | return output_paths |
115 | 132 |
|
|
0 commit comments