Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions src/pipelines/embedding_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
use super::model::EmbeddingModel;
use super::pipeline::EmbeddingPipeline;
use crate::core::ModelOptions;
use crate::pipelines::utils::{BasePipelineBuilder, DeviceRequest, DeviceSelectable, StandardPipelineBuilder};
use std::sync::Arc;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest, DeviceSelectable, BasePipelineBuilder};

pub struct EmbeddingPipelineBuilder<M: EmbeddingModel> {
options: M::Options,
device_request: DeviceRequest,
}
pub struct EmbeddingPipelineBuilder<M: EmbeddingModel>(StandardPipelineBuilder<M::Options>);

impl<M: EmbeddingModel> EmbeddingPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device_request: DeviceRequest::Default,
}
Self(StandardPipelineBuilder::new(options))
}
}

impl<M: EmbeddingModel> DeviceSelectable for EmbeddingPipelineBuilder<M> {
fn device_request_mut(&mut self) -> &mut DeviceRequest {
&mut self.device_request
self.0.device_request_mut()
}
}

Expand All @@ -34,11 +28,11 @@ where
type Options = M::Options;

fn options(&self) -> &Self::Options {
&self.options
&self.0.options
}

fn device_request(&self) -> &DeviceRequest {
&self.device_request
&self.0.device_request
}

fn create_model(options: Self::Options, device: candle_core::Device) -> anyhow::Result<M> {
Expand All @@ -50,9 +44,9 @@ where
}

fn construct_pipeline(model: M, tokenizer: tokenizers::Tokenizer) -> anyhow::Result<Self::Pipeline> {
Ok(EmbeddingPipeline {
model: Arc::new(model),
tokenizer
Ok(EmbeddingPipeline {
model: Arc::new(model),
tokenizer,
})
}
}
Expand Down
20 changes: 7 additions & 13 deletions src/pipelines/fill_mask_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use super::model::FillMaskModel;
use super::pipeline::FillMaskPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest, DeviceSelectable, BasePipelineBuilder};
use crate::core::ModelOptions;
use crate::pipelines::utils::{BasePipelineBuilder, DeviceRequest, DeviceSelectable, StandardPipelineBuilder};

pub struct FillMaskPipelineBuilder<M: FillMaskModel> {
options: M::Options,
device_request: DeviceRequest,
}
pub struct FillMaskPipelineBuilder<M: FillMaskModel>(StandardPipelineBuilder<M::Options>);

impl<M: FillMaskModel> FillMaskPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device_request: DeviceRequest::Default,
}
Self(StandardPipelineBuilder::new(options))
}
}

impl<M: FillMaskModel> DeviceSelectable for FillMaskPipelineBuilder<M> {
fn device_request_mut(&mut self) -> &mut DeviceRequest {
&mut self.device_request
self.0.device_request_mut()
}
}

Expand All @@ -33,11 +27,11 @@ where
type Options = M::Options;

fn options(&self) -> &Self::Options {
&self.options
&self.0.options
}

fn device_request(&self) -> &DeviceRequest {
&self.device_request
&self.0.device_request
}

fn create_model(options: Self::Options, device: candle_core::Device) -> anyhow::Result<M> {
Expand Down
26 changes: 10 additions & 16 deletions src/pipelines/reranker_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
use super::model::RerankModel;
use super::pipeline::RerankPipeline;
use crate::core::ModelOptions;
use crate::pipelines::utils::{BasePipelineBuilder, DeviceRequest, DeviceSelectable, StandardPipelineBuilder};
use std::sync::Arc;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest, DeviceSelectable, BasePipelineBuilder};

pub struct RerankPipelineBuilder<M: RerankModel> {
options: M::Options,
device_request: DeviceRequest,
}
pub struct RerankPipelineBuilder<M: RerankModel>(StandardPipelineBuilder<M::Options>);

impl<M: RerankModel> RerankPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device_request: DeviceRequest::Default,
}
Self(StandardPipelineBuilder::new(options))
}
}

impl<M: RerankModel> DeviceSelectable for RerankPipelineBuilder<M> {
fn device_request_mut(&mut self) -> &mut DeviceRequest {
&mut self.device_request
self.0.device_request_mut()
}
}

Expand All @@ -34,11 +28,11 @@ where
type Options = M::Options;

