Skip to content

Commit

Permalink
Make EC a class
Browse files Browse the repository at this point in the history
  • Loading branch information
grynspan committed Jan 25, 2025
1 parent 3009cf2 commit b0be207
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 124 deletions.
1 change: 1 addition & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ extension Array where Element == PackageDescription.SwiftSetting {
.unsafeFlags(["-require-explicit-sendable"]),
.enableUpcomingFeature("ExistentialAny"),
.enableExperimentalFeature("SuppressedAssociatedTypes"),
.enableExperimentalFeature("NonescapableTypes"),

.enableExperimentalFeature("AccessLevelOnImport"),
.enableUpcomingFeature("InternalImportsByDefault"),
Expand Down
28 changes: 14 additions & 14 deletions Sources/Testing/Expectations/ExpectationChecking+Macro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ func check(
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkCondition(
_ condition: (inout __ExpectationContext) throws -> Bool,
_ condition: (__ExpectationContext) throws -> Bool,
sourceCode: @escaping @autoclosure @Sendable () -> [__ExpressionID: String],
comments: @autoclosure () -> [Comment],
isRequired: Bool,
sourceLocation: SourceLocation
) rethrows -> Result<Void, any Error> {
var expectationContext = __ExpectationContext.init(sourceCode: sourceCode())
let condition = try condition(&expectationContext)
let expectationContext = __ExpectationContext.init(sourceCode: sourceCode())
let condition = try condition(expectationContext)

return check(
condition,
Expand All @@ -131,14 +131,14 @@ public func __checkCondition(
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkCondition<T>(
_ optionalValue: (inout __ExpectationContext) throws -> T?,
_ optionalValue: (__ExpectationContext) throws -> T?,
sourceCode: @escaping @autoclosure @Sendable () -> [__ExpressionID: String],
comments: @autoclosure () -> [Comment],
isRequired: Bool,
sourceLocation: SourceLocation
) rethrows -> Result<T, any Error> where T: ~Copyable {
var expectationContext = __ExpectationContext(sourceCode: sourceCode())
let optionalValue = try optionalValue(&expectationContext)
let expectationContext = __ExpectationContext(sourceCode: sourceCode())
let optionalValue = try optionalValue(expectationContext)

let result = check(
optionalValue != nil,
Expand Down Expand Up @@ -166,15 +166,15 @@ public func __checkCondition<T>(
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkConditionAsync(
_ condition: (inout __ExpectationContext) async throws -> Bool,
_ condition: (__ExpectationContext) async throws -> Bool,
sourceCode: @escaping @autoclosure @Sendable () -> [__ExpressionID: String],
comments: @autoclosure () -> [Comment],
isRequired: Bool,
isolation: isolated (any Actor)? = #isolation,
sourceLocation: SourceLocation
) async rethrows -> Result<Void, any Error> {
var expectationContext = __ExpectationContext(sourceCode: sourceCode())
let condition = try await condition(&expectationContext)
let expectationContext = __ExpectationContext(sourceCode: sourceCode())
let condition = try await condition(expectationContext)

return check(
condition,
Expand All @@ -193,15 +193,15 @@ public func __checkConditionAsync(
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkConditionAsync<T>(
_ optionalValue: (inout __ExpectationContext) async throws -> sending T?,
_ optionalValue: (__ExpectationContext) async throws -> sending T?,
sourceCode: @escaping @autoclosure @Sendable () -> [__ExpressionID: String],
comments: @autoclosure () -> [Comment],
isRequired: Bool,
isolation: isolated (any Actor)? = #isolation,
sourceLocation: SourceLocation
) async rethrows -> Result<T, any Error> where T: ~Copyable {
var expectationContext = __ExpectationContext(sourceCode: sourceCode())
let optionalValue = try await optionalValue(&expectationContext)
let expectationContext = __ExpectationContext(sourceCode: sourceCode())
let optionalValue = try await optionalValue(expectationContext)

let result = check(
optionalValue != nil,
Expand Down Expand Up @@ -516,7 +516,7 @@ public func __checkClosureCall<R>(
isRequired: Bool,
sourceLocation: SourceLocation
) -> Result<(any Error)?, any Error> {
var expectationContext = __ExpectationContext(sourceCode: sourceCode())
let expectationContext = __ExpectationContext(sourceCode: sourceCode())

var errorMatches = false
var mismatchExplanationValue: String? = nil
Expand Down Expand Up @@ -569,7 +569,7 @@ public func __checkClosureCall<R>(
isolation: isolated (any Actor)? = #isolation,
sourceLocation: SourceLocation
) async -> Result<(any Error)?, any Error> {
var expectationContext = __ExpectationContext(sourceCode: sourceCode())
let expectationContext = __ExpectationContext(sourceCode: sourceCode())

var errorMatches = false
var mismatchExplanationValue: String? = nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func callAsFunction<P>(_ value: P, _ id: __ExpressionID) -> P where P: _Pointer {
@inlinable public func callAsFunction<P>(_ value: consuming P, _ id: __ExpressionID) -> P where P: _Pointer {
captureValue(value, id)
}

Expand All @@ -49,7 +49,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func callAsFunction(_ value: String, _ id: __ExpressionID) -> String {
@inlinable public func callAsFunction(_ value: consuming String, _ id: __ExpressionID) -> String {
captureValue(value, id)
}

Expand All @@ -70,7 +70,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func callAsFunction<E>(_ value: Array<E>, _ id: __ExpressionID) -> Array<E> {
@inlinable public func callAsFunction<E>(_ value: consuming Array<E>, _ id: __ExpressionID) -> Array<E> {
captureValue(value, id)
}

Expand All @@ -90,7 +90,7 @@ extension __ExpectationContext {
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@_disfavoredOverload
@inlinable public mutating func callAsFunction<T>(_ value: T?, _ id: __ExpressionID) -> T? {
@inlinable public func callAsFunction<T>(_ value: consuming T?, _ id: __ExpressionID) -> T? {
captureValue(value, id)
}
}
Expand Down
26 changes: 13 additions & 13 deletions Sources/Testing/Expectations/ExpectationContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
///
/// - Warning: This type is used to implement the `#expect()` and `#require()`
/// macros. Do not use it directly.
public struct __ExpectationContext: ~Copyable {
public final class __ExpectationContext {
/// The source code representations of any captured expressions.
///
/// Unlike the rest of the state in this type, the source code dictionary is
Expand Down Expand Up @@ -167,7 +167,7 @@ extension __ExpectationContext {
///
/// This function helps overloads of `callAsFunction(_:_:)` disambiguate
/// themselves and avoid accidental recursion.
@usableFromInline mutating func captureValue<T>(_ value: T, _ id: __ExpressionID) -> T {
@usableFromInline func captureValue<T>(_ value: T, _ id: __ExpressionID) -> T {
runtimeValues[id] = { Expression.Value(reflecting: value) }
return value
}
Expand All @@ -185,7 +185,7 @@ extension __ExpectationContext {
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@_disfavoredOverload
@inlinable public mutating func callAsFunction<T>(_ value: T, _ id: __ExpressionID) -> T {
@inlinable public func callAsFunction<T>(_ value: T, _ id: __ExpressionID) -> T {
captureValue(value, id)
}

Expand All @@ -203,7 +203,7 @@ extension __ExpectationContext {
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@_disfavoredOverload
public mutating func callAsFunction<T>(_ value: consuming T, _ id: __ExpressionID) -> T where T: ~Copyable {
public func callAsFunction<T>(_ value: consuming T, _ id: __ExpressionID) -> T where T: ~Copyable {
// TODO: add support for borrowing non-copyable expressions (need @lifetime)
return value
}
Expand All @@ -219,7 +219,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public mutating func __inoutAfter<T>(_ value: T, _ id: __ExpressionID) {
public func __inoutAfter<T>(_ value: T, _ id: __ExpressionID) {
runtimeValues[id] = { Expression.Value(reflecting: value, timing: .after) }
}
}
Expand Down Expand Up @@ -272,14 +272,14 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func __cmp<T, U, R>(
_ op: (T, U) throws -> R,
@inlinable public func __cmp<T, U>(
_ op: (T, U) throws -> Bool,
_ opID: __ExpressionID,
_ lhs: T,
_ lhsID: __ExpressionID,
_ rhs: U,
_ rhsID: __ExpressionID
) rethrows -> R {
) rethrows -> Bool {
try captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
}

Expand All @@ -290,7 +290,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public mutating func __cmp<C>(
public func __cmp<C>(
_ op: (C, C) -> Bool,
_ opID: __ExpressionID,
_ lhs: C,
Expand All @@ -315,7 +315,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func __cmp<R>(
@inlinable public func __cmp<R>(
_ op: (R, R) -> Bool,
_ opID: __ExpressionID,
_ lhs: R,
Expand All @@ -334,7 +334,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public mutating func __cmp<S>(
public func __cmp<S>(
_ op: (S, S) -> Bool,
_ opID: __ExpressionID,
_ lhs: S,
Expand Down Expand Up @@ -392,7 +392,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func __as<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
@inlinable public func __as<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
let result = captureValue(value, valueID) as? U

if result == nil {
Expand Down Expand Up @@ -421,7 +421,7 @@ extension __ExpectationContext {
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
@inlinable public mutating func __is<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
@inlinable public func __is<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
let result = captureValue(value, valueID) is U

if !result {
Expand Down
31 changes: 12 additions & 19 deletions Sources/TestingMacros/Support/ConditionArgumentParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,10 @@ extension ConditionMacro {
// If we're inserting any additional code into the closure before
// the rewritten argument, we can't elide the return keyword.
ReturnStmtSyntax(
expression: expandedExpr.with(\.leadingTrivia, .space)
).with(\.trailingTrivia, .newline)
returnKeyword: .keyword(.return, trailingTrivia: .space),
expression: expandedExpr,
trailingTrivia: .space
)
}
}

Expand Down Expand Up @@ -715,19 +717,11 @@ extension ConditionMacro {
parameters: ClosureParameterListSyntax {
ClosureParameterSyntax(
firstName: expressionContextName,
colon: .colonToken().with(\.trailingTrivia, .space),
colon: .colonToken(trailingTrivia: .space),
type: TypeSyntax(
AttributedTypeSyntax(
specifiers: [
TypeSpecifierListSyntax.Element(
SimpleTypeSpecifierSyntax(specifier: .keyword(.inout))
.with(\.trailingTrivia, .space)
)
],
baseType: MemberTypeSyntax(
baseType: IdentifierTypeSyntax(name: .identifier("Testing")),
name: .identifier("__ExpectationContext")
)
MemberTypeSyntax(
baseType: IdentifierTypeSyntax(name: .identifier("Testing")),
name: .identifier("__ExpectationContext")
)
)
)
Expand All @@ -736,12 +730,11 @@ extension ConditionMacro {
),
returnClause: returnType.map { returnType in
ReturnClauseSyntax(
type: returnType.with(\.leadingTrivia, .space)
).with(\.leadingTrivia, .space)
arrow: .arrowToken(leadingTrivia: .space, trailingTrivia: .space),
type: returnType
)
},
inKeyword: .keyword(.in)
.with(\.leadingTrivia, .space)
.with(\.trailingTrivia, .newline)
inKeyword: .keyword(.in, leadingTrivia: .space, trailingTrivia: .space)
),
statements: codeBlockItems
)
Expand Down
26 changes: 22 additions & 4 deletions Tests/SubexpressionShowcase/SubexpressionShowcase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ struct T {
}

func subexpressionShowcase() async throws {
let fff = false
let ttt = true
#expect(false || true)

#expect((fff == ttt) == ttt)
Testing.__checkCondition({(__ec: Testing.__ExpectationContext) -> Swift.Bool in
__ec.__cmp(==,0x0,__ec((__ec.__cmp(==,0x3a,__ec(fff,0x7a),0x7a,__ec(ttt,0x43a),0x43a)),0x2),0x2,__ec(ttt,0x8000),0x8000)
},sourceCode: [0x0:"(fff == ttt) == ttt",0x2:"(fff == ttt)",0x3a:"fff == ttt",0x7a:"fff",0x43a:"ttt",0x8000:"ttt"],comments: [],isRequired: false,sourceLocation: Testing.SourceLocation.__here()).__expected()

#expect((Int)(123) == 124)
#expect((Int, Double)(123, 456.0) == (124, 457.0))
#expect((123, 456) == (789, 0x12))
Expand Down Expand Up @@ -101,11 +109,21 @@ func subexpressionShowcase() async throws {
}
#expect(await k(true))

func k2(_ x: @escaping @autoclosure () -> Bool) async -> Bool {
x()
}
#expect(await k2(true))

#if false
// Unsupported: __ec is necessarily inout and captures non-sendable state, so
// this will fail to compile. Making __ec a class instead is possible, but
// adds a very large amount of code and locking overhead for what we can
// assume is an edge case.
// Unsupported: __ec necessarily captures non-sendable state, so this will
// fail to compile because it is capturing __ec in a sendable closure. We
// could add locks guarding __ec's mutable state and eagerly capture state,
// but that would slow down tests significantly. The type checker cannot
// handle the number of `where T: Sendable` overloads of various functions
// that we would need in order to provide eager capture only for non-sendable
// values. However, this is a relatively narrow case, so for now we'll just
// accept it as unsupported and tell affected test authors to refactor their
// expectations so as to call m(_:) _before_ #expect().
func m(_ x: @autoclosure @Sendable () -> Bool) -> Bool {
x()
}
Expand Down
Loading

0 comments on commit b0be207

Please sign in to comment.