From 7d766d88d54ce38b9ebc943aea78b95b0533983d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 Oct 2023 17:02:10 -0400 Subject: [PATCH] Prevent KrylovJL use for RecursiveArrayTools --- Project.toml | 6 +++++- ext/LinearSolveRecursiveArrayToolsExt.jl | 12 ++++++++++++ src/default.jl | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 ext/LinearSolveRecursiveArrayToolsExt.jl diff --git a/Project.toml b/Project.toml index 186db01b7..34e9bbea8 100644 --- a/Project.toml +++ b/Project.toml @@ -33,14 +33,15 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" [extensions] LinearSolveBandedMatricesExt = "BandedMatrices" @@ -53,6 +54,7 @@ LinearSolveKernelAbstractionsExt = "KernelAbstractions" LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMetalExt = "Metal" LinearSolvePardisoExt = "Pardiso" +LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools" [compat] ArrayInterface = "7.4.11" @@ -72,6 +74,7 @@ Krylov = "0.9" KrylovKit = "0.5, 0.6" PrecompileTools = "1" Preferences = "1" +RecursiveArrayTools = "2" RecursiveFactorization = "0.2.8" Reexport = "1" Requires = "1" @@ -99,6 +102,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/LinearSolveRecursiveArrayToolsExt.jl b/ext/LinearSolveRecursiveArrayToolsExt.jl new file mode 100644 index 000000000..98da9583c --- /dev/null +++ b/ext/LinearSolveRecursiveArrayToolsExt.jl @@ -0,0 +1,12 @@ +module LinearSolveRecursiveArrayToolsExt + +using LinearSolve, RecursiveArrayTools +import LinearSolve: init_cacheval + +# Krylov.jl tries to init with `ArrayPartition(undef, ...)`. Avoid hitting that! +function init_cacheval(alg::LinearSolve.KrylovJL, A, b::ArrayPartition, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + return nothing +end + +end diff --git a/src/default.jl b/src/default.jl index 04d9e0229..509a14571 100644 --- a/src/default.jl +++ b/src/default.jl @@ -93,7 +93,7 @@ end end function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions) - if assump.condition === OperatorConodition.IllConditioned || !assump.issq + if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @static if VERSION >= v"1.8-"