fn options(&self) -> &Self::Options {
&self.options
&self.0.options
}

fn device_request(&self) -> &DeviceRequest {
&self.device_request
&self.0.device_request
}

fn create_model(options: Self::Options, device: candle_core::Device) -> anyhow::Result<M> {
Expand All @@ -50,9 +44,9 @@ where
}

fn construct_pipeline(model: M, tokenizer: tokenizers::Tokenizer) -> anyhow::Result<Self::Pipeline> {
Ok(RerankPipeline {
model: Arc::new(model),
tokenizer
Ok(RerankPipeline {
model: Arc::new(model),
tokenizer,
})
}
}
Expand Down
20 changes: 7 additions & 13 deletions src/pipelines/sentiment_analysis_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use super::model::SentimentAnalysisModel;
use super::pipeline::SentimentAnalysisPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest, DeviceSelectable, BasePipelineBuilder};
use crate::core::ModelOptions;
use crate::pipelines::utils::{BasePipelineBuilder, DeviceRequest, DeviceSelectable, StandardPipelineBuilder};

pub struct SentimentAnalysisPipelineBuilder<M: SentimentAnalysisModel> {
options: M::Options,
device_request: DeviceRequest,
}
pub struct SentimentAnalysisPipelineBuilder<M: SentimentAnalysisModel>(StandardPipelineBuilder<M::Options>);

impl<M: SentimentAnalysisModel> SentimentAnalysisPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device_request: DeviceRequest::Default,
}
Self(StandardPipelineBuilder::new(options))
}
}

impl<M: SentimentAnalysisModel> DeviceSelectable for SentimentAnalysisPipelineBuilder<M> {
fn device_request_mut(&mut self) -> &mut DeviceRequest {
&mut self.device_request
self.0.device_request_mut()
}
}

Expand All @@ -33,11 +27,11 @@ where
type Options = M::Options;

fn options(&self) -> &Self::Options {
&self.options
&self.0.options
}

fn device_request(&self) -> &DeviceRequest {
&self.device_request
&self.0.device_request
}

fn create_model(options: Self::Options, device: candle_core::Device) -> anyhow::Result<M> {
Expand Down
1 change: 0 additions & 1 deletion src/pipelines/utils/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
//!
//! The `BasePipelineBuilder` trait captures this common pattern.

use std::sync::Arc;
use anyhow::Result;
use crate::core::{global_cache, ModelOptions};
use super::{build_cache_key, DeviceRequest, DeviceSelectable};
Expand Down
20 changes: 7 additions & 13 deletions src/pipelines/zero_shot_classification_pipeline/builder.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use super::model::ZeroShotClassificationModel;
use super::pipeline::ZeroShotClassificationPipeline;
use crate::core::{global_cache, ModelOptions};
use crate::pipelines::utils::{build_cache_key, DeviceRequest, DeviceSelectable, BasePipelineBuilder};
use crate::core::ModelOptions;
use crate::pipelines::utils::{BasePipelineBuilder, DeviceRequest, DeviceSelectable, StandardPipelineBuilder};

pub struct ZeroShotClassificationPipelineBuilder<M: ZeroShotClassificationModel> {
options: M::Options,
device_request: DeviceRequest,
}
pub struct ZeroShotClassificationPipelineBuilder<M: ZeroShotClassificationModel>(StandardPipelineBuilder<M::Options>);

impl<M: ZeroShotClassificationModel> ZeroShotClassificationPipelineBuilder<M> {
pub fn new(options: M::Options) -> Self {
Self {
options,
device_request: DeviceRequest::Default,
}
Self(StandardPipelineBuilder::new(options))
}
}

impl<M: ZeroShotClassificationModel> DeviceSelectable for ZeroShotClassificationPipelineBuilder<M> {
fn device_request_mut(&mut self) -> &mut DeviceRequest {
&mut self.device_request
self.0.device_request_mut()
}
}

Expand All @@ -33,11 +27,11 @@ where
type Options = M::Options;

fn options(&self) -> &Self::Options {
&self.options
&self.0.options
}

fn device_request(&self) -> &DeviceRequest {
&self.device_request
&self.0.device_request
}

fn create_model(options: Self::Options, device: candle_core::Device) -> anyhow::Result<M> {
Expand Down
Loading