From 9f63c259fbff785f113eb29ad14b0d330a7182c5 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Thu, 21 Jul 2022 14:53:37 -0700 Subject: [PATCH] fix `as` (#348) * fix `as` * add test * bump version --- Project.toml | 10 +++++----- src/primitives/as.jl | 12 ++++++------ test/runtests.jl | 9 +++++++++ 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index d0a2ceb..72f1163 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Soss" uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c" author = ["Chad Scherrer "] -version = "0.21.1" +version = "0.21.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -52,7 +52,7 @@ JuliaVariables = "0.2" MLStyle = "0.3,0.4" MacroTools = "0.5" MappedArrays = "0.3, 0.4" -MeasureBase = "0.9" +MeasureBase = "0.10, 0.11, 0.12" MeasureTheory = "0.16" NamedTupleTools = "0.12, 0.13, 0.14" NestedTuples = "0.3.9" @@ -61,16 +61,16 @@ Reexport = "1" Requires = "1" RuntimeGeneratedFunctions = "0.5" SampleChains = "0.5" -SimpleGraphs = "0.5, 0.6, 0.7" +SimpleGraphs = "= 0.7.18" SimplePartitions = "0.2, 0.3" -SimplePosets = "0.1" +SimplePosets = "= 0.1.5" SpecialFunctions = "1, 2" Static = "0.5, 0.6" StatsBase = "0.33" StatsFuns = "0.9, 1" SymbolicUtils = "0.17, 0.18, 0.19" TransformVariables = "0.5, 0.6" -TupleVectors = "0.1" +TupleVectors = "0.1, 0.2" julia = "1.6" [extras] diff --git a/src/primitives/as.jl b/src/primitives/as.jl index 1b4bbb3..77c07e6 100644 --- a/src/primitives/as.jl +++ b/src/primitives/as.jl @@ -44,9 +44,11 @@ function sourceXform(_data=NamedTuple()) rhs = st.rhs thecode = @q begin - _t = Soss.as($rhs, get(_data, $xname, NamedTuple())) - if !isnothing(_t) - _result = merge(_result, ($x=_t,)) + _d = get(_data, $xname, nothing) + if isnothing(_d) # xname is not defined in _data + _result = merge(_result, ($x = Soss.as($rhs),)) + elseif _d isa NamedTuple + _result = merge(_result, ($x = Soss.as($rhs, _d),)) end end @@ -90,11 +92,9 @@ function asTransform(supp:: Dists.RealInterval) return ScaledShiftedLogistic(ub-lb, lb) end -as(d, _data) = nothing - as(μ::AbstractMeasure, _data::NamedTuple) = as(μ) -as(d::Dists.AbstractMvNormal, _data::NamedTuple=NamedTuple()) = as(Array, size(d)) +as(d::Dists.AbstractMvNormal, _data::NamedTuple = NamedTuple()) = TV.as(Array, size(d)) @gg function _as(M::Type{<:TypeLevel}, _m::Model{Asub,B}, _args::A, _data) where {Asub,A,B} body = type2model(_m) |> sourceXform(_data) |> loadvals(_args, _data) diff --git a/test/runtests.jl b/test/runtests.jl index 4215a9e..116c365 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Soss using Test +using LinearAlgebra using MeasureTheory import TransformVariables as TV using TransformVariables: transform @@ -194,4 +195,12 @@ include("examples-list.jl") base = basemeasure(post) @test logdensity_def(base, (p=0.2, x=post.obs.x)) isa Real end + + @testset "https://github.com/cscherrer/Soss.jl/issues/342" begin + m = Soss.@model () begin + z ~ Dists.MvNormal(zeros(10), I) + end + t = Soss.as(m()) + @test TV.transform(t, zeros(10)) == (z = zeros(10), ) + end end