Skip to content

Add extracted data validation #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions node-zerox/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
OperationMode,
Page,
PageStatus,
ValidationLog,
ZeroxArgs,
ZeroxOutput,
} from "./types";
Expand Down Expand Up @@ -82,6 +83,7 @@ export const zerox = async ({
let priorPage: string = "";
let pages: Page[] = [];
let imagePaths: string[] = [];
let validationLog: ValidationLog = { extracted: [] };
const startTime = new Date();

if (openaiAPIKey && openaiAPIKey.length > 0) {
Expand Down Expand Up @@ -276,10 +278,10 @@ export const zerox = async ({
});
}

const response = CompletionProcessor.process(
OperationMode.OCR,
rawResponse
);
const response = CompletionProcessor.process({
mode: OperationMode.OCR,
response: rawResponse as CompletionResponse,
});

inputTokenCount += response.inputTokens;
outputTokenCount += response.outputTokens;
Expand Down Expand Up @@ -361,6 +363,7 @@ export const zerox = async ({
schema: Record<string, unknown>
): Promise<Record<string, unknown>> => {
let result: Record<string, unknown> = {};
let validationResult: Record<string, unknown> | null = null;
try {
await runRetries(
async () => {
Expand All @@ -381,16 +384,19 @@ export const zerox = async ({
});
}

const response = CompletionProcessor.process(
OperationMode.EXTRACTION,
rawResponse
);
const response = CompletionProcessor.process({
mode: OperationMode.EXTRACTION,
response: rawResponse as ExtractionResponse,
schema,
});

inputTokenCount += response.inputTokens;
outputTokenCount += response.outputTokens;

numSuccessfulExtractionRequests++;

if (response.issues && response.issues.length > 0) {
validationResult = { page: pageNumber, issues: response.issues };
}
for (const key of Object.keys(schema?.properties ?? {})) {
const value = response.extracted[key];
if (value !== null && value !== undefined) {
Expand All @@ -409,7 +415,7 @@ export const zerox = async ({
throw error;
}

return result;
return { result, validationResult };
};

if (perPageSchema) {
Expand Down Expand Up @@ -438,6 +444,7 @@ export const zerox = async ({
extractionTasks.push(
(async () => {
let result: Record<string, unknown> = {};
let validationResult: Record<string, unknown> | null = null;
try {
await runRetries(
async () => {
Expand All @@ -459,20 +466,25 @@ export const zerox = async ({
});
}

const response = CompletionProcessor.process(
OperationMode.EXTRACTION,
rawResponse
);
const response = CompletionProcessor.process({
mode: OperationMode.EXTRACTION,
response: rawResponse as ExtractionResponse,
schema,
});

inputTokenCount += response.inputTokens;
outputTokenCount += response.outputTokens;
numSuccessfulExtractionRequests++;
if (response.issues && response.issues.length > 0) {
validationResult = { page: null, issues: response.issues };
}

result = response.extracted;
},
maxRetries,
0
);
return result;
return { result, validationResult };
} catch (error) {
numFailedExtractionRequests++;
throw error;
Expand All @@ -482,8 +494,12 @@ export const zerox = async ({
}

const results = await Promise.all(extractionTasks);
extracted = results.reduce((acc, result) => {
Object.entries(result || {}).forEach(([key, value]) => {
validationLog.extracted = results.reduce(
(acc, result) => (result.validationResult ? [...acc, result.validationResult] : acc),
[]
);
extracted = results.reduce((acc, resultObj) => {
Object.entries(resultObj?.result || {}).forEach(([key, value]) => {
if (!acc[key]) {
acc[key] = [];
}
Expand Down Expand Up @@ -573,6 +589,7 @@ export const zerox = async ({
}
: null,
},
validationLog,
};
} finally {
if (correctOrientation && scheduler) {
Expand Down
28 changes: 27 additions & 1 deletion node-zerox/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export interface ZeroxOutput {
outputTokens: number;
pages: Page[];
summary: Summary;
validationLog: ValidationLog;
}

export interface AzureCredentials {
Expand Down Expand Up @@ -177,7 +178,11 @@ export interface ExtractionResponse {
outputTokens: number;
}

export type ProcessedExtractionResponse = Omit<ExtractionResponse, "logprobs">;
// export type ProcessedExtractionResponse = Omit<ExtractionResponse, "logprobs">;

export interface ProcessedExtractionResponse extends Omit<ExtractionResponse, "logprobs"> {
issues: any;
}

interface BaseLLMParams {
frequencyPenalty?: number;
Expand Down Expand Up @@ -254,3 +259,24 @@ export interface ExcelSheetContent {
contentLength: number;
sheetName: string;
}

// Define extraction-specific parameters
export interface ExtractionProcessParams {
mode: OperationMode.EXTRACTION;
response: ExtractionResponse;
schema: Record<string, unknown>;
}

// Define OCR-specific parameters
export interface CompletionProcessParams {
mode: OperationMode.OCR;
response: CompletionResponse;
schema?: undefined;
}

// Union type for all possible parameter combinations
export type ProcessParams = ExtractionProcessParams | CompletionProcessParams;

export interface ValidationLog {
extracted: { page: number | null, issues: any }[];
}
16 changes: 16 additions & 0 deletions node-zerox/src/utils/common.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { parse } from "flatted";

export const camelToSnakeCase = (str: string) =>
str.replace(/[A-Z]/g, (letter: string) => `_${letter.toLowerCase()}`);

Expand Down Expand Up @@ -119,3 +121,17 @@ export const splitSchema = (
: null,
};
};

export const formatJsonValue = (
value: any,
useFlatted: boolean = false
): any => {
if (typeof value === "string") {
try {
return useFlatted ? parse(value) : JSON.parse(value);
} catch {
return value;
}
}
return value;
};
114 changes: 114 additions & 0 deletions node-zerox/src/utils/fixSchemaValidationErrors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { formatJsonValue } from "../utils";
import { JSONSchema, JSONSchemaDefinition } from "openai/lib/jsonschema";
import { ZodError } from "zod";

/**
* Handles specific cases of ZodError by traversing the
* error paths in the original value and modifying invalid entries (e.g., replacing
* invalid enum values with null or converting strings to booleans or numbers)
*
* Handled cases:
* - Boolean strings ("true" or "false") should be converted to actual booleans
* - Numeric strings (e.g., "123") should be converted to numbers
* - For other cases, default to the default value or null
*
* @param {ZodError} err - The error object containing validation details
* @returns {any} - The modified value object with resolved issues
*/
export const fixSchemaValidationErrors = ({
err,
schema,
value: originalValue,
}: {
err: ZodError<Record<string, any>>;
schema: JSONSchema;
value: Record<string, any>;
}) => {
const errors = err.issues;
let value = originalValue;

errors.forEach((error) => {
const lastKey = error.path[error.path.length - 1];

let parent = value;
for (let i = 0; i < error.path.length - 1; i++) {
parent = parent?.[error.path[i]];
}

let defaultValue = null;
if (schema) {
let schemaProperty = schema;
let properties: JSONSchema | JSONSchemaDefinition[] =
schemaProperty.properties || schemaProperty;

for (let i = 0; i < error.path.length; i++) {
const pathKey = error.path[i];
if (properties && properties[pathKey as keyof typeof properties]) {
schemaProperty = properties[
pathKey as keyof typeof properties
] as JSONSchema;
if (schemaProperty.type === "array" && schemaProperty.items) {
// If array of object (table)
if ((schemaProperty.items as JSONSchema).type === "object") {
properties = (schemaProperty.items as JSONSchema).properties || {};
i++; // Skip the numeric path (row index)
} else {
properties = schemaProperty.items as JSONSchema;
}
} else {
properties = schemaProperty.properties || {};
}
}
}

if (schemaProperty && "default" in schemaProperty) {
defaultValue = schemaProperty.default;
}
}

if (parent && typeof parent === "object") {
const currentValue = parent[lastKey];

if (
error.code === "invalid_type" &&
error.expected === "boolean" &&
error.received === "string" &&
(currentValue === "true" || currentValue === "false")
) {
parent[lastKey] = currentValue === "true";
} else if (
error.code === "invalid_type" &&
error.expected === "number" &&
error.received === "string" &&
!isNaN(Number(currentValue))
) {
parent[lastKey] = Number(currentValue);
} else if (
error.code === "invalid_type" &&
error.expected === "array" &&
error.received === "string"
) {
// TODO: could this be problematic? no check if the parsed array conformed to the schema
const value = formatJsonValue(currentValue);
if (Array.isArray(value)) {
parent[lastKey] = value;
}
} else if (
error.code === "invalid_type" &&
(error.expected === "array" ||
error.expected === "boolean" ||
error.expected === "integer" ||
error.expected === "number" ||
error.expected === "string" ||
error.expected.includes(" | ")) && // `Expected` for enums comes back as z.enum(['a', 'b']) => expected: "'a' | 'b'"
currentValue === undefined
) {
parent[lastKey] = defaultValue !== null ? defaultValue : null;
} else {
parent[lastKey] = defaultValue !== null ? defaultValue : null;
}
}
});

return value;
};
Loading