diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index a31cee06..b1aba865 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -180,7 +180,9 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0) false_branch_fn_name = gensym(:false_branch) all_input_vars = true_branch_input_list ∪ false_branch_input_list - all_output_vars = true_branch_assignments ∪ false_branch_assignments + filter!(x -> x != :(:), all_input_vars) + all_output_vars = all_true_branch_vars ∪ all_false_branch_vars + filter!(x -> x != :(:), all_output_vars) discard_vars !== nothing && setdiff!(all_output_vars, discard_vars) all_vars = all_input_vars ∪ all_output_vars diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 1f6fcbf4..b3443a73 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -19,7 +19,7 @@ end ReactantCore.is_traced(::TracedRArray) = true -new_traced_value(::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing) +new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x diff --git a/test/control_flow.jl b/test/control_flow.jl index fcfa45ce..349be1a1 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -329,10 +329,9 @@ end @test res_ra ≈ res end -# XXX: mutation is currently broken function condition10_condition_with_setindex(x) @trace if sum(x) > 0 - x[1, 1] = -1.0 + x[:, 1] = -1.0 else x[1, 1] = 1.0 end @@ -345,7 +344,57 @@ end res_ra = @jit(condition10_condition_with_setindex(x_ra)) @test res_ra[1, 1] == -1.0 broken = true - @test res_ra[2, 1] == 1.0 broken = true + @test res_ra[2, 1] == -1.0 broken = true @test x_ra[1, 1] == -1.0 broken = true - @test x_ra[2, 1] == 1.0 broken = true + @test x_ra[2, 1] == -1.0 broken = true + + x = -rand(2, 10) + x[2, 1] = 0.0 + x_ra = Reactant.to_rarray(x) + + res_ra = @jit(condition10_condition_with_setindex(x_ra)) + @test res_ra[1, 1] == -1.0 broken = true + @test res_ra[2, 1] == 0.0 broken = true + @test x_ra[1, 1] == -1.0 broken = true + @test x_ra[2, 1] == 0.0 broken = true +end + +function condition11_nested_ifff(x, y, z) + x_sum = sum(x) + @trace if x_sum > 0 + y_sum = sum(y) + if y_sum > 0 + if sum(z) > 0 + z = x_sum + y_sum + sum(z) + else + z = x_sum + y_sum + end + else + z = x_sum - y_sum + end + else + y_sum = sum(y) + z = x_sum - y_sum + end + return z +end + +@testset "condition11: nested if 3 levels deep" begin + x = rand(2, 10) + y = rand(2, 10) + z = rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + z_ra = Reactant.to_rarray(z) + + @test @jit(condition11_nested_ifff(x_ra, y_ra, z_ra)) ≈ condition11_nested_ifff(x, y, z) + + x = -rand(2, 10) + y = -rand(2, 10) + z = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + z_ra = Reactant.to_rarray(z) + + @test @jit(condition11_nested_ifff(x_ra, y_ra, z_ra)) ≈ condition11_nested_ifff(x, y, z) end