Skip to content
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
828691b
added new methods for computing barycentric coords on tensor-product …
achanbour Feb 9, 2026
85e17ef
minor fixes
achanbour Feb 10, 2026
0b38a37
modified compute_barycentric_coords to consistently handle input poin…
achanbour Feb 11, 2026
78db95e
extended the tensor product barycentric coordinates computation funct…
achanbour Feb 16, 2026
4f3d9cd
removed line enforcing points to be numpy arrays to ensure code gener…
achanbour Feb 16, 2026
4677313
added conversion to numpy array only in axis_barycentric_coords
achanbour Feb 16, 2026
c19b8bd
generalised numpy operations in compute_barycentric_coordinates metho…
achanbour Feb 19, 2026
ac0d93e
changed compute_barycentric_coordinates method to work on input GEM e…
achanbour Feb 23, 2026
cf47fe9
extended gem matmul to convert numpy arrays to Literals and deal sepa…
achanbour Feb 23, 2026
c04b6fd
changes made to gem and FIAT to make barycentric coordinate computati…
achanbour Feb 25, 2026
f339b01
extended gem's slicing syntax and simplified the code in compute_axis…
achanbour Feb 26, 2026
ddf6eb4
tidied up and added comments
achanbour Feb 26, 2026
65408bd
Merge remote-tracking branch 'origin/main' into achanbour/bary-coords
achanbour Apr 20, 2026
7b35ca3
recent changes post merging
achanbour Apr 20, 2026
5919d61
compute barycentric coords symbolically in simplicies and hypercubes …
achanbour Apr 20, 2026
3d87fc1
modified unit tests for computing bary coords on points passed as num…
achanbour Apr 20, 2026
76e63f7
latest changes + gem tests
achanbour Apr 21, 2026
e2403c4
removed egg-info from tracking
achanbour Apr 21, 2026
26e9f39
implemented handler in gem.evaluate for FlexiblyIndexed nodes
achanbour Apr 23, 2026
d9a5e4b
added a new ListIndex type to GEM for indexing GEM tensors using an i…
achanbour Apr 23, 2026
05ea454
final changes
achanbour Apr 24, 2026
0cb0398
fixed formatting
achanbour Apr 24, 2026
fb46fd6
added docstrings to tests
achanbour Apr 24, 2026
11ee81f
update .gitignore from main
achanbour May 5, 2026
f6d32f1
update .gitignore from main
achanbour May 5, 2026
6bbddc5
changed ListIndex to be a subclass of Index
achanbour May 5, 2026
a5b8ab0
renamed tp bary coords methods + minor fixes + updated tests
achanbour May 5, 2026
82d51a9
separated ListIndex from Index
achanbour May 6, 2026
4833f4e
added explanation to facet permutation to reorder barycentric coords
achanbour May 6, 2026
f813c85
enforced facet ordering in the barycentric coordinates across all sim…
achanbour May 6, 2026
8084e8a
generalised the facet ordering permutation of bary coords computed on…
achanbour May 7, 2026
1944578
modified the index substitution logic to correctly handle ListIndex
achanbour May 7, 2026
2b6cca3
changed index substitution optim
achanbour May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

/build/
/dist/
/fenics_fiat.egg-info/
/firedrake_fiat.egg-info/

/.cache/
/doc/sphinx/source/api-doc
Expand Down
157 changes: 154 additions & 3 deletions FIAT/reference_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,10 @@ def get_dimension(self):

def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
Comment thread
achanbour marked this conversation as resolved.
Outdated
"""Returns the barycentric coordinates of a list of points on the complex."""
if len(points) == 0:

if isinstance(points, numpy.ndarray) and len(points) == 0:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not start assuming that points may only be a numpy.ndarray. In FIAT most of the time points are treated as tuples, since those are immutable. We try to use numpy for our computations as much as possible, while relying on the casting done by most numpy functions.

return points

if entity is None:
entity = (self.get_spatial_dimension(), 0)
entity_dim, entity_id = entity
Expand All @@ -640,8 +642,11 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
h = 1 / numpy.linalg.norm(A, axis=1)
b *= h
A *= h[:, None]
out = numpy.dot(points, A.T)
return numpy.add(out, b, out=out)
# out = numpy.dot(points, A.T)
out = points @ A.T

