@@ -1289,41 +1289,113 @@ class StableDiffusionGGML {
12891289 return latent;
12901290 }
12911291
1292- ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1293- int64_t t0 = ggml_time_ms ();
1294- ggml_tensor* result = NULL ;
1295- int tile_size = 32 ;
1296- // TODO: arg instead of env?
1292+ void get_vae_tile_overlap (float & tile_overlap) {
1293+ const char * SD_TILE_OVERLAP = getenv (" SD_TILE_OVERLAP" );
1294+ if (SD_TILE_OVERLAP != nullptr ) {
1295+ std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1296+ try {
1297+ tile_overlap = std::stof (sd_tile_overlap_str);
1298+ if (tile_overlap < 0.0 ) {
1299+ LOG_WARN (" SD_TILE_OVERLAP too low, setting it to 0.0" );
1300+ tile_overlap = 0.0 ;
1301+ } else if (tile_overlap > 0.5 ) {
1302+ LOG_WARN (" SD_TILE_OVERLAP too high, setting it to 0.5" );
1303+ tile_overlap = 0.5 ;
1304+ }
1305+ } catch (const std::invalid_argument&) {
1306+ LOG_WARN (" SD_TILE_OVERLAP is invalid, keeping the default" );
1307+ } catch (const std::out_of_range&) {
1308+ LOG_WARN (" SD_TILE_OVERLAP is out of range, keeping the default" );
1309+ }
1310+ }
1311+ if (SD_TILE_OVERLAP != nullptr ) {
1312+ LOG_INFO (" VAE Tile overlap: %.2f" , tile_overlap);
1313+ }
1314+ }
1315+
1316+ void get_vae_tile_sizes (int & tile_size_x, int & tile_size_y, float tile_overlap, int latent_x, int latent_y) {
12971317 const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
12981318 if (SD_TILE_SIZE != nullptr ) {
1319+ // format is AxB, or just A (equivalent to AxA)
1320+ // A and B can be integers (tile size) or floating point
1321+ // floating point <= 1 means simple fraction of the latent dimension
1322+ // floating point > 1 means number of tiles across that dimension
1323+ // a single number gets applied to both
1324+ auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1325+ float factor = std::stof (factor_str);
1326+ if (factor > 1.0 )
1327+ factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1328+ return factor;
1329+ };
1330+ const int min_tile_dimension = 4 ;
12991331 std::string sd_tile_size_str = SD_TILE_SIZE;
1332+ size_t x_pos = sd_tile_size_str.find (' x' );
13001333 try {
1301- tile_size = std::stoi (sd_tile_size_str);
1334+ int tmp_x = tile_size_x, tmp_y = tile_size_y;
1335+ if (x_pos != std::string::npos) {
1336+ std::string tile_x_str = sd_tile_size_str.substr (0 , x_pos);
1337+ std::string tile_y_str = sd_tile_size_str.substr (x_pos + 1 );
1338+ if (tile_x_str.find (' .' ) != std::string::npos) {
1339+ tmp_x = std::round (latent_x * get_tile_factor (tile_x_str));
1340+ } else {
1341+ tmp_x = std::stoi (tile_x_str);
1342+ }
1343+ if (tile_y_str.find (' .' ) != std::string::npos) {
1344+ tmp_y = std::round (latent_y * get_tile_factor (tile_y_str));
1345+ } else {
1346+ tmp_y = std::stoi (tile_y_str);
1347+ }
1348+ } else {
1349+ if (sd_tile_size_str.find (' .' ) != std::string::npos) {
1350+ float tile_factor = get_tile_factor (sd_tile_size_str);
1351+ tmp_x = std::round (latent_x * tile_factor);
1352+ tmp_y = std::round (latent_y * tile_factor);
1353+ } else {
1354+ tmp_x = tmp_y = std::stoi (sd_tile_size_str);
1355+ }
1356+ }
1357+ tile_size_x = std::max (std::min (tmp_x, latent_x), min_tile_dimension);
1358+ tile_size_y = std::max (std::min (tmp_y, latent_y), min_tile_dimension);
13021359 } catch (const std::invalid_argument&) {
1303- LOG_WARN (" Invalid " );
1360+ LOG_WARN (" SD_TILE_SIZE is invalid, keeping the default " );
13041361 } catch (const std::out_of_range&) {
1305- LOG_WARN (" OOR " );
1362+ LOG_WARN (" SD_TILE_SIZE is out of range, keeping the default " );
13061363 }
13071364 }
1308- if (!decode){
1309- // TODO: also use and arg for this one?
1310- // to keep the compute buffer size consistent
1311- tile_size*=1.30539 ;
1365+ if (SD_TILE_SIZE != nullptr ) {
1366+ LOG_INFO (" VAE Tile size: %dx%d" , tile_size_x, tile_size_y);
13121367 }
1368+ }
1369+
1370+ ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1371+ int64_t t0 = ggml_time_ms ();
1372+ ggml_tensor* result = NULL ;
1373+ // TODO: args instead of env for tile size / overlap?
13131374 if (!use_tiny_autoencoder) {
1375+ float tile_overlap = 0 .5f ;
1376+ int tile_size_x = 32 ;
1377+ int tile_size_y = 32 ;
1378+
1379+ get_vae_tile_overlap (tile_overlap);
1380+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1381+
1382+ // TODO: also use an arg for this one?
1383+ // multiply tile size for encode to keep the compute buffer size consistent
1384+ tile_size_x *= 1.30539 ;
1385+ tile_size_y *= 1.30539 ;
1386+
13141387 process_vae_input_tensor (x);
13151388 if (vae_tiling && !decode_video) {
1316- // split latent in 32x32 tiles and compute in several steps
13171389 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13181390 first_stage_model->compute (n_threads, in, true , &out, NULL );
13191391 };
1320- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1392+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
13211393 } else {
13221394 first_stage_model->compute (n_threads, x, false , &result, work_ctx);
13231395 }
13241396 first_stage_model->free_compute_buffer ();
13251397 } else {
1326- if (vae_tiling && !decode_video) {
1398+ if (vae_tiling && !decode_video) {
13271399 // split latent in 32x32 tiles and compute in several steps
13281400 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13291401 tae_first_stage->compute (n_threads, in, true , &out, NULL );
@@ -1448,29 +1520,23 @@ class StableDiffusionGGML {
14481520 C,
14491521 x->ne [3 ]);
14501522 }
1451- int tile_size = 32 ;
1452- // TODO: arg instead of env?
1453- const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
1454- if (SD_TILE_SIZE != nullptr ) {
1455- std::string sd_tile_size_str = SD_TILE_SIZE;
1456- try {
1457- tile_size = std::stoi (sd_tile_size_str);
1458- } catch (const std::invalid_argument&) {
1459- LOG_WARN (" Invalid" );
1460- } catch (const std::out_of_range&) {
1461- LOG_WARN (" OOR" );
1462- }
1463- }
14641523 int64_t t0 = ggml_time_ms ();
14651524 if (!use_tiny_autoencoder) {
1525+ float tile_overlap = 0 .5f ;
1526+ int tile_size_x = 32 ;
1527+ int tile_size_y = 32 ;
1528+
1529+ get_vae_tile_overlap (tile_overlap);
1530+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1531+
14661532 process_latent_out (x);
14671533 // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14681534 if (vae_tiling && !decode_video) {
14691535 // split latent in 32x32 tiles and compute in several steps
14701536 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14711537 first_stage_model->compute (n_threads, in, true , &out, NULL );
14721538 };
1473- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1539+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
14741540 } else {
14751541 first_stage_model->compute (n_threads, x, true , &result, work_ctx);
14761542 }
0 commit comments