Skip to content

Commit 767dd3e

Browse files
Centralize type dispatch in ClientEvaluator, fix addApiHook singleton bypass, simplify TestFeatureProvider Hub (#109)
1 parent a16c808 commit 767dd3e

5 files changed

Lines changed: 247 additions & 111 deletions

File tree

core/src/main/scala/zio/openfeature/FeatureFlags.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ trait FeatureFlags {
159159
def clearHooks: UIO[Unit]
160160
def hooks: UIO[List[FeatureHook]]
161161

162+
/** Add an API-level hook that applies to all clients sharing this instance's OpenFeatureAPI. */
163+
def addApiHook(hook: dev.openfeature.sdk.Hook[_]): UIO[Unit]
164+
165+
/** Clear all API-level hooks on this instance's OpenFeatureAPI. */
166+
def clearApiHooks: UIO[Unit]
167+
162168
/** Replace the underlying provider at runtime.
163169
*
164170
* The old provider is shut down, the new provider is initialized, and the status transitions through `NotReady`
@@ -367,13 +373,11 @@ object FeatureFlags {
367373

368374
// API-level Hooks (per OpenFeature spec 4.4.1)
369375

370-
/** Add an API-level hook that applies to all clients. */
371-
def addApiHook(hook: dev.openfeature.sdk.Hook[_]): UIO[Unit] =
372-
ZIO.succeed(OpenFeatureAPI.getInstance().addHooks(hook))
376+
def addApiHook(hook: dev.openfeature.sdk.Hook[_]): ZIO[FeatureFlags, Nothing, Unit] =
377+
ZIO.serviceWithZIO(_.addApiHook(hook))
373378

374-
/** Clear all API-level hooks. */
375-
def clearApiHooks: UIO[Unit] =
376-
ZIO.succeed(OpenFeatureAPI.getInstance().clearHooks())
379+
def clearApiHooks: ZIO[FeatureFlags, Nothing, Unit] =
380+
ZIO.serviceWithZIO(_.clearApiHooks)
377381

378382
// Tracking API
379383

core/src/main/scala/zio/openfeature/FeatureFlagsLive.scala

Lines changed: 73 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -276,27 +276,6 @@ final private[openfeature] class FeatureFlagsLive(
276276
_ <- txState.record(eval)
277277
} yield resolution
278278

279-
// Evaluate a flag using a ClientEvaluator typeclass instance for type-safe SDK dispatch
280-
private def evaluateViaTypeclass[A](
281-
key: String,
282-
default: A,
283-
ofContext: dev.openfeature.sdk.EvaluationContext,
284-
timeout: Option[Duration] = None
285-
)(implicit ev: ClientEvaluator[A]): IO[FeatureFlagError, FlagResolution[A]] = {
286-
val rawEval = ev.evaluate(client, key, default, ofContext)
287-
val timedEval = timeout match {
288-
case Some(d) =>
289-
rawEval.disconnect
290-
.timeoutFail(new java.util.concurrent.TimeoutException(s"Evaluation of '$key' timed out after $d"))(d)
291-
case None => rawEval
292-
}
293-
timedEval
294-
.mapError(e => FeatureFlagError.ProviderError(e))
295-
.flatMap { details =>
296-
toFlagResolution(key, details).map(_.copy(value = ev.extractValue(details)))
297-
}
298-
}
299-
300279
private def evaluateFromClient[A: FlagType](
301280
key: String,
302281
default: A,
@@ -314,88 +293,78 @@ final private[openfeature] class FeatureFlagsLive(
314293
case None => effect
315294
}
316295

317-
val evaluation: IO[FeatureFlagError, FlagResolution[A]] = flagType.typeName match {
318-
case "Boolean" =>
319-
evaluateViaTypeclass(key, default.asInstanceOf[Boolean], ofContext, timeout)
320-
.map(_.asInstanceOf[FlagResolution[A]])
321-
322-
case "String" =>
323-
evaluateViaTypeclass(key, default.asInstanceOf[String], ofContext, timeout)
324-
.map(_.asInstanceOf[FlagResolution[A]])
325-
326-
case "Int" =>
327-
evaluateViaTypeclass(key, default.asInstanceOf[Int], ofContext, timeout)
328-
.map(_.asInstanceOf[FlagResolution[A]])
329-
330-
case "Long" =>
331-
evaluateViaTypeclass(key, default.asInstanceOf[Long], ofContext, timeout)
332-
.map(_.asInstanceOf[FlagResolution[A]])
333-
334-
case "Float" =>
335-
evaluateViaTypeclass(key, default.asInstanceOf[Float], ofContext, timeout)
336-
.map(_.asInstanceOf[FlagResolution[A]])
337-
338-
case "Double" =>
339-
evaluateViaTypeclass(key, default.asInstanceOf[Double], ofContext, timeout)
340-
.map(_.asInstanceOf[FlagResolution[A]])
341-
342-
case "Object" =>
343-
withTimeout(
344-
ZIO.attemptBlocking {
345-
val defaultValue = new dev.openfeature.sdk.Value(
346-
dev.openfeature.sdk.Structure.mapToStructure(
347-
default.asInstanceOf[Map[String, Any]].map { case (k, v) => k -> anyToObject(v) }.asJava
348-
)
349-
)
350-
client.getObjectDetails(key, defaultValue, ofContext)
296+
val evaluation: IO[FeatureFlagError, FlagResolution[A]] =
297+
ClientEvaluator.evaluateStandard[A](flagType.typeName, client, key, default, ofContext) match {
298+
case Some(erased) =>
299+
val timedEval = timeout match {
300+
case Some(d) =>
301+
erased.task.disconnect
302+
.timeoutFail(new java.util.concurrent.TimeoutException(s"Evaluation of '$key' timed out after $d"))(d)
303+
case None => erased.task
351304
}
352-
).mapError(e => FeatureFlagError.ProviderError(e))
353-
.flatMap { details =>
354-
val value = valueToMap(details.getValue)
355-
toFlagMetadata(details.getFlagMetadata).map { metadata =>
356-
FlagResolution(
357-
value = value.asInstanceOf[A],
358-
variant = Option(details.getVariant),
359-
reason = toResolutionReason(details.getReason),
360-
metadata = metadata,
361-
flagKey = key,
362-
errorCode = Option(details.getErrorCode).map(ErrorCodeConverter.fromJava),
363-
errorMessage = Option(details.getErrorMessage)
305+
timedEval
306+
.mapError(e => FeatureFlagError.ProviderError(e))
307+
.flatMap { details =>
308+
toFlagResolution(key, details).map(r => r.copy(value = erased.extract(details)))
309+
}
310+
311+
case None if flagType.typeName == "Object" =>
312+
withTimeout(
313+
ZIO.attemptBlocking {
314+
val defaultValue = new dev.openfeature.sdk.Value(
315+
dev.openfeature.sdk.Structure.mapToStructure(
316+
default.asInstanceOf[Map[String, Any]].map { case (k, v) => k -> anyToObject(v) }.asJava
317+
)
364318
)
319+
client.getObjectDetails(key, defaultValue, ofContext)
320+
}
321+
).mapError(e => FeatureFlagError.ProviderError(e))
322+
.flatMap { details =>
323+
val value = valueToMap(details.getValue)
324+
toFlagMetadata(details.getFlagMetadata).map { metadata =>
325+
FlagResolution(
326+
value = value.asInstanceOf[A],
327+
variant = Option(details.getVariant),
328+
reason = toResolutionReason(details.getReason),
329+
metadata = metadata,
330+
flagKey = key,
331+
errorCode = Option(details.getErrorCode).map(ErrorCodeConverter.fromJava),
332+
errorMessage = Option(details.getErrorMessage)
333+
)
334+
}
365335
}
366-
}
367336

368-
case _ =>
369-
// Custom type - try to decode from object
370-
withTimeout(
371-
ZIO.attemptBlocking {
372-
client.getObjectDetails(key, new dev.openfeature.sdk.Value(), ofContext)
373-
}
374-
).mapError(e => FeatureFlagError.ProviderError(e))
375-
.flatMap { details =>
376-
valueToAny(details.getValue) match {
377-
case Some(rawValue) =>
378-
flagType.decode(rawValue) match {
379-
case Right(decoded) =>
380-
toFlagMetadata(details.getFlagMetadata).map { metadata =>
381-
FlagResolution(
382-
value = decoded,
383-
variant = Option(details.getVariant),
384-
reason = toResolutionReason(details.getReason),
385-
metadata = metadata,
386-
flagKey = key,
387-
errorCode = Option(details.getErrorCode).map(ErrorCodeConverter.fromJava),
388-
errorMessage = Option(details.getErrorMessage)
389-
)
390-
}
391-
case Left(_) =>
392-
ZIO.fail(FeatureFlagError.TypeMismatch(key, flagType.typeName, "Object"))
393-
}
394-
case None =>
395-
ZIO.fail(FeatureFlagError.TypeMismatch(key, flagType.typeName, "null"))
337+
case None =>
338+
// Custom type - try to decode from object
339+
withTimeout(
340+
ZIO.attemptBlocking {
341+
client.getObjectDetails(key, new dev.openfeature.sdk.Value(), ofContext)
396342
}
397-
}
398-
}
343+
).mapError(e => FeatureFlagError.ProviderError(e))
344+
.flatMap { details =>
345+
valueToAny(details.getValue) match {
346+
case Some(rawValue) =>
347+
flagType.decode(rawValue) match {
348+
case Right(decoded) =>
349+
toFlagMetadata(details.getFlagMetadata).map { metadata =>
350+
FlagResolution(
351+
value = decoded,
352+
variant = Option(details.getVariant),
353+
reason = toResolutionReason(details.getReason),
354+
metadata = metadata,
355+
flagKey = key,
356+
errorCode = Option(details.getErrorCode).map(ErrorCodeConverter.fromJava),
357+
errorMessage = Option(details.getErrorMessage)
358+
)
359+
}
360+
case Left(_) =>
361+
ZIO.fail(FeatureFlagError.TypeMismatch(key, flagType.typeName, "Object"))
362+
}
363+
case None =>
364+
ZIO.fail(FeatureFlagError.TypeMismatch(key, flagType.typeName, "null"))
365+
}
366+
}
367+
}
399368

400369
// Check resolution error codes for provider-level failures (handles TOCTOU race
401370
// where checkProviderStatus passes but the Java SDK's internal state is stale)
@@ -716,6 +685,12 @@ final private[openfeature] class FeatureFlagsLive(
716685
override def hooks: UIO[List[FeatureHook]] =
717686
state.hooksRef.get
718687

688+
override def addApiHook(hook: dev.openfeature.sdk.Hook[_]): UIO[Unit] =
689+
ZIO.succeed(api.addHooks(hook))
690+
691+
override def clearApiHooks: UIO[Unit] =
692+
ZIO.succeed(api.clearHooks())
693+
719694
// Provider hooks (spec: provider hooks included in hook pipeline)
720695

721696
private def getProviderHooks: UIO[List[FeatureHook]] =

core/src/main/scala/zio/openfeature/internal/ClientEvaluator.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,40 @@ private[openfeature] trait ClientEvaluator[A] {
4040

4141
private[openfeature] object ClientEvaluator {
4242

43+
/** A standard-type evaluation produced by [[evaluateStandard]], with the typed extractor pre-applied to the caller's
44+
* `A`. The `asInstanceOf[A]` cast lives once here, not at each call site.
45+
*/
46+
final case class Erased[A](
47+
task: Task[FlagEvaluationDetails[_]],
48+
extract: FlagEvaluationDetails[_] => A
49+
)
50+
51+
/** Look up the evaluator for a standard type by name and produce the type-erased evaluation. Returns None for
52+
* non-standard types (Object, custom) which need special handling.
53+
*/
54+
def evaluateStandard[A](
55+
typeName: String,
56+
client: OFClient,
57+
key: String,
58+
default: A,
59+
context: dev.openfeature.sdk.EvaluationContext
60+
): Option[Erased[A]] = {
61+
def erased[T](ev: ClientEvaluator[T]): Erased[A] =
62+
Erased[A](
63+
ev.evaluate(client, key, default.asInstanceOf[T], context),
64+
details => ev.extractValue(details).asInstanceOf[A]
65+
)
66+
typeName match {
67+
case "Boolean" => Some(erased(booleanEvaluator))
68+
case "String" => Some(erased(stringEvaluator))
69+
case "Int" => Some(erased(intEvaluator))
70+
case "Long" => Some(erased(longEvaluator))
71+
case "Float" => Some(erased(floatEvaluator))
72+
case "Double" => Some(erased(doubleEvaluator))
73+
case _ => None
74+
}
75+
}
76+
4377
implicit val booleanEvaluator: ClientEvaluator[Boolean] = new ClientEvaluator[Boolean] {
4478
def evaluate(
4579
client: OFClient,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package zio.openfeature
2+
3+
import dev.openfeature.sdk.{
4+
EvaluationContext => OFEvaluationContext,
5+
EventProvider,
6+
HookContext => JavaHookContext,
7+
Metadata,
8+
OpenFeatureAPIFactory,
9+
ProviderEvaluation,
10+
ProviderState
11+
}
12+
import zio._
13+
import zio.test._
14+
import java.util.concurrent.atomic.AtomicInteger
15+
16+
object ApiHookIsolationSpec extends ZIOSpecDefault {
17+
18+
private class SimpleBooleanProvider(name: String, value: Boolean) extends EventProvider {
19+
@scala.annotation.nowarn("msg=deprecated")
20+
override def getMetadata: Metadata = new Metadata { override def getName: String = name }
21+
override def getState: ProviderState = ProviderState.READY
22+
override def initialize(ctx: OFEvaluationContext): Unit = ()
23+
override def shutdown(): Unit = ()
24+
25+
override def getBooleanEvaluation(
26+
key: String,
27+
defaultValue: java.lang.Boolean,
28+
ctx: OFEvaluationContext
29+
): ProviderEvaluation[java.lang.Boolean] =
30+
ProviderEvaluation.builder[java.lang.Boolean]().value(value).reason("STATIC").build()
31+
32+
override def getStringEvaluation(
33+
key: String,
34+
defaultValue: String,
35+
ctx: OFEvaluationContext
36+
): ProviderEvaluation[String] =
37+
ProviderEvaluation.builder[String]().value(defaultValue).reason("STATIC").build()
38+
39+
override def getIntegerEvaluation(
40+
key: String,
41+
defaultValue: java.lang.Integer,
42+
ctx: OFEvaluationContext
43+
): ProviderEvaluation[java.lang.Integer] =
44+
ProviderEvaluation.builder[java.lang.Integer]().value(defaultValue).reason("STATIC").build()
45+
46+
override def getDoubleEvaluation(
47+
key: String,
48+
defaultValue: java.lang.Double,
49+
ctx: OFEvaluationContext
50+
): ProviderEvaluation[java.lang.Double] =
51+
ProviderEvaluation.builder[java.lang.Double]().value(defaultValue).reason("STATIC").build()
52+
53+
override def getObjectEvaluation(
54+
key: String,
55+
defaultValue: dev.openfeature.sdk.Value,
56+
ctx: OFEvaluationContext
57+
): ProviderEvaluation[dev.openfeature.sdk.Value] =
58+
ProviderEvaluation.builder[dev.openfeature.sdk.Value]().value(defaultValue).reason("STATIC").build()
59+
}
60+
61+
// Counts before-hook invocations from the Java SDK thread.
62+
private class CountingJavaHook extends dev.openfeature.sdk.Hook[java.lang.Boolean] {
63+
val count = new AtomicInteger(0)
64+
override def before(
65+
ctx: JavaHookContext[java.lang.Boolean],
66+
hints: java.util.Map[String, AnyRef]
67+
): java.util.Optional[OFEvaluationContext] = {
68+
count.incrementAndGet()
69+
java.util.Optional.empty()
70+
}
71+
}
72+
73+
private def buildIsolated(provider: EventProvider): ZIO[Scope, Throwable, FeatureFlags] = {
74+
val api = OpenFeatureAPIFactory.create()
75+
val domain = s"api-hook-iso-${java.util.UUID.randomUUID()}"
76+
for {
77+
ff <- FeatureFlags.build(
78+
provider,
79+
domain = Some(domain),
80+
version = None,
81+
initialHooks = Nil,
82+
statusRef = None,
83+
addShutdownFinalizer = false,
84+
apiOverride = Some(api)
85+
)
86+
_ <- ZIO.attemptBlocking(Thread.sleep(50)).ignore
87+
} yield ff
88+
}
89+
90+
def spec = suite("API Hook Isolation")(
91+
test("addApiHook installs the hook on this FeatureFlags' isolated API") {
92+
ZIO.scoped {
93+
for {
94+
ff <- buildIsolated(new SimpleBooleanProvider("p", true))
95+
hook = new CountingJavaHook
96+
_ <- ff.addApiHook(hook)
97+
_ <- ff.boolean("flag", default = false).provideEnvironment(ZEnvironment(ff))
98+
} yield assertTrue(hook.count.get() == 1)
99+
}
100+
},
101+
test("addApiHook on one isolated FeatureFlags does not leak to another") {
102+
ZIO.scoped {
103+
for {
104+
ffA <- buildIsolated(new SimpleBooleanProvider("pA", true))
105+
ffB <- buildIsolated(new SimpleBooleanProvider("pB", true))
106+
hookA = new CountingJavaHook
107+
_ <- ffA.addApiHook(hookA)
108+
_ <- ffA.boolean("flag", default = false).provideEnvironment(ZEnvironment(ffA))
109+
_ <- ffB.boolean("flag", default = false).provideEnvironment(ZEnvironment(ffB))
110+
} yield assertTrue(hookA.count.get() == 1)
111+
}
112+
},
113+
test("clearApiHooks removes previously-installed API hooks") {
114+
ZIO.scoped {
115+
for {
116+
ff <- buildIsolated(new SimpleBooleanProvider("p", true))
117+
hook = new CountingJavaHook
118+
_ <- ff.addApiHook(hook)
119+
_ <- ff.clearApiHooks
120+
_ <- ff.boolean("flag", default = false).provideEnvironment(ZEnvironment(ff))
121+
} yield assertTrue(hook.count.get() == 0)
122+
}
123+
}
124+
)
125+
}

0 commit comments

Comments
 (0)