Skip to content

Commit a3a8674

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent a3fabad commit a3a8674

File tree

2 files changed

+131
-55
lines changed

2 files changed

+131
-55
lines changed

ggml_extend.hpp

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,40 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
598598

599599
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
600600

601+
__STATIC_INLINE__ void
602+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
603+
604+
int tile_overlap = (tile_size * tile_overlap_factor);
605+
int non_tile_overlap = tile_size - tile_overlap;
606+
607+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
608+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
609+
610+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
611+
// if tiles don't fit perfectly using the desired overlap
612+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
613+
num_tiles_dim++;
614+
}
615+
616+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
617+
if (num_tiles_dim <= 2) {
618+
if (small_dim <= tile_size) {
619+
num_tiles_dim = 1;
620+
tile_overlap_factor_dim = 0;
621+
} else {
622+
num_tiles_dim = 2;
623+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
624+
}
625+
}
626+
}
627+
601628
// Tiling
602-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
629+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
630+
const int p_tile_size_x, const int p_tile_size_y,
631+
const float tile_overlap_factor, on_tile_process on_processing) {
632+
633+
output = ggml_set_f32(output, 0);
634+
603635
int input_width = (int)input->ne[0];
604636
int input_height = (int)input->ne[1];
605637
int output_width = (int)output->ne[0];
@@ -618,62 +650,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
618650
small_height = input_height;
619651
}
620652

621-
int tile_overlap = (tile_size * tile_overlap_factor);
622-
int non_tile_overlap = tile_size - tile_overlap;
653+
int num_tiles_x;
654+
float tile_overlap_factor_x;
655+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
623656

624-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
625-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
626-
627-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
628-
// if tiles don't fit perfectly using the desired overlap
629-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
630-
num_tiles_x++;
631-
}
632-
633-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
634-
if (num_tiles_x <= 2) {
635-
if (small_width <= tile_size) {
636-
num_tiles_x = 1;
637-
tile_overlap_factor_x = 0;
638-
} else {
639-
num_tiles_x = 2;
640-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
641-
}
642-
}
643-
644-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
645-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
646-
647-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
648-
// if tiles don't fit perfectly using the desired overlap
649-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
650-
num_tiles_y++;
651-
}
652-
653-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
654-
if (num_tiles_y <= 2) {
655-
if (small_height <= tile_size) {
656-
num_tiles_y = 1;
657-
tile_overlap_factor_y = 0;
658-
} else {
659-
num_tiles_y = 2;
660-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
661-
}
662-
}
657+
int num_tiles_y;
658+
float tile_overlap_factor_y;
659+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
663660

664661
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
665662
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
666663

667664
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
668665

669-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
670-
int non_tile_overlap_x = tile_size - tile_overlap_x;
666+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
667+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
671668

672-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
673-
int non_tile_overlap_y = tile_size - tile_overlap_y;
669+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
670+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
674671

675-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
676-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
672+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
673+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
677674

678675
int input_tile_size_x = tile_size_x;
679676
int input_tile_size_y = tile_size_y;
@@ -762,6 +759,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
762759
ggml_free(tiles_ctx);
763760
}
764761

762+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
763+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
764+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
765+
}
766+
765767
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
766768
struct ggml_tensor* a) {
767769
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,23 +1055,91 @@ class StableDiffusionGGML {
10551055
x->ne[3]); // channels
10561056
int64_t t0 = ggml_time_ms();
10571057

1058-
int tile_size = 32;
1059-
// TODO: arg instead of env?
1058+
// TODO: args instead of env for tile size / overlap?
1059+
1060+
float tile_overlap = 0.5f;
1061+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1062+
if (SD_TILE_OVERLAP != nullptr) {
1063+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1064+
try {
1065+
tile_overlap = std::stof(sd_tile_overlap_str);
1066+
if (tile_overlap < 0.0) {
1067+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1068+
tile_overlap = 0.0;
1069+
}
1070+
else if (tile_overlap > 0.5) {
1071+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1072+
tile_overlap = 0.5;
1073+
}
1074+
} catch (const std::invalid_argument&) {
1075+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1076+
} catch (const std::out_of_range&) {
1077+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1078+
}
1079+
}
1080+
1081+
int tile_size_x = 32;
1082+
int tile_size_y = 32;
10601083
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
10611084
if (SD_TILE_SIZE != nullptr) {
1085+
// format is AxB, or just A (equivalent to AxA)
1086+
// A and B can be integers (tile size) or floating point
1087+
// floating point <= 1 means simple fraction of the latent dimension
1088+
// floating point > 1 means number of tiles across that dimension
1089+
// a single number gets applied to both
1090+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1091+
float factor = std::stof(factor_str);
1092+
if (factor > 1.0)
1093+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1094+
return factor;
1095+
};
1096+
const int latent_x = W / (decode ? 1 : 8);
1097+
const int latent_y = H / (decode ? 1 : 8);
1098+
const int min_tile_dimension = 4;
10621099
std::string sd_tile_size_str = SD_TILE_SIZE;
1100+
size_t x_pos = sd_tile_size_str.find('x');
10631101
try {
1064-
tile_size = std::stoi(sd_tile_size_str);
1102+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1103+
if (x_pos != std::string::npos) {
1104+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1105+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1106+
if (tile_x_str.find('.') != std::string::npos) {
1107+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1108+
}
1109+
else {
1110+
tmp_x = std::stoi(tile_x_str);
1111+
}
1112+
if (tile_y_str.find('.') != std::string::npos) {
1113+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1114+
}
1115+
else {
1116+
tmp_y = std::stoi(tile_y_str);
1117+
}
1118+
}
1119+
else {
1120+
if (sd_tile_size_str.find('.') != std::string::npos) {
1121+
float tile_factor = get_tile_factor(sd_tile_size_str);
1122+
tmp_x = std::round(latent_x * tile_factor);
1123+
tmp_y = std::round(latent_y * tile_factor);
1124+
}
1125+
else {
1126+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1127+
}
1128+
}
1129+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1130+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
10651131
} catch (const std::invalid_argument&) {
1066-
LOG_WARN("Invalid");
1132+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
10671133
} catch (const std::out_of_range&) {
1068-
LOG_WARN("OOR");
1134+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
10691135
}
10701136
}
1137+
10711138
if(!decode){
10721139
// TODO: also use and arg for this one?
10731140
// to keep the compute buffer size consistent
1074-
tile_size*=1.30539;
1141+
tile_size_x*=1.30539;
1142+
tile_size_y*=1.30539;
10751143
}
10761144
if (!use_tiny_autoencoder) {
10771145
if (decode) {
@@ -1080,11 +1148,17 @@ class StableDiffusionGGML {
10801148
ggml_tensor_scale_input(x);
10811149
}
10821150
if (vae_tiling) {
1151+
if (SD_TILE_SIZE != nullptr) {
1152+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
1153+
}
1154+
if (SD_TILE_OVERLAP != nullptr) {
1155+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1156+
}
10831157
// split latent in 32x32 tiles and compute in several steps
10841158
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
10851159
first_stage_model->compute(n_threads, in, decode, &out);
10861160
};
1087-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1161+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
10881162
} else {
10891163
first_stage_model->compute(n_threads, x, decode, &result);
10901164
}

0 commit comments

Comments
 (0)