From b6fffc0fa57479c768b96b036fce954c9127dae5 Mon Sep 17 00:00:00 2001 From: Felipe Forbeck Date: Mon, 28 Oct 2024 14:09:11 -0300 Subject: [PATCH 1/2] feat: count egress bytes --- src/bindings.d.ts | 12 +- src/index.js | 6 +- src/middleware/index.js | 1 + src/middleware/withEgressTracker.js | 117 +++++ src/middleware/withEgressTracker.types.ts | 6 + src/middleware/withRateLimit.js | 19 +- src/services/accounting.js | 4 +- test/fixtures/worker-fixture.js | 8 +- .../unit/middleware/withEgressTracker.spec.js | 462 ++++++++++++++++++ wrangler.toml | 9 +- 10 files changed, 621 insertions(+), 23 deletions(-) create mode 100644 src/middleware/withEgressTracker.js create mode 100644 src/middleware/withEgressTracker.types.ts create mode 100644 test/unit/middleware/withEgressTracker.spec.js diff --git a/src/bindings.d.ts b/src/bindings.d.ts index 3003676..929045e 100644 --- a/src/bindings.d.ts +++ b/src/bindings.d.ts @@ -2,19 +2,23 @@ import { CID } from '@web3-storage/gateway-lib/handlers' import { Environment as RateLimiterEnvironment } from './middleware/withRateLimit.types.ts' import { Environment as CarBlockEnvironment } from './middleware/withCarBlockHandler.types.ts' import { Environment as ContentClaimsDagulaEnvironment } from './middleware/withCarBlockHandler.types.ts' - +import { Environment as EgressTrackerEnvironment } from './middleware/withEgressTracker.types.ts' +import { UnknownLink } from 'multiformats' export interface Environment extends CarBlockEnvironment, RateLimiterEnvironment, - ContentClaimsDagulaEnvironment { + ContentClaimsDagulaEnvironment, + EgressTrackerEnvironment { VERSION: string + CONTENT_CLAIMS_SERVICE_URL?: string + ACCOUNTING_SERVICE_URL: string } export interface AccountingService { - record: (cid: CID, options: GetCIDRequestConfig) => Promise + record: (resource: UnknownLink, bytes: number, servedAt: string) => Promise getTokenMetadata: (token: string) => Promise } export interface Accounting { - create: ({ serviceURL }: { serviceURL?: string }) => AccountingService + create: ({ serviceURL }: { serviceURL: string }) => AccountingService } diff --git a/src/index.js b/src/index.js index c054127..ba074e5 100644 --- a/src/index.js +++ b/src/index.js @@ -22,7 +22,8 @@ import { withCarBlockHandler, withRateLimit, withNotFound, - withLocator + withLocator, + withEgressTracker } from './middleware/index.js' /** @@ -57,6 +58,9 @@ export default { // Rate-limit requests withRateLimit, + // Track egress bytes + withEgressTracker, + // Fetch data withCarBlockHandler, withNotFound, diff --git a/src/middleware/index.js b/src/middleware/index.js index ed13c8e..d2a7613 100644 --- a/src/middleware/index.js +++ b/src/middleware/index.js @@ -5,3 +5,4 @@ export { withRateLimit } from './withRateLimit.js' export { withVersionHeader } from './withVersionHeader.js' export { withNotFound } from './withNotFound.js' export { withLocator } from './withLocator.js' +export { withEgressTracker } from './withEgressTracker.js' \ No newline at end of file diff --git a/src/middleware/withEgressTracker.js b/src/middleware/withEgressTracker.js new file mode 100644 index 0000000..56d688a --- /dev/null +++ b/src/middleware/withEgressTracker.js @@ -0,0 +1,117 @@ +import { Accounting } from '../services/accounting.js' + +/** + * @import { Context, IpfsUrlContext, Middleware } from '@web3-storage/gateway-lib' + * @import { Environment } from './withEgressTracker.types.js' + * @import { AccountingService } from '../bindings.js' + * @typedef {IpfsUrlContext & { ACCOUNTING_SERVICE?: AccountingService }} EgressTrackerContext + */ + +/** + * The egress tracking handler must be enabled after the rate limiting handler, + * and before any handler that serves the response body. It uses the CID of the + * served content to record the egress in the accounting service, and it counts + * the bytes served with a TransformStream to determine the egress amount. + * + * @type {Middleware} + */ +export function withEgressTracker (handler) { + return async (req, env, ctx) => { + if (env.FF_EGRESS_TRACKER_ENABLED !== 'true') { + return handler(req, env, ctx) + } + + let response + try { + response = await handler(req, env, ctx) + } catch (error) { + console.error('Error in egress tracker handler:', error) + throw error + } + + if (!response.ok || !response.body) { + return response + } + + const { dataCid } = ctx + const accounting = ctx.ACCOUNTING_SERVICE ?? Accounting.create({ + serviceURL: env.ACCOUNTING_SERVICE_URL + }) + + const { readable, writable } = createEgressPassThroughStream(ctx, accounting, dataCid) + + try { + ctx.waitUntil(response.body.pipeTo(writable)) + } catch (error) { + console.error('Error in egress tracker handler:', error) + // Original response in case of an error to avoid breaking the chain and serve the content + return response + } + + return new Response(readable, { + status: response.status, + statusText: response.statusText, + headers: response.headers + }) + } +} + +/** + * Creates a TransformStream to count bytes served to the client. + * It records egress when the stream is finalized without an error. + * + * @param {import('@web3-storage/gateway-lib/middleware').Context} ctx - The context object. + * @param {AccountingService} accounting - The accounting service instance to record egress. + * @param {import('@web3-storage/gateway-lib/handlers').CID} dataCid - The CID of the served content. + * @returns {TransformStream} - The created TransformStream. + */ +function createEgressPassThroughStream (ctx, accounting, dataCid) { + let totalBytesServed = 0 + + return new TransformStream({ + /** + * The start function is called when the stream is being initialized. + * It resets the total bytes served to 0. + */ + start () { + totalBytesServed = 0 + }, + /** + * The transform function is called for each chunk of the response body. + * It enqueues the chunk and updates the total bytes served. + * If an error occurs, it signals an error to the controller and logs it. + * The bytes are not counted in case of enqueuing an error. + * @param {Uint8Array} chunk + * @param {TransformStreamDefaultController} controller + */ + async transform (chunk, controller) { + try { + controller.enqueue(chunk) + totalBytesServed += chunk.byteLength + } catch (error) { + console.error('Error while counting egress bytes:', error) + controller.error(error) + } + }, + + /** + * The flush function is called when the stream is being finalized, + * which is when the response is being sent to the client. + * So before the response is sent, we record the egress. + * It is called only once and it triggers a non-blocking call to the accounting service. + * If an error occurs, the egress is not recorded. + * NOTE: The flush function is NOT called in case of an stream error. + */ + async flush (controller) { + try { + // Non-blocking call to the accounting service to record egress + if (totalBytesServed > 0) { + ctx.waitUntil(accounting.record(dataCid, totalBytesServed, new Date().toISOString())) + } + } catch (error) { + console.error('Error while recording egress:', error) + controller.error(error) + } + } + }) +} diff --git a/src/middleware/withEgressTracker.types.ts b/src/middleware/withEgressTracker.types.ts new file mode 100644 index 0000000..d764548 --- /dev/null +++ b/src/middleware/withEgressTracker.types.ts @@ -0,0 +1,6 @@ +import { Environment as MiddlewareEnvironment } from '@web3-storage/gateway-lib' + +export interface Environment extends MiddlewareEnvironment { + ACCOUNTING_SERVICE_URL: string + FF_EGRESS_TRACKER_ENABLED: string +} diff --git a/src/middleware/withRateLimit.js b/src/middleware/withRateLimit.js index 3e5e4c7..5487dcb 100644 --- a/src/middleware/withRateLimit.js +++ b/src/middleware/withRateLimit.js @@ -12,6 +12,7 @@ import { Accounting } from '../services/accounting.js' * RateLimitService, * RateLimitExceeded * } from './withRateLimit.types.js' + * @typedef {Context & { ACCOUNTING_SERVICE?: import('../bindings.js').AccountingService }} RateLimiterContext */ /** @@ -20,7 +21,7 @@ import { Accounting } from '../services/accounting.js' * it can be enabled or disabled using the FF_RATE_LIMITER_ENABLED flag. * Every successful request is recorded in the accounting service. * - * @type {Middleware} + * @type {Middleware} */ export function withRateLimit (handler) { return async (req, env, ctx) => { @@ -33,20 +34,14 @@ export function withRateLimit (handler) { const isRateLimitExceeded = await rateLimitService.check(dataCid, req) if (isRateLimitExceeded === RATE_LIMIT_EXCEEDED.YES) { throw new HttpError('Too Many Requests', { status: 429 }) - } else { - const accounting = Accounting.create({ - serviceURL: env.ACCOUNTING_SERVICE_URL - }) - // NOTE: non-blocking call to the accounting service - ctx.waitUntil(accounting.record(dataCid, req)) - return handler(req, env, ctx) } + return handler(req, env, ctx) } } /** * @param {Environment} env - * @param {Context} ctx + * @param {RateLimiterContext} ctx * @returns {RateLimitService} */ function create (env, ctx) { @@ -105,7 +100,7 @@ async function isRateLimited (rateLimitAPI, cid) { /** * @param {Environment} env * @param {string} authToken - * @param {Context} ctx + * @param {RateLimiterContext} ctx * @returns {Promise} */ async function getTokenMetadata (env, authToken, ctx) { @@ -116,9 +111,7 @@ async function getTokenMetadata (env, authToken, ctx) { return decode(cachedValue) } - const accounting = Accounting.create({ - serviceURL: env.ACCOUNTING_SERVICE_URL - }) + const accounting = ctx.ACCOUNTING_SERVICE ?? Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) const tokenMetadata = await accounting.getTokenMetadata(authToken) if (tokenMetadata) { // NOTE: non-blocking call to the auth token metadata cache diff --git a/src/services/accounting.js b/src/services/accounting.js index 6bcc289..4fe2462 100644 --- a/src/services/accounting.js +++ b/src/services/accounting.js @@ -3,8 +3,8 @@ */ export const Accounting = { create: ({ serviceURL }) => ({ - record: async (cid, options) => { - console.log(`using ${serviceURL} to record a GET for ${cid} with options`, options) + record: async (cid, bytes, servedAt) => { + console.log(`using ${serviceURL} to record egress for ${cid} with total bytes: ${bytes} and servedAt: ${servedAt}`) }, getTokenMetadata: async () => { diff --git a/test/fixtures/worker-fixture.js b/test/fixtures/worker-fixture.js index 8c7eeaa..a4a75de 100644 --- a/test/fixtures/worker-fixture.js +++ b/test/fixtures/worker-fixture.js @@ -12,6 +12,8 @@ const __dirname = path.dirname(__filename) */ const wranglerEnv = process.env.WRANGLER_ENV || 'integration' +const DEBUG = process.env.DEBUG === 'true' || false + /** * Worker information object * @typedef {Object} WorkerInfo @@ -41,7 +43,7 @@ export const mochaGlobalSetup = async () => { ) console.log(`Output: ${await workerInfo.getOutput()}`) console.log('WorkerInfo:', workerInfo) - console.log('Test worker started!') + console.log(`Test worker started! ENV: ${wranglerEnv}, DEBUG: ${DEBUG}`) } catch (error) { console.error('Failed to start test worker:', error) throw error @@ -59,7 +61,9 @@ export const mochaGlobalTeardown = async () => { try { const { stop } = workerInfo await stop?.() - // console.log('getOutput', getOutput()) // uncomment for debugging + if (DEBUG) { + console.log('getOutput', await workerInfo.getOutput()) + } console.log('Test worker stopped!') } catch (error) { console.error('Failed to stop test worker:', error) diff --git a/test/unit/middleware/withEgressTracker.spec.js b/test/unit/middleware/withEgressTracker.spec.js new file mode 100644 index 0000000..2df3ccc --- /dev/null +++ b/test/unit/middleware/withEgressTracker.spec.js @@ -0,0 +1,462 @@ +/* eslint-disable no-unused-expressions + --- + `no-unused-expressions` doesn't understand that several of Chai's assertions + are implemented as getters rather than explicit function calls; it thinks + the assertions are unused expressions. */ +import { randomBytes } from 'node:crypto' +import { describe, it, afterEach, before } from 'mocha' +import { assert, expect } from 'chai' +import sinon from 'sinon' +import { CID } from 'multiformats' +import { withEgressHandler } from '../../../src/handlers/egress-tracker.js' +import { Builder, toBlobKey } from '../../helpers/builder.js' +import { CARReaderStream } from 'carstream' + +/** + * Creates a request with an optional authorization header. + * + * @param {Object} [options] + * @param {string} [options.authorization] The value for the `Authorization` + * header, if any. + */ +const createRequest = async ({ authorization } = {}) => + new Request('http://doesnt-matter.com/', { + headers: new Headers( + authorization ? { Authorization: authorization } : {} + ) + }) + +const env = + /** @satisfies {import('../../../src/handlers/egress-tracker.types.js').Environment} */ + ({ + DEBUG: 'true', + ACCOUNTING_SERVICE_URL: 'http://example.com', + FF_EGRESS_TRACKER_ENABLED: 'true' + }) + +const accountingRecordMethodStub = sinon.stub() + // @ts-expect-error + .returns(async (cid, bytes, servedAt) => { + console.log(`[mock] record called with cid: ${cid}, bytes: ${bytes}, servedAt: ${servedAt}`) + return Promise.resolve() + }) + +/** + * Mock implementation of the AccountingService. + * + * @param {Object} options + * @param {string} options.serviceURL - The URL of the accounting service. + * @returns {import('../../../src/bindings.js').AccountingService} + */ +const AccountingService = ({ serviceURL }) => { + console.log(`[mock] Accounting.create called with serviceURL: ${serviceURL}`) + + return { + record: accountingRecordMethodStub, + getTokenMetadata: sinon.stub().resolves(undefined) + } +} + +const ctx = + /** @satisfies {import('../../../src/handlers/egress-tracker.js').EgressTrackerContext} */ + ({ + dataCid: CID.parse('bafybeibv7vzycdcnydl5n5lbws6lul2omkm6a6b5wmqt77sicrwnhesy7y'), + waitUntil: sinon.stub().returns(undefined), + path: '', + searchParams: new URLSearchParams(), + ACCOUNTING_SERVICE: AccountingService({ serviceURL: env.ACCOUNTING_SERVICE_URL }) + }) + +describe('withEgressTracker', async () => { + /** @type {Builder} */ + let builder + /** @type {Map} */ + let bucketData + /** @type {{ put: (digest: string, bytes: Uint8Array) => Promise, get: (digest: string) => Promise }} */ + let bucket + + before(async () => { + bucketData = new Map() + bucket = { + put: async (/** @type {string} */ digest, /** @type {Uint8Array} */ bytes) => { + console.log(`[mock] bucket.put called with digest: ${digest}, bytes: ${bytes.byteLength}`) + bucketData.set(digest, bytes) + return Promise.resolve() + }, + // @ts-expect-error - don't need to check the type of the fake bucket + get: async (/** @type {string} */ blobKey) => { + console.log(`[mock] bucket.get called with digest: ${blobKey}`) + return Promise.resolve(bucketData.get(blobKey)) + } + } + builder = new Builder(bucket) + }) + + afterEach(() => { + accountingRecordMethodStub.reset() + bucketData.clear() + }) + + describe('withEgressTracker -> Successful Requests', () => { + it('should track egress bytes for a successful request', async () => { + const content = new TextEncoder().encode('Hello, world!') + const totalBytes = Buffer.byteLength(content) + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + const response = await handler(request, env, ctx) + // Ensure the response body is fully consumed + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][0], 'first argument should be the cid').to.equal(ctx.dataCid) + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(10_000) + + it('should record egress for a large file', async () => { + const largeContent = new Uint8Array(100 * 1024 * 1024) // 100 MB + const totalBytes = largeContent.byteLength + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(largeContent) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + await response.text() // Consume the response body + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][0], 'first argument should be the cid').to.equal(ctx.dataCid) + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }) + + it('should correctly track egress for responses with chunked transfer encoding', async () => { + const chunk1 = new TextEncoder().encode('Hello, ') + const chunk2 = new TextEncoder().encode('world!') + const totalBytes = Buffer.byteLength(chunk1) + Buffer.byteLength(chunk2) + + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(chunk1) + controller.enqueue(chunk2) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }) + + it('should record egress bytes for a CAR file request', async () => { + // Simulate a CAR file content + const carContent = new Blob([randomBytes(256)]) + const { shards } = await builder.add(carContent) + assert.equal(shards.length, 1) + + const key = toBlobKey(shards[0].multihash) + /** @type {Uint8Array | undefined} */ + const carBytes = await bucket.get(key) + expect(carBytes).to.be.not.undefined + expect(carBytes).to.be.instanceOf(Uint8Array) + const expectedTotalBytes = carBytes.byteLength + + // Mock a response with the CAR file content + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(carBytes) + controller.close() + } + }), { + status: 200, + headers: { 'Content-Type': 'application/vnd.ipld.car; version=1;' } + }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + expect(response.status).to.equal(200) + + // Consume the response body by reading the CAR file + const source = /** @type {ReadableStream} */ (await response.body) + + /** @type {(import('carstream').Block & import('carstream').Position)[]} */ + const blocks = [] + await source + .pipeThrough(new CARReaderStream()) + .pipeTo(new WritableStream({ + write: (block) => { blocks.push(block) } + })) + + // expect(blocks[0].bytes).to.deep.equal(carBytes) - FIXME (fforbeck): how to get the correct byte count? + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(expectedTotalBytes) + }) + + it('should correctly track egress for delayed responses', async () => { + const content = new TextEncoder().encode('Delayed response content') + const totalBytes = Buffer.byteLength(content) + + const mockResponse = new Response(new ReadableStream({ + start (controller) { + setTimeout(() => { + controller.enqueue(content) + controller.close() + }, 2000) // Simulate a delay of 2 seconds + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Delayed response content') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(5000) + }) + + describe('withEgressTracker -> Feature Flag', () => { + it('should not track egress if the feature flag is disabled', async () => { + const innerHandler = sinon.stub().resolves(new Response(null, { status: 200 })) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + const envDisabled = { ...env, FF_EGRESS_TRACKER_ENABLED: 'false' } + + const response = await handler(request, envDisabled, ctx) + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.notCalled, 'record should not be called').to.be.true + }) + }) + + describe('withEgressTracker -> Non-OK Responses', () => { + it('should not track egress for non-OK responses', async () => { + const mockResponse = new Response(null, { status: 404 }) + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + expect(response.status).to.equal(404) + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + + it('should not track egress if the response has no body', async () => { + const mockResponse = new Response(null, { status: 200 }) + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) + + describe('withEgressTracker -> Concurrent Requests', () => { + it('should correctly track egress for multiple concurrent requests', async () => { + const content1 = new TextEncoder().encode('Hello, world!') + const content2 = new TextEncoder().encode('Goodbye, world!') + const totalBytes1 = Buffer.byteLength(content1) + const totalBytes2 = Buffer.byteLength(content2) + + const mockResponse1 = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content1) + controller.close() + } + }), { status: 200 }) + + const mockResponse2 = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content2) + controller.close() + } + }), { status: 200 }) + + const innerHandler1 = sinon.stub().resolves(mockResponse1) + const innerHandler2 = sinon.stub().resolves(mockResponse2) + + const handler1 = withEgressHandler(innerHandler1) + const handler2 = withEgressHandler(innerHandler2) + + const request1 = await createRequest() + const request2 = await createRequest() + + const [response1, response2] = await Promise.all([ + handler1(request1, env, ctx), + handler2(request2, env, ctx) + ]) + + const responseBody1 = await response1.text() + const responseBody2 = await response2.text() + + expect(response1.status).to.equal(200) + expect(responseBody1).to.equal('Hello, world!') + expect(response2.status).to.equal(200) + expect(responseBody2).to.equal('Goodbye, world!') + + expect(accountingRecordMethodStub.calledTwice, 'record should be called twice').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes for first request').to.equal(totalBytes1) + expect(accountingRecordMethodStub.args[1][1], 'second argument should be the total bytes for second request').to.equal(totalBytes2) + }).timeout(10_000) + }) + + describe('withEgressTracker -> Different Content Types', () => { + it('should track egress for JSON content type', async () => { + const jsonContent = JSON.stringify({ message: 'Hello, JSON!' }) + const totalBytes = Buffer.byteLength(jsonContent) + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(new TextEncoder().encode(jsonContent)) + controller.close() + } + }), { status: 200, headers: { 'Content-Type': 'application/json' } }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.json() + + expect(response.status).to.equal(200) + expect(responseBody).to.deep.equal({ message: 'Hello, JSON!' }) + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(10_000) + }) + + describe('withEgressTracker -> Zero-byte Responses', () => { + it('should not record egress for zero-byte responses', async () => { + const mockResponse = new Response(new ReadableStream({ + start (controller) { + // Do not enqueue any data, simulating a zero-byte response + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('') + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) + + describe('withEgressTracker -> Interrupted Connection', () => { + it('should not record egress if there is a stream error while downloading', async () => { + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.error(new Error('Stream error')) + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + try { + // Consume the response body to trigger the error + await response.text() + expect.fail('Expected a stream error') + } catch (/** @type {any} */ error) { + expect(error.message).to.equal('Stream error') + } + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + + it('should not record egress if the connection is interrupted during a large file download', async () => { + const largeContent = new Uint8Array(100 * 1024 * 1024) // 100 MB + const mockResponse = new Response(new ReadableStream({ + start (controller) { + // Stream a portion of the content + controller.enqueue(largeContent.subarray(0, 10 * 1024 * 1024)) // 10 MB + // Simulate connection interruption by raising an error + controller.error(new Error('Connection interrupted')) + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + const response = await handler(request, env, ctx) + + try { + // Consume the response body to trigger the error + await response.text() + expect.fail('Expected a connection interrupted error') + } catch (/** @type {any} */ error) { + expect(error.message).to.equal('Connection interrupted') + } + + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }).timeout(10_000) + }) + + describe('withEgressTracker -> Accounting Service', () => { + it('should log an error and continue serving the response if the accounting service fails', async () => { + const content = new TextEncoder().encode('Hello, world!') + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressHandler(innerHandler) + const request = await createRequest() + + // Simulate an error in the accounting service record method + ctx.ACCOUNTING_SERVICE.record = sinon.stub().rejects(new Error('Accounting service error')) + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) +}) diff --git a/wrangler.toml b/wrangler.toml index 2f30a96..7b2918a 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -43,6 +43,7 @@ command = "npm run build" [env.production.vars] MAX_SHARDS = "825" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://claims.web3.storage" # Staging! @@ -59,6 +60,7 @@ command = "npm run build" [env.staging.vars] MAX_SHARDS = "825" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" # Test! @@ -71,6 +73,7 @@ r2_buckets = [ [env.test.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" MAX_SHARDS = "120" CONTENT_CLAIMS_SERVICE_URL = "https://test.claims.web3.storage" @@ -84,6 +87,7 @@ r2_buckets = [ [env.alanshaw.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://dev.claims.web3.storage" [env.fforbeck] @@ -100,7 +104,8 @@ r2_buckets = [ [env.fforbeck.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" -CONTENT_CLAIMS_SERVICE_URL = "https://dev.claims.web3.storage" +FF_EGRESS_TRACKER_ENABLED = "false" +CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" [[env.fforbeck.unsafe.bindings]] name = "RATE_LIMITER" @@ -125,7 +130,9 @@ r2_buckets = [ [env.integration.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "true" +FF_EGRESS_TRACKER_ENABLED = "true" CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" +ACCOUNTING_SERVICE_URL = "https://example.com/service" [[env.integration.unsafe.bindings]] name = "RATE_LIMITER" From 25982f4fd07e7345f921cfdd9a257eab3cf79c2c Mon Sep 17 00:00:00 2001 From: Felipe Forbeck Date: Tue, 29 Oct 2024 16:38:00 -0300 Subject: [PATCH 2/2] reviewer suggestions implemented --- src/middleware/index.js | 2 +- src/middleware/withEgressTracker.js | 67 ++++++------------- test/fixtures/worker-fixture.js | 2 +- .../unit/middleware/withEgressTracker.spec.js | 46 ++++++------- 4 files changed, 46 insertions(+), 71 deletions(-) diff --git a/src/middleware/index.js b/src/middleware/index.js index d2a7613..6fe437e 100644 --- a/src/middleware/index.js +++ b/src/middleware/index.js @@ -5,4 +5,4 @@ export { withRateLimit } from './withRateLimit.js' export { withVersionHeader } from './withVersionHeader.js' export { withNotFound } from './withNotFound.js' export { withLocator } from './withLocator.js' -export { withEgressTracker } from './withEgressTracker.js' \ No newline at end of file +export { withEgressTracker } from './withEgressTracker.js' diff --git a/src/middleware/withEgressTracker.js b/src/middleware/withEgressTracker.js index 56d688a..125f5df 100644 --- a/src/middleware/withEgressTracker.js +++ b/src/middleware/withEgressTracker.js @@ -21,14 +21,7 @@ export function withEgressTracker (handler) { return handler(req, env, ctx) } - let response - try { - response = await handler(req, env, ctx) - } catch (error) { - console.error('Error in egress tracker handler:', error) - throw error - } - + const response = await handler(req, env, ctx) if (!response.ok || !response.body) { return response } @@ -38,17 +31,18 @@ export function withEgressTracker (handler) { serviceURL: env.ACCOUNTING_SERVICE_URL }) - const { readable, writable } = createEgressPassThroughStream(ctx, accounting, dataCid) - - try { - ctx.waitUntil(response.body.pipeTo(writable)) - } catch (error) { - console.error('Error in egress tracker handler:', error) - // Original response in case of an error to avoid breaking the chain and serve the content - return response - } + const responseBody = response.body.pipeThrough( + createByteCountStream((totalBytesServed) => { + // Non-blocking call to the accounting service to record egress + if (totalBytesServed > 0) { + ctx.waitUntil( + accounting.record(dataCid, totalBytesServed, new Date().toISOString()) + ) + } + }) + ) - return new Response(readable, { + return new Response(responseBody, { status: response.status, statusText: response.statusText, headers: response.headers @@ -60,36 +54,26 @@ export function withEgressTracker (handler) { * Creates a TransformStream to count bytes served to the client. * It records egress when the stream is finalized without an error. * - * @param {import('@web3-storage/gateway-lib/middleware').Context} ctx - The context object. - * @param {AccountingService} accounting - The accounting service instance to record egress. - * @param {import('@web3-storage/gateway-lib/handlers').CID} dataCid - The CID of the served content. - * @returns {TransformStream} - The created TransformStream. + * @param {(totalBytesServed: number) => void} onClose + * @template {Uint8Array} T + * @returns {TransformStream} - The created TransformStream. */ -function createEgressPassThroughStream (ctx, accounting, dataCid) { +function createByteCountStream (onClose) { let totalBytesServed = 0 return new TransformStream({ - /** - * The start function is called when the stream is being initialized. - * It resets the total bytes served to 0. - */ - start () { - totalBytesServed = 0 - }, /** * The transform function is called for each chunk of the response body. * It enqueues the chunk and updates the total bytes served. * If an error occurs, it signals an error to the controller and logs it. * The bytes are not counted in case of enqueuing an error. - * @param {Uint8Array} chunk - * @param {TransformStreamDefaultController} controller */ async transform (chunk, controller) { try { controller.enqueue(chunk) totalBytesServed += chunk.byteLength } catch (error) { - console.error('Error while counting egress bytes:', error) + console.error('Error while counting bytes:', error) controller.error(error) } }, @@ -97,21 +81,12 @@ function createEgressPassThroughStream (ctx, accounting, dataCid) { /** * The flush function is called when the stream is being finalized, * which is when the response is being sent to the client. - * So before the response is sent, we record the egress. - * It is called only once and it triggers a non-blocking call to the accounting service. + * So before the response is sent, we record the egress using the callback. * If an error occurs, the egress is not recorded. - * NOTE: The flush function is NOT called in case of an stream error. + * NOTE: The flush function is NOT called in case of a stream error. */ - async flush (controller) { - try { - // Non-blocking call to the accounting service to record egress - if (totalBytesServed > 0) { - ctx.waitUntil(accounting.record(dataCid, totalBytesServed, new Date().toISOString())) - } - } catch (error) { - console.error('Error while recording egress:', error) - controller.error(error) - } + async flush () { + onClose(totalBytesServed) } }) } diff --git a/test/fixtures/worker-fixture.js b/test/fixtures/worker-fixture.js index a4a75de..c5e9632 100644 --- a/test/fixtures/worker-fixture.js +++ b/test/fixtures/worker-fixture.js @@ -12,7 +12,7 @@ const __dirname = path.dirname(__filename) */ const wranglerEnv = process.env.WRANGLER_ENV || 'integration' -const DEBUG = process.env.DEBUG === 'true' || false +const DEBUG = process.env.DEBUG === 'true' /** * Worker information object diff --git a/test/unit/middleware/withEgressTracker.spec.js b/test/unit/middleware/withEgressTracker.spec.js index 2df3ccc..9e6dcbf 100644 --- a/test/unit/middleware/withEgressTracker.spec.js +++ b/test/unit/middleware/withEgressTracker.spec.js @@ -8,7 +8,7 @@ import { describe, it, afterEach, before } from 'mocha' import { assert, expect } from 'chai' import sinon from 'sinon' import { CID } from 'multiformats' -import { withEgressHandler } from '../../../src/handlers/egress-tracker.js' +import { withEgressTracker } from '../../../src/middleware/withEgressTracker.js' import { Builder, toBlobKey } from '../../helpers/builder.js' import { CARReaderStream } from 'carstream' @@ -27,7 +27,7 @@ const createRequest = async ({ authorization } = {}) => }) const env = - /** @satisfies {import('../../../src/handlers/egress-tracker.types.js').Environment} */ + /** @satisfies {import('../../../src/middleware/withEgressTracker.types.js').Environment} */ ({ DEBUG: 'true', ACCOUNTING_SERVICE_URL: 'http://example.com', @@ -35,11 +35,11 @@ const env = }) const accountingRecordMethodStub = sinon.stub() - // @ts-expect-error - .returns(async (cid, bytes, servedAt) => { - console.log(`[mock] record called with cid: ${cid}, bytes: ${bytes}, servedAt: ${servedAt}`) - return Promise.resolve() - }) + .returns( + /** @type {import('../../../src/bindings.js').AccountingService['record']} */ + async (cid, bytes, servedAt) => { + console.log(`[mock] record called with cid: ${cid}, bytes: ${bytes}, servedAt: ${servedAt}`) + }) /** * Mock implementation of the AccountingService. @@ -58,7 +58,7 @@ const AccountingService = ({ serviceURL }) => { } const ctx = - /** @satisfies {import('../../../src/handlers/egress-tracker.js').EgressTrackerContext} */ + /** @satisfies {import('../../../src/middleware/withEgressTracker.js').EgressTrackerContext} */ ({ dataCid: CID.parse('bafybeibv7vzycdcnydl5n5lbws6lul2omkm6a6b5wmqt77sicrwnhesy7y'), waitUntil: sinon.stub().returns(undefined), @@ -110,7 +110,7 @@ describe('withEgressTracker', async () => { const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) // Ensure the response body is fully consumed @@ -134,7 +134,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -160,7 +160,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -197,7 +197,7 @@ describe('withEgressTracker', async () => { }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -233,7 +233,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -249,7 +249,7 @@ describe('withEgressTracker', async () => { describe('withEgressTracker -> Feature Flag', () => { it('should not track egress if the feature flag is disabled', async () => { const innerHandler = sinon.stub().resolves(new Response(null, { status: 200 })) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const envDisabled = { ...env, FF_EGRESS_TRACKER_ENABLED: 'false' } @@ -264,7 +264,7 @@ describe('withEgressTracker', async () => { it('should not track egress for non-OK responses', async () => { const mockResponse = new Response(null, { status: 404 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -276,7 +276,7 @@ describe('withEgressTracker', async () => { it('should not track egress if the response has no body', async () => { const mockResponse = new Response(null, { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -310,8 +310,8 @@ describe('withEgressTracker', async () => { const innerHandler1 = sinon.stub().resolves(mockResponse1) const innerHandler2 = sinon.stub().resolves(mockResponse2) - const handler1 = withEgressHandler(innerHandler1) - const handler2 = withEgressHandler(innerHandler2) + const handler1 = withEgressTracker(innerHandler1) + const handler2 = withEgressTracker(innerHandler2) const request1 = await createRequest() const request2 = await createRequest() @@ -347,7 +347,7 @@ describe('withEgressTracker', async () => { }), { status: 200, headers: { 'Content-Type': 'application/json' } }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -370,7 +370,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -391,7 +391,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -418,7 +418,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() const response = await handler(request, env, ctx) @@ -445,7 +445,7 @@ describe('withEgressTracker', async () => { }), { status: 200 }) const innerHandler = sinon.stub().resolves(mockResponse) - const handler = withEgressHandler(innerHandler) + const handler = withEgressTracker(innerHandler) const request = await createRequest() // Simulate an error in the accounting service record method