Skip to content

Commit d1fcf84

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent 4184254 commit d1fcf84

File tree

2 files changed

+129
-55
lines changed

2 files changed

+129
-55
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
603603

604604
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
605605

606+
__STATIC_INLINE__ void
607+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
608+
609+
int tile_overlap = (tile_size * tile_overlap_factor);
610+
int non_tile_overlap = tile_size - tile_overlap;
611+
612+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
613+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
614+
615+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
616+
// if tiles don't fit perfectly using the desired overlap
617+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
618+
num_tiles_dim++;
619+
}
620+
621+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
622+
if (num_tiles_dim <= 2) {
623+
if (small_dim <= tile_size) {
624+
num_tiles_dim = 1;
625+
tile_overlap_factor_dim = 0;
626+
} else {
627+
num_tiles_dim = 2;
628+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
629+
}
630+
}
631+
}
632+
606633
// Tiling
607-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
634+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
635+
const int p_tile_size_x, const int p_tile_size_y,
636+
const float tile_overlap_factor, on_tile_process on_processing) {
637+
608638
output = ggml_set_f32(output, 0);
609639

610640
int input_width = (int)input->ne[0];
@@ -625,62 +655,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
625655
small_height = input_height;
626656
}
627657

628-
int tile_overlap = (tile_size * tile_overlap_factor);
629-
int non_tile_overlap = tile_size - tile_overlap;
630-
631-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
632-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
633-
634-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
635-
// if tiles don't fit perfectly using the desired overlap
636-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
637-
num_tiles_x++;
638-
}
639-
640-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
641-
if (num_tiles_x <= 2) {
642-
if (small_width <= tile_size) {
643-
num_tiles_x = 1;
644-
tile_overlap_factor_x = 0;
645-
} else {
646-
num_tiles_x = 2;
647-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
648-
}
649-
}
650-
651-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
652-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
658+
int num_tiles_x;
659+
float tile_overlap_factor_x;
660+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
653661

654-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
655-
// if tiles don't fit perfectly using the desired overlap
656-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
657-
num_tiles_y++;
658-
}
659-
660-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
661-
if (num_tiles_y <= 2) {
662-
if (small_height <= tile_size) {
663-
num_tiles_y = 1;
664-
tile_overlap_factor_y = 0;
665-
} else {
666-
num_tiles_y = 2;
667-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
668-
}
669-
}
662+
int num_tiles_y;
663+
float tile_overlap_factor_y;
664+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
670665

671666
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
672667
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
673668

674669
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
675670

676-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
677-
int non_tile_overlap_x = tile_size - tile_overlap_x;
671+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
672+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
678673

679-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
680-
int non_tile_overlap_y = tile_size - tile_overlap_y;
674+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
675+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
681676

682-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
683-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
677+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
678+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
684679

685680
int input_tile_size_x = tile_size_x;
686681
int input_tile_size_y = tile_size_y;
@@ -769,6 +764,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
769764
ggml_free(tiles_ctx);
770765
}
771766

767+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
768+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
769+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
770+
}
771+
772772
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
773773
struct ggml_tensor* a) {
774774
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,23 +1094,91 @@ class StableDiffusionGGML {
10941094
x->ne[3]); // channels
10951095
int64_t t0 = ggml_time_ms();
10961096

1097-
int tile_size = 32;
1098-
// TODO: arg instead of env?
1097+
// TODO: args instead of env for tile size / overlap?
1098+
1099+
float tile_overlap = 0.5f;
1100+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1101+
if (SD_TILE_OVERLAP != nullptr) {
1102+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1103+
try {
1104+
tile_overlap = std::stof(sd_tile_overlap_str);
1105+
if (tile_overlap < 0.0) {
1106+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1107+
tile_overlap = 0.0;
1108+
}
1109+
else if (tile_overlap > 0.5) {
1110+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1111+
tile_overlap = 0.5;
1112+
}
1113+
} catch (const std::invalid_argument&) {
1114+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1115+
} catch (const std::out_of_range&) {
1116+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1117+
}
1118+
}
1119+
1120+
int tile_size_x = 32;
1121+
int tile_size_y = 32;
10991122
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
11001123
if (SD_TILE_SIZE != nullptr) {
1124+
// format is AxB, or just A (equivalent to AxA)
1125+
// A and B can be integers (tile size) or floating point
1126+
// floating point <= 1 means simple fraction of the latent dimension
1127+
// floating point > 1 means number of tiles across that dimension
1128+
// a single number gets applied to both
1129+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1130+
float factor = std::stof(factor_str);
1131+
if (factor > 1.0)
1132+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1133+
return factor;
1134+
};
1135+
const int latent_x = W / (decode ? 1 : 8);
1136+
const int latent_y = H / (decode ? 1 : 8);
1137+
const int min_tile_dimension = 4;
11011138
std::string sd_tile_size_str = SD_TILE_SIZE;
1139+
size_t x_pos = sd_tile_size_str.find('x');
11021140
try {
1103-
tile_size = std::stoi(sd_tile_size_str);
1141+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1142+
if (x_pos != std::string::npos) {
1143+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1144+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1145+
if (tile_x_str.find('.') != std::string::npos) {
1146+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1147+
}
1148+
else {
1149+
tmp_x = std::stoi(tile_x_str);
1150+
}
1151+
if (tile_y_str.find('.') != std::string::npos) {
1152+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1153+
}
1154+
else {
1155+
tmp_y = std::stoi(tile_y_str);
1156+
}
1157+
}
1158+
else {
1159+
if (sd_tile_size_str.find('.') != std::string::npos) {
1160+
float tile_factor = get_tile_factor(sd_tile_size_str);
1161+
tmp_x = std::round(latent_x * tile_factor);
1162+
tmp_y = std::round(latent_y * tile_factor);
1163+
}
1164+
else {
1165+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1166+
}
1167+
}
1168+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1169+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
11041170
} catch (const std::invalid_argument&) {
1105-
LOG_WARN("Invalid");
1171+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
11061172
} catch (const std::out_of_range&) {
1107-
LOG_WARN("OOR");
1173+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
11081174
}
11091175
}
1176+
11101177
if(!decode){
11111178
// TODO: also use and arg for this one?
11121179
// to keep the compute buffer size consistent
1113-
tile_size*=1.30539;
1180+
tile_size_x*=1.30539;
1181+
tile_size_y*=1.30539;
11141182
}
11151183
if (!use_tiny_autoencoder) {
11161184
if (decode) {
@@ -1119,11 +1187,17 @@ class StableDiffusionGGML {
11191187
ggml_tensor_scale_input(x);
11201188
}
11211189
if (vae_tiling) {
1190+
if (SD_TILE_SIZE != nullptr) {
1191+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
1192+
}
1193+
if (SD_TILE_OVERLAP != nullptr) {
1194+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1195+
}
11221196
// split latent in 32x32 tiles and compute in several steps
11231197
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
11241198
first_stage_model->compute(n_threads, in, decode, &out);
11251199
};
1126-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1200+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
11271201
} else {
11281202
first_stage_model->compute(n_threads, x, decode, &result);
11291203
}

0 commit comments

Comments
 (0)