Skip to content

Commit

Permalink
fix aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Oct 29, 2024
1 parent d3eae72 commit f8f1597
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ __init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltyp

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
alias_b = default_alias_b(alg, prob.A, prob.b),
abstol = default_tol(real(eltype(prob.b))),
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
Expand All @@ -149,23 +147,50 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
Pr = nothing,
assumptions = OperatorAssumptions(issquare(prob.A)),
sensealg = LinearSolveAdjoint(),
alias = LinearAliases(),
alias = LinearAliasSpecifier(),
kwargs...)
(;A, b, u0, p) = prob

if isnothing(alias.alias_A)
to_alias_A = alias_A
has_A = haskey(kwargs,:alias_A)
has_b = haskey(kwargs,:alias_b)

if has_A || has_b
aliases = LinearAliasSpecifier()
if has_A
Base.depwarn("alias_A keyword argument is deprecated, to set `alias_A`,
please use a LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_A = true))", :alias_A)
SciMLBase.@reset aliases.alias_A = values(kwargs).alias_A
else
SciMLBase.@reset aliases.alias_A = default_alias_A(alg, prob.A, prob.b)
end

if has_b
Base.depwarn("alias_b keyword argument is deprecated, to set `alias_b`,
please use an LinearAliasSpecifier, e.g. `solve(prob, alias = LinearAliasSpecifier(alias_b = true))", :alias_b)
SciMLBase.@reset aliases.alias_b = values(kwargs).alias_b
else
SciMLBase.@reset aliases.alias_b = default_alias_b(alg, prob.A, prob.b)
end

aliases
else
to_alias_A = alias.alias_A
# If alias isa Bool, all fields of ODEAliases set to alias
if alias isa Bool
aliases = LinearAliasSpecifier(alias = alias)
elseif alias isa LinearAliasSpecifier || isnothing(alias)
aliases = alias
end

if isnothing(aliases.alias_A)
SciMLBase.@reset aliases.alias_A = default_alias_A(alg,prob.A,prob.b)
end
if isnothing(aliases.alias_b)
SciMLBase.@reset aliases.alias_b = default_alias_b(alg,prob.A,prob.b)
end
aliases
end

if isnothing(alias.alias_b)
to_alias_b = alias_b
else
to_alias_b = alias.alias_b
end

A = if to_alias_A || A isa SMatrix
A = if aliases.alias_A || A isa SMatrix
A
elseif A isa Array
copy(A)
Expand All @@ -177,7 +202,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,

b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
Array(b) # the solution to a linear solve will always be dense!
elseif to_alias_b || b isa SVector
elseif aliases.alias_b || b isa SVector
b
elseif b isa Array
copy(b)
Expand Down

0 comments on commit f8f1597

Please sign in to comment.