Skip to content

Commit

Permalink
Added docstrings to more functions and types (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshin23 authored Sep 8, 2023
1 parent 575526c commit 42a50e5
Show file tree
Hide file tree
Showing 9 changed files with 454 additions and 121 deletions.
4 changes: 3 additions & 1 deletion src/ExaModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ export ExaModel,
multipliers,
multipliers_L,
multipliers_U,
WrapperNLPModel
WrapperNLPModel,
@register_univariate,
@register_bivariate

end # module ExaModels
73 changes: 59 additions & 14 deletions src/gradient.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,78 @@
@inbounds @inline function drpass(d::D, y, adj) where {D<:AdjointNode1}
"""
drpass(d::D, y, adj)
Performs dense gradient evaluation via the reverse pass on the computation (sub)graph formed by forward pass
# Arguments:
- `d`: first-order computation (sub)graph
- `y`: result vector
- `adj`: adjoint propagated up to the current node
"""
@inline function drpass(d::D, y, adj) where {D<:AdjointNode1}
offset = drpass(d.inner, y, adj * d.y)
nothing
end
@inbounds @inline function drpass(d::D, y, adj) where {D<:AdjointNode2}
@inline function drpass(d::D, y, adj) where {D<:AdjointNode2}
offset = drpass(d.inner1, y, adj * d.y1)
offset = drpass(d.inner2, y, adj * d.y2)
nothing
end
@inbounds @inline function drpass(d::D, y, adj) where {D<:AdjointNodeVar}
y[d.i] += adj
@inline function drpass(d::D, y, adj) where {D<:AdjointNodeVar}
@inbounds y[d.i] += adj
nothing
end
@inbounds @inline function drpass(f::F, x, y, adj) where {F<:SIMDFunction} end
@inline function drpass(f::F, x, y, adj) where {F<:SIMDFunction} end

"""
gradient!(y, f, x, adj)
Performs dense gradient evalution
# Arguments:
- `y`: result vector
- `f`: the function to be differentiated in `SIMDFunction` format
- `x`: variable vector
- `adj`: initial adjoint
"""
function gradient!(y, f, x, adj)
@simd for k in eachindex(f.itr)
drpass(f.f.f(f.itr[k], AdjointNodeSource(x)), y, adj)
@inbounds drpass(f.f.f(f.itr[k], AdjointNodeSource(x)), y, adj)
end
return y
end

"""
grpass(d::D, comp, y, o1, cnt, adj)
Performs dsparse gradient evaluation via the reverse pass on the computation (sub)graph formed by forward pass
@inbounds @inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNode1}
# Arguments:
- `d`: first-order computation (sub)graph
- `comp`: a `Compressor`, which helps map counter to sparse vector index
- `y`: result vector
- `o1`: index offset
- `cnt`: counter
- `adj`: adjoint propagated up to the current node
"""
@inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNode1}
cnt = grpass(d.inner, comp, y, o1, cnt, adj * d.y)
return cnt
end
@inbounds @inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNode2}
@inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNode2}
cnt = grpass(d.inner1, comp, y, o1, cnt, adj * d.y1)
cnt = grpass(d.inner2, comp, y, o1, cnt, adj * d.y2)
return cnt
end
@inbounds @inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNodeVar}
y[o1+comp(cnt += 1)] += adj
@inline function grpass(d::D, comp, y, o1, cnt, adj) where {D<:AdjointNodeVar}
@inbounds y[o1+comp(cnt += 1)] += adj
return cnt
end

