diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 83e19497..fe6206db 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -57,6 +57,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -87,46 +88,35 @@ public BaseLlmFlow( /** * Pre-processes the LLM request before sending it to the LLM. Executes all registered {@link - * RequestProcessor}. + * RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the + * events generated by them. */ - protected Single preprocess( - InvocationContext context, LlmRequest llmRequest) { - - List> eventIterables = new ArrayList<>(); + protected Flowable preprocess( + InvocationContext context, AtomicReference llmRequestRef) { LlmAgent agent = (LlmAgent) context.agent(); - Single currentLlmRequest = Single.just(llmRequest); - for (RequestProcessor processor : requestProcessors) { - currentLlmRequest = - currentLlmRequest - .flatMap(request -> processor.processRequest(context, request)) - .doOnSuccess( - result -> { - if (result.events() != null) { - eventIterables.add(result.events()); - } - }) - .map(RequestProcessingResult::updatedRequest); - } - - return currentLlmRequest.flatMap( - processedRequest -> { - LlmRequest.Builder updatedRequestBuilder = processedRequest.toBuilder(); - + RequestProcessor toolsProcessor = + (ctx, req) -> { + LlmRequest.Builder builder = req.toBuilder(); return agent - .canonicalTools(new ReadonlyContext(context)) + .canonicalTools(new ReadonlyContext(ctx)) .concatMapCompletable( - tool -> - tool.processLlmRequest( - updatedRequestBuilder, ToolContext.builder(context).build())) + tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> { - Iterable combinedEvents = Iterables.concat(eventIterables); - return RequestProcessingResult.create( - updatedRequestBuilder.build(), combinedEvents); - })); - }); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); + }; + + Iterable allProcessors = + Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor)); + + return Flowable.fromIterable(allProcessors) + .concatMap( + processor -> + Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) + .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) + .flattenAsFlowable( + result -> result.events() != null ? result.events() : ImmutableList.of())); } /** @@ -343,24 +333,23 @@ private Single handleAfterModelCallback( * @throws IllegalStateException if a transfer agent is specified but not found. */ private Flowable runOneStep(InvocationContext context) { - LlmRequest initialLlmRequest = LlmRequest.builder().build(); - - return preprocess(context, initialLlmRequest) - .flatMapPublisher( - preResult -> { - LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest(); - Iterable preEvents = preResult.events(); + AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); + Flowable preprocessEvents = preprocess(context, llmRequestRef); + return preprocessEvents.concatWith( + Flowable.defer( + () -> { + LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); if (context.endInvocation()) { logger.debug("End invocation requested during preprocessing."); - return Flowable.fromIterable(preEvents); + return Flowable.empty(); } try { context.incrementLlmCallsCount(); } catch (LlmCallsLimitExceededException e) { logger.error("LLM calls limit exceeded.", e); - return Flowable.fromIterable(preEvents).concatWith(Flowable.error(e)); + return Flowable.error(e); } final Event mutableEventTemplate = @@ -374,48 +363,44 @@ private Flowable runOneStep(InvocationContext context) { // events with fresh timestamp. mutableEventTemplate.setTimestamp(0L); - Flowable restOfFlow = - callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) - .concatMap( - llmResponse -> - postprocess( - context, - mutableEventTemplate, - llmRequestAfterPreprocess, - llmResponse) - .doFinally( - () -> { - String oldId = mutableEventTemplate.id(); - mutableEventTemplate.setId(Event.generateEventId()); - logger.debug( - "Updated mutableEventTemplate ID from {} to {} for" - + " next LlmResponse", - oldId, - mutableEventTemplate.id()); - })) - .concatMap( - event -> { - Flowable postProcessedEvents = Flowable.just(event); - if (event.actions().transferToAgent().isPresent()) { - String agentToTransfer = event.actions().transferToAgent().get(); - logger.debug("Transferring to agent: {}", agentToTransfer); - BaseAgent rootAgent = context.agent().rootAgent(); - BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); - if (nextAgent == null) { - String errorMsg = - "Agent not found for transfer: " + agentToTransfer; - logger.error(errorMsg); - return postProcessedEvents.concatWith( - Flowable.error(new IllegalStateException(errorMsg))); - } - return postProcessedEvents.concatWith( - Flowable.defer(() -> nextAgent.runAsync(context))); - } - return postProcessedEvents; - }); - - return restOfFlow.startWithIterable(preEvents); - }); + return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) + .concatMap( + llmResponse -> + postprocess( + context, + mutableEventTemplate, + llmRequestAfterPreprocess, + llmResponse) + .doFinally( + () -> { + String oldId = mutableEventTemplate.id(); + mutableEventTemplate.setId(Event.generateEventId()); + logger.debug( + "Updated mutableEventTemplate ID from {} to {} for" + + " next LlmResponse", + oldId, + mutableEventTemplate.id()); + })) + .concatMap( + event -> { + Flowable postProcessedEvents = Flowable.just(event); + if (event.actions().transferToAgent().isPresent()) { + String agentToTransfer = event.actions().transferToAgent().get(); + logger.debug("Transferring to agent: {}", agentToTransfer); + BaseAgent rootAgent = context.agent().rootAgent(); + BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer); + if (nextAgent == null) { + String errorMsg = "Agent not found for transfer: " + agentToTransfer; + logger.error(errorMsg); + return postProcessedEvents.concatWith( + Flowable.error(new IllegalStateException(errorMsg))); + } + return postProcessedEvents.concatWith( + Flowable.defer(() -> nextAgent.runAsync(context))); + } + return postProcessedEvents; + }); + })); } /** @@ -465,14 +450,15 @@ private Flowable run(InvocationContext invocationContext, int stepsComple */ @Override public Flowable runLive(InvocationContext invocationContext) { - LlmRequest llmRequest = LlmRequest.builder().build(); + AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); + Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); - return preprocess(invocationContext, llmRequest) - .flatMapPublisher( - preResult -> { - LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest(); + return preprocessEvents.concatWith( + Flowable.defer( + () -> { + LlmRequest llmRequestAfterPreprocess = llmRequestRef.get(); if (invocationContext.endInvocation()) { - return Flowable.fromIterable(preResult.events()); + return Flowable.empty(); } String eventIdForSendData = Event.generateEventId(); @@ -623,10 +609,8 @@ public void onError(Throwable e) { } }); - return receiveFlow - .takeWhile(event -> !event.actions().endInvocation().orElse(false)) - .startWithIterable(preResult.events()); - }); + return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false)); + })); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java index 5008718b..aa70d57b 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java @@ -17,7 +17,6 @@ package com.google.adk.flows.llmflows; import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME; -import static com.google.adk.flows.llmflows.Functions.TOOL_CALL_SECURITY_STATES; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -157,20 +156,8 @@ public Single processRequest( toolsToResumeWithArgs.values(), ImmutableMap.copyOf(toolsToResumeWithConfirmation)) .map( - assembledEvent -> { - clearToolCallSecurityStates(invocationContext, toolsToResumeWithArgs.keySet()); - - // Create an updated LlmRequest including the new event's content - ImmutableList.Builder updatedContentsBuilder = - ImmutableList.builder().addAll(llmRequest.contents()); - assembledEvent.content().ifPresent(updatedContentsBuilder::add); - - LlmRequest updatedLlmRequest = - llmRequest.toBuilder().contents(updatedContentsBuilder.build()).build(); - - return RequestProcessingResult.create( - updatedLlmRequest, ImmutableList.of(assembledEvent)); - }) + assembledEvent -> + RequestProcessingResult.create(llmRequest, ImmutableList.of(assembledEvent))) .toSingle() .onErrorReturn( e -> { @@ -255,36 +242,4 @@ private Optional> maybeCreateToolConfirmatio return Optional.empty(); } - - private void clearToolCallSecurityStates( - InvocationContext invocationContext, Collection processedFunctionCallIds) { - var state = invocationContext.session().state(); - Object statesObj = state.get(TOOL_CALL_SECURITY_STATES); - - if (statesObj == null) { - return; - } - if (!(statesObj instanceof Map)) { - logger.warn( - "Session key {} does not contain a Map, cannot clear tool states. Found: {}", - TOOL_CALL_SECURITY_STATES, - statesObj.getClass().getName()); - return; - } - - try { - @SuppressWarnings("unchecked") // safe after instanceof check - Map updatedToolCallStates = new HashMap<>((Map) statesObj); - - // Remove the entries for the function calls that just got processed - processedFunctionCallIds.forEach(updatedToolCallStates::remove); - - state.put(TOOL_CALL_SECURITY_STATES, updatedToolCallStates); - } catch (ClassCastException e) { - logger.warn( - "Session key {} has unexpected map types, cannot clear tool states.", - TOOL_CALL_SECURITY_STATES, - e); - } - } } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index e96099a7..cc2fb443 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -27,11 +27,11 @@ public class SingleFlow extends BaseLlmFlow { protected static final ImmutableList REQUEST_PROCESSORS = ImmutableList.of( new Basic(), + new RequestConfirmationLlmRequestProcessor(), new Instructions(), new Identity(), new Contents(), new Examples(), - new RequestConfirmationLlmRequestProcessor(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java index 37a04393..5f4932a8 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BaseLlmFlowTest.java @@ -46,6 +46,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.junit.Test; import org.junit.runner.RunWith; @@ -301,6 +302,118 @@ public void run_withTools_toolsAreAddedToRequest() { assertThat(testLlm.getLastRequest().tools()).containsEntry("my_function", testTool); } + @Test + public void run_withRequestProcessorsAndTools_modifiesRequestInOrder() { + Content content = Content.fromParts(Part.fromText("LLM response")); + TestLlm testLlm = createTestLlm(createLlmResponse(content)); + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(testLlm) + .tools(ImmutableList.of(new TestTool("my_function", ImmutableMap.of()))) + .build()); + RequestProcessor requestProcessor1 = + createRequestProcessor( + request -> + request.toBuilder().appendInstructions(ImmutableList.of("instruction1")).build()); + RequestProcessor requestProcessor2 = + createRequestProcessor( + request -> + request.toBuilder().appendInstructions(ImmutableList.of("instruction2")).build()); + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + ImmutableList.of(requestProcessor1, requestProcessor2), + /* responseProcessors= */ ImmutableList.of()); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(testLlm.getLastRequest().tools()).containsKey("my_function"); + assertThat(testLlm.getLastRequest().config().orElseThrow().systemInstruction().orElseThrow()) + .isEqualTo(Content.fromParts(Part.fromText("instruction1\n\ninstruction2"))); + } + + @Test + public void run_requestProcessorsEmitEventsDirectly() { + Event eventFromProcessor1 = + Event.builder() + .id("event1") + .invocationId("invId") + .author("user") + .content(Content.fromParts(Part.fromText("event1"))) + .build(); + RequestProcessor processor1 = + (unusedCtx, request) -> + Single.just( + RequestProcessingResult.create(request, ImmutableList.of(eventFromProcessor1))); + RequestProcessor processor2 = + (context, request) -> { + boolean sawEvent1 = + context.session().events().stream() + .anyMatch(e -> e.id().equals(eventFromProcessor1.id())); + + Event resultEvent = + Event.builder() + .id("event2") + .invocationId("invId") + .author("user") + .content( + Content.fromParts( + Part.fromText(sawEvent1 ? "event1 was seen" : "event1 was not seen"))) + .build(); + + return Single.just( + RequestProcessingResult.create(request, ImmutableList.of(resultEvent))); + }; + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + ImmutableList.of(processor1, processor2), /* responseProcessors= */ ImmutableList.of()); + InvocationContext invocationContext = + createInvocationContext( + createTestAgent( + createTestLlm( + createLlmResponse(Content.fromParts(Part.fromText("llm response")))))); + + List events = + baseLlmFlow + .run(invocationContext) + .doOnNext(event -> invocationContext.session().events().add(event)) + .toList() + .blockingGet(); + + assertThat(events.stream().map(Event::stringifyContent)) + .containsExactly("event1", "event1 was seen", "llm response") + .inOrder(); + } + + @Test + public void run_requestProcessorsAreCalledExactlyOnce() { + AtomicInteger processor1CallCount = new AtomicInteger(); + AtomicInteger processor2CallCount = new AtomicInteger(); + + RequestProcessor processor1 = + (unusedCtx, request) -> { + processor1CallCount.incrementAndGet(); + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + }; + RequestProcessor processor2 = + (unusedCtx, request) -> { + processor2CallCount.incrementAndGet(); + return Single.just(RequestProcessingResult.create(request, ImmutableList.of())); + }; + BaseLlmFlow baseLlmFlow = + createBaseLlmFlow( + ImmutableList.of(processor1, processor2), /* responseProcessors= */ ImmutableList.of()); + InvocationContext invocationContext = + createInvocationContext( + createTestAgent( + createTestLlm( + createLlmResponse(Content.fromParts(Part.fromText("llm response")))))); + + List unused = baseLlmFlow.run(invocationContext).toList().blockingGet(); + + assertThat(processor1CallCount.get()).isEqualTo(1); + assertThat(processor2CallCount.get()).isEqualTo(1); + } + private static BaseLlmFlow createBaseLlmFlowWithoutProcessors() { return createBaseLlmFlow(ImmutableList.of(), ImmutableList.of()); } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index f1a01100..a931ac87 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -16,10 +16,12 @@ package com.google.adk.runner; +import static com.google.adk.testing.TestUtils.createFunctionCallLlmResponse; import static com.google.adk.testing.TestUtils.createLlmResponse; import static com.google.adk.testing.TestUtils.createTestAgent; import static com.google.adk.testing.TestUtils.createTestAgentBuilder; import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.adk.testing.TestUtils.createTextLlmResponse; import static com.google.adk.testing.TestUtils.simplifyEvents; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -35,6 +37,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.Functions; import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; @@ -44,10 +47,13 @@ import com.google.adk.testing.TestUtils; import com.google.adk.testing.TestUtils.EchoTool; import com.google.adk.testing.TestUtils.FailingEchoTool; +import com.google.adk.tools.FunctionTool; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; @@ -900,4 +906,82 @@ public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() { .test() .assertError(IllegalArgumentException.class); } + + @Test + public void runAsync_withToolConfirmation() { + TestLlm testLlm = + createTestLlm( + createFunctionCallLlmResponse( + "tool_call_id", "echoTool", ImmutableMap.of("message", "hello")), + createTextLlmResponse("Response after observing tool needs confirmation."), + createTextLlmResponse("Response after user confirmed.")); + LlmAgent agent = + createTestAgentBuilder(testLlm) + .tools(FunctionTool.create(Tools.class, "echoTool", /* requireConfirmation= */ true)) + .build(); + Runner runner = Runner.builder().agent(agent).appName("test").build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + List eventsBeforeConfirmation = + runner + .runAsync("user", session.id(), Content.fromParts(Part.fromText("from user"))) + .toList() + .blockingGet(); + FunctionCall askUserConfirmationFunctionCall = + Iterables.getOnlyElement( + eventsBeforeConfirmation.stream() + .map(Functions::getAskUserConfirmationFunctionCalls) + .filter(functionCalls -> !functionCalls.isEmpty()) + .findFirst() + .get()); + List eventsAfterConfirmation = + runner + .runAsync( + "user", + session.id(), + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id(askUserConfirmationFunctionCall.id().get()) + .name(askUserConfirmationFunctionCall.name().get()) + .response(ImmutableMap.of("confirmed", true))) + .build())) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(eventsBeforeConfirmation)) + .containsExactly( + "test agent: FunctionCall(name=echoTool, args={message=hello})", + "test agent: FunctionCall(name=adk_request_confirmation," + + " args={originalFunctionCall=FunctionCall{id=Optional[tool_call_id]," + + " args=Optional[{message=hello}], name=Optional[echoTool]," + + " partialArgs=Optional.empty, willContinue=Optional.empty}," + + " toolConfirmation=ToolConfirmation{hint=Please approve or reject the tool call" + + " echoTool() by responding with a FunctionResponse with an expected" + + " ToolConfirmation payload., confirmed=false, payload=null}})", + "test agent: FunctionResponse(name=echoTool, response={error=This tool call requires" + + " confirmation, please approve or reject.})", + "test agent: Response after observing tool needs confirmation.") + .inOrder(); + assertThat(simplifyEvents(eventsAfterConfirmation)) + .containsExactly( + "test agent: FunctionResponse(name=echoTool, response={message=hello})", + "test agent: Response after user confirmed.") + .inOrder(); + assertThat(testLlm.getLastRequest().contents().stream().map(TestUtils::formatContent)) + .containsExactly( + "from user", + "FunctionCall(name=echoTool, args={message=hello})", + "FunctionResponse(name=echoTool, response={message=hello})") + .inOrder(); + } + + public static class Tools { + private Tools() {} + + public static ImmutableMap echoTool(String message) { + return ImmutableMap.of("message", message); + } + } } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index d21a19f4..d4ccade2 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -81,17 +81,23 @@ public static Event createEscalateEvent(String id) { .build(); } - public static ImmutableList simplifyEvents(List events) { + public static ImmutableList simplifyEvents(List events) { return events.stream() .map(event -> event.author() + ": " + formatEventContent(event)) .collect(toImmutableList()); } private static String formatEventContent(Event event) { - return event - .content() - .or(() -> event.actions().compaction().map(EventCompaction::compactedContent)) - .flatMap(Content::parts) + return formatContent( + event + .content() + .or(() -> event.actions().compaction().map(EventCompaction::compactedContent)) + .orElse(Content.builder().build())); + } + + public static String formatContent(Content content) { + return content + .parts() .map( parts -> { if (parts.size() == 1) { @@ -214,6 +220,22 @@ public static LlmResponse createLlmResponse(Content content) { return LlmResponse.builder().content(content).build(); } + public static LlmResponse createTextLlmResponse(String text) { + return createLlmResponse(Content.builder().role("model").parts(Part.fromText(text)).build()); + } + + public static LlmResponse createFunctionCallLlmResponse( + String id, String functionName, Map args) { + Content content = + Content.builder() + .parts( + Part.builder() + .functionCall(FunctionCall.builder().id(id).name(functionName).args(args))) + .role("model") + .build(); + return createLlmResponse(content); + } + public static class EchoTool extends BaseTool { public EchoTool() { super("echo_tool", "description");