|
4 | 4 | using System.Linq;
|
5 | 5 | using System.Numerics.Tensors;
|
6 | 6 | using System.Numerics;
|
| 7 | +using OnnxStack.Core.Model; |
7 | 8 |
|
8 | 9 | namespace OnnxStack.Core
|
9 | 10 | {
|
10 |
| - public static class TensorExtension |
| 11 | + public static partial class TensorExtension |
11 | 12 | {
|
12 | 13 | /// <summary>
|
13 | 14 | /// Divides the tensor by float.
|
@@ -423,5 +424,78 @@ private static DenseTensor<float> ConcatenateAxis2(DenseTensor<float> tensor1, D
|
423 | 424 |
|
424 | 425 | return concatenatedTensor;
|
425 | 426 | }
|
| 427 | + |
| 428 | + |
| 429 | + /// <summary> |
| 430 | + /// Splits the Tensor into 4 equal tiles. |
| 431 | + /// </summary> |
| 432 | + /// <param name="sourceTensor">The source tensor.</param> |
| 433 | + /// <returns>TODO: Optimize</returns> |
| 434 | + public static ImageTiles SplitTiles(this DenseTensor<float> sourceTensor) |
| 435 | + { |
| 436 | + int tileWidth = sourceTensor.Dimensions[3] / 2; |
| 437 | + int tileHeight = sourceTensor.Dimensions[2] / 2; |
| 438 | + |
| 439 | + return new ImageTiles( |
| 440 | + SplitTile(sourceTensor, 0, 0, tileHeight, tileWidth), |
| 441 | + SplitTile(sourceTensor, 0, tileWidth, tileHeight, tileWidth * 2), |
| 442 | + SplitTile(sourceTensor, tileHeight, 0, tileHeight * 2, tileWidth), |
| 443 | + SplitTile(sourceTensor, tileHeight, tileWidth, tileHeight * 2, tileWidth * 2)); |
| 444 | + } |
| 445 | + |
| 446 | + private static DenseTensor<float> SplitTile(DenseTensor<float> tensor, int startRow, int startCol, int endRow, int endCol) |
| 447 | + { |
| 448 | + int height = endRow - startRow; |
| 449 | + int width = endCol - startCol; |
| 450 | + int channels = tensor.Dimensions[1]; |
| 451 | + var slicedData = new DenseTensor<float>(new[] { 1, channels, height, width }); |
| 452 | + for (int c = 0; c < channels; c++) |
| 453 | + { |
| 454 | + for (int i = 0; i < height; i++) |
| 455 | + { |
| 456 | + for (int j = 0; j < width; j++) |
| 457 | + { |
| 458 | + slicedData[0, c, i, j] = tensor[0, c, startRow + i, startCol + j]; |
| 459 | + } |
| 460 | + } |
| 461 | + } |
| 462 | + return slicedData; |
| 463 | + } |
| 464 | + |
| 465 | + |
| 466 | + /// <summary> |
| 467 | + /// Rejoins the tiles into a single Tensor. |
| 468 | + /// </summary> |
| 469 | + /// <param name="tiles">The tiles.</param> |
| 470 | + /// <returns>TODO: Optimize</returns> |
| 471 | + public static DenseTensor<float> RejoinTiles(this ImageTiles tiles) |
| 472 | + { |
| 473 | + int totalHeight = tiles.Tile1.Dimensions[2] + tiles.Tile3.Dimensions[2]; |
| 474 | + int totalWidth = tiles.Tile1.Dimensions[3] + tiles.Tile2.Dimensions[3]; |
| 475 | + int channels = tiles.Tile1.Dimensions[1]; |
| 476 | + var destination = new DenseTensor<float>(new[] { 1, channels, totalHeight, totalWidth }); |
| 477 | + RejoinTile(destination, tiles.Tile1, 0, 0); |
| 478 | + RejoinTile(destination, tiles.Tile2, 0, tiles.Tile1.Dimensions[3]); |
| 479 | + RejoinTile(destination, tiles.Tile3, tiles.Tile1.Dimensions[2], 0); |
| 480 | + RejoinTile(destination, tiles.Tile4, tiles.Tile1.Dimensions[2], tiles.Tile1.Dimensions[3]); |
| 481 | + return destination; |
| 482 | + } |
| 483 | + |
| 484 | + private static void RejoinTile(DenseTensor<float> destination, DenseTensor<float> tile, int startRow, int startCol) |
| 485 | + { |
| 486 | + int channels = tile.Dimensions[1]; |
| 487 | + int height = tile.Dimensions[2]; |
| 488 | + int width = tile.Dimensions[3]; |
| 489 | + for (int c = 0; c < channels; c++) |
| 490 | + { |
| 491 | + for (int i = 0; i < height; i++) |
| 492 | + { |
| 493 | + for (int j = 0; j < width; j++) |
| 494 | + { |
| 495 | + destination[0, c, startRow + i, startCol + j] = tile[0, c, i, j]; |
| 496 | + } |
| 497 | + } |
| 498 | + } |
| 499 | + } |
426 | 500 | }
|
427 | 501 | }
|
0 commit comments