-
Notifications
You must be signed in to change notification settings - Fork 7
Generalising compute barycentric coordinates #245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
828691b
85e17ef
0b38a37
78db95e
4f3d9cd
4677313
c19b8bd
ac0d93e
cf47fe9
c04b6fd
f339b01
ddf6eb4
65408bd
7b35ca3
5919d61
3d87fc1
76e63f7
e2403c4
26e9f39
d9a5e4b
05ea454
0cb0398
fb46fd6
11ee81f
f6d32f1
6bbddc5
a5b8ab0
82d51a9
4833f4e
f813c85
8084e8a
1944578
2b6cca3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
|
|
||
| /build/ | ||
| /dist/ | ||
| /fenics_fiat.egg-info/ | ||
| /firedrake_fiat.egg-info/ | ||
|
|
||
| /.cache/ | ||
| /doc/sphinx/source/api-doc | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -615,8 +615,10 @@ def get_dimension(self): | |
|
|
||
| def compute_barycentric_coordinates(self, points, entity=None, rescale=False): | ||
| """Returns the barycentric coordinates of a list of points on the complex.""" | ||
| if len(points) == 0: | ||
|
|
||
| if isinstance(points, numpy.ndarray) and len(points) == 0: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not start assuming that |
||
| return points | ||
|
|
||
| if entity is None: | ||
| entity = (self.get_spatial_dimension(), 0) | ||
| entity_dim, entity_id = entity | ||
|
|
@@ -640,8 +642,11 @@ def compute_barycentric_coordinates(self, points, entity=None, rescale=False): | |
| h = 1 / numpy.linalg.norm(A, axis=1) | ||
| b *= h | ||
| A *= h[:, None] | ||
| out = numpy.dot(points, A.T) | ||
| return numpy.add(out, b, out=out) | ||
| # out = numpy.dot(points, A.T) | ||
| out = points @ A.T | ||
|
|
||
| # return numpy.add(out, b) | ||
| return out + b | ||
|
|
||
| def compute_bubble(self, points, entity=None): | ||
| """Returns the lowest-order bubble on an entity evaluated at the given | ||
|
|
@@ -1406,6 +1411,45 @@ def extrinsic_orientation_permutation_map(self): | |
| def is_macrocell(self): | ||
| return any(c.is_macrocell() for c in self.cells) | ||
|
|
||
| def compute_factor_barycentric_coordinates(self, points, entity=None, rescale=False): | ||
| """Compute barycentric coordinates on each axis (factor) of a tensor-product cell. | ||
|
achanbour marked this conversation as resolved.
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| points: numpy.ndarray or GEM.Node | ||
| The reference coordinates of the points. | ||
|
|
||
| Returns | ||
| ------- | ||
| numpy.ndarray | ||
| A flattened array of shape ``(total_bary_coords, )`` and dtype object if points are GEM nodes, | ||
| otherwise dtype numeric. The i-th entry contains the barycentric coordinates | ||
| on the i-th factor cell. If factor i is a simplex of dimension d, this will | ||
| have shape ``(npoints, d+1)``. If factor i is a hypercube of dimension d, | ||
| this will have shape ``(npoints, 2*d)``. | ||
| """ | ||
| import gem | ||
|
|
||
| if isinstance(points, numpy.ndarray) and len(points) == 0: | ||
| return points | ||
|
|
||
| axis_dims = [c.get_spatial_dimension() for c in self.cells] | ||
| point_slices = TensorProductCell._split_slices(axis_dims) | ||
|
|
||
| result = [] | ||
| for factor, s in zip(self.cells, point_slices): | ||
| result.append(factor.compute_barycentric_coordinates(points[..., s], entity, rescale)) | ||
|
achanbour marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Flatten the array | ||
| # We cannot construct the flat array directly since we may not know upfront the total number | ||
| # of barycentric coordinates (e.g., in a simplex it is d+1, in a hypercube it is 2*d) | ||
| flat_result = numpy.array([bary[j] for bary in result for j in range(bary.shape[0])]) | ||
|
|
||
| if isinstance(points, gem.Node): | ||
| return gem.as_gem(flat_result) # returns a ListTensor wrapping the scalar GEM expr. of bary coords. | ||
|
achanbour marked this conversation as resolved.
Outdated
|
||
|
|
||
| return flat_result | ||
|
|
||
|
|
||
| class Hypercube(Cell): | ||
| """Abstract class for a reference hypercube""" | ||
|
|
@@ -1423,6 +1467,8 @@ def __init__(self, dimension, product): | |
| self.product = product | ||
| self.unflattening_map = compute_unflattening_map(pt) | ||
|
|
||
| self.facet_perm = compute_facet_permutation(self.unflattening_map, self.product) | ||
|
|
||
| def get_dimension(self): | ||
| """Returns the subelement dimension of the cell. Same as the | ||
| spatial dimension.""" | ||
|
|
@@ -1521,6 +1567,27 @@ def __ge__(self, other): | |
| def __le__(self, other): | ||
| return self.product <= other | ||
|
|
||
| def compute_barycentric_coordinates(self, points, entity=None, rescale=False): | ||
| """Returns the barycentric coordinates of a list of points on the hypercube. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| points: numpy.ndarray or GEM.Node | ||
| The reference coordinates of the points. | ||
|
|
||
| Returns | ||
| ------- | ||
| List of numpy.ndarray or GEM.ComponentTensor | ||
| Returns a list of barycentric coordinates in local facet order such that for any point | ||
| lying on local facet `lf` of the cell, the barycentric coordinate at index `lf` vanishes. | ||
| """ | ||
| if isinstance(points, numpy.ndarray) and len(points) == 0: | ||
| return points | ||
|
|
||
| tp_bary_coords = self.product.compute_factor_barycentric_coordinates(points, entity, rescale) | ||
|
|
||
| return tp_bary_coords[self.facet_perm] # A[[1, 3, 5, 4]] | ||
|
|
||
|
|
||
| class UFCHypercube(Hypercube): | ||
| """Reference UFC Hypercube | ||
|
|
@@ -1839,6 +1906,90 @@ def compute_unflattening_map(topology_dict): | |
| return unflattening_map | ||
|
|
||
|
|
||
| def compute_facet_permutation(unflattening_map, product): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the possible permutations for Prisms and Cubes? Don't they end up in a trivial permutation?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean by trivial permutation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think he means an identity permutation. I would recommend that you add some sort of check to catch those somewhere.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's identity for simplicies. As written There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding the comment with the quad example. I see the issue now. The problem is that the I think the implementation of compute_barycentric_coordinates on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expected ordering for
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Pablo. It is indeed very confusing but with your help I am able to see better. I'll try to re-iterate the reasoning here for the sake of clarity.
|
||
| """ | ||
| Returns a permutation mapping each facet of a Hypercube to the index of the | ||
| barycentric coordinate that vanishes on it. | ||
|
|
||
| Let's take the example of a quad in 2D. Calling `compute_factor_barycentric_coordinates` returns | ||
| the barycentric coordinates on each of its 2 axes: | ||
|
|
||
| axis 0 (x): lambda_x_1, lambda_x_2 | ||
| axis 1 (y): lambda_y_1, lambda_y_2 | ||
|
|
||
| as a flat array bary_coords = [lambda_x_1, lambda_x_2, lambda_y_1, lambda_y_2] | ||
|
|
||
| A quad has 4 facets which are numbered in UFC order as: | ||
|
|
||
| facet 3 | ||
| ┌───────┐ | ||
| │ │ | ||
| facet 0 │ │ facet 1 | ||
| │ │ | ||
| │ │ | ||
| └───────┘ | ||
| facet 2 | ||
|
|
||
| Since each axis is a UFCInterval (a simplex), with vertices at P1 = (0,) and P2 = (1,) its barycentric coordinates | ||
| are lambda_1 = 1 - t (vanishes at P2) and lambda_2 = t (vanishes at P1). This applies the rule for barycentric coordinates | ||
| on simplicies which is that lambda_i vanishes on the facet opposite vertex i. | ||
|
|
||
| Therefore: | ||
|
|
||
| - lambda_x_1 vanishes on facet 1 (x=1), lambda_x_2 vanishes on facet 0 (x=0) | ||
| - lambda_y_1 vanishes on facet 3 (y=1), lambda_y_2 vanishes on facet 2 (y=0) | ||
|
|
||
| The permutation computed in this function reorders the array of barycentric coordinates such that: | ||
|
|
||
| bary_coords[perm] = [lambda_x_2, lambda_x_1, lambda_y_2, lambda_y_1] | ||
|
|
||
| where the i-th entry corresponds to the barycentric coordinate vanishing on facet i. | ||
| """ | ||
| # First compute axis offsets into the flattened barycentric coordinate array. | ||
| axis_offsets = [] | ||
| offset = 0 | ||
| for axis_cell in product.cells: | ||
| axis_offsets.append(offset) | ||
| offset += axis_cell.get_dimension() + 1 | ||
|
|
||
| # Initialise the integer permutation array | ||
| sd = len(product.cells) | ||
| num_facets = 2 * sd | ||
| perm = numpy.zeros(num_facets, dtype=int) | ||
|
|
||
| for f in range(num_facets): | ||
| # Recover the tensor-product representation of the facet as given by the unflattening map | ||
| dim_tuple, tp_entity = unflattening_map[(sd - 1, f)] | ||
|
|
||
| # Determine the axis that's orthogonal to the facet | ||
| # E.g., in a quad: | ||
| # if dim_tuple = (0,1) -> facet has dimension 0 on the first component -> fixed at x = 0 or x = 1 | ||
| # if dim_tuple = (1,0) -> facet has dimension 0 on the second component -> fixed at y = 0 or y = 1 | ||
| axis = next( | ||
| i for i, d in enumerate(dim_tuple) | ||
| if d == product.cells[i].get_dimension() - 1 | ||
| ) | ||
|
|
||
| # Determine the index of the endpoint that produces the facet | ||
| # which gives the local facet number in the axis space | ||
| entity_shape = tuple( | ||
| len(c.get_topology()[d]) | ||
| for c, d in zip(product.cells, dim_tuple) | ||
| ) | ||
| tuple_ei = numpy.unravel_index(tp_entity, entity_shape) | ||
| local_facet = tuple_ei[axis] | ||
|
|
||
| # For a simplex (UFCInterval, UFCTriangle), the barycentric coordinate that vanishes on local facet i | ||
| # corresponds to the ID of the vertex that doesn't belong to that facet | ||
| all_vertices = set(product.cells[axis].get_topology()[0].keys()) | ||
| facet_vertices = set(product.cells[axis].get_topology()[0][local_facet]) | ||
| bary_index = next(iter(all_vertices - facet_vertices)) | ||
|
|
||
| perm[f] = axis_offsets[axis] + bary_index | ||
|
|
||
| return perm | ||
|
|
||
|
|
||
| def max_complex(complexes): | ||
| max_cell = max(complexes) | ||
| if all(max_cell >= b for b in complexes): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,5 @@ | ||||||||||||
| from __future__ import annotations | ||||||||||||
|
|
||||||||||||
| """GEM is the intermediate language of TSFC for describing | ||||||||||||
| tensor-valued mathematical expressions and tensor operations. | ||||||||||||
| It is similar to Einstein's notation. | ||||||||||||
|
|
@@ -20,6 +22,8 @@ | |||||||||||
| from operator import attrgetter | ||||||||||||
| from numbers import Integral, Number | ||||||||||||
|
|
||||||||||||
| from types import EllipsisType | ||||||||||||
|
|
||||||||||||
| import numpy | ||||||||||||
| from numpy import asarray | ||||||||||||
|
|
||||||||||||
|
|
@@ -32,7 +36,7 @@ | |||||||||||
| 'Variable', 'Sum', 'Product', 'Division', 'FloorDiv', 'Remainder', 'Power', | ||||||||||||
| 'MathFunction', 'MinValue', 'MaxValue', 'Comparison', | ||||||||||||
| 'LogicalNot', 'LogicalAnd', 'LogicalOr', 'Conditional', | ||||||||||||
| 'Index', 'VariableIndex', 'Indexed', 'ComponentTensor', | ||||||||||||
| 'Index', 'VariableIndex', 'ListIndex', 'Indexed', 'ComponentTensor', | ||||||||||||
| 'IndexSum', 'ListTensor', 'Concatenate', 'Delta', 'OrientationVariableIndex', | ||||||||||||
| 'index_sum', 'partial_indexed', 'reshape', 'view', | ||||||||||||
| 'indices', 'as_gem', 'FlexiblyIndexed', | ||||||||||||
|
|
@@ -81,12 +85,53 @@ def is_equal(self, other): | |||||||||||
| self.children = other.children | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
| def __getitem__(self, indices): | ||||||||||||
| try: | ||||||||||||
| indices = tuple(indices) | ||||||||||||
| except TypeError: | ||||||||||||
| indices = (indices, ) | ||||||||||||
| return Indexed(self, indices) | ||||||||||||
| def __getitem__( | ||||||||||||
| self, | ||||||||||||
| key: IndexT | tuple[IndexT, ...], | ||||||||||||
| ) -> ComponentTensor | Indexed: | ||||||||||||
| """A generalised interface for indexing GEM tensors""" | ||||||||||||
| if not isinstance(key, tuple): | ||||||||||||
| key = (key,) | ||||||||||||
|
|
||||||||||||
| # Expand ellipsis -> fill in remaining dimensions with slice(None) | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand this behaviour. What if I have something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, there's no support for mixed indexing involving slices and arrays so something like Lines 110 to 114 in a5b8ab0
The reason for that is that indexing by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice to unify the code paths for to work |
||||||||||||
| if any(k is Ellipsis for k in key): | ||||||||||||
| if key.count(Ellipsis) > 1: | ||||||||||||
| raise NotImplementedError("Multiple ellipses are not supported.") | ||||||||||||
| ellipsis_pos = key.index(Ellipsis) | ||||||||||||
| remaining_dims = len(self.shape) - (len(key) - 1) | ||||||||||||
| if remaining_dims < 0: | ||||||||||||
| raise IndexError("Too many indices provided.") | ||||||||||||
| key = ( | ||||||||||||
| key[:ellipsis_pos] | ||||||||||||
| + (slice(None), ) * remaining_dims | ||||||||||||
| + key[ellipsis_pos + 1:] | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| has_slice = any(isinstance(k, slice) for k in key) | ||||||||||||
| has_array = any(isinstance(k, (numpy.ndarray, list)) for k in key) | ||||||||||||
|
|
||||||||||||
| if has_slice and has_array: | ||||||||||||
| raise NotImplementedError("Mixed slice and array indexing is not supported.") | ||||||||||||
|
|
||||||||||||
| # Slice indexing -> delegate to view() | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if the implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then everything is together. We could even deprecate gem.view There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not have everything in gem.view? In this way we can keep the dunder methods in gem short and concise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind that. I just want to unify the code paths somewhat. |
||||||||||||
| if has_slice: | ||||||||||||
| # view expects one slice for each axis/dim of the tensor | ||||||||||||
| if len(key) != len(self.shape): | ||||||||||||
| raise IndexError("Expects the number of slices to match the gem.Node tensor rank") | ||||||||||||
| return view(self, *key) | ||||||||||||
|
|
||||||||||||
| # Support a list or array of integer indices | ||||||||||||
| # Previously, built a ListTensor out of Indexed nodes, one for each element of the permutation | ||||||||||||
| if has_array: | ||||||||||||
| arr_pos = next(i for i, k in enumerate(key) if isinstance(k, (numpy.ndarray, list))) | ||||||||||||
| list_index = ListIndex(key[arr_pos]) | ||||||||||||
| new_key = key[:arr_pos] + (list_index,) + key[arr_pos+1:] | ||||||||||||
| indexed = Indexed(self, new_key) # A[perm[i]] is a scalar | ||||||||||||
| retval = ComponentTensor(indexed, (list_index.free_index,)) # CT(A[perm[i]], i) -> A[perm] | ||||||||||||
| return retval | ||||||||||||
|
|
||||||||||||
| # Point indexing | ||||||||||||
| return Indexed(self, key) | ||||||||||||
|
|
||||||||||||
| def __neg__(self): | ||||||||||||
| return componentwise(Product, minus, self) | ||||||||||||
|
|
@@ -117,6 +162,7 @@ def __matmul__(self, other): | |||||||||||
| raise ValueError("Both objects must have shape for matmul") | ||||||||||||
| elif self.shape[-1] != other.shape[0]: | ||||||||||||
| raise ValueError(f"Mismatching shapes {self.shape} and {other.shape} in matmul") | ||||||||||||
|
|
||||||||||||
| *i, k = indices(len(self.shape)) | ||||||||||||
| _, *j = indices(len(other.shape)) | ||||||||||||
| expr = Product(Indexed(self, (*i, k)), Indexed(other, (k, *j))) | ||||||||||||
|
|
@@ -677,6 +723,48 @@ def __repr__(self): | |||||||||||
| def __reduce__(self): | ||||||||||||
| return type(self), (self.expression,) | ||||||||||||
|
|
||||||||||||
| class ListIndex(IndexBase): | ||||||||||||
|
achanbour marked this conversation as resolved.
|
||||||||||||
| """ | ||||||||||||
| Option 1: tensor[list_index] is tensor[index_array[list_index]] free index ranging over 0,1..., len(index_array) | ||||||||||||
|
|
||||||||||||
| Option 2: tensor[list_index] is tensor[list_index.index_array[list_index.free_index]] | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| __slots__ = ('index_array', 'free_index') | ||||||||||||
|
|
||||||||||||
| def __init__(self, index_array, name=None): | ||||||||||||
| index_array = numpy.asarray(index_array) | ||||||||||||
| assert numpy.issubdtype(index_array.dtype, numpy.integer) | ||||||||||||
|
|
||||||||||||
| # Wraps a free index together with the index array | ||||||||||||
| self.index_array = index_array | ||||||||||||
| self.free_index = Index(extent=len(index_array), name=name) | ||||||||||||
|
|
||||||||||||
| # super().__init__(name=name, extent=len(index_array)) | ||||||||||||
|
|
||||||||||||
| def __str__(self): | ||||||||||||
| return f"{self.index_array.tolist()}[i_{self.free_index}]" | ||||||||||||
|
|
||||||||||||
| def __repr__(self): | ||||||||||||
| return f"ListIndex({self.index_array.tolist()}, {self.free_index})" | ||||||||||||
|
|
||||||||||||
| # def __eq__(self, other): | ||||||||||||
| # if type(self) is not type(other): | ||||||||||||
| # return False | ||||||||||||
| # return numpy.array_equal(self.index_array, other.index_array) | ||||||||||||
|
|
||||||||||||
| # def __ne__(self, other): | ||||||||||||
| # return not self.__eq__(other) | ||||||||||||
|
|
||||||||||||
| # def __hash__(self): | ||||||||||||
| # return hash((type(self), self.index_array.tobytes())) | ||||||||||||
|
|
||||||||||||
| # def __reduce__(self): | ||||||||||||
| # return type(self), (self.index_array, ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| IndexT = int | Index | VariableIndex | ListIndex | slice | EllipsisType | list | numpy.ndarray | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class Indexed(Scalar): | ||||||||||||
| __slots__ = ('children', 'multiindex', 'indirect_children') | ||||||||||||
|
|
@@ -740,6 +828,8 @@ def __new__(cls, aggregate, multiindex): | |||||||||||
| new_indices.append(i) | ||||||||||||
| elif isinstance(i, VariableIndex): | ||||||||||||
| new_indices.extend(i.expression.free_indices) | ||||||||||||
| elif isinstance(i, ListIndex): | ||||||||||||
| new_indices.append(i.free_index) | ||||||||||||
| self.free_indices = unique(aggregate.free_indices + tuple(new_indices)) | ||||||||||||
|
|
||||||||||||
| return self | ||||||||||||
|
|
@@ -752,6 +842,8 @@ def index_ordering(self): | |||||||||||
| free_indices.append(i) | ||||||||||||
| elif isinstance(i, VariableIndex): | ||||||||||||
| free_indices.extend(i.expression.free_indices) | ||||||||||||
| elif isinstance(i, ListIndex): | ||||||||||||
| free_indices.append(i.free_index) | ||||||||||||
| return tuple(free_indices) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -869,7 +961,7 @@ def __new__(cls, expression, multiindex): | |||||||||||
| shape = tuple(index.extent for index in multiindex) | ||||||||||||
| assert all(s >= 0 for s in shape) | ||||||||||||
|
|
||||||||||||
| # Zero folding | ||||||||||||
| # Zero foldingc | ||||||||||||
| if isinstance(expression, Zero): | ||||||||||||
| return Zero(shape, dtype=expression.dtype) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.