@@ -930,14 +930,13 @@ namespace Flux {
930930 }
931931
932932 struct ggml_tensor * forward (struct ggml_context * ctx,
933- struct ggml_tensor * x ,
933+ std::vector< struct ggml_tensor *> imgs ,
934934 struct ggml_tensor * timestep,
935935 struct ggml_tensor * context,
936936 struct ggml_tensor * c_concat,
937937 struct ggml_tensor * y,
938938 struct ggml_tensor * guidance,
939939 struct ggml_tensor * pe,
940- bool kontext_concat = false ,
941940 struct ggml_tensor * arange = NULL ,
942941 std::vector<int > skip_layers = std::vector<int >(),
943942 SDVersion version = VERSION_FLUX) {
@@ -951,19 +950,31 @@ namespace Flux {
951950 // pe: (L, d_head/2, 2, 2)
952951 // return: (N, C, H, W)
953952
953+ auto x = imgs[0 ];
954954 GGML_ASSERT (x->ne [3 ] == 1 );
955955
956956 int64_t W = x->ne [0 ];
957957 int64_t H = x->ne [1 ];
958958 int64_t C = x->ne [2 ];
959959 int64_t patch_size = 2 ;
960- int pad_h = (patch_size - H % patch_size) % patch_size;
961- int pad_w = (patch_size - W % patch_size) % patch_size;
962- x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 ); // [N, C, H + pad_h, W + pad_w]
960+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
961+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
963962
964963 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
965- auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
966- int64_t patchified_img_size = img->ne [1 ];
964+ ggml_tensor* img = NULL ; // [N, h*w, C * patch_size * patch_size]
965+ int64_t patchified_img_size;
966+ for (auto & x : imgs) {
967+ int pad_h = (patch_size - x->ne [0 ] % patch_size) % patch_size;
968+ int pad_w = (patch_size - x->ne [1 ] % patch_size) % patch_size;
969+ ggml_tensor* pad_x = ggml_pad (ctx, x, pad_w, pad_h, 0 , 0 );
970+ pad_x = patchify (ctx, pad_x, patch_size);
971+ if (img) {
972+ img = ggml_concat (ctx, img, pad_x, 1 );
973+ } else {
974+ img = pad_x;
975+ patchified_img_size = img->ne [1 ];
976+ }
977+ }
967978 if (version == VERSION_FLUX_FILL) {
968979 GGML_ASSERT (c_concat != NULL );
969980 ggml_tensor* masked = ggml_view_4d (ctx, c_concat, c_concat->ne [0 ], c_concat->ne [1 ], C, 1 , c_concat->nb [1 ], c_concat->nb [2 ], c_concat->nb [3 ], 0 );
@@ -999,10 +1010,6 @@ namespace Flux {
9991010 control = patchify (ctx, control, patch_size);
10001011
10011012 img = ggml_concat (ctx, img, control, 0 );
1002- } else if (kontext_concat && c_concat != NULL ) {
1003- ggml_tensor* kontext = ggml_pad (ctx, c_concat, pad_w, pad_h, 0 , 0 );
1004- kontext = patchify (ctx, kontext, patch_size);
1005- img = ggml_concat (ctx, img, kontext, 1 );
10061013 }
10071014
10081015 auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size]
@@ -1097,8 +1104,8 @@ namespace Flux {
10971104 struct ggml_tensor * c_concat,
10981105 struct ggml_tensor * y,
10991106 struct ggml_tensor * guidance,
1100- bool kontext_concat = false ,
1101- std::vector<int > skip_layers = std::vector<int >()) {
1107+ std::vector< struct ggml_tensor *> kontext_imgs = std::vector< struct ggml_tensor *>() ,
1108+ std::vector<int> skip_layers = std::vector<int>()) {
11021109 GGML_ASSERT (x->ne [3 ] == 1 );
11031110 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
11041111
@@ -1109,6 +1116,9 @@ namespace Flux {
11091116 if (c_concat != NULL ) {
11101117 c_concat = to_backend (c_concat);
11111118 }
1119+ for (auto &img : kontext_imgs){
1120+ img = to_backend (img);
1121+ }
11121122 if (flux_params.is_chroma ) {
11131123 const char * SD_CHROMA_ENABLE_GUIDANCE = getenv (" SD_CHROMA_ENABLE_GUIDANCE" );
11141124 bool disable_guidance = true ;
@@ -1148,11 +1158,8 @@ namespace Flux {
11481158 if (flux_params.guidance_embed || flux_params.is_chroma ) {
11491159 guidance = to_backend (guidance);
11501160 }
1151-
1152- std::vector<struct ggml_tensor *> imgs{x};
1153- if (kontext_concat && c_concat != NULL ) {
1154- imgs.push_back (c_concat);
1155- }
1161+ auto imgs = kontext_imgs;
1162+ imgs.insert (imgs.begin (), x);
11561163
11571164 pe_vec = flux.gen_pe (imgs, context, 2 , flux_params.theta , flux_params.axes_dim );
11581165 int pos_len = pe_vec.size () / flux_params.axes_dim_sum / 2 ;
@@ -1175,14 +1182,13 @@ namespace Flux {
11751182 // }
11761183
11771184 struct ggml_tensor * out = flux.forward (compute_ctx,
1178- x ,
1185+ imgs ,
11791186 timesteps,
11801187 context,
11811188 c_concat,
11821189 y,
11831190 guidance,
11841191 pe,
1185- kontext_concat,
11861192 precompute_arange,
11871193 skip_layers,
11881194 version);
@@ -1199,17 +1205,17 @@ namespace Flux {
11991205 struct ggml_tensor * c_concat,
12001206 struct ggml_tensor * y,
12011207 struct ggml_tensor * guidance,
1202- bool kontext_concat = false ,
1203- struct ggml_tensor ** output = NULL ,
1204- struct ggml_context * output_ctx = NULL ,
1205- std::vector<int > skip_layers = std::vector<int >()) {
1208+ std::vector< struct ggml_tensor *> kontext_imgs = std::vector< struct ggml_tensor *>() ,
1209+ struct ggml_tensor** output = NULL,
1210+ struct ggml_context* output_ctx = NULL,
1211+ std::vector<int> skip_layers = std::vector<int>()) {
12061212 // x: [N, in_channels, h, w]
12071213 // timesteps: [N, ]
12081214 // context: [N, max_position, hidden_size]
12091215 // y: [N, adm_in_channels] or [1, adm_in_channels]
12101216 // guidance: [N, ]
12111217 auto get_graph = [&]() -> struct ggml_cgraph * {
1212- return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_concat , skip_layers);
1218+ return build_graph(x, timesteps, context, c_concat, y, guidance, kontext_imgs , skip_layers);
12131219 };
12141220
12151221 return GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
@@ -1249,7 +1255,7 @@ namespace Flux {
12491255 struct ggml_tensor * out = NULL ;
12501256
12511257 int t0 = ggml_time_ms ();
1252- compute (8 , x, timesteps, context, NULL , y, guidance, false , &out, work_ctx);
1258+ compute (8 , x, timesteps, context, NULL , y, guidance, std::vector< struct ggml_tensor *>() , &out, work_ctx);
12531259 int t1 = ggml_time_ms ();
12541260
12551261 LOG_DEBUG (" flux test done in %dms" , t1 - t0);
0 commit comments