Skip to content

Commit a2e5ec8

Browse files
fix: Wire /api/vision endpoint to main branch
Root cause: Vision API code existed but was never merged from feature/shimmy-vision-phase1 branch. This commit adds: - /api/vision POST route in server.rs - pub async fn vision() handler in api.rs - vision + vision_license module exports in lib.rs and main.rs - vision_license_manager field in AppState - generate_vision() method on LoadedModel trait - Remove shimmy-vision private crate dependency (use local code) The endpoint was working when testing from the feature branch but the main branch lacked the HTTP server wiring.
1 parent a4fd1a4 commit a2e5ec8

6 files changed

Lines changed: 159 additions & 10 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ full = ["huggingface", "llama", "mlx"] # Full compilation - includes all backend
3636
gpu = ["huggingface", "llama-cuda", "llama-vulkan", "llama-opencl"] # GPU-optimized build
3737
apple = ["huggingface", "mlx"] # Apple Silicon optimized - MLX + HuggingFace
3838
coverage = ["huggingface"] # Coverage testing - minimal deps for faster builds
39-
vision = ["dep:shimmy-vision", "dep:image", "dep:base64", "dep:chromiumoxide", "dep:ed25519-dalek", "dep:hex", "dep:sha2"] # Optional vision feature for image/web analysis
39+
vision = ["dep:image", "dep:base64", "dep:chromiumoxide", "dep:ed25519-dalek", "dep:hex", "dep:sha2"] # Optional vision feature for image/web analysis
4040

4141
[dependencies]
4242
anyhow = "1"
@@ -76,9 +76,6 @@ reqwest = { version = "0.11", features = ["json", "rustls-tls"], default-feature
7676
# llama.cpp bindings (optional) - published shimmy-llama-cpp-2 with MoE CPU offloading support
7777
shimmy-llama-cpp-2 = { version = "0.1.123", optional = true, default-features = false }
7878

79-
# Private vision crate (optional) - licensed vision processing
80-
shimmy-vision = { git = "https://github.com/Michael-A-Kuykendall/shimmy-vision-private.git", optional = true }
81-
8279
[dev-dependencies]
8380
tokio-tungstenite = "0.20"
8481
criterion = { version = "0.5", features = ["html_reports"] }

src/api.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,114 @@ mod tests {
11281128
assert_eq!(request.messages.as_ref().unwrap().len(), 1);
11291129
}
11301130
}
1131+
1132+
#[cfg(feature = "vision")]
1133+
#[axum::debug_handler]
1134+
pub async fn vision(
1135+
State(state): State<Arc<AppState>>,
1136+
Json(mut req): Json<crate::vision::VisionRequest>,
1137+
) -> impl IntoResponse {
1138+
// Extract license from environment or request
1139+
if req.license.is_none() {
1140+
req.license = std::env::var("SHIMMY_LICENSE_KEY").ok();
1141+
}
1142+
1143+
// Use default vision model or specified one
1144+
let env_model = std::env::var("SHIMMY_VISION_MODEL").ok();
1145+
let model_name = req
1146+
.model
1147+
.as_deref()
1148+
.or(env_model.as_deref())
1149+
.unwrap_or("minicpm-v")
1150+
.to_string();
1151+
1152+
let Some(license_manager) = state.vision_license_manager.as_ref() else {
1153+
tracing::error!("Vision license manager not initialized");
1154+
return (
1155+
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
1156+
Json(serde_json::json!({
1157+
"error": {
1158+
"code": "VISION_LICENSE_MANAGER_MISSING",
1159+
"message": "Vision subsystem not initialized",
1160+
}
1161+
})),
1162+
)
1163+
.into_response();
1164+
};
1165+
1166+
fn map_vision_error_status(message: &str) -> axum::http::StatusCode {
1167+
if message.contains("Either image_base64 or url must be provided") {
1168+
return axum::http::StatusCode::BAD_REQUEST;
1169+
}
1170+
if message.starts_with("Failed to decode base64 image") {
1171+
return axum::http::StatusCode::BAD_REQUEST;
1172+
}
1173+
if message.starts_with("Failed to preprocess image") {
1174+
return axum::http::StatusCode::UNPROCESSABLE_ENTITY;
1175+
}
1176+
if message.contains("Vision model '") && message.contains("not available in Ollama") {
1177+
return axum::http::StatusCode::UNPROCESSABLE_ENTITY;
1178+
}
1179+
if message.contains("Failed to fetch image from URL") {
1180+
if message.to_lowercase().contains("timed out") {
1181+
return axum::http::StatusCode::GATEWAY_TIMEOUT;
1182+
}
1183+
return axum::http::StatusCode::BAD_GATEWAY;
1184+
}
1185+
if message.contains("Vision inference timed out") {
1186+
return axum::http::StatusCode::GATEWAY_TIMEOUT;
1187+
}
1188+
if message.contains("Failed to load vision model")
1189+
|| message.contains("Vision inference failed")
1190+
{
1191+
return axum::http::StatusCode::BAD_GATEWAY;
1192+
}
1193+
1194+
axum::http::StatusCode::INTERNAL_SERVER_ERROR
1195+
}
1196+
1197+
match crate::vision::process_vision_request(
1198+
req,
1199+
&model_name,
1200+
license_manager,
1201+
&state,
1202+
)
1203+
.await
1204+
{
1205+
Ok(response) => Json(response).into_response(),
1206+
Err(e) => {
1207+
// Check if it's a license error
1208+
if let Some(license_err) = e.downcast_ref::<crate::vision_license::VisionLicenseError>()
1209+
{
1210+
return (
1211+
license_err.to_status_code(),
1212+
Json(license_err.to_json_error()),
1213+
)
1214+
.into_response();
1215+
}
1216+
1217+
let full_message = e.to_string();
1218+
let status = map_vision_error_status(&full_message);
1219+
1220+
tracing::error!(status = %status, "Vision processing error: {}", full_message);
1221+
// Expose client error messages (4xx) to help users fix their requests.
1222+
// Hide server error details (5xx) unless running in dev mode.
1223+
let message = if status.is_client_error() || std::env::var("SHIMMY_VISION_DEV_MODE").is_ok() {
1224+
full_message
1225+
} else {
1226+
"Vision processing error".to_string()
1227+
};
1228+
1229+
(
1230+
status,
1231+
Json(serde_json::json!({
1232+
"error": {
1233+
"code": "VISION_PROCESSING_ERROR",
1234+
"message": message,
1235+
}
1236+
})),
1237+
)
1238+
.into_response()
1239+
}
1240+
}
1241+
}

