diff --git a/environment.yaml b/environment.yaml index f41c3cada..ce497402e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -21,7 +21,7 @@ dependencies: - streamlit>=0.73.1 - einops==0.3.0 - torch-fidelity==0.3.0 - - transformers==4.19.2 + - transformers==4.33.2 - torchmetrics==0.6.0 - kornia==0.6 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py index c82918240..5103ece5f 100644 --- a/optimizedSD/optimized_txt2img.py +++ b/optimizedSD/optimized_txt2img.py @@ -129,6 +129,11 @@ def load_model_from_config(ckpt, verbose=False): type=str, help="if specified, load prompts from this file", ) +parser.add_argument( + "--negative-prompt-file", + type=str, + help="if specified, load negative prompts from this file", +) parser.add_argument( "--seed", type=int, @@ -235,6 +240,7 @@ def load_model_from_config(ckpt, verbose=False): batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: assert opt.prompt is not None prompt = opt.prompt @@ -250,6 +256,13 @@ def load_model_from_config(ckpt, verbose=False): data = batch_size * list(data) data = list(chunk(sorted(data), batch_size)) +if not opt.negative_prompt_file: + negative_prompt = "" +else: + print(f"reading negative prompts from {opt.negative_prompt_file}") + with open(opt.negative_prompt_file, "r") as f: + negative_prompt = f.read().splitlines() + negative_prompt = ' '.join(negative_prompt) if opt.precision == "autocast" and opt.device != "cpu": precision_scope = autocast @@ -271,7 +284,7 @@ def load_model_from_config(ckpt, verbose=False): modelCS.to(opt.device) uc = None if opt.scale != 1.0: - uc = modelCS.get_learned_conditioning(batch_size * [""]) + uc = modelCS.get_learned_conditioning(batch_size * [negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts)