Skip to content

Conversation

@harshil-sanghvi
Copy link
Contributor

PR Description

This PR introduces a set of performance-focused changes across the Ragas codebase. The goal is to reduce unnecessary overhead in evaluation and sampling paths while keeping all outputs fully consistent with the existing behavior. Most updates replace repeated or quadratic operations with linear or cached implementations, resulting in noticeably faster runs on larger datasets.

Key Optimizations

Critical Impact Changes

1. Average Precision Calculation (O(n²) → O(n))

Files Modified:

  • src/ragas/metrics/_context_precision.py
  • src/ragas/metrics/collections/context_precision/metric.py

The previous implementation recalculated cumulative sums inside a loop. This PR replaces it with a single-pass cumulative sum approach. This brings down the time cost for average precision calculations, especially when the number of retrieved contexts is large.

Before:

numerator = sum(
    [
        (sum(verdict_list[: i + 1]) / (i + 1)) * verdict_list[i]
        for i in range(len(verdict_list))
    ]
)

After:

cumsum = 0
numerator = 0.0
for i, verdict in enumerate(verdict_list):
    cumsum += verdict
    if verdict:
        numerator += cumsum / (i + 1)

2. Node Lookup Optimization (O(n) → O(1))

Files Modified:

  • src/ragas/testset/graph.py

Repeated linear scans over graph nodes caused noticeable slowdown during test set generation. A dedicated _node_id_cache is added to support O(1) lookups. The cache reconstructs itself automatically after deserialization to avoid stale state.


3. Stratified Sampling Optimization (O(n²) → O(n))

Files Modified:

  • src/ragas/dataset_schema.py

The sampling loop previously rebuilt sets and lists on each iteration. The updated code computes the shortage once, determines remaining indices, and uses random.sample() to fetch all missing items in one step. This reduces overhead for large datasets.

Before:

while len(sampled_indices) < n:
    remaining_indices = set(range(len(self.samples))) - set(sampled_indices)
    if not remaining_indices:
        break
    sampled_indices.append(random.choice(list(remaining_indices)))

After:

if len(sampled_indices) < n:
    remaining_indices = set(range(len(self.samples))) - set(sampled_indices)
    shortage = n - len(sampled_indices)
    if remaining_indices and shortage > 0:
        additional_samples = random.sample(
            list(remaining_indices), min(shortage, len(remaining_indices))
        )
        sampled_indices.extend(additional_samples)

High Impact Changes

4. Vectorized Hamming Distance

Files Modified:

  • src/ragas/optimizers/utils.py

Distance computation is now implemented using scipy.spatial.distance utilities instead of nested Python loops. This shifts work to optimized C-backed functions and simplifies the code. The new version also ensures a symmetric distance matrix.

Before:

distances = np.zeros((len(vectors), len(vectors)), dtype=int)
for i in range(len(vectors)):
    for j in range(i + 1, len(vectors)):
        distance = np.sum(vectors[i] != vectors[j])
        distances[i][j] = distance

After:

from scipy.spatial.distance import pdist, squareform

vectors_array = np.array(vectors)
distances = squareform(pdist(vectors_array, metric='hamming') * length)
return distances.astype(int)

5. Persona Lookup Optimization (O(n) → O(1))

Files Modified:

  • src/ragas/testset/persona.py

A _name_cache lookup table is added and initialized automatically. This avoids repeated linear scans when resolving persona entries and keeps compatibility with Pydantic’s initialization flow.


Medium Impact Changes

6. Batch Creation Cleanup

Files Modified:

  • src/ragas/dataset_schema.py

Avoids evaluating the same slice twice by storing it in a variable before reuse. This slightly improves batch-related operations and makes the code easier to follow.


7. LLM Type Checking Streamline

Files Modified:

  • src/ragas/llms/base.py

Replaces a looped type check with a tuple-based isinstance() call. While not a major performance change, it simplifies the logic and reduces overhead for repeated checks.

Before:

for llm_type in MULTIPLE_COMPLETION_SUPPORTED:
    if isinstance(llm, llm_type):
        return True
return False

After:

return isinstance(llm, MULTIPLE_COMPLETION_SUPPORTED)

8. Counter Usage Simplification

Files Modified:

  • src/ragas/metrics/base.py

Replaces a multi-step process to find the most common element with Counter.most_common(1). This avoids unnecessary intermediate structures.


Design Notes

Cache Management

Both node and persona lookup caches rebuild automatically when needed, keeping lookup operations efficient without requiring callers to manage state.

Backward Compatibility

All optimizations preserve existing behavior, and test suites should pass without any required changes.

Dependencies

scipy is used for vectorized distance calculations. It is already part of the project dependencies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant