Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/src/firebase.dart
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ extension FirebaseX on Firebase {
bool external = false,
String? documentPattern,
String? refPattern,
List<String>? allowedOrigins,
}) {
// Check for duplicate function names
if (functions.any((f) => f.name == name)) {
Expand All @@ -195,6 +196,7 @@ extension FirebaseX on Firebase {
name: transformedName,
handler: handler,
external: external,
allowedOrigins: allowedOrigins,
documentPattern: documentPattern,
refPattern: refPattern,
),
Expand All @@ -214,6 +216,7 @@ final class FirebaseFunctionDeclaration {
required this.name,
required this.handler,
required this.external,
this.allowedOrigins,
this.documentPattern,
this.refPattern,
}) : path = name;
Expand All @@ -238,6 +241,9 @@ final class FirebaseFunctionDeclaration {
/// Event-driven functions are internal (false, POST only).
final bool external;

/// Allowed origins for CORS (if specified).
final List<String>? allowedOrigins;

/// The function handler.
final FirebaseFunctionHandler handler;
}
Expand Down
27 changes: 16 additions & 11 deletions lib/src/https/https_namespace.dart
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,20 @@ class HttpsNamespace extends FunctionsNamespace {
// ignore: experimental_member_use
@mustBeConst HttpsOptions? options = const HttpsOptions(),
}) {
firebase.registerFunction(name, (request) async {
try {
return await handler(request);
} on HttpsError catch (e) {
return e.toShelfResponse();
} catch (e, stackTrace) {
return logInternalError(e, stackTrace).toShelfResponse();
}
}, external: true);
firebase.registerFunction(
name,
(request) async {
try {
return await handler(request);
} on HttpsError catch (e) {
return e.toShelfResponse();
} catch (e, stackTrace) {
return logInternalError(e, stackTrace).toShelfResponse();
}
},
external: true,
allowedOrigins: options?.cors?.runtimeValue(),
);
}

/// Creates an HTTPS callable function (untyped data).
Expand Down Expand Up @@ -141,7 +146,7 @@ class HttpsNamespace extends FunctionsNamespace {
(result) => result.data,
(result) => result.toResponse(),
);
});
}, allowedOrigins: options?.cors?.runtimeValue());
}

/// Creates an HTTPS callable function with typed data.
Expand Down Expand Up @@ -225,7 +230,7 @@ class HttpsNamespace extends FunctionsNamespace {
headers: {'Content-Type': 'application/json'},
),
);
});
}, allowedOrigins: options?.cors?.runtimeValue());
}

/// Internal handler for callable functions.
Expand Down
76 changes: 64 additions & 12 deletions lib/src/server.dart
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ const _corsAnyOriginHeaders = {
'Access-Control-Allow-Headers': '*',
};

Response _buildOptionsCorsResponse(
Request request,
List<String> allowedOrigins,
) => Response.ok('', headers: corsHeadersFor(request, allowedOrigins));

Response _applyCorsHeaders(
Request request,
Response response,
List<String> allowedOrigins,
) => response.change(headers: corsHeadersFor(request, allowedOrigins));

@visibleForTesting
Map<String, String> corsHeadersFor(
Request request,
List<String> allowedOrigins,
) {
if (allowedOrigins.contains('*')) {
return _corsAnyOriginHeaders;
}

final origin = request.headers['origin'];
if (origin != null && allowedOrigins.contains(origin)) {
return {
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Methods': '*',
'Access-Control-Allow-Headers': '*',
};
}

return const {};
}

