Skip to content

Commit

Permalink
fix: partial fix to mutation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2024
1 parent b63beff commit 7e0ba29
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
4 changes: 3 additions & 1 deletion lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0)
false_branch_fn_name = gensym(:false_branch)

all_input_vars = true_branch_input_list false_branch_input_list
all_output_vars = true_branch_assignments false_branch_assignments
filter!(x -> x != :(:), all_input_vars)
all_output_vars = all_true_branch_vars all_false_branch_vars
filter!(x -> x != :(:), all_output_vars)
discard_vars !== nothing && setdiff!(all_output_vars, discard_vars)

all_vars = all_input_vars all_output_vars
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

ReactantCore.is_traced(::TracedRArray) = true

new_traced_value(::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing)
new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A))

TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x

Expand Down
57 changes: 53 additions & 4 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,9 @@ end
@test res_ra res
end

# XXX: mutation is currently broken
function condition10_condition_with_setindex(x)
@trace if sum(x) > 0
x[1, 1] = -1.0
x[:, 1] = -1.0
else
x[1, 1] = 1.0
end
Expand All @@ -345,7 +344,57 @@ end

res_ra = @jit(condition10_condition_with_setindex(x_ra))
@test res_ra[1, 1] == -1.0 broken = true
@test res_ra[2, 1] == 1.0 broken = true
@test res_ra[2, 1] == -1.0 broken = true
@test x_ra[1, 1] == -1.0 broken = true
@test x_ra[2, 1] == 1.0 broken = true
@test x_ra[2, 1] == -1.0 broken = true

x = -rand(2, 10)
x[2, 1] = 0.0
x_ra = Reactant.to_rarray(x)

res_ra = @jit(condition10_condition_with_setindex(x_ra))
@test res_ra[1, 1] == -1.0 broken = true
@test res_ra[2, 1] == 0.0 broken = true
@test x_ra[1, 1] == -1.0 broken = true
@test x_ra[2, 1] == 0.0 broken = true
end

function condition11_nested_ifff(x, y, z)
x_sum = sum(x)
@trace if x_sum > 0
y_sum = sum(y)
if y_sum > 0
if sum(z) > 0
z = x_sum + y_sum + sum(z)
else
z = x_sum + y_sum
end
else
z = x_sum - y_sum
end
else
y_sum = sum(y)
z = x_sum - y_sum
end
return z
end

@testset "condition11: nested if 3 levels deep" begin
x = rand(2, 10)
y = rand(2, 10)
z = rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)
z_ra = Reactant.to_rarray(z)

@test @jit(condition11_nested_ifff(x_ra, y_ra, z_ra)) condition11_nested_ifff(x, y, z)

x = -rand(2, 10)
y = -rand(2, 10)
z = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)
z_ra = Reactant.to_rarray(z)

@test @jit(condition11_nested_ifff(x_ra, y_ra, z_ra)) condition11_nested_ifff(x, y, z)
end

0 comments on commit 7e0ba29

Please sign in to comment.