@@ -154,51 +154,58 @@ def init_interpretations_edge(edges, specific_labels):
154154 )
155155
156156 _reason_fn = _py (interpretation .Interpretation .reason )
157- if "num_ga" in inspect .signature (_reason_fn ).parameters :
158- def reason (
159- interpretations_node ,
160- interpretations_edge ,
161- predicate_map_node ,
162- predicate_map_edge ,
163- tmax ,
164- prev_reasoning_data ,
165- rules ,
166- nodes ,
167- edges ,
168- neighbors ,
169- reverse_neighbors ,
170- rules_to_be_applied_node ,
171- rules_to_be_applied_edge ,
172- edges_to_be_added_node_rule ,
173- edges_to_be_added_edge_rule ,
174- rules_to_be_applied_node_trace ,
175- rules_to_be_applied_edge_trace ,
176- facts_to_be_applied_node ,
177- facts_to_be_applied_edge ,
178- facts_to_be_applied_node_trace ,
179- facts_to_be_applied_edge_trace ,
180- ipl ,
181- rule_trace_node ,
182- rule_trace_edge ,
183- rule_trace_node_atoms ,
184- rule_trace_edge_atoms ,
185- reverse_graph ,
186- atom_trace ,
187- save_graph_attributes_to_rule_trace ,
188- persistent ,
189- inconsistency_check ,
190- store_interpretation_changes ,
191- update_mode ,
192- allow_ground_rules ,
193- max_facts_time ,
194- annotation_functions ,
195- head_functions ,
196- convergence_mode ,
197- convergence_delta ,
198- verbose ,
199- again ,
200- closed_world_predicates ,
201- ):
157+ _reason_params = inspect .signature (_reason_fn ).parameters
158+ _has_num_ga = "num_ga" in _reason_params
159+
160+ # Both branches wrap so we can inject `extended_ann_fn_flags` (one bool per
161+ # rule, default False since the disable_jit tests don't use 6-arg ann_fns).
162+ # The num_ga branch additionally injects `[0]` for num_ga.
163+ def reason (
164+ interpretations_node ,
165+ interpretations_edge ,
166+ predicate_map_node ,
167+ predicate_map_edge ,
168+ tmax ,
169+ prev_reasoning_data ,
170+ rules ,
171+ nodes ,
172+ edges ,
173+ neighbors ,
174+ reverse_neighbors ,
175+ rules_to_be_applied_node ,
176+ rules_to_be_applied_edge ,
177+ edges_to_be_added_node_rule ,
178+ edges_to_be_added_edge_rule ,
179+ rules_to_be_applied_node_trace ,
180+ rules_to_be_applied_edge_trace ,
181+ facts_to_be_applied_node ,
182+ facts_to_be_applied_edge ,
183+ facts_to_be_applied_node_trace ,
184+ facts_to_be_applied_edge_trace ,
185+ ipl ,
186+ rule_trace_node ,
187+ rule_trace_edge ,
188+ rule_trace_node_atoms ,
189+ rule_trace_edge_atoms ,
190+ reverse_graph ,
191+ atom_trace ,
192+ save_graph_attributes_to_rule_trace ,
193+ persistent ,
194+ inconsistency_check ,
195+ store_interpretation_changes ,
196+ update_mode ,
197+ allow_ground_rules ,
198+ max_facts_time ,
199+ annotation_functions ,
200+ head_functions ,
201+ convergence_mode ,
202+ convergence_delta ,
203+ verbose ,
204+ again ,
205+ closed_world_predicates ,
206+ ):
207+ extended_ann_fn_flags = [False ] * len (rules )
208+ if _has_num_ga :
202209 return _reason_fn (
203210 interpretations_node [0 ],
204211 interpretations_edge [0 ],
@@ -236,6 +243,7 @@ def reason(
236243 allow_ground_rules ,
237244 max_facts_time ,
238245 annotation_functions ,
246+ extended_ann_fn_flags ,
239247 head_functions ,
240248 convergence_mode ,
241249 convergence_delta ,
@@ -244,8 +252,52 @@ def reason(
244252 again ,
245253 closed_world_predicates ,
246254 )
247- else :
248- reason = _reason_fn
255+ else :
256+ return _reason_fn (
257+ interpretations_node ,
258+ interpretations_edge ,
259+ predicate_map_node ,
260+ predicate_map_edge ,
261+ tmax ,
262+ prev_reasoning_data ,
263+ rules ,
264+ nodes ,
265+ edges ,
266+ neighbors ,
267+ reverse_neighbors ,
268+ rules_to_be_applied_node ,
269+ rules_to_be_applied_edge ,
270+ edges_to_be_added_node_rule ,
271+ edges_to_be_added_edge_rule ,
272+ rules_to_be_applied_node_trace ,
273+ rules_to_be_applied_edge_trace ,
274+ facts_to_be_applied_node ,
275+ facts_to_be_applied_edge ,
276+ facts_to_be_applied_node_trace ,
277+ facts_to_be_applied_edge_trace ,
278+ ipl ,
279+ rule_trace_node ,
280+ rule_trace_edge ,
281+ rule_trace_node_atoms ,
282+ rule_trace_edge_atoms ,
283+ reverse_graph ,
284+ atom_trace ,
285+ save_graph_attributes_to_rule_trace ,
286+ persistent ,
287+ inconsistency_check ,
288+ store_interpretation_changes ,
289+ update_mode ,
290+ allow_ground_rules ,
291+ max_facts_time ,
292+ annotation_functions ,
293+ extended_ann_fn_flags ,
294+ head_functions ,
295+ convergence_mode ,
296+ convergence_delta ,
297+ verbose ,
298+ again ,
299+ closed_world_predicates ,
300+ )
249301 ns .reason = reason
250302
251303 class FakeLabel :
0 commit comments