diff --git a/.nojekyll b/.nojekyll
new file mode 100644
index 0000000..e69de29
diff --git a/doctrees/environment.pickle b/doctrees/environment.pickle
new file mode 100644
index 0000000..3b82022
Binary files /dev/null and b/doctrees/environment.pickle differ
diff --git a/doctrees/index.doctree b/doctrees/index.doctree
new file mode 100644
index 0000000..d5e2745
Binary files /dev/null and b/doctrees/index.doctree differ
diff --git a/doctrees/modules.doctree b/doctrees/modules.doctree
new file mode 100644
index 0000000..b93e93e
Binary files /dev/null and b/doctrees/modules.doctree differ
diff --git a/doctrees/pancax.bcs.doctree b/doctrees/pancax.bcs.doctree
new file mode 100644
index 0000000..1489074
Binary files /dev/null and b/doctrees/pancax.bcs.doctree differ
diff --git a/doctrees/pancax.bvps.doctree b/doctrees/pancax.bvps.doctree
new file mode 100644
index 0000000..63495f5
Binary files /dev/null and b/doctrees/pancax.bvps.doctree differ
diff --git a/doctrees/pancax.constitutive_models.doctree b/doctrees/pancax.constitutive_models.doctree
new file mode 100644
index 0000000..d894f50
Binary files /dev/null and b/doctrees/pancax.constitutive_models.doctree differ
diff --git a/doctrees/pancax.data.doctree b/doctrees/pancax.data.doctree
new file mode 100644
index 0000000..644ca0b
Binary files /dev/null and b/doctrees/pancax.data.doctree differ
diff --git a/doctrees/pancax.doctree b/doctrees/pancax.doctree
new file mode 100644
index 0000000..0d2bf50
Binary files /dev/null and b/doctrees/pancax.doctree differ
diff --git a/doctrees/pancax.domains.doctree b/doctrees/pancax.domains.doctree
new file mode 100644
index 0000000..629db48
Binary files /dev/null and b/doctrees/pancax.domains.doctree differ
diff --git a/doctrees/pancax.fem.doctree b/doctrees/pancax.fem.doctree
new file mode 100644
index 0000000..fe06a2d
Binary files /dev/null and b/doctrees/pancax.fem.doctree differ
diff --git a/doctrees/pancax.fem.elements.doctree b/doctrees/pancax.fem.elements.doctree
new file mode 100644
index 0000000..155adcc
Binary files /dev/null and b/doctrees/pancax.fem.elements.doctree differ
diff --git a/doctrees/pancax.kernels.doctree b/doctrees/pancax.kernels.doctree
new file mode 100644
index 0000000..861766f
Binary files /dev/null and b/doctrees/pancax.kernels.doctree differ
diff --git a/doctrees/pancax.loss_functions.doctree b/doctrees/pancax.loss_functions.doctree
new file mode 100644
index 0000000..cf14539
Binary files /dev/null and b/doctrees/pancax.loss_functions.doctree differ
diff --git a/doctrees/pancax.math.doctree b/doctrees/pancax.math.doctree
new file mode 100644
index 0000000..169ef69
Binary files /dev/null and b/doctrees/pancax.math.doctree differ
diff --git a/doctrees/pancax.networks.doctree b/doctrees/pancax.networks.doctree
new file mode 100644
index 0000000..9e9bed0
Binary files /dev/null and b/doctrees/pancax.networks.doctree differ
diff --git a/doctrees/pancax.optimizers.doctree b/doctrees/pancax.optimizers.doctree
new file mode 100644
index 0000000..91f02ff
Binary files /dev/null and b/doctrees/pancax.optimizers.doctree differ
diff --git a/html/.buildinfo b/html/.buildinfo
new file mode 100644
index 0000000..8d73133
--- /dev/null
+++ b/html/.buildinfo
@@ -0,0 +1,4 @@
+# Sphinx build info version 1
+# This file records the configuration used when building these files. When it is not found, a full rebuild will be done.
+config: 030c48f1638182c071f4607d697d8ccf
+tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/html/_modules/index.html b/html/_modules/index.html
new file mode 100644
index 0000000..b17cec1
--- /dev/null
+++ b/html/_modules/index.html
@@ -0,0 +1,184 @@
+
+
+
+
+
+[docs]
+defget_edges(domain,sset_names:List[str]):
+ mesh=domain.fspace.mesh
+ ssets=[mesh.sideSets[name]fornameinsset_names]
+ edges=jnp.vstack(ssets)
+ edges=jnp.sort(edges,axis=1)
+ edges=jnp.unique(edges,axis=0)# TODO not sure about this one
+ returnedges
+[docs]
+classEssentialBC(eqx.Module):
+"""
+ :param nodeSet: A name for a nodeset in the mesh
+ :param component: The dof to apply the essential bc to
+ :param function: A function f(x, t) = u that gives the value
+ to enforce on the (nodeset, component) of a field.
+ This defaults to the zero function
+ """
+ nodeSet:str
+ component:int
+ function:Optional[BCFunc]=lambdax,t:0.0
+
+
+fromjaxtypingimportArray,Float
+frompancax.femimportsurface
+fromtypingimportCallable,Optional
+importequinoxaseqx
+importjax
+importjax.numpyasjnp
+
+
+BCFunc=Callable[[Float[Array,"nd"],float],Float[Array,"nf"]]
+# remove component from the definition.
+# it doesn't appear to be doing anything
+# class NaturalBC(NamedTuple):
+
+[docs]
+classConstitutiveModel(ABC):
+"""
+ Base class for consistutive models.
+
+ The interface to be defined by derived classes include
+ the energy method and the unpack_properties method.
+
+ :param n_properties: The number of properties
+ :param property_names: The names of the properties
+ """
+ n_properties:int
+ property_names:List[str]
+
+
+[docs]
+ @abstractmethod
+ defunpack_properties(self,props:Float[Array,"np"]):
+"""
+ This method unpacks properties from 'props' and returns
+ them with potentially static properties bound to the model.
+ """
+ pass
+[docs]
+classFullFieldData(eqx.Module):
+"""
+ Data structure to store full field data used as ground truth
+ for output fields of a PINN when solving inverse problems.
+
+ :param inputs: Data that serves as inputs to the PINN
+ :param outputs: Data that serves as outputs of the PINN
+ :param n_time_steps: Variable used for book keeping
+ """
+ inputs:Array
+ outputs:Array
+ n_time_steps:int
+
+
+fromjaxtypingimportArray
+fromtypingimportOptional,Union
+importequinoxaseqx
+importexodus3asexodus
+importjax.numpyasjnp
+importmatplotlib.pyplotasplt
+importnumpyasnp
+importpandas
+
+
+# TODO currently hardcoded to force which may be limiting
+# for others interested in doing other physics
+
+[docs]
+classGlobalData(eqx.Module):
+"""
+ Data structure that holds global data to be used as
+ ground truth for some global field calculated from
+ PINN outputs used in inverse modeling training
+
+ :param times: A set of times used to compare to physics calculations
+ :param displacements: Currently hardcoded to use a displacement-force curve TODO
+ :param outputs: Field used as ground truth, hardcoded essentially to a reaction force now
+ :param n_nodes: Book-keeping variable for number of nodes on nodeset to measure global response from
+ :param n_time_steps: Book-keeping variable
+ :param reaction_nodes: Node set nodes for where to measure reaction forces
+ :param reaction_dof: Degree of freedom to use for reaction force calculation
+ """
+ times:Array# change to inputs?
+ displacements:Array
+ outputs:Array
+ n_nodes:int
+ n_time_steps:int
+ reaction_nodes:Array
+ reaction_dof:int
+
+
+[docs]
+classBaseDomain(eqx.Module):
+"""
+ Base domain for all problem types.
+ This holds essential things for the problem
+ such as a mesh to load a geometry from,
+ times, physics, bcs, etc.
+
+ :param mesh: A mesh from an exodus file most likely
+ :param coords: An array of coordinates
+ :param times: An array of times
+ :param physics: An initialized physics object
+ :param essential_bcs: A list of EssentialBCs
+ :param natural_bcs: a list of NaturalBCs
+ :param dof_manager: A DofManager for keeping track of essential bcs
+ """
+ mesh_file:str
+ mesh:Mesh
+ coords:Float[Array,"nn nd"]
+ times:Union[Float[Array,"nt"],Float[Array,"nn 1"]]
+ physics:PhysicsKernel
+ essential_bcs:List[EssentialBC]
+ natural_bcs:List[NaturalBC]
+ dof_manager:DofManager
+
+
+[docs]
+ def__init__(
+ self,
+ physics:PhysicsKernel,
+ essential_bcs:List[EssentialBC],
+ natural_bcs:any,# TODO figure out to handle this
+ mesh_file:str,
+ times:Float[Array,"nt"],
+ p_order:Optional[int]=1
+ ):
+ withTimer('BaseDomain.__init__'):
+ # setup
+ n_dofs=physics.n_dofs
+
+ # mesh
+ mesh=read_exodus_mesh(mesh_file)
+
+ # if tri mesh, we can make it higher order from lower order
+ iftype(mesh.parentElement)==SimplexTriElement:
+ mesh=create_higher_order_mesh_from_simplex_mesh(mesh,p_order,copyNodeSets=True)
+ else:
+ print('WARNING: Ignoring polynomial order flag for non tri mesh')
+
+ withTimer('move coordinates to device'):
+ coords=jnp.array(mesh.coords)
+
+ # dof book keeping
+ dof_manager=DofManager(mesh,n_dofs,essential_bcs)
+ # TODO move below to dof manager
+ dof_manager.isUnknown=jnp.array(dof_manager.isUnknown,dtype=jnp.bool)
+ dof_manager.unknownIndices=jnp.array(dof_manager.unknownIndices)
+
+ # setting all at once
+ self.mesh_file=mesh_file
+ self.mesh=mesh
+ self.coords=coords
+ self.times=times
+ self.physics=physics
+ # self.essential_bcs = essential_bcs
+ self.essential_bcs=EssentialBCSet(essential_bcs)
+ self.natural_bcs=natural_bcs
+ self.dof_manager=dof_manager
+[docs]
+classCollocationDomain(BaseDomain):
+"""
+ Base domain for all problem types.
+ This holds essential things for the problem
+ such as a mesh to load a geometry from,
+ times, physics, bcs, etc.
+
+ :param mesh: A mesh from an exodus file most likely
+ :param coords: An array of coordinates
+ :param times: An array of times
+ :param physics: An initialized physics object
+ :param essential_bcs: A list of EssentialBCs
+ :param natural_bcs: a list of NaturalBCs
+ :param dof_manager: A DofManager for keeping track of essential bcs
+ """
+ mesh:Mesh
+ coords:Float[Array,"nn nd"]
+ times:Union[Float[Array,"nt"],Float[Array,"nn 1"]]
+ physics:PhysicsKernel
+ essential_bcs:List[EssentialBC]
+ natural_bcs:List[NaturalBC]
+ dof_manager:DofManager
+ neumann_xs:Float[Array,"nn nd"]
+ neumann_ns:Float[Array,"nn nd"]
+ # neumann_outputs: Float[Array, "nn nf"]
+
+
+[docs]
+ def__init__(
+ self,
+ physics:PhysicsKernel,
+ essential_bcs:List[EssentialBC],
+ natural_bcs:List[NaturalBC],
+ mesh_file:str,
+ times:Float[Array,"nt"],
+ p_order:Optional[int]=1,
+ q_order:Optional[int]=2,
+ vectorize_over_time:Optional[bool]=False
+ )->None:
+"""
+ :param physics: A ``PhysicsKernel`` object
+ :param essential_bcs: A list of ``EssentiablBC`` objects
+ :param natural_bcs: TODO
+ :param mesh_file: mesh file name as string
+ :param times: set of times
+ :param p_order: Polynomial order for mesh. Only hooked up to tri meshes.
+ :param q_order: Quadrature order to use.
+ :param vectorize_over_time: Flag to enable vectorization over time
+ this likely only makes sense for path-independent problems.
+ """
+ withTimer('CollocationDomain.__init__'):
+ super().__init__(
+ physics,essential_bcs,natural_bcs,
+ mesh_file,times,
+ p_order=p_order
+ )
+
+ # TODO currently all of the below is busted for transient problems
+
+ # TODO need to gather dirichlet inputs/outputs
+ # mesh = self.fspace.mesh
+ # if len(essential_bcs) > 0:
+ # self.dirichlet_xs = jnp.vstack([bc.coordinates(mesh) for bc in essential_bcs])
+ # self.dirichlet_outputs = jnp.vstack([
+ # jax.vmap(bc.function, in_axes=(0, None))(bc.coordinates(mesh), )
+ # ])
+
+ # TODO below should eventually be move to the base class maybe?
+ # get neumann xs and ns
+ # mesh = self.fspace.mesh
+ mesh=self.mesh
+ ifmesh.num_dimensions!=2:
+ raiseValueError(
+ 'Only 2D meshes currently supported for collocation problems. '
+ 'Need to implement surface normal calculations on 3D elements.'
+ )
+ # q_rule_1d = self.q_rule_1d
+ # currently only support 2D meshes
+ q_rule_1d=QuadratureRule(LineElement(1),q_order)
+ iflen(natural_bcs)>0:
+ self.neumann_xs=[bc.coordinates(mesh,q_rule_1d)forbcinnatural_bcs]
+ self.neumann_ns=[bc.normals(mesh,q_rule_1d)forbcinnatural_bcs]
+ print('Warning this neumann condition will fail for inhomogenous conditions with time')
+ # self.neumann_outputs = jnp.vstack([
+ # jax.vmap(bc.function, in_axes=(0, None))(bc.coordinates(mesh, q_rule_1d), 0.0) \
+ # for bc in natural_bcs
+ # ])
+ else:
+ self.neumann_xs=[]
+ self.neumann_ns=[]
+[docs]
+classDeltaPINNDomain(VariationalDomain):
+"""
+ Base domain for all problem types.
+ This holds essential things for the problem
+ such as a mesh to load a geometry from,
+ times, physics, bcs, etc.
+
+ :param mesh: A mesh from an exodus file most likely
+ :param coords: An array of coordinates
+ :param times: An array of times
+ :param physics: An initialized physics object
+ :param essential_bcs: A list of EssentialBCs
+ :param natural_bcs: a list of NaturalBCs
+ :param dof_manager: A DofManager for keeping track of essential bcs
+ :param conns: An array of connectivities
+ :param fspace: A FunctionSpace to help with integration
+ :param fspace_centroid: A FunctionSpace to help with integration
+ :param n_eigen_values: Number of eigenvalues to use
+ """
+ mesh:Mesh
+ coords:Float[Array,"nn nd"]
+ times:Union[Float[Array,"nt"],Float[Array,"nn 1"]]
+ physics:PhysicsKernel
+ essential_bcs:List[EssentialBC]
+ natural_bcs:List[NaturalBC]
+ dof_manager:DofManager
+ conns:Int[Array,"ne nnpe"]
+ fspace:FunctionSpace
+ fspace_centroid:FunctionSpace
+ n_eigen_values:int
+ eigen_modes:Float[Array,"nn nev"]
+
+
+[docs]
+ def__init__(
+ self,
+ physics:PhysicsKernel,
+ essential_bcs:List[EssentialBC],
+ natural_bcs:List[NaturalBC],
+ mesh_file:str,
+ times:Float[Array,"nt"],
+ n_eigen_values:int,
+ p_order:Optional[int]=1,
+ q_order:Optional[int]=2
+ )->None:
+"""
+ :param physics: A ``PhysicsKernel`` object
+ :param essential_bcs: A list of ``EssentiablBC`` objects
+ :param natural_bcs: TODO
+ :param mesh_file: mesh file name as string
+ :param times: set of times
+ :param p_order: Polynomial order for mesh. Only hooked up to tri meshes.
+ :param q_order: Quadrature order to use.
+ :param n_eigen_values: Number of eigenvalues to use
+ """
+ ifnotphysics.use_delta_pinn:
+ raiseValueError('Need a physics object set up for DeltaPINN')
+
+ withTimer('DeltaPINNDomain.__init__'):
+ super().__init__(
+ physics,essential_bcs,natural_bcs,mesh_file,times,
+ p_order=p_order,q_order=q_order
+ )
+ self.n_eigen_values=n_eigen_values
+ self.eigen_modes=self.solve_eigen_problem(mesh_file,p_order,q_order)
+[docs]
+classInverseDomain(VariationalDomain):
+"""
+ Inverse domain type derived from a ForwardDomain
+
+ Note that currently this likely only supports single block meshes.
+
+ :param physics: A physics kernel to use for physics calculations.
+ :param dof_manager: A DofManager to track what dofs are free/fixed.
+ :param fspace: A FunctionSpace to help with integration
+ :param q_rule_1d: A quadrature rule for line/surface integrations. TODO rename this
+ :param q_rule_2d: A quadrature rule for cell integrations. TODO rename this
+ :param coords: Nodal coordinates in the reference configuration.
+ :param conns: Element connectivity matrix
+ :param field_data: Data structure that holds the full field data.
+ :param global_data: Data structure that holds the global data.
+ """
+ physics:PhysicsKernel
+ dof_manager:DofManager
+ fspace:FunctionSpace
+ fspace_centroid:FunctionSpace
+ coords:Float[Array,"nn nd"]
+ conns:Float[Array,"ne nnpe"]
+ times:Union[Float[Array,"nt"],Float[Array,"nn 1"]]
+ field_data:FullFieldData
+ global_data:GlobalData
+
+
+[docs]
+ def__init__(
+ self,
+ physics:PhysicsKernel,
+ essential_bcs:List[EssentialBC],
+ natural_bcs:any,# TODO figure out to handle this
+ mesh_file:str,
+ times:Float[Array,"nt"],
+ field_data:FullFieldData,
+ global_data:GlobalData,
+ p_order:Optional[int]=1,
+ q_order:Optional[int]=2
+ )->None:
+"""
+ :param physics: A ``PhysicsKernel`` object
+ :param essential_bcs: A list of ``EssentiablBC`` objects
+ :param natural_bcs: TODO
+ :param mesh_file: mesh file name as string
+ :param times: An array of time values to use
+ :param field_data: ``FieldData`` object
+ :param global_data: ``GlobalData`` object
+ :param p_order: Polynomial order for mesh. Only hooked up to tri meshes.
+ :param q_order: Quadrature order to use.
+ """
+ super().__init__(
+ physics,essential_bcs,natural_bcs,mesh_file,times,
+ p_order=p_order,q_order=q_order
+ )
+ self.field_data=field_data
+ self.global_data=global_data
+[docs]
+classVariationalDomain(BaseDomain):
+"""
+ Base domain for all problem types.
+ This holds essential things for the problem
+ such as a mesh to load a geometry from,
+ times, physics, bcs, etc.
+
+ :param mesh: A mesh from an exodus file most likely
+ :param coords: An array of coordinates
+ :param times: An array of times
+ :param physics: An initialized physics object
+ :param essential_bcs: A list of EssentialBCs
+ :param natural_bcs: a list of NaturalBCs
+ :param dof_manager: A DofManager for keeping track of essential bcs
+ :param conns: An array of connectivities
+ :param fspace: A FunctionSpace to help with integration
+ :param fspace_centroid: A FunctionSpace to help with integration
+ """
+ mesh:Mesh
+ coords:Float[Array,"nn nd"]
+ times:Union[Float[Array,"nt"],Float[Array,"nn 1"]]
+ physics:PhysicsKernel
+ essential_bcs:List[EssentialBC]
+ natural_bcs:List[NaturalBC]
+ dof_manager:DofManager
+ conns:Int[Array,"ne nnpe"]
+ fspace:FunctionSpace
+ fspace_centroid:FunctionSpace
+
+
+[docs]
+ def__init__(
+ self,
+ physics:PhysicsKernel,
+ essential_bcs:List[EssentialBC],
+ natural_bcs:any,# TODO figure out to handle this
+ mesh_file:str,
+ times:Float[Array,"nt"],
+ p_order:Optional[int]=1,
+ q_order:Optional[int]=2
+ )->None:
+"""
+ :param physics: A ``PhysicsKernel`` object
+ :param essential_bcs: A list of ``EssentiablBC`` objects
+ :param natural_bcs: TODO
+ :param mesh_file: mesh file name as string
+ :param times: set of times
+ :param p_order: Polynomial order for mesh. Only hooked up to tri meshes.
+ :param q_order: Quadrature order to use.
+ """
+ withTimer('VariationalDomain.__init__'):
+ super().__init__(
+ physics,essential_bcs,natural_bcs,mesh_file,times,
+ p_order=p_order
+ )
+ withTimer('move connectivity to device'):
+ self.conns=jnp.array(self.mesh.conns)
+
+ self.fspace=NonAllocatedFunctionSpace(
+ self.mesh,QuadratureRule(self.mesh.parentElement,q_order)
+ )
+ self.fspace_centroid=NonAllocatedFunctionSpace(
+ self.mesh,QuadratureRule(self.mesh.parentElement,1)
+ )
+from.function_spaceimportFunctionSpace
+fromjaxtypingimportArray,Bool,Float,Int
+frompancax.bcsimportEssentialBC
+frompancax.timerimportTimer
+fromtypingimportList,Tuple
+importjax.numpyasnp
+importnumpyasonp
+
+# TODO
+# getting some error when making this a child of eqx.Module
+
+[docs]
+classDofManager:
+"""
+ Collection of arrays needed to differentiate between
+ fixed and free dofs for fem like calculations.
+
+ TODO better document the parameters in this guy
+ """
+
+[docs]
+ def__init__(self,mesh,dim:int,EssentialBCs:List[EssentialBC])->None:
+"""
+ :param functionSpace: ``FunctionSpace`` object
+ :param dim: The number of dims (really the number of active dofs for the physics)
+ :param EssentialBCs: A list of of ``EssentialBC`` objects
+ """
+ withTimer('DofManager.__init__'):
+ self.fieldShape=mesh.num_nodes,dim
+ self.isBc=onp.full(self.fieldShape,False,dtype=bool)
+ forebcinEssentialBCs:
+ self.isBc[mesh.nodeSets[ebc.nodeSet],ebc.component]=True
+
+ self.isUnknown=~self.isBc
+
+ self.ids=onp.arange(self.isBc.size).reshape(self.fieldShape)
+
+ self.unknownIndices=self.ids[self.isUnknown]
+ self.bcIndices=self.ids[self.isBc]
+
+ ones=onp.ones(self.isBc.size,dtype=int)*-1
+ # self.dofToUnknown = ones.at[self.unknownIndices].set(np.arange(self.unknownIndices.size))
+ self.dofToUnknown=ones
+ self.dofToUnknown[self.unknownIndices]=onp.arange(self.unknownIndices.size)
+
+ self.HessRowCoords,self.HessColCoords=self._make_hessian_coordinates(onp.array(mesh.conns))
+
+ self.hessian_bc_mask=self._make_hessian_bc_mask(onp.array(mesh.conns))
+
+
+
+[docs]
+ defget_bc_size(self)->int:
+"""
+ :return: the number of fixed dofs
+ """
+ returnnp.sum(self.isBc).item()# item() method casts to Python int
+
+
+
+[docs]
+ defget_unknown_size(self)->int:
+"""
+ :return: the size of the unkowns vector
+ """
+ returnnp.sum(self.isUnknown).item()# item() method casts to Python int
+
+
+
+[docs]
+ defcreate_field(self,Uu,Ubc=0.0)->Float[Array,"nn nd"]:
+"""
+ :param Uu: Vector of unknown values
+ :param Ubc: Values for bc to apply
+ :return: U, a field of unknowns and bcs combined.
+ """
+ U=np.zeros(self.isBc.shape).at[self.isBc].set(Ubc)
+ returnU.at[self.isUnknown].set(Uu)
+
+
+
+[docs]
+ defget_bc_values(self,U)->Float[Array,"nb"]:
+"""
+ :param U: a nodal field
+ :return: the bc values in the field U
+ """
+ returnU[self.isBc]
+
+
+
+[docs]
+ defget_unknown_values(self,U)->Float[Array,"nu"]:
+"""
+ :param U: a nodal field
+ :return: the unknown values in the field U
+ """
+ returnU[self.isUnknown]
+[docs]
+defvander2d(x,degree):
+ x=np.asarray(x)
+ nNodes=(degree+1)*(degree+2)//2
+ pq=pascal_triangle_monomials(degree)
+
+ # It's easier to process if the input arrays
+ # always have the same shape
+ # If a 1D array is given (a single point),
+ # convert to the equivalent 2D array
+ x=x.reshape(-1,2)
+
+ # switch to bi-unit triangle (-1,-1)--(1,-1)--(-1,1)
+ z=2.0*x-1.0
+
+ defmap_from_tri_to_square(xi):
+ small=1e-12
+ # The mapping has a singularity at the vertex (-1, 1).
+ # Handle that point specially.
+ indexSingular=xi[:,1]>1.0-small
+ xiShifted=xi.copy()
+ xiShifted[indexSingular,1]=1.0-small
+ eta=np.zeros_like(xi)
+ eta[:,0]=2.0*(1.0+xiShifted[:,0])/(1.0-xiShifted[:,1])-1.0
+ eta[:,1]=xiShifted[:,1]
+ eta[indexSingular,0]=-1.0
+ eta[indexSingular,1]=1.0
+
+ # Jacobian of map.
+ # Actually, deta is just the first row of the Jacobian.
+ # The second row is trivially [0, 1], so we don't compute it.
+ # We just use that fact directly in the derivative Vandermonde
+ # expressions.
+ deta=np.zeros_like(xi)
+ deta[:,0]=2/(1-xiShifted[:,1])
+ deta[:,1]=2*(1+xiShifted[:,0])/(1-xiShifted[:,1])**2
+ returneta,deta
+
+ E,dE=map_from_tri_to_square(np.asarray(z))
+
+ A=np.zeros((x.shape[0],nNodes))
+ Ax=A.copy()
+ Ay=A.copy()
+ N1D=np.polynomial.Polynomial([0.5,-0.5])
+ foriinrange(nNodes):
+ p=np.polynomial.Legendre.basis(pq[i,0])
+
+ # SciPy's polynomials use the deprecated poly1d type
+ # of NumPy. To convert to the modern Polynomial type,
+ # we need to reverse the order of the coefficients.
+ qPoly1d=special.jacobi(pq[i,1],2*pq[i,0]+1,0)
+ q=np.polynomial.Polynomial(qPoly1d.coef[::-1])
+
+ forjinrange(pq[i,0]):
+ q*=N1D
+
+ # orthonormality weight
+ weight=np.sqrt((2*pq[i,0]+1)*2*(pq[i,0]+pq[i,1]+1))
+
+ A[:,i]=weight*p(E[:,0])*q(E[:,1])
+
+ # derivatives
+ dp=p.deriv()
+ dq=q.deriv()
+ Ax[:,i]=2*weight*dp(E[:,0])*q(E[:,1])*dE[:,0]
+ Ay[:,i]=2*weight*(dp(E[:,0])*q(E[:,1])*dE[:,1]
+ +p(E[:,0])*dq(E[:,1]))
+
+ returnA,Ax,Ay
+
+
+
+# TODO add hessians maybe?
+
+[docs]
+classShapeFunctions(NamedTuple):
+"""
+ Shape functions and shape function gradients (in the parametric space).
+
+ :param values: Values of the shape functions at a discrete set of points.
+ Shape is ``(nPts, nNodes)``, where ``nPts`` is the number of
+ points at which the shame functinos are evaluated, and ``nNodes``
+ is the number of nodes in the element (which is equal to the
+ number of shape functions).
+ :param gradients: Values of the parametric gradients of the shape functions.
+ Shape is ``(nPts, nDim, nNodes)``, where ``nDim`` is the number
+ of spatial dimensions. Line elements are an exception, which
+ have shape ``(nPts, nNdodes)``.
+ """
+ values:Float[Array,"np nn"]
+ gradients:Float[Array,"np nd nn"]
+
+
+
+
+[docs]
+classBaseElement(eqx.Module):
+"""
+ Base class for different element technologies
+
+ :param elementType: Element type name
+ :param degree: Polynomial degree
+ :param coordinates: Nodal coordinates in the reference configuration
+ :param vertexNodes: Vertex node number, 0-based
+ :param faceNodes: Nodes associated with each face, 0-based
+ :param interiorNodes: Nodes in the interior, 0-based or empty
+ """
+ elementType:str
+ degree:int
+ coordinates:Float[Array,"nn nd"]
+ vertexNodes:Int[Array,"nn"]
+ faceNodes:Int[Array,"nf nnpf"]
+ interiorNodes:Int[Array,"nni"]
+
+ # def __init__(self, elementType, degree, coordinates, vertexNodes, faceNodes, interiorNodes):
+ # self.elementType = elementType
+ # self.degree = degree
+ # self.coordinates = coordinates
+ # self.vertexNodes = vertexNodes
+ # self.faceNodes = faceNodes
+ # self.interiorNodes = interiorNodes
+
+
+[docs]
+ @abstractmethod
+ defcompute_shapes(self,nodalPoints,evaluationPoints):
+"""
+ Method to be defined to calculate shape function values
+ and gradients given a list of nodal points (usually the vertexNodes)
+ and a list of evaluation points (usually the quadrature points).
+ """
+ pass
+
+
+
+ # TODO figure out how to rope in quadrature rules into this class
+
+fromjaxtypingimportArray,Float,Int
+from.meshimportMesh
+frompancax.fem.quadrature_rulesimportQuadratureRule
+frompancax.fem.elements.base_elementimportShapeFunctions
+frompancax.timerimportTimer
+importequinoxaseqx
+importjax
+importjax.numpyasnp
+
+
+# TODO need to do checks on inputs to make sure they're compatable
+
+[docs]
+ defcompute_field_gradient(self,u,X):
+"""
+ Takes in element level coordinates X and field u
+ """
+ grad_Ns=self.shape_function_gradients(X)
+ returnjax.vmap(lambdau,grad_N:u.T@grad_N,in_axes=(None,0))(u,grad_Ns)
+
+
+
+[docs]
+ defevaluate_on_element(self,U,X,state,dt,props,func):
+"""
+ Takes in element level field, coordinates, states, etc.
+ and evaluates the function func
+ """
+ Ns=self.shape_function_values(X)
+ grad_Ns=self.shape_function_gradients(X)
+
+ u_qs=jax.vmap(lambdau,N:u.T@N,in_axes=(None,0))(U,Ns)
+ grad_u_qs=jax.vmap(lambdau,grad_N:u.T@grad_N,in_axes=(None,0))(U,grad_Ns)
+ X_qs=jax.vmap(lambdaX,N:X.T@N,in_axes=(None,0))(X,Ns)
+ func_vals=jax.vmap(func,in_axes=(0,0,0,None,0,None))(
+ u_qs,grad_u_qs,state,props,X_qs,dt
+ )
+ returnfunc_vals
+[docs]
+classFunctionSpace(eqx.Module):
+"""
+ Data needed for calculus on functions in the discrete function space.
+
+ In describing the shape of the attributes, ``ne`` is the number of
+ elements in the mesh, ``nqpe`` is the number of quadrature points per
+ element, ``npe`` is the number of nodes per element, and ``nd`` is the
+ spatial dimension of the domain.
+
+ :param shapes: Shape function values on each element, shape (ne, nqpe, npe)
+ :param vols: Volume attributed to each quadrature point. That is, the
+ quadrature weight (on the parameteric element domain) multiplied by
+ the Jacobian determinant of the map from the parent element to the
+ element in the domain. Shape (ne, nqpe).
+ :param shapeGrads: Derivatives of the shape functions with respect to the
+ spatial coordinates of the domain. Shape (ne, nqpe, npe, nd).
+ :param mesh: The ``Mesh`` object of the domain.
+ :param quadratureRule: The ``QuadratureRule`` on which to sample the shape
+ functions.
+ :param isAxisymmetric: boolean indicating if the function space data are
+ axisymmetric.
+ """
+ shapes:Float[Array,"ne nqpe npe"]
+ vols:Float[Array,"ne nqpe"]
+ shapeGrads:Float[Array,"ne nqpe npe nd"]
+ # mesh: any
+ conns:Int[Array,"ne nnpe"]
+ quadratureRule:QuadratureRule
+ isAxisymmetric:bool
+
+
+
+
+[docs]
+defconstruct_function_space(mesh,quadratureRule,mode2D='cartesian'):
+"""Construct a discrete function space.
+
+ Parameters
+ ----------
+ :param mesh: The mesh of the domain.
+ :param quadratureRule: The quadrature rule to be used for integrating on the
+ domain.
+ :param mode2D: A string indicating how the 2D domain is interpreted for
+ integration. Valid values are ``cartesian`` and ``axisymmetric``.
+ Axisymetric mode will include the factor of 2*pi*r in the ``vols``
+ attribute.
+
+ Returns
+ -------
+ The ``FunctionSpace`` object.
+ """
+ withTimer('construct_function_space'):
+ # shapeOnRef = interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss)
+ shapeOnRef=mesh.parentElement.compute_shapes(mesh.parentElement.coordinates,quadratureRule.xigauss)
+ returnconstruct_function_space_from_parent_element(mesh,shapeOnRef,quadratureRule,mode2D)
+
+
+
+
+[docs]
+defconstruct_function_space_from_parent_element(mesh,shapeOnRef,quadratureRule,mode2D='cartesian'):
+"""
+ Construct a function space with precomputed shape function data on the parent element.
+
+ This version of the function space constructor is Jax-transformable,
+ and in particular can be jitted. The computation of the shape function
+ values and derivatives on the parent element is not transformable in
+ general. However, the mapping of the shape function data to the elements in
+ the mesh is transformable. One can precompute the parent element shape
+ functions once and for all, and then use this special factory function to
+ construct the function space and avoid the non-transformable part of the
+ operation. The primary use case is for shape sensitivities: the coordinates
+ of the mesh change, and we want Jax to pick up the sensitivities of the
+ shape function derivatives in space to the coordinate changes
+ (which occurs through the mapping from the parent element to the spatial
+ domain).
+
+ Parameters
+ ----------
+ :param mesh: The mesh of the domain.
+ :param shapeOnRef: A tuple of the shape function values and gradients on the
+ parent element, evaluated at the quadrature points. The caller must
+ take care to ensure the shape functions are evaluated at the same
+ points as contained in the ``quadratureRule`` parameter.
+ :param quadratureRule: The quadrature rule to be used for integrating on the
+ domain.
+ :param mode2D: A string indicating how the 2D domain is interpreted for
+ integration. See the default factory function for details.
+
+ Returns
+ -------
+ The ``FunctionSpace`` object.
+ """
+
+ shapes=jax.vmap(lambdaelConns,elShape:elShape,(0,None))(mesh.conns,shapeOnRef.values)
+
+ shapeGrads=jax.vmap(map_element_shape_grads,(None,0,None,None))(
+ mesh.coords,mesh.conns,mesh.parentElement,shapeOnRef.gradients
+ )
+
+ ifmode2D=='cartesian':
+ el_vols=compute_element_volumes
+ isAxisymmetric=False
+ elifmode2D=='axisymmetric':
+ el_vols=compute_element_volumes_axisymmetric
+ isAxisymmetric=True
+ vols=jax.vmap(el_vols,(None,0,None,None,None,None))(
+ mesh.coords,mesh.conns,mesh.parentElement,shapeOnRef.values,shapeOnRef.gradients,quadratureRule.wgauss
+ )
+
+ # return FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric)
+ returnFunctionSpace(shapes,vols,shapeGrads,mesh.conns,quadratureRule,isAxisymmetric)
+
+
+
+
+[docs]
+defmap_element_shape_grads(coordField,nodeOrdinals,parentElement,shapeGradients):
+ # coords here should be 3 x 2
+ # shapegrads shoudl be 3 x 2
+ # need J to be 2 x 2 but be careful about transpose
+ # below from Cthonios
+ # J = (X_el * ∇N_ξ)'
+ # J_inv = inv(J)
+ # ∇N_X = (J_inv * ∇N_ξ')'
+ Xn=coordField.take(nodeOrdinals,0)
+ Js=jax.vmap(lambdax,dN:(x.T@dN).T,in_axes=(None,0))(Xn,shapeGradients)
+ Jinvs=jax.vmap(lambdaJ:np.linalg.inv(J),in_axes=(0,))(Js)
+ returnjax.vmap(lambdaJinv,dN:(Jinv@dN.T).T,in_axes=(0,0))(Jinvs,shapeGradients)
+[docs]
+defintegrate_over_block(functionSpace,U,X,stateVars,props,dt,func,block,
+ *params,modify_element_gradient=default_modify_element_gradient):
+"""
+ Integrates a density function over a block of the mesh.
+
+ :param functionSpace: Function space object to do the integration with.
+ :param U: The vector of dofs for the primal field in the functional.
+ :param X: Nodal coordinates
+ :param stateVars: Internal state variable array.
+ :param dt: Current time increment
+ :param func: Lagrangian density function to integrate, Must have the signature
+ ``func(u, dudx, q, x, *params) -> scalar``, where ``u`` is the primal field, ``q`` is the
+ value of the internal variables, ``x`` is the current point coordinates, and ``*params`` is
+ a variadic set of additional parameters, which correspond to the ``*params`` argument.
+ block: Group of elements to integrate over. This is an array of element indices. For
+ performance, the elements within the block should be numbered consecutively.
+ :param modify_element_gradient: Optional function that modifies the gradient at the element level.
+ This can be to set the particular 2D mode, and additionally to enforce volume averaging
+ on the gradient operator. This is a keyword-only argument.
+
+ Returns
+ A scalar value for the integral of the density functional ``func`` integrated over the
+ block of elements.
+ """
+ # below breaks the sphinx doc stuff
+ # :param *params: Optional parameter fields to pass into Lagrangian density function. These are
+
+
+ vals=evaluate_on_block(functionSpace,U,X,stateVars,props,dt,func,block,*params,modify_element_gradient=modify_element_gradient)
+ returnnp.dot(vals.ravel(),functionSpace.vols[block].ravel())
+[docs]
+defevaluate_on_block(functionSpace,U,X,stateVars,dt,props,func,block,
+ *params,modify_element_gradient=default_modify_element_gradient):
+"""Evaluates a density function at every quadrature point in a block of the mesh.
+
+ :param functionSpace: Function space object to do the evaluation with.
+ :param U: The vector of dofs for the primal field in the functional.
+ :param X: Nodal coordinates
+ :param stateVars: Internal state variable array.
+ :param dt: Current time increment
+ :param func: Lagrangian density function to evaluate, Must have the signature
+ ```func(u, dudx, q, x, *params) -> scalar```, where ```u``` is the primal field, ```q``` is the
+ value of the internal variables, ```x``` is the current point coordinates, and ```*params``` is
+ a variadic set of additional parameters, which correspond to the ```*params``` argument.
+ :param block: Group of elements to evaluate over. This is an array of element indices. For
+ performance, the elements within the block should be numbered consecutively.
+
+ :param modify_element_gradient: Optional function that modifies the gradient at the element level.
+ This can be to set the particular 2D mode, and additionally to enforce volume averaging
+ on the gradient operator. This is a keyword-only argument.
+
+ Returns
+ An array of shape (numElements, numQuadPtsPerElement) that contains the scalar values of the
+ density functional ```func``` at every quadrature point in the block.
+ """
+ # below breaks sphinx doc stuff
+ # :param *params: Optional parameter fields to pass into Lagrangian density function. These are
+ # represented as a single value per element.
+ fs=functionSpace
+ compute_elem_values=jax.vmap(evaluate_on_element,(None,None,0,None,None,0,0,0,0,None,None,*tuple(0forpinparams)))
+
+ blockValues=compute_elem_values(U,X,stateVars[block],props,dt,fs.shapes[block],
+ fs.shapeGrads[block],fs.vols[block],
+ fs.conns[block],func,modify_element_gradient,*params)
+ returnblockValues
+
+
+
+
+[docs]
+defintegrate_element_from_local_field(elemNodalField,elemNodalCoords,elemStates,dt,elemShapes,elemShapeGrads,elemVols,func,modify_element_gradient=default_modify_element_gradient):
+"""
+ Integrate over element with element nodal field as input.
+ This allows element residuals and element stiffness matrices to computed.
+ """
+ elemVals=jax.vmap(interpolate_to_point,(None,0))(elemNodalField,elemShapes)
+ elemGrads=jax.vmap(compute_quadrature_point_field_gradient,(None,0))(elemNodalField,elemShapeGrads)
+ elemGrads=modify_element_gradient(elemGrads,elemShapes,elemVols,elemNodalField,elemNodalCoords)
+ elemPoints=jax.vmap(interpolate_to_point,(None,0))(elemNodalCoords,elemShapes)
+ fVals=jax.vmap(func,(0,0,0,0,None))(elemVals,elemGrads,elemStates,elemPoints,dt)
+ returnnp.dot(fVals,elemVols)
+[docs]
+defget_nodal_values_on_edge(functionSpace,nodalField,edge):
+"""
+ Get nodal values of a field on an element edge.
+
+ :param functionSpace: a FunctionSpace object
+ :param nodalField: The nodal vector defined over the mesh (shape is number of
+ nodes by number of field components)
+ :param edge: tuple containing the element number containing the edge and the
+ permutation (0, 1, or 2) of the edge within the triangle
+ """
+ edgeNodes=functionSpace.mesh.parentElement.faceNodes[edge[1],:]
+ nodes=functionSpace.mesh.conns[edge[0],edgeNodes]
+ returnnodalField[nodes]
+
+
+
+
+[docs]
+definterpolate_nodal_field_on_edge(functionSpace,U,interpolationPoints,edge):
+"""
+ Interpolate a nodal field to specified points on an element edge.
+
+ :param functionSpace: a FunctionSpace object
+ :param U: the nodal values array
+ :param interpolationPoints: coordinates of points (in the 1D parametric space) to
+ interpolate to
+ :param edge: tuple containing the element number containing the edge and the
+ permutation (0, 1, or 2) of the edge within the triangle
+ """
+ # edgeShapes = interpolants.compute_shapes(functionSpace.mesh.parentElement1d, interpolationPoints)
+ edgeShapes=functionSpace.mesh.parentElement1d.compute_shapes(
+ functionSpace.mesh.parentElement1d.coordinates,interpolationPoints
+ )
+ edgeU=get_nodal_values_on_edge(functionSpace,U,edge)
+ returnedgeShapes.values.T@edgeU
+[docs]
+classMesh(eqx.Module):
+"""
+ Triangle mesh representing a domain.
+
+ :param coords: Coordinates of the nodes, shape ``(nNodes, nDim)``.
+ :param conns: Nodal connectivity table of the elements.
+ :param simplexNodesOrdinals: Indices of the nodes that are vertices.
+ :param parentElement: A ``ParentElement`` that is the element type in
+ parametric space. A mesh can contain only 1 element type.
+ :param parentElement1d:
+ :param blocks: A dictionary mapping element block names to the indices of the
+ elements in the block.
+ :param nodeSets: A dictionary mapping node set names to the indices of the
+ nodes.
+ :param sideSets: A dictionary mapping side set names to the edges. The
+ edge data structure is a tuple of the element index and the local
+ number of the edge within that element. For example, triangle
+ elements will have edge 0, 1, or 2 for this entry.
+ """
+ coords:Float[Array,"nn nd"]
+ conns:Float[Array,"ne nnpe"]
+ simplexNodesOrdinals:Float[Array,"ne 3"]
+ # TODO finish out the typing below
+ parentElement:any
+ parentElement1d:any
+ blocks:Optional[Dict[str,Float]]=None
+ nodeSets:Optional[Dict[str,Float]]=None
+ sideSets:Optional[Dict[str,Float]]=None
+
+
+[docs]
+defcreate_edges(conns):
+"""Generate topological information about edges in a triangulation.
+
+ Parameters
+ ----------
+ conns : (nTriangles, 3) array
+ Connectivity table of the triangulation.
+
+ Returns
+ -------
+ edgeConns : (nEdges, 2) array
+ Vertices of each edge. Boundary edges are always in the
+ counter-clockwise sense, so that the interior of the body is on the left
+ side when walking from the first vertex to the second.
+ edges : (nEdges, 4) array
+ Edge-to-triangle topological information. Each row provides the
+ follwing information for each edge: [leftT, leftP, rightT, rightP],
+ where leftT is the ID of the triangle to the left, leftP is the
+ permutation of the edge in the left triangle (edge 0, 1, or 2), rightT
+ is the ID of the triangle to the right, and rightP is the permutation
+ of the edge in the right triangle. If the edge is a boundary edge, the
+ values of rightT and rightP are -1.
+ """
+ nTris=conns.shape[0]
+ allTriFaces=onp.vstack((conns[:,(0,1)],conns[:,(1,2)],conns[:,(2,0)]))
+ foo=onp.sort(allTriFaces,axis=1)
+ bar,i=onp.unique(foo,return_index=True,axis=0)
+ edgeConns=(allTriFaces[i,:])
+
+ nEdges=edgeConns.shape[0]
+ edges=-onp.ones((nEdges,4),dtype=onp.int_)
+ edgeElementIds=onp.tile(np.arange(nTris),3)
+ edges[:,0]=edgeElementIds[i]
+ edges[:,1]=i//nTris
+
+ fori,ecinenumerate(edgeConns):
+ rowsMatch=onp.all(onp.flip(ec)==allTriFaces,axis=1)
+ ifonp.any(rowsMatch):
+ j=onp.where(rowsMatch)[0][0]
+ # there should only be one matching row, but take element 0
+ # because j will have the same number of axes (2) as
+ # rowsMatch.
+ edges[i,2]=edgeElementIds[j]
+ edges[i,3]=j//nTris
+
+ returnedgeConns,edges
+[docs]
+defget_edge_field(mesh:Mesh,edge,field):
+"""
+ Evaluate field on nodes of an element edge.
+ Arguments:
+
+ :param mesh: a Mesh object
+ :param edge: tuple containing the element number containing the edge and the
+ permutation (0, 1, or 2) of the edge within the triangle
+ """
+ returnfield[get_edge_node_indices(mesh,edge)]
+[docs]
+defcompute_edge_vectors(mesh:Mesh,edgeCoords):
+"""
+ Get geometric vectors for an element edge.
+
+ Assumes that the edgs has a constant shape jacobian, that is, the
+ transformation from the parent element is affine.
+
+ Arguments
+ :param mesh: a Mesh object
+ :param edgeCoords: coordinates of all nodes on the edge, in the order
+ defined by the 1D parent element convention
+
+ Returns
+ tuple (t, n, j) with
+ :return t: the unit tangent vector
+ :return n: the outward unit normal vector
+ :return j: jacobian of the transformation from parent to physical space
+ """
+ Xv=edgeCoords[mesh.parentElement1d.vertexNodes,:]
+ tangent=Xv[1]-Xv[0]
+ normal=np.array([tangent[1],-tangent[0]])
+ jac=np.linalg.norm(tangent)
+ returntangent/jac,normal/jac,jac
+from.elementsimport*
+from.elements.base_elementimportBaseElement
+fromjax.laximportswitch
+fromjaxtypingimportArray,Float
+frompancax.timerimportTimer
+importequinoxaseqx
+importjax.numpyasjnp
+importmath
+importnumpyasojnp
+importscipy.special
+
+
+# TODO
+# think about moving this stuff into elements
+
+
+[docs]
+classQuadratureRule(eqx.Module):
+"""
+ Quadrature rule points and weights.
+ A ``namedtuple`` containing ``xigauss``, a numpy array of the
+ coordinates of the sample points in the reference domain, and
+ ``wgauss``, a numpy array with the weights.
+
+ :param xigauss: coordinates of gauss points in reference element
+ :param wgauss: weights of gauss points in reference element
+ """
+ xigauss:Float[Array,"nq nd"]
+ wgauss:Float[Array,"nq"]
+
+
+
+
+ # TODO maybe there's a better way?
+ # TODO below is just so we don't have to change any tests for now
+ def__iter__(self):
+ yieldself.xigauss
+ yieldself.wgauss
+
+ def__len__(self):
+ returnself.xigauss.shape[0]
+
+
+
+
+[docs]
+defcreate_quadrature_rule_1D(degree:int)->QuadratureRule:
+"""Creates a Gauss-Legendre quadrature on the unit interval.
+
+ The rule can exactly integrate polynomials of degree up to
+ ``degree``.
+
+ Parameters
+ ----------
+ degree: Highest degree polynomial to be exactly integrated by the quadrature rule
+
+ Returns
+ -------
+ A ``QuadratureRule`` named tuple containing the quadrature point coordinates
+ and the weights.
+ """
+
+ n=math.ceil((degree+1)/2)
+ xi,w=scipy.special.roots_sh_legendre(n)
+ returnxi,w
+[docs]
+defcreate_quadrature_rule_on_triangle(degree:int)->QuadratureRule:
+"""Creates a Gauss-Legendre quadrature on the unit triangle.
+
+ The rule can exactly integrate 2D polynomials up to the value of
+ ``degree``. The domain is the triangle between the vertices
+ (0, 0)-(1, 0)-(0, 1). The rules here are guaranteed to be
+ cyclically symmetric in triangular coordinates and to have strictly
+ positive weights.
+
+ Parameters
+ ----------
+ degree: Highest degree polynomial to be exactly integrated by the quadrature rule
+
+ Returns
+ -------
+ A ``QuadratureRule`` named tuple containing the quadrature point coordinates
+ and the weights.
+ """
+ ifdegree==1:
+ xi=ojnp.array([[3.33333333333333333E-01,3.33333333333333333E-01]])
+
+ w=ojnp.array([5.00000000000000000E-01])
+ elifdegree==2:
+ xi=ojnp.array([[6.66666666666666667E-01,1.66666666666666667E-01],
+ [1.66666666666666667E-01,6.66666666666666667E-01],
+ [1.66666666666666667E-01,1.66666666666666667E-01]])
+
+ w=ojnp.array([1.66666666666666666E-01,
+ 1.66666666666666667E-01,
+ 1.66666666666666667E-01])
+ elifdegree<=4:
+ xi=ojnp.array([[1.081030181680700E-01,4.459484909159650E-01],
+ [4.459484909159650E-01,1.081030181680700E-01],
+ [4.459484909159650E-01,4.459484909159650E-01],
+ [8.168475729804590E-01,9.157621350977100E-02],
+ [9.157621350977100E-02,8.168475729804590E-01],
+ [9.157621350977100E-02,9.157621350977100E-02]])
+
+ w=ojnp.array([1.116907948390055E-01,
+ 1.116907948390055E-01,
+ 1.116907948390055E-01,
+ 5.497587182766100E-02,
+ 5.497587182766100E-02,
+ 5.497587182766100E-02])
+ elifdegree<=5:
+ xi=ojnp.array([[3.33333333333333E-01,3.33333333333333E-01],
+ [5.97158717897700E-02,4.70142064105115E-01],
+ [4.70142064105115E-01,5.97158717897700E-02],
+ [4.70142064105115E-01,4.70142064105115E-01],
+ [7.97426985353087E-01,1.01286507323456E-01],
+ [1.01286507323456E-01,7.97426985353087E-01],
+ [1.01286507323456E-01,1.01286507323456E-01]])
+
+ w=ojnp.array([1.12500000000000E-01,
+ 6.61970763942530E-02,
+ 6.61970763942530E-02,
+ 6.61970763942530E-02,
+ 6.29695902724135E-02,
+ 6.29695902724135E-02,
+ 6.29695902724135E-02])
+ elifdegree<=6:
+ xi=ojnp.array([[5.01426509658179E-01,2.49286745170910E-01],
+ [2.49286745170910E-01,5.01426509658179E-01],
+ [2.49286745170910E-01,2.49286745170910E-01],
+ [8.73821971016996E-01,6.30890144915020E-02],
+ [6.30890144915020E-02,8.73821971016996E-01],
+ [6.30890144915020E-02,6.30890144915020E-02],
+ [5.31450498448170E-02,3.10352451033784E-01],
+ [6.36502499121399E-01,5.31450498448170E-02],
+ [3.10352451033784E-01,6.36502499121399E-01],
+ [5.31450498448170E-02,6.36502499121399E-01],
+ [6.36502499121399E-01,3.10352451033784E-01],
+ [3.10352451033784E-01,5.31450498448170E-02]])
+
+ w=ojnp.array([5.83931378631895E-02,
+ 5.83931378631895E-02,
+ 5.83931378631895E-02,
+ 2.54224531851035E-02,
+ 2.54224531851035E-02,
+ 2.54224531851035E-02,
+ 4.14255378091870E-02,
+ 4.14255378091870E-02,
+ 4.14255378091870E-02,
+ 4.14255378091870E-02,
+ 4.14255378091870E-02,
+ 4.14255378091870E-02])
+ elifdegree<=10:
+ xi=ojnp.array([[0.33333333333333333E+00,0.33333333333333333E+00],
+ [0.4269134091050342E-02,0.49786543295447483E+00],
+ [0.49786543295447483E+00,0.4269134091050342E-02],
+ [0.49786543295447483E+00,0.49786543295447483E+00],
+ [0.14397510054188759E+00,0.42801244972905617E+00],
+ [0.42801244972905617E+00,0.14397510054188759E+00],
+ [0.42801244972905617E+00,0.42801244972905617E+00],
+ [0.6304871745135507E+00,0.18475641274322457E+00],
+ [0.18475641274322457E+00,0.6304871745135507E+00],
+ [0.18475641274322457E+00,0.18475641274322457E+00],
+ [0.9590375628566448E+00,0.20481218571677562E-01],
+ [0.20481218571677562E-01,0.9590375628566448E+00],
+ [0.20481218571677562E-01,0.20481218571677562E-01],
+ [0.3500298989727196E-01,0.1365735762560334E+00],
+ [0.1365735762560334E+00,0.8284234338466947E+00],
+ [0.8284234338466947E+00,0.3500298989727196E-01],
+ [0.1365735762560334E+00,0.3500298989727196E-01],
+ [0.8284234338466947E+00,0.1365735762560334E+00],
+ [0.3500298989727196E-01,0.8284234338466947E+00],
+ [0.37549070258442674E-01,0.3327436005886386E+00],
+ [0.3327436005886386E+00,0.6297073291529187E+00],
+ [0.6297073291529187E+00,0.37549070258442674E-01],
+ [0.3327436005886386E+00,0.37549070258442674E-01],
+ [0.6297073291529187E+00,0.3327436005886386E+00],
+ [0.37549070258442674E-01,0.6297073291529187E+00]])
+
+ w=ojnp.array([0.4176169990259819E-01,
+ 0.36149252960283717E-02,
+ 0.36149252960283717E-02,
+ 0.36149252960283717E-02,
+ 0.3724608896049025E-01,
+ 0.3724608896049025E-01,
+ 0.3724608896049025E-01,
+ 0.39323236701554264E-01,
+ 0.39323236701554264E-01,
+ 0.39323236701554264E-01,
+ 0.3464161543553752E-02,
+ 0.3464161543553752E-02,
+ 0.3464161543553752E-02,
+ 0.147591601673897E-01,
+ 0.147591601673897E-01,
+ 0.147591601673897E-01,
+ 0.147591601673897E-01,
+ 0.147591601673897E-01,
+ 0.147591601673897E-01,
+ 0.1978968359803062E-01,
+ 0.1978968359803062E-01,
+ 0.1978968359803062E-01,
+ 0.1978968359803062E-01,
+ 0.1978968359803062E-01,
+ 0.1978968359803062E-01])
+ else:
+ raiseValueError("Quadrature of precision this high is not implemented.")
+
+ returnxi,w
+
+
+
+# TODO remove this stuff
+
+[docs]
+defcreate_padded_quadrature_rule_1D(degree):
+"""Creates 1D Gauss quadrature rule data that are padded to maintain a
+ uniform size, which makes this function jit-able.
+
+ This function is inteded to be used only when jit compilation of calls to the
+ quadrature rules are needed. Otherwise, prefer to use the standard quadrature
+ rules. The standard rules do not contain extra 0s for padding, which makes
+ them more efficient when used repeatedly (such as in the global energy).
+
+ Args:
+ degree: degree of highest polynomial to be integrated exactly
+ """
+
+ jnpts=jnp.ceil((degree+1)/2).astype(int)
+ xi,w=switch(jnpts,
+ [_gauss_quad_1D_1pt,_gauss_quad_1D_2pt,_gauss_quad_1D_3pt,
+ _gauss_quad_1D_4pt,_gauss_quad_1D_5pt],
+ None)
+ return0.5*(xi+1.0),0.5*w
+[docs]
+definterpolate_nodal_field_on_edge(mesh,U,interpolationPoints,edge):
+ # This function isn't used yet.
+ # We may want to replicate parts of FunctionSpace to do surface integrals.
+ #
+ fieldIndex=Surface.get_field_index(edge,mesh.conns)
+ nodalValues=Surface.eval_field(U,fieldIndex)
+ return0.0
+[docs]
+deffull_tensor_names(base_name:str):
+"""
+ Provides a full list of tensorial variable component names
+ :param base_name: base name for a tensor variable e.g. base_name_xx
+ """
+ return[
+ f'{base_name}_xx',f'{base_name}_xy',f'{base_name}_xz',
+ f'{base_name}_yx',f'{base_name}_yy',f'{base_name}_yz',
+ f'{base_name}_zx',f'{base_name}_zy',f'{base_name}_zz'
+ ]
+
+
+
+
+[docs]
+deffull_tensor_names_2D(base_name:str):
+"""
+ Provides a full list of tensorial variable component names
+ :param base_name: base name for a tensor variable e.g. base_name_xx
+ """
+ return[
+ f'{base_name}_xx',f'{base_name}_xy',
+ f'{base_name}_yx',f'{base_name}_yy'
+ ]
+
+
+# TODO improve this
+
+[docs]
+defelement_pp(func,has_props=False,jit=True):
+"""
+ :param func: Function to use for an element property output variable
+ :param has_props: Whether or not this function need properties
+ :param jit: Whether or not to jit this function
+ """
+ ifjit:
+ returneqx.filter_jit(func)
+ else:
+ returnfunc
+
+
+
+
+[docs]
+defnodal_pp(func,has_props=False,jit=True):
+"""
+ :param func: Function to use for a nodal property output variable
+ :param has_props: Whether or not this function need properties
+ :param jit: Whether or not to jit this function
+ """
+ ifhas_props:
+ # new_func = lambda p, d, t: vmap(
+ # func, in_axes=(None, 0, None, None)
+ # )(p.fields, d.coords, t, p.properties)
+ new_func=lambdap,d,t:vmap(
+ func,in_axes=(None,0,None)
+ )(p,d.coords,t)
+ else:
+ new_func=lambdap,d,t:vmap(
+ func,in_axes=(None,0,None)
+ )(p.fields,d.coords,t)
+
+ ifjit:
+ returneqx.filter_jit(new_func)
+ else:
+ returnnew_func
+
+
+
+# make a standard pp method that just has nodal fields, element grads, etc.
+
+[docs]
+classPhysicsKernel(ABC):
+ n_dofs:int
+ field_value_names:List[str]
+ bc_func:Callable# further type this guy
+ var_name_to_method:Dict[str,Dict[str,Union[Callable,List[str]]]]={}
+ use_delta_pinn:bool
+
+
+[docs]
+classStrongFormPhysicsKernel(PhysicsKernel):
+ n_dofs:int
+ field_value_names:List[str]
+ bc_func:Callable# further type this guy
+ use_delta_pinn:bool
+
+
+[docs]
+ def__init__(self,mesh_file,bc_func,use_delta_pinn)->None:
+ # if use_delta_pinn:
+ # raise ValueError('DeltaPINNs are currently not supported with collocation PINNs.')
+ super().__init__(mesh_file,bc_func,use_delta_pinn)
+[docs]
+classWeakFormPhysicsKernel(PhysicsKernel):
+ n_dofs:int
+ field_value_names:List[str]
+ bc_func:Callable# further type this guy
+ use_delta_pinn:bool
+
+
+[docs]
+classBaseLossFunction(eqx.Module):
+"""
+ Base class for loss functions.
+ Currently does nothing but helps build a
+ type hierarchy.
+ """
+
+
+[docs]
+classBCLossFunction(BaseLossFunction):
+"""
+ Base class for boundary condition loss functions.
+
+ A ``load_step`` method is expect with the following
+ type signature
+ ``load_step(self, params, domain, t)``
+ """
+
+[docs]
+classPhysicsLossFunction(BaseLossFunction):
+"""
+ Base class for physics loss functions.
+
+ A ``load_step`` method is expect with the following
+ type signature
+ ``load_step(self, params, domain, t)``
+ """
+
+
+
+ def__call__(self,params,domain):
+ field_network,_=params
+ n_dims=domain.coords.shape[1]
+ xs=domain.field_data.inputs[:,0:n_dims]
+ # TODO need time normalization
+ ts=domain.field_data.inputs[:,n_dims]
+ # TODO below is currenlty the odd ball for the field_value API
+ u_pred=vmap(domain.physics.field_values,in_axes=(None,0,0))(
+ field_network,xs,ts
+ )
+
+ # TODO add output normalization
+ loss=jnp.square(u_pred-domain.field_data.outputs).mean()
+ aux={'field_data_loss':loss}
+ returnself.weight*loss,aux
Source code for pancax.loss_functions.strong_form_loss_functions
+from.base_loss_functionimportPhysicsLossFunction
+fromjaximportvmap
+fromtypingimportOptional
+importjax.numpyasjnp
+
+
+# NOTE this probably does not currently dsupport deltaPINNs
+
+[docs]
+ defload_step(self,params,domain,t):
+ func=domain.physics.strong_form_residual
+ # TODO this will fail on delta PINNs currently
+ residuals=vmap(func,in_axes=(None,0,None))(params,domain.coords,t)
+ returnjnp.square(residuals).mean()
+[docs]
+classEnergyLoss(PhysicsLossFunction):
+r"""
+ Energy loss function akin to the deep energy method.
+
+ Calculates the following quantity
+
+ .. math::
+ \mathcal{L} = w\Pi\left[u\right] = w\int_\Omega\psi\left(\mathbf{F}\right)
+
+ :param weight: weight for this loss function
+ """
+ weight:float
+
+
+[docs]
+classEnergyAndResidualLoss(PhysicsLossFunction):
+r"""
+ Energy and residual loss function used in Hamel et. al
+
+ Calculates the following quantity
+
+ .. math::
+ \mathcal{L} = w_1\Pi\left[u\right] + w_2\delta\Pi\left[u\right]_{free}
+
+ :param energy_weight: Weight for the energy w_1
+ :param residual_weight: Weight for the residual w_2
+ """
+ energy_weight:float
+ residual_weight:float
+
+
+[docs]
+classIncompressibleEnergyLoss(PhysicsLossFunction):
+r"""
+ Energy loss function akin to the deep energy method.
+
+ Calculates the following quantity
+
+ .. math::
+ \mathcal{L} = w\Pi\left[u\right] = w\int_\Omega\psi\left(\mathbf{F}\right)
+
+ :param weight: weight for this loss function
+ """
+ weight:float
+
+
+[docs]
+classIncompressibleEnergyAndResidualLoss(PhysicsLossFunction):
+r"""
+ Energy and residual loss function used in Hamel et. al
+
+ Calculates the following quantity
+
+ .. math::
+ \mathcal{L} = w_1\Pi\left[u\right] + w_2\delta\Pi\left[u\right]_{free}
+
+ :param energy_weight: Weight for the energy w_1
+ :param residual_weight: Weight for the residual w_2
+ """
+ energy_weight:float
+ residual_weight:float
+
+
+[docs]
+defsum2(a):
+"""
+ Sum a vector to much higher accuracy than numpy.sum.
+
+ Parameters
+ ----------
+ a : ndarray, with only one axis (shape [n,])
+
+ Returns
+ -------
+ sum : real
+ The sum of the numbers in the array
+
+
+ This special sum method computes the result as accurate as if
+ computed in quadruple precision.
+
+ Reference:
+ T. Ogita, S. M. Rump, and S. Oishi. Accurate sum and dot product.
+ SIAM J. Sci. Comput., Vol 26, No 6, pp. 1955-1988.
+ doi: 10.1137/030601818
+ """
+ deff(carry,ai):
+ p,sigma=carry
+ p,q=_two_sum(p,ai)
+ sigma+=q
+ return(p,sigma),p
+
+ total=0.0
+ c=0.0
+ (total,c),partialSums=lax.scan(f,(total,c),a)
+ returntotal+c
+[docs]
+defdot2(x,y):
+"""
+ Compute inner product of 2 vectors to much higher accuracy than numpy.dot.
+
+ Parameters
+ ----------
+ :param x: ndarray, with only one axis (shape [n,])
+ :param y: ndarray, with only one axis (shape [n,])
+
+ Returns
+ -------
+ :return dotprod: real
+ The inner product of the input vectors.
+
+
+ This special inner product method computes the result as accurate
+ as if computed in quadruple precision. This algorithm is useful to
+ computing objective functions from numerical integration. It avoids
+ accumulation of floating point cancellation error that can obscure
+ whether an objective function has truly decreased.
+
+ The environment variable setting
+ 'XLA_FLAGS = "--xla_cpu_enable_fast_math=false"'
+ is critical for this function to work on the CPU. Otherwise, xla
+ apparently sets a flag for LLVM that allows unsafe floating point
+ optimizations that can change associativity.
+
+ Reference
+ T. Ogita, S. M. Rump, and S. Oishi. Accurate sum and dot product.
+ SIAM J. Sci. Comput., Vol 26, No 6, pp. 1955-1988.
+ doi 10.1137/030601818
+
+ """
+ deff(carry,xy):
+ p,s=carry
+ xi,yi=xy
+ h,r=_two_product(xi,yi)
+ p,q=_two_sum(p,h)
+ s=s+(q+r)
+ return(p,s),p
+
+ rawTotal=0.0
+ compensation=0.0
+ X=np.column_stack((x,y))
+ (rawTotal,compensation),partialSums=lax.scan(f,
+ (rawTotal,compensation),
+ X)
+ returnrawTotal+compensation
+[docs]
+deftriaxiality(A):
+ mean_normal=np.trace(A)/3.0
+ mises_norm=mises_equivalent_stress(A)
+ # avoid division by zero in case of spherical tensor
+ mises_norm+=np.finfo(np.dtype("float64")).eps
+ returnmean_normal/mises_norm
+
+
+
+# Compute eigen values and vectors of a symmetric 3x3 tensor
+# Note, returned eigen vectors may not be unit length
+#
+# Note, this routine involves high powers of the input tensor (~M^8).
+# Thus results can start to denormalize when the infinity norm of the input
+# tensor falls outside the range 1.0e-40 to 1.0e+40.
+#
+# Outside this range use eigen_sym33_unit
+
+
+
+
+# Helper function for 3x3 spectral decompositions
+# Pade approximation to cos( acos(x)/3 )
+# was obtained from Mathematica with the following commands:
+#
+# Needs["FunctionApproximations`"]
+# r1 = MiniMaxApproximation[Cos[ArcCos[x]/3], {x, {0, 1}, 6, 5}, WorkingPrecision -> 18, MaxIterations -> 500]
+#
+# 6 and 5 indicate the polynomial order in the numerator and denominator.
+
+[docs]
+@sqrtm.defjvp
+defjvp_sqrtm(primals,tangents):
+ A,=primals
+ H,=tangents
+ sqrtA=sqrtm(A)
+ dim=A.shape[0]
+ # TODO(brandon): Use a stable algorithm for solving a Sylvester equation.
+ # See https://en.wikipedia.org/wiki/Bartels%E2%80%93Stewart_algorithm
+ # The following will only reliably work for small matrices.
+ I=np.identity(dim)
+ M=np.kron(sqrtA.T,I)+np.kron(I,sqrtA)
+ Hvec=H.T.ravel()
+ returnsqrtA,(linalg.solve(M,Hvec)).reshape((dim,dim)).T
+
+
+
+
+[docs]
+defsqrtm_dbp(A):
+""" Matrix square root by product form of Denman-Beavers iteration.
+
+ Translated from the Matrix Function Toolbox
+ http://www.ma.man.ac.uk/~higham/mftoolbox
+ Nicholas J. Higham, Functions of Matrices: Theory and Computation,
+ SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
+ """
+ dim=A.shape[0]
+ tol=0.5*np.sqrt(dim)*np.finfo(np.dtype("float64")).eps
+ maxIters=32
+ scaleTol=0.01
+
+ defscaling(M):
+ d=np.abs(linalg.det(M))**(1.0/(2.0*dim))
+ g=1.0/d
+ returng
+
+ defcond_f(loopData):
+ _,_,error,k,_=loopData
+ p=np.array([k<maxIters,error>tol],dtype=bool)
+ returnnp.all(p)
+
+ defbody_f(loopData):
+ X,M,error,k,diff=loopData
+ g=np.where(diff>=scaleTol,
+ scaling(M),
+ 1.0)
+
+ X*=g
+ M*=g*g
+
+ Y=X
+ N=linalg.inv(M)
+ I=np.identity(dim)
+ X=0.5*X@(I+N)
+ M=0.5*(I+0.5*(M+N))
+ error=np.linalg.norm(M-I,'fro')
+ diff=np.linalg.norm(X-Y,'fro')/np.linalg.norm(X,'fro')
+ k+=1
+ return(X,M,error,k,diff)
+
+ X0=A
+ M0=A
+ error0=np.finfo(np.dtype("float64")).max
+ k0=0
+ diff0=2.0*scaleTol# want to force scaling on first iteration
+ loopData0=(X0,M0,error0,k0,diff0)
+
+ X,_,_,k,_=while_loop(cond_f,body_f,loopData0)
+
+ returnX,k
+[docs]
+def_logm_iss(A):
+"""Logarithmic map by inverse scaling and squaring and Padé approximants
+
+ Translated from the Matrix Function Toolbox
+ http://www.ma.man.ac.uk/~higham/mftoolbox
+ Nicholas J. Higham, Functions of Matrices: Theory and Computation,
+ SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
+ """
+ dim=A.shape[0]
+ c15=log_pade_coefficients[15]
+
+ defcond_f(loopData):
+ _,_,k,_,_,converged=loopData
+ conditions=np.array([~converged,k<16],dtype=bool)
+ returnconditions.all()
+
+ defcompute_pade_degree(diff,j,itk):
+ j+=1
+ # Manually force the return type of searchsorted to be 64-bit int, because it
+ # returns 32-bit ints, ignoring the global `jax_enable_x64` flag. This looks
+ # like a bug. I filed an issue (#11375) with Jax to correct this.
+ # If they fix it, the conversions on p and q can be removed.
+ p=np.searchsorted(log_pade_coefficients[2:16],diff,side='right').astype(np.int64)
+ p+=2
+ q=np.searchsorted(log_pade_coefficients[2:16],diff/2.0,side='right').astype(np.int64)
+ q+=2
+ m,j,converged=if_then_else((2*(p-q)//3<itk)|(j==2),
+ (p+1,j,True),(0,j,False))
+ returnm,j,converged
+
+ defbody_f(loopData):
+ X,j,k,m,itk,converged=loopData
+ diff=np.linalg.norm(X-np.identity(dim),ord=1)
+ m,j,converged=if_then_else(diff<c15,
+ compute_pade_degree(diff,j,itk),
+ (m,j,converged))
+ X,itk=sqrtm_dbp(X)
+ k+=1
+ returnX,j,k,m,itk,converged
+
+ X=A
+ j=0
+ k=0
+ m=0
+ itk=5
+ converged=False
+ X,j,k,m,itk,converged=while_loop(cond_f,body_f,(X,j,k,m,itk,converged))
+ returnX,k,m
+[docs]
+classBasePancaxModel(eqx.Module):
+"""
+ Base class for pancax model parameters.
+
+ This includes a few helper methods
+ """
+
+[docs]
+ defserialise(self,base_name,epoch):
+ file_name=f'{base_name}_{str(epoch).zfill(7)}.eqx'
+ print(f'Serialising current parameters to {file_name}')
+ eqx.tree_serialise_leaves(file_name,self)
+
+
+
+
+
+[docs]
+classFieldPropertyPair(BasePancaxModel):
+"""
+ Data structure for storing a set of field network
+ parameters and a set of material properties
+
+ :param fields: field network parameters object
+ :param properties: property parameters object
+ """
+ fields:eqx.Module
+ properties:eqx.Module
+
+ def__iter__(self):
+"""
+ Iterator for user friendliness
+ """
+ returniter((self.fields,self.properties))
+
+
+[docs]
+defzero_init(key:jax.random.PRNGKey,shape)->Float[Array,"no ni"]:
+"""
+ :param weight: current weight array for sizing
+ :param key: rng key
+ :return: A new set of weights
+ """
+ out,in_=weight.shape
+ returnjnp.zeros(shape,dtype=jnp.float64)
+
+
+
+
+[docs]
+deftrunc_init(key:jax.random.PRNGKey,shape)->Float[Array,"no ni"]:
+"""
+ :param weight: current weight array for sizing
+ :param key: rng key
+ :return: A new set of weights
+ """
+ stddev=jnp.sqrt(1/shape[0])
+ returnstddev*jax.random.truncated_normal(key,shape=shape,lower=-2,upper=2)
+
+
+
+
+[docs]
+definit_linear_weight(model:eqx.Module,init_fn:Callable,key:jax.random.PRNGKey)->eqx.Module:
+"""
+ :param model: equinox model
+ :param init_fn: function to initialize weigth with
+ :param key: rng key
+ :return: a new equinox model
+ """
+ is_linear=lambdax:isinstance(x,eqx.nn.Linear)
+ get_weights=lambdam:[
+ x.weight
+ forxinjax.tree_util.tree_leaves(m,is_leaf=is_linear)
+ ifis_linear(x)
+ ]
+ weights=get_weights(model)
+ new_weights=[
+ init_fn(subkey,weight.shape)
+ forsubkey,weightinzip(jax.random.split(key,len(weights)),weights)
+ ]
+ new_model=eqx.tree_at(get_weights,model,new_weights)
+ returnnew_model
+from.initializationimport*
+fromtypingimportCallable
+fromtypingimportOptional
+importequinoxaseqx
+importjax
+
+
+# TODO should we convert these to actual classes?
+
+
+[docs]
+defLinear(
+ n_inputs:int,
+ n_outputs:int,
+ key:jax.random.PRNGKey
+):
+"""
+ :param n_inputs: Number of inputs to linear layer
+ :param n_outputs: Number of outputs of the linear layer
+ :param key: rng key
+ :return: Equinox Linear layer
+ """
+ model=eqx.nn.Linear(
+ n_inputs,n_outputs,
+ use_bias=False,
+ key=key
+ )
+ model=eqx.tree_at(lambdal:l.weight,model,jnp.zeros((n_outputs,n_inputs),dtype=jnp.float64))
+ returnmodel
+
+
+
+
+[docs]
+defMLP(
+ n_inputs:int,
+ n_outputs:int,
+ n_neurons:int,
+ n_layers:int,
+ activation:Callable,
+ key:jax.random.PRNGKey,
+ use_final_bias:Optional[bool]=False,
+ init_func:Optional[Callable]=trunc_init
+):
+"""
+ :param n_inputs: Number of inputs to the MLP
+ :param n_outputs: Number of outputs of the MLP
+ :param n_neurons: Number of neurons in each hidden layer of the MLP
+ :param n_layers: Number of hidden layers in the MLP
+ :param activation: Activation function, e.g. tanh or relu
+ :param key: rng key
+ :param use_final_bias: Boolean for whether or not to use a bias
+ vector in the final layer
+ :return: Equinox MLP layer
+ """
+ model=eqx.nn.MLP(
+ n_inputs,n_outputs,n_neurons,n_layers,
+ activation=activation,
+ use_final_bias=use_final_bias,
+ key=key
+ )
+ # model = init_linear_weight(model, init_func, key)
+ # model = init_linear(model, init_func, key)
+ returnmodel
+[docs]
+ defensemble_init(self,params):
+ self.step=self.make_step_method()
+ # if self.jit:
+ # self.step = eqx.filter_jit(self.step)
+
+ # need to now make an ensemble wrapper our self.step
+ # but make sure not to jit it until after the vmap
+ defensemble_step(params,domain,opt_st):
+ params,opt_st,loss=eqx.filter_vmap(
+ self.step,in_axes=(eqx.if_array(0),None,eqx.if_array(0))
+ )(params,domain,opt_st)
+ returnparams,opt_st,loss
+
+ ifself.jit:
+ self.ensemble_step=eqx.filter_jit(ensemble_step)
+
+ defvmap_func(p):
+ returnself.opt.init(eqx.filter(p,eqx.is_array))
+
+ opt_st=eqx.filter_vmap(vmap_func,in_axes=(eqx.if_array(0),))(params)
+ returnopt_st
+[docs]
+classTimerError(Exception):
+"""A custom exception used to report errors in use of Timer class"""
+
+
+
+
+[docs]
+@dataclass
+classTimer(ContextDecorator):
+"""Time your code using a class, context manager, or decorator"""
+
+ timers:ClassVar[Dict[str,float]]=dict()
+ name:Optional[str]=None
+ text:str="Time in {name}: {:0.8f} seconds"
+ logger:Optional[Callable[[str],None]]=print
+ _start_time:Optional[float]=field(default=None,init=False,repr=False)
+
+ def__post_init__(self)->None:
+"""Initialization: add timer to dict of timers"""
+ ifself.name:
+ self.timers.setdefault(self.name,0)
+
+
+[docs]
+ defstart(self)->None:
+"""Start a new timer"""
+ ifself._start_timeisnotNone:
+ raiseTimerError(f"Timer is running. Use .stop() to stop it")
+
+ self._start_time=time.perf_counter()
+
+
+
+[docs]
+ defstop(self)->float:
+"""Stop the timer, and report the elapsed time"""
+ ifself._start_timeisNone:
+ raiseTimerError(f"Timer is not running. Use .start() to start it")
+
+ # Calculate elapsed time
+ elapsed_time=time.perf_counter()-self._start_time
+ self._start_time=None
+
+ # Report elapsed time
+ ifself.logger:
+ if(self.name):
+ self.logger(self.text.format(elapsed_time,name=self.name))
+ else:
+ self.logger(self.text.format(elapsed_time,name=''))
+ ifself.name:
+ self.timers[self.name]+=elapsed_time
+
+ returnelapsed_time
+
+
+ def__enter__(self)->"Timer":
+"""Start a new timer as a context manager"""
+ self.start()
+ returnself
+
+ def__exit__(self,*exc_info:Any)->None:
+"""Stop the context manager timer"""
+ self.stop()
+frompancax.history_writerimportHistoryWriter
+frompancax.loggingimportLogger
+frompancax.post_processorimportPostProcessor
+frompancax.utilsimportset_checkpoint_file
+frompathlibimportPath
+importos
+
+# TODO make this a proper equinox module
+
"
+ )
+ );
+ },
+
+ /**
+ * helper function to hide the search marks again
+ */
+ hideSearchWords: () => {
+ document
+ .querySelectorAll("#searchbox .highlight-link")
+ .forEach((el) => el.remove());
+ document
+ .querySelectorAll("span.highlighted")
+ .forEach((el) => el.classList.remove("highlighted"));
+ localStorage.removeItem("sphinx_highlight_terms")
+ },
+
+ initEscapeListener: () => {
+ // only install a listener if it is really needed
+ if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return;
+
+ document.addEventListener("keydown", (event) => {
+ // bail for input elements
+ if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return;
+ // bail with special keys
+ if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return;
+ if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) {
+ SphinxHighlight.hideSearchWords();
+ event.preventDefault();
+ }
+ });
+ },
+};
+
+_ready(() => {
+ /* Do not call highlightSearchWords() when we are on the search page.
+ * It will highlight words from the *previous* search query.
+ */
+ if (typeof Search === "undefined") SphinxHighlight.highlightSearchWords();
+ SphinxHighlight.initEscapeListener();
+});
diff --git a/html/genindex.html b/html/genindex.html
new file mode 100644
index 0000000..c62a6b1
--- /dev/null
+++ b/html/genindex.html
@@ -0,0 +1,2845 @@
+
+
+
+
+
+
+
+ Index — pancax 0.0.2 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Index
+
+
+
+
+
+
+
+
+
+
Index
+
+
+ _
+ | A
+ | B
+ | C
+ | D
+ | E
+ | F
+ | G
+ | H
+ | I
+ | J
+ | K
+ | L
+ | M
+ | N
+ | O
+ | P
+ | Q
+ | R
+ | S
+ | T
+ | U
+ | V
+ | W
+ | X
+ | Z
+
+
Data structure that holds global data to be used as
+ground truth for some global field calculated from
+PINN outputs used in inverse modeling training
+
+
Parameters:
+
+
times – A set of times used to compare to physics calculations
+
displacements – Currently hardcoded to use a displacement-force curve TODO
+
outputs – Field used as ground truth, hardcoded essentially to a reaction force now
+
n_nodes – Book-keeping variable for number of nodes on nodeset to measure global response from
+
n_time_steps – Book-keeping variable
+
reaction_nodes – Node set nodes for where to measure reaction forces
+
reaction_dof – Degree of freedom to use for reaction force calculation
Shape functions and shape function gradients (in the parametric space).
+
+
Parameters:
+
+
values – Values of the shape functions at a discrete set of points.
+Shape is (nPts,nNodes), where nPts is the number of
+points at which the shame functinos are evaluated, and nNodes
+is the number of nodes in the element (which is equal to the
+number of shape functions).
+
gradients – Values of the parametric gradients of the shape functions.
+Shape is (nPts,nDim,nNodes), where nDim is the number
+of spatial dimensions. Line elements are an exception, which
+have shape (nPts,nNdodes).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Method to be defined to calculate shape function values
+and gradients given a list of nodal points (usually the vertexNodes)
+and a list of evaluation points (usually the quadrature points).
Data needed for calculus on functions in the discrete function space.
+
In describing the shape of the attributes, ne is the number of
+elements in the mesh, nqpe is the number of quadrature points per
+element, npe is the number of nodes per element, and nd is the
+spatial dimension of the domain.
+
+
Parameters:
+
+
shapes – Shape function values on each element, shape (ne, nqpe, npe)
+
vols – Volume attributed to each quadrature point. That is, the
+quadrature weight (on the parameteric element domain) multiplied by
+the Jacobian determinant of the map from the parent element to the
+element in the domain. Shape (ne, nqpe).
+
shapeGrads – Derivatives of the shape functions with respect to the
+spatial coordinates of the domain. Shape (ne, nqpe, npe, nd).
+
mesh – The Mesh object of the domain.
+
quadratureRule – The QuadratureRule on which to sample the shape
+functions.
+
isAxisymmetric – boolean indicating if the function space data are
+axisymmetric.
:param :
+:type : param mesh: The mesh of the domain.
+:param :
+:type : param quadratureRule: The quadrature rule to be used for integrating on the
+:param domain.:
+:param :
+:type : param mode2D: A string indicating how the 2D domain is interpreted for
+:param integration. Valid values are cartesian and axisymmetric.:
+:param Axisymetric mode will include the factor of 2*pi*r in the vols:
+:param attribute.:
Construct a function space with precomputed shape function data on the parent element.
+
This version of the function space constructor is Jax-transformable,
+and in particular can be jitted. The computation of the shape function
+values and derivatives on the parent element is not transformable in
+general. However, the mapping of the shape function data to the elements in
+the mesh is transformable. One can precompute the parent element shape
+functions once and for all, and then use this special factory function to
+construct the function space and avoid the non-transformable part of the
+operation. The primary use case is for shape sensitivities: the coordinates
+of the mesh change, and we want Jax to pick up the sensitivities of the
+shape function derivatives in space to the coordinate changes
+(which occurs through the mapping from the parent element to the spatial
+domain).
+
:param :
+:type : param mesh: The mesh of the domain.
+:param :
+:type : param shapeOnRef: A tuple of the shape function values and gradients on the
+:param parent element:
+:param evaluated at the quadrature points. The caller must:
+:param take care to ensure the shape functions are evaluated at the same:
+:param points as contained in the quadratureRule parameter.:
+:param : domain.
+:type : param quadratureRule: The quadrature rule to be used for integrating on the
+:param :
+:type : param mode2D: A string indicating how the 2D domain is interpreted for
+:param integration. See the default factory function for details.:
+pancax.fem.function_space.integrate_over_block(functionSpace, U, X, stateVars, props, dt, func, block, *params, modify_element_gradient=<functiondefault_modify_element_gradient>)[source]
+
Integrates a density function over a block of the mesh.
+
+
Parameters:
+
+
functionSpace – Function space object to do the integration with.
+
U – The vector of dofs for the primal field in the functional.
+
X – Nodal coordinates
+
stateVars – Internal state variable array.
+
dt – Current time increment
+
func – Lagrangian density function to integrate, Must have the signature
+func(u,dudx,q,x,*params)->scalar, where u is the primal field, q is the
+value of the internal variables, x is the current point coordinates, and *params is
+a variadic set of additional parameters, which correspond to the *params argument.
+block: Group of elements to integrate over. This is an array of element indices. For
+performance, the elements within the block should be numbered consecutively.
+
modify_element_gradient – Optional function that modifies the gradient at the element level.
+This can be to set the particular 2D mode, and additionally to enforce volume averaging
+on the gradient operator. This is a keyword-only argument.
+
+
+
+
Returns
+A scalar value for the integral of the density functional func integrated over the
+block of elements.
+
+
+
+
+pancax.fem.function_space.evaluate_on_block(functionSpace, U, X, stateVars, dt, props, func, block, *params, modify_element_gradient=<functiondefault_modify_element_gradient>)[source]
+
Evaluates a density function at every quadrature point in a block of the mesh.
+
+
Parameters:
+
+
functionSpace – Function space object to do the evaluation with.
+
U – The vector of dofs for the primal field in the functional.
+
X – Nodal coordinates
+
stateVars – Internal state variable array.
+
dt – Current time increment
+
func – Lagrangian density function to evaluate, Must have the signature
+`func(u,dudx,q,x,*params)->scalar`, where `u` is the primal field, `q` is the
+value of the internal variables, `x` is the current point coordinates, and `*params` is
+a variadic set of additional parameters, which correspond to the `*params` argument.
+
block – Group of elements to evaluate over. This is an array of element indices. For
+performance, the elements within the block should be numbered consecutively.
+
modify_element_gradient – Optional function that modifies the gradient at the element level.
+This can be to set the particular 2D mode, and additionally to enforce volume averaging
+on the gradient operator. This is a keyword-only argument.
+
+
+
+
Returns
+An array of shape (numElements, numQuadPtsPerElement) that contains the scalar values of the
+density functional `func` at every quadrature point in the block.
coords – Coordinates of the nodes, shape (nNodes,nDim).
+
conns – Nodal connectivity table of the elements.
+
simplexNodesOrdinals – Indices of the nodes that are vertices.
+
parentElement – A ParentElement that is the element type in
+parametric space. A mesh can contain only 1 element type.
+
parentElement1d
+
blocks – A dictionary mapping element block names to the indices of the
+elements in the block.
+
nodeSets – A dictionary mapping node set names to the indices of the
+nodes.
+
sideSets – A dictionary mapping side set names to the edges. The
+edge data structure is a tuple of the element index and the local
+number of the edge within that element. For example, triangle
+elements will have edge 0, 1, or 2 for this entry.
Generate topological information about edges in a triangulation.
+
+
Parameters:
+
conns ((nTriangles, 3) array) – Connectivity table of the triangulation.
+
+
Returns:
+
+
edgeConns ((nEdges, 2) array) – Vertices of each edge. Boundary edges are always in the
+counter-clockwise sense, so that the interior of the body is on the left
+side when walking from the first vertex to the second.
+
edges ((nEdges, 4) array) – Edge-to-triangle topological information. Each row provides the
+follwing information for each edge: [leftT, leftP, rightT, rightP],
+where leftT is the ID of the triangle to the left, leftP is the
+permutation of the edge in the left triangle (edge 0, 1, or 2), rightT
+is the ID of the triangle to the right, and rightP is the permutation
+of the edge in the right triangle. If the edge is a boundary edge, the
+values of rightT and rightP are -1.
Assumes that the edgs has a constant shape jacobian, that is, the
+transformation from the parent element is affine.
+
Arguments
+:param mesh: a Mesh object
+:param edgeCoords: coordinates of all nodes on the edge, in the order
+defined by the 1D parent element convention
+
Returns
+tuple (t, n, j) with
+:return t: the unit tangent vector
+:return n: the outward unit normal vector
+:return j: jacobian of the transformation from parent to physical space
Quadrature rule points and weights.
+A namedtuple containing xigauss, a numpy array of the
+coordinates of the sample points in the reference domain, and
+wgauss, a numpy array with the weights.
+
+
Parameters:
+
+
xigauss – coordinates of gauss points in reference element
+
wgauss – weights of gauss points in reference element
Creates a Gauss-Legendre quadrature on the unit triangle.
+
The rule can exactly integrate 2D polynomials up to the value of
+degree. The domain is the triangle between the vertices
+(0, 0)-(1, 0)-(0, 1). The rules here are guaranteed to be
+cyclically symmetric in triangular coordinates and to have strictly
+positive weights.
+
+
Parameters:
+
degree (Highest degree polynomial to be exactly integrated by the quadrature rule)
+
+
Returns:
+
+
A QuadratureRule named tuple containing the quadrature point coordinates
Creates 1D Gauss quadrature rule data that are padded to maintain a
+uniform size, which makes this function jit-able.
+
This function is inteded to be used only when jit compilation of calls to the
+quadrature rules are needed. Otherwise, prefer to use the standard quadrature
+rules. The standard rules do not contain extra 0s for padding, which makes
+them more efficient when used repeatedly (such as in the global energy).
+
+
Parameters:
+
degree – degree of highest polynomial to be integrated exactly
Jacobian of element_quantity_new with respect to positional argument(s) 3. Takes the same arguments as element_quantity_new but returns the jacobian of the output with respect to the arguments at positions 3.
Jacobian of element_quantity_new with respect to positional argument(s) 3. Takes the same arguments as element_quantity_new but returns the jacobian of the output with respect to the arguments at positions 3.
Gradient of potential_energy with respect to positional argument(s) 1. Takes the same arguments as potential_energy but returns the gradient, which has the same shape as the arguments at positions 1.
Value and gradient of potential_energy with respect to positional argument(s) 1. Takes the same arguments as potential_energy but returns a two-element tuple where the first element is the value of potential_energy and the second element is the gradient, which has the same shape as the arguments at positions 1.
Gradient of incompressible_energy with respect to positional argument(s) 1. Takes the same arguments as incompressible_energy but returns the gradient, which has the same shape as the arguments at positions 1.
Value and gradient of incompressible_energy with respect to positional argument(s) 1. Takes the same arguments as incompressible_energy but returns a two-element tuple where the first element is the value of incompressible_energy and the second element is the gradient, which has the same shape as the arguments at positions 1.
Sum a vector to much higher accuracy than numpy.sum.
+
+
Parameters:
+
a (ndarray, with only one axis (shape [n,]))
+
+
Returns:
+
sum – The sum of the numbers in the array
+
+
Return type:
+
real
+
+
+
This special sum method computes the result as accurate as if
+computed in quadruple precision.
+
Reference:
+T. Ogita, S. M. Rump, and S. Oishi. Accurate sum and dot product.
+SIAM J. Sci. Comput., Vol 26, No 6, pp. 1955-1988.
+doi: 10.1137/030601818
Compute inner product of 2 vectors to much higher accuracy than numpy.dot.
+
:param :
+:type : param x: ndarray, with only one axis (shape [n,])
+:param :
+:type : param y: ndarray, with only one axis (shape [n,])
+
+
Returns:
+
The inner product of the input vectors.
+
+
Return type:
+
return dotprod: real
+
+
+
This special inner product method computes the result as accurate
+as if computed in quadruple precision. This algorithm is useful to
+computing objective functions from numerical integration. It avoids
+accumulation of floating point cancellation error that can obscure
+whether an objective function has truly decreased.
+
The environment variable setting
+‘XLA_FLAGS = “–xla_cpu_enable_fast_math=false”’
+is critical for this function to work on the CPU. Otherwise, xla
+apparently sets a flag for LLVM that allows unsafe floating point
+optimizations that can change associativity.
+
Reference
+T. Ogita, S. M. Rump, and S. Oishi. Accurate sum and dot product.
+SIAM J. Sci. Comput., Vol 26, No 6, pp. 1955-1988.
+doi 10.1137/030601818
Matrix square root by product form of Denman-Beavers iteration.
+
Translated from the Matrix Function Toolbox
+http://www.ma.man.ac.uk/~higham/mftoolbox
+Nicholas J. Higham, Functions of Matrices: Theory and Computation,
+SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,
Logarithmic map by inverse scaling and squaring and Padé approximants
+
Translated from the Matrix Function Toolbox
+http://www.ma.man.ac.uk/~higham/mftoolbox
+Nicholas J. Higham, Functions of Matrices: Theory and Computation,
+SIAM, Philadelphia, PA, USA, 2008. ISBN 978-0-898716-46-7,