Skip to content

Commit 28890fe

Browse files
committed
[CLN] Use InternalUpdateConfiguration in Rust, cleanup go code
1 parent 2604be9 commit 28890fe

File tree

13 files changed

+335
-76
lines changed

13 files changed

+335
-76
lines changed

chromadb/test/configurations/test_collection_configuration.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ def test_hnsw_configuration_updates(client: ClientAPI) -> None:
281281
assert hnsw_config.get("ef_construction") == 100
282282
assert hnsw_config.get("max_neighbors") == 16
283283

284+
coll = client.get_collection(name="test_updates")
285+
loaded_config = coll.configuration_json
286+
if loaded_config and isinstance(loaded_config, dict):
287+
hnsw_config = loaded_config.get("hnsw", {})
288+
if isinstance(hnsw_config, dict):
289+
assert hnsw_config.get("ef_search") == 20
290+
assert hnsw_config.get("space") == "cosine"
291+
assert hnsw_config.get("ef_construction") == 100
292+
assert hnsw_config.get("max_neighbors") == 16
293+
284294

285295
def test_configuration_persistence(client_factories: "ClientFactories") -> None:
286296
"""Test configuration persistence across client restarts"""
@@ -712,6 +722,15 @@ def test_configuration_spann_updates(client: ClientAPI) -> None:
712722
# Original values should remain unchanged
713723
assert spann_config.get("space") == "cosine"
714724

725+
coll = client.get_collection("test_spann_updates")
726+
loaded_config = coll.configuration_json
727+
if loaded_config and isinstance(loaded_config, dict):
728+
spann_config = loaded_config.get("spann", {})
729+
if isinstance(spann_config, dict):
730+
assert spann_config.get("ef_search") == 150
731+
assert spann_config.get("search_nprobe") == 20
732+
assert spann_config.get("space") == "cosine"
733+
715734

716735
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
717736
def test_spann_update_from_json(client: ClientAPI) -> None:
@@ -757,6 +776,21 @@ def test_spann_update_from_json(client: ClientAPI) -> None:
757776
assert spann_config.get("max_neighbors") == 12
758777
assert spann_config.get("write_nprobe") == 20
759778

779+
coll = client.get_collection("test_spann_json_update")
780+
loaded_config = coll.configuration_json
781+
if loaded_config and isinstance(loaded_config, dict):
782+
spann_config = loaded_config.get("spann", {})
783+
if isinstance(spann_config, dict):
784+
# Updated values
785+
assert spann_config.get("ef_search") == 200
786+
assert spann_config.get("search_nprobe") == 15
787+
788+
# Unchanged values
789+
assert spann_config.get("space") == "cosine"
790+
assert spann_config.get("ef_construction") == 150
791+
assert spann_config.get("max_neighbors") == 12
792+
assert spann_config.get("write_nprobe") == 20
793+
760794

761795
def test_overwrite_spann_configuration() -> None:
762796
"""Test the overwrite_spann_configuration function directly"""

go/pkg/sysdb/coordinator/model/collection_configuration.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
package model
22

33
type EmbeddingFunctionConfiguration struct {
4-
Type string `json:"type"`
5-
Config *EmbeddingFunctionNewConfiguration `json:"config,omitempty"`
6-
}
7-
8-
type EmbeddingFunctionNewConfiguration struct {
4+
Type string `json:"type"`
95
Name string `json:"name"`
106
Config interface{} `json:"config"`
117
}
128

