+
+# Overview
+
+This is the official MindSpore implementation of [CannyEdit](https://vaynexie.github.io/CannyEdit/).
+
+CannyEdit is a novel training-free framework to support multitask image editing. It enables high-quality region-specific image edits, especially useful in cases where SOTA free-form image editing methods fail to ground edits accurately. Besides, it can support edits on multiple user-specific regions at one generation pass when multiple masks are given.
+
+
+
+1. Install
+ [CANN 8.0.RC3.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC3.beta1) and MindSpore according to the [official instructions](https://www.mindspore.cn/install).
+2. Install requirements
+ ```shell
+ pip install -r requirements.txt
+ ```
+3. Install mindone
+ ```shell
+ cd mindone
+ pip install -e .
+ ```
+ Try `python -c "import mindone"`. If no error occurs, the installation is successful.
+
+## 🚀 Quick Start
+The pipeline of using CannyEdit consists of 3 steps:
+1. Generate masks (Optional. Skipped if you have)
+2. Generate prompts (Optional. Skipped if you have)
+3. Generate edited image
+
+### Step 1: Generate masks (Optional)
+At first, the step needs model weights of [SAM2](https://github.com/facebookresearch/sam2/). Please download it using tools in `examples/sam2`.
+```bash
+cd examples/sam2/checkpoints && \
+./download_ckpts.sh &&
+```
+And the checkpoints will be downloaded into examples/sam2/checkpoints.
+
+Then, modify the path of checkpoint in the script file below. And run the shell script to launch the app of mask generator.
+```bash
+cd examples/canny_edit && \
+bash run_app_mask.sh
+```
+Then open the address of http://localhost:5000. If you use browser remotely, you can set on your remote machine as below:
+```bash
+ssh -L 8081:localhost:5000 username@ip
+```
+According to the mapping, just open the address of http://localhost:8081 on your remote machine.
+
+In the webpage of mask generator, choose specific method for corresponding editing task.
+
+- Adding task: Circle a target area where you want to add an object or person. Then click "Generate Ellipse Mask"
+- Replace and removal tasks: Draw a line on a certain area of an existing object or person. Then click "Generate SAM Mask"
+
+### Step 2: Generate prompts (Optional)
+In main.py, it will check if there is not source prompt for input image or target prompt for edited image. It will call Visual Language Model (VLM) to generate related prompts. Here we use [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
+
+### Step 3: Generate edited image
+There are several examples listed in run_infer.sh. Just uncomment one of them to generate corresponding case.
+```bash
+bash run_infer.sh
+```
+Here are examples of output for each test case
+
+- **case 1: Replace background with mountains**
+```bash
+python main.py \
+ --image_path './assets/imgs/girl33.jpeg' \
+ --image_whratio_unchange \
+ --save_folder './results/' \
+ --prompt_local "A mountain." \
+ --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+ --prompt_target "A young girl with red hair smiles brightly, wearing a red and white checkered shirt, sitting on a bench with mountains in the background." \
+ --mask_path "./assets/mask_temp/mask_209_inverse.png"
+```
+
+
+
+
+
+
+
+
+
+
+
+ From left to right, these are original image, mask image, and generated edited image.
+
+
+- **case 2: Replace the girl with a boy**
+```bash
+python main.py \
+ --image_path './assets/imgs/girl33.jpeg' \
+ --image_whratio_unchange \
+ --save_folder './results/' \
+ --prompt_local "A boy smiling." \
+ --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+ --prompt_target "A young boy with red hair smiles brightly, wearing a red and white checkered shirt." \
+ --mask_path "./assets/mask_temp/mask_208.png"
+```
+
+
+
+
+
+
+
+
+
+ From left to right, these are original image, mask image, and generated edited image.
+
+
+- **case 3: Add a monkey**
+```bash
+python main.py \
+--image_path './assets/imgs/girl33.jpeg' \
+--image_whratio_unchange \
+--save_folder './results/' \
+--prompt_local "A monkey playing." \
+--prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+--prompt_target "A young girl with red hair smiles brightly, wearing a red and white checkered shirt, a monkey playing nearby." \
+--mask_path "./assets/mask_temp/mask_213.png"
+```
+
+
+
+
+
+
+
+
+
+ From left to right, these are original image, mask image, and generated edited image.
+
+ From left to right, these are original image, mask image, and generated edited image.
+
+
+- **case 5: Replace the girl with a boy + add a monkey**
+```bash
+python main.py \
+ --image_path './assets/imgs/girl33.jpeg' \
+ --image_whratio_unchange \
+ --save_folder './results/' \
+ --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+ --prompt_local "A boy smiling." \
+ --prompt_local "A monkey playing." \
+ --mask_path "./assets/mask_temp/mask_208.png" \
+ --mask_path "./assets/mask_temp/mask_215.png" \
+ --prompt_target "A young boy wearing a red and white checkered shirt, a monkey playing nearby."
+```
+
+
+
+
+
+
+
+
+
+
+ From left to right, these are original image, two mask images, and generated edited image.
+
+
+
+## Performance
+
+
+Experiments are tested on Ascend Atlas 800T A2 machines with pyantive mode.
+
+- mindspore 2.7.0
+
+| model | cards | resolution | task | steps | s/Step | s/Image |
+|------------|-------|------------|----------------|-------|--------------|---------------|
+| CannyEdit | 1 | 768x768 | Replace | 50 | 6.12 | 306 |
+| CannyEdit | 1 | 768x768 | Add | 50 | 1.96 | 98 |
+| CannyEdit | 1 | 768x768 | Removal | 50 | 6.6 | 330 |
+| CannyEdit | 1 | 768x768 | Replace + Add | 50 | 5.7 | 285 |
+
+## Acknowledgement
+The codebase is modified based on [x-flux](https://github.com/XLabs-AI/x-flux).
diff --git a/examples/canny_edit/app_mask.py b/examples/canny_edit/app_mask.py
new file mode 100644
index 0000000000..97b6cc14b4
--- /dev/null
+++ b/examples/canny_edit/app_mask.py
@@ -0,0 +1,219 @@
+import argparse
+import base64
+import os
+import sys
+
+import cv2
+import numpy as np
+from flask import Flask, jsonify, render_template, request, send_file
+
+import mindspore as ms
+
+parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
+sam_dir = os.path.join(parent_dir, "sam2")
+sys.path.insert(0, sam_dir)
+
+# Check if we can use the model of SAM2.1
+try:
+ from sam2.build_sam import build_sam2
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+ SAM_AVAILABLE = True
+except ImportError:
+ SAM_AVAILABLE = False
+
+app = Flask(__name__)
+app.config["UPLOAD_FOLDER"] = "uploads"
+app.config["MASK_TEMP_FOLDER"] = "mask_temp"
+app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024 # 16MB max file size
+
+# Create the directory of uploading
+os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
+os.makedirs(app.config["MASK_TEMP_FOLDER"], exist_ok=True)
+
+
+def create_ellipse_mask(points, image_shape):
+ """create oval mask according to trace of points"""
+ if len(points) < 5:
+ # If points are much less, create a circle
+ center = np.mean(points, axis=0)
+ radius = max(10, int(np.max(np.abs(points - center)) / 2))
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ cv2.circle(mask, (int(center[0]), int(center[1])), radius, 255, -1)
+ return mask
+
+ # fit oval
+ points = np.array(points, dtype=np.int32)
+ ellipse = cv2.fitEllipse(points)
+ mask = np.zeros(image_shape[:2], dtype=np.uint8)
+ cv2.ellipse(mask, ellipse, 255, -1)
+ return mask
+
+
+def create_sam_mask(points, image, min_mask_region_area=500):
+ """Create mask using SAM2.1"""
+ if not SAM_AVAILABLE or sam_predictor is None:
+ raise Exception("SAM2.1 model is not available")
+
+ # Set image
+ sam_predictor.set_image(image)
+
+ # Prepare input points
+ input_points = np.array(points)
+ input_labels = np.ones(len(points)) # All points are front view
+
+ # Predict mask
+ masks, scores, _ = sam_predictor.predict(
+ point_coords=input_points, point_labels=input_labels, multimask_output=True
+ )
+
+ # Choose one mask with the highest score
+ best_mask_idx = np.argmax(scores)
+ mask = masks[best_mask_idx]
+ mask = mask.astype(np.uint8)
+ mask = remove_small_regions(mask, min_mask_region_area, "holes")
+ mask = remove_small_regions(mask, min_mask_region_area, "islands")
+
+ # Transfer into uint8
+ mask = (mask * 255).astype(np.uint8)
+ return mask
+
+
+def remove_small_regions(mask, area_thresh, mode):
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:]
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask
+
+
+@app.route("/")
+def index():
+ return render_template("index.html", sam_available=SAM_AVAILABLE)
+
+
+@app.route("/upload", methods=["POST"])
+def upload_image():
+ if "image" not in request.files:
+ return jsonify({"error": "No image file provided"}), 400
+
+ file = request.files["image"]
+ if file.filename == "":
+ return jsonify({"error": "No image selected"}), 400
+
+ # Save uploading image
+ filename = file.filename
+ filepath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
+ file.save(filepath)
+
+ # Read an image and get its size
+ image = cv2.imread(filepath)
+ if image is None:
+ return jsonify({"error": "Invalid image file"}), 400
+
+ height, width = image.shape[:2]
+
+ return jsonify({"filename": filename, "filepath": filepath, "width": width, "height": height})
+
+
+@app.route("/process", methods=["POST"])
+def process_image():
+ data = request.json
+ filename = data.get("filename")
+ points = data.get("points", [])
+ method = data.get("method", "ellipse") # 'ellipse' or 'sam'
+
+ if not filename or not points:
+ return jsonify({"error": "Missing filename or points"}), 400
+
+ # Read an image
+ filepath = os.path.join(app.config["UPLOAD_FOLDER"], filename)
+ image = cv2.imread(filepath)
+ if image is None:
+ return jsonify({"error": "Image not found"}), 404
+
+ try:
+ # Create mask following chosen mask
+ if method == "ellipse":
+ mask = create_ellipse_mask(points, image.shape)
+ elif method == "sam" and SAM_AVAILABLE:
+ mask = create_sam_mask(points, image)
+ else:
+ return jsonify({"error": "Invalid method or SAM not available"}), 400
+
+ # Save mask as a temporary file
+ mask_filename = f"mask_{filename}"
+ mask_file = mask_filename.split(".")[0]
+ count = 0
+ for file in os.listdir(app.config["MASK_TEMP_FOLDER"]):
+ if mask_file in file:
+ count += 1
+
+ mask_path = os.path.join(app.config["MASK_TEMP_FOLDER"], f"{mask_file}_{method}_{count}.png")
+
+ cv2.imwrite(mask_path, mask)
+
+ # Transfer mask into encoding of base64
+ _, buffer = cv2.imencode(".png", mask)
+ mask_base64 = base64.b64encode(buffer).decode("utf-8")
+
+ return jsonify(
+ {
+ "mask_filename": mask_filename,
+ "mask_path": mask_path,
+ "mask_data": f"data: image/png; base64, {mask_base64}",
+ }
+ )
+
+ except Exception as e:
+ return jsonify({"error": str(e)}), 500
+
+
+@app.route("/download/")
+def download_mask(filename):
+ mask_path = os.path.join(app.config["MASK_TEMP_FOLDER"], filename)
+ if os.path.exists(mask_path):
+ return send_file(mask_path, as_attachment=True)
+ return jsonify({"error": "File not found"}), 404
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--sam_checkpoint", type=str, default="./checkpoints/sam2.1_hiera_large.pt")
+ parser.add_argument("--model_cfg", type=str, default="./configs/sam2.1/sam2.1_hiera_l.yaml")
+ args = parser.parse_args()
+
+ # try to load model of SAM2.1
+ sam_model = None
+ sam_predictor = None
+ if SAM_AVAILABLE:
+ try:
+ # Note: you should download model weights of SAM2.1, and set the right path
+ sam_checkpoint = args.sam_checkpoint
+
+ if os.path.exists(sam_checkpoint):
+ model_cfg = args.model_cfg
+ sam = build_sam2(model_cfg, sam_checkpoint)
+ dtype = ms.float16
+ sam.to_float(dtype)
+ sam_predictor = SAM2ImagePredictor(sam)
+ print("SAM2.1 model loaded successfully")
+ else:
+ print(f"SAM2.1 model checkpoint not found at {sam_checkpoint}")
+ SAM_AVAILABLE = False
+ except Exception as e:
+ print(f"Error loading SAM2.1 model: {e}")
+ SAM_AVAILABLE = False
+ else:
+ print("SAM2.1 is not available. Please install segment-anything package.")
+ app.run(host="0.0.0.0", port=5000, debug=True)
diff --git a/examples/canny_edit/assets/example_results/result_338.png b/examples/canny_edit/assets/example_results/result_338.png
new file mode 100644
index 0000000000..947d76af51
Binary files /dev/null and b/examples/canny_edit/assets/example_results/result_338.png differ
diff --git a/examples/canny_edit/assets/example_results/result_339.png b/examples/canny_edit/assets/example_results/result_339.png
new file mode 100644
index 0000000000..668d2c7eff
Binary files /dev/null and b/examples/canny_edit/assets/example_results/result_339.png differ
diff --git a/examples/canny_edit/assets/example_results/result_345.png b/examples/canny_edit/assets/example_results/result_345.png
new file mode 100644
index 0000000000..fd70b522f6
Binary files /dev/null and b/examples/canny_edit/assets/example_results/result_345.png differ
diff --git a/examples/canny_edit/assets/example_results/result_346.png b/examples/canny_edit/assets/example_results/result_346.png
new file mode 100644
index 0000000000..fbc2f3b79d
Binary files /dev/null and b/examples/canny_edit/assets/example_results/result_346.png differ
diff --git a/examples/canny_edit/assets/example_results/result_800.png b/examples/canny_edit/assets/example_results/result_800.png
new file mode 100644
index 0000000000..c0b5268722
Binary files /dev/null and b/examples/canny_edit/assets/example_results/result_800.png differ
diff --git a/examples/canny_edit/assets/imgs/girl33.jpeg b/examples/canny_edit/assets/imgs/girl33.jpeg
new file mode 100644
index 0000000000..f8ae8805f5
Binary files /dev/null and b/examples/canny_edit/assets/imgs/girl33.jpeg differ
diff --git a/examples/canny_edit/assets/mask_temp/mask_208.png b/examples/canny_edit/assets/mask_temp/mask_208.png
new file mode 100644
index 0000000000..42863b59a4
Binary files /dev/null and b/examples/canny_edit/assets/mask_temp/mask_208.png differ
diff --git a/examples/canny_edit/assets/mask_temp/mask_209_inverse.png b/examples/canny_edit/assets/mask_temp/mask_209_inverse.png
new file mode 100644
index 0000000000..8ef859fb5a
Binary files /dev/null and b/examples/canny_edit/assets/mask_temp/mask_209_inverse.png differ
diff --git a/examples/canny_edit/assets/mask_temp/mask_213.png b/examples/canny_edit/assets/mask_temp/mask_213.png
new file mode 100644
index 0000000000..10cfb6926a
Binary files /dev/null and b/examples/canny_edit/assets/mask_temp/mask_213.png differ
diff --git a/examples/canny_edit/assets/mask_temp/mask_215.png b/examples/canny_edit/assets/mask_temp/mask_215.png
new file mode 100644
index 0000000000..8daa446de8
Binary files /dev/null and b/examples/canny_edit/assets/mask_temp/mask_215.png differ
diff --git a/examples/canny_edit/assets/page_imgs/grid_image.png b/examples/canny_edit/assets/page_imgs/grid_image.png
new file mode 100644
index 0000000000..e79786a9a1
Binary files /dev/null and b/examples/canny_edit/assets/page_imgs/grid_image.png differ
diff --git a/examples/canny_edit/main.py b/examples/canny_edit/main.py
new file mode 100644
index 0000000000..e539ccd78e
--- /dev/null
+++ b/examples/canny_edit/main.py
@@ -0,0 +1,462 @@
+import argparse
+import os
+import random
+import sys
+import warnings
+
+import numpy as np
+from PIL import Image
+from src.cannyedit_pipeline import CannyEditPipeline
+from src.util import plot_image_with_mask, process_mask
+
+import mindspore as ms
+import mindspore.nn as nn
+
+# Suppress all warnings
+warnings.filterwarnings("ignore")
+
+
+def create_argparser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--qwen_checkpoint_path", default="Qwen/Qwen2.5-VL-7B-Instruct")
+ parser.add_argument(
+ "--prompt_source", type=str, help="The text prompt that describes the source image" # required=True,
+ )
+ parser.add_argument(
+ "--prompt_target", help="The text prompt that describes the targeted image after editing" # required=True,
+ )
+ parser.add_argument(
+ "--prompt_local",
+ action="append",
+ help="The local prompt(s) for edit region(s)",
+ )
+ parser.add_argument(
+ "--mask_path",
+ action="append",
+ help="path(s) of mask(s) indicating the region to edit",
+ )
+ parser.add_argument("--dilate_mask", action="store_true", help="Dilate the mask")
+ parser.add_argument(
+ "--fill_hole_mask",
+ action="store_true",
+ default=True,
+ help="Fill the holes in the mask, useful for the imprecise segmentation masks",
+ )
+ parser.add_argument("--width", type=int, default=768, help="The width for generated image")
+ parser.add_argument("--height", type=int, default=768, help="The height for generated image")
+ parser.add_argument(
+ "--image_whratio_unchange",
+ action="store_true",
+ help="In default we use square input/output, set this to True if you wish to keep the original image width/height ratio unchanged.",
+ )
+ parser.add_argument("--save_folder", type=str, default="./cannyedit_outputs/", help="Folder to save")
+ parser.add_argument(
+ "--neg_prompt2",
+ type=str,
+ default="focus,centered foreground, humans, objects, noise, blurring, low resolution, artifacts, distortion, "
+ "overexposure, and uneven lighting, bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs,"
+ " bad arms, missing legs, missing arms, poorly drawn face, disconnected limbs",
+ help="The input text negative prompt2",
+ )
+ # 'oval, noise, plaid, polka-dot, leopard print, cartoon, unreal, animate, amputation, '
+ parser.add_argument(
+ "--neg_prompt",
+ type=str,
+ default="humans, objects, noise, blurring, low resolution, artifacts, distortion, overexposure, and uneven lighting,"
+ " bad anatomy, bad hands, missing fingers, extra fingers, three hands, three legs, bad arms, missing legs, missing arms,"
+ " poorly drawn face, disconnected limbs",
+ help="The input text negative prompt",
+ )
+ parser.add_argument("--control_weight2", type=float, default=0.7, help="Controlnet model strength (from 0 to 1.0)")
+ parser.add_argument(
+ "--multi_run",
+ action="store_true",
+ help="If true, we will cache the inversion result and previous generation result, and then allow the multi-run edits",
+ )
+ parser.add_argument("--inversion_save_path", type=str, default=None, help="Path to save the inversion result")
+ parser.add_argument(
+ "--generate_save_path", type=str, default=None, help="Path to save the previous generation result"
+ )
+ parser.add_argument("--img_prompt", type=str, default=None, help="Path to input image prompt")
+ parser.add_argument("--neg_img_prompt", type=str, default=None, help="Path to input negative image prompt")
+ parser.add_argument("--local_path", type=str, default=None, help="Local path to the model checkpoint (Controlnet)")
+ parser.add_argument(
+ "--repo_id", type=str, default=None, help="A HuggingFace repo id to download model (Controlnet)"
+ )
+ parser.add_argument("--name", type=str, default=None, help="A filename to download from HuggingFace")
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
+ parser.add_argument("--use_controlnet", action="store_true", help="Load Controlnet model")
+ parser.add_argument("--use_paint", action="store_true", help="Load inpainting model")
+ parser.add_argument("--image_path", type=str, default=None, help="Path to image")
+ parser.add_argument("--control_weight", type=float, default=0.8, help="Controlnet model strength (from 0 to 1.0)")
+ parser.add_argument(
+ "--control_type",
+ type=str,
+ default="canny",
+ choices=("canny", "openpose", "depth", "zoe", "hed", "hough", "tile"),
+ help="Name of controlnet condition, example: canny",
+ )
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="flux-dev",
+ choices=("flux-dev", "flux-dev-fp8", "flux-schnell"),
+ help="Model type to use (flux-dev, flux-dev-fp8, flux-schnell)",
+ )
+ parser.add_argument("--num_steps", type=int, default=50, help="The num_steps for diffusion process")
+ parser.add_argument("--guidance", type=float, default=4, help="The guidance for diffusion process")
+ parser.add_argument(
+ "--seed", type=int, default=random.randint(0, 9999999), help="A seed for reproducible inference"
+ )
+ parser.add_argument("--true_gs", type=float, default=2, help="true guidance")
+ parser.add_argument("--timestep_to_start_cfg", type=int, default=5, help="timestep to start true guidance")
+ return parser
+
+
+def generate_output_by_qwen(qwen_model, qwen_processor, image_path, height, width, prompt, max_new_tokens=128):
+ """
+ Processes an image and text input, passes them through a model, and generates output text.
+
+ Args:
+ qwen_model: The pre-trained model for inference.
+ qwen_processor: The processor for text and vision inputs.
+ image_path (str): Path to the input image.
+ height (int): Original height of the input image.
+ width (int): Original width of the input image.
+ prompt (str): Text prompt to guide the model's generation.
+ max_new_tokens (int): Maximum number of tokens to generate. Default is 128.
+
+ Returns:
+ str: Decoded output text from the model.
+ """
+ # Prepare the messages with resized image and the text prompt
+ from mindone.transformers.models.qwen2_vl.qwen_vl_utils import process_vision_info
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": image_path,
+ "resized_height": int(height / 2.5),
+ "resized_width": int(width / 2.5),
+ },
+ {"type": "text", "text": prompt},
+ ],
+ }
+ ]
+
+ # Process the text template
+ text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # Process image and video inputs
+ image_inputs, video_inputs = process_vision_info(messages)
+
+ # Prepare the input tensors
+ inputs = qwen_processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="np")
+ for k, v in inputs.items():
+ inputs[k] = ms.Tensor(v)
+
+ # Generate the output
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=max_new_tokens)
+
+ # Trim the generated IDs to exclude input tokens
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
+
+ # Decode the output text
+ output_text = qwen_processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )
+
+ return output_text
+
+
+def main(args):
+ removal_flag = False
+ mask_path_list = []
+ image = Image.open(args.image_path).convert("RGB")
+ if args.image_whratio_unchange is True:
+ widtho, heighto = image.size
+ maxone = np.max([widtho, heighto])
+ if maxone == widtho:
+ args.width = args.width
+ args.height = int(args.width * (heighto / widtho))
+ else:
+ args.height = args.height
+ args.width = int(args.height * (widtho / heighto))
+ print("Keep image width/height ratio unchanged, we now use:[width, height]=" + str([args.width, args.height]))
+
+ cannyedit_pipeline = CannyEditPipeline("flux-dev", offload=args.offload)
+ cannyedit_pipeline.set_controlnet(
+ "canny", None, "XLabs-AI/flux-controlnet-canny-v3", "flux-canny-controlnet-v3.safetensors"
+ )
+
+ # Input local prompt
+ if args.prompt_local is None:
+ args.prompt_local = []
+ print("No local prompt provided. Do you want to enter the local prompt here?")
+ resp = input("Press 'y' for yes, anything else for no and exit:").strip().lower()
+ if resp == "y":
+ args.prompt_local.append(input("Enter the first local prompt: "))
+ for kk in range(10):
+ resp = input(
+ "Enter the next local prompt if you may have, enter 'done' if you have finished all inputs: "
+ )
+ if resp == "done":
+ break
+ else:
+ args.prompt_local.append(resp)
+ else:
+ print("\n")
+ print("Exiting CannyEdit.")
+ sys.exit(1) # Exit with an error code
+
+ for pp_ind in range(len(args.prompt_local)):
+ if "[remove]" in args.prompt_local[pp_ind]:
+ args.prompt_local[pp_ind] = "empty background" + " out-of-focus, atmospheric background"
+
+ # --------------------------------------------------------------------------------------
+ # Read the mask files is provided
+ if args.mask_path is not None:
+ dilate_mask = args.dilate_mask
+ if "empty background" in args.prompt_local[0]:
+ removal_flag = True
+ local_mask = process_mask(
+ args.mask_path[0],
+ args.height,
+ args.width,
+ dilate=dilate_mask,
+ dilation_kernel_size=(5, 5),
+ fill_holes=args.fill_hole_mask,
+ closing_kernel_size=(1, 1),
+ )
+ mask_path_list.append(args.mask_path[0])
+ local_mask_addition = []
+ mask_count = 1
+ for maskp in args.mask_path[1:]:
+ dilate_mask = args.dilate_mask
+ # removal_add
+ if "empty background" in args.prompt_local[mask_count]:
+ removal_flag = True
+ local_mask_addition.append(
+ process_mask(
+ maskp,
+ args.height,
+ args.width,
+ dilate=dilate_mask,
+ dilation_kernel_size=(5, 5),
+ fill_holes=args.fill_hole_mask,
+ closing_kernel_size=(1, 1),
+ )
+ )
+ mask_path_list.append(maskp)
+ mask_count += 1
+
+ else:
+ raise ValueError("mask_path must be provided!")
+
+ result_save_path = ""
+ # Apply vlm to generate source prompt and target prompt if not provided
+ if args.prompt_source is None or args.prompt_target is None:
+ print("no source/target prompt is provided, using QWEN2.5-VL to generate the prompt automatically \n")
+ from transformers import AutoProcessor
+
+ from mindone.transformers import Qwen2_5_VLForConditionalGeneration
+
+ with nn.no_init_parameters():
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ args.qwen_checkpoint_path, mindspore_dtype=ms.bfloat16
+ )
+ min_pixels = 256 * 28 * 28
+ max_pixels = 512 * 28 * 28
+ qwen_processor = AutoProcessor.from_pretrained(
+ args.qwen_checkpoint_path, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ if args.prompt_source is None:
+ output_text = generate_output_by_qwen(
+ qwen_model,
+ qwen_processor,
+ args.image_path,
+ args.height,
+ args.width,
+ "Describe this image in 15 words.",
+ max_new_tokens=128,
+ )
+ args.prompt_source = output_text[0]
+ print("\n")
+ print("VLM generated source prompt: " + args.prompt_source)
+ print("\n")
+ if args.prompt_target is None:
+ # removal_add
+ print(
+ "Important: Currently the auto generation of target prompts only support **adding and removal**. If the editing involves only "
+ "edit tasks like replacement, please provide the target prompt here, you may refer to the VLM-generated source prompt.\n"
+ )
+ if removal_flag is False:
+ resp = input(
+ "Press '1' for using VLM to geernate target prompt, other enter the target prompt directly: \n"
+ )
+ elif removal_flag is True:
+ resp = "1"
+
+ if resp != "1":
+ args.prompt_target = resp
+ print("Entered target prompt: " + args.prompt_target)
+ print("\n")
+ if resp == "1":
+ if removal_flag is False:
+ prompt_for_target = (
+ "Given the caption for this image:"
+ + str(args.prompt_source)
+ + "Suppose there would be new objects in the image:"
+ )
+ words_count = 15
+ for object_add in args.prompt_local:
+ prompt_for_target += object_add + "; and "
+ words_count += 5
+ prompt_for_target += (
+ "\n Based on the original caption and the description to the new objects. Generate the new caption after the objects are added in "
+ + str(words_count)
+ + " words."
+ ) # Keep the original caption if possible'
+ output_text = generate_output_by_qwen(
+ qwen_model,
+ qwen_processor,
+ args.image_path,
+ args.height,
+ args.width,
+ prompt_for_target,
+ max_new_tokens=128,
+ )
+
+ args.prompt_target = output_text[0]
+ print("VLM generated target prompt: " + args.prompt_target)
+ print("\n")
+ # removal_add
+ elif removal_flag is True:
+ args.prompt_target = " "
+ save_path = plot_image_with_mask(
+ args.image_path,
+ mask_path_list,
+ width=args.width,
+ height=args.height,
+ save_path="assets/mask_temp/masktpimage.png",
+ )
+ output_text = generate_output_by_qwen(
+ qwen_model,
+ qwen_processor,
+ save_path,
+ args.height,
+ args.width,
+ "Support the objects within the red bounding box will be removed. Describe the image background excluding"
+ " the removed objects in 10 words.",
+ max_new_tokens=128,
+ )
+
+ args.prompt_target = output_text[0]
+ print("\n")
+ print("VLM generated target prompt for the removal task: " + args.prompt_target)
+ print("\n")
+
+ output_text = generate_output_by_qwen(
+ qwen_model,
+ qwen_processor,
+ save_path,
+ args.height,
+ args.width,
+ "describe the region within the red bounding box in 10 words.",
+ max_new_tokens=128,
+ )
+ args.neg_prompt = output_text[0]
+ print("\n")
+ print("VLM generated negative prompt for the removal task: " + output_text[0])
+ print("\n")
+
+ del qwen_model
+ del qwen_processor
+
+ # --------------------------------------------------------------------------------------
+ print("Running CannyEdit")
+ # Stage 1: Generation
+ stage1 = "stage_removal"
+ result = cannyedit_pipeline(
+ prompt_source=args.prompt_source,
+ prompt_local1=args.prompt_local[0],
+ prompt_target=args.prompt_target,
+ prompt_local_addition=args.prompt_local[1:],
+ controlnet_image=image,
+ local_mask=local_mask,
+ local_mask_addition=local_mask_addition,
+ width=args.width,
+ height=args.height,
+ guidance=args.guidance,
+ num_steps=args.num_steps,
+ seed=args.seed,
+ true_gs=args.true_gs,
+ control_weight=args.control_weight,
+ control_weight2=args.control_weight2,
+ neg_prompt=args.neg_prompt,
+ # removal_add
+ neg_prompt2=args.neg_prompt2,
+ timestep_to_start_cfg=args.timestep_to_start_cfg,
+ stage=stage1,
+ generate_save_path=args.generate_save_path,
+ inversion_save_path=args.inversion_save_path,
+ )
+
+ # Save the edited image
+ if not os.path.exists(args.save_folder):
+ os.mkdir(args.save_folder)
+ ind = len(os.listdir(args.save_folder))
+ result_save_path = os.path.join(args.save_folder, f"result_{ind}.png")
+ result.save(result_save_path)
+
+ if removal_flag is False:
+ # Stage 1: Generation
+ stage1 = "stage_generate"
+ print("Running CannyEdit")
+ result = cannyedit_pipeline(
+ prompt_source=args.prompt_source,
+ prompt_local1=args.prompt_local[0],
+ prompt_target=args.prompt_target,
+ prompt_local_addition=args.prompt_local[1:],
+ controlnet_image=image,
+ local_mask=local_mask,
+ local_mask_addition=local_mask_addition,
+ width=args.width,
+ height=args.height,
+ guidance=args.guidance,
+ num_steps=args.num_steps,
+ seed=args.seed,
+ true_gs=args.true_gs,
+ control_weight=args.control_weight,
+ control_weight2=args.control_weight2,
+ neg_prompt=args.neg_prompt,
+ neg_prompt2=args.neg_prompt2,
+ timestep_to_start_cfg=args.timestep_to_start_cfg,
+ stage=stage1,
+ generate_save_path=args.generate_save_path,
+ inversion_save_path=args.inversion_save_path,
+ )
+
+ # Save the edited image
+ if not os.path.exists(args.save_folder):
+ os.mkdir(args.save_folder)
+ ind = len(os.listdir(args.save_folder))
+ result_save_path = os.path.join(args.save_folder, f"result_{ind}.png")
+ result.save(result_save_path)
+
+ if result_save_path:
+ print(f"Generated image saved in {result_save_path}")
+
+ # remove all cached files
+ if args.inversion_save_path is not None and os.path.exists(args.inversion_save_path):
+ os.remove(args.inversion_save_path)
+ if args.generate_save_path is not None and os.path.exists(args.generate_save_path):
+ os.remove(args.generate_save_path)
+
+
+if __name__ == "__main__":
+ args = create_argparser().parse_args()
+ main(args)
diff --git a/examples/canny_edit/requirements.txt b/examples/canny_edit/requirements.txt
new file mode 100644
index 0000000000..61cbecb223
--- /dev/null
+++ b/examples/canny_edit/requirements.txt
@@ -0,0 +1,12 @@
+Flask==3.1.2
+huggingface_hub==0.26.0
+matplotlib==3.5.1
+opencv_python==4.10.0.84
+Pillow==11.3.0
+safetensors==0.6.2
+tqdm==4.67.1
+transformers==4.50.0
+hydra-core>=1.3.2
+torch # load SAM2 pytorch weights
+iopath>=0.1.10
+omegaconf>=2.3.0
diff --git a/examples/canny_edit/run_app_mask.sh b/examples/canny_edit/run_app_mask.sh
new file mode 100644
index 0000000000..97b7a2037f
--- /dev/null
+++ b/examples/canny_edit/run_app_mask.sh
@@ -0,0 +1,6 @@
+export DEVICE_ID=0
+export no_proxy="localhost,127.0.0.1"
+
+python app_mask.py \
+--sam_checkpoint ../sam2/checkpoints/sam2.1_hiera_large.pt \
+--model_cfg ../sam2/configs/sam2.1/sam2.1_hiera_l.yaml
diff --git a/examples/canny_edit/run_infer.sh b/examples/canny_edit/run_infer.sh
new file mode 100644
index 0000000000..cf28a093d5
--- /dev/null
+++ b/examples/canny_edit/run_infer.sh
@@ -0,0 +1,56 @@
+export DEVICE_ID=0
+
+# Test case 1: Replace background with mountains
+python main.py \
+ --image_path './assets/imgs/girl33.jpeg' \
+ --image_whratio_unchange \
+ --save_folder './results/' \
+ --prompt_local "A mountain." \
+ --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+ --prompt_target "A young girl with red hair smiles brightly, wearing a red and white checkered shirt, sitting on a bench with mountains in the background." \
+ --mask_path "./assets/mask_temp/mask_209_inverse.png"
+
+
+## Test case 2: Replace the girl with a boy
+#python main.py \
+# --image_path './assets/imgs/girl33.jpeg' \
+# --image_whratio_unchange \
+# --save_folder './results/' \
+# --prompt_local "A boy smiling." \
+# --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+# --prompt_target "A young boy with red hair smiles brightly, wearing a red and white checkered shirt." \
+# --mask_path "./assets/mask_temp/mask_208.png"
+
+
+## Test case 3: Add a monkey
+#python main.py \
+#--image_path './assets/imgs/girl33.jpeg' \
+#--image_whratio_unchange \
+#--save_folder './results/' \
+#--prompt_local "A monkey playing." \
+#--prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+#--prompt_target "A young girl with red hair smiles brightly, wearing a red and white checkered shirt, a monkey playing nearby." \
+#--mask_path "./assets/mask_temp/mask_213.png"
+
+
+## Test case 4 Remove the girl
+#python main.py \
+# --image_path './assets/imgs/girl33.jpeg' \
+# --image_whratio_unchange \
+# --save_folder './results/' \
+# --prompt_local '[remove]' \
+# --mask_path "./assets/mask_temp/mask_208.png" \
+# --dilate_mask \
+
+
+## Test case 5: Replace the girl with a boy + add a monkey
+#python main.py \
+# --image_path './assets/imgs/girl33.jpeg' \
+# --image_whratio_unchange \
+# --save_folder './results/' \
+# --prompt_source "A young girl with red hair smiles brightly, wearing a red and white checkered shirt." \
+# --prompt_local "A boy smiling." \
+# --prompt_local "A monkey playing." \
+# --mask_path "./assets/mask_temp/mask_208.png" \
+# --mask_path "./assets/mask_temp/mask_215.png" \
+# --prompt_target "A young boy wearing a red and white checkered shirt, a monkey playing nearby."
diff --git a/examples/canny_edit/src/__init__.py b/examples/canny_edit/src/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/canny_edit/src/annotator/__init__.py b/examples/canny_edit/src/annotator/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/canny_edit/src/annotator/canny/__init__.py b/examples/canny_edit/src/annotator/canny/__init__.py
new file mode 100644
index 0000000000..cb0da951dc
--- /dev/null
+++ b/examples/canny_edit/src/annotator/canny/__init__.py
@@ -0,0 +1,6 @@
+import cv2
+
+
+class CannyDetector:
+ def __call__(self, img, low_threshold, high_threshold):
+ return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/examples/canny_edit/src/annotator/util.py b/examples/canny_edit/src/annotator/util.py
new file mode 100644
index 0000000000..36156978aa
--- /dev/null
+++ b/examples/canny_edit/src/annotator/util.py
@@ -0,0 +1,38 @@
+import os
+
+import cv2
+import numpy as np
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), "ckpts")
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
diff --git a/examples/canny_edit/src/cannyedit_pipeline.py b/examples/canny_edit/src/cannyedit_pipeline.py
new file mode 100644
index 0000000000..11e8ffa37e
--- /dev/null
+++ b/examples/canny_edit/src/cannyedit_pipeline.py
@@ -0,0 +1,281 @@
+from typing import List
+
+import numpy as np
+from PIL import Image
+from src.sampling import denoise_cannyedit, get_image_tensor, get_noise, get_schedule, prepare, unpack
+from src.sampling_removal import denoise_cannyedit_removal
+from src.util import (
+ Annotator,
+ load_ae,
+ load_checkpoint,
+ load_clip,
+ load_controlnet,
+ load_flow_model,
+ load_flow_model_quantized,
+ load_t5,
+)
+
+import mindspore as ms
+
+
+def prepare_conditional_inputs(base_input, suffix):
+ """
+ Helper function to process and restructure the input dictionary.
+ """
+ result = {}
+ for key in ["txt", "txt_ids", "vec"]:
+ result[f"{key}{suffix}"] = base_input[key]
+ base_input.pop("img") # Remove the key from the original dictionary
+ base_input.pop("img_ids") # Remove the key from the original dictionary
+ return result
+
+
+class CannyEditPipeline:
+ def __init__(self, model_type, offload: bool = False):
+ self.offload = offload
+ self.model_type = model_type
+
+ self.ae = load_ae(
+ model_type,
+ )
+
+ self.clip = load_clip()
+ self.t5 = load_t5(max_length=512)
+
+ if "fp8" in model_type:
+ self.model = load_flow_model_quantized(
+ model_type,
+ )
+ else:
+ self.model = load_flow_model(
+ model_type,
+ )
+
+ self.image_encoder_path = "openai/clip-vit-large-patch14"
+ self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
+ self.lora_types_to_names = {
+ "realism": "lora.safetensors",
+ }
+ self.controlnet_loaded = False
+ self.ip_loaded = False
+ self.paint_loaded = False
+
+ def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
+ self.controlnet = load_controlnet(
+ self.model_type,
+ ).to_float(ms.bfloat16)
+ checkpoint = load_checkpoint(local_path, repo_id, name)
+ self.controlnet.load_state_dict(checkpoint, strict=False)
+ self.annotator = Annotator(
+ control_type,
+ )
+ self.controlnet_loaded = True
+ self.control_type = control_type
+
+ def __call__(
+ self,
+ prompt_source: str,
+ prompt_local1: str,
+ prompt_target: str,
+ prompt_local_addition: List[str],
+ controlnet_image: Image = None,
+ local_mask=None,
+ local_mask_addition=[],
+ width: int = 512,
+ height: int = 512,
+ guidance: float = 4,
+ num_steps: int = 50,
+ seed: int = 123456789,
+ true_gs: float = 3,
+ control_weight: float = 0.9,
+ control_weight2: float = 0.5,
+ neg_prompt: str = "",
+ neg_prompt2: str = "",
+ timestep_to_start_cfg: int = 0,
+ generate_save_path=None,
+ inversion_save_path=None,
+ stage=None,
+ ):
+ width = 16 * (width // 16)
+ height = 16 * (height // 16)
+
+ # change: process the source image
+ if self.controlnet_loaded:
+ source_image = controlnet_image.copy()
+ controlnet_cond = self.annotator(controlnet_image, width, height)
+ controlnet_cond = ms.Tensor.from_numpy((np.array(controlnet_cond) / 127.5) - 1)
+ controlnet_cond = controlnet_cond.permute(2, 0, 1).unsqueeze(0).to(ms.bfloat16)
+
+ # change:add parameters
+ return self.construct(
+ prompt_source,
+ prompt_local1,
+ prompt_target,
+ prompt_local_addition,
+ local_mask,
+ local_mask_addition,
+ width,
+ height,
+ guidance,
+ num_steps,
+ seed,
+ source_image,
+ controlnet_cond,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ true_gs=true_gs,
+ control_weight=control_weight,
+ control_weight2=control_weight2,
+ neg_prompt=neg_prompt,
+ neg_prompt2=neg_prompt2,
+ generate_save_path=generate_save_path,
+ inversion_save_path=inversion_save_path,
+ stage=stage,
+ )
+
+ def construct(
+ self,
+ prompt_source,
+ prompt_local1,
+ prompt_target,
+ prompt_local_addition,
+ local_mask,
+ local_mask_addition,
+ width,
+ height,
+ guidance,
+ num_steps,
+ seed,
+ source_image=None,
+ controlnet_cond=None,
+ timestep_to_start_cfg=0,
+ true_gs=3.5,
+ control_weight=0.9,
+ control_weight2=0.5,
+ neg_prompt="",
+ neg_prompt2="",
+ generate_save_path=None,
+ inversion_save_path=None,
+ stage=None,
+ ):
+ x = get_noise(1, height, width, dtype=ms.bfloat16, seed=seed)
+
+ source_image_latent = self.ae.encode(get_image_tensor(source_image, height, width, dtype=ms.float32)).to(
+ ms.bfloat16
+ )
+
+ timesteps = get_schedule(
+ num_steps,
+ (width // 8) * (height // 8) // (16 * 16),
+ shift=True,
+ )
+
+ ms.manual_seed(seed)
+ with ms._no_grad():
+ if self.offload:
+ self.t5, self.clip = self.t5, self.clip
+ self.offload_model_to_cpu(self.t5, self.clip)
+
+ inp_cond_im = prepare(t5=self.t5, clip=self.clip, img=source_image_latent, prompt=prompt_source)
+
+ # Prepare inputs with different prompts
+ inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt="a real-world image of " + prompt_source)
+ neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
+ # removal_add
+ neg_inp_cond2 = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt2)
+ # Process inp_cond2, local prompt 1
+ inp_cond2 = prepare(t5=self.t5, clip=self.clip, img=x, prompt="a real-world image of " + prompt_local1)
+ inp_cond2 = prepare_conditional_inputs(inp_cond2, "2")
+ # Process inp_cond3, target prompt
+ inp_cond3 = prepare(t5=self.t5, clip=self.clip, img=x, prompt="a real-world image of " + prompt_target)
+ inp_cond3 = prepare_conditional_inputs(inp_cond3, "3")
+ # Process additional local prompts
+ inp_cond_addition = {}
+ inp_cond_addition["txt_addition"] = []
+ inp_cond_addition["txt_ids_addition"] = []
+ inp_cond_addition["vec_addition"] = []
+ for pp in prompt_local_addition:
+ inp_cond4 = prepare(t5=self.t5, clip=self.clip, img=x, prompt="a real-world image of " + str(pp))
+ inp_cond_addition["txt_addition"].append(inp_cond4["txt"])
+ inp_cond_addition["txt_ids_addition"].append(inp_cond4["txt_ids"])
+ inp_cond_addition["vec_addition"].append(inp_cond4["vec"])
+
+ source_image_latent_rg = inp_cond_im["img"]
+
+ if self.offload:
+ self.offload_model_to_cpu(self.t5, self.clip)
+
+ if self.controlnet_loaded:
+ if stage == "stage_generate":
+ x = denoise_cannyedit(
+ self.model,
+ **inp_cond,
+ **inp_cond2,
+ **inp_cond3,
+ **inp_cond_addition,
+ local_mask=local_mask,
+ local_mask_addition=local_mask_addition,
+ source_image_latent=source_image_latent,
+ source_image_latent_rg=source_image_latent_rg,
+ controlnet=self.controlnet,
+ timesteps=timesteps,
+ guidance=guidance,
+ controlnet_cond=controlnet_cond,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ neg_txt=neg_inp_cond["txt"],
+ neg_txt_ids=neg_inp_cond["txt_ids"],
+ neg_vec=neg_inp_cond["vec"],
+ true_gs=true_gs,
+ controlnet_gs=control_weight,
+ controlnet_gs2=control_weight2,
+ seed=seed,
+ generate_save_path=generate_save_path,
+ inversion_save_path=inversion_save_path,
+ stage=stage,
+ )
+ # removal_add
+ elif stage == "stage_removal":
+ x = denoise_cannyedit_removal(
+ self.model,
+ **inp_cond,
+ **inp_cond2,
+ **inp_cond3,
+ **inp_cond_addition,
+ local_mask=local_mask,
+ local_mask_addition=local_mask_addition,
+ source_image_latent=source_image_latent,
+ source_image_latent_rg=source_image_latent_rg,
+ controlnet=self.controlnet,
+ timesteps=timesteps,
+ guidance=guidance,
+ controlnet_cond=controlnet_cond,
+ timestep_to_start_cfg=timestep_to_start_cfg,
+ neg_txt=neg_inp_cond["txt"] * 0.5 + neg_inp_cond2["txt"] * 0.5,
+ neg_txt_ids=neg_inp_cond["txt_ids"],
+ neg_vec=neg_inp_cond["vec"] * 0.5 + neg_inp_cond2["vec"] * 0.5,
+ true_gs=true_gs,
+ controlnet_gs=control_weight,
+ controlnet_gs2=control_weight2,
+ seed=seed,
+ generate_save_path=generate_save_path,
+ inversion_save_path=inversion_save_path,
+ stage=stage,
+ )
+
+ if self.offload:
+ self.offload_model_to_cpu(self.model)
+
+ x = unpack(x.float(), height, width)
+ x = self.ae.decode(x)
+
+ self.offload_model_to_cpu(self.ae.decoder)
+
+ x1 = x.clamp(-1, 1)
+ # x1 = rearrange(x1[-1], "c h w -> h w c")
+ x1 = x1[-1].permute(1, 2, 0)
+ output_img = Image.fromarray((127.5 * (x1 + 1.0)).asnumpy().astype(np.uint8))
+ return output_img
+
+ def offload_model_to_cpu(self, *models):
+ if not self.offload:
+ return
+ raise NotImplementedError("Offload is not implemented")
diff --git a/examples/canny_edit/src/controlnet.py b/examples/canny_edit/src/controlnet.py
new file mode 100644
index 0000000000..1721b2aeb5
--- /dev/null
+++ b/examples/canny_edit/src/controlnet.py
@@ -0,0 +1,214 @@
+from dataclasses import dataclass
+
+import mindspore
+from mindspore import Parameter, Tensor
+from mindspore.common.initializer import Constant, initializer
+
+from .modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+def zero_module(module):
+ for p in module.get_parameters():
+ constant_(p, 0.0)
+ return module
+
+
+def constant_(tensor: Parameter, val: float) -> None:
+ tensor.set_data(initializer(Constant(val), tensor.shape, tensor.dtype))
+
+
+class ControlNetFlux(mindspore.nn.Cell):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, params: FluxParams, controlnet_depth=2):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = mindspore.mint.nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ if params.guidance_embed
+ else mindspore.mint.nn.Identity()
+ )
+ self.txt_in = mindspore.mint.nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = mindspore.nn.CellList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(controlnet_depth)
+ ]
+ )
+
+ # add ControlNet blocks
+ self.controlnet_blocks = mindspore.nn.CellList([])
+ for _ in range(controlnet_depth):
+ controlnet_block = mindspore.mint.nn.Linear(self.hidden_size, self.hidden_size)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_blocks.append(controlnet_block)
+ self.pos_embed_input = mindspore.mint.nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.gradient_checkpointing = False
+ self.input_hint_block = mindspore.nn.SequentialCell(
+ mindspore.mint.nn.Conv2d(3, 16, 3, padding=1),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1),
+ mindspore.mint.nn.SiLU(),
+ mindspore.mint.nn.Conv2d(16, 16, 3, padding=1, stride=2),
+ mindspore.mint.nn.SiLU(),
+ zero_module(mindspore.mint.nn.Conv2d(16, 16, 3, padding=1)),
+ )
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @property
+ def attn_processors(self):
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: mindspore.nn.Cell, processors):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: mindspore.nn.Cell, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def construct(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ controlnet_cond: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ y: Tensor,
+ guidance: Tensor | None = None,
+ ) -> Tensor:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+ controlnet_cond = self.input_hint_block(controlnet_cond)
+ # controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ b, c, h, w = controlnet_cond.shape
+ h = h // 2 # ph=2
+ w = w // 2 # pw=2
+ controlnet_cond = controlnet_cond.reshape(b, c, h, 2, w, 2)
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
+ controlnet_cond = controlnet_cond.reshape(b, h * w, c * 4)
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
+ img = img + controlnet_cond
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = mindspore.mint.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+
+ block_res_samples = ()
+
+ for block in self.double_blocks:
+ if self.training and self.gradient_checkpointing:
+ raise NotImplementedError("Gradient checkpoint is not yet supported.")
+ else:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ block_res_samples = block_res_samples + (img,)
+
+ controlnet_block_res_samples = ()
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
+ block_res_sample = controlnet_block(block_res_sample)
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
+
+ return controlnet_block_res_samples
diff --git a/examples/canny_edit/src/model.py b/examples/canny_edit/src/model.py
new file mode 100644
index 0000000000..68840464d8
--- /dev/null
+++ b/examples/canny_edit/src/model.py
@@ -0,0 +1,198 @@
+from dataclasses import dataclass
+from typing import Union
+
+from src.modules.layers import DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
+
+from mindspore import Tensor, mint
+from mindspore import nn as nn
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Cell):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = mint.nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else mint.nn.Identity()
+ )
+ self.txt_in = mint.nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = nn.CellList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.CellList(
+ [
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ @property
+ def attn_processors(self):
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: nn.Cell, processors):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def construct(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ y: Tensor,
+ block_controlnet_hidden_states=None,
+ guidance: Union[Tensor, None] = None,
+ image_proj: Union[Tensor, None] = None,
+ ip_scale: Union[Tensor, None] = 1.0,
+ # change: input the attention_kwargs (including the parameters for CannyEdit) into the model
+ attention_kwargs={},
+ ) -> Tensor:
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+
+ vec = self.time_in(timestep_embedding(timesteps, 256))
+ if self.params.guidance_embed:
+ if guidance is None:
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = mint.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+ # if block_controlnet_hidden_states is not None:
+ # controlnet_depth = len(block_controlnet_hidden_states)
+
+ for index_block, block in enumerate(self.double_blocks):
+ if self.training and self.gradient_checkpointing:
+ raise NotImplementedError("Gradient checkpoint is not yet supported.")
+ else:
+ img, txt = block(
+ img=img,
+ txt=txt,
+ vec=vec,
+ pe=pe,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ # change: input the attention_kwargs (including the parameters for CannyEdit) into the model
+ attention_kwargs=attention_kwargs,
+ )
+ # controlnet residual
+ if block_controlnet_hidden_states is not None:
+ img = img + block_controlnet_hidden_states[index_block % 2]
+
+ img = mint.cat((txt, img), 1)
+
+ for index_block, block in enumerate(self.single_blocks):
+ if self.training and self.gradient_checkpointing:
+ raise NotImplementedError("Gradient checkpoint is not yet supported.")
+ else:
+ # change: input the attention_kwargs (including the parameters for CannyEdit) into the model
+ img = block(img, vec=vec, pe=pe, attention_kwargs=attention_kwargs)
+
+ img = img[:, txt.shape[1] :, ...]
+
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ return img
diff --git a/examples/canny_edit/src/modules/__init__.py b/examples/canny_edit/src/modules/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/canny_edit/src/modules/autoencoder.py b/examples/canny_edit/src/modules/autoencoder.py
new file mode 100644
index 0000000000..3cbffb9952
--- /dev/null
+++ b/examples/canny_edit/src/modules/autoencoder.py
@@ -0,0 +1,313 @@
+from dataclasses import dataclass
+
+from src.modules.math import scaled_dot_product_attention
+
+import mindspore as ms
+from mindspore import Tensor, mint
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: list[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+def swish(x: Tensor) -> Tensor:
+ return x * mint.sigmoid(x)
+
+
+class AttnBlock(ms.nn.Cell):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = mint.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.q = mint.nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.k = mint.nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.v = mint.nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = mint.nn.Conv2d(in_channels, in_channels, kernel_size=1)
+
+ def attention(self, h_: Tensor) -> Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q = q.permute(0, 2, 3, 1).reshape(b, 1, h * w, c).contiguous()
+ k = k.permute(0, 2, 3, 1).reshape(b, 1, h * w, c).contiguous()
+ v = v.permute(0, 2, 3, 1).reshape(b, 1, h * w, c).contiguous()
+ h_ = scaled_dot_product_attention(q, k, v)
+
+ return h_.reshape(b, h, w, c).permute(0, 3, 1, 2)
+
+ def construct(self, x: Tensor) -> Tensor:
+ return x + self.proj_out(self.attention(x))
+
+
+class ResnetBlock(ms.nn.Cell):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = mint.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = mint.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = mint.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = mint.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = mint.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def construct(self, x):
+ h = x
+ h = self.norm1(h)
+ h = swish(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = swish(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(ms.nn.Cell):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ # no asymmetric padding in mindspore conv, must do it ourselves
+ self.conv = mint.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def construct(self, x: Tensor):
+ pad = (0, 1, 0, 1)
+ x = mint.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(ms.nn.Cell):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = mint.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def construct(self, x: Tensor):
+ x = mint.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class Encoder(ms.nn.Cell):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = mint.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = ms.nn.CellList()
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = ms.nn.CellList()
+ attn = ms.nn.CellList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = ms.nn.Cell()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = ms.nn.Cell()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = mint.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = mint.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ def construct(self, x: Tensor) -> Tensor:
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(ms.nn.Cell):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = mint.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = ms.nn.Cell()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = ms.nn.CellList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = ms.nn.CellList()
+ attn = ms.nn.CellList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = ms.nn.Cell()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = mint.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = mint.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def construct(self, z: Tensor) -> Tensor:
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class DiagonalGaussian(ms.nn.Cell):
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
+ super().__init__()
+ self.sample = sample
+ self.chunk_dim = chunk_dim
+
+ def construct(self, z: Tensor) -> Tensor:
+ mean, logvar = mint.chunk(z, 2, dim=self.chunk_dim)
+ if self.sample:
+ std = mint.exp(0.5 * logvar)
+ return mean + std * mint.randn_like(mean)
+ else:
+ return mean
+
+
+class AutoEncoder(ms.nn.Cell):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def construct(self, x: Tensor) -> Tensor:
+ return self.decode(self.encode(x))
diff --git a/examples/canny_edit/src/modules/conditioner.py b/examples/canny_edit/src/modules/conditioner.py
new file mode 100644
index 0000000000..abd9e616c4
--- /dev/null
+++ b/examples/canny_edit/src/modules/conditioner.py
@@ -0,0 +1,45 @@
+from transformers import CLIPTokenizer, T5Tokenizer
+
+import mindspore as ms
+from mindspore import Tensor
+
+from mindone.transformers import CLIPTextModel, T5EncoderModel
+
+
+class HFEmbedder(ms.nn.Cell):
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
+ super().__init__()
+ self.is_clip = version.startswith("openai")
+ self.max_length = max_length
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
+
+ if self.is_clip:
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
+ else:
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
+
+ self.hf_module.set_train(False)
+
+ for param in self.hf_module.get_parameters():
+ param.requires_grad = False
+
+ def construct(self, text: list[str]) -> Tensor:
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="np",
+ )
+
+ outputs = self.hf_module(
+ input_ids=Tensor(batch_encoding["input_ids"]),
+ attention_mask=None,
+ output_hidden_states=False,
+ return_dict=True,
+ )
+ return outputs[self.output_key]
diff --git a/examples/canny_edit/src/modules/layers.py b/examples/canny_edit/src/modules/layers.py
new file mode 100644
index 0000000000..a514c85eaa
--- /dev/null
+++ b/examples/canny_edit/src/modules/layers.py
@@ -0,0 +1,511 @@
+import math
+from dataclasses import dataclass
+from typing import Union # Import Tuple and Union from typing
+
+from src.modules.math import apply_rope, attention, rope
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore import Tensor, mint, ops
+
+from mindone.transformers.mindspore_adapter.utils import _DTYPE_2_MIN
+
+
+def scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, dtype=None, training=True
+):
+ # force dtype(fp16 or bf16) precision calculation
+ ori_dtype = query.dtype
+ if dtype is not None:
+ query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == ms.bool_:
+ attn_mask = attn_mask.to(ms.float32)
+ attn_mask = attn_mask.masked_fill((1 - attn_mask).to(ms.bool_), _DTYPE_2_MIN[ms.float16])
+ attn_mask = attn_mask.to(query.dtype)
+
+ attn_weight = mint.nn.functional.softmax(
+ mint.matmul(query, mint.transpose(key, -2, -1)) / (query.shape[-1] ** 0.5) + attn_mask,
+ dim=-1,
+ dtype=ms.float32,
+ ).astype(query.dtype)
+ else:
+ L, S = query.shape[-2], key.shape[-2]
+ attn_bias = mint.zeros((L, S), dtype=query.dtype)
+ if is_causal:
+ temp_mask = mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0)
+ attn_bias = ops.masked_fill(attn_bias, mint.logical_not(temp_mask), _DTYPE_2_MIN[ms.float16])
+ attn_bias = attn_bias.to(query.dtype)
+
+ attn_weight = mint.nn.functional.softmax(
+ mint.matmul(query, mint.transpose(key, -2, -1)) / (query.shape[-1] ** 0.5) + attn_bias,
+ dim=-1,
+ dtype=ms.float32,
+ ).astype(query.dtype)
+
+ attn_weight = mint.nn.functional.dropout(attn_weight, p=dropout_p, training=training)
+
+ out = mint.matmul(attn_weight, value)
+ out = out.astype(ori_dtype)
+
+ return out
+
+
+# change
+def scaled_dot_product_attention2(
+ query,
+ key,
+ value,
+ image_size,
+ dropout_p=0.0,
+ is_causal=False,
+ attn_mask=None,
+ union_mask=None,
+ local_mask_list=[],
+ local_t2i_strength=1,
+ context_t2i_strength=1,
+ locali2i_strength=1,
+ local2out_i2i_strength=1,
+ num_edit_region=1,
+ scale=None,
+ enable_gqa=False,
+):
+ L, S = query.shape[-2], key.shape[-2]
+ scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
+ attn_bias = mint.zeros((L, S), dtype=query.dtype)
+ if is_causal:
+ assert attn_mask is None
+ temp_mask = mint.ones((L, S), dtype=ms.bool).tril(diagonal=0)
+ attn_bias = ops.masked_fill(attn_bias, mint.logical_not(temp_mask), _DTYPE_2_MIN[ms.float16])
+ attn_bias = attn_bias.to(query.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == ms.bool_:
+ attn_bias = ops.masked_fill(attn_bias, mint.logical_not(attn_mask), _DTYPE_2_MIN[ms.float16])
+ else:
+ attn_bias = attn_mask + attn_bias
+
+ if enable_gqa:
+ key = key.repeat_interleave(query.shape[-3] // key.shape[-3], -3)
+ value = value.repeat_interleave(query.shape[-3] // value.shape[-3], -3)
+
+ attn_weight = mint.matmul(query, mint.transpose(key, -2, -1)) * scale_factor
+ attn_weight += attn_bias
+
+ # Attention Amplification
+ # amplify the attention between the local text prompt and local edit region
+ curr_atten = attn_weight[:, :, -image_size:, 512 : 512 * (num_edit_region + 1)].copy()
+ attn_weight[:, :, -image_size:, 512 : 512 * (num_edit_region + 1)] = mint.where(
+ union_mask == 1, curr_atten, curr_atten * (local_t2i_strength)
+ )
+ # amplify the attention between the target prompt and the whole image
+ curr_atten1 = attn_weight[:, :, -image_size:, :512].copy()
+ attn_weight[:, :, -image_size:, :512] = curr_atten1 * (context_t2i_strength)
+
+ for local_mask in local_mask_list:
+ # outside the union of masks is 1
+ mask1_flat = union_mask.flatten() # (local_mask).flatten()
+ mask1_indices = 512 * (num_edit_region + 1) + mint.nonzero(mask1_flat, as_tuple=True)[0]
+ # mask2_flat inside the mask is 1
+ mask2_flat = (1 - local_mask).flatten()
+ mask2_indices = 512 * (num_edit_region + 1) + mint.nonzero(mask2_flat, as_tuple=True)[0]
+ # inside the other masks is 1
+ mask3_flat = 1 - mint.logical_or(mask1_flat.bool(), mask2_flat.bool()).int()
+ mask3_indices = 512 * (num_edit_region + 1) + mint.nonzero(mask3_flat, as_tuple=True)[0]
+
+ # amplify the attention within the edit region
+ attn_weight[:, :, mask2_indices[:, None], mask2_indices] = (
+ locali2i_strength * attn_weight[:, :, mask2_indices[:, None], mask2_indices]
+ )
+ # amplify the attention between the edit region and the bg region
+ attn_weight[:, :, mask2_indices[:, None], mask1_indices] = (
+ local2out_i2i_strength * attn_weight[:, :, mask2_indices[:, None], mask1_indices]
+ )
+ # amplify the attention between the edit region and other edit regions
+ attn_weight[:, :, mask2_indices[:, None], mask3_indices] = (
+ local2out_i2i_strength * attn_weight[:, :, mask2_indices[:, None], mask3_indices]
+ )
+
+ # END of Amplification
+
+ attn_weight = mint.nn.functional.softmax(attn_weight, dim=-1, dtype=ms.float32).astype(query.dtype)
+ attn_weight = mint.nn.functional.dropout(attn_weight, p=dropout_p, training=True)
+
+ out = mint.matmul(attn_weight, value)
+ out = out.astype(query.dtype)
+ return out
+
+
+class EmbedND(nn.Cell):
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def construct(self, ids: Tensor) -> Tensor:
+ n_axes = ids.shape[-1]
+ emb = mint.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+
+ return emb.unsqueeze(1)
+
+
+def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ t = time_factor * t
+ half = dim // 2
+ freqs = mint.exp(-math.log(max_period) * mint.arange(start=0, end=half, dtype=ms.float32) / half)
+
+ args = t[:, None].float() * freqs[None]
+ embedding = mint.cat([mint.cos(args), mint.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = mint.cat([embedding, mint.zeros_like(embedding[:, :1])], dim=-1)
+ if t.dtype in [ms.float16, ms.float32, ms.float64, ms.bfloat16]:
+ embedding = embedding.to(t.dtype)
+ return embedding
+
+
+class MLPEmbedder(nn.Cell):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = mint.nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = mint.nn.SiLU()
+ self.out_layer = mint.nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def construct(self, x: Tensor) -> Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class RMSNorm(nn.Cell):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.scale = ms.Parameter(mint.ones(dim))
+
+ def construct(self, x: Tensor):
+ x_dtype = x.dtype
+ x = x.float()
+ rrms = mint.rsqrt(mint.mean(x**2, dim=-1, keepdim=True) + 1e-6)
+ return (x * rrms).to(dtype=x_dtype) * self.scale
+
+
+class QKNorm(nn.Cell):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = RMSNorm(dim)
+ self.key_norm = RMSNorm(dim)
+
+ def construct(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
+ q = self.query_norm(q)
+ k = self.key_norm(k)
+ return q.to(v.dtype), k.to(v.dtype)
+
+
+class FLuxSelfAttnProcessor:
+ def __call__(self, attn, x, pe, **attention_kwargs):
+ qkv = attn.qkv(x)
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ B, L, _ = qkv.shape
+ qkv_reshaped = qkv.reshape(B, L, 3, self.num_heads, -1)
+ q, k, v = qkv_reshaped.permute(2, 0, 3, 1, 4)
+ q, k = attn.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = attn.proj(x)
+ return x
+
+
+class SelfAttention(nn.Cell):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = mint.nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = mint.nn.Linear(dim, dim)
+
+ def construct(self, x: Tensor):
+ # a dummy construct function to avoid error
+ return x
+
+
+@dataclass
+class ModulationOut:
+ shift: Tensor
+ scale: Tensor
+ gate: Tensor
+
+
+class Modulation(nn.Cell):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = mint.nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def construct(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
+ out = self.lin(mint.nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
+
+ return (
+ ModulationOut(*out[:3]),
+ ModulationOut(*out[3:]) if self.is_double else None,
+ )
+
+
+class DoubleStreamBlockProcessor:
+ def __call__(self, attn, img, txt, vec, pe, attention_kwargs):
+ img_mod1, img_mod2 = attn.img_mod(vec)
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = attn.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = attn.img_attn.qkv(img_modulated)
+ # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
+ B, L, _ = img_qkv.shape
+ img_qkv_reshaped = img_qkv.reshape(B, L, 3, attn.num_heads, attn.head_dim)
+ img_q, img_k, img_v = img_qkv_reshaped.permute(2, 0, 3, 1, 4)
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = attn.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
+
+ # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
+ B, L, _ = txt_qkv.shape
+ txt_qkv_reshaped = txt_qkv.reshape(B, L, 3, attn.num_heads, attn.head_dim)
+ txt_q, txt_k, txt_v = txt_qkv_reshaped.permute(2, 0, 3, 1, 4)
+
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # change
+ if "regional_attention_mask" in attention_kwargs:
+ q = mint.cat((txt_q, img_q), dim=2)
+ k = mint.cat((txt_k, img_k), dim=2)
+ v = mint.cat((txt_v, img_v), dim=2)
+ q, k = apply_rope(q, k, pe)
+ attention_mask = attention_kwargs["regional_attention_mask"]
+ if "union_mask" in attention_kwargs:
+ x = scaled_dot_product_attention2(
+ q,
+ k,
+ v,
+ attention_kwargs["image_size"],
+ dropout_p=0.0,
+ is_causal=False,
+ attn_mask=attention_mask,
+ union_mask=attention_kwargs["union_mask"],
+ local_mask_list=attention_kwargs["local_mask_all_dilate"],
+ local_t2i_strength=attention_kwargs["local_t2i_strength"],
+ context_t2i_strength=attention_kwargs["context_t2i_strength"],
+ locali2i_strength=attention_kwargs["local_i2i_strength"],
+ local2out_i2i_strength=attention_kwargs["local2out_i2i_strength"],
+ num_edit_region=attention_kwargs["num_edit_region"],
+ )
+
+ else:
+ x = scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, attn_mask=attention_mask)
+
+ # attn1 = rearrange(x, "B H L D -> B L (H D)")
+ B, H, L, D = x.shape
+ attn1 = x.permute(0, 2, 1, 3).reshape(B, L, -1)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+ else:
+ q = mint.cat((txt_q, img_q), dim=2)
+ k = mint.cat((txt_k, img_k), dim=2)
+ v = mint.cat((txt_v, img_v), dim=2)
+ attn1 = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
+
+ return img, txt
+
+
+class DoubleStreamBlock(nn.Cell):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
+ super().__init__()
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.head_dim = hidden_size // num_heads
+
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.img_norm2 = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_mlp = nn.SequentialCell(
+ mint.nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ mint.nn.GELU(approximate="tanh"),
+ mint.nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.txt_norm2 = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_mlp = ms.nn.SequentialCell(
+ mint.nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ mint.nn.GELU(approximate="tanh"),
+ mint.nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+ processor = DoubleStreamBlockProcessor()
+ self.set_processor(processor)
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def construct(
+ self,
+ img: Tensor,
+ txt: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Tensor = None,
+ ip_scale: float = 1.0,
+ attention_kwargs={},
+ ) -> tuple[Tensor, Tensor]:
+ if image_proj is None:
+ return self.processor(self, img, txt, vec, pe, attention_kwargs)
+ else:
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
+
+
+class SingleStreamBlockProcessor:
+ def __call__(self, attn: nn.Cell, x: Tensor, vec: Tensor, pe: Tensor, attention_kwargs) -> Tensor:
+ mod, _ = attn.modulation(vec)
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
+ qkv, mlp = mint.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
+
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
+ B, L, _ = qkv.shape
+ qkv_reshaped = qkv.reshape(B, L, 3, attn.num_heads, -1)
+ q, k, v = qkv_reshaped.permute(2, 0, 3, 1, 4)
+ q, k = attn.norm(q, k, v)
+
+ # change
+ if "regional_attention_mask" in attention_kwargs:
+ q, k = apply_rope(q, k, pe)
+ attention_mask = attention_kwargs["regional_attention_mask"]
+
+ if "union_mask" in attention_kwargs:
+ attn_1 = scaled_dot_product_attention2(
+ q,
+ k,
+ v,
+ attention_kwargs["image_size"],
+ dropout_p=0.0,
+ is_causal=False,
+ attn_mask=attention_mask,
+ union_mask=attention_kwargs["union_mask"],
+ local_mask_list=attention_kwargs["local_mask_all_dilate"],
+ local_t2i_strength=attention_kwargs["local_t2i_strength"],
+ context_t2i_strength=attention_kwargs["context_t2i_strength"],
+ locali2i_strength=attention_kwargs["local_i2i_strength"],
+ local2out_i2i_strength=attention_kwargs["local2out_i2i_strength"],
+ num_edit_region=attention_kwargs["num_edit_region"],
+ )
+ else:
+ attn_1 = scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, attn_mask=attention_mask)
+
+ # attn_1 = rearrange(attn_1, "B H L D -> B L (H D)")
+ B, H, L, D = attn_1.shape
+ attn_1 = attn_1.permute(0, 2, 1, 3).reshape(B, L, -1)
+ else:
+ attn_1 = attention(q, k, v, pe=pe)
+
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = attn.linear2(mint.cat((attn_1, attn.mlp_act(mlp)), 2))
+ output = x + mod.gate * output
+
+ return output
+
+
+class SingleStreamBlock(nn.Cell):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ """
+
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: Union[float, None] = None):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = mint.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = mint.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(self.head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.mlp_act = mint.nn.GELU(approximate="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ processor = SingleStreamBlockProcessor()
+ self.set_processor(processor)
+
+ def set_processor(self, processor) -> None:
+ self.processor = processor
+
+ def get_processor(self):
+ return self.processor
+
+ def construct(
+ self,
+ x: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ image_proj: Union[Tensor, None] = None,
+ ip_scale: float = 1.0,
+ attention_kwargs={},
+ ) -> Tensor:
+ if image_proj is None:
+ return self.processor(self, x, vec, pe, attention_kwargs)
+ else:
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
+
+
+class LastLayer(nn.Cell):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = mint.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = mint.nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = ms.nn.SequentialCell(
+ mint.nn.SiLU(), mint.nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def construct(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
diff --git a/examples/canny_edit/src/modules/math.py b/examples/canny_edit/src/modules/math.py
new file mode 100644
index 0000000000..9d0dc72185
--- /dev/null
+++ b/examples/canny_edit/src/modules/math.py
@@ -0,0 +1,34 @@
+import mindspore as ms
+from mindspore import Tensor, ops
+
+from mindone.transformers.mindspore_adapter import scaled_dot_product_attention
+
+
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
+ q, k = apply_rope(q, k, pe)
+
+ x = scaled_dot_product_attention(q, k, v)
+ B, H, L, D = x.shape
+
+ # x = rearrange(x, "B H L D -> B L (H D)")
+ x = x.permute(0, 2, 1, 3).reshape(B, L, H * D)
+
+ return x
+
+
+def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
+ assert dim % 2 == 0
+ scale = ops.arange(0, dim, 2, dtype=ms.float64) / dim
+ omega = 1.0 / (theta**scale)
+ out = pos.unsqueeze(-1) * omega
+ out = ops.stack([ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1)
+ out = out.reshape(*out.shape[:-1], 2, 2)
+ return out.float()
+
+
+def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
diff --git a/examples/canny_edit/src/sampling.py b/examples/canny_edit/src/sampling.py
new file mode 100644
index 0000000000..309421f54f
--- /dev/null
+++ b/examples/canny_edit/src/sampling.py
@@ -0,0 +1,720 @@
+import math
+import random
+from typing import Callable, Union
+
+import numpy as np
+from src.model import Flux
+from src.modules.conditioner import HFEmbedder
+from tqdm import tqdm
+
+import mindspore as ms
+import mindspore.dataset.transforms as transforms
+import mindspore.dataset.vision as vision
+from mindspore import Tensor
+from mindspore.dataset.vision import Inter
+
+
+def get_noise(
+ num_samples: int,
+ height: int,
+ width: int,
+ dtype: ms.dtype,
+ seed: int,
+):
+ ms.set_seed(seed)
+ return ms.mint.randn(
+ num_samples,
+ 16,
+ # allow for packing
+ 2 * math.ceil(height / 16),
+ 2 * math.ceil(width / 16),
+ dtype=dtype,
+ )
+
+
+def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
+ bs, c, h, w = img.shape
+ if bs == 1 and not isinstance(prompt, str):
+ bs = len(prompt)
+
+ b, c, h, w = img.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ img = img.reshape(b, c, sh, 2, sw, 2)
+ img = img.permute(0, 2, 4, 1, 3, 5)
+ img = img.reshape(b, sh * sw, c * 4)
+
+ if img.shape[0] == 1 and bs > 1:
+ img = img.broadcast_to((bs, *img.shape[1:]))
+
+ img_ids = ms.mint.zeros((h // 2, w // 2, 3))
+ img_ids[..., 1] = img_ids[..., 1] + ms.mint.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + ms.mint.arange(w // 2)[None, :]
+ # img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) # keep for debugging
+ h, w, _ = img_ids.shape
+ img_ids = img_ids.reshape(1, h, w, 3).broadcast_to((bs, h, w, 3))
+ img_ids = img_ids.reshape(bs, h * w, 3)
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ txt = t5(prompt)
+ # if txt.shape[0] == 1 and bs > 1:
+ # txt = repeat(txt, "1 ... -> bs ...", bs=bs) # keep for debugging
+ if txt.shape[0] == 1 and bs > 1:
+ txt = txt.broadcast_to((bs, *txt.shape[1:]))
+ txt_ids = ms.mint.zeros((bs, txt.shape[1], 3))
+
+ vec = clip(prompt)
+ # if vec.shape[0] == 1 and bs > 1:
+ # vec = repeat(vec, "1 ... -> bs ...", bs=bs) # keep for debugging
+ if vec.shape[0] == 1 and bs > 1:
+ vec = vec.broadcast_to((bs, *vec.shape[1:]))
+
+ return {
+ "img": img,
+ "img_ids": img_ids,
+ "txt": txt,
+ "txt_ids": txt_ids,
+ "vec": vec,
+ }
+
+
+def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+
+def time_shift(mu: float, sigma: float, t: Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+
+def get_schedule(
+ num_steps: int,
+ image_seq_len: int,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ shift: bool = True,
+) -> list[float]:
+ # extra step for zero
+ timesteps = ms.mint.linspace(1, 0, num_steps + 1)
+
+ # shifting the schedule to favor high timesteps for higher signal images
+ if shift:
+ # eastimate mu based on linear estimation between two points
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
+ timesteps = time_shift(mu, 1.0, timesteps)
+
+ return timesteps.tolist()
+
+
+def denoise(
+ model: Flux,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ neg_txt: Tensor,
+ neg_txt_ids: Tensor,
+ neg_vec: Tensor,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ true_gs=1,
+ timestep_to_start_cfg=0,
+ # ip-adapter parameters
+ image_proj: Tensor = None,
+ neg_image_proj: Tensor = None,
+ ip_scale: Tensor | float = 1.0,
+ neg_ip_scale: Tensor | float = 1.0,
+):
+ i = 0
+ # this is ignored for schnell
+ guidance_vec = ms.ops.full((img.shape[0],), guidance, dtype=img.dtype)
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
+ t_vec = ms.ops.full(
+ (img.shape[0],),
+ t_curr,
+ dtype=img.dtype,
+ )
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ )
+ if i >= timestep_to_start_cfg:
+ neg_pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=neg_txt,
+ txt_ids=neg_txt_ids,
+ y=neg_vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ image_proj=neg_image_proj,
+ ip_scale=neg_ip_scale,
+ )
+ pred = neg_pred + true_gs * (pred - neg_pred)
+ img = img + (t_prev - t_curr) * pred
+ i += 1
+ return img
+
+
+def denoise_fireflow(
+ model: Flux,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ # sampling parameters
+ timesteps: list[float],
+ inverse,
+ info,
+ guidance: float = 4.0,
+):
+ if inverse:
+ timesteps = timesteps[::-1]
+ guidance_vec = ms.ops.full((img.shape[0],), guidance, dtype=img.dtype)
+
+ next_step_velocity = None
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
+ t_vec = ms.ops.full((img.shape[0],), t_curr, dtype=img.dtype)
+ info["t"] = t_prev if inverse else t_curr
+ info["inverse"] = inverse
+ info["second_order"] = False
+
+ if inverse is True:
+ if next_step_velocity is None:
+ block_res_samples = info["controlnet"](
+ img=img,
+ img_ids=img_ids,
+ controlnet_cond=info["controlnet_cond"],
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[ij * info["controlnet_gs"] for ij in block_res_samples],
+ )
+
+ else:
+ pred = next_step_velocity
+
+ img_mid = img + (t_prev - t_curr) / 2 * pred
+
+ t_vec_mid = ms.ops.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype)
+ info["second_order"] = True
+
+ block_res_samples = info["controlnet"](
+ img=img_mid,
+ img_ids=img_ids,
+ controlnet_cond=info["controlnet_cond"],
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec_mid,
+ guidance=guidance_vec,
+ )
+
+ pred_mid = model(
+ img=img_mid,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec_mid,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[ij * info["controlnet_gs"] for ij in block_res_samples],
+ )
+
+ next_step_velocity = pred_mid
+ img = img + (t_prev - t_curr) * pred_mid
+ info[t_curr] = img
+
+ return img, info
+
+
+def process_mask(input_mask, height, width, latent_image, kernel_size=1):
+ """
+ Process the input mask and return processed_mask, dilated_mask, and flattened_mask.
+
+ Args:
+ input_mask (ms.Tensor or None): Input mask tensor or None.
+ height (int): Height to be used for processing.
+ width (int): Width to be used for processing.
+ latent_image (ms.Tensor): Source image latent tensor (used for dtype).
+ kernel_size (int): Size of the dilation kernel (default is 1).
+
+ Returns:
+ tuple: (processed_mask, dilated_mask, flattened_mask)
+ """
+ # Initialize the processed mask based on the input mask
+ if input_mask is None:
+ processed_mask = ms.mint.ones((1, int(height / 16) * int(width / 16), 1))
+ else:
+ processed_mask = input_mask.copy()
+ # Ensure processed_mask has the correct dtype and is on GPU
+ processed_mask = processed_mask.to(latent_image.dtype)
+
+ # Convert processed_mask to numpy and prepare for dilation
+ processed_mask_np = (1 - processed_mask.copy().float()).asnumpy()
+ processed_mask_np = np.squeeze(processed_mask_np).reshape(int(height / 16), int(width / 16))
+ # Kernel size and number of iterations for dilation
+ # Perform dilation (currently commented out in the original code)
+ dilated_mask_np = processed_mask_np # Example: cv2.dilate(processed_mask_np, kernel, iterations=iterations)
+ dilated_mask_np_larger = processed_mask_np # Example: cv2.dilate(processed_mask_np, (4 * int(height / 512), 4 * int(height / 512)), iterations=iterations)
+ # Convert dilated masks back to mindspore tensors
+ dilated_mask = ms.tensor(dilated_mask_np, dtype=ms.float32).flatten().unsqueeze(1)
+ dilated_mask_larger = ms.tensor(dilated_mask_np_larger, dtype=ms.float32).flatten().unsqueeze(1)
+ # Update processed_mask and dilated_mask_larger
+ processed_mask = 1 - dilated_mask
+ dilated_mask = 1 - dilated_mask_larger
+ # Compute flattened_mask
+ flattened_mask = (1 - processed_mask).flatten()
+ return processed_mask, dilated_mask, flattened_mask
+
+
+def denoise_cannyedit(
+ model: Flux,
+ controlnet: None,
+ source_image_latent: Tensor,
+ source_image_latent_rg: Tensor,
+ img: Tensor,
+ img_ids: Tensor,
+ # source prompt-related embeddings
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ # local prompt 1-related embeddings
+ txt2: Tensor,
+ txt_ids2: Tensor,
+ vec2: Tensor,
+ # target prompt-related embeddings
+ txt3: Tensor,
+ txt_ids3: Tensor,
+ vec3: Tensor,
+ # additional local prompts-related embeddings
+ txt_addition: list[Tensor],
+ txt_ids_addition: list[Tensor],
+ vec_addition: list[Tensor],
+ # negative prompt-related embeddings
+ neg_txt: Tensor,
+ neg_txt_ids: Tensor,
+ neg_vec: Tensor,
+ local_mask,
+ local_mask_addition,
+ controlnet_cond,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ true_gs=1,
+ controlnet_gs=0.7,
+ controlnet_gs2=0.5,
+ timestep_to_start_cfg=0,
+ # ip-adapter parameters
+ image_proj: Tensor = None,
+ neg_image_proj: Tensor = None,
+ ip_scale: Union[Tensor, float] = 1,
+ neg_ip_scale: Union[Tensor, float] = 1,
+ seed=random.randint(0, 99999),
+ generate_save_path=None,
+ inversion_save_path=None,
+ stage="stage_generate",
+):
+ guidance_vec = ms.ops.full((img.shape[0],), guidance, dtype=img.dtype)
+ t_length = len(timesteps)
+ info_generate = {}
+ time_to_start = 2 # 1
+ i = 0
+
+ for t_curr, t_prev in tqdm(
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, desc="CannyEdit Denoising Steps"
+ ):
+ if i == 0:
+ # -------------------------Inversion-------------------------------------------------------------------------
+ timesteps_inv = timesteps # [time_to_start-2:]
+ info = {}
+ info["controlnet_cond"] = controlnet_cond
+ info["controlnet"] = controlnet
+ info["controlnet_gs"] = controlnet_gs2
+ z, info = denoise_fireflow(
+ model,
+ source_image_latent_rg,
+ img_ids,
+ txt,
+ txt_ids,
+ vec,
+ timesteps_inv,
+ guidance=1,
+ inverse=True,
+ info=info,
+ )
+ if inversion_save_path is not None and stage == "stage_generate":
+ np.save(inversion_save_path, info)
+
+ # -----------------------End of Inversion----------------------------------------------------------------------
+
+ # ---------------------Processing mask---------------------------------------------------------------------
+ # print('Denoising Start....')
+ bs, c, h, w = source_image_latent.shape
+ H_use = int(h * 8)
+ W_use = int(w * 8)
+
+ b, c, h, w = source_image_latent.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # source_image_latent = rearrange(source_image_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep fpr debugging
+ source_image_latent = source_image_latent.reshape(b, c, sh, 2, sw, 2)
+ source_image_latent = source_image_latent.permute(0, 2, 4, 1, 3, 5)
+ source_image_latent = source_image_latent.reshape(b, sh * sw, c * 4)
+
+ source_image_latent = source_image_latent.broadcast_to((bs, *source_image_latent.shape[1:]))
+ # source_image_latent = repeat(source_image_latent, "1 ... -> bs ...", bs=bs) # keep fpr debugging
+
+ # process the first local mask
+ # after processing, the value inside the edit region is 0 and is 1 elsewhere in local_mask1_proceed and local_mask1_dilate;
+ # for local_mask1_flat, inside edit region is 1 instead
+ local_mask1_proceed, local_mask1_dilate, local_mask1_flat = process_mask(
+ local_mask, H_use, W_use, source_image_latent, kernel_size=1
+ )
+ # process the additional local masks
+ local_mask_add_proceed = []
+ local_mask_add_dilate = []
+ local_mask_add_flat = []
+ local_mask_all_dilate = []
+ local_mask_all_dilate.append(local_mask1_dilate)
+ if local_mask_addition != []:
+ for local_mask1 in local_mask_addition:
+ local_mask2_proceed, local_mask2_dilate, local_mask2_flat = process_mask(
+ local_mask1, H_use, W_use, source_image_latent, kernel_size=1
+ )
+ local_mask_add_proceed.append(local_mask2_proceed)
+ local_mask_add_dilate.append(local_mask2_dilate)
+ local_mask_add_flat.append(local_mask2_flat)
+ local_mask_all_dilate.append(local_mask2_dilate)
+
+ # initialize the mask (union_mask) used for canny control relaxation and blending, where the value inside the union
+ # of the edit regions is 0 and is 1 elsewhere.
+ if local_mask_addition == []:
+ union_mask = local_mask1_dilate
+ elif local_mask_addition != []:
+ union_inverted = 1 - local_mask1_dilate
+ for mask_dilate in local_mask_add_dilate:
+ mask_dilate_inverted = 1 - mask_dilate
+ union_inverted = ms.mint.logical_or(union_inverted.bool(), mask_dilate_inverted.bool())
+ union_inverted = union_inverted.int()
+ union_mask = 1 - union_inverted
+ # ------------------End of processing mask------------------------------------------------------------------
+
+ # ------------------Handle attention mask-------------------------------------------------------------------
+ # len(local_mask_add_proceed)=number of additional local edit prompts + 1 (the first local edit prompt) + 1 (target prompt)
+ conds = [None] * (len(local_mask_add_proceed) + 2)
+ masks = [None] * (len(local_mask_add_proceed) + 2)
+
+ # the first local prompt and mask for the first local edit region
+ conds[1] = txt2
+ masks[1] = 1 - local_mask1_proceed.flatten().unsqueeze(1).repeat(1, conds[1].shape[1])
+ # the additional local prompts and their corresponding local edit regions
+ for indd in range(len(local_mask_add_proceed)):
+ conds[2 + indd] = txt_addition[indd]
+ masks[2 + indd] = 1 - local_mask_add_proceed[indd].flatten().unsqueeze(1).repeat(
+ 1, conds[2 + indd].shape[1]
+ )
+ # the target prompt and its mask see the whole image
+ conds[0] = txt3
+ masks[0] = ms.mint.ones_like(masks[1])
+
+ regional_embeds = ms.mint.cat(conds, dim=1)
+ encoder_seq_len = regional_embeds.shape[1]
+ hidden_seq_len = source_image_latent.shape[1]
+ txt_ids_region = ms.mint.zeros((regional_embeds.shape[1], 3)).to(dtype=txt_ids.dtype).unsqueeze(0)
+
+ # initialize attention mask
+ regional_attention_mask = ms.mint.zeros(
+ (encoder_seq_len + hidden_seq_len, encoder_seq_len + hidden_seq_len), dtype=ms.bool
+ )
+ num_of_regions = len(masks)
+ each_prompt_seq_len = encoder_seq_len // num_of_regions
+
+ # ================================
+ # T2T, T2I and I2T attention mask
+ # Each text can only see itself
+ # Local prompt can only see/be seen by the local edit region
+ # Target prompt can see/be seen by the whole image
+ for ij in range(num_of_regions):
+ # t2t mask txt attends to itself
+ regional_attention_mask[
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len,
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len,
+ ] = True
+ # t2i and i2t mask
+ regional_attention_mask[
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len, encoder_seq_len:
+ ] = masks[ij].transpose(-1, -2)
+ regional_attention_mask[
+ encoder_seq_len:, ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len
+ ] = masks[ij]
+
+ # ================================
+ # I2I mask
+ # I2I attention mask:
+ # initialization
+
+ attention_mask_i2i = ms.mint.zeros(
+ (int(H_use / 16) * int(W_use / 16), int(H_use / 16) * int(W_use / 16)), dtype=ms.int
+ )
+ # background region can only see background region
+ # Find the union of regions where both masks are 0
+ zero_union_mask = local_mask1_flat == 0
+ # Iterate over all masks in the list and combine their conditions
+ for local_mask2_flat in local_mask_add_flat:
+ zero_union_mask &= local_mask2_flat == 0
+ zero_union_indices = ms.mint.nonzero(zero_union_mask, as_tuple=True)[0]
+ # attention_mask_i2i[zero_union_indices[:, None],] = 1
+ attention_mask_i2i[zero_union_indices[:, None], zero_union_indices] = 1
+
+ # edited region can only see the whole image
+ mask1_indices = ms.mint.nonzero(local_mask1_flat, as_tuple=True)[0]
+ attention_mask_i2i[mask1_indices, :] = 1
+
+ for local_mask2_flat in local_mask_add_flat:
+ mask2_indices = ms.mint.nonzero(local_mask2_flat, as_tuple=True)[0]
+ attention_mask_i2i[mask2_indices, :] = 1
+
+ regional_attention_mask[encoder_seq_len:, encoder_seq_len:] = ms.mint.ones_like(
+ regional_attention_mask[encoder_seq_len:, encoder_seq_len:]
+ )
+
+ # ------------------End of Handle attention mask-------------------------------------------------------------------
+
+ apply_local_point = 0.7
+ apply_extenda_point = 0.5
+
+ if timesteps[i] not in info:
+ tempp = timesteps[i + 1]
+
+ else:
+ tempp = timesteps[i]
+
+ if i >= time_to_start:
+ if i == time_to_start:
+ img = info[tempp]
+ # ------------------Reinitialize each local edit region--------------------------------------------------
+
+ x2 = get_noise(1, H_use, W_use, dtype=ms.bfloat16, seed=seed)
+ bs, c, h, w = x2.shape
+
+ b, c, h, w = x2.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # x2 = rearrange(x2, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ x2 = x2.reshape(b, c, sh, 2, sw, 2)
+ x2 = x2.permute(0, 2, 4, 1, 3, 5)
+ x2 = x2.reshape(b, sh * sw, c * 4)
+
+ # x2 = repeat(x2, "1 ... -> bs ...", bs=bs) # keep for debugging
+ x2 = x2.broadcast_to((bs, *x2.shape[1:]))
+ img[:, local_mask1_flat.bool(), :] = x2[:, local_mask1_flat.bool(), :]
+ seed += 1
+
+ for local_mask2_flat in local_mask_add_flat:
+ x3 = get_noise(1, H_use, W_use, dtype=ms.bfloat16, seed=seed)
+ bs, c, h, w = x3.shape
+
+ b, c, h, w = x3.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # x3 = rearrange(x3, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ x3 = x3.reshape(b, c, sh, 2, sw, 2)
+ x3 = x3.permute(0, 2, 4, 1, 3, 5)
+ x3 = x3.reshape(b, sh * sw, c * 4)
+
+ # x3= repeat(x3, "1 ... -> bs ...", bs=bs)
+ x3 = x3.broadcast_to((bs, *x3.shape[1:]))
+ img[:, local_mask2_flat.bool(), :] = x3[:, local_mask2_flat.bool(), :]
+ seed += 1
+
+ # ------------------END of Reinitialize each local edit region-------------------------------------------
+
+ # ================================== Start Denoising =================================================
+
+ t_vec = ms.mint.full((img.shape[0],), t_curr, dtype=img.dtype)
+ imgg = info[tempp]
+
+ block_res_samples = controlnet(
+ # use img=imgg if we want to use the original image+noise in the controlnet
+ img=imgg,
+ img_ids=img_ids,
+ controlnet_cond=controlnet_cond,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+ # ------------------Selective Canny Masking-------------------------------------------------------------------
+ soft_masks = []
+ # Generate a Gaussian soft mask for each tensor
+ for tensor in block_res_samples:
+ soft_masks.append(union_mask.to(dtype=tensor.dtype)) # Add to list
+ block_res_samples = [block_res_samples[i] * soft_masks[i] for i in range(len(block_res_samples))]
+ # ------------------End of Selective Canny Masking--------------------------------------------------------------
+
+ # Stage 1: Regional Denoising
+ if i < int(apply_local_point * t_length):
+ attention_kwargs = {}
+ # *********************************************
+ # union of all local mask
+ attention_kwargs["union_mask"] = union_mask
+ # input each local mask
+ attention_kwargs["local_mask_all_dilate"] = local_mask_all_dilate
+ # number of local edit regions
+ attention_kwargs["num_edit_region"] = 1 + len(txt_addition)
+ # attention_mask
+ attention_kwargs["regional_attention_mask"] = regional_attention_mask.bool()
+ # 0811
+ # amplify the attention between the local text promt and local edit region
+ attention_kwargs["local_t2i_strength"] = 1 + 0.2 * (1 - (i / (apply_local_point * t_length)))
+ # amplify the attention between the target prompt and the whole image
+ attention_kwargs["context_t2i_strength"] = 1
+ # amplify the attention within each edit region
+ attention_kwargs["local_i2i_strength"] = 1
+ # amplify the attention between each edit region and other regions
+ attention_kwargs["local2out_i2i_strength"] = 1.0
+ attention_kwargs["image_size"] = int(H_use / 16) * int(W_use / 16)
+ # *********************************************
+ if i <= int(t_length * apply_extenda_point):
+ controlnet_control = [ij * controlnet_gs2 for ij in block_res_samples]
+ else:
+ controlnet_control = None
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=regional_embeds,
+ txt_ids=txt_ids_region,
+ y=ms.mint.zeros_like(vec3),
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=controlnet_control,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ attention_kwargs=attention_kwargs,
+ )
+
+ # Stage 2: Normal Denoising
+ else:
+ attention_kwargs = {}
+ if i <= int(t_length * apply_extenda_point):
+ controlnet_control = [ij * controlnet_gs2 for ij in block_res_samples]
+ else:
+ controlnet_control = None
+
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt3,
+ txt_ids=txt_ids3,
+ y=vec3,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=controlnet_control,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ attention_kwargs=attention_kwargs,
+ )
+
+ img = img + (t_prev - t_curr) * pred
+
+ # Cyclical Blending
+ if i < 5 or (i <= 30 and i % 5 == 0):
+ img = ms.mint.where(union_mask == 1, 0.5 * info[tempp] + 0.5 * img, img)
+ elif i <= 40 and i % 10 == 0:
+ img = ms.mint.where(union_mask == 1, 0.2 * info[tempp] + 0.8 * img, img)
+
+ info_generate[i] = img
+
+ i += 1
+
+ if generate_save_path is not None and stage == "stage_generate":
+ np.save(generate_save_path, info_generate)
+ return img
+
+
+def unpack(x: Tensor, height: int, width: int) -> Tensor:
+ b = x.shape[0]
+ h = math.ceil(height / 16)
+ w = math.ceil(width / 16)
+ # return rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ # h=h, w=w, ph=2, pw=2) # keep for debugging
+ x = x.reshape(b, h * w, -1) # -1 will infer the correct size for c*ph*pw
+ c = x.shape[2] // 4 # since ph=pw=2, total is 4
+ x = x.reshape(b, h, w, c, 2, 2)
+ x = x.transpose(0, 3, 1, 4, 2, 5)
+ x = x.reshape(b, c, h * 2, w * 2)
+ return x
+
+
+def get_image_tensor(
+ image,
+ height: int,
+ width: int,
+ dtype: ms.dtype,
+):
+ # transforms used for preprocessing dataset
+ train_transforms = transforms.Compose(
+ [
+ vision.Resize((height, width), interpolation=Inter.BILINEAR),
+ vision.ToTensor(),
+ vision.Normalize([0.5], [0.5]),
+ ]
+ )
+ image_tensor = train_transforms(image)
+ image_tensor = ms.tensor(image_tensor)
+ image_tensor = image_tensor.to(dtype)
+ return image_tensor
+
+
+# def get_image_mask(img, height: int, width: int, dtype: ms.dtype, ):
+# img = np.array(img).astype(np.float32)
+# if len(img.shape) == 3:
+# img = img[:, :, 0]
+#
+# if np.max(img) > 128:
+# img = img / 255
+#
+# img[img > 0.5] = 1.0
+# img[img <= 0.5] = 0.0
+# img = img * 255.0
+#
+# _, img = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY)
+# img = Image.fromarray(img.astype("uint8")).convert("L")
+#
+# resize = vision.Resize((height, width))
+# img = resize(img)
+# toT = vision.ToTensor()
+# img = toT(img)
+# img[img != 0] = 1
+# img = img.unsqueeze(0)
+# img = img.to(dtype)
+# return img
diff --git a/examples/canny_edit/src/sampling_removal.py b/examples/canny_edit/src/sampling_removal.py
new file mode 100644
index 0000000000..10754b1724
--- /dev/null
+++ b/examples/canny_edit/src/sampling_removal.py
@@ -0,0 +1,654 @@
+import math
+import random
+from typing import Callable, Union
+
+import numpy as np
+from src.model import Flux
+from tqdm import tqdm
+
+import mindspore as ms
+from mindspore import Tensor, mint
+
+
+def get_noise(
+ num_samples: int,
+ height: int,
+ width: int,
+ dtype: ms.dtype,
+ seed: int,
+):
+ ms.set_seed(seed)
+ return ms.mint.randn(
+ num_samples,
+ 16,
+ # allow for packing
+ 2 * math.ceil(height / 16),
+ 2 * math.ceil(width / 16),
+ dtype=dtype,
+ )
+
+
+def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+
+def time_shift(mu: float, sigma: float, t: Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+
+def get_schedule(
+ num_steps: int,
+ image_seq_len: int,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+ shift: bool = True,
+) -> list[float]:
+ # extra step for zero
+ timesteps = ms.mint.linspace(1, 0, num_steps + 1)
+
+ # shifting the schedule to favor high timesteps for higher signal images
+ if shift:
+ # eastimate mu based on linear estimation between two points
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
+ timesteps = time_shift(mu, 1.0, timesteps)
+
+ return timesteps.tolist()
+
+
+# Add
+def denoise_fireflow(
+ model: Flux,
+ # model input
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ # sampling parameters
+ timesteps: list[float],
+ inverse,
+ info,
+ guidance: float = 4.0,
+):
+ if inverse:
+ timesteps = timesteps[::-1]
+ guidance_vec = ms.ops.full((img.shape[0],), guidance, dtype=img.dtype)
+
+ next_step_velocity = None
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
+ t_vec = ms.ops.full((img.shape[0],), t_curr, dtype=img.dtype)
+ info["t"] = t_prev if inverse else t_curr
+ info["inverse"] = inverse
+ info["second_order"] = False
+
+ if inverse is True:
+ if next_step_velocity is None:
+ block_res_samples = info["controlnet"](
+ img=img,
+ img_ids=img_ids,
+ controlnet_cond=info["controlnet_cond"],
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[ij * info["controlnet_gs"] for ij in block_res_samples],
+ )
+
+ else:
+ pred = next_step_velocity
+
+ img_mid = img + (t_prev - t_curr) / 2 * pred
+
+ t_vec_mid = ms.ops.full((img.shape[0],), t_curr + (t_prev - t_curr) / 2, dtype=img.dtype)
+ info["second_order"] = True
+
+ block_res_samples = info["controlnet"](
+ img=img_mid,
+ img_ids=img_ids,
+ controlnet_cond=info["controlnet_cond"],
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec_mid,
+ guidance=guidance_vec,
+ )
+
+ pred_mid = model(
+ img=img_mid,
+ img_ids=img_ids,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec_mid,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=[ij * info["controlnet_gs"] for ij in block_res_samples],
+ )
+
+ next_step_velocity = pred_mid
+ img = img + (t_prev - t_curr) * pred_mid
+ info[t_curr] = img
+
+ return img, info
+
+
+def process_mask(input_mask, height, width, latent_image, kernel_size=1):
+ """
+ Process the input mask and return processed_mask, dilated_mask, and flattened_mask.
+
+ Args:
+ input_mask (ms.Tensor or None): Input mask tensor or None.
+ height (int): Height to be used for processing.
+ width (int): Width to be used for processing.
+ latent_image (ms.Tensor): Source image latent tensor (used for dtype).
+ kernel_size (int): Size of the dilation kernel (default is 1).
+
+ Returns:
+ tuple: (processed_mask, dilated_mask, flattened_mask)
+ """
+ # Initialize the processed mask based on the input mask
+ if input_mask is None:
+ processed_mask = ms.mint.ones((1, int(height / 16) * int(width / 16), 1))
+ else:
+ processed_mask = input_mask.copy()
+ # Ensure processed_mask has the correct dtype and is on GPU
+ processed_mask = processed_mask.to(latent_image.dtype)
+ # original_mask = processed_mask.copy()
+ # Convert processed_mask to numpy and prepare for dilation
+ processed_mask_np = (1 - processed_mask.copy().float()).asnumpy()
+ processed_mask_np = np.squeeze(processed_mask_np).reshape(int(height / 16), int(width / 16))
+ # Kernel size and number of iterations for dilation
+ # iterations = 1
+ # kernel = np.ones((kernel_size, kernel_size), np.uint8)
+ # Perform dilation (currently commented out in the original code)
+ dilated_mask_np = processed_mask_np # Example: cv2.dilate(processed_mask_np, kernel, iterations=iterations)
+ dilated_mask_np_larger = processed_mask_np # Example: cv2.dilate(processed_mask_np, (4 * int(height / 512), 4 * int(height / 512)), iterations=iterations)
+ # Convert dilated masks back to mindspore tensors
+ dilated_mask = ms.tensor(dilated_mask_np, dtype=ms.float32).flatten().unsqueeze(1)
+ dilated_mask_larger = ms.tensor(dilated_mask_np_larger, dtype=ms.float32).flatten().unsqueeze(1)
+ # Update processed_mask and dilated_mask_larger
+ processed_mask = 1 - dilated_mask
+ dilated_mask = 1 - dilated_mask_larger
+ # Compute flattened_mask
+ flattened_mask = (1 - processed_mask).flatten()
+ return processed_mask, dilated_mask, flattened_mask
+
+
+def generate_combined_noise(local_mask1_flat, x2, epsilon, idx, info, timesteps):
+ """
+ Generate combined noise based on a local mask and input noise tensors.
+
+ Args:
+ local_mask1_flat (ms.Tensor): A 1D binary tensor indicating the mask (0s and 1s).
+ x2 (ms.Tensor): Noise tensor to be scaled and combined.
+ epsilon (float): Scaling factor for the second noise tensor.
+ Returns:
+ ms.Tensor: Combined noise tensor.
+ """
+ # Count the number of 1s in the mask
+ num_ones = int(mint.sum(local_mask1_flat).item())
+ # Create a tensor `b` with the same size as `local_mask1_flat` and initialize it with 0
+ b = mint.zeros_like(local_mask1_flat)
+ # Get the indices of 0s and 1s in `local_mask1_flat`
+ zero_indices = mint.where(local_mask1_flat == 0)[0]
+ one_indices = mint.where(local_mask1_flat == 1)[0]
+
+ # Handle sampling based on the number of zeros and ones
+ if len(zero_indices) >= num_ones:
+ # Randomly select `num_ones` indices from the 0 region
+ selected_indices = mint.randperm(len(zero_indices))[:num_ones]
+ b[zero_indices[selected_indices]] = 1
+ else:
+ # Select all 0 indices and calculate remaining
+ b[zero_indices] = 1
+ remaining = num_ones - len(zero_indices)
+ # Randomly select the remaining indices from the 1 region
+ selected_indices = mint.randperm(len(one_indices))[:remaining]
+ b[one_indices[selected_indices]] = 1
+
+ # Extract and normalize noise1 based on the updated mask
+ noise1 = info[timesteps[idx + 10]][:, b.bool(), :]
+ mean = mint.mean(noise1)
+ std = mint.std(noise1)
+ noise1 = (noise1 - mean) / std
+
+ # Extract and normalize noise2 based on the original mask
+ noise2 = x2[:, local_mask1_flat.bool(), :]
+ mean = mint.mean(noise2)
+ std = mint.std(noise2)
+ noise2 = (noise2 - mean) / std
+ # Combine noise1 and scaled noise2
+ combined_noise = noise1 + epsilon * noise2
+ # Normalize the combined noise
+ mean = mint.mean(combined_noise)
+ std = mint.std(combined_noise)
+ combined_noise = (combined_noise - mean) / std
+
+ return combined_noise
+
+
+def denoise_cannyedit_removal(
+ model: Flux,
+ controlnet: None,
+ source_image_latent: Tensor,
+ source_image_latent_rg: Tensor,
+ img: Tensor,
+ img_ids: Tensor,
+ # source prompt-related embeddings
+ txt: Tensor,
+ txt_ids: Tensor,
+ vec: Tensor,
+ # local prompt 1-related embeddings
+ txt2: Tensor,
+ txt_ids2: Tensor,
+ vec2: Tensor,
+ # target prompt-related embeddings
+ txt3: Tensor,
+ txt_ids3: Tensor,
+ vec3: Tensor,
+ # additional local prompts-related embeddings
+ txt_addition: list[Tensor],
+ txt_ids_addition: list[Tensor],
+ vec_addition: list[Tensor],
+ # negative prompt-related embeddings
+ neg_txt: Tensor,
+ neg_txt_ids: Tensor,
+ neg_vec: Tensor,
+ local_mask,
+ local_mask_addition,
+ controlnet_cond,
+ # sampling parameters
+ timesteps: list[float],
+ guidance: float = 4.0,
+ true_gs=1,
+ controlnet_gs=0.7,
+ controlnet_gs2=0.5,
+ timestep_to_start_cfg=0,
+ # ip-adapter parameters
+ image_proj: Tensor = None,
+ neg_image_proj: Tensor = None,
+ ip_scale: Union[Tensor, float] = 1,
+ neg_ip_scale: Union[Tensor, float] = 1,
+ seed=random.randint(0, 99999),
+ generate_save_path=None,
+ inversion_save_path=None,
+ stage="stage_removal",
+):
+ guidance_vec = ms.ops.full((img.shape[0],), guidance, dtype=img.dtype)
+ t_length = len(timesteps)
+ if generate_save_path is not None:
+ info_generate = []
+
+ time_to_start = 2
+ i = 0
+
+ for t_curr, t_prev in tqdm(
+ zip(timesteps[:-1], timesteps[1:]), total=len(timesteps) - 1, desc="CannyEdit Denoising Steps"
+ ):
+ if i == 0:
+ # -------------------------Inversion-------------------------------------------------------------------------
+
+ if stage == "stage_removal":
+ timesteps_inv = timesteps # [time_to_start-2:]
+ info = {}
+ info["controlnet_cond"] = controlnet_cond
+ info["controlnet"] = controlnet
+ info["controlnet_gs"] = controlnet_gs2
+
+ z, info = denoise_fireflow(
+ model,
+ source_image_latent_rg,
+ img_ids,
+ txt,
+ txt_ids,
+ vec,
+ timesteps_inv,
+ guidance=1,
+ inverse=True,
+ info=info,
+ )
+
+ if inversion_save_path is not None and stage == "stage_removal":
+ np.save(inversion_save_path, info)
+
+ if stage == "stage_removal_regen" and inversion_save_path is not None:
+ print("load previous inversion results")
+ info = np.load(inversion_save_path, allow_pickle=True).item()
+ # print('Inversion End....')
+
+ # -----------------------End of Inversion----------------------------------------------------------------------
+
+ # ---------------------Processing mask---------------------------------------------------------------------
+ # print('Denoising Start....')
+ bs, c, h, w = source_image_latent.shape
+ H_use = int(h * 8)
+ W_use = int(w * 8)
+
+ b, c, h, w = source_image_latent.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # source_image_latent = rearrange(source_image_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep fpr debugging
+ source_image_latent = source_image_latent.reshape(b, c, sh, 2, sw, 2)
+ source_image_latent = source_image_latent.permute(0, 2, 4, 1, 3, 5)
+ source_image_latent = source_image_latent.reshape(b, sh * sw, c * 4)
+
+ source_image_latent = source_image_latent.broadcast_to((bs, *source_image_latent.shape[1:]))
+ # source_image_latent = repeat(source_image_latent, "1 ... -> bs ...", bs=bs) # keep fpr debugging
+
+ # process the first local mask
+ # after processing, the value inside the edit region is 0 and is 1 elsewhere in local_mask1_proceed and local_mask1_dilate;
+ # for local_mask1_flat, inside edit region is 1 instead
+ local_mask1_proceed, local_mask1_dilate, local_mask1_flat = process_mask(
+ local_mask, H_use, W_use, source_image_latent, kernel_size=1
+ )
+ # process the additional local masks
+ local_mask_add_proceed = []
+ local_mask_add_dilate = []
+ local_mask_add_flat = []
+ local_mask_all_dilate = []
+ local_mask_all_dilate.append(local_mask1_dilate)
+ if local_mask_addition != []:
+ for local_mask1 in local_mask_addition:
+ local_mask2_proceed, local_mask2_dilate, local_mask2_flat = process_mask(
+ local_mask1, H_use, W_use, source_image_latent, kernel_size=1
+ )
+ local_mask_add_proceed.append(local_mask2_proceed)
+ local_mask_add_dilate.append(local_mask2_dilate)
+ local_mask_add_flat.append(local_mask2_flat)
+ local_mask_all_dilate.append(local_mask2_dilate)
+
+ # initialize the mask (union_mask) used for canny control relaxation and blending, where the value inside the union of the edit
+ # regions is 0 and is 1 elsewhere.
+ if local_mask_addition == []:
+ union_mask = local_mask1_dilate
+ elif local_mask_addition != []:
+ union_inverted = 1 - local_mask1_dilate
+ for mask_dilate in local_mask_add_dilate:
+ mask_dilate_inverted = 1 - mask_dilate
+ union_inverted = ms.mint.logical_or(union_inverted.bool(), mask_dilate_inverted.bool())
+ union_inverted = union_inverted.int()
+ union_mask = 1 - union_inverted
+ # ------------------End of processing mask------------------------------------------------------------------
+
+ # ------------------Handle attention mask-------------------------------------------------------------------
+ # len(local_mask_add_proceed)=number of additional local edit prompts + 1 (the first local edit prompt) + 1 (target prompt)
+ conds = [None] * (len(local_mask_add_proceed) + 2)
+ masks = [None] * (len(local_mask_add_proceed) + 2)
+
+ # the first local prompt and mask for the first local edit region
+ conds[1] = txt2
+ masks[1] = 1 - local_mask1_proceed.flatten().unsqueeze(1).repeat(1, conds[1].shape[1])
+ # the additional local prompts and their corresponding local edit regions
+ for indd in range(len(local_mask_add_proceed)):
+ conds[2 + indd] = txt_addition[indd]
+ masks[2 + indd] = 1 - local_mask_add_proceed[indd].flatten().unsqueeze(1).repeat(
+ 1, conds[2 + indd].shape[1]
+ )
+ # the target prompt and its mask see the whole image
+ conds[0] = txt3
+ masks[0] = ms.mint.ones_like(masks[1])
+
+ regional_embeds = ms.mint.cat(conds, dim=1)
+ encoder_seq_len = regional_embeds.shape[1]
+ hidden_seq_len = source_image_latent.shape[1]
+ txt_ids_region = ms.mint.zeros((regional_embeds.shape[1], 3)).to(dtype=txt_ids.dtype).unsqueeze(0)
+
+ # initialize attention mask
+ regional_attention_mask = ms.mint.zeros(
+ (encoder_seq_len + hidden_seq_len, encoder_seq_len + hidden_seq_len), dtype=ms.bool
+ )
+ num_of_regions = len(masks)
+ each_prompt_seq_len = encoder_seq_len // num_of_regions
+
+ # ================================
+ # T2T, T2I and I2T attention mask
+ # Each text can only see itself
+ # Local prompt can only see/be seen by the local edit region
+ # Target prompt can see/be seen by the whole image
+ for ij in range(num_of_regions):
+ # t2t mask txt attends to itself
+ regional_attention_mask[
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len,
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len,
+ ] = True
+ # t2i and i2t mask
+ regional_attention_mask[
+ ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len, encoder_seq_len:
+ ] = masks[ij].transpose(-1, -2)
+ regional_attention_mask[
+ encoder_seq_len:, ij * each_prompt_seq_len : (ij + 1) * each_prompt_seq_len
+ ] = masks[ij]
+
+ # ================================
+ # I2I mask
+ # I2I attention mask:
+ # initialization
+
+ attention_mask_i2i = ms.mint.zeros(
+ (int(H_use / 16) * int(W_use / 16), int(H_use / 16) * int(W_use / 16)), dtype=ms.int
+ )
+ # background region can only see background region
+ # Find the union of regions where both masks are 0
+ zero_union_mask = local_mask1_flat == 0
+ # Iterate over all masks in the list and combine their conditions
+ for local_mask2_flat in local_mask_add_flat:
+ zero_union_mask &= local_mask2_flat == 0
+ zero_union_indices = ms.mint.nonzero(zero_union_mask, as_tuple=True)[0]
+ # attention_mask_i2i[zero_union_indices[:, None],] = 1
+ attention_mask_i2i[zero_union_indices[:, None], zero_union_indices] = 1
+
+ # edited region can only see the whole image
+ mask1_indices = ms.mint.nonzero(local_mask1_flat, as_tuple=True)[0]
+ attention_mask_i2i[mask1_indices, :] = 1
+
+ for local_mask2_flat in local_mask_add_flat:
+ mask2_indices = ms.mint.nonzero(local_mask2_flat, as_tuple=True)[0]
+ attention_mask_i2i[mask2_indices, :] = 1
+
+ regional_attention_mask[encoder_seq_len:, encoder_seq_len:] = ms.mint.ones_like(
+ regional_attention_mask[encoder_seq_len:, encoder_seq_len:]
+ )
+ # regional_attention_mask[encoder_seq_len:, encoder_seq_len:] = attention_mask_i2i
+ # ================================
+
+ # ------------------End of Handle attention mask-------------------------------------------------------------------
+
+ apply_local_point = 0.7
+ apply_extenda_point = 0.5
+
+ # 0818
+ try:
+ tempp = timesteps[i + 10]
+ except IndexError:
+ tempp = timesteps[i]
+
+ if i >= time_to_start:
+ if i == time_to_start:
+ img = info[tempp]
+ # ------------------Reinitialize each local edit region--------------------------------------------------
+
+ x2 = get_noise(1, H_use, W_use, dtype=ms.bfloat16, seed=seed)
+ bs, c, h, w = x2.shape
+
+ b, c, h, w = x2.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # x2 = rearrange(x2, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ x2 = x2.reshape(b, c, sh, 2, sw, 2)
+ x2 = x2.permute(0, 2, 4, 1, 3, 5)
+ x2 = x2.reshape(b, sh * sw, c * 4)
+
+ # x2 = repeat(x2, "1 ... -> bs ...", bs=bs) # keep for debugging
+ x2 = x2.broadcast_to((bs, *x2.shape[1:]))
+ epsilon = 3
+ combined_noise = generate_combined_noise(local_mask1_flat, x2, epsilon, i, info, timesteps)
+ img[:, local_mask1_flat.bool(), :] = combined_noise
+ seed += 1
+
+ for local_mask2_flat in local_mask_add_flat:
+ x3 = get_noise(1, H_use, W_use, dtype=ms.bfloat16, seed=seed)
+ bs, c, h, w = x3.shape
+
+ b, c, h, w = x3.shape
+ sh = h // 2 # ph=2
+ sw = w // 2 # pw=2
+ # x3 = rearrange(x3, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) # keep for debugging
+ x3 = x3.reshape(b, c, sh, 2, sw, 2)
+ x3 = x3.permute(0, 2, 4, 1, 3, 5)
+ x3 = x3.reshape(b, sh * sw, c * 4)
+
+ # x3= repeat(x3, "1 ... -> bs ...", bs=bs)
+ x3 = x3.broadcast_to((bs, *x3.shape[1:]))
+ combined_noise = generate_combined_noise(local_mask2_flat, x3, epsilon, i, info, timesteps)
+ img[:, local_mask2_flat.bool(), :] = combined_noise
+ seed += 1
+
+ # ------------------END of Reinitialize each local edit region-------------------------------------------
+
+ # ================================== Start Denoising =================================================
+
+ t_vec = ms.mint.full((img.shape[0],), t_curr, dtype=img.dtype)
+ imgg = info[tempp]
+
+ block_res_samples = controlnet(
+ # use img=imgg if we want to use the original image+noise in the controlnet
+ img=imgg,
+ img_ids=img_ids,
+ controlnet_cond=controlnet_cond,
+ txt=txt,
+ txt_ids=txt_ids,
+ y=vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ )
+ # ------------------Selective Canny Masking-------------------------------------------------------------------
+ soft_masks = []
+ # Generate a Gaussian soft mask for each tensor
+ for tensor in block_res_samples:
+ soft_masks.append(union_mask.to(dtype=tensor.dtype)) # Add to list
+ # print("soft_masks applied!")
+ block_res_samples = [block_res_samples[i] * soft_masks[i] for i in range(len(block_res_samples))]
+ # ------------------End of Selective Canny Masking--------------------------------------------------------------
+
+ # Stage 1: Regional Denoising
+ if i < int(apply_local_point * t_length):
+ attention_kwargs = {}
+ # *********************************************
+ # union of all local mask
+ attention_kwargs["union_mask"] = union_mask
+ # input each local mask
+ attention_kwargs["local_mask_all_dilate"] = local_mask_all_dilate
+ # number of local edit regions
+ attention_kwargs["num_edit_region"] = 1 + len(txt_addition)
+ # attention_mask
+ attention_kwargs["regional_attention_mask"] = regional_attention_mask.bool()
+ # amplify the attention between the local text promt and local edit region
+ attention_kwargs["local_t2i_strength"] = 1 + 0.4 * (1 - (i / (apply_local_point * t_length)))
+ # amplify the attention between the target prompt and the whole image
+ attention_kwargs["context_t2i_strength"] = 1.0
+ # amplify the attention within each edit region
+ attention_kwargs["local_i2i_strength"] = 1.0
+ # amplify the attention between each edit region and other regions
+ attention_kwargs["local2out_i2i_strength"] = 1.0 + (0.35) * (1 - (i / (apply_local_point * t_length)))
+ attention_kwargs["image_size"] = int(H_use / 16) * int(W_use / 16)
+ # *********************************************
+ if i <= int(t_length * apply_extenda_point):
+ controlnet_control = [ij * controlnet_gs2 for ij in block_res_samples]
+ else:
+ controlnet_control = None
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=regional_embeds,
+ txt_ids=txt_ids_region,
+ y=ms.mint.zeros_like(vec3),
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=controlnet_control,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ attention_kwargs=attention_kwargs,
+ )
+
+ # Stage 2: Normal Denoising
+ else:
+ attention_kwargs = {}
+ if i <= int(t_length * apply_extenda_point):
+ controlnet_control = [ij * controlnet_gs2 for ij in block_res_samples]
+ else:
+ controlnet_control = None
+
+ pred = model(
+ img=img,
+ img_ids=img_ids,
+ txt=txt3,
+ txt_ids=txt_ids3,
+ y=vec3,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=controlnet_control,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ attention_kwargs=attention_kwargs,
+ )
+
+ cfg_step = 30
+ if i < cfg_step:
+ neg_pred1 = model(
+ img=img,
+ img_ids=img_ids,
+ txt=neg_txt,
+ txt_ids=neg_txt_ids,
+ y=neg_vec,
+ timesteps=t_vec,
+ guidance=guidance_vec,
+ block_controlnet_hidden_states=controlnet_control,
+ image_proj=image_proj,
+ ip_scale=ip_scale,
+ )
+
+ true_gs1 = 3 + (true_gs) * (1 - (i / (cfg_step)))
+ pred = mint.where(union_mask == 1, pred, neg_pred1 + true_gs1 * (pred - neg_pred1))
+
+ img = img + (t_prev - t_curr) * pred
+
+ # Cyclical Blending
+ if i < 10 or (i <= 30 and i % 5 == 0):
+ img = mint.where(union_mask == 1, 0.5 * info[tempp] + 0.5 * img, img)
+ elif i <= 40 and i % 10 == 0:
+ img = mint.where(union_mask == 1, 0.3 * info[tempp] + 0.7 * img, img)
+
+ if generate_save_path is not None:
+ info_generate[i] = img
+
+ i += 1
+
+ if generate_save_path is not None and stage == "stage_generate":
+ np.save(generate_save_path, info_generate)
+ return img
+
+
+def unpack(x: Tensor, height: int, width: int) -> Tensor:
+ b = x.shape[0]
+ h = math.ceil(height / 16)
+ w = math.ceil(width / 16)
+ # return rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ # h=h, w=w, ph=2, pw=2) # keep for debugging
+ x = x.reshape(b, h * w, -1) # -1 will infer the correct size for c*ph*pw
+ c = x.shape[2] // 4 # since ph=pw=2, total is 4
+ x = x.reshape(b, h, w, c, 2, 2)
+ x = x.transpose(0, 3, 1, 4, 2, 5)
+ x = x.reshape(b, c, h * 2, w * 2)
+ return x
diff --git a/examples/canny_edit/src/util.py b/examples/canny_edit/src/util.py
new file mode 100644
index 0000000000..500a9e6d65
--- /dev/null
+++ b/examples/canny_edit/src/util.py
@@ -0,0 +1,395 @@
+import os
+from dataclasses import dataclass
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from safetensors import safe_open
+from transformers.utils import is_safetensors_available
+
+import mindspore
+import mindspore as ms
+from mindspore.nn.utils import no_init_parameters
+
+from .annotator.canny import CannyDetector
+from .controlnet import ControlNetFlux
+from .model import Flux, FluxParams
+from .modules.autoencoder import AutoEncoder, AutoEncoderParams
+from .modules.conditioner import HFEmbedder
+
+
+def load_safetensors(path):
+ if path.endswith(".safetensors") and is_safetensors_available():
+ # Check format of the archive
+ with safe_open(path, framework="np") as f:
+ metadata = f.metadata()
+ if metadata is not None:
+ format = metadata.get("format", None)
+ if format is not None and format not in ["pt", "tf", "flax", "np"]:
+ raise OSError(
+ f"The safetensors archive passed at {path} does not contain the valid metadata. Make sure "
+ "you save your model with the `save_pretrained` method."
+ )
+ return ms.load_checkpoint(path, format="safetensors")
+
+
+def get_lora_rank(checkpoint):
+ for k in checkpoint.keys():
+ if k.endswith(".down.weight"):
+ return checkpoint[k].shape[0]
+
+
+def load_checkpoint(local_path, repo_id, name):
+ if local_path is not None:
+ if ".safetensors" in local_path:
+ print(f"Loading .safetensors checkpoint from {local_path}")
+ checkpoint = load_safetensors(local_path)
+ else:
+ print(f"Loading checkpoint from {local_path}")
+ checkpoint = ms.load_checkpoint(local_path)
+ elif repo_id is not None and name is not None:
+ print(f"Loading checkpoint {name} from repo id {repo_id}")
+ checkpoint = load_from_repo_id(repo_id, name)
+ else:
+ raise ValueError("LOADING ERROR: you must specify local_path or repo_id with name in HF to download")
+ return checkpoint
+
+
+def c_crop(image):
+ width, height = image.size
+ new_size = min(width, height)
+ left = (width - new_size) / 2
+ top = (height - new_size) / 2
+ right = (width + new_size) / 2
+ bottom = (height + new_size) / 2
+ return image.crop((left, top, right, bottom))
+
+
+def pad64(x):
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def safer_memory(x):
+ # Fix many MAC/AMD problems
+ return np.ascontiguousarray(x.copy()).copy()
+
+
+# Added upscale_method, mode params
+def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode="edge"):
+ if skip_hwc3:
+ img = input_image
+ else:
+ img = HWC3(input_image)
+ H_raw, W_raw, _ = img.shape
+ if resolution == 0:
+ return img, lambda x: x
+ k = float(resolution) / float(min(H_raw, W_raw))
+ H_target = int(np.round(float(H_raw) * k))
+ W_target = int(np.round(float(W_raw) * k))
+ img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA)
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
+
+ def remove_pad(x):
+ return safer_memory(x[:H_target, :W_target, ...])
+
+ return safer_memory(img_padded), remove_pad
+
+
+class Annotator:
+ def __init__(self, name: str):
+ if name == "canny":
+ processor = CannyDetector()
+ else:
+ raise ValueError(f"Invalid annotator name: {name}")
+ self.name = name
+ self.processor = processor
+
+ def __call__(self, image: Image, width: int, height: int):
+ image = np.array(image)
+ detect_resolution = max(width, height)
+ image, remove_pad = resize_image_with_pad(image, detect_resolution)
+
+ image = np.array(image)
+ if self.name == "canny":
+ result = self.processor(image, low_threshold=100, high_threshold=200)
+ elif self.name == "hough":
+ result = self.processor(image, thr_v=0.05, thr_d=5)
+ elif self.name == "depth":
+ result = self.processor(image)
+ result, _ = result
+ else:
+ result = self.processor(image)
+
+ result = HWC3(remove_pad(result))
+ result = cv2.resize(result, (width, height))
+ return result
+
+
+@dataclass
+class ModelSpec:
+ params: FluxParams
+ ae_params: AutoEncoderParams
+ ckpt_path: str | None
+ ae_path: str | None
+ repo_id: str | None
+ repo_flow: str | None
+ repo_ae: str | None
+ repo_id_ae: str | None
+
+
+configs = {
+ "flux-dev": ModelSpec(
+ repo_id="black-forest-labs/FLUX.1-dev",
+ repo_id_ae="black-forest-labs/FLUX.1-dev",
+ repo_flow="flux1-dev.safetensors",
+ repo_ae="ae.safetensors",
+ ckpt_path=os.getenv("FLUX_DEV"),
+ params=FluxParams(
+ in_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ ),
+ ae_path=os.getenv("AE"),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
+}
+
+
+def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
+ if len(missing) > 0 and len(unexpected) > 0:
+ print(f"Got {len(missing)} missing keys: \n\t" + "\n\t".join(missing))
+ print("\n" + "-" * 79 + "\n")
+ print(f"Got {len(unexpected)} unexpected keys: \n\t" + "\n\t".join(unexpected))
+ elif len(missing) > 0:
+ print(f"Got {len(missing)} missing keys: \n\t" + "\n\t".join(missing))
+ elif len(unexpected) > 0:
+ print(f"Got {len(unexpected)} unexpected keys: \n\t" + "\n\t".join(unexpected))
+
+
+def load_from_repo_id(repo_id, checkpoint_name):
+ ckpt_path = hf_hub_download(repo_id, checkpoint_name)
+ sd = load_safetensors(ckpt_path)
+ return sd
+
+
+def load_flow_model(name: str, hf_download: bool = True):
+ # Loading Flux
+ print(f"Init model of {name}")
+ ckpt_path = configs[name].ckpt_path
+ if ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download:
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
+
+ with no_init_parameters():
+ model = Flux(configs[name].params)
+ model = set_model_param_dtype(model, dtype=ms.bfloat16)
+ model = model.to_float(ms.bfloat16)
+
+ if ckpt_path is not None:
+ print(f"Loading checkpoint of {name}")
+ sd = load_safetensors(ckpt_path)
+ # missing, unexpected = model.load_state_dict(sd, strict=False)
+ missing, unexpected = ms.load_param_into_net(model, sd, strict_load=False)
+ print_load_warning(missing, unexpected)
+ return model
+
+
+def load_flow_model_quantized(name: str, hf_download: bool = True):
+ raise NotImplementedError("Quantization is not supported in mindspore.")
+
+
+def load_controlnet(name, transformer=None):
+ controlnet = ControlNetFlux(configs[name].params)
+ if transformer is not None:
+ controlnet.load_state_dict(transformer.state_dict(), strict=False)
+ return controlnet
+
+
+def load_t5(max_length: int = 512) -> HFEmbedder:
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, mindspore_dtype=mindspore.bfloat16)
+
+
+def load_clip() -> HFEmbedder:
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, mindspore_dtype=mindspore.bfloat16)
+
+
+def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
+ ckpt_path = configs[name].ae_path
+ if ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download:
+ ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
+
+ # Loading the autoencoder
+ print("Init autoencoder")
+ ae = AutoEncoder(configs[name].ae_params)
+
+ if ckpt_path is not None:
+ sd = load_safetensors(ckpt_path)
+ # missing, unexpected = ae.load_state_dict(sd, strict=False)
+ missing, unexpected = ms.load_param_into_net(ae, sd, strict_load=False)
+ print_load_warning(missing, unexpected)
+ return ae
+
+
+def set_model_param_dtype(model, dtype=ms.bfloat16, keep_norm_fp32=False):
+ if model is not None:
+ assert isinstance(model, ms.nn.Cell)
+
+ k_num, c_num = 0, 0
+ for _, p in model.parameters_and_names():
+ # filter norm/embedding position_ids param
+ if keep_norm_fp32 and ("norm" in p.name):
+ # print(f"param {p.name} keep {p.dtype}") # disable print
+ k_num += 1
+ elif "position_ids" in p.name:
+ k_num += 1
+ else:
+ c_num += 1
+ p.set_dtype(dtype)
+
+ print(f"Convert '{type(model).__name__}' param to {dtype}, keep/modify num {k_num}/{c_num}.")
+
+ return model
+
+
+def process_mask(
+ mask_path, height, width, dilate=False, dilation_kernel_size=(5, 5), fill_holes=False, closing_kernel_size=(5, 5)
+):
+ """
+ Processes a mask image, optionally fills holes, dilates it, and returns a simple mask tensor.
+
+ Args:
+ mask_path (str): The path to the mask image.
+ height (int): The desired height for the original image dimensions.
+ width (int): The desired width for the original image dimensions.
+ dilate (bool, optional): If True, performs a dilation operation to expand the
+ mask area. Defaults to False.
+ dilation_kernel_size (tuple, optional): The size of the kernel for dilation.
+ Defaults to (5, 5).
+ fill_holes (bool, optional): If True, performs a morphological closing operation
+ to fill small holes within the mask. Defaults to False.
+ closing_kernel_size (tuple, optional): The size of the kernel for the closing
+ operation. Defaults to (5, 5).
+
+ Returns:
+ ms.Tensor: The processed simple mask tensor, ready for use.
+ """
+ # Read the mask image
+ mask = cv2.imread(mask_path)
+ if mask is None:
+ raise FileNotFoundError(f"Could not read mask file from: {mask_path}")
+
+ # Convert the mask to grayscale
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
+
+ # Downsample the mask to 1/16th of the target dimensions
+ downsampled_mask = cv2.resize(mask, (width // 16, height // 16), interpolation=cv2.INTER_AREA)
+
+ # Threshold the downsampled mask to a binary format (0 or 255)
+ _, binary_downsampled_mask = cv2.threshold(downsampled_mask, 127, 255, cv2.THRESH_BINARY)
+
+ # --- Optional Hole Filling Step ---
+ # This operation is ideal for making masks contiguous and removing "pepper" noise.
+ if fill_holes:
+ # Create the kernel for the closing operation.
+ kernel = np.ones(closing_kernel_size, np.uint8)
+ # Apply morphological closing.
+ binary_downsampled_mask = cv2.morphologyEx(binary_downsampled_mask, cv2.MORPH_CLOSE, kernel)
+
+ # --- Optional Dilation Step ---
+ # This expands the outer boundary of the mask.
+ if dilate:
+ # Create a kernel for the dilation.
+ kernel = np.ones(dilation_kernel_size, np.uint8)
+ # Apply the dilation operation.
+ binary_downsampled_mask = cv2.dilate(binary_downsampled_mask, kernel, iterations=1)
+
+ # Normalize the binary mask to have values of 0 and 1
+ binary_downsampled_mask = (binary_downsampled_mask // 255).astype(np.uint8)
+
+ # Invert the mask (object area becomes 0, background becomes 1)
+ local_mask = 1 - binary_downsampled_mask
+
+ # Convert the final mask to a MindSpore tensor
+ local_mask_tensor = ms.tensor(local_mask, dtype=ms.float32)
+
+ return local_mask_tensor
+
+
+def plot_image_with_mask(image_path, mask_path_list, width, height, save_path):
+ # Load the image
+ image = cv2.imread(image_path)
+
+ # Resize the image to the specified width and height
+ image = cv2.resize(image, (width, height))
+
+ # Convert the image from BGR to RGB for proper display in matplotlib
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ # Create a figure for plotting
+ fig, ax = plt.subplots(figsize=(10, 10))
+ ax.imshow(image)
+ ax.axis("off")
+
+ for mask_path in mask_path_list:
+ # Load the mask
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
+
+ # Resize the mask to match the resized image dimensions
+ mask = cv2.resize(mask, (width, height))
+
+ # Find the coordinates of the white region in the mask
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours:
+ # Get the bounding rectangle for each contour
+ x, y, w, h = cv2.boundingRect(contour)
+ # Draw the bounding rectangle as a red box
+ rect = plt.Rectangle((x, y), w, h, edgecolor="red", fill=False, linewidth=4)
+ ax.add_patch(rect)
+
+ # Save the figure to the specified output path
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
+ plt.close(fig)
+
+ return save_path
diff --git a/examples/canny_edit/templates/index.html b/examples/canny_edit/templates/index.html
new file mode 100644
index 0000000000..fc2a30bf38
--- /dev/null
+++ b/examples/canny_edit/templates/index.html
@@ -0,0 +1,288 @@
+
+
+
+
+
+ Interactive Image Segmentation
+
+
+
+
+
Interactive Image Segmentation
+
+
+
1. Upload Image
+
+
+
+
+
+
+
2. Draw on Image
+
Click and drag on the image to mark the area you want to segment.