Skip to content

Commit cc05599

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent 92f784b commit cc05599

File tree

2 files changed

+143
-77
lines changed

2 files changed

+143
-77
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,38 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
736736

737737
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
738738

739+
__STATIC_INLINE__ void
740+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
741+
742+
int tile_overlap = (tile_size * tile_overlap_factor);
743+
int non_tile_overlap = tile_size - tile_overlap;
744+
745+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
746+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
747+
748+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
749+
// if tiles don't fit perfectly using the desired overlap
750+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
751+
num_tiles_dim++;
752+
}
753+
754+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
755+
if (num_tiles_dim <= 2) {
756+
if (small_dim <= tile_size) {
757+
num_tiles_dim = 1;
758+
tile_overlap_factor_dim = 0;
759+
} else {
760+
num_tiles_dim = 2;
761+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
762+
}
763+
}
764+
}
765+
739766
// Tiling
740-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
767+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
768+
const int p_tile_size_x, const int p_tile_size_y,
769+
const float tile_overlap_factor, on_tile_process on_processing) {
770+
741771
output = ggml_set_f32(output, 0);
742772

743773
int input_width = (int)input->ne[0];
@@ -758,62 +788,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
758788
small_height = input_height;
759789
}
760790

761-
int tile_overlap = (tile_size * tile_overlap_factor);
762-
int non_tile_overlap = tile_size - tile_overlap;
763-
764-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
765-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
766-
767-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
768-
// if tiles don't fit perfectly using the desired overlap
769-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
770-
num_tiles_x++;
771-
}
772-
773-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
774-
if (num_tiles_x <= 2) {
775-
if (small_width <= tile_size) {
776-
num_tiles_x = 1;
777-
tile_overlap_factor_x = 0;
778-
} else {
779-
num_tiles_x = 2;
780-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
781-
}
782-
}
783-
784-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
785-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
791+
int num_tiles_x;
792+
float tile_overlap_factor_x;
793+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
786794

787-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
788-
// if tiles don't fit perfectly using the desired overlap
789-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
790-
num_tiles_y++;
791-
}
792-
793-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
794-
if (num_tiles_y <= 2) {
795-
if (small_height <= tile_size) {
796-
num_tiles_y = 1;
797-
tile_overlap_factor_y = 0;
798-
} else {
799-
num_tiles_y = 2;
800-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
801-
}
802-
}
795+
int num_tiles_y;
796+
float tile_overlap_factor_y;
797+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
803798

804799
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
805800
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
806801

807802
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
808803

809-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
810-
int non_tile_overlap_x = tile_size - tile_overlap_x;
804+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
805+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
811806

812-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
813-
int non_tile_overlap_y = tile_size - tile_overlap_y;
807+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
808+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
814809

815-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
816-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
810+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
811+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
817812

818813
int input_tile_size_x = tile_size_x;
819814
int input_tile_size_y = tile_size_y;
@@ -902,6 +897,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
902897
ggml_free(tiles_ctx);
903898
}
904899

