@@ -1295,41 +1295,113 @@ class StableDiffusionGGML {
12951295 return latent;
12961296 }
12971297
1298- ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1299- int64_t t0 = ggml_time_ms ();
1300- ggml_tensor* result = NULL ;
1301- int tile_size = 32 ;
1302- // TODO: arg instead of env?
1298+ void get_vae_tile_overlap (float & tile_overlap) {
1299+ const char * SD_TILE_OVERLAP = getenv (" SD_TILE_OVERLAP" );
1300+ if (SD_TILE_OVERLAP != nullptr ) {
1301+ std::string sd_tile_overlap_str = SD_TILE_OVERLAP;
1302+ try {
1303+ tile_overlap = std::stof (sd_tile_overlap_str);
1304+ if (tile_overlap < 0.0 ) {
1305+ LOG_WARN (" SD_TILE_OVERLAP too low, setting it to 0.0" );
1306+ tile_overlap = 0.0 ;
1307+ } else if (tile_overlap > 0.5 ) {
1308+ LOG_WARN (" SD_TILE_OVERLAP too high, setting it to 0.5" );
1309+ tile_overlap = 0.5 ;
1310+ }
1311+ } catch (const std::invalid_argument&) {
1312+ LOG_WARN (" SD_TILE_OVERLAP is invalid, keeping the default" );
1313+ } catch (const std::out_of_range&) {
1314+ LOG_WARN (" SD_TILE_OVERLAP is out of range, keeping the default" );
1315+ }
1316+ }
1317+ if (SD_TILE_OVERLAP != nullptr ) {
1318+ LOG_INFO (" VAE Tile overlap: %.2f" , tile_overlap);
1319+ }
1320+ }
1321+
1322+ void get_vae_tile_sizes (int & tile_size_x, int & tile_size_y, float tile_overlap, int latent_x, int latent_y) {
13031323 const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
13041324 if (SD_TILE_SIZE != nullptr ) {
1325+ // format is AxB, or just A (equivalent to AxA)
1326+ // A and B can be integers (tile size) or floating point
1327+ // floating point <= 1 means simple fraction of the latent dimension
1328+ // floating point > 1 means number of tiles across that dimension
1329+ // a single number gets applied to both
1330+ auto get_tile_factor = [tile_overlap](const std::string& factor_str) {
1331+ float factor = std::stof (factor_str);
1332+ if (factor > 1.0 )
1333+ factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1334+ return factor;
1335+ };
1336+ const int min_tile_dimension = 4 ;
13051337 std::string sd_tile_size_str = SD_TILE_SIZE;
1338+ size_t x_pos = sd_tile_size_str.find (' x' );
13061339 try {
1307- tile_size = std::stoi (sd_tile_size_str);
1340+ int tmp_x = tile_size_x, tmp_y = tile_size_y;
1341+ if (x_pos != std::string::npos) {
1342+ std::string tile_x_str = sd_tile_size_str.substr (0 , x_pos);
1343+ std::string tile_y_str = sd_tile_size_str.substr (x_pos + 1 );
1344+ if (tile_x_str.find (' .' ) != std::string::npos) {
1345+ tmp_x = std::round (latent_x * get_tile_factor (tile_x_str));
1346+ } else {
1347+ tmp_x = std::stoi (tile_x_str);
1348+ }
1349+ if (tile_y_str.find (' .' ) != std::string::npos) {
1350+ tmp_y = std::round (latent_y * get_tile_factor (tile_y_str));
1351+ } else {
1352+ tmp_y = std::stoi (tile_y_str);
1353+ }
1354+ } else {
1355+ if (sd_tile_size_str.find (' .' ) != std::string::npos) {
1356+ float tile_factor = get_tile_factor (sd_tile_size_str);
1357+ tmp_x = std::round (latent_x * tile_factor);
1358+ tmp_y = std::round (latent_y * tile_factor);
1359+ } else {
1360+ tmp_x = tmp_y = std::stoi (sd_tile_size_str);
1361+ }
1362+ }
1363+ tile_size_x = std::max (std::min (tmp_x, latent_x), min_tile_dimension);
1364+ tile_size_y = std::max (std::min (tmp_y, latent_y), min_tile_dimension);
13081365 } catch (const std::invalid_argument&) {
1309- LOG_WARN (" Invalid " );
1366+ LOG_WARN (" SD_TILE_SIZE is invalid, keeping the default " );
13101367 } catch (const std::out_of_range&) {
1311- LOG_WARN (" OOR " );
1368+ LOG_WARN (" SD_TILE_SIZE is out of range, keeping the default " );
13121369 }
13131370 }
1314- if (!decode){
1315- // TODO: also use and arg for this one?
1316- // to keep the compute buffer size consistent
1317- tile_size*=1.30539 ;
1371+ if (SD_TILE_SIZE != nullptr ) {
1372+ LOG_INFO (" VAE Tile size: %dx%d" , tile_size_x, tile_size_y);
13181373 }
1374+ }
1375+
1376+ ggml_tensor* encode_first_stage (ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false ) {
1377+ int64_t t0 = ggml_time_ms ();
1378+ ggml_tensor* result = NULL ;
1379+ // TODO: args instead of env for tile size / overlap?
13191380 if (!use_tiny_autoencoder) {
1381+ float tile_overlap = 0 .5f ;
1382+ int tile_size_x = 32 ;
1383+ int tile_size_y = 32 ;
1384+
1385+ get_vae_tile_overlap (tile_overlap);
1386+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1387+
1388+ // TODO: also use an arg for this one?
1389+ // multiply tile size for encode to keep the compute buffer size consistent
1390+ tile_size_x *= 1.30539 ;
1391+ tile_size_y *= 1.30539 ;
1392+
13201393 process_vae_input_tensor (x);
13211394 if (vae_tiling && !decode_video) {
1322- // split latent in 32x32 tiles and compute in several steps
13231395 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13241396 first_stage_model->compute (n_threads, in, true , &out, NULL );
13251397 };
1326- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1398+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
13271399 } else {
13281400 first_stage_model->compute (n_threads, x, false , &result, work_ctx);
13291401 }
13301402 first_stage_model->free_compute_buffer ();
13311403 } else {
1332- if (vae_tiling && !decode_video) {
1404+ if (vae_tiling && !decode_video) {
13331405 // split latent in 32x32 tiles and compute in several steps
13341406 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13351407 tae_first_stage->compute (n_threads, in, true , &out, NULL );
@@ -1454,29 +1526,23 @@ class StableDiffusionGGML {
14541526 C,
14551527 x->ne [3 ]);
14561528 }
1457- int tile_size = 32 ;
1458- // TODO: arg instead of env?
1459- const char * SD_TILE_SIZE = getenv (" SD_TILE_SIZE" );
1460- if (SD_TILE_SIZE != nullptr ) {
1461- std::string sd_tile_size_str = SD_TILE_SIZE;
1462- try {
1463- tile_size = std::stoi (sd_tile_size_str);
1464- } catch (const std::invalid_argument&) {
1465- LOG_WARN (" Invalid" );
1466- } catch (const std::out_of_range&) {
1467- LOG_WARN (" OOR" );
1468- }
1469- }
14701529 int64_t t0 = ggml_time_ms ();
14711530 if (!use_tiny_autoencoder) {
1531+ float tile_overlap = 0 .5f ;
1532+ int tile_size_x = 32 ;
1533+ int tile_size_y = 32 ;
1534+
1535+ get_vae_tile_overlap (tile_overlap);
1536+ get_vae_tile_sizes (tile_size_x, tile_size_y, tile_overlap, x->ne [0 ] / 8 , x->ne [1 ] / 8 );
1537+
14721538 process_latent_out (x);
14731539 // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
14741540 if (vae_tiling && !decode_video) {
14751541 // split latent in 32x32 tiles and compute in several steps
14761542 auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14771543 first_stage_model->compute (n_threads, in, true , &out, NULL );
14781544 };
1479- sd_tiling (x, result, 8 , tile_size, 0 . 5f , on_tiling);
1545+ sd_tiling_non_square (x, result, 8 , tile_size_x, tile_size_y, tile_overlap , on_tiling);
14801546 } else {
14811547 first_stage_model->compute (n_threads, x, true , &result, work_ctx);
14821548 }
0 commit comments