2121using System . Collections . Concurrent ;
2222using System . Text . Json ;
2323using System . Text . Json . Serialization . Metadata ;
24+ using System . Threading . Channels ;
2425using OpenQA . Selenium . BiDi . Session ;
2526using OpenQA . Selenium . Internal . Logging ;
2627
2728namespace OpenQA . Selenium . BiDi ;
2829
2930internal sealed class Broker : IAsyncDisposable
3031{
32+ // Limits how many received messages can be buffered before backpressure is applied to the transport.
33+ private const int ReceivedMessageQueueCapacity = 16 ;
34+
35+ // How long to wait for a command response before cancelling.
36+ private static readonly TimeSpan DefaultCommandTimeout = TimeSpan . FromSeconds ( 30 ) ;
37+
3138 private readonly ILogger _logger = Internal . Logging . Log . GetLogger < Broker > ( ) ;
3239
3340 private readonly ITransport _transport ;
@@ -38,7 +45,16 @@ internal sealed class Broker : IAsyncDisposable
3845
3946 private long _currentCommandId ;
4047
41- private readonly Task _receivingMessageTask ;
48+ private readonly Channel < PooledBufferWriter > _receivedMessages = Channel . CreateBounded < PooledBufferWriter > (
49+ new BoundedChannelOptions ( ReceivedMessageQueueCapacity ) { SingleReader = true , SingleWriter = true , FullMode = BoundedChannelFullMode . Wait } ) ;
50+
51+ private readonly Channel < PooledBufferWriter > _bufferPool = Channel . CreateBounded < PooledBufferWriter > (
52+ new BoundedChannelOptions ( ReceivedMessageQueueCapacity ) { SingleReader = false , SingleWriter = false } ) ;
53+
54+ private volatile Exception ? _terminalReceiveException ;
55+
56+ private readonly Task _receivingTask ;
57+ private readonly Task _processingTask ;
4258 private readonly CancellationTokenSource _receiveMessagesCancellationTokenSource ;
4359
4460 public Broker ( ITransport transport , IBiDi bidi , Func < ISessionModule > sessionProvider )
@@ -48,7 +64,8 @@ public Broker(ITransport transport, IBiDi bidi, Func<ISessionModule> sessionProv
4864 _eventDispatcher = new EventDispatcher ( sessionProvider ) ;
4965
5066 _receiveMessagesCancellationTokenSource = new CancellationTokenSource ( ) ;
51- _receivingMessageTask = Task . Run ( ( ) => ReceiveMessagesLoopAsync ( _receiveMessagesCancellationTokenSource . Token ) ) ;
67+ _receivingTask = Task . Run ( ( ) => ReceiveMessagesAsync ( _receiveMessagesCancellationTokenSource . Token ) ) ;
68+ _processingTask = Task . Run ( ProcessMessagesAsync ) ;
5269 }
5370
5471 public Task < Subscription > SubscribeAsync < TEventArgs > ( string eventName , EventHandler eventHandler , SubscriptionOptions ? options , JsonTypeInfo < TEventArgs > jsonTypeInfo , CancellationToken cancellationToken )
@@ -61,6 +78,11 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
6178 where TCommand : Command
6279 where TResult : EmptyResult
6380 {
81+ if ( _terminalReceiveException is { } terminalException )
82+ {
83+ throw new BiDiException ( "The broker is no longer processing messages due to a transport error." , terminalException ) ;
84+ }
85+
6486 command . Id = Interlocked . Increment ( ref _currentCommandId ) ;
6587
6688 var tcs = new TaskCompletionSource < EmptyResult > ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
@@ -69,42 +91,49 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
6991 ? CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken )
7092 : new CancellationTokenSource ( ) ;
7193
72- var timeout = options ? . Timeout ?? TimeSpan . FromSeconds ( 30 ) ;
94+ var timeout = options ? . Timeout ?? DefaultCommandTimeout ;
7395 cts . CancelAfter ( timeout ) ;
7496
75- using var sendBuffer = new PooledBufferWriter ( ) ;
97+ var sendBuffer = RentBuffer ( ) ;
7698
77- using ( var writer = new Utf8JsonWriter ( sendBuffer ) )
99+ try
78100 {
79- JsonSerializer . Serialize ( writer , command , jsonCommandTypeInfo ) ;
80- }
101+ using ( var writer = new Utf8JsonWriter ( sendBuffer ) )
102+ {
103+ JsonSerializer . Serialize ( writer , command , jsonCommandTypeInfo ) ;
104+ }
81105
82- var commandInfo = new CommandInfo ( tcs , jsonResultTypeInfo ) ;
83- _pendingCommands [ command . Id ] = commandInfo ;
106+ var commandInfo = new CommandInfo ( tcs , jsonResultTypeInfo ) ;
107+ _pendingCommands [ command . Id ] = commandInfo ;
84108
85- using var ctsRegistration = cts . Token . Register ( ( ) =>
86- {
87- tcs . TrySetCanceled ( cts . Token ) ;
88- _pendingCommands . TryRemove ( command . Id , out _ ) ;
89- } ) ;
109+ using var ctsRegistration = cts . Token . Register ( ( ) =>
110+ {
111+ tcs . TrySetCanceled ( cts . Token ) ;
112+ _pendingCommands . TryRemove ( command . Id , out _ ) ;
113+ } ) ;
90114
91- try
92- {
93- if ( _logger . IsEnabled ( LogEventLevel . Trace ) )
115+ try
94116 {
117+ if ( _logger . IsEnabled ( LogEventLevel . Trace ) )
118+ {
95119#if NET8_0_OR_GREATER
96- _logger . Trace ( $ "BiDi SND --> { System . Text . Encoding . UTF8 . GetString ( sendBuffer . WrittenMemory . Span ) } ") ;
120+ _logger . Trace ( $ "BiDi SND --> { System . Text . Encoding . UTF8 . GetString ( sendBuffer . WrittenMemory . Span ) } ") ;
97121#else
98- _logger . Trace ( $ "BiDi SND --> { System . Text . Encoding . UTF8 . GetString ( sendBuffer . WrittenMemory . ToArray ( ) ) } ") ;
122+ _logger . Trace ( $ "BiDi SND --> { System . Text . Encoding . UTF8 . GetString ( sendBuffer . WrittenMemory . ToArray ( ) ) } ") ;
99123#endif
100- }
124+ }
101125
102- await _transport . SendAsync ( sendBuffer . WrittenMemory , cts . Token ) . ConfigureAwait ( false ) ;
126+ await _transport . SendAsync ( sendBuffer . WrittenMemory , cts . Token ) . ConfigureAwait ( false ) ;
127+ }
128+ catch
129+ {
130+ _pendingCommands . TryRemove ( command . Id , out _ ) ;
131+ throw ;
132+ }
103133 }
104- catch
134+ finally
105135 {
106- _pendingCommands . TryRemove ( command . Id , out _ ) ;
107- throw ;
136+ ReturnBuffer ( sendBuffer ) ;
108137 }
109138
110139 return ( TResult ) await tcs . Task . ConfigureAwait ( false ) ;
@@ -114,22 +143,32 @@ public async ValueTask DisposeAsync()
114143 {
115144 _receiveMessagesCancellationTokenSource . Cancel ( ) ;
116145
117- await _eventDispatcher . DisposeAsync ( ) . ConfigureAwait ( false ) ;
118-
119146 try
120147 {
121- await _receivingMessageTask . ConfigureAwait ( false ) ;
122- }
123- catch ( OperationCanceledException ) when ( _receiveMessagesCancellationTokenSource . IsCancellationRequested )
124- {
125- // Expected when cancellation is requested, ignore.
126- }
148+ try
149+ {
150+ await _receivingTask . ConfigureAwait ( false ) ;
151+ }
152+ catch ( OperationCanceledException ) when ( _receiveMessagesCancellationTokenSource . IsCancellationRequested )
153+ {
154+ // Expected when cancellation is requested, ignore.
155+ }
127156
128- _receiveMessagesCancellationTokenSource . Dispose ( ) ;
157+ await _transport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
129158
130- await _transport . DisposeAsync ( ) . ConfigureAwait ( false ) ;
159+ await _processingTask . ConfigureAwait ( false ) ;
160+
161+ await _eventDispatcher . DisposeAsync ( ) . ConfigureAwait ( false ) ;
162+ }
163+ finally
164+ {
165+ _receiveMessagesCancellationTokenSource . Dispose ( ) ;
131166
132- GC . SuppressFinalize ( this ) ;
167+ while ( _bufferPool . Reader . TryRead ( out var buffer ) )
168+ {
169+ buffer . Dispose ( ) ;
170+ }
171+ }
133172 }
134173
135174 private void ProcessReceivedMessage ( ReadOnlySpan < byte > data )
@@ -281,30 +320,63 @@ private void ProcessReceivedMessage(ReadOnlySpan<byte> data)
281320 }
282321 }
283322
284- private async Task ReceiveMessagesLoopAsync ( CancellationToken cancellationToken )
323+ private async Task ReceiveMessagesAsync ( CancellationToken cancellationToken )
285324 {
286- using var receiveBufferWriter = new PooledBufferWriter ( ) ;
287-
288325 try
289326 {
290327 while ( ! cancellationToken . IsCancellationRequested )
291328 {
292- receiveBufferWriter . Reset ( ) ;
293-
294- await _transport . ReceiveAsync ( receiveBufferWriter , cancellationToken ) . ConfigureAwait ( false ) ;
329+ var buffer = RentBuffer ( ) ;
295330
296- if ( _logger . IsEnabled ( LogEventLevel . Trace ) )
331+ try
297332 {
333+ await _transport . ReceiveAsync ( buffer , cancellationToken ) . ConfigureAwait ( false ) ;
334+
335+ if ( _logger . IsEnabled ( LogEventLevel . Trace ) )
336+ {
298337#if NET8_0_OR_GREATER
299- _logger . Trace ( $ "BiDi RCV <-- { System . Text . Encoding . UTF8 . GetString ( receiveBufferWriter . WrittenMemory . Span ) } ") ;
338+ _logger . Trace ( $ "BiDi RCV <-- { System . Text . Encoding . UTF8 . GetString ( buffer . WrittenMemory . Span ) } ") ;
300339#else
301- _logger . Trace ( $ "BiDi RCV <-- { System . Text . Encoding . UTF8 . GetString ( receiveBufferWriter . WrittenMemory . ToArray ( ) ) } ") ;
340+ _logger . Trace ( $ "BiDi RCV <-- { System . Text . Encoding . UTF8 . GetString ( buffer . WrittenMemory . ToArray ( ) ) } ") ;
302341#endif
342+ }
343+
344+ await _receivedMessages . Writer . WriteAsync ( buffer , cancellationToken ) . ConfigureAwait ( false ) ;
345+ }
346+ catch
347+ {
348+ ReturnBuffer ( buffer ) ;
349+ throw ;
303350 }
351+ }
352+ }
353+ catch ( Exception ex ) when ( ex is not OperationCanceledException )
354+ {
355+ if ( _logger . IsEnabled ( LogEventLevel . Error ) )
356+ {
357+ _logger . Error ( $ "Unhandled error occurred while receiving remote messages: { ex } ") ;
358+ }
304359
360+ // Propagated via _terminalReceiveException; not rethrown to keep disposal orderly.
361+ _terminalReceiveException = ex ;
362+ }
363+ finally
364+ {
365+ _receivedMessages . Writer . TryComplete ( ) ;
366+ }
367+ }
368+
369+ private async Task ProcessMessagesAsync ( )
370+ {
371+ var reader = _receivedMessages . Reader ;
372+
373+ while ( await reader . WaitToReadAsync ( ) . ConfigureAwait ( false ) )
374+ {
375+ while ( reader . TryRead ( out var buffer ) )
376+ {
305377 try
306378 {
307- ProcessReceivedMessage ( receiveBufferWriter . WrittenMemory . Span ) ;
379+ ProcessReceivedMessage ( buffer . WrittenMemory . Span ) ;
308380 }
309381 catch ( Exception ex )
310382 {
@@ -313,25 +385,43 @@ private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken)
313385 _logger . Error ( $ "Unhandled error occurred while processing remote message: { ex } ") ;
314386 }
315387 }
388+ finally
389+ {
390+ ReturnBuffer ( buffer ) ;
391+ }
316392 }
317393 }
318- catch ( Exception ex ) when ( ex is not OperationCanceledException )
319- {
320- if ( _logger . IsEnabled ( LogEventLevel . Error ) )
321- {
322- _logger . Error ( $ "Unhandled error occurred while receiving remote messages: { ex } ") ;
323- }
324394
325- // Fail all pending commands, as the connection is likely broken if we failed to receive messages.
326- foreach ( var id in _pendingCommands . Keys )
395+ // Channel is fully drained. Fail any commands that didn't get a response:
396+ // either with the transport error or cancellation for clean shutdown.
397+ var terminalException = _terminalReceiveException ;
398+ foreach ( var id in _pendingCommands . Keys )
399+ {
400+ if ( _pendingCommands . TryRemove ( id , out var pendingCommand ) )
327401 {
328- if ( _pendingCommands . TryRemove ( id , out var pendingCommand ) )
402+ if ( terminalException is not null )
329403 {
330- pendingCommand . TaskCompletionSource . TrySetException ( ex ) ;
404+ pendingCommand . TaskCompletionSource . TrySetException ( terminalException ) ;
405+ }
406+ else
407+ {
408+ pendingCommand . TaskCompletionSource . TrySetCanceled ( ) ;
331409 }
332410 }
411+ }
412+ }
413+
414+ private PooledBufferWriter RentBuffer ( )
415+ {
416+ return _bufferPool . Reader . TryRead ( out var buffer ) ? buffer : new PooledBufferWriter ( ) ;
417+ }
333418
334- throw ;
419+ private void ReturnBuffer ( PooledBufferWriter buffer )
420+ {
421+ buffer . Reset ( ) ;
422+ if ( ! _bufferPool . Writer . TryWrite ( buffer ) )
423+ {
424+ buffer . Dispose ( ) ;
335425 }
336426 }
337427
@@ -359,7 +449,13 @@ public void Reset()
359449 _written = 0 ;
360450 }
361451
362- public void Advance ( int count ) => _written += count ;
452+ public void Advance ( int count )
453+ {
454+ if ( count < 0 ) throw new ArgumentOutOfRangeException ( nameof ( count ) ) ;
455+ if ( _written + count > ( _buffer ? . Length ?? 0 ) ) throw new InvalidOperationException ( "Cannot advance past the end of the buffer." ) ;
456+
457+ _written += count ;
458+ }
363459
364460 public Memory < byte > GetMemory ( int sizeHint = 0 )
365461 {
@@ -377,8 +473,10 @@ private void EnsureCapacity(int sizeHint)
377473 {
378474 var buffer = _buffer ?? throw new ObjectDisposedException ( nameof ( PooledBufferWriter ) ) ;
379475
380- if ( sizeHint <= 0 ) sizeHint = buffer . Length - _written ;
381- if ( sizeHint <= 0 ) sizeHint = buffer . Length ;
476+ if ( sizeHint <= 0 )
477+ {
478+ sizeHint = Math . Max ( 1 , buffer . Length - _written ) ;
479+ }
382480
383481 if ( _written + sizeHint > buffer . Length )
384482 {
0 commit comments