Skip to content

Commit 8e5367c

Browse files
authored
feat(socketio): make rooms available in disconnect handler (#613)
1 parent 4699ab9 commit 8e5367c

File tree

4 files changed

+68
-15
lines changed

4 files changed

+68
-15
lines changed

crates/socketioxide/src/handler/disconnect.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ use super::MakeErasedHandler;
5252
pub(crate) type BoxedDisconnectHandler<A> = Box<dyn ErasedDisconnectHandler<A>>;
5353
pub(crate) trait ErasedDisconnectHandler<A: Adapter>: Send + Sync + 'static {
5454
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason);
55+
56+
fn call_with_defer(
57+
&self,
58+
s: Arc<Socket<A>>,
59+
reason: DisconnectReason,
60+
defer: fn(Arc<Socket<A>>),
61+
);
5562
}
5663

5764
impl<A: Adapter, T, H> MakeErasedHandler<H, A, T>
@@ -74,6 +81,17 @@ where
7481
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
7582
self.handler.call(s, reason);
7683
}
84+
85+
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self, s, defer), fields(id = ?s.id)))]
86+
#[inline(always)]
87+
fn call_with_defer(
88+
&self,
89+
s: Arc<Socket<A>>,
90+
reason: DisconnectReason,
91+
defer: fn(Arc<Socket<A>>),
92+
) {
93+
self.handler.call_with_defer(s, reason, defer);
94+
}
7795
}
7896

7997
/// A trait used to extract the arguments from the disconnect event.
@@ -114,7 +132,17 @@ See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for de
114132
)]
115133
pub trait DisconnectHandler<A: Adapter, T>: Send + Sync + 'static {
116134
/// Call the handler with the given arguments.
117-
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason);
135+
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
136+
self.call_with_defer(s, reason, |_| ());
137+
}
138+
139+
/// Call the handler and issue a deferred function that will be executed **after** the disconnect handler
140+
fn call_with_defer(
141+
&self,
142+
s: Arc<Socket<A>>,
143+
reason: DisconnectReason,
144+
defer: fn(Arc<Socket<A>>),
145+
);
118146

