We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fc7ae98 commit 1f0bd90Copy full SHA for 1f0bd90
src/python/training/train_ldm.py
@@ -74,12 +74,10 @@ def main(args):
74
model_type="diffusion",
75
)
76
77
- # Load Autoencoder to produce the latent representations
78
print(f"Loading Stage 1 from {args.stage1_uri}")
79
stage1 = mlflow.pytorch.load_model(args.stage1_uri)
80
stage1.eval()
81
82
- # Create the diffusion model
83
print("Creating model...")
84
config = OmegaConf.load(args.config_file)
85
diffusion = DiffusionModelUNet(**config["ldm"].get("params", dict()))
0 commit comments