1212from FIAT .dual_set import DualSet
1313from FIAT .finite_element import CiarletElement
1414from FIAT .barycentric_interpolation import LagrangeLineExpansionSet
15+ from FIAT .quadrature_schemes import create_quadrature
1516
1617__all__ = ['NodalEnrichedElement' ]
1718
@@ -38,36 +39,49 @@ def __init__(self, *elements):
3839 "of NodalEnrichedElement are nodal" )
3940
4041 # Extract common data
41- embedded_degrees = [e .get_nodal_basis (). get_embedded_degree () for e in elements ]
42+ embedded_degrees = [e .degree () for e in elements ]
4243 embedded_degree = max (embedded_degrees )
43- degree = max (e .degree () for e in elements )
4444 order = max (e .get_order () for e in elements )
4545 formdegree = None if any (e .get_formdegree () is None for e in elements ) \
4646 else max (e .get_formdegree () for e in elements )
47- # LagrangeExpansionSet has fixed degree, ensure we grab the highest one
48- elem = elements [embedded_degrees .index (embedded_degree )]
47+
48+ # Grab the ExpansionSet defined on the maximal complex with the highest degree
49+ elem = max (elements , key = lambda e : (e .get_reference_complex (), e .degree ()))
4950 ref_el = elem .get_reference_complex ()
5051 expansion_set = elem .get_nodal_basis ().get_expansion_set ()
5152 mapping = elem .mapping ()[0 ]
5253 value_shape = elem .value_shape ()
5354
5455 # Sanity check
55- assert all (e .get_reference_complex () = = ref_el for e in elements )
56+ assert all (e .get_reference_complex () < = ref_el for e in elements )
5657 assert all (set (e .mapping ()) == {mapping , } for e in elements )
5758 assert all (e .value_shape () == value_shape for e in elements )
5859
5960 # Merge polynomial sets
60- if isinstance (expansion_set , LagrangeLineExpansionSet ):
61+ if isinstance (expansion_set , LagrangeLineExpansionSet ) and expansion_set . degree == embedded_degree :
6162 # Obtain coefficients via interpolation
6263 points = expansion_set .get_points ()
6364 coeffs = np .vstack ([e .tabulate (0 , points )[(0 ,)] for e in elements ])
64- else :
65- assert all (e .get_nodal_basis ().get_expansion_set () == expansion_set
66- for e in elements )
65+ elif all (e .get_nodal_basis ().get_expansion_set () == expansion_set for e in elements ):
66+ # All elements have the same ExpansionSet, just merge coefficients
6767 coeffs = [e .get_coeffs () for e in elements ]
6868 coeffs = _merge_coeffs (coeffs , ref_el , embedded_degrees , expansion_set .continuity )
69+ else :
70+ # Obtain coefficients via projection
71+ sd = ref_el .get_spatial_dimension ()
72+ Q = create_quadrature (ref_el , 2 * embedded_degree )
73+ qpts = Q .get_points ()
74+ phis = expansion_set ._tabulate (embedded_degree , qpts , 0 )[(0 ,)* sd ]
75+ PhiW = np .multiply (phis , Q .get_weights ())
76+ M = np .tensordot (phis , PhiW , (- 1 , - 1 ))
77+ MinvPhiW = np .linalg .solve (M , PhiW )
78+
79+ tabulations = np .concatenate ([e .tabulate (0 , qpts )[(0 ,)* sd ] for e in elements ], axis = 0 )
80+ coeffs = np .tensordot (tabulations , MinvPhiW , (- 1 , - 1 ))
81+ assert coeffs .shape [1 :- 1 ] == value_shape
82+
6983 poly_set = PolynomialSet (ref_el ,
70- degree ,
84+ embedded_degree ,
7185 embedded_degree ,
7286 expansion_set ,
7387 coeffs )
@@ -135,5 +149,5 @@ def _merge_entity_ids(entity_ids, offsets):
135149 for entity in ids [dim ]:
136150 if entity not in ret [dim ]:
137151 ret [dim ][entity ] = []
138- ret [dim ][entity ].extend (np . array ( ids [ dim ][ entity ]) + offsets [ i ])
152+ ret [dim ][entity ].extend (offsets [ i ] + dof for dof in ids [ dim ][ entity ])
139153 return ret
0 commit comments