Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## 2024-05-24 - Avoid `dict.setdefault` with Complex Defaults in Hot Loops
**Learning:** Using `dict.setdefault(key, complex_default)` inside hot loops incurs a significant performance penalty. Even if the key already exists, Python eagerly evaluates and constructs the `complex_default` argument on every iteration before passing it to `setdefault`. This is particularly costly when the default involves list comprehensions or multiple object instantiations.
**Action:** Replace `dict.setdefault` calls within hot loops with explicit membership checks (`if key not in dict: dict[key] = complex_default`). This ensures the default is only evaluated and instantiated when strictly necessary.
8 changes: 3 additions & 5 deletions src/geometry/hull_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
HardmaxResult,
NumberLike,
ValueLike,
_as_fraction,
_coerce_key,
_coerce_value,
_normalize_number,
Expand Down Expand Up @@ -205,10 +204,9 @@ def _rebuild_if_needed(self) -> None:
total_value_sum = [Fraction(0) for _ in range(self._value_width or 0)]

for index, (key, value) in enumerate(self._entries):
bucket = aggregates.setdefault(
key,
{"value_sum": [Fraction(0) for _ in value], "count": 0, "entry_indices": []},
)
if key not in aggregates:
aggregates[key] = {"value_sum": [Fraction(0) for _ in value], "count": 0, "entry_indices": []}
bucket = aggregates[key]
for coord_index, coord in enumerate(value):
bucket["value_sum"][coord_index] += coord
total_value_sum[coord_index] += coord
Expand Down
9 changes: 4 additions & 5 deletions src/model/free_running_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,13 @@ def evaluate_free_running_programs(

outcomes.append(outcome)
bucket = bucket_name(outcome.program_steps)
bucket_state = per_bucket.setdefault(
bucket,
{
if bucket not in per_bucket:
per_bucket[bucket] = {
"program_count": 0,
"exact_trace_count": 0,
"exact_final_state_count": 0,
},
)
}
bucket_state = per_bucket[bucket]
bucket_state["program_count"] += 1
bucket_state["exact_trace_count"] += int(outcome.exact_trace_match)
bucket_state["exact_final_state_count"] += int(outcome.exact_final_state_match)
Expand Down
13 changes: 7 additions & 6 deletions src/model/softmax_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,15 +582,14 @@ def evaluate_teacher_forced_model(
total_correct += correct

bucket = baseline_bucket_name(example.program_steps)
bucket_state = per_bucket.setdefault(
bucket,
{
if bucket not in per_bucket:
per_bucket[bucket] = {
"example_count": 0,
"token_count": 0,
"correct_tokens": 0,
"weighted_loss": 0.0,
},
)
}
bucket_state = per_bucket[bucket]
bucket_state["example_count"] = int(bucket_state["example_count"]) + 1
bucket_state["token_count"] = int(bucket_state["token_count"]) + token_count
bucket_state["correct_tokens"] = int(bucket_state["correct_tokens"]) + correct
Expand Down Expand Up @@ -753,7 +752,9 @@ def evaluate_free_running_rollout(
)

bucket = baseline_bucket_name(example.program_steps)
bucket_state = per_bucket.setdefault(bucket, {"example_count": 0, "exact_count": 0})
if bucket not in per_bucket:
per_bucket[bucket] = {"example_count": 0, "exact_count": 0}
bucket_state = per_bucket[bucket]
bucket_state["example_count"] = int(bucket_state["example_count"]) + 1
bucket_state["exact_count"] = int(bucket_state["exact_count"]) + int(exact)

Expand Down
16 changes: 12 additions & 4 deletions src/model/trainable_latest_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def exact_program_accuracy(scorer: TrainableLatestWriteScorer, samples: Sequence
return 0.0
per_program: dict[str, list[bool]] = {}
for sample in samples:
per_program.setdefault(sample.program_name, []).append(scorer.predict_index(sample) == sample.target_index)
if sample.program_name not in per_program:
per_program[sample.program_name] = []
per_program[sample.program_name].append(scorer.predict_index(sample) == sample.target_index)
exact = sum(1 for outcomes in per_program.values() if all(outcomes))
return exact / len(per_program)

Expand All @@ -198,12 +200,18 @@ def evaluate_scorer(
correct_samples += int(correct)

bucket = bucket_name(sample.program_steps)
bucket_state = per_bucket.setdefault(bucket, {"sample_count": 0, "sample_correct": 0, "programs": {}})
if bucket not in per_bucket:
per_bucket[bucket] = {"sample_count": 0, "sample_correct": 0, "programs": {}}
bucket_state = per_bucket[bucket]
bucket_state["sample_count"] = int(bucket_state["sample_count"]) + 1
bucket_state["sample_correct"] = int(bucket_state["sample_correct"]) + int(correct)
bucket_state["programs"].setdefault(sample.program_name, []).append(correct)
if sample.program_name not in bucket_state["programs"]:
bucket_state["programs"][sample.program_name] = []
bucket_state["programs"][sample.program_name].append(correct)

per_program.setdefault(sample.program_name, []).append(correct)
if sample.program_name not in per_program:
per_program[sample.program_name] = []
per_program[sample.program_name].append(correct)
program_steps[sample.program_name] = sample.program_steps

exact_programs = sum(1 for outcomes in per_program.values() if all(outcomes))
Expand Down