Skip to content

[HOTFIX] revert server response with configuration_json for coll config #4328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
UpdateCollectionConfiguration,
create_collection_configuration_to_json,
update_collection_configuration_to_json,
create_collection_configuration_from_legacy_metadata_dict,
populate_create_hnsw_defaults,
validate_create_hnsw_config,
CreateHNSWConfiguration,
populate_create_spann_defaults,
validate_create_spann_config,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings
from chromadb.telemetry.opentelemetry import (
Expand Down Expand Up @@ -311,6 +317,57 @@ async def create_collection(
)
model = CollectionModel.from_json(resp_json)

# TODO: @jairad26 Remove this once server response contains configuration
hnsw = None
spann = None
embedding_function = None
if configuration is not None:
hnsw = configuration.get("hnsw")
spann = configuration.get("spann")
embedding_function = configuration.get("embedding_function")

# if neither are specified, use the legacy metadata to populate the configuration
if hnsw is None and spann is None:
if model.metadata is not None:
# update the configuration with the legacy metadata
configuration = (
create_collection_configuration_from_legacy_metadata_dict(
model.metadata
)
)
hnsw = configuration.get("hnsw")
spann = configuration.get("spann")

else:
# At this point we know at least one of hnsw or spann is not None
if hnsw is not None:
populate_create_hnsw_defaults(hnsw)
validate_create_hnsw_config(hnsw)
if spann is not None:
populate_create_spann_defaults(spann)
validate_create_spann_config(spann)

assert configuration is not None
configuration["hnsw"] = hnsw
configuration["spann"] = spann

# if hnsw and spann are both still None, it was neither specified in config nor in legacy metadata
# in this case, rfe will take care of defaults, so we just need to populate the hnsw config
if hnsw is None and spann is None:
hnsw = CreateHNSWConfiguration()
populate_create_hnsw_defaults(hnsw)
validate_create_hnsw_config(hnsw)
if configuration is not None:
configuration["hnsw"] = hnsw
else:
configuration = CreateCollectionConfiguration(hnsw=hnsw)

assert configuration is not None
configuration["embedding_function"] = embedding_function
model.configuration_json = create_collection_configuration_to_json(
configuration
)

return model

@trace_method("AsyncFastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
Expand Down
20 changes: 20 additions & 0 deletions chromadb/api/collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,26 @@ class CreateSpannConfiguration(TypedDict, total=False):
merge_threshold: int


def populate_create_spann_defaults(
config: CreateSpannConfiguration,
) -> CreateSpannConfiguration:
if config.get("space") is None:
config["space"] = "l2"
if config.get("ef_construction") is None:
config["ef_construction"] = 200
if config.get("max_neighbors") is None:
config["max_neighbors"] = 64
if config.get("ef_search") is None:
config["ef_search"] = 200
if config.get("reassign_neighbor_count") is None:
config["reassign_neighbor_count"] = 64
if config.get("split_threshold") is None:
config["split_threshold"] = 200
if config.get("merge_threshold") is None:
config["merge_threshold"] = 100
return config


def validate_create_spann_config(
config: Optional[CreateSpannConfiguration], ef: Optional[EmbeddingFunction] = None # type: ignore
) -> None:
Expand Down
58 changes: 58 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
UpdateCollectionConfiguration,
update_collection_configuration_to_json,
create_collection_configuration_to_json,
create_collection_configuration_from_legacy_metadata_dict,
populate_create_hnsw_defaults,
validate_create_hnsw_config,
CreateHNSWConfiguration,
validate_create_spann_config,
populate_create_spann_defaults,
)
from chromadb import __version__
from chromadb.api.base_http_client import BaseHTTPClient
Expand Down Expand Up @@ -264,6 +270,58 @@ def create_collection(
)
model = CollectionModel.from_json(resp_json)

# TODO: @jairad26 Remove this once server response contains configuration
hnsw = None
spann = None
embedding_function = None
if configuration is not None:
hnsw = configuration.get("hnsw")
spann = configuration.get("spann")
embedding_function = configuration.get("embedding_function")

# if neither are specified, use the legacy metadata to populate the configuration
print("Test legacy metadata: ", model.metadata)
if hnsw is None and spann is None:
if model.metadata is not None:
# update the configuration with the legacy metadata
configuration = (
create_collection_configuration_from_legacy_metadata_dict(
model.metadata
)
)
hnsw = configuration.get("hnsw")
spann = configuration.get("spann")

else:
# At this point we know at least one of hnsw or spann is not None
if hnsw is not None:
populate_create_hnsw_defaults(hnsw)
validate_create_hnsw_config(hnsw)
if spann is not None:
populate_create_spann_defaults(spann)
validate_create_spann_config(spann)

assert configuration is not None
configuration["hnsw"] = hnsw
configuration["spann"] = spann

# if hnsw and spann are both still None, it was neither specified in config nor in legacy metadata
# in this case, rfe will take care of defaults, so we just need to populate the hnsw config
if hnsw is None and spann is None:
hnsw = CreateHNSWConfiguration()
populate_create_hnsw_defaults(hnsw)
validate_create_hnsw_config(hnsw)
if configuration is not None:
configuration["hnsw"] = hnsw
else:
configuration = CreateCollectionConfiguration(hnsw=hnsw)

assert configuration is not None
configuration["embedding_function"] = embedding_function
model.configuration_json = create_collection_configuration_to_json(
configuration
)

return model

@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
Expand Down
171 changes: 86 additions & 85 deletions chromadb/test/configurations/test_collection_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import json
import os
from chromadb.utils.embedding_functions import register_embedding_function
from chromadb.test.conftest import ClientFactories

# from chromadb.test.conftest import ClientFactories


# Check if we are running in a mode where SPANN is disabled
Expand Down Expand Up @@ -274,48 +275,48 @@ def test_hnsw_configuration_updates(client: ClientAPI) -> None:
assert hnsw_config.get("max_neighbors") == 16


def test_configuration_persistence(client_factories: "ClientFactories") -> None:
"""Test configuration persistence across client restarts"""
# Use the factory to create the initial client
client = client_factories.create_client_from_system()
client.reset()

# Create collection with specific configuration
hnsw_config: CreateHNSWConfiguration = {
"space": "cosine",
"ef_construction": 100,
"max_neighbors": 10,
}
config: CreateCollectionConfiguration = {
"hnsw": hnsw_config,
"embedding_function": CustomEmbeddingFunction(dim=5),
}

client.create_collection(
name="test_persist_config",
configuration=config,
)

# Simulate client restart by creating a new client from the same system
client2 = client_factories.create_client_from_system()

coll = client2.get_collection(
name="test_persist_config",
)

loaded_config = load_collection_configuration_from_json(
coll._model.configuration_json
)
if loaded_config and isinstance(loaded_config, dict):
hnsw_config = cast(CreateHNSWConfiguration, loaded_config.get("hnsw", {}))
assert hnsw_config.get("space") == "cosine"
assert hnsw_config.get("ef_construction") == 100
assert hnsw_config.get("max_neighbors") == 10
assert hnsw_config.get("ef_search") == 100

ef = loaded_config.get("embedding_function")
assert ef is not None
assert ef.name() == "custom_ef"
# def test_configuration_persistence(client_factories: "ClientFactories") -> None:
# """Test configuration persistence across client restarts"""
# # Use the factory to create the initial client
# client = client_factories.create_client_from_system()
# client.reset()

# # Create collection with specific configuration
# hnsw_config: CreateHNSWConfiguration = {
# "space": "cosine",
# "ef_construction": 100,
# "max_neighbors": 10,
# }
# config: CreateCollectionConfiguration = {
# "hnsw": hnsw_config,
# "embedding_function": CustomEmbeddingFunction(dim=5),
# }

# client.create_collection(
# name="test_persist_config",
# configuration=config,
# )

# # Simulate client restart by creating a new client from the same system
# client2 = client_factories.create_client_from_system()

# coll = client2.get_collection(
# name="test_persist_config",
# )

# loaded_config = load_collection_configuration_from_json(
# coll._model.configuration_json
# )
# if loaded_config and isinstance(loaded_config, dict):
# hnsw_config = cast(CreateHNSWConfiguration, loaded_config.get("hnsw", {}))
# assert hnsw_config.get("space") == "cosine"
# assert hnsw_config.get("ef_construction") == 100
# assert hnsw_config.get("max_neighbors") == 10
# assert hnsw_config.get("ef_search") == 100

# ef = loaded_config.get("embedding_function")
# assert ef is not None
# assert ef.name() == "custom_ef"


def test_configuration_result_format(client: ClientAPI) -> None:
Expand Down Expand Up @@ -501,48 +502,48 @@ def supported_spaces(self) -> list[Space]:
assert "SPANN is still in development" in str(excinfo.value)


@pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
def test_spann_configuration_persistence(client_factories: "ClientFactories") -> None:
"""Test SPANN configuration persistence across client restarts"""
client = client_factories.create_client_from_system()
client.reset()

# Create collection with specific SPANN configuration
spann_config: CreateSpannConfiguration = {
"space": "cosine",
"ef_construction": 100,
"max_neighbors": 10,
"search_nprobe": 5,
"write_nprobe": 10,
}
config: CreateCollectionConfiguration = {
"spann": spann_config,
"embedding_function": CustomEmbeddingFunction(dim=5),
}

client.create_collection(
name="test_persist_spann_config",
configuration=config,
)

client2 = client_factories.create_client_from_system()

coll = client2.get_collection(
name="test_persist_spann_config",
)

loaded_config = load_collection_configuration_from_json(
coll._model.configuration_json
)
if loaded_config and isinstance(loaded_config, dict):
spann_config = cast(CreateSpannConfiguration, loaded_config.get("spann", {}))
ef = loaded_config.get("embedding_function")
assert spann_config.get("space") == "cosine"
assert spann_config.get("ef_construction") == 100
assert spann_config.get("max_neighbors") == 10
assert spann_config.get("search_nprobe") == 5
assert spann_config.get("write_nprobe") == 10
assert ef is not None
# @pytest.mark.skipif(is_spann_disabled_mode, reason=skip_reason_spann_disabled)
# def test_spann_configuration_persistence(client_factories: "ClientFactories") -> None:
# """Test SPANN configuration persistence across client restarts"""
# client = client_factories.create_client_from_system()
# client.reset()

# # Create collection with specific SPANN configuration
# spann_config: CreateSpannConfiguration = {
# "space": "cosine",
# "ef_construction": 100,
# "max_neighbors": 10,
# "search_nprobe": 5,
# "write_nprobe": 10,
# }
# config: CreateCollectionConfiguration = {
# "spann": spann_config,
# "embedding_function": CustomEmbeddingFunction(dim=5),
# }

# client.create_collection(
# name="test_persist_spann_config",
# configuration=config,
# )

# client2 = client_factories.create_client_from_system()

# coll = client2.get_collection(
# name="test_persist_spann_config",
# )

# loaded_config = load_collection_configuration_from_json(
# coll._model.configuration_json
# )
# if loaded_config and isinstance(loaded_config, dict):
# spann_config = cast(CreateSpannConfiguration, loaded_config.get("spann", {}))
# ef = loaded_config.get("embedding_function")
# assert spann_config.get("space") == "cosine"
# assert spann_config.get("ef_construction") == 100
# assert spann_config.get("max_neighbors") == 10
# assert spann_config.get("search_nprobe") == 5
# assert spann_config.get("write_nprobe") == 10
# assert ef is not None


def test_exclusive_hnsw_spann_configuration(client: ClientAPI) -> None:
Expand Down
13 changes: 6 additions & 7 deletions chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from chromadb.api.collection_configuration import (
CollectionConfiguration,
HNSWConfiguration,
SpannConfiguration,
collection_configuration_to_json,
load_collection_configuration_from_json,
)
Expand Down Expand Up @@ -156,7 +155,7 @@ def get_configuration(self) -> CollectionConfiguration:
)
return CollectionConfiguration(
hnsw=HNSWConfiguration(),
spann=SpannConfiguration(),
spann=None,
embedding_function=None,
)

Expand All @@ -175,11 +174,11 @@ def get_model_fields(self) -> Dict[Any, Any]:
@override
def from_json(cls, json_map: Dict[str, Any]) -> Self:
"""Deserializes a Collection object from JSON"""
configuration: CollectionConfiguration = {
"hnsw": {},
"spann": {},
"embedding_function": None,
}
configuration = CollectionConfiguration(
hnsw=None,
spann=None,
embedding_function=None,
)
try:
configuration_json = json_map.get("configuration_json", None)
configuration = load_collection_configuration_from_json(configuration_json)
Expand Down
Loading
Loading