Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions Sources/Atomics/Protocols/AtomicReference.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ internal var _concurrencyWindow: Int { 20 }

extension DoubleWord {
fileprivate init(_raw: UnsafeRawPointer?, readers: Int, version: Int) {
let r = UInt(bitPattern: readers) & Self._readersMask
assert(r == readers)
let r = UInt(bitPattern: readers)
precondition(r <= Self._readersMask)
self.init(
first: UInt(bitPattern: _raw),
second: r | (UInt(bitPattern: version) &<< Self._readersBitWidth))
Expand Down Expand Up @@ -107,8 +107,8 @@ extension DoubleWord {
fileprivate var _readers: Int {
get { Int(bitPattern: second & Self._readersMask) }
set {
let n = UInt(bitPattern: newValue) & Self._readersMask
assert(n == newValue)
let n = UInt(bitPattern: newValue)
precondition(n <= Self._readersMask)
second = (second & ~Self._readersMask) | n
}
}
Expand All @@ -123,6 +123,19 @@ extension DoubleWord {
second & Self._readersMask)
}
}

@inline(__always)
fileprivate func _incrementingReaders(
readerMask: UInt = Self._readersMask
) -> Self? {
precondition(readerMask <= Self._readersMask)
let readers = UInt(bitPattern: _readers)
guard readers < readerMask else { return nil }
return DoubleWord(
_raw: _raw,
readers: Int(readers + 1),
version: _version)
}
}

extension Optional where Wrapped == AnyObject {
Expand Down Expand Up @@ -166,17 +179,30 @@ internal struct _AtomicReferenceStorage {
from pointer: UnsafeMutablePointer<Self>,
hint: DoubleWord? = nil
) -> DoubleWord {
var hint = hint
while true {
if let result = _tryStartLoading(from: pointer, hint: hint) {
return result
}
hint = nil
}
}

fileprivate static func _tryStartLoading(
from pointer: UnsafeMutablePointer<Self>,
hint: DoubleWord? = nil,
readerMask: UInt = DoubleWord._readersMask
) -> DoubleWord? {
var old = hint ?? Storage.atomicLoad(at: pointer._extract, ordering: .relaxed)
if old._raw == nil {
atomicMemoryFence(ordering: .acquiring)
return old
}
// Increment reader count
while true {
let new = DoubleWord(
_raw: old._raw,
readers: old._readers &+ 1,
version: old._version)
guard let new = old._incrementingReaders(readerMask: readerMask) else {
return nil
}
var done: Bool
(done, old) = Storage.atomicWeakCompareExchange(
expected: old,
Expand All @@ -189,7 +215,7 @@ internal struct _AtomicReferenceStorage {
}
}

private static func _finishLoading(
fileprivate static func _finishLoading(
_ value: DoubleWord,
from pointer: UnsafeMutablePointer<Self>
) -> AnyObject? {
Expand Down Expand Up @@ -362,6 +388,44 @@ extension AtomicReferenceStorage {
}
}

extension AtomicReferenceStorage {
@_spi(Testing)
public static func _testStartLoadingReaderCounts(
for value: Value,
readerMask: UInt
) -> (counts: [Int], overflowed: Bool) {
precondition(readerMask <= DoubleWord._readersMask)

let pointer = UnsafeMutablePointer<_AtomicReferenceStorage>.allocate(capacity: 1)
pointer.initialize(to: _AtomicReferenceStorage(value))

var started: [DoubleWord] = []
defer {
for value in started.reversed() {
_ = _AtomicReferenceStorage._finishLoading(value, from: pointer)
}
_ = pointer.pointee.dispose()
pointer.deinitialize(count: 1)
pointer.deallocate()
}

for _ in 0 ..< Int(readerMask) {
guard let value = _AtomicReferenceStorage._tryStartLoading(
from: pointer,
readerMask: readerMask)
else {
return (started.map { $0._readers }, true)
}
started.append(value)
}

let overflowed = _AtomicReferenceStorage._tryStartLoading(
from: pointer,
readerMask: readerMask) == nil
return (started.map { $0._readers }, overflowed)
}
}

extension AtomicReferenceStorage: AtomicStorage {
@inlinable @inline(__always)
@_alwaysEmitIntoClient
Expand Down
11 changes: 10 additions & 1 deletion Tests/AtomicsTests/StrongReferenceRace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//

import XCTest
import Atomics
@_spi(Testing) import Atomics
import Dispatch

private var iterations: Int {
Expand Down Expand Up @@ -75,6 +75,14 @@ class StrongReferenceRace: XCTestCase {
func testLoad8() { checkLoad(count: 8, iterations: iterations) }
func testLoad16() { checkLoad(count: 16, iterations: iterations) }

func testReaderCountDoesNotWrapAtCapacity() {
let result = AtomicReferenceStorage<Node>._testStartLoadingReaderCounts(
for: Node(),
readerMask: 3)
XCTAssertEqual(result.counts, [1, 2, 3])
XCTAssertTrue(result.overflowed)
}

func checkCompareExchange(count: Int, iterations: Int, file: StaticString = #file, line: UInt = #line) {
let a = Node()
let b = Node()
Expand Down Expand Up @@ -240,6 +248,7 @@ class StrongReferenceRace: XCTestCase {
("testLoad4", testLoad4),
("testLoad8", testLoad8),
("testLoad16", testLoad16),
("testReaderCountDoesNotWrapAtCapacity", testReaderCountDoesNotWrapAtCapacity),
("testCompareExchange1", testCompareExchange1),
("testCompareExchange2", testCompareExchange2),
("testCompareExchange4", testCompareExchange4),
Expand Down