-
Notifications
You must be signed in to change notification settings - Fork 0
β‘ Bolt: [Replace setdefault with explicit membership checks in loops] #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| ## 2024-04-20 - [Avoid setdefault with complex defaults in loops] | ||
| **Learning:** `dict.setdefault(key, complex_default)` evaluates `complex_default` on every iteration, even if the key already exists. This causes significant performance overhead in hot loops when the default is a dictionary or list, as it constantly creates and immediately discards these objects. | ||
| **Action:** Replace `dict.setdefault(key, complex_default)` with `if key not in dict: dict[key] = complex_default` to avoid expensive eager evaluation. |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -582,15 +582,14 @@ def evaluate_teacher_forced_model( | |||||||||
| total_correct += correct | ||||||||||
|
|
||||||||||
| bucket = baseline_bucket_name(example.program_steps) | ||||||||||
| bucket_state = per_bucket.setdefault( | ||||||||||
| bucket, | ||||||||||
| { | ||||||||||
| if bucket not in per_bucket: | ||||||||||
| per_bucket[bucket] = { | ||||||||||
| "example_count": 0, | ||||||||||
| "token_count": 0, | ||||||||||
| "correct_tokens": 0, | ||||||||||
| "weighted_loss": 0.0, | ||||||||||
| }, | ||||||||||
| ) | ||||||||||
| } | ||||||||||
| bucket_state = per_bucket[bucket] | ||||||||||
| bucket_state["example_count"] = int(bucket_state["example_count"]) + 1 | ||||||||||
|
Comment on lines
+585
to
593
|
||||||||||
| bucket_state["token_count"] = int(bucket_state["token_count"]) + token_count | ||||||||||
| bucket_state["correct_tokens"] = int(bucket_state["correct_tokens"]) + correct | ||||||||||
|
|
@@ -753,7 +752,9 @@ def evaluate_free_running_rollout( | |||||||||
| ) | ||||||||||
|
|
||||||||||
| bucket = baseline_bucket_name(example.program_steps) | ||||||||||
| bucket_state = per_bucket.setdefault(bucket, {"example_count": 0, "exact_count": 0}) | ||||||||||
| if bucket not in per_bucket: | ||||||||||
| per_bucket[bucket] = {"example_count": 0, "exact_count": 0} | ||||||||||
| bucket_state = per_bucket[bucket] | ||||||||||
|
Comment on lines
+755
to
+757
|
||||||||||
| if bucket not in per_bucket: | |
| per_bucket[bucket] = {"example_count": 0, "exact_count": 0} | |
| bucket_state = per_bucket[bucket] | |
| bucket_state = per_bucket.setdefault(bucket, {"example_count": 0, "exact_count": 0}) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,7 +171,9 @@ def exact_program_accuracy(scorer: TrainableLatestWriteScorer, samples: Sequence | |
| return 0.0 | ||
| per_program: dict[str, list[bool]] = {} | ||
| for sample in samples: | ||
| per_program.setdefault(sample.program_name, []).append(scorer.predict_index(sample) == sample.target_index) | ||
| if sample.program_name not in per_program: | ||
| per_program[sample.program_name] = [] | ||
| per_program[sample.program_name].append(scorer.predict_index(sample) == sample.target_index) | ||
|
Comment on lines
+174
to
+176
|
||
| exact = sum(1 for outcomes in per_program.values() if all(outcomes)) | ||
| return exact / len(per_program) | ||
|
|
||
|
|
@@ -198,12 +200,20 @@ def evaluate_scorer( | |
| correct_samples += int(correct) | ||
|
|
||
| bucket = bucket_name(sample.program_steps) | ||
| bucket_state = per_bucket.setdefault(bucket, {"sample_count": 0, "sample_correct": 0, "programs": {}}) | ||
| if bucket not in per_bucket: | ||
| per_bucket[bucket] = {"sample_count": 0, "sample_correct": 0, "programs": {}} | ||
| bucket_state = per_bucket[bucket] | ||
| bucket_state["sample_count"] = int(bucket_state["sample_count"]) + 1 | ||
| bucket_state["sample_correct"] = int(bucket_state["sample_correct"]) + int(correct) | ||
| bucket_state["programs"].setdefault(sample.program_name, []).append(correct) | ||
|
|
||
| per_program.setdefault(sample.program_name, []).append(correct) | ||
| programs: dict = bucket_state["programs"] # type: ignore | ||
| if sample.program_name not in programs: | ||
| programs[sample.program_name] = [] | ||
| programs[sample.program_name].append(correct) | ||
|
Comment on lines
+209
to
+212
|
||
|
|
||
| if sample.program_name not in per_program: | ||
| per_program[sample.program_name] = [] | ||
| per_program[sample.program_name].append(correct) | ||
| program_steps[sample.program_name] = sample.program_steps | ||
|
|
||
| exact_programs = sum(1 for outcomes in per_program.values() if all(outcomes)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This membership-check + indexing pattern does two dict lookups per program (
bucket in per_bucketthenper_bucket[bucket]). Since the stated goal is improving hot-loop performance, consider a single-lookup approach (getwith initialization ortry/except KeyError) to avoid the extra lookup.