Skip to content

Commit 6b353c1

Browse files
authored
[dotnet] [bidi] Decouple transport and processing (#17291)
1 parent 549261b commit 6b353c1

File tree

3 files changed

+168
-60
lines changed

3 files changed

+168
-60
lines changed

dotnet/src/webdriver/BiDi/Broker.cs

Lines changed: 157 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@
2121
using System.Collections.Concurrent;
2222
using System.Text.Json;
2323
using System.Text.Json.Serialization.Metadata;
24+
using System.Threading.Channels;
2425
using OpenQA.Selenium.BiDi.Session;
2526
using OpenQA.Selenium.Internal.Logging;
2627

2728
namespace OpenQA.Selenium.BiDi;
2829

2930
internal sealed class Broker : IAsyncDisposable
3031
{
32+
// Limits how many received messages can be buffered before backpressure is applied to the transport.
33+
private const int ReceivedMessageQueueCapacity = 16;
34+
35+
// How long to wait for a command response before cancelling.
36+
private static readonly TimeSpan DefaultCommandTimeout = TimeSpan.FromSeconds(30);
37+
3138
private readonly ILogger _logger = Internal.Logging.Log.GetLogger<Broker>();
3239

3340
private readonly ITransport _transport;
@@ -38,7 +45,16 @@ internal sealed class Broker : IAsyncDisposable
3845

3946
private long _currentCommandId;
4047

41-
private readonly Task _receivingMessageTask;
48+
private readonly Channel<PooledBufferWriter> _receivedMessages = Channel.CreateBounded<PooledBufferWriter>(
49+
new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = true, SingleWriter = true, FullMode = BoundedChannelFullMode.Wait });
50+
51+
private readonly Channel<PooledBufferWriter> _bufferPool = Channel.CreateBounded<PooledBufferWriter>(
52+
new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = false, SingleWriter = false });
53+
54+
private volatile Exception? _terminalReceiveException;
55+
56+
private readonly Task _receivingTask;
57+
private readonly Task _processingTask;
4258
private readonly CancellationTokenSource _receiveMessagesCancellationTokenSource;
4359

4460
public Broker(ITransport transport, IBiDi bidi, Func<ISessionModule> sessionProvider)
@@ -48,7 +64,8 @@ public Broker(ITransport transport, IBiDi bidi, Func<ISessionModule> sessionProv
4864
_eventDispatcher = new EventDispatcher(sessionProvider);
4965

5066
_receiveMessagesCancellationTokenSource = new CancellationTokenSource();
51-
_receivingMessageTask = Task.Run(() => ReceiveMessagesLoopAsync(_receiveMessagesCancellationTokenSource.Token));
67+
_receivingTask = Task.Run(() => ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token));
68+
_processingTask = Task.Run(ProcessMessagesAsync);
5269
}
5370

5471
public Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo<TEventArgs> jsonTypeInfo, CancellationToken cancellationToken)
@@ -61,6 +78,11 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
6178
where TCommand : Command
6279
where TResult : EmptyResult
6380
{
81+
if (_terminalReceiveException is { } terminalException)
82+
{
83+
throw new BiDiException("The broker is no longer processing messages due to a transport error.", terminalException);
84+
}
85+
6486
command.Id = Interlocked.Increment(ref _currentCommandId);
6587

6688
var tcs = new TaskCompletionSource<EmptyResult>(TaskCreationOptions.RunContinuationsAsynchronously);
@@ -69,42 +91,49 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
6991
? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)
7092
: new CancellationTokenSource();
7193

72-
var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30);
94+
var timeout = options?.Timeout ?? DefaultCommandTimeout;
7395
cts.CancelAfter(timeout);
7496

75-
using var sendBuffer = new PooledBufferWriter();
97+
var sendBuffer = RentBuffer();
7698

77-
using (var writer = new Utf8JsonWriter(sendBuffer))
99+
try
78100
{
79-
JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo);
80-
}
101+
using (var writer = new Utf8JsonWriter(sendBuffer))
102+
{
103+
JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo);
104+
}
81105

82-
var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo);
83-
_pendingCommands[command.Id] = commandInfo;
106+
var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo);
107+
_pendingCommands[command.Id] = commandInfo;
84108

