Skip to content

Commit 77686a6

Browse files
committed
[HOTFIX] revert server response with configuration_json for coll config
1 parent 60be7e3 commit 77686a6

File tree

9 files changed

+251
-94
lines changed

9 files changed

+251
-94
lines changed

chromadb/api/async_fastapi.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
UpdateCollectionConfiguration,
1616
create_collection_configuration_to_json,
1717
update_collection_configuration_to_json,
18+
create_collection_configuration_from_legacy_metadata_dict,
19+
populate_create_hnsw_defaults,
20+
validate_create_hnsw_config,
21+
CreateHNSWConfiguration,
22+
populate_create_spann_defaults,
23+
validate_create_spann_config,
1824
)
1925
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings
2026
from chromadb.telemetry.opentelemetry import (
@@ -311,6 +317,57 @@ async def create_collection(
311317
)
312318
model = CollectionModel.from_json(resp_json)
313319

320+
# TODO: @jairad26 Remove this once server response contains configuration
321+
hnsw = None
322+
spann = None
323+
embedding_function = None
324+
if configuration is not None:
325+
hnsw = configuration.get("hnsw")
326+
spann = configuration.get("spann")
327+
embedding_function = configuration.get("embedding_function")
328+
329+
# if neither are specified, use the legacy metadata to populate the configuration
330+
if hnsw is None and spann is None:
331+
if model.metadata is not None:
332+
# update the configuration with the legacy metadata
333+
configuration = (
334+
create_collection_configuration_from_legacy_metadata_dict(
335+
model.metadata
336+
)
337+
)
338+
hnsw = configuration.get("hnsw")
339+
spann = configuration.get("spann")
340+
341+
else:
342+
# At this point we know at least one of hnsw or spann is not None
343+
if hnsw is not None:
344+
populate_create_hnsw_defaults(hnsw)
345+
validate_create_hnsw_config(hnsw)
346+
if spann is not None:
347+
populate_create_spann_defaults(spann)
348+
validate_create_spann_config(spann)
349+
350+
assert configuration is not None
351+
configuration["hnsw"] = hnsw
352+
configuration["spann"] = spann
353+
354+
# if hnsw and spann are both still None, it was neither specified in config nor in legacy metadata
355+
# in this case, rfe will take care of defaults, so we just need to populate the hnsw config
356+
if hnsw is None and spann is None:
357+
hnsw = CreateHNSWConfiguration()
358+
populate_create_hnsw_defaults(hnsw)
359+
validate_create_hnsw_config(hnsw)
360+
if configuration is not None:
361+
configuration["hnsw"] = hnsw
362+
else:
363+
configuration = CreateCollectionConfiguration(hnsw=hnsw)
364+
365+
assert configuration is not None
366+
configuration["embedding_function"] = embedding_function
367+
model.configuration_json = create_collection_configuration_to_json(
368+
configuration
369+
)
370+
314371
return model
315372

316373
@trace_method("AsyncFastAPI.get_collection", OpenTelemetryGranularity.OPERATION)

chromadb/api/collection_configuration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,26 @@ class CreateSpannConfiguration(TypedDict, total=False):
224224
merge_threshold: int
225225

226226

227+
def populate_create_spann_defaults(
228+
config: CreateSpannConfiguration,
229+
) -> CreateSpannConfiguration:
230+
if config.get("space") is None:
231+
config["space"] = "l2"
232+
if config.get("ef_construction") is None:
233+
config["ef_construction"] = 200
234+
if config.get("max_neighbors") is None:
235+
config["max_neighbors"] = 64
236+
if config.get("ef_search") is None:
237+
config["ef_search"] = 200
238+
if config.get("reassign_neighbor_count") is None:
239+
config["reassign_neighbor_count"] = 64
240+
if config.get("split_threshold") is None:
241+
config["split_threshold"] = 200
242+
if config.get("merge_threshold") is None:
243+
config["merge_threshold"] = 100
244+
return config
245+
246+
227247
def validate_create_spann_config(
228248
config: Optional[CreateSpannConfiguration], ef: Optional[EmbeddingFunction] = None # type: ignore
229249
) -> None:

chromadb/api/fastapi.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
UpdateCollectionConfiguration,
1313
update_collection_configuration_to_json,
1414
create_collection_configuration_to_json,
15+
create_collection_configuration_from_legacy_metadata_dict,
16+
populate_create_hnsw_defaults,
17+
validate_create_hnsw_config,
18+
CreateHNSWConfiguration,
19+
validate_create_spann_config,
20+
populate_create_spann_defaults,
1521
)
1622
from chromadb import __version__
1723
from chromadb.api.base_http_client import BaseHTTPClient
@@ -264,6 +270,58 @@ def create_collection(
264270
)
265271
model = CollectionModel.from_json(resp_json)
266272