/// Routes incoming requests to the appropriate function handler.
FutureOr<Response> _routeRequest(
Request request,
Expand Down Expand Up @@ -144,7 +176,7 @@ FutureOr<Response> _routeToTargetFunction(
Firebase firebase,
FirebaseEnv env,
String functionTarget,
) {
) async {
final functions = firebase.functions;

// Find the function with matching name
Expand All @@ -165,6 +197,11 @@ FutureOr<Response> _routeToTargetFunction(
// from the Node.js model does not apply here.

// Validate HTTP method for event functions
if (request.method.toUpperCase() == 'OPTIONS' &&
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this fixes the validation order too

targetFunction.allowedOrigins != null) {
return _buildOptionsCorsResponse(request, targetFunction.allowedOrigins!);
}

if (!targetFunction.external && request.method.toUpperCase() != 'POST') {
return Response(
405,
Expand All @@ -173,10 +210,12 @@ FutureOr<Response> _routeToTargetFunction(
);
}

// Execute the target function (all requests go to this function)
// Wrap with onInit to ensure initialization callback runs before first execution
final wrappedHandler = withInit(targetFunction.handler);
return wrappedHandler(request);
final response = await wrappedHandler(request);
if (targetFunction.allowedOrigins != null) {
return _applyCorsHeaders(request, response, targetFunction.allowedOrigins!);
}
return response;
}

/// Routes request by path matching (development/shared process mode).
Expand Down Expand Up @@ -215,16 +254,29 @@ FutureOr<Response> _routeByPath(

// Try to find a matching function by name
for (final function in functions) {
// Internal functions (events) only accept POST requests
if (!function.external && currentRequest.method.toUpperCase() != 'POST') {
continue;
}

// Match by function name
if (functionName == function.name) {
// Wrap with onInit to ensure initialization callback runs before first execution
if (currentRequest.method.toUpperCase() == 'OPTIONS' &&
function.allowedOrigins != null) {
return _buildOptionsCorsResponse(
currentRequest,
function.allowedOrigins!,
);
}

if (!function.external && currentRequest.method.toUpperCase() != 'POST') {
continue;
}

final wrappedHandler = withInit(function.handler);
return wrappedHandler(currentRequest);
final response = await wrappedHandler(currentRequest);
if (function.allowedOrigins != null) {
return _applyCorsHeaders(
currentRequest,
response,
function.allowedOrigins!,
);
}
return response;
}
}

Expand Down
17 changes: 16 additions & 1 deletion test/unit/https_namespace_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ void main() {
(request) async => Response.ok('OK'),
);

expect(_findFunction(firebase, 'options-function'), isNotNull);
final func = _findFunction(firebase, 'options-function')!;
expect(func.allowedOrigins, ['https://example.com']);
});

test('CallableOptions can be provided', () {
Expand All @@ -399,6 +400,20 @@ void main() {

expect(_findFunction(firebase, 'callable-options-function'), isNotNull);
});

test('CallableOptions can be provided passing allowedOrigins', () {
https.onCall(
name: 'callableOptionsFunctionWithOrigins',
options: const CallableOptions(cors: Cors(['https://example.com'])),
(request, response) async => CallableResult('OK'),
);

final func = _findFunction(
firebase,
'callable-options-function-with-origins',
)!;
expect(func.allowedOrigins, ['https://example.com']);
});
});
});
}
Expand Down
29 changes: 29 additions & 0 deletions test/unit/server_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

import 'package:firebase_functions/src/server.dart';
import 'package:shelf/shelf.dart';
import 'package:test/test.dart';

void main() {
Expand Down Expand Up @@ -49,5 +50,33 @@ void main() {
expect(extractTraceId('1234567890xyzdef1234567890abcdef/5'), isNull);
});
});

group('corsHeadersFor', () {
test('returns asterisk when allowedOrigins contains asterisk', () {
final request = Request('GET', Uri.parse('http://localhost/test'));
final headers = corsHeadersFor(request, ['*']);
expect(headers['Access-Control-Allow-Origin'], '*');
});

test('echoes the Origin header if it matches allowedOrigins', () {
final request = Request(
'GET',
Uri.parse('http://localhost/test'),
headers: {'origin': 'https://example.com'},
);
final headers = corsHeadersFor(request, ['https://example.com']);
expect(headers['Access-Control-Allow-Origin'], 'https://example.com');
});

test('returns empty map if no match is found', () {
final request = Request(
'GET',
Uri.parse('http://localhost/test'),
headers: {'origin': 'https://evil.com'},
);
final headers = corsHeadersFor(request, ['https://example.com']);
expect(headers, isEmpty);
});
});
});
}
Loading