Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ public void shouldRejectCompleteWhenPending() {

assertThatThrownBy(
() -> componentClient.forTask(taskId).complete(TEST_TASK, new TestResult("done", 100)))
.hasMessageContaining("Task can only be completed when ASSIGNED or IN_PROGRESS");
.hasMessageContaining(
"Task can only be completed when ASSIGNED, IN_PROGRESS, or RESULT_REJECTED");
}

// --- fail ---
Expand All @@ -114,7 +115,8 @@ public void shouldRejectFailWhenPending() {
componentClient.forTask(taskId).create(TEST_TASK.instructions("do something"));

assertThatThrownBy(() -> componentClient.forTask(taskId).fail("rejected"))
.hasMessageContaining("Task can only be failed when ASSIGNED or IN_PROGRESS");
.hasMessageContaining(
"Task can only be failed when ASSIGNED, IN_PROGRESS, or RESULT_REJECTED");
}

// --- result ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,78 @@ public void shouldRejectFailWhenPending() {
assertThat(result.isError()).isTrue();
}

@Test
public void shouldRejectTask() {
var testKit = createTestKit();
testKit.method(TaskEntity::create).invoke(createRequest("research"));
testKit.method(TaskEntity::assign).invoke("agent-1");
testKit.method(TaskEntity::start).invoke();
publishedNotifications.clear();

EventSourcedResult<Done> result =
testKit
.method(TaskEntity::rejectResult)
.invoke(new TaskEntity.RejectResultRequest("com.example.MyRule", "result too short"));
assertThat(result.getReply()).isEqualTo(done());
result.getNextEventOfType(TaskEvent.TaskResultRejected.class);
assertThat(testKit.getState().status()).isEqualTo(TaskStatus.RESULT_REJECTED);
assertThat(testKit.getState().failureReason()).isEqualTo("result too short");

var rejected = (TaskNotification.ResultRejected) publishedNotifications.poll();
assertThat(rejected.ruleClassName()).isEqualTo("com.example.MyRule");
assertThat(rejected.reason()).isEqualTo("result too short");
assertThat(publishedNotifications).isEmpty();
}

@Test
public void shouldCompleteAfterRejection() {
var testKit = createTestKit();
testKit.method(TaskEntity::create).invoke(createRequest("research"));
testKit.method(TaskEntity::assign).invoke("agent-1");
testKit.method(TaskEntity::start).invoke();
testKit
.method(TaskEntity::rejectResult)
.invoke(new TaskEntity.RejectResultRequest("com.example.MyRule", "result too short"));
publishedNotifications.clear();

EventSourcedResult<Done> result =
testKit.method(TaskEntity::complete).invoke("{\"summary\":\"a longer result\"}");
assertThat(result.getReply()).isEqualTo(done());
assertThat(testKit.getState().status()).isEqualTo(TaskStatus.COMPLETED);
}

@Test
public void shouldTrackMultipleRejections() {
var testKit = createTestKit();
testKit.method(TaskEntity::create).invoke(createRequest("research"));
testKit.method(TaskEntity::assign).invoke("agent-1");
testKit.method(TaskEntity::start).invoke();

testKit
.method(TaskEntity::rejectResult)
.invoke(new TaskEntity.RejectResultRequest("Rule1", "first rejection"));
testKit
.method(TaskEntity::rejectResult)
.invoke(new TaskEntity.RejectResultRequest("Rule2", "second rejection"));

assertThat(testKit.getState().status()).isEqualTo(TaskStatus.RESULT_REJECTED);
}

@Test
public void shouldFailAfterRejection() {
var testKit = createTestKit();
testKit.method(TaskEntity::create).invoke(createRequest("research"));
testKit.method(TaskEntity::assign).invoke("agent-1");
testKit.method(TaskEntity::start).invoke();
testKit
.method(TaskEntity::rejectResult)
.invoke(new TaskEntity.RejectResultRequest("com.example.MyRule", "bad result"));

EventSourcedResult<Done> result = testKit.method(TaskEntity::fail).invoke("giving up");
assertThat(result.getReply()).isEqualTo(done());
assertThat(testKit.getState().status()).isEqualTo(TaskStatus.FAILED);
}