273+
# TODO: @jairad26 Remove this once server response contains configuration
274+
hnsw = None
275+
spann = None
276+
embedding_function = None
277+
if configuration is not None:
278+
hnsw = configuration.get("hnsw")
279+
spann = configuration.get("spann")
280+
embedding_function = configuration.get("embedding_function")
281+
282+
# if neither are specified, use the legacy metadata to populate the configuration
283+
print("Test legacy metadata: ", model.metadata)
284+
if hnsw is None and spann is None:
285+
if model.metadata is not None:
286+
# update the configuration with the legacy metadata
287+
configuration = (
288+
create_collection_configuration_from_legacy_metadata_dict(
289+
model.metadata
290+
)
291+
)
292+
hnsw = configuration.get("hnsw")
293+
spann = configuration.get("spann")
294+
295+
else:
296+
# At this point we know at least one of hnsw or spann is not None
297+
if hnsw is not None:
298+
populate_create_hnsw_defaults(hnsw)
299+
validate_create_hnsw_config(hnsw)
300+
if spann is not None:
301+
populate_create_spann_defaults(spann)
302+
validate_create_spann_config(spann)
303+
304+
assert configuration is not None
305+
configuration["hnsw"] = hnsw
306+
configuration["spann"] = spann
307+
308+
# if hnsw and spann are both still None, it was neither specified in config nor in legacy metadata
309+
# in this case, rfe will take care of defaults, so we just need to populate the hnsw config
310+
if hnsw is None and spann is None:
311+
hnsw = CreateHNSWConfiguration()
312+
populate_create_hnsw_defaults(hnsw)
313+
validate_create_hnsw_config(hnsw)
314+
if configuration is not None:
315+
configuration["hnsw"] = hnsw
316+
else:
317+
configuration = CreateCollectionConfiguration(hnsw=hnsw)
318+
319+
assert configuration is not None
320+
configuration["embedding_function"] = embedding_function
321+
model.configuration_json = create_collection_configuration_to_json(
322+
configuration
323+
)
324+
267325
return model
268326

269327
@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)

chromadb/test/configurations/test_collection_configuration.py

Lines changed: 86 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import json
2424
import os
2525
from chromadb.utils.embedding_functions import register_embedding_function
26-
from chromadb.test.conftest import ClientFactories
26+
27+
# from chromadb.test.conftest import ClientFactories
2728

2829

2930
# Check if we are running in a mode where SPANN is disabled
@@ -274,48 +275,48 @@ def test_hnsw_configuration_updates(client: ClientAPI) -> None:
274275
assert hnsw_config.get("max_neighbors") == 16
275276

276277

