diff --git a/src/bindings.d.ts b/src/bindings.d.ts index 3003676..929045e 100644 --- a/src/bindings.d.ts +++ b/src/bindings.d.ts @@ -2,19 +2,23 @@ import { CID } from '@web3-storage/gateway-lib/handlers' import { Environment as RateLimiterEnvironment } from './middleware/withRateLimit.types.ts' import { Environment as CarBlockEnvironment } from './middleware/withCarBlockHandler.types.ts' import { Environment as ContentClaimsDagulaEnvironment } from './middleware/withCarBlockHandler.types.ts' - +import { Environment as EgressTrackerEnvironment } from './middleware/withEgressTracker.types.ts' +import { UnknownLink } from 'multiformats' export interface Environment extends CarBlockEnvironment, RateLimiterEnvironment, - ContentClaimsDagulaEnvironment { + ContentClaimsDagulaEnvironment, + EgressTrackerEnvironment { VERSION: string + CONTENT_CLAIMS_SERVICE_URL?: string + ACCOUNTING_SERVICE_URL: string } export interface AccountingService { - record: (cid: CID, options: GetCIDRequestConfig) => Promise + record: (resource: UnknownLink, bytes: number, servedAt: string) => Promise getTokenMetadata: (token: string) => Promise } export interface Accounting { - create: ({ serviceURL }: { serviceURL?: string }) => AccountingService + create: ({ serviceURL }: { serviceURL: string }) => AccountingService } diff --git a/src/index.js b/src/index.js index c054127..ba074e5 100644 --- a/src/index.js +++ b/src/index.js @@ -22,7 +22,8 @@ import { withCarBlockHandler, withRateLimit, withNotFound, - withLocator + withLocator, + withEgressTracker } from './middleware/index.js' /** @@ -57,6 +58,9 @@ export default { // Rate-limit requests withRateLimit, + // Track egress bytes + withEgressTracker, + // Fetch data withCarBlockHandler, withNotFound, diff --git a/src/middleware/index.js b/src/middleware/index.js index ed13c8e..6fe437e 100644 --- a/src/middleware/index.js +++ b/src/middleware/index.js @@ -5,3 +5,4 @@ export { withRateLimit } from './withRateLimit.js' export { withVersionHeader } from './withVersionHeader.js' export { withNotFound } from './withNotFound.js' export { withLocator } from './withLocator.js' +export { withEgressTracker } from './withEgressTracker.js' diff --git a/src/middleware/withEgressTracker.js b/src/middleware/withEgressTracker.js new file mode 100644 index 0000000..125f5df --- /dev/null +++ b/src/middleware/withEgressTracker.js @@ -0,0 +1,92 @@ +import { Accounting } from '../services/accounting.js' + +/** + * @import { Context, IpfsUrlContext, Middleware } from '@web3-storage/gateway-lib' + * @import { Environment } from './withEgressTracker.types.js' + * @import { AccountingService } from '../bindings.js' + * @typedef {IpfsUrlContext & { ACCOUNTING_SERVICE?: AccountingService }} EgressTrackerContext + */ + +/** + * The egress tracking handler must be enabled after the rate limiting handler, + * and before any handler that serves the response body. It uses the CID of the + * served content to record the egress in the accounting service, and it counts + * the bytes served with a TransformStream to determine the egress amount. + * + * @type {Middleware} + */ +export function withEgressTracker (handler) { + return async (req, env, ctx) => { + if (env.FF_EGRESS_TRACKER_ENABLED !== 'true') { + return handler(req, env, ctx) + } + + const response = await handler(req, env, ctx) + if (!response.ok || !response.body) { + return response + } + + const { dataCid } = ctx + const accounting = ctx.ACCOUNTING_SERVICE ?? Accounting.create({ + serviceURL: env.ACCOUNTING_SERVICE_URL + }) + + const responseBody = response.body.pipeThrough( + createByteCountStream((totalBytesServed) => { + // Non-blocking call to the accounting service to record egress + if (totalBytesServed > 0) { + ctx.waitUntil( + accounting.record(dataCid, totalBytesServed, new Date().toISOString()) + ) + } + }) + ) + + return new Response(responseBody, { + status: response.status, + statusText: response.statusText, + headers: response.headers + }) + } +} + +/** + * Creates a TransformStream to count bytes served to the client. + * It records egress when the stream is finalized without an error. + * + * @param {(totalBytesServed: number) => void} onClose + * @template {Uint8Array} T + * @returns {TransformStream} - The created TransformStream. + */ +function createByteCountStream (onClose) { + let totalBytesServed = 0 + + return new TransformStream({ + /** + * The transform function is called for each chunk of the response body. + * It enqueues the chunk and updates the total bytes served. + * If an error occurs, it signals an error to the controller and logs it. + * The bytes are not counted in case of enqueuing an error. + */ + async transform (chunk, controller) { + try { + controller.enqueue(chunk) + totalBytesServed += chunk.byteLength + } catch (error) { + console.error('Error while counting bytes:', error) + controller.error(error) + } + }, + + /** + * The flush function is called when the stream is being finalized, + * which is when the response is being sent to the client. + * So before the response is sent, we record the egress using the callback. + * If an error occurs, the egress is not recorded. + * NOTE: The flush function is NOT called in case of a stream error. + */ + async flush () { + onClose(totalBytesServed) + } + }) +} diff --git a/src/middleware/withEgressTracker.types.ts b/src/middleware/withEgressTracker.types.ts new file mode 100644 index 0000000..d764548 --- /dev/null +++ b/src/middleware/withEgressTracker.types.ts @@ -0,0 +1,6 @@ +import { Environment as MiddlewareEnvironment } from '@web3-storage/gateway-lib' + +export interface Environment extends MiddlewareEnvironment { + ACCOUNTING_SERVICE_URL: string + FF_EGRESS_TRACKER_ENABLED: string +} diff --git a/src/middleware/withRateLimit.js b/src/middleware/withRateLimit.js index 3e5e4c7..5487dcb 100644 --- a/src/middleware/withRateLimit.js +++ b/src/middleware/withRateLimit.js @@ -12,6 +12,7 @@ import { Accounting } from '../services/accounting.js' * RateLimitService, * RateLimitExceeded * } from './withRateLimit.types.js' + * @typedef {Context & { ACCOUNTING_SERVICE?: import('../bindings.js').AccountingService }} RateLimiterContext */ /** @@ -20,7 +21,7 @@ import { Accounting } from '../services/accounting.js' * it can be enabled or disabled using the FF_RATE_LIMITER_ENABLED flag. * Every successful request is recorded in the accounting service. * - * @type {Middleware} + * @type {Middleware} */ export function withRateLimit (handler) { return async (req, env, ctx) => { @@ -33,20 +34,14 @@ export function withRateLimit (handler) { const isRateLimitExceeded = await rateLimitService.check(dataCid, req) if (isRateLimitExceeded === RATE_LIMIT_EXCEEDED.YES) { throw new HttpError('Too Many Requests', { status: 429 }) - } else { - const accounting = Accounting.create({ - serviceURL: env.ACCOUNTING_SERVICE_URL - }) - // NOTE: non-blocking call to the accounting service - ctx.waitUntil(accounting.record(dataCid, req)) - return handler(req, env, ctx) } + return handler(req, env, ctx) } } /** * @param {Environment} env - * @param {Context} ctx + * @param {RateLimiterContext} ctx * @returns {RateLimitService} */ function create (env, ctx) { @@ -105,7 +100,7 @@ async function isRateLimited (rateLimitAPI, cid) { /** * @param {Environment} env * @param {string} authToken - * @param {Context} ctx + * @param {RateLimiterContext} ctx * @returns {Promise} */ async function getTokenMetadata (env, authToken, ctx) { @@ -116,9 +111,7 @@ async function getTokenMetadata (env, authToken, ctx) { return decode(cachedValue) } - const accounting = Accounting.create({ - serviceURL: env.ACCOUNTING_SERVICE_URL - }) + const accounting = ctx.ACCOUNTING_SERVICE ?? Accounting.create({ serviceURL: env.ACCOUNTING_SERVICE_URL }) const tokenMetadata = await accounting.getTokenMetadata(authToken) if (tokenMetadata) { // NOTE: non-blocking call to the auth token metadata cache diff --git a/src/services/accounting.js b/src/services/accounting.js index 6bcc289..4fe2462 100644 --- a/src/services/accounting.js +++ b/src/services/accounting.js @@ -3,8 +3,8 @@ */ export const Accounting = { create: ({ serviceURL }) => ({ - record: async (cid, options) => { - console.log(`using ${serviceURL} to record a GET for ${cid} with options`, options) + record: async (cid, bytes, servedAt) => { + console.log(`using ${serviceURL} to record egress for ${cid} with total bytes: ${bytes} and servedAt: ${servedAt}`) }, getTokenMetadata: async () => { diff --git a/test/fixtures/worker-fixture.js b/test/fixtures/worker-fixture.js index 8c7eeaa..c5e9632 100644 --- a/test/fixtures/worker-fixture.js +++ b/test/fixtures/worker-fixture.js @@ -12,6 +12,8 @@ const __dirname = path.dirname(__filename) */ const wranglerEnv = process.env.WRANGLER_ENV || 'integration' +const DEBUG = process.env.DEBUG === 'true' + /** * Worker information object * @typedef {Object} WorkerInfo @@ -41,7 +43,7 @@ export const mochaGlobalSetup = async () => { ) console.log(`Output: ${await workerInfo.getOutput()}`) console.log('WorkerInfo:', workerInfo) - console.log('Test worker started!') + console.log(`Test worker started! ENV: ${wranglerEnv}, DEBUG: ${DEBUG}`) } catch (error) { console.error('Failed to start test worker:', error) throw error @@ -59,7 +61,9 @@ export const mochaGlobalTeardown = async () => { try { const { stop } = workerInfo await stop?.() - // console.log('getOutput', getOutput()) // uncomment for debugging + if (DEBUG) { + console.log('getOutput', await workerInfo.getOutput()) + } console.log('Test worker stopped!') } catch (error) { console.error('Failed to stop test worker:', error) diff --git a/test/unit/middleware/withEgressTracker.spec.js b/test/unit/middleware/withEgressTracker.spec.js new file mode 100644 index 0000000..9e6dcbf --- /dev/null +++ b/test/unit/middleware/withEgressTracker.spec.js @@ -0,0 +1,462 @@ +/* eslint-disable no-unused-expressions + --- + `no-unused-expressions` doesn't understand that several of Chai's assertions + are implemented as getters rather than explicit function calls; it thinks + the assertions are unused expressions. */ +import { randomBytes } from 'node:crypto' +import { describe, it, afterEach, before } from 'mocha' +import { assert, expect } from 'chai' +import sinon from 'sinon' +import { CID } from 'multiformats' +import { withEgressTracker } from '../../../src/middleware/withEgressTracker.js' +import { Builder, toBlobKey } from '../../helpers/builder.js' +import { CARReaderStream } from 'carstream' + +/** + * Creates a request with an optional authorization header. + * + * @param {Object} [options] + * @param {string} [options.authorization] The value for the `Authorization` + * header, if any. + */ +const createRequest = async ({ authorization } = {}) => + new Request('http://doesnt-matter.com/', { + headers: new Headers( + authorization ? { Authorization: authorization } : {} + ) + }) + +const env = + /** @satisfies {import('../../../src/middleware/withEgressTracker.types.js').Environment} */ + ({ + DEBUG: 'true', + ACCOUNTING_SERVICE_URL: 'http://example.com', + FF_EGRESS_TRACKER_ENABLED: 'true' + }) + +const accountingRecordMethodStub = sinon.stub() + .returns( + /** @type {import('../../../src/bindings.js').AccountingService['record']} */ + async (cid, bytes, servedAt) => { + console.log(`[mock] record called with cid: ${cid}, bytes: ${bytes}, servedAt: ${servedAt}`) + }) + +/** + * Mock implementation of the AccountingService. + * + * @param {Object} options + * @param {string} options.serviceURL - The URL of the accounting service. + * @returns {import('../../../src/bindings.js').AccountingService} + */ +const AccountingService = ({ serviceURL }) => { + console.log(`[mock] Accounting.create called with serviceURL: ${serviceURL}`) + + return { + record: accountingRecordMethodStub, + getTokenMetadata: sinon.stub().resolves(undefined) + } +} + +const ctx = + /** @satisfies {import('../../../src/middleware/withEgressTracker.js').EgressTrackerContext} */ + ({ + dataCid: CID.parse('bafybeibv7vzycdcnydl5n5lbws6lul2omkm6a6b5wmqt77sicrwnhesy7y'), + waitUntil: sinon.stub().returns(undefined), + path: '', + searchParams: new URLSearchParams(), + ACCOUNTING_SERVICE: AccountingService({ serviceURL: env.ACCOUNTING_SERVICE_URL }) + }) + +describe('withEgressTracker', async () => { + /** @type {Builder} */ + let builder + /** @type {Map} */ + let bucketData + /** @type {{ put: (digest: string, bytes: Uint8Array) => Promise, get: (digest: string) => Promise }} */ + let bucket + + before(async () => { + bucketData = new Map() + bucket = { + put: async (/** @type {string} */ digest, /** @type {Uint8Array} */ bytes) => { + console.log(`[mock] bucket.put called with digest: ${digest}, bytes: ${bytes.byteLength}`) + bucketData.set(digest, bytes) + return Promise.resolve() + }, + // @ts-expect-error - don't need to check the type of the fake bucket + get: async (/** @type {string} */ blobKey) => { + console.log(`[mock] bucket.get called with digest: ${blobKey}`) + return Promise.resolve(bucketData.get(blobKey)) + } + } + builder = new Builder(bucket) + }) + + afterEach(() => { + accountingRecordMethodStub.reset() + bucketData.clear() + }) + + describe('withEgressTracker -> Successful Requests', () => { + it('should track egress bytes for a successful request', async () => { + const content = new TextEncoder().encode('Hello, world!') + const totalBytes = Buffer.byteLength(content) + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + const response = await handler(request, env, ctx) + // Ensure the response body is fully consumed + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][0], 'first argument should be the cid').to.equal(ctx.dataCid) + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(10_000) + + it('should record egress for a large file', async () => { + const largeContent = new Uint8Array(100 * 1024 * 1024) // 100 MB + const totalBytes = largeContent.byteLength + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(largeContent) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + await response.text() // Consume the response body + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][0], 'first argument should be the cid').to.equal(ctx.dataCid) + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }) + + it('should correctly track egress for responses with chunked transfer encoding', async () => { + const chunk1 = new TextEncoder().encode('Hello, ') + const chunk2 = new TextEncoder().encode('world!') + const totalBytes = Buffer.byteLength(chunk1) + Buffer.byteLength(chunk2) + + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(chunk1) + controller.enqueue(chunk2) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }) + + it('should record egress bytes for a CAR file request', async () => { + // Simulate a CAR file content + const carContent = new Blob([randomBytes(256)]) + const { shards } = await builder.add(carContent) + assert.equal(shards.length, 1) + + const key = toBlobKey(shards[0].multihash) + /** @type {Uint8Array | undefined} */ + const carBytes = await bucket.get(key) + expect(carBytes).to.be.not.undefined + expect(carBytes).to.be.instanceOf(Uint8Array) + const expectedTotalBytes = carBytes.byteLength + + // Mock a response with the CAR file content + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(carBytes) + controller.close() + } + }), { + status: 200, + headers: { 'Content-Type': 'application/vnd.ipld.car; version=1;' } + }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + expect(response.status).to.equal(200) + + // Consume the response body by reading the CAR file + const source = /** @type {ReadableStream} */ (await response.body) + + /** @type {(import('carstream').Block & import('carstream').Position)[]} */ + const blocks = [] + await source + .pipeThrough(new CARReaderStream()) + .pipeTo(new WritableStream({ + write: (block) => { blocks.push(block) } + })) + + // expect(blocks[0].bytes).to.deep.equal(carBytes) - FIXME (fforbeck): how to get the correct byte count? + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(expectedTotalBytes) + }) + + it('should correctly track egress for delayed responses', async () => { + const content = new TextEncoder().encode('Delayed response content') + const totalBytes = Buffer.byteLength(content) + + const mockResponse = new Response(new ReadableStream({ + start (controller) { + setTimeout(() => { + controller.enqueue(content) + controller.close() + }, 2000) // Simulate a delay of 2 seconds + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Delayed response content') + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(5000) + }) + + describe('withEgressTracker -> Feature Flag', () => { + it('should not track egress if the feature flag is disabled', async () => { + const innerHandler = sinon.stub().resolves(new Response(null, { status: 200 })) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + const envDisabled = { ...env, FF_EGRESS_TRACKER_ENABLED: 'false' } + + const response = await handler(request, envDisabled, ctx) + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.notCalled, 'record should not be called').to.be.true + }) + }) + + describe('withEgressTracker -> Non-OK Responses', () => { + it('should not track egress for non-OK responses', async () => { + const mockResponse = new Response(null, { status: 404 }) + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + expect(response.status).to.equal(404) + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + + it('should not track egress if the response has no body', async () => { + const mockResponse = new Response(null, { status: 200 }) + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + expect(response.status).to.equal(200) + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) + + describe('withEgressTracker -> Concurrent Requests', () => { + it('should correctly track egress for multiple concurrent requests', async () => { + const content1 = new TextEncoder().encode('Hello, world!') + const content2 = new TextEncoder().encode('Goodbye, world!') + const totalBytes1 = Buffer.byteLength(content1) + const totalBytes2 = Buffer.byteLength(content2) + + const mockResponse1 = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content1) + controller.close() + } + }), { status: 200 }) + + const mockResponse2 = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content2) + controller.close() + } + }), { status: 200 }) + + const innerHandler1 = sinon.stub().resolves(mockResponse1) + const innerHandler2 = sinon.stub().resolves(mockResponse2) + + const handler1 = withEgressTracker(innerHandler1) + const handler2 = withEgressTracker(innerHandler2) + + const request1 = await createRequest() + const request2 = await createRequest() + + const [response1, response2] = await Promise.all([ + handler1(request1, env, ctx), + handler2(request2, env, ctx) + ]) + + const responseBody1 = await response1.text() + const responseBody2 = await response2.text() + + expect(response1.status).to.equal(200) + expect(responseBody1).to.equal('Hello, world!') + expect(response2.status).to.equal(200) + expect(responseBody2).to.equal('Goodbye, world!') + + expect(accountingRecordMethodStub.calledTwice, 'record should be called twice').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes for first request').to.equal(totalBytes1) + expect(accountingRecordMethodStub.args[1][1], 'second argument should be the total bytes for second request').to.equal(totalBytes2) + }).timeout(10_000) + }) + + describe('withEgressTracker -> Different Content Types', () => { + it('should track egress for JSON content type', async () => { + const jsonContent = JSON.stringify({ message: 'Hello, JSON!' }) + const totalBytes = Buffer.byteLength(jsonContent) + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(new TextEncoder().encode(jsonContent)) + controller.close() + } + }), { status: 200, headers: { 'Content-Type': 'application/json' } }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.json() + + expect(response.status).to.equal(200) + expect(responseBody).to.deep.equal({ message: 'Hello, JSON!' }) + expect(accountingRecordMethodStub.calledOnce, 'record should be called once').to.be.true + expect(accountingRecordMethodStub.args[0][1], 'second argument should be the total bytes').to.equal(totalBytes) + }).timeout(10_000) + }) + + describe('withEgressTracker -> Zero-byte Responses', () => { + it('should not record egress for zero-byte responses', async () => { + const mockResponse = new Response(new ReadableStream({ + start (controller) { + // Do not enqueue any data, simulating a zero-byte response + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('') + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) + + describe('withEgressTracker -> Interrupted Connection', () => { + it('should not record egress if there is a stream error while downloading', async () => { + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.error(new Error('Stream error')) + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + const response = await handler(request, env, ctx) + + try { + // Consume the response body to trigger the error + await response.text() + expect.fail('Expected a stream error') + } catch (/** @type {any} */ error) { + expect(error.message).to.equal('Stream error') + } + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + + it('should not record egress if the connection is interrupted during a large file download', async () => { + const largeContent = new Uint8Array(100 * 1024 * 1024) // 100 MB + const mockResponse = new Response(new ReadableStream({ + start (controller) { + // Stream a portion of the content + controller.enqueue(largeContent.subarray(0, 10 * 1024 * 1024)) // 10 MB + // Simulate connection interruption by raising an error + controller.error(new Error('Connection interrupted')) + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + const response = await handler(request, env, ctx) + + try { + // Consume the response body to trigger the error + await response.text() + expect.fail('Expected a connection interrupted error') + } catch (/** @type {any} */ error) { + expect(error.message).to.equal('Connection interrupted') + } + + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }).timeout(10_000) + }) + + describe('withEgressTracker -> Accounting Service', () => { + it('should log an error and continue serving the response if the accounting service fails', async () => { + const content = new TextEncoder().encode('Hello, world!') + const mockResponse = new Response(new ReadableStream({ + start (controller) { + controller.enqueue(content) + controller.close() + } + }), { status: 200 }) + + const innerHandler = sinon.stub().resolves(mockResponse) + const handler = withEgressTracker(innerHandler) + const request = await createRequest() + + // Simulate an error in the accounting service record method + ctx.ACCOUNTING_SERVICE.record = sinon.stub().rejects(new Error('Accounting service error')) + + const response = await handler(request, env, ctx) + const responseBody = await response.text() + + expect(response.status).to.equal(200) + expect(responseBody).to.equal('Hello, world!') + expect(accountingRecordMethodStub.called, 'record should not be called').to.be.false + }) + }) +}) diff --git a/wrangler.toml b/wrangler.toml index 2f30a96..7b2918a 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -43,6 +43,7 @@ command = "npm run build" [env.production.vars] MAX_SHARDS = "825" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://claims.web3.storage" # Staging! @@ -59,6 +60,7 @@ command = "npm run build" [env.staging.vars] MAX_SHARDS = "825" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" # Test! @@ -71,6 +73,7 @@ r2_buckets = [ [env.test.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" MAX_SHARDS = "120" CONTENT_CLAIMS_SERVICE_URL = "https://test.claims.web3.storage" @@ -84,6 +87,7 @@ r2_buckets = [ [env.alanshaw.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" +FF_EGRESS_TRACKER_ENABLED = "false" CONTENT_CLAIMS_SERVICE_URL = "https://dev.claims.web3.storage" [env.fforbeck] @@ -100,7 +104,8 @@ r2_buckets = [ [env.fforbeck.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "false" -CONTENT_CLAIMS_SERVICE_URL = "https://dev.claims.web3.storage" +FF_EGRESS_TRACKER_ENABLED = "false" +CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" [[env.fforbeck.unsafe.bindings]] name = "RATE_LIMITER" @@ -125,7 +130,9 @@ r2_buckets = [ [env.integration.vars] DEBUG = "true" FF_RATE_LIMITER_ENABLED = "true" +FF_EGRESS_TRACKER_ENABLED = "true" CONTENT_CLAIMS_SERVICE_URL = "https://staging.claims.web3.storage" +ACCOUNTING_SERVICE_URL = "https://example.com/service" [[env.integration.unsafe.bindings]] name = "RATE_LIMITER"