Skip to content

Commit

Permalink
Allow splitting Basis created using ElementVector (#811)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored Nov 29, 2021
1 parent 264c9e9 commit be9642f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
19 changes: 15 additions & 4 deletions skfem/assembly/basis/abstract_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
from numpy import ndarray
from skfem.assembly.dofs import Dofs, DofsView
from skfem.element import DiscreteField, Element, ElementComposite
from skfem.element import (DiscreteField, Element, ElementComposite,
ElementVector)
from skfem.mapping import Mapping
from skfem.mesh import Mesh
from skfem.quadrature import get_quadrature
Expand Down Expand Up @@ -360,11 +361,17 @@ def linear_combination(n, refn):

def split_indices(self) -> List[ndarray]:
"""Return indices for the solution components."""
if isinstance(self.elem, ElementComposite):
if ((isinstance(self.elem, ElementComposite)
or isinstance(self.elem, ElementVector))):
nelems = (len(self.elem.elems)
if isinstance(self.elem, ElementComposite)
else self.mesh.dim())
o = np.zeros(4, dtype=np.int64)
output: List[ndarray] = []
for k in range(len(self.elem.elems)):
e = self.elem.elems[k]
for k in range(nelems):
e = (self.elem.elems[k]
if isinstance(self.elem, ElementComposite)
else self.elem.elem)
output.append(np.concatenate((
self.nodal_dofs[o[0]:(o[0] + e.nodal_dofs)].flatten('F'),
self.edge_dofs[o[1]:(o[1] + e.edge_dofs)].flatten('F'),
Expand All @@ -385,6 +392,10 @@ def split_bases(self) -> List['AbstractBasis']:
return [type(self)(self.mesh, e, self.mapping,
quadrature=self.quadrature)
for e in self.elem.elems]
elif isinstance(self.elem, ElementVector):
return [type(self)(self.mesh, self.elem.elem, self.mapping,
quadrature=self.quadrature)
for _ in range(self.mesh.dim())]
raise ValueError("AbstractBasis.elem has only a single component!")

@property
Expand Down
7 changes: 7 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ def bilinf(u, p, v, q, w):
self.assertTrue(abs((C1 - C2).min()) < 1e-10)
self.assertTrue(abs((C1 - C2).max()) < 1e-10)

# test splitting ElementVector
(ux, uxbasis), (uy, uybasis) = u_basis.split(u)
assert_allclose(ux[uxbasis.nodal_dofs[0]], u[u_basis.nodal_dofs[0]])
assert_allclose(ux[uxbasis.facet_dofs[0]], u[u_basis.facet_dofs[0]])
assert_allclose(uy[uybasis.nodal_dofs[0]], u[u_basis.nodal_dofs[1]])
assert_allclose(uy[uybasis.facet_dofs[0]], u[u_basis.facet_dofs[1]])


class TestCompositeFacetAssembly(TestCase):

Expand Down

0 comments on commit be9642f

Please sign in to comment.