Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions okio/api/okio.api
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ public class okio/ByteString : java/io/Serializable, java/lang/Comparable {
public final fun endsWith (Lokio/ByteString;)Z
public final fun endsWith ([B)Z
public fun equals (Ljava/lang/Object;)Z
public final fun equals (Lokio/ByteString;Z)Z
public final fun getByte (I)B
public fun hashCode ()I
public fun hex ()Ljava/lang/String;
Expand Down
4 changes: 4 additions & 0 deletions okio/src/appleMain/kotlin/okio/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import okio.internal.commonDecodeHex
import okio.internal.commonEncodeUtf8
import okio.internal.commonEndsWith
import okio.internal.commonEquals
import okio.internal.commonEqualsConstantTime
import okio.internal.commonGetByte
import okio.internal.commonGetSize
import okio.internal.commonHashCode
Expand Down Expand Up @@ -168,6 +169,9 @@ internal actual constructor(

actual override fun equals(other: Any?) = commonEquals(other)

actual fun equals(other: ByteString, constantTime: Boolean) =
if (constantTime) commonEqualsConstantTime(other) else this == other

actual override fun hashCode() = commonHashCode()

actual override fun compareTo(other: ByteString) = commonCompareTo(other)
Expand Down
9 changes: 9 additions & 0 deletions okio/src/commonMain/kotlin/okio/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ internal constructor(data: ByteArray) : Comparable<ByteString> {

override fun equals(other: Any?): Boolean

/**
* Returns true if the bytes of this equal the bytes of `other`. If [constantTime] is true this
* always inspects every byte and does not short-circuit on the first mismatch, so its running
* time does not depend on where the byte strings differ. Use that for timing-safe comparison of
* secrets like hashes or message authentication codes. If [constantTime] is false this behaves
* like [equals] and may return as soon as a mismatch is found.
*/
fun equals(other: ByteString, constantTime: Boolean): Boolean

override fun hashCode(): Int

override fun compareTo(other: ByteString): Int
Expand Down
10 changes: 10 additions & 0 deletions okio/src/commonMain/kotlin/okio/internal/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,16 @@ internal inline fun ByteString.commonEquals(other: Any?): Boolean {
}
}

internal fun ByteString.commonEqualsConstantTime(other: ByteString): Boolean {
if (other === this) return true
if (other.size != size) return false
var result = 0
for (i in 0 until size) {
result = result or (this[i].toInt() xor other[i].toInt())
Comment on lines +240 to +242

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you chose to use 0 as the true-y value and or the comparisons as opposed to using 1 and and? The result is the same, but this version is unintuitive to me because it's not how I think about the comparisons being done.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the standard constant-time idiom: XOR each byte pair (0 when equal) and OR the results, so result stays 0 only if every pair matched and becomes non-zero the moment any byte differs, with a single comparison at the very end. Going with 1 and AND would need a per-byte equality test (this[i] == other[i]), which reintroduces a branch/comparison inside the loop that the OR-of-XOR form avoids. It mirrors crypto/subtle.ConstantTimeCompare and CRYPTO_memcmp. Happy to add a short comment to that effect if it helps readability.

}
return result == 0
}

@Suppress("NOTHING_TO_INLINE")
internal inline fun ByteString.commonHashCode(): Int {
val result = hashCode
Expand Down
25 changes: 25 additions & 0 deletions okio/src/commonTest/kotlin/okio/ByteStringTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,31 @@ class ByteStringTest(
assertEquals(ByteString.of(), factory.decodeHex(""))
}

@Test fun equalsConstantTime() {
val byteString = factory.decodeHex("000102")
assertTrue(byteString.equals(byteString, constantTime = true))
assertTrue(byteString.equals("000102".decodeHex(), constantTime = true))
assertFalse(byteString.equals("800102".decodeHex(), constantTime = true))
assertFalse(byteString.equals("000180".decodeHex(), constantTime = true))
assertFalse(byteString.equals("0001".decodeHex(), constantTime = true))
assertFalse(byteString.equals("00010203".decodeHex(), constantTime = true))
}

@Test fun equalsConstantTimeEmptyTest() {
assertTrue(factory.decodeHex("").equals(ByteString.EMPTY, constantTime = true))
assertFalse(factory.decodeHex("").equals("00".decodeHex(), constantTime = true))
}

@Test fun equalsNotConstantTime() {
val byteString = factory.decodeHex("000102")
assertTrue(byteString.equals(byteString, constantTime = false))
assertTrue(byteString.equals("000102".decodeHex(), constantTime = false))
assertFalse(byteString.equals("800102".decodeHex(), constantTime = false))
assertFalse(byteString.equals("000180".decodeHex(), constantTime = false))
assertFalse(byteString.equals("0001".decodeHex(), constantTime = false))
assertFalse(byteString.equals("00010203".decodeHex(), constantTime = false))
}

private val bronzeHorseman = "На берегу пустынных волн"

@Test fun utf8() {
Expand Down
4 changes: 4 additions & 0 deletions okio/src/jvmMain/kotlin/okio/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import okio.internal.commonDecodeHex
import okio.internal.commonEncodeUtf8
import okio.internal.commonEndsWith
import okio.internal.commonEquals
import okio.internal.commonEqualsConstantTime
import okio.internal.commonGetByte
import okio.internal.commonGetSize
import okio.internal.commonHashCode
Expand Down Expand Up @@ -189,6 +190,9 @@ internal actual constructor(

actual override fun equals(other: Any?) = commonEquals(other)

actual fun equals(other: ByteString, constantTime: Boolean) =
if (constantTime) commonEqualsConstantTime(other) else this == other

actual override fun hashCode() = commonHashCode()

actual override fun compareTo(other: ByteString) = commonCompareTo(other)
Expand Down
82 changes: 82 additions & 0 deletions okio/src/jvmTest/kotlin/okio/ConstantTimeEqualsTimingTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (C) 2026 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

import okio.ByteString.Companion.toByteString
import org.junit.Assert.assertTrue
import org.junit.Test

/**
* Statistical timing test for [ByteString.equals] with [constantTime]=true.
*
* Given two very large byte strings:
* - When not equal, constant-time equals takes the same time as when equal (no short-circuit).
* - When not equal, normal equals takes statistically significantly less time than when equal
* (short-circuits at the first differing byte).
*/
class ConstantTimeEqualsTimingTest {

@Test
fun constantTimeEqualsDoesNotShortCircuit() {
val n = 1_000_000
val aBytes = ByteArray(n) { 0 }
val bBytes = ByteArray(n) { 0 } // identical to a
val cBytes = ByteArray(n) { 0 }.also { it[0] = 1 } // differs at byte 0

val a = aBytes.toByteString()
val b = bBytes.toByteString()
val c = cBytes.toByteString()

// Warm up the JIT.
repeat(200) {
a.equals(b, constantTime = true)
a.equals(c, constantTime = true)
a == b
a == c
}

val iterations = 500

val ctMatch = median(iterations) { a.equals(b, constantTime = true) }
val ctMismatch = median(iterations) { a.equals(c, constantTime = true) }
val normalMismatch = median(iterations) { a == c }

// CT(match) and CT(mismatch) should be within 3x of each other: neither short-circuits.
val ctRatio =
if (ctMatch > ctMismatch) ctMatch.toDouble() / ctMismatch else ctMismatch.toDouble() / ctMatch
assertTrue(
"CT(match)=$ctMatch ns and CT(mismatch)=$ctMismatch ns differ by ${ctRatio}x (expected <3x)",
ctRatio < 3.0,
)

// normal(mismatch) must be significantly faster than CT(mismatch): normal short-circuits at byte 0.
assertTrue(
"normal(mismatch)=$normalMismatch ns should be <2% of CT(mismatch)=$ctMismatch ns (short-circuit at byte 0)",
normalMismatch * 50L < ctMismatch,
)
}

private inline fun median(n: Int, block: () -> Unit): Long {
val times = LongArray(n)
repeat(n) { i ->
val t0 = System.nanoTime()
block()
times[i] = System.nanoTime() - t0
}
times.sort()
return times[n / 2]
}
}
4 changes: 4 additions & 0 deletions okio/src/nonAppleMain/kotlin/okio/ByteString.kt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import okio.internal.commonDecodeHex
import okio.internal.commonEncodeUtf8
import okio.internal.commonEndsWith
import okio.internal.commonEquals
import okio.internal.commonEqualsConstantTime
import okio.internal.commonGetByte
import okio.internal.commonGetSize
import okio.internal.commonHashCode
Expand Down Expand Up @@ -162,6 +163,9 @@ internal actual constructor(

actual override fun equals(other: Any?) = commonEquals(other)

actual fun equals(other: ByteString, constantTime: Boolean) =
if (constantTime) commonEqualsConstantTime(other) else this == other

actual override fun hashCode() = commonHashCode()

actual override fun compareTo(other: ByteString) = commonCompareTo(other)
Expand Down
Loading