-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add ByteString.equals(other, constantTime) overload
#1812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
dd9f614
d9deb8c
9445141
4bfb638
a59f221
a4e8c40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,7 @@ import okio.internal.commonDecodeHex | |
| import okio.internal.commonEncodeUtf8 | ||
| import okio.internal.commonEndsWith | ||
| import okio.internal.commonEquals | ||
| import okio.internal.commonEqualsConstantTime | ||
| import okio.internal.commonGetByte | ||
| import okio.internal.commonGetSize | ||
| import okio.internal.commonHashCode | ||
|
|
@@ -189,6 +190,8 @@ internal actual constructor( | |
|
|
||
| actual override fun equals(other: Any?) = commonEquals(other) | ||
|
|
||
| actual fun equals(other: ByteString, constantTime: Boolean) = commonEqualsConstantTime(other) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not handle
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in a4e8c40. The overload now branches on the flag and falls back to the regular equals for constantTime=false, with the false path covered in ByteStringTest. |
||
|
|
||
| actual override fun hashCode() = commonHashCode() | ||
|
|
||
| actual override fun compareTo(other: ByteString) = commonCompareTo(other) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * Copyright (C) 2026 Square, Inc. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package okio | ||
|
|
||
| import okio.ByteString.Companion.toByteString | ||
| import org.junit.Assert.assertTrue | ||
| import org.junit.Test | ||
|
|
||
| /** | ||
| * Statistical timing test for [ByteString.equals] with [constantTime]=true. | ||
| * | ||
| * Given two very large byte strings: | ||
| * - When not equal, constant-time equals takes the same time as when equal (no short-circuit). | ||
| * - When not equal, normal equals takes statistically significantly less time than when equal | ||
| * (short-circuits at the first differing byte). | ||
| */ | ||
| class ConstantTimeEqualsTimingTest { | ||
|
|
||
| @Test | ||
| fun constantTimeEqualsDoesNotShortCircuit() { | ||
| val n = 1_000_000 | ||
| val aBytes = ByteArray(n) { 0 } | ||
| val bBytes = ByteArray(n) { 0 } // identical to a | ||
| val cBytes = ByteArray(n) { 0 }.also { it[0] = 1 } // differs at byte 0 | ||
|
|
||
| val a = aBytes.toByteString() | ||
| val b = bBytes.toByteString() | ||
| val c = cBytes.toByteString() | ||
|
|
||
| // Warm up the JIT. | ||
| repeat(200) { | ||
| a.equals(b, constantTime = true) | ||
| a.equals(c, constantTime = true) | ||
| a == b | ||
| a == c | ||
| } | ||
|
|
||
| val iterations = 500 | ||
|
|
||
| val ctMatch = median(iterations) { a.equals(b, constantTime = true) } | ||
| val ctMismatch = median(iterations) { a.equals(c, constantTime = true) } | ||
| val normalMismatch = median(iterations) { a == c } | ||
|
|
||
| // CT(match) and CT(mismatch) should be within 3x of each other: neither short-circuits. | ||
| val ctRatio = | ||
| if (ctMatch > ctMismatch) ctMatch.toDouble() / ctMismatch else ctMismatch.toDouble() / ctMatch | ||
| assertTrue( | ||
| "CT(match)=$ctMatch ns and CT(mismatch)=$ctMismatch ns differ by ${ctRatio}x (expected <3x)", | ||
| ctRatio < 3.0, | ||
| ) | ||
|
|
||
| // normal(mismatch) must be significantly faster than CT(mismatch): normal short-circuits at byte 0. | ||
| assertTrue( | ||
| "normal(mismatch)=$normalMismatch ns should be <2% of CT(mismatch)=$ctMismatch ns (short-circuit at byte 0)", | ||
| normalMismatch * 50L < ctMismatch, | ||
| ) | ||
| } | ||
|
|
||
| private inline fun median(n: Int, block: () -> Unit): Long { | ||
| val times = LongArray(n) | ||
| repeat(n) { i -> | ||
| val t0 = System.nanoTime() | ||
| block() | ||
| times[i] = System.nanoTime() - t0 | ||
| } | ||
| times.sort() | ||
| return times[n / 2] | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you chose to use 0 as the true-y value and
orthe comparisons as opposed to using 1 andand? The result is the same, but this version is unintuitive to me because it's not how I think about the comparisons being done.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the standard constant-time idiom: XOR each byte pair (0 when equal) and OR the results, so result stays 0 only if every pair matched and becomes non-zero the moment any byte differs, with a single comparison at the very end. Going with 1 and AND would need a per-byte equality test (this[i] == other[i]), which reintroduces a branch/comparison inside the loop that the OR-of-XOR form avoids. It mirrors crypto/subtle.ConstantTimeCompare and CRYPTO_memcmp. Happy to add a short comment to that effect if it helps readability.