Skip to content

Conversation

@ljt019
Copy link
Owner

@ljt019 ljt019 commented Jul 7, 2025

Summary

  • add embed_batch method to embedding model trait and pipeline
  • fix recursive call in Qwen3EmbeddingModel implementation
  • use shared DeviceRequest in all pipeline builders
  • add tests for batch embedding

Testing

  • cargo test --lib
  • cargo test --doc

https://chatgpt.com/codex/tasks/task_e_686c2639239083308b54be59e2578c46

Summary by CodeRabbit

  • New Features

    • Added support for batch embedding, allowing multiple texts to be embedded in a single call.
    • Introduced a new method to obtain batch embeddings from the embedding pipeline.
  • Refactor

    • Unified and centralized device selection logic across pipeline builders, improving device configuration options.
    • Updated builder APIs to use a new device request mechanism for more flexible hardware selection.
  • Tests

    • Added tests to verify batch embedding functionality and ensure correct results for multiple inputs.

@coderabbitai
Copy link

coderabbitai bot commented Jul 7, 2025

Walkthrough

The changes introduce a unified device selection mechanism via a new DeviceRequest enum, refactor pipeline builders to use this abstraction, and add batch embedding support to the embedding pipeline. The embedding model trait and pipeline are extended with batch methods, and corresponding tests are added to verify batch embedding functionality.

Changes

File(s) Change Summary
src/pipelines/utils/mod.rs Introduced public DeviceRequest enum and its resolve method for unified device selection logic.
src/pipelines/embedding_pipeline/builder.rs
src/pipelines/reranker_pipeline/builder.rs
Refactored builders to use DeviceRequest instead of direct device handling; updated related methods.
src/pipelines/text_generation_pipeline/builder.rs Switched to importing DeviceRequest from pipelines::utils module.
src/pipelines/embedding_pipeline/embedding_model.rs Added embed_batch default method to EmbeddingModel trait for batch embedding support.
src/pipelines/embedding_pipeline/embedding_pipeline.rs Added embed_batch method; refactored embed to use batch embedding logic.
src/models/implementations/qwen3_embeddings.rs Fixed recursive call in embed method by using fully qualified syntax.
tests/embedding_pipeline_tests/batch_embedding.rs Added async test for batch embedding with the Qwen3 model.
tests/embedding_pipeline_tests/main.rs Imported new batch_embedding test module.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant EmbeddingPipeline
    participant EmbeddingModel
    participant Tokenizer

    User->>EmbeddingPipeline: embed_batch(texts)
    EmbeddingPipeline->>EmbeddingModel: embed_batch(tokenizer, texts)
    loop for each text in texts
        EmbeddingModel->>EmbeddingModel: embed(tokenizer, text)
    end
    EmbeddingModel-->>EmbeddingPipeline: embeddings[]
    EmbeddingPipeline-->>User: embeddings[]
Loading

Possibly related PRs

Poem

A rabbit hops through code so bright,
Batch embeddings now take flight!
Devices chosen with a single enum’s might,
No more recursion in the night.
Tests are added, all is well—
In this meadow, features swell! 🐇✨

Warning

There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 Clippy (1.86.0)
Updating crates.io index
Updating git repository `https://github.com/huggingface/candle.git`

error: failed to get candle-core as a dependency of package transformers v0.0.12 ()

Caused by:
failed to load source for dependency candle-core

Caused by:
Unable to update https://github.com/huggingface/candle.git#d0a3b33e

Caused by:
failed to create directory /usr/local/git/db/candle-5b4d092929d18d36

Caused by:
Permission denied (os error 13)

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@ljt019 ljt019 self-assigned this Jul 7, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
src/pipelines/text_generation_pipeline/builder.rs (1)

125-130: Remove duplicate device resolution logic.

The manual matching on DeviceRequest variants duplicates the logic already implemented in DeviceRequest::resolve(). This defeats the purpose of the abstraction and creates maintenance issues.

Apply this diff to use the centralized resolution method:

