Skip to content

Commit 55cc4d0

Browse files
committed
Support save and load in numpy.lib.format.
1 parent d794576 commit 55cc4d0

File tree

11 files changed

+608
-281
lines changed

11 files changed

+608
-281
lines changed

src/TensorFlowNET.Core/GlobalUsing.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
global using System;
22
global using System.Collections.Generic;
33
global using System.Text;
4+
global using System.Collections;
5+
global using System.Data;
6+
global using System.Linq;

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.load.cs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
 using System;
2-
using System.Collections;
3-
using System.Collections.Generic;
4-
using System.IO;
5-
using System.Linq;
6-
using System.Reflection;
7-
using System.Text;
8-
using Tensorflow.Util;
1+
using System.IO;
92

103
namespace Tensorflow.NumPy
114
{
@@ -15,10 +8,7 @@ public NDArray load(string file)
158
{
169
using var stream = new FileStream(file, FileMode.Open);
1710
using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true);
18-
int bytes;
19-
Type type;
20-
int[] shape;
21-
if (!ParseReader(reader, out bytes, out type, out shape))
11+
if (!ParseReader(reader, out var bytes, out var type, out var shape))
2212
throw new FormatException();
2313

2414
Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim));
@@ -31,10 +21,7 @@ public Array LoadMatrix(Stream stream)
3121
{
3222
using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII, leaveOpen: true))
3323
{
34-
int bytes;
35-
Type type;
36-
int[] shape;
37-
if (!ParseReader(reader, out bytes, out type, out shape))
24+
if (!ParseReader(reader, out var bytes, out var type, out var shape))
3825
throw new FormatException();
3926

4027
Array matrix = Array.CreateInstance(type, shape);
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*****************************************************************************
2+
Copyright 2023 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System.IO;
18+
using System.IO.Compression;
19+
20+
namespace Tensorflow.NumPy;
21+
22+
public partial class np
23+
{
24+
[AutoNumPy]
25+
public static NpzDictionary loadz(string file)
26+
{
27+
using var stream = new FileStream(file, FileMode.Open);
28+
return new NpzDictionary(stream);
29+
}
30+
31+
public static void save(string file, NDArray nd)
32+
{
33+
using var stream = new FileStream(file, FileMode.Create);
34+
NpyFormat.Save(nd, stream);
35+
}
36+
37+
public static void savez(string file, params NDArray[] nds)
38+
{
39+
using var stream = new FileStream(file, FileMode.Create);
40+
NpzFormat.Save(nds, stream);
41+
}
42+
43+
public static void savez(string file, object nds)
44+
{
45+
using var stream = new FileStream(file, FileMode.Create);
46+
NpzFormat.Save(nds, stream);
47+
}
48+
49+
public static void savez_compressed(string file, params NDArray[] nds)
50+
{
51+
using var stream = new FileStream(file, FileMode.Create);
52+
NpzFormat.Save(nds, stream, CompressionLevel.Fastest);
53+
}
54+
55+
public static void savez_compressed(string file, object nds)
56+
{
57+
using var stream = new FileStream(file, FileMode.Create);
58+
NpzFormat.Save(nds, stream, CompressionLevel.Fastest);
59+
}
60+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using System.IO;
2+
using System.Runtime.InteropServices;
3+
4+
namespace Tensorflow.NumPy;
5+
6+
public class NpyFormat
7+
{
8+
public static void Save(NDArray array, Stream stream, bool leaveOpen = true)
9+
{
10+
using var writer = new BinaryWriter(stream, Encoding.ASCII, leaveOpen: leaveOpen);
11+
12+
string dtype = GetDtypeName(array, out var type, out var maxLength);
13+
int[] shape = array.shape.as_int_list();
14+
var bytesWritten = (ulong)writeHeader(writer, dtype, shape);
15+
stream.Write(array.ToByteArray(), 0, (int)array.bytesize);
16+
}
17+
18+
private static int writeHeader(BinaryWriter writer, string dtype, int[] shape)
19+
{
20+
// The first 6 bytes are a magic string: exactly "x93NUMPY"
21+
22+
char[] magic = { 'N', 'U', 'M', 'P', 'Y' };
23+
writer.Write((byte)147);
24+
writer.Write(magic);
25+
writer.Write((byte)1); // major
26+
writer.Write((byte)0); // minor;
27+
28+
string tuple = shape.Length == 1 ? $"{shape[0]}," : String.Join(", ", shape.Select(i => i.ToString()).ToArray());
29+
string header = "{{'descr': '{0}', 'fortran_order': False, 'shape': ({1}), }}";
30+
header = string.Format(header, dtype, tuple);
31+
int preamble = 10; // magic string (6) + 4
32+
33+
int len = header.Length + 1; // the 1 is to account for the missing \n at the end
34+
int headerSize = len + preamble;
35+
36+
int pad = 16 - (headerSize % 16);
37+
header = header.PadRight(header.Length + pad);
38+
header += "\n";
39+
headerSize = header.Length + preamble;
40+
41+
if (headerSize % 16 != 0)
42+
throw new Exception("");
43+
44+
writer.Write((ushort)header.Length);
45+
for (int i = 0; i < header.Length; i++)
46+
writer.Write((byte)header[i]);
47+
48+
return headerSize;
49+
}
50+
51+
private static string GetDtypeName(NDArray array, out Type type, out int bytes)
52+
{
53+
type = array.dtype.as_system_dtype();
54+
55+
bytes = 1;
56+
57+
if (type == typeof(string))
58+
{
59+
throw new NotSupportedException("");
60+
}
61+
else if (type == typeof(bool))
62+
{
63+
bytes = 1;
64+
}
65+
else
66+
{
67+
bytes = Marshal.SizeOf(type);
68+
}
69+
70+
if (type == typeof(bool))
71+
return "|b1";
72+
else if (type == typeof(byte))
73+
return "|i1";
74+
else if (type == typeof(short))
75+
return "<i2";
76+
else if (type == typeof(int))
77+
return "<i4";
78+
else if (type == typeof(long))
79+
return "<i8";
80+
else if (type == typeof(ushort))
81+
return "<u2";
82+
else if (type == typeof(uint))
83+
return "<u4";
84+
else if (type == typeof(ulong))
85+
return "<u8";
86+
else if (type == typeof(float))
87+
return "<f4";
88+
else if (type == typeof(double))
89+
return "<f8";
90+
else if (type == typeof(string))
91+
return "|S" + bytes;
92+
else
93+
throw new NotSupportedException();
94+
}
95+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
using System.IO;
2+
using System.IO.Compression;
3+
4+
namespace Tensorflow.NumPy;
5+
6+
public class NpzDictionary<T> : IDisposable, IReadOnlyDictionary<string, T>, ICollection<T>
7+
where T : class,
8+
ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
9+
{
10+
Stream stream;
11+
ZipArchive archive;
12+
13+
bool disposedValue = false;
14+
15+
Dictionary<string, ZipArchiveEntry> entries;
16+
Dictionary<string, T> arrays;
17+
18+
19+
public NpzDictionary(Stream stream)
20+
{
21+
this.stream = stream;
22+
this.archive = new ZipArchive(stream, ZipArchiveMode.Read, leaveOpen: true);
23+
24+
this.entries = new Dictionary<string, ZipArchiveEntry>();
25+
foreach (var entry in archive.Entries)
26+
this.entries[entry.FullName] = entry;
27+
28+
this.arrays = new Dictionary<string, T>();
29+
}
30+
31+
32+
public IEnumerable<string> Keys
33+
{
34+
get { return entries.Keys; }
35+
}
36+
37+
38+
public IEnumerable<T> Values
39+
{
40+
get { return entries.Values.Select(OpenEntry); }
41+
}
42+
43+
public int Count
44+
{
45+
get { return entries.Count; }
46+
}
47+
48+
49+
public object SyncRoot
50+
{
51+
get { return ((ICollection)entries).SyncRoot; }
52+
}
53+
54+
55+
public bool IsSynchronized
56+
{
57+
get { return ((ICollection)entries).IsSynchronized; }
58+
}
59+
60+
public bool IsReadOnly
61+
{
62+
get { return true; }
63+
}
64+
65+
public T this[string key]
66+
{
67+
get { return OpenEntry(entries[key]); }
68+
}
69+
70+
private T OpenEntry(ZipArchiveEntry entry)
71+
{
72+
T array;
73+
if (arrays.TryGetValue(entry.FullName, out array))
74+
return array;
75+
76+
using (Stream s = entry.Open())
77+
{
78+
array = Load_Npz(s);
79+
arrays[entry.FullName] = array;
80+
return array;
81+
}
82+
}
83+
84+
protected virtual T Load_Npz(Stream s)
85+
{
86+
return np.Load<T>(s);
87+
}
88+
89+
public bool ContainsKey(string key)
90+
{
91+
return entries.ContainsKey(key);
92+
}
93+
94+
public bool TryGetValue(string key, out T value)
95+
{
96+
value = default(T);
97+
ZipArchiveEntry entry;
98+
if (!entries.TryGetValue(key, out entry))
99+
return false;
100+
value = OpenEntry(entry);
101+
return true;
102+
}
103+
104+
public IEnumerator<KeyValuePair<string, T>> GetEnumerator()
105+
{
106+
foreach (var entry in archive.Entries)
107+
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
108+
}
109+
110+
IEnumerator IEnumerable.GetEnumerator()
111+
{
112+
foreach (var entry in archive.Entries)
113+
yield return new KeyValuePair<string, T>(entry.FullName, OpenEntry(entry));
114+
}
115+
116+
IEnumerator<T> IEnumerable<T>.GetEnumerator()
117+
{
118+
foreach (var entry in archive.Entries)
119+
yield return OpenEntry(entry);
120+
}
121+
122+
public void CopyTo(Array array, int arrayIndex)
123+
{
124+
foreach (var v in this)
125+
array.SetValue(v, arrayIndex++);
126+
}
127+
128+
public void CopyTo(T[] array, int arrayIndex)
129+
{
130+
foreach (var v in this)
131+
array.SetValue(v, arrayIndex++);
132+
}
133+
134+
public void Add(T item)
135+
{
136+
throw new ReadOnlyException();
137+
}
138+
139+
public void Clear()
140+
{
141+
throw new ReadOnlyException();
142+
}
143+
144+
public bool Contains(T item)
145+
{
146+
foreach (var v in this)
147+
if (Object.Equals(v.Value, item))
148+
return true;
149+
return false;
150+
}
151+
152+
public bool Remove(T item)
153+
{
154+
throw new ReadOnlyException();
155+
}
156+
157+
protected virtual void Dispose(bool disposing)
158+
{
159+
if (!disposedValue)
160+
{
161+
if (disposing)
162+
{
163+
archive.Dispose();
164+
stream.Dispose();
165+
}
166+
167+
archive = null;
168+
stream = null;
169+
entries = null;
170+
arrays = null;
171+
172+
disposedValue = true;
173+
}
174+
}
175+
176+
public void Dispose()
177+
{
178+
Dispose(true);
179+
}
180+
}

0 commit comments

Comments
 (0)