Skip to content

Commit 3204a48

Browse files
committed
Fix tests
1 parent 2d55540 commit 3204a48

4 files changed

Lines changed: 123 additions & 47 deletions

File tree

pyreason/scripts/interpretation/interpretation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
972972
add_head_var_node_to_graph = True
973973
groundings[head_var_1] = numba.typed.List([head_var_1])
974974

975+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
976+
# is allocated K*N times (K head groundings * N clauses) below, but the data
977+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
978+
# here and append references in the inner loop instead of fresh copies.
975979
for head_grounding in groundings[head_var_1]:
976980
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
977981
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
@@ -1111,6 +1115,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
11111115
valid_edge_groundings.append((g1, g2))
11121116

11131117
# Loop through the head variable groundings
1118+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
1119+
# is allocated K*N times (K edge groundings * N clauses) below, but the data
1120+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
1121+
# here and append references in the inner loop instead of fresh copies.
11141122
for valid_e in valid_edge_groundings:
11151123
head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
11161124
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))

pyreason/scripts/interpretation/interpretation_fp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
10941094
add_head_var_node_to_graph = True
10951095
groundings[head_var_1] = numba.typed.List([head_var_1])
10961096

1097+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
1098+
# is allocated K*N times (K head groundings * N clauses) below, but the data
1099+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
1100+
# here and append references in the inner loop instead of fresh copies.
10971101
for head_grounding in groundings[head_var_1]:
10981102
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
10991103
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
@@ -1232,6 +1236,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
12321236
valid_edge_groundings.append((g1, g2))
12331237

12341238
# Loop through the head variable groundings
1239+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
1240+
# is allocated K*N times (K edge groundings * N clauses) below, but the data
1241+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
1242+
# here and append references in the inner loop instead of fresh copies.
12351243
for valid_e in valid_edge_groundings:
12361244
head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
12371245
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))

pyreason/scripts/interpretation/interpretation_parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
972972
add_head_var_node_to_graph = True
973973
groundings[head_var_1] = numba.typed.List([head_var_1])
974974

975+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
976+
# is allocated K*N times (K head groundings * N clauses) below, but the data
977+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
978+
# here and append references in the inner loop instead of fresh copies.
975979
for head_grounding in groundings[head_var_1]:
976980
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
977981
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
@@ -1111,6 +1115,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
11111115
valid_edge_groundings.append((g1, g2))
11121116

11131117
# Loop through the head variable groundings
1118+
# TODO(perf): when extended_ann_fn is True, `numba.typed.List(clause_variables)`
1119+
# is allocated K*N times (K edge groundings * N clauses) below, but the data
1120+
# is rule-static. Hoist a precomputed `clause_variables_precomputed` list
1121+
# here and append references in the inner loop instead of fresh copies.
11141122
for valid_e in valid_edge_groundings:
11151123
head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
11161124
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))

tests/unit/disable_jit/interpretations/test_interpretation_common.py

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -154,51 +154,58 @@ def init_interpretations_edge(edges, specific_labels):
154154
)
155155

