Skip to content

Commit 727256f

Browse files
authored
Fix tf.function retracing in TensorFlow benchmark (#27665)
## Summary - Move `tf.function`-decorated forward functions out of the inner benchmark loop to prevent unnecessary graph retracing on every `(batch_size, sequence_length)` iteration - Update deprecated `experimental_compile` to `jit_compile` (available since TF 2.4) - Hoist `import random` out of the inner loop Fixes #14953 ## Motivation When `run_with_tf_optimizations` is used as a decorator inside the innermost `(batch_size, sequence_length)` loop, each iteration creates a new Python function object. Since `tf.function` keys its trace cache on function identity, a new object means a forced retrace every iteration — the cached graph is never reused. This defeats the purpose of `tf.function` and adds significant overhead from repeated graph construction and optimization passes. The [TensorFlow documentation on tracing](https://www.tensorflow.org/guide/function#rules_of_tracing) explicitly warns against defining `tf.function`-decorated functions inside loops. ## Changes **`onnxruntime/python/tools/transformers/benchmark.py`** (1 file, ~35 insertions / ~31 deletions): 1. **Hoisted forward function definitions** (`encoder_forward`, `encoder_decoder_forward`, `lxmert_forward`) from the inner `batch_size × sequence_length` loop to the per-model scope. They are now defined once per model, and the `@run_with_tf_optimizations` decorator (which applies `@tf.function`) is only invoked once per model. 2. **Changed forward functions to accept `input_ids` as a parameter** instead of closing over the loop variable. This lets `tf.function` trace based on the tensor's `(dtype, shape)` spec and reuse cached concrete functions when shapes repeat across iterations. 3. **Updated `experimental_compile=use_xla`** to **`jit_compile=use_xla`**. The `experimental_compile` parameter was deprecated in TF 2.4 (Dec 2020) and removed in TF 2.12. 4. **Moved `import random`** from the innermost loop body to before the outer model loop — the module only needs to be imported once. 5. **Moved inference function selection** (`if config.is_encoder_decoder ... elif isinstance(config, LxmertConfig) ...`) outside the batch/sequence loops since it depends only on the model config, not on batch size or sequence length. The original priority order (`is_encoder_decoder` checked before `LxmertConfig`) is preserved. ## Test Plan - [x] `lintrunner -a` passes cleanly (no RUFF or RUFF-FORMAT violations) - [x] `python -m py_compile benchmark.py` — syntax verified - [x] Change is purely structural — function behavior (inputs, outputs, control flow) is identical - [ ] Manual verification with TensorFlow installed (TF is an optional dependency not present in the standard CI matrix; this code path is exercised via `python benchmark.py -e tensorflow`)
1 parent aa6f2e3 commit 727256f

1 file changed

Lines changed: 34 additions & 31 deletions

File tree

onnxruntime/python/tools/transformers/benchmark.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import argparse
4444
import logging
4545
import os
46+
import random
4647
import timeit
4748
from datetime import datetime
4849

@@ -431,7 +432,7 @@ def run_in_eager_mode(*args, **kwargs):
431432
return func(*args, **kwargs)
432433

433434
@wraps(func)
434-
@tf.function(experimental_compile=use_xla)
435+
@tf.function(jit_compile=use_xla)
435436
def run_in_graph_mode(*args, **kwargs):
436437
return func(*args, **kwargs)
437438

@@ -500,6 +501,36 @@ def run_tensorflow(
500501

501502
max_input_size = tokenizer.model_max_length
502503

504+
# Define tf.function-decorated forward functions once per model, outside the
505+
# batch_size/sequence_length loops. Passing input_ids as an argument (instead
506+
# of closing over it) allows tf.function to cache traced graphs by input shape
507+
# rather than retracing on every loop iteration. See issue #14953.
508+
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
509+
def encoder_forward(input_ids):
510+
return model(input_ids, training=False) # noqa: B023
511+
512+
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
513+
def encoder_decoder_forward(input_ids):
514+
return model(input_ids, decoder_input_ids=input_ids, training=False) # noqa: B023
515+
516+
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
517+
def lxmert_forward(input_ids):
518+
feats = tf.random.normal([1, 1, config.visual_feat_dim]) # noqa: B023
519+
pos = tf.random.normal([1, 1, config.visual_pos_dim]) # noqa: B023
520+
return model( # noqa: B023
521+
input_ids,
522+
visual_feats=feats,
523+
visual_pos=pos,
524+
training=False,
525+
)
526+
527+
if config.is_encoder_decoder:
528+
inference = encoder_decoder_forward
529+
elif isinstance(config, LxmertConfig):
530+
inference = lxmert_forward
531+
else:
532+
inference = encoder_forward
533+
503534
for batch_size in batch_sizes:
504535
if batch_size <= 0:
505536
continue
@@ -510,42 +541,14 @@ def run_tensorflow(
510541

511542
logger.info(f"Run Tensorflow on {model_name} with input shape {[batch_size, sequence_length]}")
512543

513-
import random # noqa: PLC0415
514-
515544
rng = random.Random()
516545
values = [rng.randint(0, config.vocab_size - 1) for i in range(batch_size * sequence_length)]
517546
input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
518547

519548
try:
520-
# Disable both for better inference perf
521-
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
522-
def encoder_forward():
523-
return model(input_ids, training=False) # noqa: B023
524-
525-
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
526-
def encoder_decoder_forward():
527-
return model(input_ids, decoder_input_ids=input_ids, training=False) # noqa: B023
528-
529-
@run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
530-
def lxmert_forward():
531-
feats = tf.random.normal([1, 1, config.visual_feat_dim]) # noqa: B023
532-
pos = tf.random.normal([1, 1, config.visual_pos_dim]) # noqa: B023
533-
return model( # noqa: B023
534-
input_ids, # noqa: B023
535-
visual_feats=feats,
536-
visual_pos=pos,
537-
training=False,
538-
)
539-
540-
inference = encoder_forward
541-
if config.is_encoder_decoder:
542-
inference = encoder_decoder_forward
543-
elif isinstance(config, LxmertConfig):
544-
inference = lxmert_forward
545-
546-
inference()
549+
inference(input_ids)
547550

548-
runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1) # noqa: B023
551+
runtimes = timeit.repeat(lambda: inference(input_ids), repeat=repeat_times, number=1) # noqa: B023
549552

550553
result = {
551554
"engine": "tensorflow",

0 commit comments

Comments
 (0)