Skip to content

Commit 15d0255

Browse files
psiddhfacebook-github-bot
authored andcommitted
Add LLM performance regression instrumentation tests
Summary: Adds `LlmPerformanceTest`, an Android instrumentation test that measures inference performance metrics (TPS, TPS stability, TTFT) for ExecuTorch LLM on the stories110M fixture and asserts they meet minimum thresholds. This enables OKR 3.3 (Performance Testing: TPS/latency regression detection) using the same zero-infra approach as D105741356 — same fixture, same CI paths, no new dependencies. Three performance aspects are tested: 1. `testTpsAboveThreshold` — decode speed regression gate. A warm-up run is excluded from measurement. Threshold is configurable via instrumentation arg (`minTps`) so the same APK works on emulator (1.0 TPS) and device (10+ TPS). 2. `testTpsStability` — checks coefficient of variation across 3 runs is below 0.5. Catches thread contention, GC pressure, or scheduling instability that causes inconsistent user experience. 3. `testTimeToFirstToken` — measures prompt evaluation latency (prefill time). Asserts TTFT < 30s. Catches regressions in the prefill/KV-cache-fill path that make the app feel unresponsive before generation starts. All metrics are reported via InstrumentationRegistry.sendStatus() for CI metric capture and future dashboarding. Differential Revision: D105840841
1 parent 3b5d18d commit 15d0255

1 file changed

Lines changed: 294 additions & 0 deletions