85-
using var ctsRegistration = cts.Token.Register(() =>
86-
{
87-
tcs.TrySetCanceled(cts.Token);
88-
_pendingCommands.TryRemove(command.Id, out _);
89-
});
109+
using var ctsRegistration = cts.Token.Register(() =>
110+
{
111+
tcs.TrySetCanceled(cts.Token);
112+
_pendingCommands.TryRemove(command.Id, out _);
113+
});
90114

91-
try
92-
{
93-
if (_logger.IsEnabled(LogEventLevel.Trace))
115+
try
94116
{
117+
if (_logger.IsEnabled(LogEventLevel.Trace))
118+
{
95119
#if NET8_0_OR_GREATER
96-
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}");
120+
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}");
97121
#else
98-
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}");
122+
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}");
99123
#endif
100-
}
124+
}
101125

102-
await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false);
126+
await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false);
127+
}
128+
catch
129+
{
130+
_pendingCommands.TryRemove(command.Id, out _);
131+
throw;
132+
}
103133
}
104-
catch
134+
finally
105135
{
106-
_pendingCommands.TryRemove(command.Id, out _);
107-
throw;
136+
ReturnBuffer(sendBuffer);
108137
}
109138

110139
return (TResult)await tcs.Task.ConfigureAwait(false);
@@ -114,22 +143,32 @@ public async ValueTask DisposeAsync()
114143
{
115144
_receiveMessagesCancellationTokenSource.Cancel();
116145

117-
await _eventDispatcher.DisposeAsync().ConfigureAwait(false);
118-
119146
try
120147
{
121-
await _receivingMessageTask.ConfigureAwait(false);
122-
}
123-
catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested)
124-
{
125-
// Expected when cancellation is requested, ignore.
126-
}
148+
try
149+
{
150+
await _receivingTask.ConfigureAwait(false);
151+
}
152+
catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested)
153+
{
154+
// Expected when cancellation is requested, ignore.
155+
}
127156

128-
_receiveMessagesCancellationTokenSource.Dispose();
157+
await _transport.DisposeAsync().ConfigureAwait(false);
129158

130-
await _transport.DisposeAsync().ConfigureAwait(false);
159+
await _processingTask.ConfigureAwait(false);
160+
161+
await _eventDispatcher.DisposeAsync().ConfigureAwait(false);
162+
}
163+
finally
164+
{
165+
_receiveMessagesCancellationTokenSource.Dispose();
131166

132-
GC.SuppressFinalize(this);
167+
while (_bufferPool.Reader.TryRead(out var buffer))
168+
{
169+
buffer.Dispose();
170+
}
171+
}
133172
}
134173

135174
private void ProcessReceivedMessage(ReadOnlySpan<byte> data)
@@ -281,30 +320,63 @@ private void ProcessReceivedMessage(ReadOnlySpan<byte> data)
281320
}
282321
}
283322

