Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions io/zenoh-transport/src/unicast/establishment/open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ pub(crate) async fn open_link(
endpoint: EndPoint,
link: LinkUnicast,
manager: &TransportManager,
expected_zid: Option<&ZenohIdProto>,
) -> ZResult<TransportUnicast> {
let direction = TransportLinkUnicastDirection::Outbound;
let is_streamed = link.is_streamed();
Expand Down Expand Up @@ -725,6 +726,19 @@ pub(crate) async fn open_link(

let iack_out = step!(fsm.recv_init_ack((&mut link_unicast, &mut state)).await);

// Check if the expected_zid matches the peer's ZID
if let Some(zid) = expected_zid {
if &iack_out.other_zid != zid {
let _ = link_unicast.close(Some(close::reason::INVALID)).await;
return Err(zerror!(
"Expected peer ZID {} but received {}",
zid,
iack_out.other_zid
)
.into());
}
}

// Open handshake
let osyn_in = SendOpenSynIn {
mine_zid: manager.config.zid,
Expand Down
20 changes: 18 additions & 2 deletions io/zenoh-transport/src/unicast/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,23 @@ impl TransportManager {
}
}

pub async fn open_transport_unicast(
pub async fn open_transport_unicast(&self, endpoint: EndPoint) -> ZResult<TransportUnicast> {
self.open_transport_unicast_inner(endpoint, None).await
}

pub async fn open_transport_unicast_with_zid(
&self,
endpoint: EndPoint,
expected_zid: &ZenohIdProto,
) -> ZResult<TransportUnicast> {
self.open_transport_unicast_inner(endpoint, Some(expected_zid))
.await
}

async fn open_transport_unicast_inner(
&self,
mut endpoint: EndPoint,
expected_zid: Option<&ZenohIdProto>,
) -> ZResult<TransportUnicast> {
if self
.locator_inspector
Expand Down Expand Up @@ -865,7 +879,9 @@ impl TransportManager {
// Open the link
tokio::time::timeout(self.config.unicast.open_timeout, async {
match manager.new_link(endpoint.clone()).await {
Ok(link) => super::establishment::open::open_link(endpoint, link, self).await,
Ok(link) => {
super::establishment::open::open_link(endpoint, link, self, expected_zid).await
}
Err(e) => Err(e),
}
})
Expand Down
176 changes: 176 additions & 0 deletions io/zenoh-transport/tests/unicast_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2987,3 +2987,179 @@ async fn test_multilink_max_links(
.await;
(router_manager, client_manager, client_transport)
}

#[cfg(feature = "transport_tcp")]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn transport_unicast_with_zid_matching() {
zenoh_util::init_log_from_env_or("error");

let client_id = ZenohIdProto::try_from([1]).unwrap();
let router_id = ZenohIdProto::try_from([2]).unwrap();

let router_handler = Arc::new(SHRouter::default());
let unicast = make_transport_manager_builder(
#[cfg(feature = "transport_multilink")]
1,
false,
);
let router_manager = TransportManager::builder()
.zid(router_id)
.whatami(WhatAmI::Router)
.unicast(unicast)
.build_test(router_handler.clone())
.unwrap();

let endpoint: EndPoint = format!("tcp/127.0.0.1:{}", get_free_tcp_port())
.parse()
.unwrap();

let _ = ztimeout!(router_manager.add_listener(endpoint.clone())).unwrap();

let unicast = make_transport_manager_builder(
#[cfg(feature = "transport_multilink")]
1,
false,
);
let client_manager = TransportManager::builder()
.whatami(WhatAmI::Client)
.zid(client_id)
.unicast(unicast)
.build_test(Arc::new(SHClient))
.unwrap();

// Open transport with expected ZID - should succeed
let transport =
ztimeout!(client_manager.open_transport_unicast_with_zid(endpoint.clone(), &router_id))
.unwrap();
assert_eq!(transport.get_zid().unwrap(), router_id);

ztimeout!(transport.close()).unwrap();
ztimeout!(async {
while !router_manager.get_transports_unicast().await.is_empty() {
tokio::time::sleep(SLEEP).await;
}
});
ztimeout!(router_manager.del_listener(&endpoint)).unwrap();
ztimeout!(router_manager.close());
ztimeout!(client_manager.close());
}

// Test that open_transport_unicast_with_zid fails when connecting to a peer with non-matching ZID
#[cfg(feature = "transport_tcp")]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn transport_unicast_with_zid_mismatch() {
zenoh_util::init_log_from_env_or("error");

let client_id = ZenohIdProto::try_from([1]).unwrap();
let router_id = ZenohIdProto::try_from([2]).unwrap();
let wrong_zid = ZenohIdProto::try_from([99]).unwrap();

let router_handler = Arc::new(SHRouter::default());
let unicast = make_transport_manager_builder(
#[cfg(feature = "transport_multilink")]
1,
false,
);
let router_manager = TransportManager::builder()
.zid(router_id)
.whatami(WhatAmI::Router)
.unicast(unicast)
.build_test(router_handler.clone())
.unwrap();

let endpoint: EndPoint = format!("tcp/127.0.0.1:{}", get_free_tcp_port())
.parse()
.unwrap();

let _ = ztimeout!(router_manager.add_listener(endpoint.clone())).unwrap();

let unicast = make_transport_manager_builder(
#[cfg(feature = "transport_multilink")]
1,
false,
);
let client_manager = TransportManager::builder()
.whatami(WhatAmI::Client)
.zid(client_id)
.unicast(unicast)
.build_test(Arc::new(SHClient))
.unwrap();

// Open transport with wrong expected ZID - should fail
let result =
ztimeout!(client_manager.open_transport_unicast_with_zid(endpoint.clone(), &wrong_zid));
assert!(
result.is_err(),
"Expected connection to fail with wrong ZID"
);

// Cleanup
ztimeout!(router_manager.del_listener(&endpoint)).unwrap();
ztimeout!(router_manager.close());
ztimeout!(client_manager.close());
}

// Test that multilink scenario works with open_transport_unicast_with_zid
// When connecting to multiple endpoints of the same peer, the same ZID should be enforced
#[cfg(all(feature = "transport_tcp", feature = "transport_multilink"))]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn transport_unicast_with_zid_multilink() {
zenoh_util::init_log_from_env_or("error");

let client_id = ZenohIdProto::try_from([1]).unwrap();
let router_id = ZenohIdProto::try_from([2]).unwrap();

let router_handler = Arc::new(SHRouter::default());
let unicast = make_transport_manager_builder(2, false);
let router_manager = TransportManager::builder()
.zid(router_id)
.whatami(WhatAmI::Router)
.unicast(unicast)
.build_test(router_handler.clone())
.unwrap();

let endpoints: Vec<EndPoint> = vec![
format!("tcp/127.0.0.1:{}", get_free_tcp_port())
.parse()
.unwrap(),
format!("tcp/127.0.0.1:{}", get_free_tcp_port())
.parse()
.unwrap(),
];

for e in endpoints.iter() {
let _ = ztimeout!(router_manager.add_listener(e.clone())).unwrap();
}

let unicast = make_transport_manager_builder(2, false);
let client_manager = TransportManager::builder()
.whatami(WhatAmI::Client)
.zid(client_id)
.unicast(unicast)
.build_test(Arc::new(SHClient))
.unwrap();

// Open first transport with expected ZID - should succeed
let transport1 =
ztimeout!(client_manager.open_transport_unicast_with_zid(endpoints[0].clone(), &router_id))
.unwrap();
assert_eq!(transport1.get_zid().unwrap(), router_id);

// Open second transport with the same expected ZID - should succeed
let transport2 =
ztimeout!(client_manager.open_transport_unicast_with_zid(endpoints[1].clone(), &router_id))
.unwrap();
assert_eq!(transport2.get_zid().unwrap(), router_id);

ztimeout!(transport1.close()).unwrap();
ztimeout!(async {
while !router_manager.get_transports_unicast().await.is_empty() {
tokio::time::sleep(SLEEP).await;
}
});
for e in endpoints.iter() {
ztimeout!(router_manager.del_listener(e)).unwrap();
}
ztimeout!(router_manager.close());
ztimeout!(client_manager.close());
}
2 changes: 1 addition & 1 deletion zenoh/src/net/runtime/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ impl Runtime {
Err(e) => tracing::trace!("{} {} on {}: {}", ERR, zid, locator, e),
}
} else {
match manager.open_transport_unicast(endpoint).await {
match manager.open_transport_unicast_with_zid(endpoint, zid).await {
Ok(transport) => {
tracing::debug!(
"Successfully connected to newly scouted peer: {:?}",
Expand Down
Loading