# return numpy.add(out, b)
return out + b

def compute_bubble(self, points, entity=None):
"""Returns the lowest-order bubble on an entity evaluated at the given
Expand Down Expand Up @@ -1406,6 +1411,45 @@ def extrinsic_orientation_permutation_map(self):
def is_macrocell(self):
return any(c.is_macrocell() for c in self.cells)

def compute_factor_barycentric_coordinates(self, points, entity=None, rescale=False):
"""Compute barycentric coordinates on each axis (factor) of a tensor-product cell.
Comment thread
achanbour marked this conversation as resolved.

Parameters
----------
points: numpy.ndarray or GEM.Node
The reference coordinates of the points.

Returns
-------
numpy.ndarray
A flattened array of shape ``(total_bary_coords, )`` and dtype object if points are GEM nodes,
otherwise dtype numeric. The i-th entry contains the barycentric coordinates
on the i-th factor cell. If factor i is a simplex of dimension d, this will
have shape ``(npoints, d+1)``. If factor i is a hypercube of dimension d,
this will have shape ``(npoints, 2*d)``.
"""
import gem

if isinstance(points, numpy.ndarray) and len(points) == 0:
return points

axis_dims = [c.get_spatial_dimension() for c in self.cells]
point_slices = TensorProductCell._split_slices(axis_dims)

result = []
for factor, s in zip(self.cells, point_slices):
result.append(factor.compute_barycentric_coordinates(points[..., s], entity, rescale))
Comment thread
achanbour marked this conversation as resolved.
Outdated

# Flatten the array
# We cannot construct the flat array directly since we may not know upfront the total number
# of barycentric coordinates (e.g., in a simplex it is d+1, in a hypercube it is 2*d)
flat_result = numpy.array([bary[j] for bary in result for j in range(bary.shape[0])])

if isinstance(points, gem.Node):
return gem.as_gem(flat_result) # returns a ListTensor wrapping the scalar GEM expr. of bary coords.
Comment thread
achanbour marked this conversation as resolved.
Outdated

return flat_result


class Hypercube(Cell):
"""Abstract class for a reference hypercube"""
Expand All @@ -1423,6 +1467,8 @@ def __init__(self, dimension, product):
self.product = product
self.unflattening_map = compute_unflattening_map(pt)

self.facet_perm = compute_facet_permutation(self.unflattening_map, self.product)

def get_dimension(self):
"""Returns the subelement dimension of the cell. Same as the
spatial dimension."""
Expand Down Expand Up @@ -1521,6 +1567,27 @@ def __ge__(self, other):
def __le__(self, other):
return self.product <= other

def compute_barycentric_coordinates(self, points, entity=None, rescale=False):
"""Returns the barycentric coordinates of a list of points on the hypercube.

Parameters
----------
points: numpy.ndarray or GEM.Node
The reference coordinates of the points.

Returns
-------
List of numpy.ndarray or GEM.ComponentTensor
Returns a list of barycentric coordinates in local facet order such that for any point
lying on local facet `lf` of the cell, the barycentric coordinate at index `lf` vanishes.
"""
if isinstance(points, numpy.ndarray) and len(points) == 0:
return points

tp_bary_coords = self.product.compute_factor_barycentric_coordinates(points, entity, rescale)

return tp_bary_coords[self.facet_perm] # A[[1, 3, 5, 4]]


class UFCHypercube(Hypercube):
"""Reference UFC Hypercube
Expand Down Expand Up @@ -1839,6 +1906,90 @@ def compute_unflattening_map(topology_dict):
return unflattening_map


def compute_facet_permutation(unflattening_map, product):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the possible permutations for Prisms and Cubes? Don't they end up in a trivial permutation?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by trivial permutation?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he means an identity permutation. I would recommend that you add some sort of check to catch those somewhere.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's identity for simplicies. As written compute_facet_permutation only works for Hypercubes since we only have unflattening_map there. For general TP cells like prisms, my assumptions is that something similar to an unflattening map should be implemented first before generalising compute_facet_permutation.