src/engine/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use anyhow::Result;
1+
use anyhow::{anyhow, Result};
22
use async_trait::async_trait;
33
use serde::{Deserialize, Serialize};
44
use std::path::PathBuf;
@@ -126,6 +126,17 @@ pub trait LoadedModel: Send + Sync {
126126
opts: GenOptions,
127127
on_token: Option<Box<dyn FnMut(String) + Send>>,
128128
) -> Result<String>;
129+
130+
async fn generate_vision(
131+
&self,
132+
_image_data: &[u8],
133+
_prompt: &str,
134+
_opts: GenOptions,
135+
_on_token: Option<Box<dyn FnMut(String) + Send>>,
136+
) -> Result<String> {
137+
// Default implementation returns error - vision models should override
138+
Err(anyhow!("Vision not supported by this model"))
139+
}
129140
}
130141

131142
pub mod llama;

src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub mod safetensors_adapter;
2222
pub mod server;
2323
pub mod templates;
2424
pub mod tools;
25+
#[cfg(feature = "vision")]
26+
pub mod vision;
27+
#[cfg(feature = "vision")]
28+
pub mod vision_license;
2529
pub mod util {
2630
pub mod diag;
2731
pub mod memory;
@@ -42,6 +46,8 @@ pub struct AppState {
4246
pub registry: model_registry::Registry,
4347
pub observability: observability::ObservabilityManager,
4448
pub response_cache: cache::ResponseCache,
49+
#[cfg(feature = "vision")]
50+
pub vision_license_manager: Option<crate::vision_license::VisionLicenseManager>,
4551
}
4652

4753
impl AppState {
@@ -54,6 +60,8 @@ impl AppState {
5460
registry,
5561
observability: observability::ObservabilityManager::new(),
5662
response_cache: cache::ResponseCache::new(),
63+
#[cfg(feature = "vision")]
64+
vision_license_manager: Some(crate::vision_license::VisionLicenseManager::new()),
5765
}
5866
}
5967
}

src/main.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ mod openai_compat;
1616
mod port_manager;
1717
mod server;
1818
mod templates;
19+
#[cfg(feature = "vision")]
20+
mod vision;
21+
#[cfg(feature = "vision")]
22+
mod vision_license;
1923
mod util {
2024
pub mod diag;
2125
pub mod memory;
@@ -32,16 +36,28 @@ pub struct AppState {
3236
pub registry: Registry,
3337
pub observability: observability::ObservabilityManager,
3438
pub response_cache: cache::ResponseCache,
39+
#[cfg(feature = "vision")]
40+
pub vision_license_manager: Option<crate::vision_license::VisionLicenseManager>,
3541
}
3642

3743
impl AppState {
3844
pub fn new(engine: Box<dyn engine::InferenceEngine>, registry: Registry) -> Self {
39-
Self {
45+
#[allow(unused_mut)]
46+
let mut state = Self {
4047
engine,
4148
registry,
4249
observability: observability::ObservabilityManager::new(),
4350
response_cache: cache::ResponseCache::new(),
51+
#[cfg(feature = "vision")]
52+
vision_license_manager: None,
53+
};
54+
55+
#[cfg(feature = "vision")]
56+
{
57+
state.vision_license_manager = Some(crate::vision_license::VisionLicenseManager::new());
4458
}
59+
60+
state
4561
}
4662
}
4763

src/server.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ async fn metrics_endpoint(State(state): State<Arc<AppState>>) -> Json<Value> {
130130

131131
pub async fn run(addr: SocketAddr, state: Arc<AppState>) -> anyhow::Result<()> {
132132
let listener = tokio::net::TcpListener::bind(addr).await?;
133-
let app = Router::new()
133+
#[allow(unused_mut)]
134+
let mut app = Router::new()
134135
.route("/health", get(health_check))
135136
.route("/metrics", get(metrics_endpoint))
136137
.route("/diag", get(diag_handler))
@@ -150,9 +151,14 @@ pub async fn run(addr: SocketAddr, state: Arc<AppState>) -> anyhow::Result<()> {
150151
)
151152
.route("/v1/models", get(openai_compat::models))
152153
// Anthropic Claude API compatibility
153-
.route("/v1/messages", post(anthropic_compat::messages))
154-
.layer(middleware::from_fn(cors_layer))
155-
.with_state(state);
154+
.route("/v1/messages", post(anthropic_compat::messages));
155+
156+
#[cfg(feature = "vision")]
157+
{
158+
app = app.route("/api/vision", post(api::vision));
159+
}
160+
161+
let app = app.layer(middleware::from_fn(cors_layer)).with_state(state);
156162
axum::serve(listener, app).await?;
157163
Ok(())
158164
}

0 commit comments

Comments
 (0)