119147
#[doc(hidden)]
120148
fn phantom(&self) -> std::marker::PhantomData<T> {
@@ -135,7 +163,7 @@ macro_rules! impl_handler_async {
135163
A: Adapter,
136164
$( $ty: FromDisconnectParts<A> + Send, )*
137165
{
138-
fn call(&self, s: Arc<Socket<A>>, reason: DisconnectReason) {
166+
fn call_with_defer(&self, s: Arc<Socket<A>>, reason: DisconnectReason, defer: fn(Arc<Socket<A>>)) {
139167
$(
140168
let $ty = match $ty::from_disconnect_parts(&s, reason) {
141169
Ok(v) => v,
@@ -148,7 +176,11 @@ macro_rules! impl_handler_async {
148176
)*
149177

150178
let fut = (self.clone())($($ty,)*);
151-
tokio::spawn(fut);
179+
180+
tokio::spawn(async move {
181+
fut.await;
182+
defer(s);
183+
});
152184
}
153185
}
154186
};

crates/socketioxide/src/socket.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ impl<A: Adapter> Socket<A> {
445445
/// # Register a disconnect handler.
446446
/// You can register only one disconnect handler per socket. If you register multiple handlers, only the last one will be used.
447447
///
448+
/// This implementation is slightly different to the socket.io spec.
449+
/// The difference being that [`rooms`](Self::rooms) are still available in this handler
450+
/// and only cleaned up AFTER the execution of this handler.
451+
/// Therefore you must not indefinitely stall/hang this handler, for example by entering an endless loop.
452+
///
448453
/// _It is recommended for code clarity to define your handler as top level function rather than closures._
449454
///
450455
/// * See the [`disconnect`](crate::handler::disconnect) module doc for more details on disconnect handler.
@@ -770,15 +775,16 @@ impl<A: Adapter> Socket<A> {
770775
pub(crate) fn close(self: Arc<Self>, reason: DisconnectReason) {
771776
self.set_connected(false);
772777

773-
let handler = { self.disconnect_handler.lock().unwrap().take() };
774-
if let Some(handler) = handler {
778+
let disconnect_handler = { self.disconnect_handler.lock().unwrap().take() };
779+
780+
if let Some(handler) = disconnect_handler {
775781
#[cfg(feature = "tracing")]
776782
tracing::trace!(?reason, ?self.id, "spawning disconnect handler");
777783

778-
handler.call(self.clone(), reason);
784+
handler.call_with_defer(self.clone(), reason, |s| s.ns.remove_socket(s.id));
785+
} else {
786+
self.ns.remove_socket(self.id);
779787
}
780-
781-
self.ns.remove_socket(self.id);
782788
}
783789

784790
/// Receive data from client

crates/socketioxide/tests/disconnect_reason.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::time::Duration;
1212

1313
use futures_util::{SinkExt, StreamExt};
1414
use socketioxide::{SocketIo, extract::SocketRef, socket::DisconnectReason};
15+
use socketioxide_core::adapter::Room;
1516
use tokio::sync::mpsc;
1617

1718
mod fixture;
@@ -198,25 +199,38 @@ pub async fn client_ns_disconnect() {
198199

199200
#[tokio::test]
200201
pub async fn server_ns_disconnect() {
201-
let (tx, mut rx) = mpsc::channel::<DisconnectReason>(1);
202+
let (tx, mut rx) = mpsc::channel::<(DisconnectReason, Vec<Room>)>(1);
202203
let (svc, io) = create_server().await;
204+
203205
io.ns("/", async move |socket: SocketRef, io: SocketIo| {
206+
socket.join("testRoom1");
207+
socket.join("testRoom2");
208+
socket.join("testRoom3");
209+
204210
tokio::spawn(async move {
205211
tokio::time::sleep(Duration::from_millis(100)).await;
206212
let s = io.sockets().into_iter().next().unwrap();
207213
s.disconnect().unwrap();
208214
});
209215

210-
socket.on_disconnect(async move |reason: DisconnectReason| tx.try_send(reason).unwrap());
216+
socket.on_disconnect(async move |s: SocketRef, reason: DisconnectReason| {
217+
tx.try_send((reason, s.rooms())).unwrap();
218+
});
211219
});
212220

213221
let _stream = create_ws_connection(&svc).await;
214222

215-
let data = tokio::time::timeout(Duration::from_millis(100), rx.recv())
216-
.await
217-
.expect("timeout waiting for DisconnectReason::ServerNSDisconnect")
218-
.unwrap();
219-
assert_eq!(data, DisconnectReason::ServerNSDisconnect);
223+
let (disconnect_reason, mut rooms) =
224+
tokio::time::timeout(Duration::from_millis(100), rx.recv())
225+
.await
226+
.expect("timeout waiting for DisconnectReason::ServerNSDisconnect")
227+
.unwrap();
228+
229+
assert_eq!(disconnect_reason, DisconnectReason::ServerNSDisconnect);
230+
231+
// Sort the rooms to guarantee order for the assertion
232+
rooms.sort();
233+
assert_eq!(rooms, vec!["testRoom1", "testRoom2", "testRoom3"]);
220234
}
221235

222236
#[tokio::test]

e2e/engineioxide/engineioxide.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ impl EngineIoHandler for MyHandler {
2525
fn on_connect(self: Arc<Self>, socket: Arc<Socket<Self::Data>>) {
2626
println!("socket connect {}", socket.id);
2727
}
28+
2829
fn on_disconnect(&self, socket: Arc<Socket<Self::Data>>, reason: DisconnectReason) {
2930
println!("socket disconnect {}: {:?}", socket.id, reason);
3031
}

0 commit comments

Comments
 (0)