Skip to content

Commit 925e7a0

Browse files
committed
vae tiling: refactor again, base on smaller buffer for alignment
1 parent f57fc8c commit 925e7a0

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

ggml_extend.hpp

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -604,62 +604,67 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
604604
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
605605

606606
// Tiling
607-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
607+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
608608
output = ggml_set_f32(output, 0);
609609

610610
int input_width = (int)input->ne[0];
611611
int input_height = (int)input->ne[1];
612612
int output_width = (int)output->ne[0];
613613
int output_height = (int)output->ne[1];
614614

615-
int input_tile_size, output_tile_size;
616-
if (scaled_out) {
617-
input_tile_size = tile_size;
618-
output_tile_size = tile_size * scale;
619-
} else {
620-
input_tile_size = tile_size * scale;
621-
output_tile_size = tile_size;
615+
GGML_ASSERT(input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
616+
GGML_ASSERT(input_width / output_width == scale || output_width / input_width == scale);
617+
618+
int small_width = output_width;
619+
int small_height = output_height;
620+
621+
bool big_out = output_width > input_width;
622+
if (big_out) {
623+
// Ex: decode
624+
small_width = input_width;
625+
small_height = input_height;
622626
}
623-
int tile_overlap = (input_tile_size * tile_overlap_factor);
624-
int non_tile_overlap = input_tile_size - tile_overlap;
625627

626-
int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
627-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % input_width;
628+
int tile_overlap = (tile_size * tile_overlap_factor);
629+
int non_tile_overlap = tile_size - tile_overlap;
630+
631+
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
632+
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
628633

629-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
634+
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
630635
// if tiles don't fit perfectly using the desired overlap
631636
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
632637
num_tiles_x++;
633638
}
634639

635-
float tile_overlap_factor_x = (float)(input_tile_size * num_tiles_x - input_width) / (float)(input_tile_size * (num_tiles_x - 1));
640+
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
636641
if (num_tiles_x <= 2) {
637-
if (input_width <= input_tile_size) {
642+
if (small_width <= tile_size) {
638643
num_tiles_x = 1;
639644
tile_overlap_factor_x = 0;
640645
} else {
641646
num_tiles_x = 2;
642-
tile_overlap_factor_x = (2 * input_tile_size - input_width) / (float)input_tile_size;
647+
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
643648
}
644649
}
645650

646-
int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
647-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % input_height;
651+
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
652+
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
648653

649-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
654+
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
650655
// if tiles don't fit perfectly using the desired overlap
651656
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
652657
num_tiles_y++;
653658
}
654659

655-
float tile_overlap_factor_y = (float)(input_tile_size * num_tiles_y - input_height) / (float)(input_tile_size * (num_tiles_y - 1));
660+
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
656661
if (num_tiles_y <= 2) {
657-
if (input_height <= input_tile_size) {
662+
if (small_height <= tile_size) {
658663
num_tiles_y = 1;
659664
tile_overlap_factor_y = 0;
660665
} else {
661666
num_tiles_y = 2;
662-
tile_overlap_factor_y = (2 * input_tile_size - input_height) / (float)input_tile_size;
667+
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
663668
}
664669
}
665670

@@ -668,11 +673,20 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
668673

669674
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
670675

671-
int tile_overlap_x = (int32_t)(input_tile_size * tile_overlap_factor_x);
672-
int non_tile_overlap_x = input_tile_size - tile_overlap_x;
676+
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
677+
int non_tile_overlap_x = tile_size - tile_overlap_x;
673678

674-
int tile_overlap_y = (int32_t)(input_tile_size * tile_overlap_factor_y);
675-
int non_tile_overlap_y = input_tile_size - tile_overlap_y;
679+
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
680+
int non_tile_overlap_y = tile_size - tile_overlap_y;
681+
682+
int input_tile_size = tile_size;
683+
int output_tile_size = tile_size;
684+
685+
if (big_out) {
686+
output_tile_size *= scale;
687+
} else {
688+
input_tile_size *= scale;
689+
}
676690

677691
struct ggml_init_params params = {};
678692
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
@@ -693,37 +707,48 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
693707
// tiling
694708
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
695709
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
696-
on_processing(input_tile, NULL, true);
697710
int num_tiles = num_tiles_x * num_tiles_y;
698711
LOG_INFO("processing %i tiles", num_tiles);
699-
pretty_progress(1, num_tiles, 0.0f);
712+
pretty_progress(0, num_tiles, 0.0f);
700713
int tile_count = 1;
701714
bool last_y = false, last_x = false;
702715
float last_time = 0.0f;
703-
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap_y) {
716+
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
704717
int dy = 0;
705-
if (y + input_tile_size >= input_height) {
718+
if (y + tile_size >= small_height) {
706719
int _y = y;
707-
y = input_height - input_tile_size;
720+
y = small_height - tile_size;
708721
dy = _y - y;
722+
if (big_out) {
723+
dy *= scale;
724+
}
709725
last_y = true;
710726
}
711-
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap_x) {
727+
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
712728
int dx = 0;
713-
if (x + input_tile_size >= input_width) {
729+
if (x + tile_size >= small_width) {
714730
int _x = x;
715-
x = input_width - input_tile_size;
731+
x = small_width - tile_size;
716732
dx = _x - x;
733+
if (big_out) {
734+
dx *= scale;
735+
}
717736
last_x = true;
718737
}
738+
739+
int x_in = big_out ? x : scale * x;
740+
int y_in = big_out ? y : scale * y;
741+
int x_out = big_out ? x * scale : x;
742+
int y_out = big_out ? y * scale : y;
743+
744+
int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
745+
int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;
746+
719747
int64_t t1 = ggml_time_ms();
720-
ggml_split_tensor_2d(input, input_tile, x, y);
748+
ggml_split_tensor_2d(input, input_tile, x_in, y_in);
721749
on_processing(input_tile, output_tile, false);
722-
if (scaled_out) {
723-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
724-
} else {
725-
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
726-
}
750+
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
751+
727752
int64_t t2 = ggml_time_ms();
728753
last_time = (t2 - t1) / 1000.0f;
729754
pretty_progress(tile_count, num_tiles, last_time);

stable-diffusion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,7 @@ class StableDiffusionGGML {
11181118
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
11191119
first_stage_model->compute(n_threads, in, decode, &out);
11201120
};
1121-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling, decode);
1121+
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
11221122
} else {
11231123
first_stage_model->compute(n_threads, x, decode, &result);
11241124
}
@@ -1132,7 +1132,7 @@ class StableDiffusionGGML {
11321132
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
11331133
tae_first_stage->compute(n_threads, in, decode, &out);
11341134
};
1135-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, decode);
1135+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
11361136
} else {
11371137
tae_first_stage->compute(n_threads, x, decode, &result);
11381138
}

0 commit comments

Comments
 (0)