Skip to content

Commit 966809e

Browse files
committed
refactor: improve assertion formatting in batch embedding tests
Updated assertion statements in test_batch_embeddings.py for better readability by restructuring multiline assertions. This change enhances code clarity without altering the test logic.
1 parent 519b0bc commit 966809e

1 file changed

Lines changed: 12 additions & 6 deletions

File tree

tests/test_batch_embeddings.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ async def test_hybrid_mode_batches_embeddings():
7171
)
7272

7373
# The embedding function should be called exactly once with all 3 texts batched
74-
assert embed_func.call_count == 1, (
75-
f"Expected 1 batched embedding call, got {embed_func.call_count}"
76-
)
74+
assert (
75+
embed_func.call_count == 1
76+
), f"Expected 1 batched embedding call, got {embed_func.call_count}"
7777
call_args = embed_func.call_args[0][0]
7878
assert len(call_args) == 3, f"Expected 3 texts in batch, got {len(call_args)}"
7979
assert call_args == ["test query", "entity1, entity2", "theme1, theme2"]
@@ -109,14 +109,20 @@ async def test_hybrid_mode_passes_embeddings_to_vdbs():
109109
assert entities_call is not None, "entities_vdb.query was not called"
110110
ll_embedding = entities_call.kwargs.get("query_embedding")
111111
assert ll_embedding is not None, "ll_embedding was not passed to entities_vdb.query"
112-
assert np.all(ll_embedding == 2.0), f"Expected ll_embedding=[2,2,...], got {ll_embedding[:3]}"
112+
assert np.all(
113+
ll_embedding == 2.0
114+
), f"Expected ll_embedding=[2,2,...], got {ll_embedding[:3]}"
113115

114116
# relationships_vdb.query should receive hl_embedding (index 2 → all 3s)
115117
rel_call = relationships_vdb.query.call_args
116118
assert rel_call is not None, "relationships_vdb.query was not called"
117119
hl_embedding = rel_call.kwargs.get("query_embedding")
118-
assert hl_embedding is not None, "hl_embedding was not passed to relationships_vdb.query"
119-
assert np.all(hl_embedding == 3.0), f"Expected hl_embedding=[3,3,...], got {hl_embedding[:3]}"
120+
assert (
121+
hl_embedding is not None
122+
), "hl_embedding was not passed to relationships_vdb.query"
123+
assert np.all(
124+
hl_embedding == 3.0
125+
), f"Expected hl_embedding=[3,3,...], got {hl_embedding[:3]}"
120126

121127

122128
@pytest.mark.offline

0 commit comments

Comments
 (0)