Skip to content

Commit 739ac40

Browse files
wbrunastduhpf
authored andcommitted
non-square VAE tiling (#3)
* refactor tile number calculation * support non-square tiles * add env var to change tile overlap * add safeguards and better error messages for SD_TILE_OVERLAP * add safeguards and include overlapping factor for SD_TILE_SIZE * avoid rounding issues when specifying SD_TILE_SIZE as a factor * lower SD_TILE_OVERLAP limit * zero-init empty output buffer
1 parent af6c43c commit 739ac40

File tree

2 files changed

+143
-77
lines changed

2 files changed

+143
-77
lines changed

ggml_extend.hpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,38 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
736736

737737
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
738738

739+
__STATIC_INLINE__ void
740+
sd_tiling_calc_tiles(int &num_tiles_dim, float& tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
741+
742+
int tile_overlap = (tile_size * tile_overlap_factor);
743+
int non_tile_overlap = tile_size - tile_overlap;
744+
745+
num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
746+
int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim;
747+
748+
if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
749+
// if tiles don't fit perfectly using the desired overlap
750+
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
751+
num_tiles_dim++;
752+
}
753+
754+
tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1));
755+
if (num_tiles_dim <= 2) {
756+
if (small_dim <= tile_size) {
757+
num_tiles_dim = 1;
758+
tile_overlap_factor_dim = 0;
759+
} else {
760+
num_tiles_dim = 2;
761+
tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size;
762+
}
763+
}
764+
}
765+
739766
// Tiling
740-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
767+
__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* output, const int scale,
768+
const int p_tile_size_x, const int p_tile_size_y,
769+
const float tile_overlap_factor, on_tile_process on_processing) {
770+
741771
output = ggml_set_f32(output, 0);
742772

743773
int input_width = (int)input->ne[0];
@@ -758,62 +788,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
758788
small_height = input_height;
759789
}
760790

761-
int tile_overlap = (tile_size * tile_overlap_factor);
762-
int non_tile_overlap = tile_size - tile_overlap;
763-
764-
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
765-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
766-
767-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
768-
// if tiles don't fit perfectly using the desired overlap
769-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
770-
num_tiles_x++;
771-
}
772-
773-
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
774-
if (num_tiles_x <= 2) {
775-
if (small_width <= tile_size) {
776-
num_tiles_x = 1;
777-
tile_overlap_factor_x = 0;
778-
} else {
779-
num_tiles_x = 2;
780-
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
781-
}
782-
}
783-
784-
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
785-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
791+
int num_tiles_x;
792+
float tile_overlap_factor_x;
793+
sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
786794

787-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
788-
// if tiles don't fit perfectly using the desired overlap
789-
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
790-
num_tiles_y++;
791-
}
792-
793-
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
794-
if (num_tiles_y <= 2) {
795-
if (small_height <= tile_size) {
796-
num_tiles_y = 1;
797-
tile_overlap_factor_y = 0;
798-
} else {
799-
num_tiles_y = 2;
800-
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
801-
}
802-
}
795+
int num_tiles_y;
796+
float tile_overlap_factor_y;
797+
sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
803798

804799
LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y);
805800
LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
806801

807802
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
808803

809-
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
810-
int non_tile_overlap_x = tile_size - tile_overlap_x;
804+
int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x);
805+
int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
811806

812-
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
813-
int non_tile_overlap_y = tile_size - tile_overlap_y;
807+
int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y);
808+
int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
814809

815-
int tile_size_x = tile_size < small_width ? tile_size : small_width;
816-
int tile_size_y = tile_size < small_height ? tile_size : small_height;
810+
int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
811+
int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
817812

818813
int input_tile_size_x = tile_size_x;
819814
int input_tile_size_y = tile_size_y;
@@ -902,6 +897,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
902897
ggml_free(tiles_ctx);
903898
}
904899

