diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index 4a664c42e2..d2880b19c3 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -323,4 +323,39 @@ describe("parseSafetensorsMetadata", () => { assert(parse.sharded); assert.strictEqual(Object.keys(parse.headers).length, 64); }); + + it("fetch info for deepseek-ai/DeepSeek-Math-V2 (163 shards)", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "deepseek-ai/DeepSeek-Math-V2", + computeParametersCount: true, + }); + + assert(parse); + assert(parse.sharded, "Model should be sharded"); + assert.strictEqual(Object.keys(parse.headers).length, 163, "Should have 163 shards"); + assert.ok(parse.parameterCount, "Should have parameter count"); + assert.ok(parse.index, "Should have index"); + + // Verify parameter count is computed + const totalParams = parse.parameterTotal || sum(Object.values(parse.parameterCount)); + assert.ok(totalParams > 0, "Total parameters should be greater than 0"); + + console.log("Total parameters:", totalParams); + console.log("Parameter count by dtype:", parse.parameterCount); + }); + + it("fetch info for Qwen/Qwen3.5-397B-A17B (94 shards)", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "Qwen/Qwen3.5-397B-A17B", + computeParametersCount: true, + }); + + assert(parse.sharded); + assert.strictEqual(Object.keys(parse.headers).length, 94); + assert.ok(parse.parameterCount); + + const totalParams = parse.parameterTotal || sum(Object.values(parse.parameterCount)); + console.log("Qwen3.5-397B total parameters:", totalParams); + console.log("Qwen3.5-397B parameter count by dtype:", parse.parameterCount); + }); }); diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index f85e845d58..3669bf9e18 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -1,4 +1,6 @@ import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { createApiError } from "../error"; import { omit } from "../utils/omit"; import { toRepoId } from "../utils/toRepoId"; import { typedEntries } from "../utils/typedEntries"; @@ -6,6 +8,7 @@ import { downloadFile } from "./download-file"; import { fileExists } from "./file-exists"; import { promisesQueue } from "../utils/promisesQueue"; import type { SetRequired } from "../vendor/type-fest/set-required"; +import { HUB_URL } from "../consts"; export const SAFETENSORS_FILE = "model.safetensors"; export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; @@ -197,6 +200,74 @@ async function parseShardedIndex( } } +/** + * Fetches a safetensors header via two range requests (8 bytes for length, then exact header), + * bypassing downloadFile/fileDownloadInfo to reduce HTTP round-trips from 3 to 2 per shard. + */ +async function fetchHeaderDirect( + url: string, + customFetch: typeof fetch, + accessToken: string | undefined, +): Promise { + const headers = accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined; + + // Step 1: fetch the 8-byte header length prefix + const resp = await customFetch(url, { + headers: { Range: "bytes=0-7", ...headers }, + }); + + if (resp.status !== 206 && resp.status !== 200) { + throw await createApiError(resp); + } + + // If the server ignored the Range header (200 instead of 206), abort to avoid + // downloading the entire multi-GB shard file into memory. + if (resp.status !== 206) { + resp.body?.cancel(); + throw new SafetensorParseError( + `Server does not support range requests (status ${resp.status}). Cannot efficiently parse safetensors header.`, + ); + } + + const lengthBuf = await resp.arrayBuffer(); + + if (lengthBuf.byteLength < 8) { + throw new SafetensorParseError(`Failed to fetch safetensors header: response too small.`); + } + + const lengthOfHeader = new DataView(lengthBuf).getBigUint64(0, true); + + if (lengthOfHeader <= 0) { + throw new SafetensorParseError(`Failed to parse safetensors header: header is malformed.`); + } + if (lengthOfHeader > MAX_HEADER_LENGTH) { + throw new SafetensorParseError( + `Failed to parse safetensors header: header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.`, + ); + } + + // Step 2: fetch exactly the header bytes + const resp2 = await customFetch(url, { + headers: { Range: `bytes=8-${8 + Number(lengthOfHeader) - 1}`, ...headers }, + }); + + if (resp2.status !== 206) { + if (!resp2.ok) { + throw await createApiError(resp2); + } + resp2.body?.cancel(); + throw new SafetensorParseError( + `Server does not support range requests (status ${resp2.status}). Cannot efficiently parse safetensors header.`, + ); + } + + try { + return JSON.parse(await resp2.text()); + } catch { + throw new SafetensorParseError(`Failed to parse safetensors header: not valid JSON.`); + } +} + async function fetchAllHeaders( path: string, index: SafetensorsIndexJson, @@ -210,13 +281,24 @@ async function fetchAllHeaders( fetch?: typeof fetch; } & Partial, ): Promise { + const repoId = toRepoId(params.repo); + const accessToken = checkCredentials(params); + const hubUrl = params.hubUrl ?? HUB_URL; + const customFetch = params.fetch ?? fetch; const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; + + const resolveUrl = (filename: string) => + `${hubUrl}/${repoId.name}/resolve/${encodeURIComponent(params.revision ?? "main")}/${pathPrefix}${filename}`; + const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( await promisesQueue( filenames.map( (filename) => async () => - [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader], + [filename, await fetchHeaderDirect(resolveUrl(filename), customFetch, accessToken)] satisfies [ + string, + SafetensorsFileHeader, + ], ), PARALLEL_DOWNLOADS, ),