156156
_reason_fn = _py(interpretation.Interpretation.reason)
157-
if "num_ga" in inspect.signature(_reason_fn).parameters:
158-
def reason(
159-
interpretations_node,
160-
interpretations_edge,
161-
predicate_map_node,
162-
predicate_map_edge,
163-
tmax,
164-
prev_reasoning_data,
165-
rules,
166-
nodes,
167-
edges,
168-
neighbors,
169-
reverse_neighbors,
170-
rules_to_be_applied_node,
171-
rules_to_be_applied_edge,
172-
edges_to_be_added_node_rule,
173-
edges_to_be_added_edge_rule,
174-
rules_to_be_applied_node_trace,
175-
rules_to_be_applied_edge_trace,
176-
facts_to_be_applied_node,
177-
facts_to_be_applied_edge,
178-
facts_to_be_applied_node_trace,
179-
facts_to_be_applied_edge_trace,
180-
ipl,
181-
rule_trace_node,
182-
rule_trace_edge,
183-
rule_trace_node_atoms,
184-
rule_trace_edge_atoms,
185-
reverse_graph,
186-
atom_trace,
187-
save_graph_attributes_to_rule_trace,
188-
persistent,
189-
inconsistency_check,
190-
store_interpretation_changes,
191-
update_mode,
192-
allow_ground_rules,
193-
max_facts_time,
194-
annotation_functions,
195-
head_functions,
196-
convergence_mode,
197-
convergence_delta,
198-
verbose,
199-
again,
200-
closed_world_predicates,
201-
):
157+
_reason_params = inspect.signature(_reason_fn).parameters
158+
_has_num_ga = "num_ga" in _reason_params
159+
160+
# Both branches wrap so we can inject `extended_ann_fn_flags` (one bool per
161+
# rule, default False since the disable_jit tests don't use 6-arg ann_fns).
162+
# The num_ga branch additionally injects `[0]` for num_ga.
163+
def reason(
164+
interpretations_node,
165+
interpretations_edge,
166+
predicate_map_node,
167+
predicate_map_edge,
168+
tmax,
169+
prev_reasoning_data,
170+
rules,
171+
nodes,
172+
edges,
173+
neighbors,
174+
reverse_neighbors,
175+
rules_to_be_applied_node,
176+
rules_to_be_applied_edge,
177+
edges_to_be_added_node_rule,
178+
edges_to_be_added_edge_rule,
179+
rules_to_be_applied_node_trace,
180+
rules_to_be_applied_edge_trace,
181+
facts_to_be_applied_node,
182+
facts_to_be_applied_edge,
183+
facts_to_be_applied_node_trace,
184+
facts_to_be_applied_edge_trace,
185+
ipl,
186+
rule_trace_node,
187+
rule_trace_edge,
188+
rule_trace_node_atoms,
189+
rule_trace_edge_atoms,
190+
reverse_graph,
191+
atom_trace,
192+
save_graph_attributes_to_rule_trace,
193+
persistent,
194+
inconsistency_check,
195+
store_interpretation_changes,
196+
update_mode,
197+
allow_ground_rules,
198+
max_facts_time,
199+
annotation_functions,
200+
head_functions,
201+
convergence_mode,
202+
convergence_delta,
203+
verbose,
204+
again,
205+
closed_world_predicates,
206+
):
207+
extended_ann_fn_flags = [False] * len(rules)
208+
if _has_num_ga:
202209
return _reason_fn(
203210
interpretations_node[0],
204211
interpretations_edge[0],
@@ -236,6 +243,7 @@ def reason(
236243
allow_ground_rules,
237244
max_facts_time,
238245
annotation_functions,
246+
extended_ann_fn_flags,
239247
head_functions,
240248
convergence_mode,
241249
convergence_delta,
@@ -244,8 +252,52 @@ def reason(
244252
again,
245253
closed_world_predicates,
246254
)
247-
else:
248-
reason = _reason_fn
255+
else:
256+
return _reason_fn(
257+
interpretations_node,
258+
interpretations_edge,
259+
predicate_map_node,
260+
predicate_map_edge,
261+
tmax,
262+
prev_reasoning_data,
263+
rules,
264+
nodes,
265+
edges,
266+
neighbors,
267+
reverse_neighbors,
268+
rules_to_be_applied_node,
269+
rules_to_be_applied_edge,
270+
edges_to_be_added_node_rule,
271+
edges_to_be_added_edge_rule,
272+
rules_to_be_applied_node_trace,
273+
rules_to_be_applied_edge_trace,
274+
facts_to_be_applied_node,
275+
facts_to_be_applied_edge,
276+
facts_to_be_applied_node_trace,
277+
facts_to_be_applied_edge_trace,
278+
ipl,
279+
rule_trace_node,
280+
rule_trace_edge,
281+
rule_trace_node_atoms,
282+
rule_trace_edge_atoms,
283+
reverse_graph,
284+
atom_trace,
285+
save_graph_attributes_to_rule_trace,
286+
persistent,
287+
inconsistency_check,
288+
store_interpretation_changes,
289+
update_mode,
290+
allow_ground_rules,
291+
max_facts_time,
292+
annotation_functions,
293+
extended_ann_fn_flags,
294+
head_functions,
295+
convergence_mode,
296+
convergence_delta,
297+
verbose,
298+
again,
299+
closed_world_predicates,
300+
)
249301
ns.reason = reason
250302

251303
class FakeLabel:

0 commit comments

Comments
 (0)