284-
private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken)
323+
private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
285324
{
286-
using var receiveBufferWriter = new PooledBufferWriter();
287-
288325
try
289326
{
290327
while (!cancellationToken.IsCancellationRequested)
291328
{
292-
receiveBufferWriter.Reset();
293-
294-
await _transport.ReceiveAsync(receiveBufferWriter, cancellationToken).ConfigureAwait(false);
329+
var buffer = RentBuffer();
295330

296-
if (_logger.IsEnabled(LogEventLevel.Trace))
331+
try
297332
{
333+
await _transport.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);
334+
335+
if (_logger.IsEnabled(LogEventLevel.Trace))
336+
{
298337
#if NET8_0_OR_GREATER
299-
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.Span)}");
338+
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.Span)}");
300339
#else
301-
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.ToArray())}");
340+
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.ToArray())}");
302341
#endif
342+
}
343+
344+
await _receivedMessages.Writer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
345+
}
346+
catch
347+
{
348+
ReturnBuffer(buffer);
349+
throw;
303350
}
351+
}
352+
}
353+
catch (Exception ex) when (ex is not OperationCanceledException)
354+
{
355+
if (_logger.IsEnabled(LogEventLevel.Error))
356+
{
357+
_logger.Error($"Unhandled error occurred while receiving remote messages: {ex}");
358+
}
304359

360+
// Propagated via _terminalReceiveException; not rethrown to keep disposal orderly.
361+
_terminalReceiveException = ex;
362+
}
363+
finally
364+
{
365+
_receivedMessages.Writer.TryComplete();
366+
}
367+
}
368+
369+
private async Task ProcessMessagesAsync()
370+
{
371+
var reader = _receivedMessages.Reader;
372+
373+
while (await reader.WaitToReadAsync().ConfigureAwait(false))
374+
{
375+
while (reader.TryRead(out var buffer))
376+
{
305377
try
306378
{
307-
ProcessReceivedMessage(receiveBufferWriter.WrittenMemory.Span);
379+
ProcessReceivedMessage(buffer.WrittenMemory.Span);
308380
}
309381
catch (Exception ex)
310382
{
@@ -313,25 +385,43 @@ private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken)
313385
_logger.Error($"Unhandled error occurred while processing remote message: {ex}");
314386
}
315387
}
388+
finally
389+
{
390+
ReturnBuffer(buffer);
391+
}
316392
}
317393
}
318-
catch (Exception ex) when (ex is not OperationCanceledException)
319-
{
320-
if (_logger.IsEnabled(LogEventLevel.Error))
321-
{
322-
_logger.Error($"Unhandled error occurred while receiving remote messages: {ex}");
323-
}
324394

325-
// Fail all pending commands, as the connection is likely broken if we failed to receive messages.
326-
foreach (var id in _pendingCommands.Keys)
395+
// Channel is fully drained. Fail any commands that didn't get a response:
396+
// either with the transport error or cancellation for clean shutdown.
397+
var terminalException = _terminalReceiveException;
398+
foreach (var id in _pendingCommands.Keys)
399+
{
400+
if (_pendingCommands.TryRemove(id, out var pendingCommand))
327401
{
328-
if (_pendingCommands.TryRemove(id, out var pendingCommand))
402+
if (terminalException is not null)
329403
{
330-
pendingCommand.TaskCompletionSource.TrySetException(ex);
404+
pendingCommand.TaskCompletionSource.TrySetException(terminalException);
405+
}
406+
else
407+
{
408+
pendingCommand.TaskCompletionSource.TrySetCanceled();
331409
}
332410
}
411+
}
412+
}
413+
414+
private PooledBufferWriter RentBuffer()
415+
{
416+
return _bufferPool.Reader.TryRead(out var buffer) ? buffer : new PooledBufferWriter();
417+
}
333418

334-
throw;
419+
private void ReturnBuffer(PooledBufferWriter buffer)
420+
{
421+
buffer.Reset();
422+
if (!_bufferPool.Writer.TryWrite(buffer))
423+
{
424+
buffer.Dispose();
335425
}
336426
}
337427

@@ -359,7 +449,13 @@ public void Reset()
359449
_written = 0;
360450
}
361451

362-
public void Advance(int count) => _written += count;
452+
public void Advance(int count)
453+
{
454+
if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
455+
if (_written + count > (_buffer?.Length ?? 0)) throw new InvalidOperationException("Cannot advance past the end of the buffer.");
456+
457+
_written += count;
458+
}
363459

364460
public Memory<byte> GetMemory(int sizeHint = 0)
365461
{
@@ -377,8 +473,10 @@ private void EnsureCapacity(int sizeHint)
377473
{
378474
var buffer = _buffer ?? throw new ObjectDisposedException(nameof(PooledBufferWriter));
379475

380-
if (sizeHint <= 0) sizeHint = buffer.Length - _written;
381-
if (sizeHint <= 0) sizeHint = buffer.Length;
476+
if (sizeHint <= 0)
477+
{
478+
sizeHint = Math.Max(1, buffer.Length - _written);
479+
}
382480

383481
if (_written + sizeHint > buffer.Length)
384482
{

dotnet/test/webdriver/BiDi/BiDiFixture.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ public async Task BiDiTearDown()
5757
await bidi.DisposeAsync();
5858
}
5959

60-
driver?.Dispose();
60+
if (driver is not null)
61+
{
62+
await driver.DisposeAsync();
63+
}
6164
}
6265

6366
public class BiDiEnabledDriverOptions : DriverOptions

dotnet/test/webdriver/BiDi/Session/SessionTests.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ namespace OpenQA.Selenium.Tests.BiDi.Session;
2121

2222
internal class SessionTests : BiDiTestFixture
2323
{
24+
[Test]
25+
public async Task ShouldHaveIdempotentDisposal()
26+
{
27+
await bidi.DisposeAsync();
28+
await bidi.DisposeAsync();
29+
}
30+
2431
[Test]
2532
public async Task CanGetStatus()
2633
{

0 commit comments

Comments
 (0)