Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Debug script for test_annotation_function parallel mode issue."""
import pyreason as pr
from pyreason import Threshold
import numba
import numpy as np
from pyreason.scripts.numba_wrapper.numba_types.interval_type import closed


@numba.njit
def probability_func(annotations, weights):
prob_A = annotations[0][0].lower
prob_B = annotations[1][0].lower
union_prob = prob_A + prob_B
union_prob = np.round(union_prob, 3)
return union_prob, 1


def main():
# Setup parallel mode
pr.reset()
pr.reset_rules()
pr.reset_settings()
pr.settings.verbose = False # Disable verbose to speed up
pr.settings.parallel_computing = True
pr.settings.allow_ground_rules = True

print("Settings configured:")
print(f" parallel_computing: {pr.settings.parallel_computing}")
print(f" allow_ground_rules: {pr.settings.allow_ground_rules}")

print("=" * 80)
print("PARALLEL MODE DEBUG")
print("=" * 80)

# Add facts
pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))

# Add annotation function
pr.add_annotation_function(probability_func)

# Add rule
pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))

# Run reasoning
print("\nRunning reasoning for 1 timestep...")
interpretation = pr.reason(timesteps=1)

# Display results
print("\n" + "=" * 80)
print("RESULTS")
print("=" * 80)

dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability'])
for t, df in enumerate(dataframes):
print(f'\nTIMESTEP - {t}')
print(df)
print()

# Check what we actually got
print("\n" + "=" * 80)
print("QUERY RESULTS")
print("=" * 80)

# Try to query the actual value
query_result = interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]'))
print(f"\nQuery for [0.21, 1]: {query_result}")

# Let's also try to see what value we actually got
# Query with a wider range to see if it exists at all
wider_query = interpretation.query(pr.Query('union_probability(A, B) : [0, 1]'))
print(f"Query for [0, 1] (wider range): {wider_query}")

# Get the actual edge data
print("\n" + "=" * 80)
print("DETAILED EDGE INSPECTION")
print("=" * 80)

# Access the interpretation's internal data to see actual values
if hasattr(interpretation, 'get_dict'):
edge_dict = interpretation.get_dict()
print(f"\nEdge dictionary keys: {edge_dict.keys()}")
if ('A', 'B') in edge_dict:
print(f"\nEdge ('A', 'B') data:")
for key, value in edge_dict[('A', 'B')].items():
print(f" {key}: {value}")

# Alternative: inspect atoms directly
if hasattr(interpretation, 'atoms'):
print(f"\nAtoms available: {interpretation.atoms}")

print("\n" + "=" * 80)
print("EXPECTED vs ACTUAL")
print("=" * 80)
print(f"Expected: union_probability(A, B) with bounds [0.21, 1]")
print(f"Actual: See dataframe above")


if __name__ == "__main__":
main()
126 changes: 126 additions & 0 deletions debug_thresholds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Debug script for test_custom_thresholds parallel mode issue."""
import pyreason as pr
from pyreason import Threshold


def main():
# Setup parallel mode
pr.reset()
pr.reset_rules()
pr.reset_settings()
pr.settings.verbose = False # Disable verbose to speed up
pr.settings.parallel_computing = True
pr.settings.atom_trace = True

print("=" * 80)
print("CUSTOM THRESHOLDS PARALLEL MODE DEBUG")
print("=" * 80)
print(f"Settings:")
print(f" parallel_computing: {pr.settings.parallel_computing}")
print(f" atom_trace: {pr.settings.atom_trace}")

# Load graph
graph_path = "./tests/functional/group_chat_graph.graphml"
print(f"\nLoading graph from: {graph_path}")
pr.load_graphml(graph_path)

# Add custom thresholds
user_defined_thresholds = [
Threshold("greater_equal", ("number", "total"), 1),
Threshold("greater_equal", ("percent", "total"), 100),
]
print(f"\nCustom thresholds: {user_defined_thresholds}")

# Add rule
pr.add_rule(
pr.Rule(
"ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
"viewed_by_all_rule",
custom_thresholds=user_defined_thresholds,
)
)
print("Rule added: ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)")

# Add facts
pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
print("\nFacts added:")
print(" Viewed(Zach) at t=0")
print(" Viewed(Justin) at t=0")
print(" Viewed(Michelle) at t=1")
print(" Viewed(Amy) at t=2")

# Run reasoning
print("\n" + "=" * 80)
print("Running reasoning for 3 timesteps...")
print("=" * 80)
interpretation = pr.reason(timesteps=3)
print("Reasoning completed!")

# Display results
print("\n" + "=" * 80)
print("RESULTS - ViewedByAll at each timestep")
print("=" * 80)

dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
for t, df in enumerate(dataframes):
print(f"\nTIMESTEP {t}:")
print(f" Number of nodes with ViewedByAll: {len(df)}")
if len(df) > 0:
print(df)
else:
print(" (no nodes with ViewedByAll)")

# Check specific assertions
print("\n" + "=" * 80)
print("ASSERTION CHECKS")
print("=" * 80)

t0_check = len(dataframes[0]) == 0
print(f"✓ t=0: ViewedByAll count = {len(dataframes[0])} (expected: 0) - {'PASS' if t0_check else 'FAIL'}")

t2_check = len(dataframes[2]) == 1
print(f"✓ t=2: ViewedByAll count = {len(dataframes[2])} (expected: 1) - {'PASS' if t2_check else 'FAIL'}")

if len(dataframes[2]) > 0:
has_textmsg = "TextMessage" in dataframes[2]["component"].values
if has_textmsg:
bounds = dataframes[2].iloc[0].ViewedByAll
bounds_check = bounds == [1, 1]
print(f"✓ t=2: TextMessage bounds = {bounds} (expected: [1, 1]) - {'PASS' if bounds_check else 'FAIL'}")
else:
print(f"✗ t=2: TextMessage not found in ViewedByAll nodes")
print(f" Available nodes: {dataframes[2]['component'].values}")
else:
print("✗ t=2: No ViewedByAll nodes found (expected TextMessage)")

# Additional debugging: show all Viewed facts at each timestep
print("\n" + "=" * 80)
print("DEBUG - Viewed nodes at each timestep")
print("=" * 80)
viewed_dataframes = pr.filter_and_sort_nodes(interpretation, ["Viewed"])
for t, df in enumerate(viewed_dataframes):
print(f"\nTIMESTEP {t}:")
if len(df) > 0:
print(df)
else:
print(" (no Viewed nodes)")

# Show HaveAccess edges if possible
print("\n" + "=" * 80)
print("DEBUG - HaveAccess edges")
print("=" * 80)
try:
access_dataframes = pr.filter_and_sort_edges(interpretation, ["HaveAccess"])
print(f"Number of HaveAccess edges at t=0: {len(access_dataframes[0]) if access_dataframes else 'N/A'}")
if access_dataframes and len(access_dataframes[0]) > 0:
print("\nSample HaveAccess edges:")
print(access_dataframes[0].head(10))
except Exception as e:
print(f"Could not retrieve HaveAccess edges: {e}")


if __name__ == "__main__":
main()
39 changes: 32 additions & 7 deletions pyreason/scripts/interpretation/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
for idx, i in enumerate(rules_to_be_applied_edge):
if i[0] == t:
comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static)
sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
print("adding edges:", sources, targets, edge_l)
edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
print('after adding, edges are:', edges)
changes_cnt += changes

# Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
Expand All @@ -475,7 +478,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
override = True if update_mode == 'override' else False
u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)

print('updating edge', e, 'label', edge_l, 'to bound', bnd)
update = u or update

# Update convergence params
Expand Down Expand Up @@ -545,6 +548,12 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
# Threadsafe flags for in_loop and update within prange; merge after loop
in_loop_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
update_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
for _ in range(len(rules)):
in_loop_threadsafe.append(False)
update_threadsafe.append(True)

for i in prange(len(rules)):
rule = rules[i]
Expand All @@ -571,8 +580,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

# If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
in_loop_threadsafe[i] = True
update_threadsafe[i] = False

for applicable_rule in applicable_edge_rules:
e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
Expand All @@ -593,22 +602,38 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi

# If delta_t is zero we apply the rules and check if more are applicable
if delta_t == 0:
in_loop = True
update = False
in_loop_threadsafe[i] = True
update_threadsafe[i] = False

# Update lists after parallel run
# Update lists after parallel run
print("len", len(rules_to_be_applied_edge_threadsafe))
for i in rules_to_be_applied_edge_threadsafe:
print(i)
for i in range(len(rules)):
if len(rules_to_be_applied_node_threadsafe[i]) > 0:
rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
print('here, edge rules')
rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
print("rules_to_be_applied_edge", rules_to_be_applied_edge)
if atom_trace:
if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
print('here, edge add')
edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule)

# Merge threadsafe flags for in_loop and update
in_loop = in_loop
update = update
for i in range(len(rules)):
if in_loop_threadsafe[i]:
in_loop = True
if not update_threadsafe[i]:
update = False

# Check for convergence after each timestep (perfect convergence or convergence specified by user)
# Check number of changed interpretations or max bound change
Expand Down Expand Up @@ -1964,4 +1989,4 @@ def str_to_int(value):
for i, v in enumerate(value):
result += (ord(v) - 48) * (10 ** (final_index - i))
result = -result if negative else result
return result
return result
Loading