@pbrubeck pbrubeck May 6, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the comment with the quad example. I see the issue now. The problem is that the UFCInterval does not order the facets in the same way UFCTriangle and UFCTetrahedron do. UFCInterval has the vertices as the facets, and vertices are always ordered in vertex order, but facets are ordered by the vertex they exclude, so you get an inconsistency for 1D when the facets are also the vertices.

I think the implementation of compute_barycentric_coordinates on UFCInterval should apply the permutation [1, 0]. Then this gives us the right ordering for UFCQuadrilateral.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected ordering for compute_barycentric_coordinates on simplices is not facet-based, but always vertex-based. If you would like a facet-based ordering, then only the interval should be permuted and tensor products would get the facet-based ordering for free.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Pablo. It is indeed very confusing but with your help I am able to see better. I'll try to re-iterate the reasoning here for the sake of clarity.

The expected ordering for compute_barycentric_coordinates on simplices is not facet-based, but always vertex-based.

  • This is consistent with the fact that $\lambda_i(P_j) = \delta_{ij}$ i.e., the i-th barycentric coordinate is 1 on vertex i and 0 on all other vertices. Hence it is 0 on the facet excluding vertex i.
  • Since in UFCTriangle and UFCTetrahedron order facets by the vertex they exclude we get that facet i excludes vertex i, i.e., $\lambda_i = 0$ on facet i.
  • Since in UFCInterval the vertices are the facets, we have that facet 1 contains vertex 1 and facet 2 contains vertex 2. This translates to $\lambda_1$ vanishing on facet 2 and $\lambda_2$ vanishing on facet 1 (reversed).

only the interval should be permuted and tensor products would get the facet-based ordering for free.

  • Since compute_factor_barycentric_coordinates recurses into calling compute_barycentric_coordinates on simplices forming the factors of the tensor-product cell, we just need to ensure that the facet-ordering is preserved on 1D intervals as the facet ordering already holds in all other simplices like UFCTriangle and UFCTetrahedron.

"""
Returns a permutation mapping each facet of a Hypercube to the index of the
barycentric coordinate that vanishes on it.

Let's take the example of a quad in 2D. Calling `compute_factor_barycentric_coordinates` returns
the barycentric coordinates on each of its 2 axes:

axis 0 (x): lambda_x_1, lambda_x_2
axis 1 (y): lambda_y_1, lambda_y_2

as a flat array bary_coords = [lambda_x_1, lambda_x_2, lambda_y_1, lambda_y_2]

A quad has 4 facets which are numbered in UFC order as:

facet 3
┌───────┐
│ │
facet 0 │ │ facet 1
│ │
│ │
└───────┘
facet 2

Since each axis is a UFCInterval (a simplex), with vertices at P1 = (0,) and P2 = (1,) its barycentric coordinates
are lambda_1 = 1 - t (vanishes at P2) and lambda_2 = t (vanishes at P1). This applies the rule for barycentric coordinates
on simplicies which is that lambda_i vanishes on the facet opposite vertex i.

Therefore:

- lambda_x_1 vanishes on facet 1 (x=1), lambda_x_2 vanishes on facet 0 (x=0)
- lambda_y_1 vanishes on facet 3 (y=1), lambda_y_2 vanishes on facet 2 (y=0)

The permutation computed in this function reorders the array of barycentric coordinates such that:

bary_coords[perm] = [lambda_x_2, lambda_x_1, lambda_y_2, lambda_y_1]

where the i-th entry corresponds to the barycentric coordinate vanishing on facet i.
"""
# First compute axis offsets into the flattened barycentric coordinate array.
axis_offsets = []
offset = 0
for axis_cell in product.cells:
axis_offsets.append(offset)
offset += axis_cell.get_dimension() + 1

# Initialise the integer permutation array
sd = len(product.cells)
num_facets = 2 * sd
perm = numpy.zeros(num_facets, dtype=int)

for f in range(num_facets):
# Recover the tensor-product representation of the facet as given by the unflattening map
dim_tuple, tp_entity = unflattening_map[(sd - 1, f)]

