Skip to content

Commit 5a58571

Browse files
authored
Introduce vector_getrange and vector_getranges for VarInfo (#738)
* replaced a closure with `Fix1` * added correct implementation of `getrange` for `TypedVarInfo` * fixed calls to varinfo methods which should be metadata methods * fixed typo * use `setval!` on the metadata directly instead of on the varinfo * added `length` implementation for `VarInfo` and `Metadata` * added testing for `getranges * introduce `vector_length` instead of `length`, since `length` already refers to the dictionary-like length impl, not vector-like * fixed bug in `getranges` for untyped varinfo * added proper testing for other `VarInfo` types * bump patch version * separated the `getrange` version which returns the range of the vecto representaiton rather than the internal representaiton into `vector_getrange` to make its function explicit * formatting * removed `vector_getrange` for metadata * added handling of missing indices + tests for these cases * added handling of duplicated values * removed no-longer relevant comment * fixed impl of `vector_getrange` and `vector_getranges` for threadsafe varinfo * fixed `vector_getranges` when `vns` are not found
1 parent 2252a9b commit 5a58571

File tree

5 files changed

+130
-6
lines changed

5 files changed

+130
-6
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.1"
3+
version = "0.31.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

Diff for: src/threadsafe.jl

+6
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:
178178
return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns)
179179
end
180180

181+
vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo)
182+
vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn)
183+
function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName})
184+
return vector_getranges(vi.varinfo, vns)
185+
end
186+
181187
function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
182188
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
183189
end

Diff for: src/varinfo.jl

+79-5
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ function VarInfo(
202202
end
203203
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
204204

205+
"""
206+
vector_length(varinfo::VarInfo)
207+
208+
Return the length of the vector representation of `varinfo`.
209+
"""
210+
vector_length(varinfo::VarInfo) = length(varinfo.metadata)
211+
vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata)
212+
vector_length(md::Metadata) = sum(length, md.ranges)
213+
205214
unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x)
206215

207216
# TODO: deprecate.
@@ -626,7 +635,72 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range
626635
Return the indices of `vns` in the metadata of `vi` corresponding to `vn`.
627636
"""
628637
function getranges(vi::VarInfo, vns::Vector{<:VarName})
629-
return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[])
638+
return map(Base.Fix1(getrange, vi), vns)
639+
end
640+
641+
"""
642+
vector_getrange(varinfo::VarInfo, varname::VarName)
643+
644+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
645+
"""
646+
vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn)
647+
function vector_getrange(vi::TypedVarInfo, vn::VarName)
648+
offset = 0
649+
for md in values(vi.metadata)
650+
# First, we need to check if `vn` is in `md`.
651+
# In this case, we can just return the corresponding range + offset.
652+
haskey(md, vn) && return getrange(md, vn) .+ offset
653+
# Otherwise, we need to get the cumulative length of the ranges in `md`
654+
# and add it to the offset.
655+
offset += sum(length, md.ranges)
656+
end
657+
# If we reach this point, `vn` is not in `vi.metadata`.
658+
throw(KeyError(vn))
659+
end
660+
661+
"""
662+
vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName})
663+
664+
Return the range corresponding to `varname` in the vector representation of `varinfo`.
665+
"""
666+
function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName})
667+
return map(Base.Fix1(vector_getrange, varinfo), varname)
668+
end
669+
# Specialized version for `TypedVarInfo`.
670+
function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName})
671+
# TODO: Does it help if we _don't_ convert to a vector here?
672+
metadatas = collect(values(varinfo.metadata))
673+
# Extract the offsets.
674+
offsets = cumsum(map(vector_length, metadatas))
675+
# Extract the ranges from each metadata.
676+
ranges = Vector{UnitRange{Int}}(undef, length(vns))
677+
# Need to keep track of which ones we've seen.
678+
not_seen = fill(true, length(vns))
679+
for (i, metadata) in enumerate(metadatas)
680+
vns_metadata = filter(Base.Fix1(haskey, metadata), vns)
681+
# If none of the variables exist in the metadata, we return an empty array.
682+
isempty(vns_metadata) && continue
683+
# Otherwise, we extract the ranges.
684+
offset = i == 1 ? 0 : offsets[i - 1]
685+
for vn in vns_metadata
686+
r_vn = getrange(metadata, vn)
687+
# Get the index, so we return in the same order as `vns`.
688+
# NOTE: There might be duplicates in `vns`, so we need to handle that.
689+
indices = findall(==(vn), vns)
690+
for idx in indices
691+
not_seen[idx] = false
692+
ranges[idx] = r_vn .+ offset
693+
end
694+
end
695+
end
696+
# Raise key error if any of the variables were not found.
697+
if any(not_seen)
698+
inds = findall(not_seen)
699+
# Just use a `convert` to get the same type as the input; don't want to confuse by overly
700+
# specilizing the types in the error message.
701+
throw(KeyError(convert(typeof(vns), vns[inds])))
702+
end
703+
return ranges
630704
end
631705

632706
"""
@@ -1314,13 +1388,13 @@ end
13141388

13151389
function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f)
13161390
# TODO: Use inplace versions to avoid allocations
1317-
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
1391+
yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn))
13181392
# Determine the new range.
1319-
start = first(getrange(vi, vn))
1393+
start = first(getrange(md, vn))
13201394
# NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`.
1321-
setrange!(vi, vn, start:(start + length(yvec) - 1))
1395+
setrange!(md, vn, start:(start + length(yvec) - 1))
13221396
# Set the new value.
1323-
setval!(vi, yvec, vn)
1397+
setval!(md, yvec, vn)
13241398
acclogp!!(vi, -logjac)
13251399
return vi
13261400
end

Diff for: src/varnamedvector.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa
10361036
return replace_raw_storage(vnv, vals)
10371037
end
10381038

1039+
vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv)
1040+
10391041
"""
10401042
unflatten(vnv::VarNamedVector, vals::AbstractVector)
10411043

Diff for: test/varinfo.jl

+42
Original file line numberDiff line numberDiff line change
@@ -813,4 +813,46 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
813813
@test DynamicPPL.istrans(varinfo2, vn)
814814
end
815815
end
816+
817+
# NOTE: It is not yet clear if this is something we want from all varinfo types.
818+
# Hence, we only test the `VarInfo` types here.
819+
@testset "vector_getranges for `VarInfo`" begin
820+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
821+
vns = DynamicPPL.TestUtils.varnames(model)
822+
nt = DynamicPPL.TestUtils.rand_prior_true(model)
823+
varinfos = DynamicPPL.TestUtils.setup_varinfos(
824+
model, nt, vns; include_threadsafe=true
825+
)
826+
# Only keep `VarInfo` types.
827+
varinfos = filter(
828+
Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos
829+
)
830+
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
831+
x = values_as(varinfo, Vector)
832+
833+
# Let's just check all the subsets of `vns`.
834+
@testset "$(convert(Vector{Any},vns_subset))" for vns_subset in
835+
combinations(vns)
836+
ranges = DynamicPPL.vector_getranges(varinfo, vns_subset)
837+
@test length(ranges) == length(vns_subset)
838+
for (r, vn) in zip(ranges, vns_subset)
839+
@test x[r] == DynamicPPL.tovec(varinfo[vn])
840+
end
841+
end
842+
843+
# Let's try some failure cases.
844+
@test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[]
845+
# Non-existent variables.
846+
@test_throws KeyError DynamicPPL.vector_getranges(
847+
varinfo, [VarName{gensym("vn")}()]
848+
)
849+
@test_throws KeyError DynamicPPL.vector_getranges(
850+
varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()]
851+
)
852+
# Duplicate variables.
853+
ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2))
854+
@test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2)
855+
end
856+
end
857+
end
816858
end

0 commit comments

Comments
 (0)