Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@
<version>${log4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<version>4.3.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>4.11.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
import com.alipay.remoting.rpc.common.SimpleServerUserProcessor;
import com.alipay.remoting.rpc.protocol.RpcProtocol;

import java.util.concurrent.TimeUnit;

import static org.awaitility.Awaitility.await;

/**
* Runtime operation connection heart beat test
*
Expand Down Expand Up @@ -93,27 +97,51 @@ public void stop() {

@Test
public void testRuntimeCloseAndEnableHeartbeat() throws InterruptedException {
// Register heart beat processor
server.getRpcServer().registerProcessor(RpcProtocol.PROTOCOL_CODE,
CommonCommandCode.HEARTBEAT, heartBeatProcessor);

// Establish connection
try {
client.getConnection(addr, 1000);
} catch (RemotingException e) {
logger.error("", e);
logger.error("Failed to establish connection", e);
}
Thread.sleep(1500);
logger.warn("before disable: " + heartBeatProcessor.getHeartBeatTimes());
Assert.assertTrue(heartBeatProcessor.getHeartBeatTimes() > 0);

// Phase 1: Verify heartbeats are being sent
await().atMost(3, TimeUnit.SECONDS)
.pollInterval(100, TimeUnit.MILLISECONDS)
.until(() -> heartBeatProcessor.getHeartBeatTimes() > 0);

logger.warn("before disable: {}", heartBeatProcessor.getHeartBeatTimes());

// Phase 2: Disable heartbeats
client.disableConnHeartbeat(addr);

// Wait a bit to make sure any in-flight heartbeats are processed
Thread.sleep(200);

// Reset counter
heartBeatProcessor.reset();
Thread.sleep(1500);
logger.warn("after disable: " + heartBeatProcessor.getHeartBeatTimes());
Assert.assertEquals(0, heartBeatProcessor.getHeartBeatTimes());

// Verify no new heartbeats arrive after disabling
await().pollDelay(500, TimeUnit.MILLISECONDS)
.during(1, TimeUnit.SECONDS)
.atMost(2, TimeUnit.SECONDS)
.pollInterval(100, TimeUnit.MILLISECONDS)
.until(() -> heartBeatProcessor.getHeartBeatTimes() == 0);

logger.warn("after disable: {}", heartBeatProcessor.getHeartBeatTimes());

// Phase 3: Re-enable heartbeats
client.enableConnHeartbeat(addr);
heartBeatProcessor.reset();
Thread.sleep(1500);
logger.warn("after enable: " + heartBeatProcessor.getHeartBeatTimes());
Assert.assertTrue(heartBeatProcessor.getHeartBeatTimes() > 0);

// Verify heartbeats resume
await().atMost(3, TimeUnit.SECONDS)
.pollInterval(100, TimeUnit.MILLISECONDS)
.until(() -> heartBeatProcessor.getHeartBeatTimes() > 0);

logger.warn("after enable: {}", heartBeatProcessor.getHeartBeatTimes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,24 @@
import com.alipay.remoting.LifeCycleException;
import com.alipay.remoting.RemotingContext;
import com.alipay.remoting.rpc.RpcCommandFactory;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import java.util.concurrent.CopyOnWriteArrayList;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import static org.awaitility.Awaitility.await;

/**
* @author Even
* @date 2024/4/29 11:20
Expand All @@ -43,21 +48,29 @@ public class RpcCommandHandlerTest {

private static RemotingContext remotingContext = null;

private static final List<RemotingContext> remotingContextList = new ArrayList<>();

private static final CountDownLatch countDownLatch = new CountDownLatch(2);
private static final List<RemotingContext> remotingContextList = new CopyOnWriteArrayList<>();

private static final Logger LOGGER = LoggerFactory.getLogger(RpcCommandHandlerTest.class);

@BeforeClass
public static void beforeClass() {
// Create a mock ChannelHandlerContext
ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class);

// Mock minimum required behavior if needed
Channel channel = Mockito.mock(Channel.class);
Mockito.when(ctx.channel()).thenReturn(channel);

ConcurrentHashMap<String, UserProcessor<?>> userProcessors = new ConcurrentHashMap<>();
userProcessors.put("testClass", new MockUserProcessors());
remotingContext = new RemotingContext(null, new InvokeContext(),true, userProcessors);
remotingContext = new RemotingContext(ctx, new InvokeContext(),true, userProcessors);
}

@Test
public void testHandleCommand() throws Exception {
// Clear any previous test data
remotingContextList.clear();

List<RpcRequestCommand> msg = new ArrayList<>();
RpcRequestCommand rpcRequestCommand = new RpcRequestCommand();
rpcRequestCommand.setTimeout(1000);
Expand All @@ -67,11 +80,16 @@ public void testHandleCommand() throws Exception {
rpcRequestCommand2.setRequestClass("testClass");
msg.add(rpcRequestCommand);
msg.add(rpcRequestCommand2);

RpcCommandHandler rpcCommandHandler = new RpcCommandHandler(new RpcCommandFactory());
rpcCommandHandler.handleCommand(remotingContext, msg);
boolean result = countDownLatch.await(15, TimeUnit.SECONDS);
Assert.assertTrue(result);
Assert.assertEquals(2, remotingContextList.size());

// Use Awaitility to wait for the conditions to be met
await().atMost(20, TimeUnit.SECONDS)
.pollInterval(100, TimeUnit.MILLISECONDS)
.until(() -> remotingContextList.size() == 2);

Assert.assertEquals(2, remotingContextList.size());
Assert.assertTrue(remotingContextList.get(0).getTimeout() != remotingContextList.get(1).getTimeout());
}

Expand All @@ -94,9 +112,8 @@ public boolean isStarted() {

@Override
public BizContext preHandleRequest(RemotingContext remotingCtx, Object request) {
Assert.assertNotSame(remotingCtx, remotingContext);
Assert.assertNotSame(remotingCtx, remotingContext);
remotingContextList.add(remotingCtx);
countDownLatch.countDown();
return null;
}

Expand Down Expand Up @@ -146,4 +163,4 @@ public ExecutorSelector getExecutorSelector() {
return null;
}
}
}
}
Loading