-
Notifications
You must be signed in to change notification settings - Fork 2
Add batch embedding #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch embedding #65
Conversation
WalkthroughThe changes introduce a unified device selection mechanism via a new Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant EmbeddingPipeline
participant EmbeddingModel
participant Tokenizer
User->>EmbeddingPipeline: embed_batch(texts)
EmbeddingPipeline->>EmbeddingModel: embed_batch(tokenizer, texts)
loop for each text in texts
EmbeddingModel->>EmbeddingModel: embed(tokenizer, text)
end
EmbeddingModel-->>EmbeddingPipeline: embeddings[]
EmbeddingPipeline-->>User: embeddings[]
Possibly related PRs
Poem
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Clippy (1.86.0)error: failed to get Caused by: Caused by: Caused by: Caused by: ✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
src/pipelines/text_generation_pipeline/builder.rs (1)
125-130: Remove duplicate device resolution logic.The manual matching on
DeviceRequestvariants duplicates the logic already implemented inDeviceRequest::resolve(). This defeats the purpose of the abstraction and creates maintenance issues.Apply this diff to use the centralized resolution method:
- let device = match self.device_request { - DeviceRequest::Default => load_device_with(None)?, - DeviceRequest::Cpu => Device::Cpu, - DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?), - DeviceRequest::Explicit(d) => d, - }; + let device = self.device_request.resolve()?;
♻️ Duplicate comments (1)
src/pipelines/text_generation_pipeline/builder.rs (1)
168-173: Remove duplicate device resolution logic.Same issue as in the
buildmethod - manual matching onDeviceRequestvariants duplicates the centralized logic.Apply this diff to use the centralized resolution method:
- let device = match self.device_request { - DeviceRequest::Default => load_device_with(None)?, - DeviceRequest::Cpu => Device::Cpu, - DeviceRequest::Cuda(i) => Device::Cuda(CudaDevice::new_with_stream(i)?), - DeviceRequest::Explicit(d) => d, - }; + let device = self.device_request.resolve()?;
🧹 Nitpick comments (2)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
13-23: LGTM! Well-implemented batch embedding with good error handling.The default implementation is correct and efficient:
- Pre-allocates the results vector with proper capacity
- Uses proper error propagation with the
?operator- Sequential processing ensures deterministic behavior
Consider for future enhancement: implementing parallel processing for better performance with large batches, though the current sequential approach is perfectly acceptable for most use cases.
tests/embedding_pipeline_tests/batch_embedding.rs (1)
1-15: Good integration test for batch embedding functionality.The test effectively verifies the core batch embedding behavior. Consider adding additional test cases to improve coverage:
#[tokio::test] async fn batch_embedding_empty_input() -> anyhow::Result<()> { let pipeline = EmbeddingPipelineBuilder::qwen3(Qwen3EmbeddingSize::Size0_6B) .build() .await?; let inputs: [&str; 0] = []; let embs = pipeline.embed_batch(&inputs)?; assert_eq!(embs.len(), 0); Ok(()) } #[tokio::test] async fn batch_embedding_single_input() -> anyhow::Result<()> { let pipeline = EmbeddingPipelineBuilder::qwen3(Qwen3EmbeddingSize::Size0_6B) .build() .await?; let inputs = ["single text"]; let embs = pipeline.embed_batch(&inputs)?; assert_eq!(embs.len(), 1); assert!(!embs[0].is_empty()); Ok(()) }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/models/implementations/qwen3_embeddings.rs(1 hunks)src/pipelines/embedding_pipeline/builder.rs(2 hunks)src/pipelines/embedding_pipeline/embedding_model.rs(1 hunks)src/pipelines/embedding_pipeline/embedding_pipeline.rs(1 hunks)src/pipelines/reranker_pipeline/builder.rs(2 hunks)src/pipelines/text_generation_pipeline/builder.rs(1 hunks)src/pipelines/utils/mod.rs(1 hunks)tests/embedding_pipeline_tests/batch_embedding.rs(1 hunks)tests/embedding_pipeline_tests/main.rs(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (6)
tests/embedding_pipeline_tests/main.rs (1)
tests/embedding_pipeline_tests/batch_embedding.rs (1)
batch_embedding(4-15)
src/pipelines/text_generation_pipeline/builder.rs (1)
src/pipelines/utils/mod.rs (1)
load_device_with(7-17)
src/models/implementations/qwen3_embeddings.rs (1)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
embed(11-11)
tests/embedding_pipeline_tests/batch_embedding.rs (1)
src/pipelines/embedding_pipeline/builder.rs (1)
qwen3(50-52)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)
embed_batch(15-17)
src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)
src/pipelines/embedding_pipeline/embedding_model.rs (1)
embed_batch(13-23)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Run Tests Development
🔇 Additional comments (10)
src/pipelines/utils/mod.rs (1)
25-54: Excellent abstraction for device selection.The
DeviceRequestenum provides a clean, type-safe way to specify device preferences across pipeline builders. The implementation properly delegates to existing device loading logic and follows good Rust patterns.src/pipelines/reranker_pipeline/builder.rs (1)
4-46: Proper implementation of DeviceRequest pattern.The builder correctly uses the centralized
DeviceRequestabstraction and properly delegates device resolution to theresolve()method. This is the intended usage pattern that should be followed by other builders.tests/embedding_pipeline_tests/main.rs (1)
2-2: LGTM - enables batch embedding tests.The module import properly includes the new batch embedding tests in the test suite.
src/models/implementations/qwen3_embeddings.rs (1)
168-168: Critical fix for infinite recursion.The fully qualified syntax correctly resolves the method call to the inherent implementation instead of recursively calling the trait method, preventing a stack overflow.
src/pipelines/embedding_pipeline/embedding_pipeline.rs (1)
15-17: Clean delegation to model's batch method.The implementation correctly delegates to the underlying model's
embed_batchmethod, maintaining proper separation of concerns.src/pipelines/embedding_pipeline/builder.rs (5)
4-4: Good architectural improvement with DeviceRequest abstraction.The import of
DeviceRequestfrom the utils module standardizes device selection across pipelines, improving code consistency and maintainability.
8-8: Consistent field type change to DeviceRequest.Replacing the optional device field with
DeviceRequestcentralizes device selection logic and aligns with the standardization effort across pipeline builders.
12-17: Proper initialization with DeviceRequest::Default.The constructor correctly initializes the device request to the default state, maintaining backward compatibility while using the new abstraction.
19-32: Builder methods correctly updated for DeviceRequest variants.All builder methods (
cpu(),cuda_device(),device()) properly map to their correspondingDeviceRequestenum variants, maintaining the same public API while using the new abstraction.
39-39: Centralized device resolution improves maintainability.The
self.device_request.resolve()?call centralizes device resolution logic, making it consistent across all pipeline builders and easier to maintain.
| pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> { | ||
| self.model.embed(&self.tokenizer, text) | ||
| let mut out = self.embed_batch(&[text])?; | ||
| Ok(out.pop().unwrap()) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Potential panic risk with unwrap() - consider safer alternative.
The refactoring is architecturally sound, but the unwrap() call could panic if embed_batch unexpectedly returns an empty vector. While this shouldn't happen with a single-element input, it's better to be defensive.
Consider this safer implementation:
pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
- let mut out = self.embed_batch(&[text])?;
- Ok(out.pop().unwrap())
+ let mut out = self.embed_batch(&[text])?;
+ out.pop()
+ .ok_or_else(|| anyhow::anyhow!("embed_batch returned empty result for single input"))
}Alternatively, you could use indexing with a more descriptive panic message:
pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
- let mut out = self.embed_batch(&[text])?;
- Ok(out.pop().unwrap())
+ self.embed_batch(&[text])
+ .map(|mut embeddings| embeddings.pop().expect("embed_batch should return exactly one embedding for single input"))
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> { | |
| self.model.embed(&self.tokenizer, text) | |
| let mut out = self.embed_batch(&[text])?; | |
| Ok(out.pop().unwrap()) | |
| } | |
| pub fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> { | |
| let mut out = self.embed_batch(&[text])?; | |
| out.pop() | |
| .ok_or_else(|| anyhow::anyhow!("embed_batch returned empty result for single input")) | |
| } |
🤖 Prompt for AI Agents
In src/pipelines/embedding_pipeline/embedding_pipeline.rs around lines 10 to 13,
the use of unwrap() on the result of out.pop() can cause a panic if the vector
is empty. To fix this, replace unwrap() with a safer alternative such as using
out.pop().ok_or_else() to return a descriptive error if the vector is empty, or
use indexing with get(0) combined with a clear error message to handle the empty
case gracefully without panicking.
Summary
embed_batchmethod to embedding model trait and pipelineQwen3EmbeddingModelimplementationDeviceRequestin all pipeline buildersTesting
cargo test --libcargo test --dochttps://chatgpt.com/codex/tasks/task_e_686c2639239083308b54be59e2578c46
Summary by CodeRabbit
New Features
Refactor
Tests