-        let device = match self.device_request {
-            DeviceRequest::Default => load_device_with(None)?,
-            DeviceRequest::Cpu => Device::Cpu,
-            DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?),
-            DeviceRequest::Explicit(d) => d,
-        };
+        let device = self.device_request.resolve()?;
♻️ Duplicate comments (1)
src/pipelines/text_generation_pipeline/builder.rs (1)

168-173: Remove duplicate device resolution logic.

Same issue as in the build method - manual matching on DeviceRequest variants duplicates the centralized logic.

Apply this diff to use the centralized resolution method:

-        let device = match self.device_request {
-            DeviceRequest::Default => load_device_with(None)?,
-            DeviceRequest::Cpu => Device::Cpu,
-            DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?),
-            DeviceRequest::Explicit(d) => d,
-        };
+        let device = self.device_request.resolve()?;
🧹 Nitpick comments (2)
src/pipelines/embedding_pipeline/embedding_model.rs (1)

13-23: LGTM! Well-implemented batch embedding with good error handling.

The default implementation is correct and efficient:

  • Pre-allocates the results vector with proper capacity
  • Uses proper error propagation with the ? operator
  • Sequential processing ensures deterministic behavior

Consider for future enhancement: implementing parallel processing for better performance with large batches, though the current sequential approach is perfectly acceptable for most use cases.

tests/embedding_pipeline_tests/batch_embedding.rs (1)

1-15: Good integration test for batch embedding functionality.

The test effectively verifies the core batch embedding behavior. Consider adding additional test cases to improve coverage:

#[tokio::test]
async fn batch_embedding_empty_input() -> anyhow::Result<()> {
    let pipeline = EmbeddingPipelineBuilder::qwen3(Qwen3EmbeddingSize::Size0_6B)
        .build()
        .await?;
    let inputs: [&str; 0] = [];
    let embs = pipeline.embed_batch(&inputs)?;
    assert_eq!(embs.len(), 0);
    Ok(())
}

#[tokio::test]
async fn batch_embedding_single_input() -> anyhow::Result<()> {
    let pipeline = EmbeddingPipelineBuilder::qwen3(Qwen3EmbeddingSize::Size0_6B)
        .build()
        .await?;
    let inputs = ["single text"];
    let embs = pipeline.embed_batch(&inputs)?;
    assert_eq!(embs.len(), 1);
    assert!(!embs[0].is_empty());
    Ok(())
}
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6573036 and 7f6b709.

📒 Files selected for processing (9)
  • src/models/implementations/qwen3_embeddings.rs (1 hunks)
  • src/pipelines/embedding_pipeline/builder.rs (2 hunks)
  • src/pipelines/embedding_pipeline/embedding_model.rs (1 hunks)
  • src/pipelines/embedding_pipeline/embedding_pipeline.rs (1 hunks)
  • src/pipelines/reranker_pipeline/builder.rs (2 hunks)
  • src/pipelines/text_generation_pipeline/builder.rs (1 hunks)
  • src/pipelines/utils/mod.rs (1 hunks)
  • tests/embedding_pipeline_tests/batch_embedding.rs (1 hunks)
  • tests/embedding_pipeline_tests/main.rs (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (6)
tests/embedding_pipeline_tests/main.rs (1)
tests/embedding_pipeline_tests/batch_embedding.rs (1)
  • batch_embedding (4-15)
src/pipelines/text_generation_pipeline/builder.rs (1)
src/pipelines/utils/mod.rs (1)
  • load_device_with (7-17)
src/models/implementations/qwen3_embeddings.rs (1)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
  • embed (11-11)
tests/embedding_pipeline_tests/batch_embedding.rs (1)
src/pipelines/embedding_pipeline/builder.rs (1)
  • qwen3 (50-52)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)
  • embed_batch (15-17)
src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
  • embed_batch (13-23)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Run Tests Development
🔇 Additional comments (10)
src/pipelines/utils/mod.rs (1)

25-54: Excellent abstraction for device selection.