139
type VectorIndexConfiguration struct {
14-
Type string `json:"type"`
1510
Hnsw *HnswConfiguration `json:"hnsw,omitempty"`
1611
Spann *SpannConfiguration `json:"spann,omitempty"`
1712
}
@@ -62,7 +57,6 @@ type InternalCollectionConfiguration struct {
6257
func DefaultHnswCollectionConfiguration() *InternalCollectionConfiguration {
6358
return &InternalCollectionConfiguration{
6459
VectorIndex: &VectorIndexConfiguration{
65-
Type: "hnsw",
6660
Hnsw: DefaultHnswConfiguration(),
6761
},
6862
}
@@ -119,7 +113,6 @@ type UpdateSpannConfiguration struct {
119113
}
120114

121115
type UpdateVectorIndexConfiguration struct {
122-
Type string `json:"type"`
123116
Hnsw *UpdateHnswConfiguration `json:"hnsw,omitempty"`
124117
Spann *UpdateSpannConfiguration `json:"spann,omitempty"`
125118
}

go/pkg/sysdb/coordinator/table_catalog.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,8 @@ func (tc *Catalog) updateCollectionConfiguration(
815815

816816
// Update existing configuration with new values
817817
if updateConfig.VectorIndex != nil {
818-
if updateConfig.VectorIndex.Type == "hnsw" && updateConfig.VectorIndex.Hnsw != nil {
819-
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Type != "hnsw" {
818+
if updateConfig.VectorIndex.Hnsw != nil {
819+
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Hnsw == nil {
820820
return existingConfigJsonStr, nil
821821
}
822822
if updateConfig.VectorIndex.Hnsw.EfSearch != nil {
@@ -837,8 +837,8 @@ func (tc *Catalog) updateCollectionConfiguration(
837837
if updateConfig.VectorIndex.Hnsw.BatchSize != nil {
838838
existingConfig.VectorIndex.Hnsw.BatchSize = *updateConfig.VectorIndex.Hnsw.BatchSize
839839
}
840-
} else if updateConfig.VectorIndex.Type == "spann" && updateConfig.VectorIndex.Spann != nil {
841-
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Type != "spann" {
840+
} else if updateConfig.VectorIndex.Spann != nil {
841+
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Spann == nil {
842842
return existingConfigJsonStr, nil
843843
}
844844
if updateConfig.VectorIndex.Spann.EfSearch != nil {

go/pkg/sysdb/coordinator/table_catalog_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,12 +1425,10 @@ func TestUpdateCollectionConfiguration(t *testing.T) {
14251425
assert.NoError(t, err)
14261426

14271427
if tt.expectedHnswConfig != nil {
1428-
assert.Equal(t, "hnsw", config.VectorIndex.Type)
14291428
assert.Equal(t, tt.expectedHnswConfig, config.VectorIndex.Hnsw)
14301429
}
14311430

14321431
if tt.expectedSpannConfig != nil {
1433-
assert.Equal(t, "spann", config.VectorIndex.Type)
14341432
assert.Equal(t, tt.expectedSpannConfig, config.VectorIndex.Spann)
14351433
}
14361434
})

go/pkg/sysdb/metastore/db/dao/collection.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ func generateCollectionUpdatesWithoutID(in *dbmodel.Collection) map[string]inter
428428
if in.Name != nil {
429429
ret["name"] = *in.Name
430430
}
431+
if in.ConfigurationJsonStr != nil {
432+
ret["configuration_json_str"] = *in.ConfigurationJsonStr
433+
}
431434
if in.Dimension != nil {
432435
ret["dimension"] = *in.Dimension
433436
}

go/pkg/sysdb/metastore/db/dao/collection_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,50 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollectionByResourceName
332332
suite.NoError(err)
333333
}
334334

335+
func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateConfigurationJsonStr() {
336+
collectionName := "test_collection_update_config"
337+
dim := int32(128)
338+
collectionID, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, dim, suite.databaseId, nil))
339+
suite.NoError(err)
340+
341+
collections, err := suite.collectionDb.GetCollections([]string{collectionID}, nil, "", "", nil, nil, false)
342+
suite.NoError(err)
343+
suite.Len(collections, 1)
344+
defaultConfig := "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
345+
suite.Equal(&defaultConfig, collections[0].Collection.ConfigurationJsonStr)
346+
347+
newConfig := "{\"c\": \"param3\", \"d\": \"param3\", \"4\": false}"
348+
err = suite.collectionDb.Update(&dbmodel.Collection{
349+
ID: collectionID,
350+
DatabaseID: suite.databaseId,
351+
ConfigurationJsonStr: &newConfig,
352+
UpdatedAt: time.Now(),
353+
})
354+
suite.NoError(err)
355+
356+
collections, err = suite.collectionDb.GetCollections([]string{collectionID}, nil, "", "", nil, nil, false)
357+
suite.NoError(err)
358+
suite.Len(collections, 1)
359+
suite.Equal(&newConfig, collections[0].Collection.ConfigurationJsonStr)
360+
361+
emptyConfig := ""
362+
err = suite.collectionDb.Update(&dbmodel.Collection{
363+
ID: collectionID,
364+
DatabaseID: suite.databaseId,
365+
ConfigurationJsonStr: &emptyConfig,
366+
UpdatedAt: time.Now(),
367+
})
368+
suite.NoError(err)
369+
370+
collections, err = suite.collectionDb.GetCollections([]string{collectionID}, nil, "", "", nil, nil, false)
371+
suite.NoError(err)
372+
suite.Len(collections, 1)
373+
suite.Equal(&emptyConfig, collections[0].Collection.ConfigurationJsonStr)
374+
375+
err = CleanUpTestCollection(suite.db, collectionID)
376+
suite.NoError(err)
377+
}
378+
335379
func TestCollectionDbTestSuiteSuite(t *testing.T) {
336380
testSuite := new(CollectionDbTestSuite)
337381
suite.Run(t, testSuite)

rust/frontend/src/server.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ use chroma_types::{
1818
DeleteCollectionRecordsResponse, DeleteDatabaseRequest, DeleteDatabaseResponse,
1919
GetCollectionRequest, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse,
2020
GetTenantRequest, GetTenantResponse, GetUserIdentityResponse, HeartbeatResponse, IncludeList,
21-
InternalCollectionConfiguration, ListCollectionsRequest, ListCollectionsResponse,
22-
ListDatabasesRequest, ListDatabasesResponse, Metadata, QueryRequest, QueryResponse,
23-
UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, UpdateCollectionResponse,
24-
UpdateMetadata, UpsertCollectionRecordsResponse,
21+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListCollectionsRequest,
22+
ListCollectionsResponse, ListDatabasesRequest, ListDatabasesResponse, Metadata, QueryRequest,
23+
QueryResponse, UpdateCollectionConfiguration, UpdateCollectionRecordsResponse,
24+
UpdateCollectionResponse, UpdateMetadata, UpsertCollectionRecordsResponse,
2525
};
2626
use chroma_types::{ForkCollectionResponse, RawWhereFields};
2727
use mdac::{Rule, Scorecard, ScorecardTicket};
@@ -1101,13 +1101,18 @@ async fn update_collection(
11011101
let collection_id =
11021102
CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?;
11031103

1104+
let configuration = match payload.new_configuration {
1105+
Some(c) => Some(InternalUpdateCollectionConfiguration::try_from(c)?),
1106+
None => None,
1107+
};
1108+
11041109
let request = chroma_types::UpdateCollectionRequest::try_new(
11051110
collection_id,
11061111
payload.new_name,
11071112
payload
11081113
.new_metadata
11091114
.map(CollectionMetadataUpdate::UpdateMetadata),
1110-
payload.new_configuration,
1115+
configuration,
11111116
)?;
11121117

11131118
server.frontend.update_collection(request).await?;

rust/python_bindings/src/bindings.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ use chroma_types::{
1919
CountResponse, CreateCollectionRequest, CreateDatabaseRequest, CreateTenantRequest, Database,
2020
DeleteCollectionRequest, DeleteDatabaseRequest, GetCollectionRequest, GetDatabaseRequest,
2121
GetResponse, GetTenantRequest, GetTenantResponse, HeartbeatError, IncludeList,
22-
InternalCollectionConfiguration, KnnIndex, ListCollectionsRequest, ListDatabasesRequest,
23-
Metadata, QueryResponse, UpdateCollectionConfiguration, UpdateCollectionRequest,
24-
UpdateMetadata, WrappedSerdeJsonError,
22+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, KnnIndex,
23+
ListCollectionsRequest, ListDatabasesRequest, Metadata, QueryResponse,
24+
UpdateCollectionConfiguration, UpdateCollectionRequest, UpdateMetadata, WrappedSerdeJsonError,
2525
};
2626
use pyo3::{exceptions::PyValueError, pyclass, pyfunction, pymethods, types::PyAnyMethods, Python};
2727
use std::time::SystemTime;
@@ -344,11 +344,16 @@ impl Bindings {
344344
None => None,
345345
};
346346

347+
let configuration = match configuration_json {
348+
Some(c) => Some(InternalUpdateCollectionConfiguration::try_from(c)?),
349+
None => None,
350+
};
351+
347352
let request = UpdateCollectionRequest::try_new(
348353
collection_id,
349354
new_name,
350355
new_metadata.map(CollectionMetadataUpdate::UpdateMetadata),
351-
configuration_json,
356+
configuration,
352357
)?;
353358

354359
let mut frontend = self.frontend.clone();

rust/sysdb/src/sqlite.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use chroma_types::{
1212
CreateTenantError, CreateTenantResponse, Database, DatabaseUuid, DeleteCollectionError,
1313
DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError,
1414
GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse,
15-
InternalCollectionConfiguration, ListDatabasesError, Metadata, MetadataValue, ResetError,
16-
ResetResponse, Segment, SegmentScope, SegmentType, SegmentUuid, UpdateCollectionConfiguration,
17-
UpdateCollectionError,
15+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError,
16+
Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
17+
SegmentUuid, UpdateCollectionError,
1818
};
1919
use futures::TryStreamExt;
2020
use sea_query_binder::SqlxBinder;
@@ -356,7 +356,7 @@ impl SqliteSysDb {
356356
name: Option<String>,
357357
metadata: Option<CollectionMetadataUpdate>,
358358
dimension: Option<u32>,
359-
configuration: Option<UpdateCollectionConfiguration>,
359+
configuration: Option<InternalUpdateCollectionConfiguration>,
360360
) -> Result<(), UpdateCollectionError> {
361361
let mut tx = self
362362
.db
@@ -1048,8 +1048,9 @@ mod tests {
10481048
use super::*;
10491049
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
10501050
use chroma_types::{
1051-
SegmentScope, SegmentType, SegmentUuid, UpdateHnswConfiguration, UpdateMetadata,
1052-
UpdateMetadataValue, VectorIndexConfiguration,
1051+
InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid,
1052+
UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue,
1053+
UpdateVectorIndexConfiguration, VectorIndexConfiguration,
10531054
};
10541055

10551056
#[tokio::test]
@@ -1354,13 +1355,14 @@ mod tests {
13541355
Some("new_name".to_string()),
13551356
Some(CollectionMetadataUpdate::UpdateMetadata(metadata)),
13561357
Some(1024),
1357-
Some(UpdateCollectionConfiguration {
1358-
hnsw: Some(UpdateHnswConfiguration {
1359-
ef_search: Some(20),
1360-
num_threads: Some(4),
1361-
..Default::default()
1362-
}),
1363-
spann: None,
1358+
Some(InternalUpdateCollectionConfiguration {
1359+
vector_index: Some(UpdateVectorIndexConfiguration::Hnsw(Some(
1360+
UpdateHnswConfiguration {
1361+
ef_search: Some(10),
1362+
num_threads: Some(2),
1363+
..Default::default()
1364+
},
1365+
))),
13641366
embedding_function: None,
13651367
}),
13661368
)
@@ -1387,7 +1389,7 @@ mod tests {
13871389
// Access HNSW configuration through pattern matching
13881390
match &collection.config.vector_index {
13891391
VectorIndexConfiguration::Hnsw(hnsw) => {
1390-
assert_eq!(hnsw.ef_search, 20);
1392+
assert_eq!(hnsw.ef_search, 10);
13911393
}
13921394
_ => panic!("Expected HNSW configuration"),
13931395
}

rust/sysdb/src/sysdb.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ use chroma_types::{
1414
DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionSizeError,
1515
GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, GetDatabaseResponse,
1616
GetSegmentsError, GetTenantError, GetTenantResponse, InternalCollectionConfiguration,
17-
ListCollectionVersionsError, ListDatabasesError, ListDatabasesResponse, Metadata, ResetError,
18-
ResetResponse, SegmentFlushInfo, SegmentFlushInfoConversionError, SegmentUuid,
19-
UpdateCollectionConfiguration, UpdateCollectionError, VectorIndexConfiguration,
17+
InternalUpdateCollectionConfiguration, ListCollectionVersionsError, ListDatabasesError,
18+
ListDatabasesResponse, Metadata, ResetError, ResetResponse, SegmentFlushInfo,
19+
SegmentFlushInfoConversionError, SegmentUuid, UpdateCollectionError, VectorIndexConfiguration,
2020
};
2121
use chroma_types::{
2222
BatchGetCollectionSoftDeleteStatusError, BatchGetCollectionVersionFilePathsError, Collection,
@@ -284,7 +284,7 @@ impl SysDb {
284284
name: Option<String>,
285285
metadata: Option<CollectionMetadataUpdate>,
286286
dimension: Option<u32>,
287-
configuration: Option<UpdateCollectionConfiguration>,
287+
configuration: Option<InternalUpdateCollectionConfiguration>,
288288
) -> Result<(), UpdateCollectionError> {
289289
match self {
290290
SysDb::Grpc(grpc) => {
@@ -969,7 +969,7 @@ impl GrpcSysDb {
969969
name: Option<String>,
970970
metadata: Option<CollectionMetadataUpdate>,
971971
dimension: Option<u32>,
972-
configuration: Option<UpdateCollectionConfiguration>,
972+
configuration: Option<InternalUpdateCollectionConfiguration>,
973973
) -> Result<(), UpdateCollectionError> {
974974
let mut configuration_json_str = None;
975975
if let Some(configuration) = configuration {

0 commit comments

Comments
 (0)