# Determine the axis that's orthogonal to the facet
# E.g., in a quad:
# if dim_tuple = (0,1) -> facet has dimension 0 on the first component -> fixed at x = 0 or x = 1
# if dim_tuple = (1,0) -> facet has dimension 0 on the second component -> fixed at y = 0 or y = 1
axis = next(
i for i, d in enumerate(dim_tuple)
if d == product.cells[i].get_dimension() - 1
)

# Determine the index of the endpoint that produces the facet
# which gives the local facet number in the axis space
entity_shape = tuple(
len(c.get_topology()[d])
for c, d in zip(product.cells, dim_tuple)
)
tuple_ei = numpy.unravel_index(tp_entity, entity_shape)
local_facet = tuple_ei[axis]

# For a simplex (UFCInterval, UFCTriangle), the barycentric coordinate that vanishes on local facet i
# corresponds to the ID of the vertex that doesn't belong to that facet
all_vertices = set(product.cells[axis].get_topology()[0].keys())
facet_vertices = set(product.cells[axis].get_topology()[0][local_facet])
bary_index = next(iter(all_vertices - facet_vertices))

perm[f] = axis_offsets[axis] + bary_index

return perm


def max_complex(complexes):
max_cell = max(complexes)
if all(max_cell >= b for b in complexes):
Expand Down
108 changes: 100 additions & 8 deletions gem/gem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

"""GEM is the intermediate language of TSFC for describing
tensor-valued mathematical expressions and tensor operations.
It is similar to Einstein's notation.
Expand All @@ -20,6 +22,8 @@
from operator import attrgetter
from numbers import Integral, Number

from types import EllipsisType

import numpy
from numpy import asarray

Expand All @@ -32,7 +36,7 @@
'Variable', 'Sum', 'Product', 'Division', 'FloorDiv', 'Remainder', 'Power',
'MathFunction', 'MinValue', 'MaxValue', 'Comparison',
'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional',
'Index', 'VariableIndex', 'Indexed', 'ComponentTensor',
'Index', 'VariableIndex', 'ListIndex', 'Indexed', 'ComponentTensor',
'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex',
'index_sum', 'partial_indexed', 'reshape', 'view',
'indices', 'as_gem', 'FlexiblyIndexed',
Expand Down Expand Up @@ -81,12 +85,53 @@ def is_equal(self, other):
self.children = other.children
return result

def __getitem__(self, indices):
try:
indices = tuple(indices)
except TypeError:
indices = (indices, )
return Indexed(self, indices)
def __getitem__(
self,
key: IndexT | tuple[IndexT, ...],
) -> ComponentTensor | Indexed:
"""A generalised interface for indexing GEM tensors"""
if not isinstance(key, tuple):
key = (key,)

# Expand ellipsis -> fill in remaining dimensions with slice(None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this behaviour. What if I have something like tensor[::2, ..., 3]? I wonder if we should only allow ... if there are no other indices provided.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, there's no support for mixed indexing involving slices and arrays so something like tensor[..., [1,2,3]] isn't supported.

fiat/gem/gem.py

Lines 110 to 114 in a5b8ab0

has_slice = any(isinstance(k, slice) for k in key)
has_array = any(isinstance(k, (numpy.ndarray, list)) for k in key)
if has_slice and has_array:
raise NotImplementedError("Mixed slice and array indexing is not supported.")

The reason for that is that indexing by ... (or by slice) is handled by view() which expresses a reshape while indexing with [1,2,3] is handled via ListIndex which expresses a gather-like operation. I'm not sure if there's a Node in GEM that allows to compose these operations.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to unify the code paths for __getitem__ and gem.view. I would definitely expect something like

tensor_2d[1::3, [2, 3, 4]]

to work

if any(k is Ellipsis for k in key):
if key.count(Ellipsis) > 1:
raise NotImplementedError("Multiple ellipses are not supported.")
ellipsis_pos = key.index(Ellipsis)
remaining_dims = len(self.shape) - (len(key) - 1)
if remaining_dims < 0:
raise IndexError("Too many indices provided.")
key = (
key[:ellipsis_pos]
+ (slice(None), ) * remaining_dims
+ key[ellipsis_pos + 1:]
)

has_slice = any(isinstance(k, slice) for k in key)
has_array = any(isinstance(k, (numpy.ndarray, list)) for k in key)

if has_slice and has_array:
raise NotImplementedError("Mixed slice and array indexing is not supported.")

# Slice indexing -> delegate to view()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the implementation of view should live in here

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then everything is together. We could even deprecate gem.view

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have everything in gem.view? In this way we can keep the dunder methods in gem short and concise.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind that. I just want to unify the code paths somewhat.

if has_slice:
# view expects one slice for each axis/dim of the tensor
if len(key) != len(self.shape):
raise IndexError("Expects the number of slices to match the gem.Node tensor rank")
return view(self, *key)

# Support a list or array of integer indices
# Previously, built a ListTensor out of Indexed nodes, one for each element of the permutation
if has_array:
arr_pos = next(i for i, k in enumerate(key) if isinstance(k, (numpy.ndarray, list)))
list_index = ListIndex(key[arr_pos])
new_key = key[:arr_pos] + (list_index,) + key[arr_pos+1:]
indexed = Indexed(self, new_key) # A[perm[i]] is a scalar
retval = ComponentTensor(indexed, (list_index.free_index,)) # CT(A[perm[i]], i) -> A[perm]
return retval

# Point indexing
return Indexed(self, key)

def __neg__(self):
return componentwise(Product, minus, self)
Expand Down Expand Up @@ -117,6 +162,7 @@ def __matmul__(self, other):
raise ValueError("Both objects must have shape for matmul")
elif self.shape[-1] != other.shape[0]:
raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul")

*i, k = indices(len(self.shape))
_, *j = indices(len(other.shape))
expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j)))
Expand Down Expand Up @@ -677,6 +723,48 @@ def __repr__(self):
def __reduce__(self):
return type(self), (self.expression,)

class ListIndex(IndexBase):
Comment thread
achanbour marked this conversation as resolved.
"""
Option 1: tensor[list_index] is tensor[index_array[list_index]] free index ranging over 0,1..., len(index_array)

