Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions finat/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ def value_shape(self):
@property
def mapping(self):
return self.product.mapping

def dual_evaluation(self, argument, coordinate_mapping=None):
return self.product.dual_evaluation(argument, coordinate_mapping)
49 changes: 48 additions & 1 deletion finat/enriched.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,29 @@
from gem.utils import cached_property

from finat.finiteelementbase import FiniteElementBase
from finat.hdivcurl import HCurlElement, HDivElement


class EnrichedElement(FiniteElementBase):
"""A finite element whose basis functions are the union of the
basis functions of several other finite elements."""

def __new__(cls, elements):
def __new__(cls, elements, is_nodal_enriched=None):
elements = tuple(chain.from_iterable(e.elements if isinstance(e, EnrichedElement) else (e,) for e in elements))
if len(elements) == 1:
return elements[0]
else:
self = super().__new__(cls)
self.elements = elements

if is_nodal_enriched is None:
is_nodal_enriched = all(
is_orthogonal(elements[i], elements[j])
for i in range(len(elements))
for j in range(i+1, len(elements))
)

self.is_nodal_enriched = is_nodal_enriched
return self

@cached_property
Expand Down Expand Up @@ -149,6 +159,34 @@ def mapping(self):
result, = mappings
return result

def dual_evaluation(self, argument, coordinate_mapping=None):
if not self.is_nodal_enriched:
raise NotImplementedError(
f"Dual evaluation not defined for element {type(self).__name__}"
)
# Gather results from all sub-elements
# Each sub_result is (eval_expr, local_indices)
sub_results = [sub.dual_evaluation(argument, coordinate_mapping=coordinate_mapping)
for sub in self.elements]

# Extract the evaluation sub-expressions
# We must ensure that all subindices are in the free indices of subexpr
# before wrapping in ComponentTensor. If some are missing (e.g. if the
# expression simplified to a constant), we multiply by a dummy ones tensor.
evals = []
for sub, (subexpr, subindices) in zip(self.elements, sub_results):
missing_indices = tuple(idx for idx in subindices if idx not in subexpr.free_indices)
if missing_indices:
shape = tuple(idx.extent for idx in missing_indices)
ones = gem.Literal(numpy.ones(shape))
dummy = gem.Indexed(ones, missing_indices)
subexpr = gem.Product(subexpr, dummy)
evals.append(gem.ComponentTensor(subexpr, subindices))

beta = self.get_indices()
expr = gem.Indexed(gem.Concatenate(*evals), beta)
return expr, beta


def tree_map(f, *args):
"""Like the built-in :py:func:`map`, but applies to a tuple tree."""
Expand Down Expand Up @@ -201,3 +239,12 @@ def concatenate_entity_permutations(elements):
offset = len(o_e_dim_permutations)
o_e_dim_permutations += list(offset + q for q in p)
return permutations


def is_orthogonal(A, B):
"""Test whether two elements are orthogonal."""
if isinstance(A, (HCurlElement, HDivElement)) and isinstance(B, (HCurlElement, HDivElement)):
Amap = A.transform(gem.Literal(numpy.ones(A.wrappee.value_shape)))
Bmap = B.transform(gem.Literal(numpy.ones(B.wrappee.value_shape)))
return sum(a * b for a, b in zip(Amap, Bmap)) == gem.Literal(0.0)
return False
2 changes: 1 addition & 1 deletion finat/restricted.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def restrict_tpe(element, domain, take_closure):
if all(f is not null_element for f in new_factors):
elements.append(finat.TensorProductElement(new_factors))
if elements:
return finat.EnrichedElement(elements)
return finat.EnrichedElement(elements, is_nodal_enriched=True)
else:
return null_element

Expand Down
18 changes: 18 additions & 0 deletions test/finat/test_dual_basis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy
import finat
import gem
from FIAT import ufc_simplex


Expand All @@ -27,3 +28,20 @@ def test_collapse_repeated_points(dim):

assert len(points) == len(numpy.unique(numpy.round(points, decimals=7), axis=0))
assert len(points) == expected


def test_enriched_element_dual_evaluation():
cell = ufc_simplex(2)
fe = finat.Lagrange(cell, 3)

fe1 = finat.RestrictedElement(fe, restriction_domain="interior")
fe2 = finat.RestrictedElement(fe, restriction_domain="facet")
enriched = finat.EnrichedElement([fe1, fe2], is_nodal_enriched=True)

# Check that calling dual_evaluation returns a valid Indexed expression
fn = lambda x: gem.Literal(1.0)
expr, indices = enriched.dual_evaluation(fn)
assert isinstance(expr, gem.Indexed)
assert isinstance(expr.children[0], gem.Concatenate)
assert len(indices) == 1
assert indices[0].extent == enriched.space_dimension()
Loading