Skip to content

Commit a4f22d0

Browse files
leejetstduhpf
authored andcommitted
add edit mode
1 parent 5b33b8d commit a4f22d0

File tree

6 files changed

+369
-229
lines changed

6 files changed

+369
-229
lines changed

diffusion_model.hpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16-
int num_video_frames = -1,
17-
std::vector<struct ggml_tensor*> controls = {},
18-
float control_strength = 0.f,
19-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
20-
struct ggml_tensor** output = NULL,
21-
struct ggml_context* output_ctx = NULL,
22-
std::vector<int> skip_layers = std::vector<int>()) = 0;
16+
std::vector<ggml_tensor*> ref_latents = {},
17+
int num_video_frames = -1,
18+
std::vector<struct ggml_tensor*> controls = {},
19+
float control_strength = 0.f,
20+
struct ggml_tensor** output = NULL,
21+
struct ggml_context* output_ctx = NULL,
22+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2323
virtual void alloc_params_buffer() = 0;
2424
virtual void free_params_buffer() = 0;
2525
virtual void free_compute_buffer() = 0;
@@ -69,13 +69,13 @@ struct UNetModel : public DiffusionModel {
6969
struct ggml_tensor* c_concat,
7070
struct ggml_tensor* y,
7171
struct ggml_tensor* guidance,
72-
int num_video_frames = -1,
73-
std::vector<struct ggml_tensor*> controls = {},
74-
float control_strength = 0.f,
75-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
76-
struct ggml_tensor** output = NULL,
77-
struct ggml_context* output_ctx = NULL,
78-
std::vector<int> skip_layers = std::vector<int>()) {
72+
std::vector<ggml_tensor*> ref_latents = {},
73+
int num_video_frames = -1,
74+
std::vector<struct ggml_tensor*> controls = {},
75+
float control_strength = 0.f,
76+
struct ggml_tensor** output = NULL,
77+
struct ggml_context* output_ctx = NULL,
78+
std::vector<int> skip_layers = std::vector<int>()) {
7979
(void)skip_layers; // SLG doesn't work with UNet models
8080
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
8181
}
@@ -120,13 +120,13 @@ struct MMDiTModel : public DiffusionModel {
120120
struct ggml_tensor* c_concat,
121121
struct ggml_tensor* y,
122122
struct ggml_tensor* guidance,
123-
int num_video_frames = -1,
124-
std::vector<struct ggml_tensor*> controls = {},
125-
float control_strength = 0.f,
126-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
127-
struct ggml_tensor** output = NULL,
128-
struct ggml_context* output_ctx = NULL,
129-
std::vector<int> skip_layers = std::vector<int>()) {
123+
std::vector<struct ggml_tensor*> ref_latents = {},
124+
int num_video_frames = -1,
125+
std::vector<struct ggml_tensor*> controls = {},
126+
float control_strength = 0.f,
127+
struct ggml_tensor** output = NULL,
128+
struct ggml_context* output_ctx = NULL,
129+
std::vector<int> skip_layers = std::vector<int>()) {
130130
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
131131
}
132132
};
@@ -172,14 +172,14 @@ struct FluxModel : public DiffusionModel {
172172
struct ggml_tensor* c_concat,
173173
struct ggml_tensor* y,
174174
struct ggml_tensor* guidance,
175-
int num_video_frames = -1,
176-
std::vector<struct ggml_tensor*> controls = {},
177-
float control_strength = 0.f,
178-
std::vector<struct ggml_tensor*> kontext_imgs = std::vector<struct ggml_tensor*>(),
179-
struct ggml_tensor** output = NULL,
180-
struct ggml_context* output_ctx = NULL,
181-
std::vector<int> skip_layers = std::vector<int>()) {
182-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, kontext_imgs, output, output_ctx, skip_layers);
175+
std::vector<ggml_tensor*> ref_latents = {},
176+
int num_video_frames = -1,
177+
std::vector<struct ggml_tensor*> controls = {},
178+
float control_strength = 0.f,
179+
struct ggml_tensor** output = NULL,
180+
struct ggml_context* output_ctx = NULL,
181+
std::vector<int> skip_layers = std::vector<int>()) {
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers);
183183
}
184184
};
185185

