@@ -604,62 +604,67 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
604604typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
605605
606606// Tiling
607- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
607+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
608608 output = ggml_set_f32 (output, 0 );
609609
610610 int input_width = (int )input->ne [0 ];
611611 int input_height = (int )input->ne [1 ];
612612 int output_width = (int )output->ne [0 ];
613613 int output_height = (int )output->ne [1 ];
614614
615- int input_tile_size, output_tile_size;
616- if (scaled_out) {
617- input_tile_size = tile_size;
618- output_tile_size = tile_size * scale;
619- } else {
620- input_tile_size = tile_size * scale;
621- output_tile_size = tile_size;
615+ GGML_ASSERT (input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
616+ GGML_ASSERT (input_width / output_width == scale || output_width / input_width == scale);
617+
618+ int small_width = output_width;
619+ int small_height = output_height;
620+
621+ bool big_out = output_width > input_width;
622+ if (big_out) {
623+ // Ex: decode
624+ small_width = input_width;
625+ small_height = input_height;
622626 }
623- int tile_overlap = (input_tile_size * tile_overlap_factor);
624- int non_tile_overlap = input_tile_size - tile_overlap;
625627
626- int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
627- int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % input_width;
628+ int tile_overlap = (tile_size * tile_overlap_factor);
629+ int non_tile_overlap = tile_size - tile_overlap;
630+
631+ int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
632+ int overshoot_x = ((num_tiles_x + 1 ) * non_tile_overlap + tile_overlap) % small_width;
628633
629- if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
634+ if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
630635 // if tiles don't fit perfectly using the desired overlap
631636 // and there is enough room to squeeze an extra tile without overlap becoming >0.5
632637 num_tiles_x++;
633638 }
634639
635- float tile_overlap_factor_x = (float )(input_tile_size * num_tiles_x - input_width ) / (float )(input_tile_size * (num_tiles_x - 1 ));
640+ float tile_overlap_factor_x = (float )(tile_size * num_tiles_x - small_width ) / (float )(tile_size * (num_tiles_x - 1 ));
636641 if (num_tiles_x <= 2 ) {
637- if (input_width <= input_tile_size ) {
642+ if (small_width <= tile_size ) {
638643 num_tiles_x = 1 ;
639644 tile_overlap_factor_x = 0 ;
640645 } else {
641646 num_tiles_x = 2 ;
642- tile_overlap_factor_x = (2 * input_tile_size - input_width ) / (float )input_tile_size ;
647+ tile_overlap_factor_x = (2 * tile_size - small_width ) / (float )tile_size ;
643648 }
644649 }
645650
646- int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
647- int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % input_height ;
651+ int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
652+ int overshoot_y = ((num_tiles_y + 1 ) * non_tile_overlap + tile_overlap) % small_height ;
648653
649- if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
654+ if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
650655 // if tiles don't fit perfectly using the desired overlap
651656 // and there is enough room to squeeze an extra tile without overlap becoming >0.5
652657 num_tiles_y++;
653658 }
654659
655- float tile_overlap_factor_y = (float )(input_tile_size * num_tiles_y - input_height ) / (float )(input_tile_size * (num_tiles_y - 1 ));
660+ float tile_overlap_factor_y = (float )(tile_size * num_tiles_y - small_height ) / (float )(tile_size * (num_tiles_y - 1 ));
656661 if (num_tiles_y <= 2 ) {
657- if (input_height <= input_tile_size ) {
662+ if (small_height <= tile_size ) {
658663 num_tiles_y = 1 ;
659664 tile_overlap_factor_y = 0 ;
660665 } else {
661666 num_tiles_y = 2 ;
662- tile_overlap_factor_y = (2 * input_tile_size - input_height ) / (float )input_tile_size ;
667+ tile_overlap_factor_y = (2 * tile_size - small_height ) / (float )tile_size ;
663668 }
664669 }
665670
@@ -668,11 +673,20 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
668673
669674 GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
670675
671- int tile_overlap_x = (int32_t )(input_tile_size * tile_overlap_factor_x);
672- int non_tile_overlap_x = input_tile_size - tile_overlap_x;
676+ int tile_overlap_x = (int32_t )(tile_size * tile_overlap_factor_x);
677+ int non_tile_overlap_x = tile_size - tile_overlap_x;
673678
674- int tile_overlap_y = (int32_t )(input_tile_size * tile_overlap_factor_y);
675- int non_tile_overlap_y = input_tile_size - tile_overlap_y;
679+ int tile_overlap_y = (int32_t )(tile_size * tile_overlap_factor_y);
680+ int non_tile_overlap_y = tile_size - tile_overlap_y;
681+
682+ int input_tile_size = tile_size;
683+ int output_tile_size = tile_size;
684+
685+ if (big_out) {
686+ output_tile_size *= scale;
687+ } else {
688+ input_tile_size *= scale;
689+ }
676690
677691 struct ggml_init_params params = {};
678692 params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
@@ -693,37 +707,48 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
693707 // tiling
694708 ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne [2 ], 1 );
695709 ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne [2 ], 1 );
696- on_processing (input_tile, NULL , true );
697710 int num_tiles = num_tiles_x * num_tiles_y;
698711 LOG_INFO (" processing %i tiles" , num_tiles);
699- pretty_progress (1 , num_tiles, 0 .0f );
712+ pretty_progress (0 , num_tiles, 0 .0f );
700713 int tile_count = 1 ;
701714 bool last_y = false , last_x = false ;
702715 float last_time = 0 .0f ;
703- for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap_y) {
716+ for (int y = 0 ; y < small_height && !last_y; y += non_tile_overlap_y) {
704717 int dy = 0 ;
705- if (y + input_tile_size >= input_height ) {
718+ if (y + tile_size >= small_height ) {
706719 int _y = y;
707- y = input_height - input_tile_size ;
720+ y = small_height - tile_size ;
708721 dy = _y - y;
722+ if (big_out) {
723+ dy *= scale;
724+ }
709725 last_y = true ;
710726 }
711- for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap_x) {
727+ for (int x = 0 ; x < small_width && !last_x; x += non_tile_overlap_x) {
712728 int dx = 0 ;
713- if (x + input_tile_size >= input_width ) {
729+ if (x + tile_size >= small_width ) {
714730 int _x = x;
715- x = input_width - input_tile_size ;
731+ x = small_width - tile_size ;
716732 dx = _x - x;
733+ if (big_out) {
734+ dx *= scale;
735+ }
717736 last_x = true ;
718737 }
738+
739+ int x_in = big_out ? x : scale * x;
740+ int y_in = big_out ? y : scale * y;
741+ int x_out = big_out ? x * scale : x;
742+ int y_out = big_out ? y * scale : y;
743+
744+ int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
745+ int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;
746+
719747 int64_t t1 = ggml_time_ms ();
720- ggml_split_tensor_2d (input, input_tile, x, y );
748+ ggml_split_tensor_2d (input, input_tile, x_in, y_in );
721749 on_processing (input_tile, output_tile, false );
722- if (scaled_out) {
723- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
724- } else {
725- ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
726- }
750+ ggml_merge_tensor_2d (output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
751+
727752 int64_t t2 = ggml_time_ms ();
728753 last_time = (t2 - t1) / 1000 .0f ;
729754 pretty_progress (tile_count, num_tiles, last_time);
0 commit comments