@inbounds @inline function grpass(d::AdjointNodeVar, comp::Nothing, y, o1, cnt, adj) # despecialization
@inline function grpass(d::AdjointNodeVar, comp::Nothing, y, o1, cnt, adj) # despecialization
push!(y, d.i)
return (cnt += 1)
end
@inbounds @inline function grpass(
@inline function grpass(
d::D,
comp,
y::V,
Expand All @@ -47,13 +81,24 @@ end
adj,
) where {D<:AdjointNodeVar,V<:AbstractVector{Tuple{Int,Int}}}
ind = o1 + comp(cnt += 1)
y[ind] = (d.i, ind)
@inbounds y[ind] = (d.i, ind)
return cnt
end

"""
sgradient!(y, f, x, adj)
Performs sparse gradient evalution
# Arguments:
- `y`: result vector
- `f`: the function to be differentiated in `SIMDFunction` format
- `x`: variable vector
- `adj`: initial adjoint
"""
function sgradient!(y, f, x, adj)
@simd for k in eachindex(f.itr)
grpass(f.f.f(f.itr[k], AdjointNodeSource(x)), f.itr.comp1, y, offset1(f, k), 0, adj)
@inbounds grpass(f.f.f(f.itr[k], AdjointNodeSource(x)), f.itr.comp1, y, offset1(f, k), 0, adj)
end
return y
end
177 changes: 148 additions & 29 deletions src/graph.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,66 @@
abstract type AbstractIndex end
# Abstract node type for the computation graph for symbolic expression
abstract type AbstractNode end
abstract type AbstractPar <: AbstractNode end

# Abstract node type for first-order forward pass tree
abstract type AbstractAdjointNode end

# Abstract node type for the computation graph for second-order forward pass
abstract type AbstractSecondAdjointNode end

"""
Var{I}
A variable node for symbolic expression tree
# Fields:
- `i::I`: (parameterized) index
"""
struct Var{I} <: AbstractNode
i::I
end
struct Par <: AbstractPar end
struct ParIndexed{I,J} <: AbstractPar

"""
ParSource
A source of parameterized data
"""
struct ParSource <: AbstractNode end

"""
ParIndexed{I, J}
A parameterized data node
# Fields:
- `inner::I`: parameter for the data
"""
struct ParIndexed{I,J} <: AbstractNode
inner::I
end

@inline ParIndexed(inner::I, n) where {I} = ParIndexed{I,n}(inner)
"""
Node1{F, I}
A node with one child for symbolic expression tree
# Fields:
- `inner::I`: children
"""
struct Node1{F,I} <: AbstractNode
inner::I
end

"""
Node2{F, I1, I2}
A node with two children for symbolic expression tree
# Fields:
- `inner1::I1`: children #1
- `inner2::I2`: children #2
"""
struct Node2{F,I1,I2} <: AbstractNode
inner1::I1
inner2::I2
Expand All @@ -27,9 +73,9 @@ struct SecondFixed{F}
inner::F
end

@inline Base.getindex(n::Par, i) = ParIndexed(n, i)
@inline Base.getindex(n::ParSource, i) = ParIndexed(n, i)

Par(iter::DataType) = Par()
Par(iter::DataType) = ParSource()
Par(iter, idx...) = ParIndexed(Par(iter, idx[2:end]...), idx[1])
Par(iter::Type{T}, idx...) where {T<:Tuple} =
Tuple(Par(p, i, idx...) for (i, p) in enumerate(T.parameters))
Expand All @@ -44,60 +90,118 @@ Par(iter::Type{T}, idx...) where {T<:NamedTuple} = NamedTuple{T.parameters[1]}(

struct Identity end

struct NaNSource{T} <: AbstractVector{T} end
@inline Base.getindex(::NaNSource{T}, i) where {T} = T(NaN)

@inline (v::Var{I})(i, x) where {I} = @inbounds x[v.i(i, x)]
@inline (v::ParSource)(i, x) = i
@inline (v::ParIndexed{I,n})(i, x) where {I,n} = @inbounds v.inner(i, x)[n]

@inbounds @inline (v::Var{I})(i, x) where {I} = x[v.i(i, x)]
@inbounds @inline (v::Par)(i, x) = i
@inbounds @inline (v::ParIndexed{I,n})(i, x) where {I,n} = v.inner(i, x)[n]
(v::ParIndexed)(i::Identity, x) = NaN16 # despecialized
(v::ParSource)(i::Identity, x) = NaN16 # despecialized
(v::Var)(i::Identity, x) = @inbounds x[v.i] # despecialized

@inbounds (v::ParIndexed)(i::Identity, x) = NaN16 # despecialized
@inbounds (v::Par)(i::Identity, x) = NaN16 # despecialized
@inbounds (v::Var)(i::Identity, x) = x[v.i] # despecialized
"""
AdjointNode1{F, T, I}
A node with one child for first-order forward pass tree
# Fields:
- `x::T`: function value
- `y::T`: first-order sensitivity
- `inner::I`: children
"""
struct AdjointNode1{F,T,I} <: AbstractAdjointNode
x::T
y::T
inner::I
end
"""
AdjointNode2{F, T, I1, I2}
A node with two children for first-order forward pass tree
# Fields:
- `x::T`: function value
- `y1::T`: first-order sensitivity w.r.t. first argument
- `y2::T`: first-order sensitivity w.r.t. second argument
- `inner1::I1`: children #1
- `inner2::I2`: children #2
"""
struct AdjointNode2{F,T,I1,I2} <: AbstractAdjointNode
x::T
y1::T
y2::T
inner1::I1
inner2::I2
end
"""
AdjointNodeVar{I, T}
A variable node for first-order forward pass tree
# Fields:
- `i::I`: index
- `x::T`: value
"""
struct AdjointNodeVar{I,T} <: AbstractAdjointNode
i::I
x::T
end
struct AdjointNodeSource{T,VT<:AbstractVector{T}}

"""
AdjointNodeSource{VT}
A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar` at index `i`.
# Fields:
- `inner::VT`: variable vector
"""
struct AdjointNodeSource{VT}
inner::VT
end
struct AdjointNodeNullSource end

@inline AdjointNode1(f::F, x::T, y, inner::I) where {F,T,I} =
AdjointNode1{F,T,I}(x, y, inner)
@inline AdjointNode2(f::F, x::T, y1, y2, inner1::I1, inner2::I2) where {F,T,I1,I2} =
AdjointNode2{F,T,I1,I2}(x, y1, y2, inner1, inner2)

@inline Base.getindex(x::I, i) where {I<:AdjointNodeSource{Nothing}} =
AdjointNodeVar(i, NaN16)
@inline Base.getindex(x::I, i) where {I<:AdjointNodeSource} =
@inbounds AdjointNodeVar(i, x.inner[i])

AdjointNodeSource(::Nothing) = AdjointNodeNullSource()

@inbounds @inline Base.getindex(x::I, i) where {I<:AdjointNodeNullSource} =
AdjointNodeVar(i, NaN16)
@inbounds @inline Base.getindex(x::I, i) where {I<:AdjointNodeSource} =
AdjointNodeVar(i, x.inner[i])
"""
SecondAdjointNode1{F, T, I}
A node with one child for second-order forward pass tree
# Fields:
- `x::T`: function value
- `y::T`: first-order sensitivity
- `h::T`: second-order sensitivity
- `inner::I`: DESCRIPTION
"""
struct SecondAdjointNode1{F,T,I} <: AbstractSecondAdjointNode
x::T
y::T
h::T
inner::I
end
"""
SecondAdjointNode2{F, T, I1, I2}
A node with one child for second-order forward pass tree
# Fields:
- `x::T`: function value
- `y1::T`: first-order sensitivity w.r.t. first argument
- `y2::T`: first-order sensitivity w.r.t. first argument
- `h11::T`: second-order sensitivity w.r.t. first argument
- `h12::T`: second-order sensitivity w.r.t. first and second argument
- `h22::T`: second-order sensitivity w.r.t. second argument
- `inner1::I1`: children #1
- `inner2::I2`: children #2
"""
struct SecondAdjointNode2{F,T,I1,I2} <: AbstractSecondAdjointNode
x::T
y1::T
Expand All @@ -109,11 +213,29 @@ struct SecondAdjointNode2{F,T,I1,I2} <: AbstractSecondAdjointNode
inner2::I2
end

"""
SecondAdjointNodeVar{I, T}
A variable node for first-order forward pass tree
# Fields:
- `i::I`: index
- `x::T`: value
"""
struct SecondAdjointNodeVar{I,T} <: AbstractSecondAdjointNode
i::I
x::T
end
struct SecondAdjointNodeSource{T,VT<:AbstractVector{T}}

"""
SecondAdjointNodeSource{VT}
A source of `AdjointNode`. `adjoint_node_source[i]` returns an `AdjointNodeVar` at index `i`.
# Fields:
- `inner::VT`: variable vector
"""
struct SecondAdjointNodeSource{VT}
inner::VT
end

Expand All @@ -132,10 +254,7 @@ end
) where {F,T,I1,I2} =
SecondAdjointNode2{F,T,I1,I2}(x, y1, y2, h11, h12, h22, inner1, inner2)

struct SecondAdjointNodeNullSource end
SecondAdjointNodeSource(::Nothing) = SecondAdjointNodeNullSource()

@inbounds @inline Base.getindex(x::I, i) where {I<:SecondAdjointNodeNullSource} =
SecondAdjointNodeVar(i, NaN)
@inbounds @inline Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource} =
SecondAdjointNodeVar(i, x.inner[i])
@inline Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource{Nothing}} =
SecondAdjointNodeVar(i, NaN16)
@inline Base.getindex(x::I, i) where {I<:SecondAdjointNodeSource} =
@inbounds SecondAdjointNodeVar(i, x.inner[i])
Loading

0 comments on commit 42a50e5

Please sign in to comment.