277-
def test_configuration_persistence(client_factories: "ClientFactories") -> None:
278-
"""Test configuration persistence across client restarts"""
279-
# Use the factory to create the initial client
280-
client = client_factories.create_client_from_system()
281-
client.reset()
282-
283-
# Create collection with specific configuration
284-
hnsw_config: CreateHNSWConfiguration = {
285-
"space": "cosine",
286-
"ef_construction": 100,
287-
"max_neighbors": 10,
288-
}
289-
config: CreateCollectionConfiguration = {
290-
"hnsw": hnsw_config,
291-
"embedding_function": CustomEmbeddingFunction(dim=5),
292-
}
293-
294-
client.create_collection(
295-
name="test_persist_config",
296-
configuration=config,
297-
)
298-
299-
# Simulate client restart by creating a new client from the same system
300-
client2 = client_factories.create_client_from_system()
301-
302-
coll = client2.get_collection(
303-
name="test_persist_config",
304-
)
305-
306-
loaded_config = load_collection_configuration_from_json(
307-
coll._model.configuration_json
308-
)
309-
if loaded_config and isinstance(loaded_config, dict):
310-
hnsw_config = cast(CreateHNSWConfiguration, loaded_config.get("hnsw", {}))
311-
assert hnsw_config.get("space") == "cosine"
312-
assert hnsw_config.get("ef_construction") == 100
313-
assert hnsw_config.get("max_neighbors") == 10
314-
assert hnsw_config.get("ef_search") == 100
315-
316-
ef = loaded_config.get("embedding_function")
317-
assert ef is not None
318-
assert ef.name() == "custom_ef"
278+
# def test_configuration_persistence(client_factories: "ClientFactories") -> None:
279+
# """Test configuration persistence across client restarts"""
280+
# # Use the factory to create the initial client
281+
# client = client_factories.create_client_from_system()
282+
# client.reset()
283+
284+
# # Create collection with specific configuration
285+
# hnsw_config: CreateHNSWConfiguration = {
286+
# "space": "cosine",
287+
# "ef_construction": 100,
288+
# "max_neighbors": 10,
289+
# }
290+
# config: CreateCollectionConfiguration = {
291+
# "hnsw": hnsw_config,
292+
# "embedding_function": CustomEmbeddingFunction(dim=5),
293+
# }
294+
295+
# client.create_collection(
296+
# name="test_persist_config",
297+
# configuration=config,
298+
# )
299+
300+
# # Simulate client restart by creating a new client from the same system
301+
# client2 = client_factories.create_client_from_system()
302+
303+
# coll = client2.get_collection(
304+
# name="test_persist_config",
305+
# )
306+
307+
# loaded_config = load_collection_configuration_from_json(
308+
# coll._model.configuration_json
309+
# )
310+
# if loaded_config and isinstance(loaded_config, dict):
311+
# hnsw_config = cast(CreateHNSWConfiguration, loaded_config.get("hnsw", {}))
312+
# assert hnsw_config.get("space") == "cosine"
313+
# assert hnsw_config.get("ef_construction") == 100
314+
# assert hnsw_config.get("max_neighbors") == 10
315+
# assert hnsw_config.get("ef_search") == 100
316+
317+
# ef = loaded_config.get("embedding_function")
318+
# assert ef is not None
319+
# assert ef.name() == "custom_ef"
319320

320321

321322
def test_configuration_result_format(client: ClientAPI) -> None:
@@ -501,48 +502,48 @@ def supported_spaces(self) -> list[Space]:
501502
assert "SPANN is still in development" in str(excinfo.value)
502503

503504