File tree

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
package org.pytorch.executorch
9+
10+
import android.os.Bundle
11+
import androidx.test.ext.junit.runners.AndroidJUnit4
12+
import androidx.test.platform.app.InstrumentationRegistry
13+
import java.io.File
14+
import java.io.IOException
15+
import java.util.Collections
16+
import org.apache.commons.io.FileUtils
17+
import org.json.JSONException
18+
import org.json.JSONObject
19+
import org.junit.After
20+
import org.junit.Assert.assertTrue
21+
import org.junit.Assert.fail
22+
import org.junit.Before
23+
import org.junit.Test
24+
import org.junit.runner.RunWith
25+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
26+
import org.pytorch.executorch.extension.llm.LlmCallback
27+
import org.pytorch.executorch.extension.llm.LlmModule
28+
29+
/**
30+
* Performance regression tests for LLM inference on ExecuTorch Android.
31+
*
32+
* Measures tokens-per-second (TPS), TPS stability, and time-to-first-token (TTFT). Results are
33+
* reported via [InstrumentationRegistry] so CI systems can capture and trend metrics over time.
34+
*
35+
* Uses the same TinyStories-110M fixture as [LlmModuleConversationHistoryTest], so no additional
36+
* test infrastructure is needed. Works on both OSS (GitHub Actions) and internal (Sandcastle) CI.
37+
*
38+
* To run locally:
39+
* ```
40+
* ./gradlew :executorch_android:connectedAndroidTest \
41+
* -Pandroid.testInstrumentationRunnerArguments.class=org.pytorch.executorch.LlmPerformanceTest
42+
* ```
43+
*
44+
* To override the TPS threshold for physical devices:
45+
* ```
46+
* -Pandroid.testInstrumentationRunnerArguments.minTps=10.0
47+
* ```
48+
*/
49+
@RunWith(AndroidJUnit4::class)
50+
class LlmPerformanceTest : LlmCallback {
51+
52+
private lateinit var llmModule: LlmModule
53+
private val generatedTokens: MutableList<String> =
54+
Collections.synchronizedList(mutableListOf<String>())
55+
private val tpsResults: MutableList<Float> =
56+
Collections.synchronizedList(mutableListOf<Float>())
57+
@Volatile private var lastStatsJson: String? = null
58+
59+
@Before
60+
@Throws(IOException::class)
61+
fun setUp() {
62+
val pteFile = File(getTestFilePath(TEST_FILE_NAME))
63+
val pteStream =
64+
requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) {
65+
"Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?"
66+
}
67+
FileUtils.copyInputStreamToFile(pteStream, pteFile)
68+
pteStream.close()
69+
70+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
71+
val tokenizerStream =
72+
requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) {
73+
"Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?"
74+
}
75+
FileUtils.copyInputStreamToFile(tokenizerStream, tokenizerFile)
76+
tokenizerStream.close()
77+
78+
llmModule =
79+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
80+
}
81+
82+
@After
83+
fun tearDown() {
84+
if (::llmModule.isInitialized) {
85+
llmModule.close()
86+
}
87+
}
88+
89+
/**
90+
* Measures TPS after a warm-up run and asserts it exceeds a minimum threshold.
91+
*
92+
* The warm-up is necessary because the first inference includes one-time costs (memory
93+
* allocation, kernel compilation on some backends) that would unfairly penalize the measurement.
94+
*
95+
* Default threshold is conservative (1.0 TPS) for emulator CI. Override with the `minTps`
96+
* instrumentation argument for physical device runs where 10-30+ TPS is expected.
97+
*/
98+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
99+
fun testTpsAboveThreshold() {
100+
val loadResult = llmModule.load()
101+
assertTrue("Model failed to load (result=$loadResult)", loadResult == 0)
102+
103+
// Warm-up: first inference includes one-time overhead
104+
resetState()
105+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
106+
assertTrue("Warm-up produced no tokens — model may be broken", generatedTokens.isNotEmpty())
107+
val warmupTps = tpsResults.lastOrNull() ?: 0f
108+
reportMetric("warmup_tps", warmupTps)
109+
110+
// Measured run
111+
resetState()
112+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
113+
114+
assertTrue("Measured run produced no tokens", generatedTokens.isNotEmpty())
115+
assertTrue("No TPS stats received from onStats callback", tpsResults.isNotEmpty())
116+
117+
val measuredTps = tpsResults.last()
118+
val minTps = getMinTpsThreshold()
119+
120+
reportMetric("measured_tps", measuredTps)
121+
reportMetric("measured_tokens", generatedTokens.size.toFloat())
122+
reportMetric("min_tps_threshold", minTps)
123+
124+
assertTrue(
125+
"TPS regression detected! measured=${"%.2f".format(measuredTps)} " +
126+
"< threshold=${"%.2f".format(minTps)}. Raw stats: $lastStatsJson",
127+
measuredTps >= minTps,
128+
)
129+
}
130+
131+
/**
132+
* Validates that TPS is stable across multiple consecutive runs.
133+
*
134+
* Large variance in TPS (high coefficient of variation) may indicate thread contention, GC
135+
* pressure, thermal throttling, or non-deterministic scheduling — all of which degrade the user
136+
* experience even if average TPS is acceptable.
137+
*/
138+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
139+
fun testTpsStability() {
140+
val loadResult = llmModule.load()
141+
assertTrue("Model failed to load", loadResult == 0)
142+
143+
// Warm-up
144+
resetState()
145+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
146+
147+
// Collect TPS over multiple runs
148+
val measurements = mutableListOf<Float>()
149+
for (i in 1..STABILITY_ITERATIONS) {
150+
resetState()
151+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
152+
if (tpsResults.isNotEmpty()) {
153+
measurements.add(tpsResults.last())
154+
}
155+
}
156+
157+
assertTrue(
158+
"Not enough TPS measurements (${measurements.size}/$STABILITY_ITERATIONS)",
159+
measurements.size >= STABILITY_ITERATIONS,
160+
)
161+
162+
val mean = measurements.average().toFloat()
163+
val variance = measurements.map { (it - mean) * (it - mean) }.average().toFloat()
164+
val stddev = Math.sqrt(variance.toDouble()).toFloat()
165+
val cv = if (mean > 0f) stddev / mean else Float.MAX_VALUE
166+
167+
reportMetric("stability_mean_tps", mean)
168+
reportMetric("stability_stddev", stddev)
169+
reportMetric("stability_cv", cv)
170+
reportMetric("stability_min", measurements.min())
171+
reportMetric("stability_max", measurements.max())
172+
173+
assertTrue(
174+
"TPS too unstable! CV=${"%.3f".format(cv)} exceeds max $MAX_CV. " +
175+
"Measurements: $measurements",
176+
cv <= MAX_CV,
177+
)
178+
}
179+
180+
/**
181+
* Measures time-to-first-token (TTFT) — the delay from calling generate() until the first token
182+
* is produced (i.e., prompt evaluation / prefill time).
183+
*
184+
* High TTFT directly impacts perceived responsiveness: the user types a message and sees nothing
185+
* happen until prefill completes.
186+
*/
187+
@Test(timeout = MAX_TEST_TIMEOUT_MS)
188+
fun testTimeToFirstToken() {
189+
val loadResult = llmModule.load()
190+
assertTrue("Model failed to load", loadResult == 0)
191+
192+
// Warm-up
193+
resetState()
194+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
195+
196+
// Measured TTFT
197+
resetState()
198+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this)
199+
200+
val statsJson = lastStatsJson
201+
assertTrue("No stats JSON received from onStats callback", statsJson != null)
202+
203+
try {
204+
val json = JSONObject(statsJson!!)
205+
val inferenceStartMs = json.getLong("inference_start_ms")
206+
val promptEvalEndMs = json.getLong("prompt_eval_end_ms")
207+
val ttftMs = promptEvalEndMs - inferenceStartMs
208+
209+
reportMetric("ttft_ms", ttftMs.toFloat())
210+
211+
assertTrue(
212+
"TTFT too slow: ${ttftMs}ms exceeds max ${MAX_TTFT_MS}ms. " +
213+
"Prompt evaluation is taking too long.",
214+
ttftMs <= MAX_TTFT_MS,
215+
)
216+
} catch (e: JSONException) {
217+
fail("Failed to parse onStats JSON for TTFT: $statsJson. Error: ${e.message}")
218+
}
219+
}
220+
221+
// ─── LlmCallback ──────────────────────────────────────────────────────────────────
222+
223+
override fun onResult(result: String) {
224+
generatedTokens.add(result)
225+
}
226+
227+
override fun onStats(stats: String) {
228+
lastStatsJson = stats
229+
try {
230+
val json = JSONObject(stats)
231+
val numTokens = json.getInt("generated_tokens")
232+
val inferenceEndMs = json.getLong("inference_end_ms")
233+
val promptEvalEndMs = json.getLong("prompt_eval_end_ms")
234+
val decodeTimeMs = inferenceEndMs - promptEvalEndMs
235+
if (decodeTimeMs > 0) {
236+
tpsResults.add(numTokens.toFloat() / decodeTimeMs * 1000f)
237+
}
238+
} catch (_: JSONException) {
239+
// Parsing failure — test will fail on assertion
240+
}
241+
}
242+
243+
// ─── Helpers ─────────────────────────────────────────────────────────────────────
244+
245+
private fun resetState() {
246+
generatedTokens.clear()
247+
tpsResults.clear()
248+
lastStatsJson = null
249+
}
250+
251+
/**
252+
* Returns the minimum TPS threshold. Overridable via instrumentation arg `minTps` so the same
253+
* test binary can gate at different levels for emulator vs physical device CI.
254+
*/
255+
private fun getMinTpsThreshold(): Float {
256+
val override = InstrumentationRegistry.getArguments().getString("minTps")
257+
return override?.toFloatOrNull() ?: DEFAULT_MIN_TPS
258+
}
259+
260+
private fun reportMetric(key: String, value: Float) {
261+
val bundle = Bundle().apply { putFloat(key, value) }
262+
InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle)
263+
}
264+
265+
companion object {
266+
private const val TEST_FILE_NAME = "/stories.pte"
267+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
268+
269+
/** Prompt for inference. Kept short to minimize test wall-time. */
270+
private const val TEST_PROMPT = "Once upon a time"
271+
private const val SEQ_LEN = 64
272+
273+
/**
274+
* Minimum TPS for the test to pass. Conservative for x86_64 emulator (API 34). For physical
275+
* devices, override via: -Pandroid.testInstrumentationRunnerArguments.minTps=10.0
276+
*/
277+
private const val DEFAULT_MIN_TPS = 1.0f
278+
279+
/** Maximum time-to-first-token in milliseconds. 30s is generous for emulator. */
280+
private const val MAX_TTFT_MS = 30_000
281+
282+
/**
283+
* Maximum coefficient of variation (stddev/mean) for TPS across runs. 0.5 = up to 50% relative
284+
* variance, which is generous for noisy emulator environments. Tighten for dedicated devices.
285+
*/
286+
private const val MAX_CV = 0.5f
287+
288+
/** Number of runs for the stability test. */
289+
private const val STABILITY_ITERATIONS = 3
290+
291+
/** Per-test timeout: 5 minutes to accommodate slow emulator environments. */
292+
private const val MAX_TEST_TIMEOUT_MS = 300_000L
293+
}
294+
}

0 commit comments

Comments
 (0)