Skip to content

Commit

Permalink
KA extension update
Browse files Browse the repository at this point in the history
  • Loading branch information
sshin23 committed Nov 6, 2024
1 parent 21cdbf3 commit 24af3b2
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions ext/ExaModelsKernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ end
function _grad_structure!(backend, objs::ExaModels.ObjectiveNull, gsparsity) end

function ExaModels.jac_structure!(
m::ExaModels.ExaModel{T,VT,E} where {T,VT,E<:KAExtension},
m::ExaModels.ExaModel{T,VT,E},
rows::V,
cols::V,
) where {V<:AbstractVector}
) where {T,VT,E<:KAExtension,V<:AbstractVector}
if !isempty(rows)
_jac_structure!(m.ext.backend, m.cons, rows, cols)
end
Expand All @@ -184,10 +184,10 @@ function _jac_structure!(backend, cons::ExaModels.ConstraintNull, rows, cols) en


function ExaModels.hess_structure!(
m::ExaModels.ExaModel{T,VT,E} where {T,VT,E<:KAExtension},
m::ExaModels.ExaModel{T,VT,E},
rows::V,
cols::V,
) where {V<:AbstractVector}
) where {T,VT,E<:KAExtension,V<:AbstractVector}
if !isempty(rows)
_obj_hess_structure!(m.ext.backend, m.objs, rows, cols)
_con_hess_structure!(m.ext.backend, m.cons, rows, cols)
Expand Down Expand Up @@ -277,10 +277,10 @@ end
function _conaugs!(backend, y, con::ExaModels.ConstraintNull, x) end

function ExaModels.grad!(
m::ExaModels.ExaModel{T,VT,E} where {T,VT,E<:KAExtension},
m::ExaModels.ExaModel{T,VT,E},
x::V,
y::V,
) where {V<:AbstractVector}
) where {T,VT,E<:KAExtension,V<:AbstractVector}
gradbuffer = m.ext.gradbuffer

if !isempty(gradbuffer)
Expand Down Expand Up @@ -308,10 +308,10 @@ end
function _grad!(backend, y, objs::ExaModels.ObjectiveNull, x) end

function ExaModels.jac_coord!(
m::ExaModels.ExaModel{T,VT,E} where {T,VT,E<:KAExtension},
m::ExaModels.ExaModel{T,VT,E},
x::V,
y::V,
) where {V<:AbstractVector}
) where {T,VT,E<:KAExtension,V<:AbstractVector}
fill!(y, zero(eltype(y)))
_jac_coord!(m.ext.backend, y, m.cons, x)
return y
Expand Down Expand Up @@ -355,7 +355,7 @@ function ExaModels.jprod_nln!(
x::AbstractVector,
v::AbstractVector,
Jv::AbstractVector,
) where {T,VT,E<:KAExtension{T,VT,NamedTuple}}
) where {T,VT,N <: NamedTuple, E<:KAExtension{T,VT,N}}

fill!(Jv, zero(eltype(Jv)))
fill!(m.ext.prodhelper.jacbuffer, zero(eltype(Jv)))
Expand All @@ -377,7 +377,7 @@ function ExaModels.jtprod_nln!(
x::AbstractVector,
v::AbstractVector,
Jtv::AbstractVector,
) where {T,VT,E<:KAExtension{T,VT,NamedTuple}}
) where {T,VT,N <: NamedTuple, E<:KAExtension{T,VT,N}}

fill!(Jtv, zero(eltype(Jtv)))
fill!(m.ext.prodhelper.jacbuffer, zero(eltype(Jtv)))
Expand All @@ -401,7 +401,7 @@ function ExaModels.hprod!(
v::AbstractVector,
Hv::AbstractVector;
obj_weight = one(eltype(x)),
) where {T,VT,E<:KAExtension{T,VT,NamedTuple}}
) where {T,VT,N <: NamedTuple, E<:KAExtension{T,VT,NamedTuple}}

fill!(Hv, zero(eltype(Hv)))
fill!(m.ext.prodhelper.hessbuffer, zero(eltype(Hv)))
Expand Down Expand Up @@ -465,12 +465,12 @@ end


function ExaModels.hess_coord!(
m::ExaModels.ExaModel{T,VT,E} where {T,VT,E<:KAExtension},
m::ExaModels.ExaModel{T,VT,E},
x::V,
y::V,
hess::V;
obj_weight = one(eltype(y)),
) where {V<:AbstractVector}
) where {T,VT,E<:KAExtension, V<:AbstractVector}
fill!(hess, zero(eltype(hess)))
_obj_hess_coord!(m.ext.backend, hess, m.objs, x, obj_weight)
_con_hess_coord!(m.ext.backend, hess, m.cons, x, y)
Expand Down

0 comments on commit 24af3b2

Please sign in to comment.