Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2ef4574

Browse files
committed
Methods to split and join a tensor into tiles
1 parent 314e093 commit 2ef4574

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

OnnxStack.Core/Extensions/TensorExtension.cs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
using System.Linq;
55
using System.Numerics.Tensors;
66
using System.Numerics;
7+
using OnnxStack.Core.Model;
78

89
namespace OnnxStack.Core
910
{
10-
public static class TensorExtension
11+
public static partial class TensorExtension
1112
{
1213
/// <summary>
1314
/// Divides the tensor by float.
@@ -423,5 +424,78 @@ private static DenseTensor<float> ConcatenateAxis2(DenseTensor<float> tensor1, D
423424

424425
return concatenatedTensor;
425426
}
427+
428+
429+
/// <summary>
430+
/// Splits the Tensor into 4 equal tiles.
431+
/// </summary>
432+
/// <param name="sourceTensor">The source tensor.</param>
433+
/// <returns>TODO: Optimize</returns>
434+
public static ImageTiles SplitTiles(this DenseTensor<float> sourceTensor)
435+
{
436+
int tileWidth = sourceTensor.Dimensions[3] / 2;
437+
int tileHeight = sourceTensor.Dimensions[2] / 2;
438+
439+
return new ImageTiles(
440+
SplitTile(sourceTensor, 0, 0, tileHeight, tileWidth),
441+
SplitTile(sourceTensor, 0, tileWidth, tileHeight, tileWidth * 2),
442+
SplitTile(sourceTensor, tileHeight, 0, tileHeight * 2, tileWidth),
443+
SplitTile(sourceTensor, tileHeight, tileWidth, tileHeight * 2, tileWidth * 2));
444+
}
445+
446+
private static DenseTensor<float> SplitTile(DenseTensor<float> tensor, int startRow, int startCol, int endRow, int endCol)
447+
{
448+
int height = endRow - startRow;
449+
int width = endCol - startCol;
450+
int channels = tensor.Dimensions[1];
451+
var slicedData = new DenseTensor<float>(new[] { 1, channels, height, width });
452+
for (int c = 0; c < channels; c++)
453+
{
454+
for (int i = 0; i < height; i++)
455+
{
456+
for (int j = 0; j < width; j++)
457+
{
458+
slicedData[0, c, i, j] = tensor[0, c, startRow + i, startCol + j];
459+
}
460+
}
461+
}
462+
return slicedData;
463+
}
464+
465+
466+
/// <summary>
467+
/// Rejoins the tiles into a single Tensor.
468+
/// </summary>
469+
/// <param name="tiles">The tiles.</param>
470+
/// <returns>TODO: Optimize</returns>
471+
public static DenseTensor<float> RejoinTiles(this ImageTiles tiles)
472+
{
473+
int totalHeight = tiles.Tile1.Dimensions[2] + tiles.Tile3.Dimensions[2];
474+
int totalWidth = tiles.Tile1.Dimensions[3] + tiles.Tile2.Dimensions[3];
475+
int channels = tiles.Tile1.Dimensions[1];
476+
var destination = new DenseTensor<float>(new[] { 1, channels, totalHeight, totalWidth });
477+
RejoinTile(destination, tiles.Tile1, 0, 0);
478+
RejoinTile(destination, tiles.Tile2, 0, tiles.Tile1.Dimensions[3]);
479+
RejoinTile(destination, tiles.Tile3, tiles.Tile1.Dimensions[2], 0);
480+
RejoinTile(destination, tiles.Tile4, tiles.Tile1.Dimensions[2], tiles.Tile1.Dimensions[3]);
481+
return destination;
482+
}
483+
484+
private static void RejoinTile(DenseTensor<float> destination, DenseTensor<float> tile, int startRow, int startCol)
485+
{
486+
int channels = tile.Dimensions[1];
487+
int height = tile.Dimensions[2];
488+
int width = tile.Dimensions[3];
489+
for (int c = 0; c < channels; c++)
490+
{
491+
for (int i = 0; i < height; i++)
492+
{
493+
for (int j = 0; j < width; j++)
494+
{
495+
destination[0, c, startRow + i, startCol + j] = tile[0, c, i, j];
496+
}
497+
}
498+
}
499+
}
426500
}
427501
}

OnnxStack.Core/Model/ImageTiles.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using Microsoft.ML.OnnxRuntime.Tensors;
2+
3+
namespace OnnxStack.Core.Model
4+
{
5+
public record ImageTiles(DenseTensor<float> Tile1, DenseTensor<float> Tile2, DenseTensor<float> Tile3, DenseTensor<float> Tile4);
6+
}

OnnxStack.Core/OnnxStack.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
<PackageReference Include="Microsoft.ML" Version="3.0.1" />
4545
<PackageReference Include="Microsoft.ML.OnnxRuntime.Extensions" Version="0.10.0" />
4646
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.17.1" />
47-
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.2" />
47+
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.3" />
4848
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
4949
</ItemGroup>
5050

0 commit comments

Comments
 (0)