diff --git a/rust/index/src/spann/types.rs b/rust/index/src/spann/types.rs index 59269864c34..260babe658e 100644 --- a/rust/index/src/spann/types.rs +++ b/rust/index/src/spann/types.rs @@ -766,6 +766,61 @@ impl SpannIndexWriter { Ok(false) } + async fn try_delete_posting_list(&self, head_id: u32) -> Result<(), SpannIndexWriterError> { + let _write_guard = self.posting_list_partitioned_mutex.lock(&head_id).await; + if self.is_head_deleted(head_id as usize).await? { + return Ok(()); + } + let result = self + .posting_list_writer + .get_owned::>("", head_id) + .await; + // If the error is posting list not found, then return ok. + match result { + Ok(Some((doc_offset_ids, doc_versions, _))) => { + let mut outdated_count = 0; + for (doc_offset_id, doc_version) in doc_offset_ids.iter().zip(doc_versions.iter()) { + if self.is_outdated(*doc_offset_id, *doc_version).await? { + outdated_count += 1; + } + } + if outdated_count == doc_offset_ids.len() { + { + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard + .hnsw_index + .delete(head_id as usize) + .map_err(|e| { + tracing::error!( + "Error deleting head {} from hnsw index: {}", + head_id, + e + ); + SpannIndexWriterError::HnswIndexMutateError(e) + })?; + } + self.posting_list_writer + .delete::>("", head_id) + .await + .map_err(|e| { + tracing::error!( + "Error deleting posting list for head {}: {}", + head_id, + e + ); + SpannIndexWriterError::PostingListSetError(e) + })?; + } + } + Ok(None) => {} + Err(e) => { + tracing::error!("Error getting posting list for head {}: {}", head_id, e); + return Err(SpannIndexWriterError::PostingListGetError(e)); + } + } + Ok(()) + } + #[allow(clippy::too_many_arguments)] async fn collect_and_reassign_split_points( &self, @@ -814,6 +869,8 @@ impl SpannIndexWriter { .await?; } } + // Delete head if all points were moved out. + self.try_delete_posting_list(new_head_ids[k] as u32).await?; } Ok(assigned_ids) } @@ -946,17 +1003,20 @@ impl SpannIndexWriter { let doc_versions; let doc_embeddings; { - // TODO(Sanket): Check if head is deleted, can happen if another concurrent thread - // deletes it. - (doc_offset_ids, doc_versions, doc_embeddings) = self + let result = self .posting_list_writer .get_owned::>("", head_id as u32) - .await - .map_err(|e| { - tracing::error!("Error getting posting list for head {}: {}", head_id, e); - SpannIndexWriterError::PostingListGetError(e) - })? - .ok_or(SpannIndexWriterError::PostingListNotFound)?; + .await; + match result { + Ok(Some((offset_ids, versions, embeddings))) => { + doc_offset_ids = offset_ids; + doc_versions = versions; + doc_embeddings = embeddings; + } + // Posting list can be concurrent deleted so bail out early if not found. + Ok(None) => return Ok(()), + Err(e) => return Err(SpannIndexWriterError::PostingListGetError(e)), + } } for (index, doc_offset_id) in doc_offset_ids.iter().enumerate() { if assigned_ids.contains(doc_offset_id) @@ -1004,6 +1064,8 @@ impl SpannIndexWriter { ) .await?; } + // Delete head if all points were moved out. + self.try_delete_posting_list(head_id as u32).await?; Ok(()) } @@ -1264,6 +1326,7 @@ impl SpannIndexWriter { if !same_head && distance_function .distance(&clustering_output.cluster_centers[k], &head_embedding) + .abs() < 1e-6 { same_head = true; @@ -1350,17 +1413,32 @@ impl SpannIndexWriter { } if !same_head { // Delete the old head - let hnsw_write_guard = self.hnsw_index.inner.write(); - hnsw_write_guard - .hnsw_index - .delete(head_id as usize) + // First delete from hnsw then from postings list. This order + // ensures that the head is never dangling. + { + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard + .hnsw_index + .delete(head_id as usize) + .map_err(|e| { + tracing::error!( + "Error deleting head {} from hnsw index: {}", + head_id, + e + ); + SpannIndexWriterError::HnswIndexMutateError(e) + })?; + } + self.posting_list_writer + .delete::>("", head_id) + .await .map_err(|e| { tracing::error!( - "Error deleting head {} from hnsw index: {}", + "Error deleting posting list for head {}: {}", head_id, e ); - SpannIndexWriterError::HnswIndexMutateError(e) + SpannIndexWriterError::PostingListSetError(e) })?; self.stats .num_heads_deleted @@ -1755,12 +1833,29 @@ impl SpannIndexWriter { self.stats .num_pl_modified .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // Delete from hnsw. - let hnsw_write_guard = self.hnsw_index.inner.write(); - hnsw_write_guard.hnsw_index.delete(head_id).map_err(|e| { - tracing::error!("Error deleting head {} from hnsw index: {}", head_id, e); - SpannIndexWriterError::HnswIndexMutateError(e) - })?; + { + // Delete from hnsw. + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard.hnsw_index.delete(head_id).map_err(|e| { + tracing::error!( + "Error deleting head {} from hnsw index: {}", + head_id, + e + ); + SpannIndexWriterError::HnswIndexMutateError(e) + })?; + } + self.posting_list_writer + .delete::>("", head_id as u32) + .await + .map_err(|e| { + tracing::error!( + "Error deleting posting list for head {}: {}", + head_id, + e + ); + SpannIndexWriterError::PostingListSetError(e) + })?; self.stats .num_heads_deleted .fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -1779,18 +1874,31 @@ impl SpannIndexWriter { self.stats .num_pl_modified .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - // Delete from hnsw. - let hnsw_write_guard = self.hnsw_index.inner.write(); - hnsw_write_guard - .hnsw_index - .delete(nearest_head_id) + { + // Delete from hnsw. + let hnsw_write_guard = self.hnsw_index.inner.write(); + hnsw_write_guard + .hnsw_index + .delete(nearest_head_id) + .map_err(|e| { + tracing::error!( + "Error deleting head {} from hnsw index: {}", + nearest_head_id, + e + ); + SpannIndexWriterError::HnswIndexMutateError(e) + })?; + } + self.posting_list_writer + .delete::>("", nearest_head_id as u32) + .await .map_err(|e| { tracing::error!( - "Error deleting head {} from hnsw index: {}", + "Error deleting posting list for head {}: {}", nearest_head_id, e ); - SpannIndexWriterError::HnswIndexMutateError(e) + SpannIndexWriterError::PostingListSetError(e) })?; self.stats .num_heads_deleted @@ -3583,7 +3691,7 @@ mod tests { } #[tokio::test] - async fn test_reassign() { + async fn test_reassign_and_delete_center() { let tmp_dir = tempfile::tempdir().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_cache = new_cache_for_test(); @@ -3775,16 +3883,19 @@ mod tests { .expect("Expected reassign to succeed"); // See the reassigned points. { - // Center 1 should remain unchanged. + // Center 1 should get 100 points: original 50 + 50 reassigned from center 3. + // Points 51-100 from center 3 (near 1000,1000) get reassigned because center 2 + // was deleted, and center 1 is the only remaining nearby center. let pl = writer .posting_list_writer .get_owned::>("", 1) .await .expect("Error getting posting list") .unwrap(); - assert_eq!(pl.0.len(), 50); - assert_eq!(pl.1.len(), 50); - assert_eq!(pl.2.len(), 100); + assert_eq!(pl.0.len(), 100); + assert_eq!(pl.1.len(), 100); + assert_eq!(pl.2.len(), 200); + // First 50 are original points 1-50 at version 1 for i in 1..=50 { assert_eq!(pl.0[i - 1], i as u32); assert_eq!(pl.1[i - 1], 1); @@ -3794,28 +3905,26 @@ mod tests { split_doc_embeddings1[(i - 1) * 2 + 1] ); } - // Center 2 should get 50 points, all with version 2 migrating from center 3. - let pl = writer - .posting_list_writer - .get_owned::>("", 2) - .await - .expect("Error getting posting list") - .unwrap(); - assert_eq!(pl.0.len(), 50); - assert_eq!(pl.1.len(), 50); - assert_eq!(pl.2.len(), 100); - for i in 1..=50 { - assert_eq!(pl.0[i - 1], 50 + i as u32); + // Next 50 are reassigned points 51-100 at version 2 (from center 3) + for i in 51..=100 { + assert_eq!(pl.0[i - 1], i as u32); assert_eq!(pl.1[i - 1], 2); - assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 1) * 2]); + assert_eq!(pl.2[(i - 1) * 2], split_doc_embeddings3[(i - 51) * 2]); assert_eq!( pl.2[(i - 1) * 2 + 1], - split_doc_embeddings3[(i - 1) * 2 + 1] + split_doc_embeddings3[(i - 51) * 2 + 1] ); } - // Center 3 should get 100 points. 50 points with version 1 which weere - // originally in center 3 and 50 points with version 2 which were originally - // in center 2. + // Center 2 should be deleted (all its original points were reassigned out). + let pl = writer + .posting_list_writer + .get_owned::>("", 2) + .await + .expect("Error getting posting list"); + assert!(pl.is_none()); + // Center 3 should get 100 points. 50 points with version 1 which were + // originally in center 3 (now outdated since reassigned to center 1) and + // 50 points with version 2 which were originally in center 2. let pl = writer .posting_list_writer .get_owned::>("", 3)