@@ -187,7 +187,7 @@ class TieredHNSWIndex : public VecSimTieredIndex<DataType, DistType> {
187187 size_t indexSize () const override ;
188188 size_t indexLabelCount () const override ;
189189 size_t indexCapacity () const override ;
190- double getDistanceFrom (labelType label, const void *blob) const override ;
190+ double getDistanceFrom_Unsafe (labelType label, const void *blob) const override ;
191191 // Do nothing here, each tier (flat buffer and HNSW) should increase capacity for itself when
192192 // needed.
193193 VecSimIndexInfo info () const override ;
@@ -210,6 +210,17 @@ class TieredHNSWIndex : public VecSimTieredIndex<DataType, DistType> {
210210 " running asynchronous GC for tiered HNSW index" );
211211 this ->executeReadySwapJobs (this ->pendingSwapJobsThreshold );
212212 }
213+ void acquireSharedLocks () override {
214+ this ->flatIndexGuard .lock_shared ();
215+ this ->mainIndexGuard .lock_shared ();
216+ this ->getHNSWIndex ()->lockSharedIndexDataGuard ();
217+ }
218+
219+ void releaseSharedLocks () override {
220+ this ->flatIndexGuard .unlock_shared ();
221+ this ->mainIndexGuard .unlock_shared ();
222+ this ->getHNSWIndex ()->unlockSharedIndexDataGuard ();
223+ }
213224#ifdef BUILD_TESTS
214225 void getDataByLabel (labelType label, std::vector<std::vector<DataType>> &vectors_output) const ;
215226#endif
@@ -621,9 +632,9 @@ TieredHNSWIndex<DataType, DistType>::~TieredHNSWIndex() {
621632template <typename DataType, typename DistType>
622633size_t TieredHNSWIndex<DataType, DistType>::indexSize() const {
623634 this ->flatIndexGuard .lock_shared ();
624- this ->getHNSWIndex ()->lockIndexDataGuard ();
635+ this ->getHNSWIndex ()->lockSharedIndexDataGuard ();
625636 size_t res = this ->backendIndex ->indexSize () + this ->frontendIndex ->indexSize ();
626- this ->getHNSWIndex ()->unlockIndexDataGuard ();
637+ this ->getHNSWIndex ()->unlockSharedIndexDataGuard ();
627638 this ->flatIndexGuard .unlock_shared ();
628639 return res;
629640}
@@ -803,14 +814,18 @@ int TieredHNSWIndex<DataType, DistType>::deleteVector(labelType label) {
803814// 3. label exists in both indexes - we may have some of the vectors with the same label in the flat
804815// buffer only and some in the Main index only (and maybe temporal duplications).
805816// So, we get the distance from both indexes and return the minimum.
817+
818+ // IMPORTANT: this should be called when the *tiered index locks are locked for shared ownership*,
819+ // along with HNSW index data guard lock. That is since the internal getDistanceFrom calls access
820+ // the indexes' data, and it is not safe to run insert/delete operation in parallel. Also, we avoid
821+ // acquiring the locks internally, since this is usually called for every vector individually, and
822+ // the overhead of acquiring and releasing the locks is significant in that case.
806823template <typename DataType, typename DistType>
807- double TieredHNSWIndex<DataType, DistType>::getDistanceFrom (labelType label,
808- const void *blob) const {
824+ double TieredHNSWIndex<DataType, DistType>::getDistanceFrom_Unsafe (labelType label,
825+ const void *blob) const {
809826 // Try to get the distance from the flat buffer.
810827 // If the label doesn't exist, the distance will be NaN.
811- this ->flatIndexGuard .lock_shared ();
812- auto flat_dist = this ->frontendIndex ->getDistanceFrom (label, blob);
813- this ->flatIndexGuard .unlock_shared ();
828+ auto flat_dist = this ->frontendIndex ->getDistanceFrom_Unsafe (label, blob);
814829
815830 // Optimization. TODO: consider having different implementations for single and multi indexes,
816831 // to avoid checking the index type on every query.
@@ -821,9 +836,7 @@ double TieredHNSWIndex<DataType, DistType>::getDistanceFrom(labelType label,
821836 }
822837
823838 // Try to get the distance from the Main index.
824- this ->mainIndexGuard .lock_shared ();
825- auto hnsw_dist = getHNSWIndex ()->safeGetDistanceFrom (label, blob);
826- this ->mainIndexGuard .unlock_shared ();
839+ auto hnsw_dist = getHNSWIndex ()->getDistanceFrom_Unsafe (label, blob);
827840
828841 // Return the minimum distance that is not NaN.
829842 return std::fmin (flat_dist, hnsw_dist);
0 commit comments