examples/cli/main.cpp

Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const char* modes_str[] = {
5757
"txt2img",
5858
"img2img",
5959
"img2vid",
60+
"edit",
6061
"convert",
6162
};
6263

@@ -71,6 +72,7 @@ enum SDMode {
7172
TXT2IMG,
7273
IMG2IMG,
7374
IMG2VID,
75+
EDIT,
7476
CONVERT,
7577
MODE_COUNT
7678
};
@@ -96,8 +98,7 @@ struct SDParams {
9698
std::string input_path;
9799
std::string mask_path;
98100
std::string control_image_path;
99-
100-
std::vector<std::string> kontext_image_paths;
101+
std::vector<std::string> ref_image_paths;
101102

102103
std::string prompt;
103104
std::string negative_prompt;
@@ -181,6 +182,10 @@ void print_params(SDParams params) {
181182
printf(" init_img: %s\n", params.input_path.c_str());
182183
printf(" mask_img: %s\n", params.mask_path.c_str());
183184
printf(" control_image: %s\n", params.control_image_path.c_str());
185+
printf(" ref_images_paths:\n");
186+
for (auto& path : params.ref_image_paths) {
187+
printf(" %s\n", path.c_str());
188+
};
184189
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
185190
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
186191
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
@@ -241,6 +246,7 @@ void print_usage(int argc, const char* argv[]) {
241246
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
242247
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
243248
printf(" --control-image [IMAGE] path to image condition, control net\n");
249+
printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
244250
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
245251
printf(" -p, --prompt [PROMPT] the prompt to render\n");
246252
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -289,9 +295,8 @@ void print_usage(int argc, const char* argv[]) {
289295
printf(" %s is the fastest\n", previews_str[SD_PREVIEW_PROJ]);
290296
printf(" --preview-interval [N] How often to save the image preview");
291297
printf(" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n");
292-
printf(" --color Colors the logging tags according to level\n");
298+
printf(" --color colors the logging tags according to level\n");
293299
printf(" -v, --verbose print extra info\n");
294-
printf(" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n");
295300
}
296301

297302
void parse_args(int argc, const char** argv, SDParams& params) {
@@ -727,12 +732,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
727732
break;
728733
}
729734
params.imatrix_in.push_back(std::string(argv[i]));
730-
} else if (arg == "-ki" || arg == "--kontext-img") {
735+
} else if (arg == "-r" || arg == "--ref-image") {
731736
if (++i >= argc) {
732737
invalid_arg = true;
733738
break;
734739
}
735-
params.kontext_image_paths.push_back(argv[i]);
740+
params.ref_image_paths.push_back(argv[i]);
736741
} else {
737742
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
738743
print_usage(argc, argv);
@@ -797,7 +802,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
797802
}
798803

799804
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
800-
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
805+
fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n");
806+
print_usage(argc, argv);
807+
exit(1);
808+
}
809+
810+
if (params.mode == EDIT && params.ref_image_paths.size() == 0) {
811+
fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n");
801812
print_usage(argc, argv);
802813
exit(1);
803814
}
@@ -1014,43 +1025,12 @@ int main(int argc, const char* argv[]) {
10141025
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
10151026
return 1;
10161027
}
1017-
bool vae_decode_only = true;
1018-
1019-
std::vector<sd_image_t> kontext_imgs;
1020-
for (auto& path : params.kontext_image_paths) {
1021-
vae_decode_only = false;
1022-
int c = 0;
1023-
int width = 0;
1024-
int height = 0;
1025-
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
1026-
if (image_buffer == NULL) {
1027-
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
1028-
return 1;
1029-
}
1030-
if (c < 3) {
1031-
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
1032-
free(image_buffer);
1033-
return 1;
1034-
}
1035-
if (width <= 0) {
1036-
fprintf(stderr, "error: the width of image must be greater than 0\n");
1037-
free(image_buffer);
1038-
return 1;
1039-
}
1040-
if (height <= 0) {
1041-
fprintf(stderr, "error: the height of image must be greater than 0\n");
1042-
free(image_buffer);
1043-
return 1;
1044-
}
1045-
kontext_imgs.push_back({(uint32_t)width,
1046-
(uint32_t)height,
1047-
3,
1048-
image_buffer});
1049-
}
10501028

1029+
bool vae_decode_only = true;
10511030
uint8_t* input_image_buffer = NULL;
10521031
uint8_t* control_image_buffer = NULL;
10531032
uint8_t* mask_image_buffer = NULL;
1033+
std::vector<sd_image_t> ref_images;
10541034

10551035
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
10561036
vae_decode_only = false;
@@ -1102,6 +1082,37 @@ int main(int argc, const char* argv[]) {
11021082
free(input_image_buffer);
11031083
input_image_buffer = resized_image_buffer;
11041084
}
1085+
} else if (params.mode == EDIT) {
1086+
vae_decode_only = false;
1087+
for (auto& path : params.ref_image_paths) {
1088+
int c = 0;
1089+
int width = 0;
1090+
int height = 0;
1091+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
1092+
if (image_buffer == NULL) {
1093+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
1094+
return 1;
1095+
}
1096+
if (c < 3) {
1097+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
1098+
free(image_buffer);
1099+
return 1;
1100+
}
1101+
if (width <= 0) {
1102+
fprintf(stderr, "error: the width of image must be greater than 0\n");
1103+
free(image_buffer);
1104+
return 1;
1105+
}
1106+
if (height <= 0) {
1107+
fprintf(stderr, "error: the height of image must be greater than 0\n");
1108+
free(image_buffer);
1109+
return 1;
1110+
}
1111+
ref_images.push_back({(uint32_t)width,
1112+
(uint32_t)height,
1113+
3,
1114+
image_buffer});
1115+
}
11051116
}
11061117

