@@ -603,8 +603,38 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
603603
604604typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
605605
606+ __STATIC_INLINE__ void
607+ sd_tiling_calc_tiles (int &num_tiles_dim, float & tile_overlap_factor_dim, int small_dim, int tile_size, const float tile_overlap_factor) {
608+
609+ int tile_overlap = (tile_size * tile_overlap_factor);
610+ int non_tile_overlap = tile_size - tile_overlap;
611+
612+ num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap;
613+ int overshoot_dim = ((num_tiles_dim + 1 ) * non_tile_overlap + tile_overlap) % small_dim;
614+
615+ if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) {
616+ // if tiles don't fit perfectly using the desired overlap
617+ // and there is enough room to squeeze an extra tile without overlap becoming >0.5
618+ num_tiles_dim++;
619+ }
620+
621+ tile_overlap_factor_dim = (float )(tile_size * num_tiles_dim - small_dim) / (float )(tile_size * (num_tiles_dim - 1 ));
622+ if (num_tiles_dim <= 2 ) {
623+ if (small_dim <= tile_size) {
624+ num_tiles_dim = 1 ;
625+ tile_overlap_factor_dim = 0 ;
626+ } else {
627+ num_tiles_dim = 2 ;
628+ tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float )tile_size;
629+ }
630+ }
631+ }
632+
606633// Tiling
607- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
634+ __STATIC_INLINE__ void sd_tiling_non_square (ggml_tensor* input, ggml_tensor* output, const int scale,
635+ const int p_tile_size_x, const int p_tile_size_y,
636+ const float tile_overlap_factor, on_tile_process on_processing) {
637+
608638 output = ggml_set_f32 (output, 0 );
609639
610640 int input_width = (int )input->ne [0 ];
@@ -625,62 +655,27 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
625655 small_height = input_height;
626656 }
627657
628- int tile_overlap = (tile_size * tile_overlap_factor);
629- int non_tile_overlap = tile_size - tile_overlap;
630-
631- int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
632- int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % small_width;
633-
634- if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
635- // if tiles don't fit perfectly using the desired overlap
636- // and there is enough room to squeeze an extra tile without overlap becoming >0.5
637- num_tiles_x++;
638- }
639-
640- float tile_overlap_factor_x = (float )(tile_size * num_tiles_x - small_width) / (float )(tile_size * (num_tiles_x - 1 ));
641- if (num_tiles_x <= 2 ) {
642- if (small_width <= tile_size) {
643- num_tiles_x = 1 ;
644- tile_overlap_factor_x = 0 ;
645- } else {
646- num_tiles_x = 2 ;
647- tile_overlap_factor_x = (2 * tile_size - small_width) / (float )tile_size;
648- }
649- }
650-
651- int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
652- int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % small_height;
658+ int num_tiles_x;
659+ float tile_overlap_factor_x;
660+ sd_tiling_calc_tiles (num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor);
653661
654- if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
655- // if tiles don't fit perfectly using the desired overlap
656- // and there is enough room to squeeze an extra tile without overlap becoming >0.5
657- num_tiles_y++;
658- }
659-
660- float tile_overlap_factor_y = (float )(tile_size * num_tiles_y - small_height) / (float )(tile_size * (num_tiles_y - 1 ));
661- if (num_tiles_y <= 2 ) {
662- if (small_height <= tile_size) {
663- num_tiles_y = 1 ;
664- tile_overlap_factor_y = 0 ;
665- } else {
666- num_tiles_y = 2 ;
667- tile_overlap_factor_y = (2 * tile_size - small_height) / (float )tile_size;
668- }
669- }
662+ int num_tiles_y;
663+ float tile_overlap_factor_y;
664+ sd_tiling_calc_tiles (num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor);
670665
671666 LOG_DEBUG (" num tiles : %d, %d " , num_tiles_x, num_tiles_y);
672667 LOG_DEBUG (" optimal overlap : %f, %f (targeting %f)" , tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor);
673668
674669 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
675670
676- int tile_overlap_x = (int32_t )(tile_size * tile_overlap_factor_x);
677- int non_tile_overlap_x = tile_size - tile_overlap_x;
671+ int tile_overlap_x = (int32_t )(p_tile_size_x * tile_overlap_factor_x);
672+ int non_tile_overlap_x = p_tile_size_x - tile_overlap_x;
678673
679- int tile_overlap_y = (int32_t )(tile_size * tile_overlap_factor_y);
680- int non_tile_overlap_y = tile_size - tile_overlap_y;
674+ int tile_overlap_y = (int32_t )(p_tile_size_y * tile_overlap_factor_y);
675+ int non_tile_overlap_y = p_tile_size_y - tile_overlap_y;
681676
682- int tile_size_x = tile_size < small_width ? tile_size : small_width;
683- int tile_size_y = tile_size < small_height ? tile_size : small_height;
677+ int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width;
678+ int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height;
684679
685680 int input_tile_size_x = tile_size_x;
686681 int input_tile_size_y = tile_size_y;
@@ -769,6 +764,11 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
769764 ggml_free (tiles_ctx);
770765}
771766
767+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale,
768+ const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
769+ sd_tiling_non_square (input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing);
770+ }
771+
772772__STATIC_INLINE__ struct ggml_tensor * ggml_group_norm_32 (struct ggml_context * ctx,
773773 struct ggml_tensor * a) {
774774 const float eps = 1e-6f ; // default eps parameter
0 commit comments