Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 38 additions & 55 deletions src/ProvidedTypes.fs
Original file line number Diff line number Diff line change
Expand Up @@ -2874,16 +2874,15 @@ module internal AssemblyReader =

type ILMethodDefs(larr: Lazy<ILMethodDef[]>) =

let mutable lmap = null
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for y in larr.Force() do
let key = y.Name
match lmap.TryGetValue key with
| true, lmpak -> lmap.[key] <- Array.append [| y |] lmpak
| false, _ -> lmap.[key] <- [| y |]
lmap
let lmap = lazy (
let m = Dictionary()
for y in larr.Force() do
let key = y.Name
match m.TryGetValue key with
| true, lmpak -> m.[key] <- Array.append [| y |] lmpak
| false, _ -> m.[key] <- [| y |]
m)
let getmap() = lmap.Value

member __.Entries = larr.Force()
member __.FindByName nm =
Expand Down Expand Up @@ -3097,14 +3096,12 @@ module internal AssemblyReader =

and ILTypeDefs(larr: Lazy<(string uoption * string * Lazy<ILTypeDef>)[]>) =

let mutable lmap = null
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for (nsp, nm, ltd) in larr.Force() do
let key = nsp, nm
lmap.[key] <- ltd
lmap
let lmap = lazy (
let m = Dictionary()
for (nsp, nm, ltd) in larr.Force() do
m.[(nsp, nm)] <- ltd
m)
let getmap() = lmap.Value

member __.Entries =
[| for (_, _, td) in larr.Force() -> td.Force() |]
Expand Down Expand Up @@ -3142,14 +3139,12 @@ module internal AssemblyReader =
override x.ToString() = "fwd " + x.Name

and ILExportedTypesAndForwarders(larr:Lazy<ILExportedTypeOrForwarder[]>) =
let mutable lmap = null
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for ltd in larr.Force() do
let key = ltd.Namespace, ltd.Name
lmap.[key] <- ltd
lmap
let lmap = lazy (
let m = Dictionary()
for ltd in larr.Force() do
m.[(ltd.Namespace, ltd.Name)] <- ltd
m)
let getmap() = lmap.Value
member __.Entries = larr.Force()
member __.TryFindByName (nsp, nm) = match getmap().TryGetValue ((nsp, nm)) with true, v -> Some v | false, _ -> None

Expand Down Expand Up @@ -4579,34 +4574,25 @@ module internal AssemblyReader =

let mkCacheInt32 lowMem _infile _nm _sz =
if lowMem then (fun f x -> f x) else
let cache = ref null
let cache = ConcurrentDictionary<int32, _>()
fun f (idx:int32) ->
let cache =
match !cache with
| null -> cache := new Dictionary<int32, _>(11)
| _ -> ()
!cache
let mutable res = Unchecked.defaultof<_>
let ok = cache.TryGetValue(idx, &res)
if ok then
res
else
let res = f idx
cache.[idx] <- res;
res
match cache.TryGetValue idx with
| true, v -> v
| false, _ ->
let v = f idx
cache.TryAdd(idx, v) |> ignore
cache.[idx]