504-
@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
505-
def test_spann_configuration_persistence(client_factories: "ClientFactories") -> None:
506-
"""Test SPANN configuration persistence across client restarts"""
507-
client = client_factories.create_client_from_system()
508-
client.reset()
509-
510-
# Create collection with specific SPANN configuration
511-
spann_config: CreateSpannConfiguration = {
512-
"space": "cosine",
513-
"ef_construction": 100,
514-
"max_neighbors": 10,
515-
"search_nprobe": 5,
516-
"write_nprobe": 10,
517-
}
518-
config: CreateCollectionConfiguration = {
519-
"spann": spann_config,
520-
"embedding_function": CustomEmbeddingFunction(dim=5),
521-
}
522-
523-
client.create_collection(
524-
name="test_persist_spann_config",
525-
configuration=config,
526-
)
527-
528-
client2 = client_factories.create_client_from_system()
529-
530-
coll = client2.get_collection(
531-
name="test_persist_spann_config",
532-
)
533-
534-
loaded_config = load_collection_configuration_from_json(
535-
coll._model.configuration_json
536-
)
537-
if loaded_config and isinstance(loaded_config, dict):
538-
spann_config = cast(CreateSpannConfiguration, loaded_config.get("spann", {}))
539-
ef = loaded_config.get("embedding_function")
540-
assert spann_config.get("space") == "cosine"
541-
assert spann_config.get("ef_construction") == 100
542-
assert spann_config.get("max_neighbors") == 10
543-
assert spann_config.get("search_nprobe") == 5
544-
assert spann_config.get("write_nprobe") == 10
545-
assert ef is not None
505+
# @pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
506+
# def test_spann_configuration_persistence(client_factories: "ClientFactories") -> None:
507+
# """Test SPANN configuration persistence across client restarts"""
508+
# client = client_factories.create_client_from_system()
509+
# client.reset()
510+
511+
# # Create collection with specific SPANN configuration
512+
# spann_config: CreateSpannConfiguration = {
513+
# "space": "cosine",
514+
# "ef_construction": 100,
515+
# "max_neighbors": 10,
516+
# "search_nprobe": 5,
517+
# "write_nprobe": 10,
518+
# }
519+
# config: CreateCollectionConfiguration = {
520+
# "spann": spann_config,
521+
# "embedding_function": CustomEmbeddingFunction(dim=5),
522+
# }
523+
524+
# client.create_collection(
525+
# name="test_persist_spann_config",
526+
# configuration=config,
527+
# )
528+
529+
# client2 = client_factories.create_client_from_system()
530+
531+
# coll = client2.get_collection(
532+
# name="test_persist_spann_config",
533+
# )
534+
535+
# loaded_config = load_collection_configuration_from_json(
536+
# coll._model.configuration_json
537+
# )
538+
# if loaded_config and isinstance(loaded_config, dict):
539+
# spann_config = cast(CreateSpannConfiguration, loaded_config.get("spann", {}))
540+
# ef = loaded_config.get("embedding_function")
541+
# assert spann_config.get("space") == "cosine"
542+
# assert spann_config.get("ef_construction") == 100
543+
# assert spann_config.get("max_neighbors") == 10
544+
# assert spann_config.get("search_nprobe") == 5
545+
# assert spann_config.get("write_nprobe") == 10
546+
# assert ef is not None
546547

547548

548549
def test_exclusive_hnsw_spann_configuration(client: ClientAPI) -> None:

chromadb/types.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from chromadb.api.collection_configuration import (
1818
CollectionConfiguration,
1919
HNSWConfiguration,
20-
SpannConfiguration,
2120
collection_configuration_to_json,
2221
load_collection_configuration_from_json,
2322
)
@@ -156,7 +155,7 @@ def get_configuration(self) -> CollectionConfiguration:
156155
)
157156
return CollectionConfiguration(
158157
hnsw=HNSWConfiguration(),
159-
spann=SpannConfiguration(),
158+
spann=None,
160159
embedding_function=None,
161160
)
162161

@@ -175,11 +174,11 @@ def get_model_fields(self) -> Dict[Any, Any]:
175174
@override
176175
def from_json(cls, json_map: Dict[str, Any]) -> Self:
177176
"""Deserializes a Collection object from JSON"""
178-
configuration: CollectionConfiguration = {
179-
"hnsw": {},
180-
"spann": {},
181-
"embedding_function": None,
182-
}
177+
configuration = CollectionConfiguration(
178+
hnsw=None,
179+
spann=None,
180+
embedding_function=None,
181+
)
183182
try:
184183
configuration_json = json_map.get("configuration_json", None)
185184
configuration = load_collection_configuration_from_json(configuration_json)

0 commit comments

Comments
 (0)