Option 2: tensor[list_index] is tensor[list_index.index_array[list_index.free_index]]
"""

__slots__ = ('index_array', 'free_index')

def __init__(self, index_array, name=None):
index_array = numpy.asarray(index_array)
assert numpy.issubdtype(index_array.dtype, numpy.integer)

# Wraps a free index together with the index array
self.index_array = index_array
self.free_index = Index(extent=len(index_array), name=name)

# super().__init__(name=name, extent=len(index_array))

def __str__(self):
return f"{self.index_array.tolist()}[i_{self.free_index}]"

def __repr__(self):
return f"ListIndex({self.index_array.tolist()}, {self.free_index})"

# def __eq__(self, other):
# if type(self) is not type(other):
# return False
# return numpy.array_equal(self.index_array, other.index_array)

# def __ne__(self, other):
# return not self.__eq__(other)

# def __hash__(self):
# return hash((type(self), self.index_array.tobytes()))

# def __reduce__(self):
# return type(self), (self.index_array, )


IndexT = int | Index | VariableIndex | ListIndex | slice | EllipsisType | list | numpy.ndarray


class Indexed(Scalar):
__slots__ = ('children', 'multiindex', 'indirect_children')
Expand Down Expand Up @@ -740,6 +828,8 @@ def __new__(cls, aggregate, multiindex):
new_indices.append(i)
elif isinstance(i, VariableIndex):
new_indices.extend(i.expression.free_indices)
elif isinstance(i, ListIndex):
new_indices.append(i.free_index)
self.free_indices = unique(aggregate.free_indices + tuple(new_indices))

return self
Expand All @@ -752,6 +842,8 @@ def index_ordering(self):
free_indices.append(i)
elif isinstance(i, VariableIndex):
free_indices.extend(i.expression.free_indices)
elif isinstance(i, ListIndex):
free_indices.append(i.free_index)
return tuple(free_indices)


Expand Down Expand Up @@ -869,7 +961,7 @@ def __new__(cls, expression, multiindex):
shape = tuple(index.extent for index in multiindex)
assert all(s >= 0 for s in shape)

# Zero folding
# Zero foldingc
if isinstance(expression, Zero):
return Zero(shape, dtype=expression.dtype)

Expand Down
Loading
Loading