diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 2990d5701a0d..9ee8202b1c00 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -125,6 +125,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = load_image(validation_image) + # Convert RGB to BGR for OpenPose validation images + validation_image_np = np.array(validation_image)[:, :, ::-1] + validation_image = Image.fromarray(validation_image_np) # maybe need to inference on 1024 to get a good image validation_image = validation_image.resize((args.resolution, args.resolution))