@@ -41,6 +41,13 @@ class FlatIndexNode : public IndexNode {
4141 return err;
4242 }
4343
44+ std::vector<float >
45+ convertFloat16ToFloat32 (const std::vector<float16>& input) {
46+ std::vector<float > output (input.size ());
47+ std::transform (input.begin (), input.end (), output.begin (), [](float16 f) { return static_cast <float >(f); });
48+ return output;
49+ }
50+
4451 Status
4552 Train (const DataSet& dataset, const Config& cfg) override {
4653 const FlatConfig& f_cfg = static_cast <const FlatConfig&>(cfg);
@@ -55,6 +62,14 @@ class FlatIndexNode : public IndexNode {
5562 LOG_KNOWHERE_WARNING_ << " please check metric type: " << f_cfg.metric_type ;
5663 return metric.error ();
5764 }
65+
66+ auto dim_data = dataset.GetDim ();
67+
68+ // If dim_data is float16, convert it to float32
69+ if (typeid (dim_data[0 ]) == typeid (float16)) {
70+ dim_data = convertFloat16ToFloat32 (dim_data);
71+ }
72+
5873 index_ = std::make_unique<T>(dataset.GetDim (), metric.value ());
5974 return Status::success;
6075 }
@@ -63,6 +78,12 @@ class FlatIndexNode : public IndexNode {
6378 Add (const DataSet& dataset, const Config& cfg) override {
6479 auto x = dataset.GetTensor ();
6580 auto n = dataset.GetRows ();
81+
82+ if (typeid (x[0 ]) == typeid (float16)) {
83+ std::vector<float > x_float32 (x.begin (), x.end ());
84+ x = x_float32;
85+ }
86+
6687 if constexpr (std::is_same<T, faiss::IndexFlat>::value) {
6788 index_->add (n, (const float *)x);
6889 }
@@ -92,6 +113,11 @@ class FlatIndexNode : public IndexNode {
92113 auto x = dataset.GetTensor ();
93114 auto dim = dataset.GetDim ();
94115
116+ // If x is float16, convert it to float32
117+ if (typeid (x[0 ]) == typeid (float16)) {
118+ x = convertFloat16ToFloat32 (x);
119+ }
120+
95121 auto len = k * nq;
96122 int64_t * ids = nullptr ;
97123 float * distances = nullptr ;
@@ -150,6 +176,11 @@ class FlatIndexNode : public IndexNode {
150176 auto xq = dataset.GetTensor ();
151177 auto dim = dataset.GetDim ();
152178
179+ // If xq is float16, convert it to float32
180+ if (typeid (xq[0 ]) == typeid (float16)) {
181+ xq = convertFloat16ToFloat32 (xq);
182+ }
183+
153184 int64_t * ids = nullptr ;
154185 float * distances = nullptr ;
155186 size_t * lims = nullptr ;
@@ -212,7 +243,14 @@ class FlatIndexNode : public IndexNode {
212243 for (int64_t i = 0 ; i < rows; i++) {
213244 index_->reconstruct (ids[i], data + i * dim);
214245 }
215- return GenResultDataSet (rows, dim, data);
246+ // If original data was float16, convert it back before returning
247+ if (typeid (dataset.GetTensor ()[0 ]) == typeid (float16)) {
248+ auto data16 = convertFloat32ToFloat16 (data, rows * dim);
249+ delete[] data;
250+ return GenResultDataSet (rows, dim, data16);
251+ } else {
252+ return GenResultDataSet (rows, dim, data);
253+ }
216254 } catch (const std::exception& e) {
217255 std::unique_ptr<float []> auto_del (data);
218256 LOG_KNOWHERE_WARNING_ << " faiss inner error: " << e.what ();
0 commit comments