900+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
901+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
902+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
903+
}
904+
905905
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
906906
struct ggml_tensor* a) {
907907
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,41 +1289,113 @@ class StableDiffusionGGML {
12891289
return latent;
12901290
}
12911291

1292-
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1293-
int64_t t0 = ggml_time_ms();
1294-
ggml_tensor* result = NULL;
1295-
int tile_size = 32;
1296-
// TODO: arg instead of env?
1292+
void get_vae_tile_overlap(float& tile_overlap) {
1293+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1294+
if (SD_TILE_OVERLAP != nullptr) {
1295+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1296+
try {
1297+
tile_overlap = std::stof(sd_tile_overlap_str);
1298+
if (tile_overlap < 0.0) {
1299+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1300+
tile_overlap = 0.0;
1301+
} else if (tile_overlap > 0.5) {
1302+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1303+
tile_overlap = 0.5;
1304+
}
1305+
} catch (const std::invalid_argument&) {
1306+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1307+
} catch (const std::out_of_range&) {
1308+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1309+
}
1310+
}
1311+
if (SD_TILE_OVERLAP != nullptr) {
1312+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1313+
}
1314+
}
1315+
1316+
void get_vae_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, int latent_x, int latent_y) {
12971317
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
12981318
if (SD_TILE_SIZE != nullptr) {
1319+
// format is AxB, or just A (equivalent to AxA)
1320+
// A and B can be integers (tile size) or floating point
1321+
// floating point <= 1 means simple fraction of the latent dimension
1322+
// floating point > 1 means number of tiles across that dimension
1323+
// a single number gets applied to both
1324+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1325+
float factor = std::stof(factor_str);
1326+
if (factor > 1.0)
1327+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1328+
return factor;
1329+
};
1330+
const int min_tile_dimension = 4;
12991331
std::string sd_tile_size_str = SD_TILE_SIZE;
1332+
size_t x_pos = sd_tile_size_str.find('x');
13001333
try {
1301-
tile_size = std::stoi(sd_tile_size_str);
1334+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1335+
if (x_pos != std::string::npos) {
1336+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1337+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1338+
if (tile_x_str.find('.') != std::string::npos) {
1339+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1340+
} else {
1341+
tmp_x = std::stoi(tile_x_str);
1342+
}
1343+
if (tile_y_str.find('.') != std::string::npos) {
1344+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1345+
} else {
1346+
tmp_y = std::stoi(tile_y_str);
1347+
}
1348+
} else {
1349+
if (sd_tile_size_str.find('.') != std::string::npos) {
1350+
float tile_factor = get_tile_factor(sd_tile_size_str);
1351+
tmp_x = std::round(latent_x * tile_factor);
1352+
tmp_y = std::round(latent_y * tile_factor);
1353+
} else {
1354+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1355+
}
1356+
}
1357+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1358+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
13021359
} catch (const std::invalid_argument&) {
1303-
LOG_WARN("Invalid");
1360+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
13041361
} catch (const std::out_of_range&) {
1305-
LOG_WARN("OOR");
1362+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
13061363
}
13071364
}
1308-
if(!decode){
1309-
// TODO: also use and arg for this one?
1310-
// to keep the compute buffer size consistent
1311-
tile_size*=1.30539;
1365+
if (SD_TILE_SIZE != nullptr) {
1366+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13121367
}
1368+
}
1369+
1370+
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1371+
int64_t t0 = ggml_time_ms();
1372+
ggml_tensor* result = NULL;
1373+
// TODO: args instead of env for tile size / overlap?
13131374
if (!use_tiny_autoencoder) {
1375+
float tile_overlap = 0.5f;
1376+
int tile_size_x = 32;
1377+
int tile_size_y = 32;
1378+
1379+
get_vae_tile_overlap(tile_overlap);
1380+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1381+
1382+
// TODO: also use an arg for this one?
1383+
// multiply tile size for encode to keep the compute buffer size consistent
1384+
tile_size_x *= 1.30539;
1385+
tile_size_y *= 1.30539;
1386+
13141387
process_vae_input_tensor(x);
13151388
if (vae_tiling && !decode_video) {
1316-
// split latent in 32x32 tiles and compute in several steps
13171389
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13181390
first_stage_model->compute(n_threads, in, true, &out, NULL);
13191391
};
1320-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1392+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
13211393
} else {
13221394
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13231395
}
13241396
first_stage_model->free_compute_buffer();
13251397
} else {
1326-
if (vae_tiling && !decode_video) {
1398+
if (vae_tiling && !decode_video) {
13271399
// split latent in 32x32 tiles and compute in several steps
13281400
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13291401
tae_first_stage->compute(n_threads, in, true, &out, NULL);
@@ -1448,29 +1520,23 @@ class StableDiffusionGGML {
14481520
C,
14491521
x->ne[3]);
14501522
}
1451-
int tile_size = 32;
1452-
// TODO: arg instead of env?
1453-
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
1454-
if (SD_TILE_SIZE != nullptr) {
1455-
std::string sd_tile_size_str = SD_TILE_SIZE;
1456-
try {
1457-
tile_size = std::stoi(sd_tile_size_str);
1458-
} catch (const std::invalid_argument&) {
1459-
LOG_WARN("Invalid");
1460-
} catch (const std::out_of_range&) {
1461-
LOG_WARN("OOR");
1462-
}
1463-
}
14641523
int64_t t0 = ggml_time_ms();
14651524
if (!use_tiny_autoencoder) {
1525+
float tile_overlap = 0.5f;
1526+
int tile_size_x = 32;
1527+
int tile_size_y = 32;
1528+
1529+
get_vae_tile_overlap(tile_overlap);
1530+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1531+
14661532
process_latent_out(x);
14671533
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14681534
if (vae_tiling && !decode_video) {
14691535
// split latent in 32x32 tiles and compute in several steps
14701536
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14711537
first_stage_model->compute(n_threads, in, true, &out, NULL);
14721538
};
1473-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1539+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
14741540
} else {
14751541
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14761542
}

0 commit comments

Comments
 (0)