Skip to content

Commit eba02b9

Browse files
committed
Optimise find_optimal_atomics
1 parent f7be49b commit eba02b9

1 file changed

Lines changed: 12 additions & 13 deletions

File tree

gem/coffee.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This file is NOT for code generation as a COFFEE AST.
55
"""
66

7-
from itertools import chain, repeat
7+
from itertools import chain
88
import logging
99

1010
import numpy
@@ -54,6 +54,8 @@ def sort_monomials(monomials):
5454
5555
:returns: the reordered list of monomials.
5656
"""
57+
if len(monomials) <= 2:
58+
return monomials
5759
# Construct a monomial subset with non-intersecting atomics
5860
head = []
5961
rest = []
@@ -64,13 +66,8 @@ def sort_monomials(monomials):
6466
else:
6567
atomics.update(m.atomics)
6668
head.append(m)
67-
# Put non-intersecting subset first
68-
monomials = head + rest
69-
# Sort the unique atomics as they appear in the monomials
70-
atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials)))
71-
# Sort the rest by the sum of the indices of their atomics
72-
rest.sort(key=lambda m: sum(map(atomics.index, m.atomics)))
73-
monomials = head + rest
69+
# Put non-intersecting subset first and recurse on the rest
70+
monomials = head + sort_monomials(rest)
7471
return monomials
7572

7673

@@ -88,24 +85,26 @@ def find_optimal_atomics(monomials, linear_indices):
8885

8986
atomics = tuple(dict.fromkeys(chain.from_iterable(monomial.atomics for monomial in monomials)))
9087

88+
monomial_atomics = [set(map(atomics.index, m.atomics)) for m in monomials]
89+
9190
def cost(solution):
92-
extent = sum(map(index_extent, solution, repeat(linear_indices)))
91+
extent = sum(index_extent(atomics[i], linear_indices) for i in solution)
9392
# Prefer shorter solutions, but larger extents
9493
return (len(solution), -extent)
9594

96-
optimal_solution = set(atomics) # pessimal but feasible solution
95+
optimal_solution = set(range(len(atomics))) # pessimal but feasible solution
9796
solution = set()
9897

9998
max_it = 1 << 12
10099
it = iter(range(max_it))
101100

102101
def solve(idx):
103-
while idx < len(monomials) and solution.intersection(monomials[idx].atomics):
102+
while idx < len(monomials) and solution.intersection(monomial_atomics[idx]):
104103
idx += 1
105104

106105
if idx < len(monomials):
107106
if len(solution) < len(optimal_solution):
108-
for atomic in monomials[idx].atomics:
107+
for atomic in monomial_atomics[idx]:
109108
solution.add(atomic)
110109
solve(idx + 1)
111110
solution.remove(atomic)
@@ -122,7 +121,7 @@ def solve(idx):
122121
logger.warning("Solution to ILP problem may not be optimal: search "
123122
"interrupted after examining %d solutions.", max_it)
124123

125-
return tuple(atomic for atomic in atomics if atomic in optimal_solution)
124+
return tuple(atomics[i] for i in optimal_solution)
126125

127126

128127
def factorise_atomics(monomials, optimal_atomics, linear_indices):

0 commit comments

Comments
 (0)