11071118
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
@@ -1187,9 +1198,8 @@ int main(int argc, const char* argv[]) {
11871198
params.control_strength,
11881199
params.style_ratio,
11891200
params.normalize_input,
1190-
params.input_id_images_path.c_str(),
1191-
kontext_imgs.data(), kontext_imgs.size());
1192-
} else {
1201+
params.input_id_images_path.c_str());
1202+
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
11931203
sd_image_t input_image = {(uint32_t)params.width,
11941204
(uint32_t)params.height,
11951205
3,
@@ -1250,9 +1260,28 @@ int main(int argc, const char* argv[]) {
12501260
params.control_strength,
12511261
params.style_ratio,
12521262
params.normalize_input,
1253-
params.input_id_images_path.c_str(),
1254-
kontext_imgs.data(), kontext_imgs.size());
1263+
params.input_id_images_path.c_str());
12551264
}
1265+
} else { // EDIT
1266+
results = edit(sd_ctx,
1267+
ref_images.data(),
1268+
ref_images.size(),
1269+
params.prompt.c_str(),
1270+
params.negative_prompt.c_str(),
1271+
params.clip_skip,
1272+
guidance_params,
1273+
params.eta,
1274+
params.width,
1275+
params.height,
1276+
params.sample_method,
1277+
params.sample_steps,
1278+
params.seed,
1279+
params.batch_count,
1280+
control_image,
1281+
params.control_strength,
1282+
params.style_ratio,
1283+
params.normalize_input,
1284+
params.input_id_images_path.c_str());
12561285
}
12571286

12581287
if (results == NULL) {
@@ -1335,4 +1364,4 @@ int main(int argc, const char* argv[]) {
13351364
free(input_image_buffer);
13361365

13371366
return 0;
1338-
}
1367+
}

examples/server/main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,7 @@ void start_server(SDParams params) {
15161516
1,
15171517
params.lastRequest.style_ratio,
15181518
params.lastRequest.normalize_input,
1519-
params.input_id_images_path.c_str(),
1520-
NULL, 0);
1519+
params.input_id_images_path.c_str());
15211520

15221521
if (results == NULL) {
15231522
printf("generate failed\n");

0 commit comments

Comments
 (0)