The DeviceRequest enum provides a clean, type-safe way to specify device preferences across pipeline builders. The implementation properly delegates to existing device loading logic and follows good Rust patterns.

src/pipelines/reranker_pipeline/builder.rs (1)

4-46: Proper implementation of DeviceRequest pattern.

The builder correctly uses the centralized DeviceRequest abstraction and properly delegates device resolution to the resolve() method. This is the intended usage pattern that should be followed by other builders.

tests/embedding_pipeline_tests/main.rs (1)

2-2: LGTM - enables batch embedding tests.

The module import properly includes the new batch embedding tests in the test suite.

src/models/implementations/qwen3_embeddings.rs (1)

168-168: Critical fix for infinite recursion.

The fully qualified syntax correctly resolves the method call to the inherent implementation instead of recursively calling the trait method, preventing a stack overflow.

src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)

15-17: Clean delegation to model's batch method.

The implementation correctly delegates to the underlying model's embed_batch method, maintaining proper separation of concerns.

src/pipelines/embedding_pipeline/builder.rs (5)

4-4: Good architectural improvement with DeviceRequest abstraction.

The import of DeviceRequest from the utils module standardizes device selection across pipelines, improving code consistency and maintainability.


8-8: Consistent field type change to DeviceRequest.

Replacing the optional device field with DeviceRequest centralizes device selection logic and aligns with the standardization effort across pipeline builders.


12-17: Proper initialization with DeviceRequest::Default.

The constructor correctly initializes the device request to the default state, maintaining backward compatibility while using the new abstraction.


19-32: Builder methods correctly updated for DeviceRequest variants.

All builder methods (cpu(), cuda_device(), device()) properly map to their corresponding DeviceRequest enum variants, maintaining the same public API while using the new abstraction.


39-39: Centralized device resolution improves maintainability.

The self.device_request.resolve()? call centralizes device resolution logic, making it consistent across all pipeline builders and easier to maintain.

Comment on lines 10 to +13
pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
self.model.embed(&self.tokenizer, text)
let mut out = self.embed_batch(&[text])?;
Ok(out.pop().unwrap())
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Potential panic risk with unwrap() - consider safer alternative.

The refactoring is architecturally sound, but the unwrap() call could panic if embed_batch unexpectedly returns an empty vector. While this shouldn't happen with a single-element input, it's better to be defensive.

Consider this safer implementation:

 pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
-    let mut out = self.embed_batch(&[text])?;
-    Ok(out.pop().unwrap())
+    let mut out = self.embed_batch(&[text])?;
+    out.pop()
+        .ok_or_else(|| anyhow::anyhow!("embed_batch returned empty result for single input"))
 }

Alternatively, you could use indexing with a more descriptive panic message:

 pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
-    let mut out = self.embed_batch(&[text])?;
-    Ok(out.pop().unwrap())
+    self.embed_batch(&[text])
+        .map(|mut embeddings| embeddings.pop().expect("embed_batch should return exactly one embedding for single input"))
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
self.model.embed(&self.tokenizer, text)
let mut out = self.embed_batch(&[text])?;
Ok(out.pop().unwrap())
}
pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
let mut out = self.embed_batch(&[text])?;
out.pop()
.ok_or_else(|| anyhow::anyhow!("embed_batch returned empty result for single input"))
}
🤖 Prompt for AI Agents
In src/pipelines/embedding_pipeline/embedding_pipeline.rs around lines 10 to 13,
the use of unwrap() on the result of out.pop() can cause a panic if the vector
is empty. To fix this, replace unwrap() with a safer alternative such as using
out.pop().ok_or_else() to return a descriptive error if the vector is empty, or
use indexing with get(0) combined with a clear error message to handle the empty
case gracefully without panicking.

@ljt019 ljt019 merged commit 3e45c20 into dev Jul 7, 2025
1 of 2 checks passed
@ljt019 ljt019 deleted the analyze-api-ergonomics-for-embedding-and-reranker-pipelines branch July 7, 2025 22:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants