@@ -57,6 +57,7 @@ const char* modes_str[] = {
5757 " txt2img" ,
5858 " img2img" ,
5959 " img2vid" ,
60+ " edit" ,
6061 " convert" ,
6162};
6263
@@ -71,6 +72,7 @@ enum SDMode {
7172 TXT2IMG,
7273 IMG2IMG,
7374 IMG2VID,
75+ EDIT,
7476 CONVERT,
7577 MODE_COUNT
7678};
@@ -96,8 +98,7 @@ struct SDParams {
9698 std::string input_path;
9799 std::string mask_path;
98100 std::string control_image_path;
99-
100- std::vector<std::string> kontext_image_paths;
101+ std::vector<std::string> ref_image_paths;
101102
102103 std::string prompt;
103104 std::string negative_prompt;
@@ -181,6 +182,10 @@ void print_params(SDParams params) {
181182 printf (" init_img: %s\n " , params.input_path .c_str ());
182183 printf (" mask_img: %s\n " , params.mask_path .c_str ());
183184 printf (" control_image: %s\n " , params.control_image_path .c_str ());
185+ printf (" ref_images_paths:\n " );
186+ for (auto & path : params.ref_image_paths ) {
187+ printf (" %s\n " , path.c_str ());
188+ };
184189 printf (" clip on cpu: %s\n " , params.clip_on_cpu ? " true" : " false" );
185190 printf (" controlnet cpu: %s\n " , params.control_net_cpu ? " true" : " false" );
186191 printf (" vae decoder on cpu:%s\n " , params.vae_on_cpu ? " true" : " false" );
@@ -241,6 +246,7 @@ void print_usage(int argc, const char* argv[]) {
241246 printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
242247 printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
243248 printf (" --control-image [IMAGE] path to image condition, control net\n " );
249+ printf (" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n " );
244250 printf (" -o, --output OUTPUT path to write result image to (default: ./output.png)\n " );
245251 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
246252 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
@@ -289,9 +295,8 @@ void print_usage(int argc, const char* argv[]) {
289295 printf (" %s is the fastest\n " , previews_str[SD_PREVIEW_PROJ]);
290296 printf (" --preview-interval [N] How often to save the image preview" );
291297 printf (" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n " );
292- printf (" --color Colors the logging tags according to level\n " );
298+ printf (" --color colors the logging tags according to level\n " );
293299 printf (" -v, --verbose print extra info\n " );
294- printf (" -ki, --kontext_img [PATH] Reference image for Flux Kontext models (can be used multiple times) \n " );
295300}
296301
297302void parse_args (int argc, const char ** argv, SDParams& params) {
@@ -727,12 +732,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
727732 break ;
728733 }
729734 params.imatrix_in .push_back (std::string (argv[i]));
730- } else if (arg == " -ki " || arg == " --kontext-img " ) {
735+ } else if (arg == " -r " || arg == " --ref-image " ) {
731736 if (++i >= argc) {
732737 invalid_arg = true ;
733738 break ;
734739 }
735- params.kontext_image_paths .push_back (argv[i]);
740+ params.ref_image_paths .push_back (argv[i]);
736741 } else {
737742 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
738743 print_usage (argc, argv);
@@ -797,7 +802,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
797802 }
798803
799804 if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path .length () == 0 ) {
800- fprintf (stderr, " error: when using the img2img mode, the following arguments are required: init-img\n " );
805+ fprintf (stderr, " error: when using the img2img/img2vid mode, the following arguments are required: init-img\n " );
806+ print_usage (argc, argv);
807+ exit (1 );
808+ }
809+
810+ if (params.mode == EDIT && params.ref_image_paths .size () == 0 ) {
811+ fprintf (stderr, " error: when using the edit mode, the following arguments are required: ref-image\n " );
801812 print_usage (argc, argv);
802813 exit (1 );
803814 }
@@ -1014,43 +1025,12 @@ int main(int argc, const char* argv[]) {
10141025 fprintf (stderr, " SVD support is broken, do not use it!!!\n " );
10151026 return 1 ;
10161027 }
1017- bool vae_decode_only = true ;
1018-
1019- std::vector<sd_image_t > kontext_imgs;
1020- for (auto & path : params.kontext_image_paths ) {
1021- vae_decode_only = false ;
1022- int c = 0 ;
1023- int width = 0 ;
1024- int height = 0 ;
1025- uint8_t * image_buffer = stbi_load (path.c_str (), &width, &height, &c, 3 );
1026- if (image_buffer == NULL ) {
1027- fprintf (stderr, " load image from '%s' failed\n " , path.c_str ());
1028- return 1 ;
1029- }
1030- if (c < 3 ) {
1031- fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
1032- free (image_buffer);
1033- return 1 ;
1034- }
1035- if (width <= 0 ) {
1036- fprintf (stderr, " error: the width of image must be greater than 0\n " );
1037- free (image_buffer);
1038- return 1 ;
1039- }
1040- if (height <= 0 ) {
1041- fprintf (stderr, " error: the height of image must be greater than 0\n " );
1042- free (image_buffer);
1043- return 1 ;
1044- }
1045- kontext_imgs.push_back ({(uint32_t )width,
1046- (uint32_t )height,
1047- 3 ,
1048- image_buffer});
1049- }
10501028
1029+ bool vae_decode_only = true ;
10511030 uint8_t * input_image_buffer = NULL ;
10521031 uint8_t * control_image_buffer = NULL ;
10531032 uint8_t * mask_image_buffer = NULL ;
1033+ std::vector<sd_image_t > ref_images;
10541034
10551035 if (params.mode == IMG2IMG || params.mode == IMG2VID) {
10561036 vae_decode_only = false ;
@@ -1102,6 +1082,37 @@ int main(int argc, const char* argv[]) {
11021082 free (input_image_buffer);
11031083 input_image_buffer = resized_image_buffer;
11041084 }
1085+ } else if (params.mode == EDIT) {
1086+ vae_decode_only = false ;
1087+ for (auto & path : params.ref_image_paths ) {
1088+ int c = 0 ;
1089+ int width = 0 ;
1090+ int height = 0 ;
1091+ uint8_t * image_buffer = stbi_load (path.c_str (), &width, &height, &c, 3 );
1092+ if (image_buffer == NULL ) {
1093+ fprintf (stderr, " load image from '%s' failed\n " , path.c_str ());
1094+ return 1 ;
1095+ }
1096+ if (c < 3 ) {
1097+ fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
1098+ free (image_buffer);
1099+ return 1 ;
1100+ }
1101+ if (width <= 0 ) {
1102+ fprintf (stderr, " error: the width of image must be greater than 0\n " );
1103+ free (image_buffer);
1104+ return 1 ;
1105+ }
1106+ if (height <= 0 ) {
1107+ fprintf (stderr, " error: the height of image must be greater than 0\n " );
1108+ free (image_buffer);
1109+ return 1 ;
1110+ }
1111+ ref_images.push_back ({(uint32_t )width,
1112+ (uint32_t )height,
1113+ 3 ,
1114+ image_buffer});
1115+ }
11051116 }
11061117
11071118 sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
@@ -1187,9 +1198,8 @@ int main(int argc, const char* argv[]) {
11871198 params.control_strength ,
11881199 params.style_ratio ,
11891200 params.normalize_input ,
1190- params.input_id_images_path .c_str (),
1191- kontext_imgs.data (), kontext_imgs.size ());
1192- } else {
1201+ params.input_id_images_path .c_str ());
1202+ } else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
11931203 sd_image_t input_image = {(uint32_t )params.width ,
11941204 (uint32_t )params.height ,
11951205 3 ,
@@ -1250,9 +1260,28 @@ int main(int argc, const char* argv[]) {
12501260 params.control_strength ,
12511261 params.style_ratio ,
12521262 params.normalize_input ,
1253- params.input_id_images_path .c_str (),
1254- kontext_imgs.data (), kontext_imgs.size ());
1263+ params.input_id_images_path .c_str ());
12551264 }
1265+ } else { // EDIT
1266+ results = edit (sd_ctx,
1267+ ref_images.data (),
1268+ ref_images.size (),
1269+ params.prompt .c_str (),
1270+ params.negative_prompt .c_str (),
1271+ params.clip_skip ,
1272+ guidance_params,
1273+ params.eta ,
1274+ params.width ,
1275+ params.height ,
1276+ params.sample_method ,
1277+ params.sample_steps ,
1278+ params.seed ,
1279+ params.batch_count ,
1280+ control_image,
1281+ params.control_strength ,
1282+ params.style_ratio ,
1283+ params.normalize_input ,
1284+ params.input_id_images_path .c_str ());
12561285 }
12571286
12581287 if (results == NULL ) {
@@ -1335,4 +1364,4 @@ int main(int argc, const char* argv[]) {
13351364 free (input_image_buffer);
13361365
13371366 return 0 ;
1338- }
1367+ }
0 commit comments