@Test
public void shouldNotPublishNotificationOnIdempotentComplete() {
var testKit = createTestKit();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,37 @@ public void shouldCompleteTaskWhenRuleAccepts() {
}

@Test
public void shouldFailTaskWhenRuleRejects() {
public void shouldRetryAndCompleteAfterRuleRejects() {
// First attempt: bad result rejected by rule. Retry: good result accepted.
agentModel
.whenMessage(msg -> msg.contains("Do something"))
.reply(completeTask(new TestTasks.TestResult("low quality", 3)));

agentModel
.whenMessage(msg -> msg.contains("Reminder"))
.reply(completeTask(new TestTasks.TestResult("improved result", 50)));

var taskId =
componentClient
.forAutonomousAgent(ValidatedTaskAgent.class, UUID.randomUUID().toString())
.runSingleTask(TestTasks.VALIDATED_TASK.instructions("Do something."));

Awaitility.await()
.ignoreExceptions()
.atMost(10, TimeUnit.SECONDS)
.untilAsserted(
() -> {
var snapshot = componentClient.forTask(taskId).get(TestTasks.VALIDATED_TASK);
assertThat(snapshot.status()).isEqualTo(TaskStatus.COMPLETED);
assertThat(snapshot.result().value()).isEqualTo("improved result");
assertThat(snapshot.result().score()).isEqualTo(50);
});
}

@Test
public void shouldFailAfterRepeatedRuleRejections() {
// Agent always returns the same bad result — rule rejects every attempt
// until maxIterationsPerTask is exhausted and the runtime fails the task
agentModel.fixedResponse(completeTask(new TestTasks.TestResult("low quality", 3)));

var taskId =
Expand All @@ -72,7 +102,7 @@ public void shouldFailTaskWhenRuleRejects() {
() -> {
var snapshot = componentClient.forTask(taskId).get(TestTasks.VALIDATED_TASK);
assertThat(snapshot.status()).isEqualTo(TaskStatus.FAILED);
assertThat(snapshot.failureReason()).contains("score must be >= 10");
assertThat(snapshot.failureReason()).contains("Max iterations");
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package akka.javasdk.agent.task;

import java.util.List;

/**
* A task definition declares what kind of work an agent can do — a description of the task and the
* expected result type.
Expand All @@ -22,4 +24,7 @@ public sealed interface TaskDefinition<R> permits Task, TaskTemplate {

/** The expected result type. */
Class<R> resultType();

/** The validation rule classes for this task definition. */
List<Class<? extends TaskRule<R>>> ruleClasses();
}
37 changes: 33 additions & 4 deletions akka-javasdk/src/main/java/akka/javasdk/agent/task/TaskEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public record CreateRequest(
List<TaskAttachment> attachments,
List<String> ruleClassNames) {}

public record RejectResultRequest(String ruleClassName, String reason) {}

public record ReassignRequest(String newAssignee, String context) {}

@Override
Expand Down Expand Up @@ -92,8 +94,10 @@ public Effect<Done> complete(String result) {
return effects().reply(done()); // idempotent for any terminal state
}
if (currentState().status() != TaskStatus.ASSIGNED
&& currentState().status() != TaskStatus.IN_PROGRESS) {
return effects().error("Task can only be completed when ASSIGNED or IN_PROGRESS");
&& currentState().status() != TaskStatus.IN_PROGRESS
&& currentState().status() != TaskStatus.RESULT_REJECTED) {
return effects()
.error("Task can only be completed when ASSIGNED, IN_PROGRESS, or RESULT_REJECTED");
}
return effects()
.persist(new TaskEvent.TaskCompleted(taskId, result))
Expand All @@ -104,13 +108,37 @@ && currentState().status() != TaskStatus.IN_PROGRESS) {
});
}

public Effect<Done> rejectResult(TaskEntity.RejectResultRequest request) {
if (currentState().taskId().isEmpty()) {
return effects().error("Task does not exist");
}
if (currentState().status() != TaskStatus.ASSIGNED
&& currentState().status() != TaskStatus.IN_PROGRESS
&& currentState().status() != TaskStatus.RESULT_REJECTED) {
return effects()
.error("Task result can only be rejected when ASSIGNED, IN_PROGRESS, or RESULT_REJECTED");
}
return effects()
.persist(
new TaskEvent.TaskResultRejected(taskId, request.ruleClassName(), request.reason()))
.thenReply(
__ -> {
notificationPublisher.publish(
new TaskNotification.ResultRejected(
taskId, request.ruleClassName(), request.reason()));
return done();
});
}

public Effect<Done> fail(String reason) {
if (isTerminal()) {
return effects().reply(done()); // idempotent for any terminal state
}
if (currentState().status() != TaskStatus.ASSIGNED
&& currentState().status() != TaskStatus.IN_PROGRESS) {
return effects().error("Task can only be failed when ASSIGNED or IN_PROGRESS");
&& currentState().status() != TaskStatus.IN_PROGRESS
&& currentState().status() != TaskStatus.RESULT_REJECTED) {
return effects()
.error("Task can only be failed when ASSIGNED, IN_PROGRESS, or RESULT_REJECTED");
}
return effects()
.persist(new TaskEvent.TaskFailed(taskId, reason))
Expand Down Expand Up @@ -188,6 +216,7 @@ public TaskState applyEvent(TaskEvent event) {
case TaskEvent.TaskAssigned e -> currentState().withAssignee(e.assignee());
case TaskEvent.TaskStarted e -> currentState().withStatus(TaskStatus.IN_PROGRESS);
case TaskEvent.TaskCompleted e -> currentState().withResult(e.result());
case TaskEvent.TaskResultRejected e -> currentState().withResultRejection(e.reason());
case TaskEvent.TaskFailed e -> currentState().withFailure(e.reason());
case TaskEvent.TaskCancelled e -> currentState().withCancellation(e.reason());
case TaskEvent.TaskReassigned e ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ record TaskAssigned(String taskId, String assignee) implements TaskEvent {}
@TypeName("akka-task-started")
record TaskStarted(String taskId) implements TaskEvent {}

@TypeName("akka-task-result-rejected")
record TaskResultRejected(String taskId, String ruleClassName, String reason)
implements TaskEvent {}

@TypeName("akka-task-completed")
record TaskCompleted(String taskId, String result) implements TaskEvent {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ public String reason() {
return reason;
}

/** Thrown when a task result is rejected by a validation rule. */
public static final class ResultRejected extends TaskException {
private final String ruleClassName;

public ResultRejected(String taskId, String ruleClassName, String reason) {
super(taskId, reason);
this.ruleClassName = ruleClassName;
}

public String ruleClassName() {
return ruleClassName;
}
}

/** Thrown when a task reaches the {@link TaskStatus#FAILED} state. */
public static final class Failed extends TaskException {
public Failed(String taskId, String reason) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ public sealed interface TaskNotification {
@TypeName("akka-task-notification-completed")
record Completed(String taskId, String result) implements TaskNotification {}

@TypeName("akka-task-notification-result-rejected")
record ResultRejected(String taskId, String ruleClassName, String reason)
implements TaskNotification {}

@TypeName("akka-task-notification-failed")
record Failed(String taskId, String reason) implements TaskNotification {}

Expand Down
17 changes: 17 additions & 0 deletions akka-javasdk/src/main/java/akka/javasdk/agent/task/TaskState.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,23 @@ public TaskState withResult(String result) {
ruleClassNames);
}

public TaskState withResultRejection(String reason) {
return new TaskState(
taskId,
name,
description,
instructions,
TaskStatus.RESULT_REJECTED,
resultTypeName,
result,
reason,
dependencyTaskIds,
assignee,
attachments,
reassignmentContext,
ruleClassNames);
}

public TaskState withFailure(String reason) {
return new TaskState(
taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ public enum TaskStatus {
PENDING,
ASSIGNED,
IN_PROGRESS,
RESULT_REJECTED,
COMPLETED,
FAILED,
CANCELLED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ private[impl] final class AutonomousAgentImpl(
private val taskAssignMethod = classOf[TaskEntity].getMethod("assign", classOf[String])
private val taskStartMethod = classOf[TaskEntity].getMethod("start")
private val taskCompleteMethod = classOf[TaskEntity].getMethod("complete", classOf[String])
private val taskRejectResultMethod =
classOf[TaskEntity].getMethod("rejectResult", classOf[TaskEntity.RejectResultRequest])
private val taskFailMethod = classOf[TaskEntity].getMethod("fail", classOf[String])
private val taskCancelMethod = classOf[TaskEntity].getMethod("cancel", classOf[String])
private val taskReassignMethod = classOf[TaskEntity].getMethod("reassign", classOf[TaskEntity.ReassignRequest])
Expand Down Expand Up @@ -156,7 +158,7 @@ private[impl] final class AutonomousAgentImpl(
request.resultTypeName.orNull,
request.dependencyTaskIds.asJava,
attachments,
Seq.empty[String].asJava)
request.ruleClassNames.asJava)
taskEntityClient(taskId)
.methodRefOneArg[TaskEntity.CreateRequest, Done](taskCreateMethod)
.withMetadata(MetadataImpl.of(context))
Expand Down Expand Up @@ -211,12 +213,13 @@ private[impl] final class AutonomousAgentImpl(
taskId,
ruleClassName,
reason)
val rejectRequest = new TaskEntity.RejectResultRequest(ruleClassName, reason)
taskEntityClient(taskId)
.methodRefOneArg[String, Done](taskFailMethod)
.methodRefOneArg[TaskEntity.RejectResultRequest, Done](taskRejectResultMethod)
.withMetadata(MetadataImpl.of(context))
.invokeAsync("Task rule rejected: " + reason)
.invokeAsync(rejectRequest)
.asScala
.map(_ => Done)(sdkExecutionContext)
.flatMap(_ => Future.failed(new SpiTask.TaskResultRejectedException(reason)))(sdkExecutionContext)
}
}(sdkExecutionContext)

Expand Down Expand Up @@ -370,12 +373,13 @@ private[impl] final class AutonomousAgentImpl(

private def toSpiTaskState(state: akka.javasdk.agent.task.TaskState): SpiTask.SpiTaskState = {
val spiStatus = state.status() match {
case TaskStatus.PENDING => SpiTask.SpiTaskStatus.Pending
case TaskStatus.ASSIGNED => SpiTask.SpiTaskStatus.Assigned
case TaskStatus.IN_PROGRESS => SpiTask.SpiTaskStatus.InProgress
case TaskStatus.COMPLETED => SpiTask.SpiTaskStatus.Completed
case TaskStatus.FAILED => SpiTask.SpiTaskStatus.Failed
case TaskStatus.CANCELLED => SpiTask.SpiTaskStatus.Cancelled
case TaskStatus.PENDING => SpiTask.SpiTaskStatus.Pending
case TaskStatus.ASSIGNED => SpiTask.SpiTaskStatus.Assigned
case TaskStatus.IN_PROGRESS => SpiTask.SpiTaskStatus.InProgress
case TaskStatus.RESULT_REJECTED => SpiTask.SpiTaskStatus.ResultRejected
case TaskStatus.COMPLETED => SpiTask.SpiTaskStatus.Completed
case TaskStatus.FAILED => SpiTask.SpiTaskStatus.Failed
case TaskStatus.CANCELLED => SpiTask.SpiTaskStatus.Cancelled
}
val resultTypeName = Option(state.resultTypeName())
val resultSchema = resultTypeName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ private[javasdk] object CapabilityConverter {
(Option(template.instructionTemplate()).filter(_.nonEmpty), parameters)
case _ => (None, Seq.empty)
}
val ruleClassNames = taskDefinition.ruleClasses().asScala.toSeq.map(_.getName)
new SpiTask.SpiTaskDefinition(
name = taskDefinition.name(),
description = taskDefinition.description(),
resultTypeName = resultType.getName,
resultSchema = resultSchema,
instructionTemplate = instructionTemplate,
templateParameters = templateParameters)
templateParameters = templateParameters,
ruleClassNames = ruleClassNames)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ private[javasdk] final class TaskClientImpl(
taskId,
ruleClassName,
reason)
sendCommand("Fail", serializer.toBytes("Task rule rejected: " + reason)).asScala
val rejectRequest = new TaskEntity.RejectResultRequest(ruleClassName, reason)
sendCommand("RejectResult", serializer.toBytes(rejectRequest)).asScala
}
}
}
Expand Down
Loading
Loading