@@ -263,22 +263,23 @@ def __new__(cls, *args, **kwargs):
263263 reference element."""
264264 if cls is not ExpansionSet :
265265 return super ().__new__ (cls )
266+ ref_el = args [0 ]
267+ shape = ref_el .get_shape ()
266268 try :
267- ref_el = args [0 ]
268269 expansion_set = {
269270 reference_element .POINT : PointExpansionSet ,
270271 reference_element .LINE : LineExpansionSet ,
271272 reference_element .TRIANGLE : TriangleExpansionSet ,
272273 reference_element .TETRAHEDRON : TetrahedronExpansionSet ,
273- }[ref_el .get_shape ()]
274- return expansion_set (* args , ** kwargs )
274+ }[shape ]
275275 except KeyError :
276- raise ValueError ("Invalid reference element type." )
276+ raise ValueError (f"Invalid reference element type { type (ref_el ).__name__ } ." )
277+ return expansion_set (* args , ** kwargs )
277278
278279 def __init__ (self , ref_el , scale = None , variant = None ):
279280 self .ref_el = ref_el
280281 self .variant = variant
281- sd = ref_el .get_spatial_dimension ()
282+ sd = ref_el .get_topological_dimension ()
282283 top = ref_el .get_topology ()
283284 base_ref_el = reference_element .default_simplex (sd )
284285 base_verts = base_ref_el .get_vertices ()
@@ -303,7 +304,7 @@ def reconstruct(self, ref_el=None, scale=None, variant=None):
303304
304305 def get_scale (self , n , cell = 0 ):
305306 scale = self .scale
306- sd = self .ref_el .get_spatial_dimension ()
307+ sd = self .ref_el .get_topological_dimension ()
307308 if isinstance (scale , str ):
308309 vol = self .ref_el .volume_of_subcomplex (sd , cell )
309310 scale = scale .lower ()
@@ -334,7 +335,7 @@ def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
334335 A , b = self .affine_mappings [cell ]
335336 ref_pts = numpy .add (numpy .dot (pts , A .T ), b ).T
336337 Jinv = A if direction is None else numpy .dot (A , direction )[:, None ]
337- sd = self .ref_el .get_spatial_dimension ()
338+ sd = self .ref_el .get_topological_dimension ()
338339 scale = self .get_scale (n , cell = cell )
339340 phi = dubiner_recurrence (sd , n , lorder , ref_pts , Jinv ,
340341 scale , variant = self .variant )
@@ -417,7 +418,7 @@ def tabulate_normal_jumps(self, n, ref_pts, facet, order=0):
417418
418419 :returns: a numpy array of tabulations of normal derivative jumps.
419420 """
420- sd = self .ref_el .get_spatial_dimension ()
421+ sd = self .ref_el .get_topological_dimension ()
421422 transform = self .ref_el .get_entity_transform (sd - 1 , facet )
422423 pts = transform (ref_pts )
423424 cell_point_map = compute_cell_point_map (self .ref_el , pts , unique = False )
@@ -507,7 +508,7 @@ def get_dmats(self, degree, cell=0):
507508 if degree == 0 :
508509 return cache .setdefault (key , numpy .zeros ((self .ref_el .get_spatial_dimension (), 1 , 1 ), "d" ))
509510
510- D = self .ref_el .get_dimension ()
511+ D = self .ref_el .get_topological_dimension ()
511512 top = self .ref_el .get_topology ()
512513 verts = self .ref_el .get_vertices_of_subcomplex (top [D ][cell ])
513514 pts = reference_element .make_lattice (verts , degree , variant = "gl" )
@@ -519,14 +520,14 @@ def get_dmats(self, degree, cell=0):
519520 def tabulate (self , n , pts ):
520521 if len (pts ) == 0 :
521522 return numpy .array ([])
522- sd = self .ref_el .get_spatial_dimension ()
523+ sd = self .ref_el .get_topological_dimension ()
523524 return self ._tabulate (n , pts )[(0 ,) * sd ]
524525
525526 def tabulate_derivatives (self , n , pts ):
526527 from FIAT .polynomial_set import mis
527528 vals = self ._tabulate (n , pts , order = 1 )
528529 # Create the ordinary data structure.
529- sd = self .ref_el .get_spatial_dimension ()
530+ sd = self .ref_el .get_topological_dimension ()
530531 v = vals [(0 ,) * sd ]
531532 dv = [vals [alpha ] for alpha in mis (sd , 1 )]
532533 data = [[(v [i , j ], [vi [i , j ] for vi in dv ])
@@ -537,7 +538,7 @@ def tabulate_derivatives(self, n, pts):
537538 def tabulate_jet (self , n , pts , order = 1 ):
538539 vals = self ._tabulate (n , pts , order = order )
539540 # Create the ordinary data structure.
540- sd = self .ref_el .get_spatial_dimension ()
541+ sd = self .ref_el .get_topological_dimension ()
541542 v0 = vals [(0 ,) * sd ]
542543 data = [v0 ]
543544 for r in range (1 , order + 1 ):
@@ -556,7 +557,7 @@ def __eq__(self, other):
556557class PointExpansionSet (ExpansionSet ):
557558 """Evaluates the point basis on a point reference element."""
558559 def __init__ (self , ref_el , ** kwargs ):
559- if ref_el .get_spatial_dimension () != 0 :
560+ if ref_el .get_topological_dimension () != 0 :
560561 raise ValueError ("Must have a point" )
561562 super ().__init__ (ref_el , ** kwargs )
562563
@@ -570,8 +571,8 @@ def _tabulate_on_cell(self, n, pts, order=0, cell=0, direction=None):
570571class LineExpansionSet (ExpansionSet ):
571572 """Evaluates the Legendre basis on a line reference element."""
572573 def __init__ (self , ref_el , ** kwargs ):
573- if ref_el .get_spatial_dimension () != 1 :
574- raise Exception ("Must have a line" )
574+ if ref_el .get_topological_dimension () != 1 :
575+ raise ValueError ("Must have a line" )
575576 super ().__init__ (ref_el , ** kwargs )
576577
577578 def _tabulate_on_cell (self , n , pts , order = 0 , cell = 0 , direction = None ):
@@ -600,15 +601,15 @@ class TriangleExpansionSet(ExpansionSet):
600601 """Evaluates the orthonormal Dubiner basis on a triangular
601602 reference element."""
602603 def __init__ (self , ref_el , ** kwargs ):
603- if ref_el .get_spatial_dimension () != 2 :
604+ if ref_el .get_topological_dimension () != 2 :
604605 raise Exception ("Must have a triangle" )
605606 super ().__init__ (ref_el , ** kwargs )
606607
607608
608609class TetrahedronExpansionSet (ExpansionSet ):
609610 """Collapsed orthonormal polynomial expansion on a tetrahedron."""
610611 def __init__ (self , ref_el , ** kwargs ):
611- if ref_el .get_spatial_dimension () != 3 :
612+ if ref_el .get_topological_dimension () != 3 :
612613 raise Exception ("Must be a tetrahedron" )
613614 super ().__init__ (ref_el , ** kwargs )
614615
@@ -627,7 +628,7 @@ def polynomial_dimension(ref_el, n, continuity=None):
627628 elif continuity == "C0" :
628629 space_dimension = sum (math .comb (n - 1 , dim ) * len (top [dim ]) for dim in top )
629630 else :
630- dim = ref_el .get_spatial_dimension ()
631+ dim = ref_el .get_topological_dimension ()
631632 space_dimension = math .comb (n + dim , dim ) * len (top [dim ])
632633 return space_dimension
633634
@@ -641,7 +642,7 @@ def polynomial_entity_ids(ref_el, n, continuity=None):
641642 :returns: a dict of dicts mapping dimension and entity id to basis functions.
642643 """
643644 top = ref_el .get_topology ()
644- sd = ref_el .get_spatial_dimension ()
645+ sd = ref_el .get_topological_dimension ()
645646 entity_ids = {}
646647 cur = 0
647648 for dim in sorted (top ):
@@ -668,7 +669,7 @@ def polynomial_cell_node_map(ref_el, n, continuity=None):
668669 :returns: a numpy array mapping cell id to basis functions supported on that cell.
669670 """
670671 top = ref_el .get_topology ()
671- sd = ref_el .get_spatial_dimension ()
672+ sd = ref_el .get_topological_dimension ()
672673
673674 entity_ids = polynomial_entity_ids (ref_el , n , continuity )
674675 ref_entity_ids = polynomial_entity_ids (ref_el .construct_subelement (sd ), n , continuity )
@@ -697,7 +698,7 @@ def compute_cell_point_map(ref_el, pts, unique=True, tol=1E-12):
697698 :returns: a dict mapping cell id to the point ids nearest to that cell.
698699 """
699700 top = ref_el .get_topology ()
700- sd = ref_el .get_spatial_dimension ()
701+ sd = ref_el .get_topological_dimension ()
701702 if len (top [sd ]) == 1 :
702703 return {0 : Ellipsis }
703704
0 commit comments