900+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale,
901+
const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
902+
sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
903+
}
904+
905905
__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
906906
struct ggml_tensor* a) {
907907
const float eps = 1e-6f; // default eps parameter

stable-diffusion.cpp

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,41 +1295,113 @@ class StableDiffusionGGML {
12951295
return latent;
12961296
}
12971297

1298-
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1299-
int64_t t0 = ggml_time_ms();
1300-
ggml_tensor* result = NULL;
1301-
int tile_size = 32;
1302-
// TODO: arg instead of env?
1298+
void get_vae_tile_overlap(float& tile_overlap) {
1299+
const char* SD_TILE_OVERLAP = getenv("SD_TILE_OVERLAP");
1300+
if (SD_TILE_OVERLAP != nullptr) {
1301+
std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1302+
try {
1303+
tile_overlap = std::stof(sd_tile_overlap_str);
1304+
if (tile_overlap < 0.0) {
1305+
LOG_WARN("SD_TILE_OVERLAP too low, setting it to 0.0");
1306+
tile_overlap = 0.0;
1307+
} else if (tile_overlap > 0.5) {
1308+
LOG_WARN("SD_TILE_OVERLAP too high, setting it to 0.5");
1309+
tile_overlap = 0.5;
1310+
}
1311+
} catch (const std::invalid_argument&) {
1312+
LOG_WARN("SD_TILE_OVERLAP is invalid, keeping the default");
1313+
} catch (const std::out_of_range&) {
1314+
LOG_WARN("SD_TILE_OVERLAP is out of range, keeping the default");
1315+
}
1316+
}
1317+
if (SD_TILE_OVERLAP != nullptr) {
1318+
LOG_INFO("VAE Tile overlap: %.2f", tile_overlap);
1319+
}
1320+
}
1321+
1322+
void get_vae_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, int latent_x, int latent_y) {
13031323
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
13041324
if (SD_TILE_SIZE != nullptr) {
1325+
// format is AxB, or just A (equivalent to AxA)
1326+
// A and B can be integers (tile size) or floating point
1327+
// floating point <= 1 means simple fraction of the latent dimension
1328+
// floating point > 1 means number of tiles across that dimension
1329+
// a single number gets applied to both
1330+
auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1331+
float factor = std::stof(factor_str);
1332+
if (factor > 1.0)
1333+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1334+
return factor;
1335+
};
1336+
const int min_tile_dimension = 4;
13051337
std::string sd_tile_size_str = SD_TILE_SIZE;
1338+
size_t x_pos = sd_tile_size_str.find('x');
13061339
try {
1307-
tile_size = std::stoi(sd_tile_size_str);
1340+
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1341+
if (x_pos != std::string::npos) {
1342+
std::string tile_x_str = sd_tile_size_str.substr(0, x_pos);
1343+
std::string tile_y_str = sd_tile_size_str.substr(x_pos + 1);
1344+
if (tile_x_str.find('.') != std::string::npos) {
1345+
tmp_x = std::round(latent_x * get_tile_factor(tile_x_str));
1346+
} else {
1347+
tmp_x = std::stoi(tile_x_str);
1348+
}
1349+
if (tile_y_str.find('.') != std::string::npos) {
1350+
tmp_y = std::round(latent_y * get_tile_factor(tile_y_str));
1351+
} else {
1352+
tmp_y = std::stoi(tile_y_str);
1353+
}
1354+
} else {
1355+
if (sd_tile_size_str.find('.') != std::string::npos) {
1356+
float tile_factor = get_tile_factor(sd_tile_size_str);
1357+
tmp_x = std::round(latent_x * tile_factor);
1358+
tmp_y = std::round(latent_y * tile_factor);
1359+
} else {
1360+
tmp_x = tmp_y = std::stoi(sd_tile_size_str);
1361+
}
1362+
}
1363+
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1364+
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
13081365
} catch (const std::invalid_argument&) {
1309-
LOG_WARN("Invalid");
1366+
LOG_WARN("SD_TILE_SIZE is invalid, keeping the default");
13101367
} catch (const std::out_of_range&) {
1311-
LOG_WARN("OOR");
1368+
LOG_WARN("SD_TILE_SIZE is out of range, keeping the default");
13121369
}
13131370
}
1314-
if(!decode){
1315-
// TODO: also use and arg for this one?
1316-
// to keep the compute buffer size consistent
1317-
tile_size*=1.30539;
1371+
if (SD_TILE_SIZE != nullptr) {
1372+
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13181373
}
1374+
}
1375+
1376+
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
1377+
int64_t t0 = ggml_time_ms();
1378+
ggml_tensor* result = NULL;
1379+
// TODO: args instead of env for tile size / overlap?
13191380
if (!use_tiny_autoencoder) {
1381+
float tile_overlap = 0.5f;
1382+
int tile_size_x = 32;
1383+
int tile_size_y = 32;
1384+
1385+
get_vae_tile_overlap(tile_overlap);
1386+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1387+
1388+
// TODO: also use an arg for this one?
1389+
// multiply tile size for encode to keep the compute buffer size consistent
1390+
tile_size_x *= 1.30539;
1391+
tile_size_y *= 1.30539;
1392+
13201393
process_vae_input_tensor(x);
13211394
if (vae_tiling && !decode_video) {
1322-
// split latent in 32x32 tiles and compute in several steps
13231395
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13241396
first_stage_model->compute(n_threads, in, true, &out, NULL);
13251397
};
1326-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1398+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
13271399
} else {
13281400
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13291401
}
13301402
first_stage_model->free_compute_buffer();
13311403
} else {
1332-
if (vae_tiling && !decode_video) {
1404+
if (vae_tiling && !decode_video) {
13331405
// split latent in 32x32 tiles and compute in several steps
13341406
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13351407
tae_first_stage->compute(n_threads, in, true, &out, NULL);
@@ -1454,29 +1526,23 @@ class StableDiffusionGGML {
14541526
C,
14551527
x->ne[3]);
14561528
}
1457-
int tile_size = 32;
1458-
// TODO: arg instead of env?
1459-
const char* SD_TILE_SIZE = getenv("SD_TILE_SIZE");
1460-
if (SD_TILE_SIZE != nullptr) {
1461-
std::string sd_tile_size_str = SD_TILE_SIZE;
1462-
try {
1463-
tile_size = std::stoi(sd_tile_size_str);
1464-
} catch (const std::invalid_argument&) {
1465-
LOG_WARN("Invalid");
1466-
} catch (const std::out_of_range&) {
1467-
LOG_WARN("OOR");
1468-
}
1469-
}
14701529
int64_t t0 = ggml_time_ms();
14711530
if (!use_tiny_autoencoder) {
1531+
float tile_overlap = 0.5f;
1532+
int tile_size_x = 32;
1533+
int tile_size_y = 32;
1534+
1535+
get_vae_tile_overlap(tile_overlap);
1536+
get_vae_tile_sizes(tile_size_x, tile_size_y, tile_overlap, x->ne[0] / 8, x->ne[1] / 8);
1537+
14721538
process_latent_out(x);
14731539
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14741540
if (vae_tiling && !decode_video) {
14751541
// split latent in 32x32 tiles and compute in several steps
14761542
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14771543
first_stage_model->compute(n_threads, in, true, &out, NULL);
14781544
};
1479-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
1545+
sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling);
14801546
} else {
14811547
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14821548
}

0 commit comments

Comments
 (0)