let mkCacheGeneric lowMem _inbase _nm _sz =
if lowMem then (fun f x -> f x) else
let cache = ref null
let cache = ConcurrentDictionary<'T, _>()
fun f (idx :'T) ->
let cache =
match !cache with
| null -> cache := new Dictionary<_, _>(11 (* sz:int *) )
| _ -> ()
!cache
match cache.TryGetValue idx with
| true, cached -> cached
| false, _ -> let res = f idx in cache.[idx] <- res; res
| true, v -> v
| false, _ ->
let v = f idx
cache.TryAdd(idx, v) |> ignore
cache.[idx]

let seekFindRow numRows rowChooser =
let mutable i = 1
Expand Down Expand Up @@ -7000,6 +6986,7 @@ namespace ProviderImplementation.ProvidedTypes

open System
open System.IO
open System.Collections.Concurrent
open System.Collections.Generic
open System.Reflection
open ProviderImplementation.ProvidedTypes.AssemblyReader
Expand All @@ -7012,14 +6999,10 @@ namespace ProviderImplementation.ProvidedTypes
// Unique wrapped type definition objects must be translated to unique wrapper objects, based
// on object identity.
type TxTable<'T2>() =
let tab = Dictionary<int, 'T2>()
let tab = ConcurrentDictionary<int, Lazy<'T2>>()
member __.Get inp f =
match tab.TryGetValue inp with
| true, tabVal -> tabVal
| false, _ ->
let res = f()
tab.[inp] <- res
res
let lazyVal = tab.GetOrAdd(inp, fun _ -> lazy (f()))
lazyVal.Value

member __.ContainsKey inp = tab.ContainsKey inp

Expand Down
43 changes: 43 additions & 0 deletions tests/BasicGenerativeProvisionTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,46 @@ let ``Generative custom attribute with named property argument encodes and round
Assert.Equal(1, namedArgs.Length)
Assert.Equal("Name", namedArgs.[0].MemberName)
Assert.Equal("MyProp", namedArgs.[0].TypedValue.Value :?> string)

[<Fact>]
let ``TargetTypeDefinition member-wrapper caches are thread-safe under parallel access``() =
// Regression test for https://github.com/fsprojects/FSharp.TypeProviders.SDK/issues/481
// PR #471 introduced lazy caches in TargetTypeDefinition. When multiple threads call
// GetConstructors/GetMethods/etc. concurrently on the same generated type the underlying
// shared caches must not corrupt. Run 8 parallel threads each interrogating every member
// kind on the same TargetTypeDefinition; if any internal collection races, the dictionaries
// will throw InvalidOperationException.
let runtimeAssemblyRefs = Targets.DotNetStandard20FSharpRefs()
let runtimeAssembly = runtimeAssemblyRefs.[0]
let cfg = Testing.MakeSimulatedTypeProviderConfig(__SOURCE_DIRECTORY__, runtimeAssembly, runtimeAssemblyRefs)
let staticArgs = [| box 5; box 6 |]
let tp = GenerativePropertyProviderWithStaticParams cfg :> TypeProviderForNamespaces
let providedNamespace = tp.Namespaces.[0]
let providedTypes = providedNamespace.GetTypes()
let providedType = providedTypes.[0]
let typeName = providedType.Name + (staticArgs |> Seq.map (fun s -> ",\"" + s.ToString() + "\"") |> Seq.reduce (+))
let t = (tp :> ITypeProvider).ApplyStaticArguments(providedType, [| typeName |], staticArgs)
let assemContents = (tp :> ITypeProvider).GetGeneratedAssemblyContents(t.Assembly)
let assem = tp.TargetContext.ReadRelatedAssembly(assemContents)
let typeName2 = providedType.Namespace + "." + typeName
let targetType = assem.GetType(typeName2)
Assert.NotNull(targetType)

let bf = BindingFlags.Public ||| BindingFlags.NonPublic ||| BindingFlags.Instance ||| BindingFlags.Static
let errors = System.Collections.Concurrent.ConcurrentBag<exn>()
let threads =
[| for _ in 1..8 ->
System.Threading.Thread(fun () ->
try
for _ in 1..50 do
targetType.GetConstructors(bf) |> ignore
targetType.GetMethods(bf) |> ignore
targetType.GetFields(bf) |> ignore
targetType.GetProperties(bf) |> ignore
targetType.GetEvents(bf) |> ignore
targetType.GetNestedTypes(bf) |> ignore
with ex ->
errors.Add(ex)) |]
for th in threads do th.Start()
for th in threads do th.Join()
Assert.True(errors.IsEmpty, sprintf "Thread-safety violations: %A" (errors |> Seq.toList))
Loading