Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 78 additions & 94 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -87,46 +88,35 @@ public BaseLlmFlow(

/**
* Pre-processes the LLM request before sending it to the LLM. Executes all registered {@link
* RequestProcessor}.
* RequestProcessor} transforming the provided {@code llmRequestRef} in-place, and emits the
* events generated by them.
*/
protected Single<RequestProcessingResult> preprocess(
InvocationContext context, LlmRequest llmRequest) {

List<Iterable<Event>> eventIterables = new ArrayList<>();
protected Flowable<Event> preprocess(
InvocationContext context, AtomicReference<LlmRequest> llmRequestRef) {
LlmAgent agent = (LlmAgent) context.agent();

Single<LlmRequest> currentLlmRequest = Single.just(llmRequest);
for (RequestProcessor processor : requestProcessors) {
currentLlmRequest =
currentLlmRequest
.flatMap(request -> processor.processRequest(context, request))
.doOnSuccess(
result -> {
if (result.events() != null) {
eventIterables.add(result.events());
}
})
.map(RequestProcessingResult::updatedRequest);
}

return currentLlmRequest.flatMap(
processedRequest -> {
LlmRequest.Builder updatedRequestBuilder = processedRequest.toBuilder();

RequestProcessor toolsProcessor =
(ctx, req) -> {
LlmRequest.Builder builder = req.toBuilder();
return agent
.canonicalTools(new ReadonlyContext(context))
.canonicalTools(new ReadonlyContext(ctx))
.concatMapCompletable(
tool ->
tool.processLlmRequest(
updatedRequestBuilder, ToolContext.builder(context).build()))
tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build()))
.andThen(
Single.fromCallable(
() -> {
Iterable<Event> combinedEvents = Iterables.concat(eventIterables);
return RequestProcessingResult.create(
updatedRequestBuilder.build(), combinedEvents);
}));
});
() -> RequestProcessingResult.create(builder.build(), ImmutableList.of())));
};

Iterable<RequestProcessor> allProcessors =
Iterables.concat(requestProcessors, ImmutableList.of(toolsProcessor));

return Flowable.fromIterable(allProcessors)
.concatMap(
processor ->
Single.defer(() -> processor.processRequest(context, llmRequestRef.get()))
.doOnSuccess(result -> llmRequestRef.set(result.updatedRequest()))
.flattenAsFlowable(
result -> result.events() != null ? result.events() : ImmutableList.of()));
}

/**
Expand Down Expand Up @@ -343,24 +333,23 @@ private Single<LlmResponse> handleAfterModelCallback(
* @throws IllegalStateException if a transfer agent is specified but not found.
*/
private Flowable<Event> runOneStep(InvocationContext context) {
LlmRequest initialLlmRequest = LlmRequest.builder().build();

return preprocess(context, initialLlmRequest)
.flatMapPublisher(
preResult -> {
LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest();
Iterable<Event> preEvents = preResult.events();
AtomicReference<LlmRequest> llmRequestRef = new AtomicReference<>(LlmRequest.builder().build());
Flowable<Event> preprocessEvents = preprocess(context, llmRequestRef);

return preprocessEvents.concatWith(
Flowable.defer(
() -> {
LlmRequest llmRequestAfterPreprocess = llmRequestRef.get();
if (context.endInvocation()) {
logger.debug("End invocation requested during preprocessing.");
return Flowable.fromIterable(preEvents);
return Flowable.empty();
}

try {
context.incrementLlmCallsCount();
} catch (LlmCallsLimitExceededException e) {
logger.error("LLM calls limit exceeded.", e);
return Flowable.fromIterable(preEvents).concatWith(Flowable.error(e));
return Flowable.error(e);
}

final Event mutableEventTemplate =
Expand All @@ -374,48 +363,44 @@ private Flowable<Event> runOneStep(InvocationContext context) {
// events with fresh timestamp.
mutableEventTemplate.setTimestamp(0L);

Flowable<Event> restOfFlow =
callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate)
.concatMap(
llmResponse ->
postprocess(
context,
mutableEventTemplate,
llmRequestAfterPreprocess,
llmResponse)
.doFinally(
() -> {
String oldId = mutableEventTemplate.id();
mutableEventTemplate.setId(Event.generateEventId());
logger.debug(
"Updated mutableEventTemplate ID from {} to {} for"
+ " next LlmResponse",
oldId,
mutableEventTemplate.id());
}))
.concatMap(
event -> {
Flowable<Event> postProcessedEvents = Flowable.just(event);
if (event.actions().transferToAgent().isPresent()) {
String agentToTransfer = event.actions().transferToAgent().get();
logger.debug("Transferring to agent: {}", agentToTransfer);
BaseAgent rootAgent = context.agent().rootAgent();
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent == null) {
String errorMsg =
"Agent not found for transfer: " + agentToTransfer;
logger.error(errorMsg);
return postProcessedEvents.concatWith(
Flowable.error(new IllegalStateException(errorMsg)));
}
return postProcessedEvents.concatWith(
Flowable.defer(() -> nextAgent.runAsync(context)));
}
return postProcessedEvents;
});

return restOfFlow.startWithIterable(preEvents);
});
return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate)
.concatMap(
llmResponse ->
postprocess(
context,
mutableEventTemplate,
llmRequestAfterPreprocess,
llmResponse)
.doFinally(
() -> {
String oldId = mutableEventTemplate.id();
mutableEventTemplate.setId(Event.generateEventId());
logger.debug(
"Updated mutableEventTemplate ID from {} to {} for"
+ " next LlmResponse",
oldId,
mutableEventTemplate.id());
}))
.concatMap(
event -> {
Flowable<Event> postProcessedEvents = Flowable.just(event);
if (event.actions().transferToAgent().isPresent()) {
String agentToTransfer = event.actions().transferToAgent().get();
logger.debug("Transferring to agent: {}", agentToTransfer);
BaseAgent rootAgent = context.agent().rootAgent();
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
if (nextAgent == null) {
String errorMsg = "Agent not found for transfer: " + agentToTransfer;
logger.error(errorMsg);
return postProcessedEvents.concatWith(
Flowable.error(new IllegalStateException(errorMsg)));
}
return postProcessedEvents.concatWith(
Flowable.defer(() -> nextAgent.runAsync(context)));
}
return postProcessedEvents;
});
}));
}

/**
Expand Down Expand Up @@ -465,14 +450,15 @@ private Flowable<Event> run(InvocationContext invocationContext, int stepsComple
*/
@Override
public Flowable<Event> runLive(InvocationContext invocationContext) {
LlmRequest llmRequest = LlmRequest.builder().build();
AtomicReference<LlmRequest> llmRequestRef = new AtomicReference<>(LlmRequest.builder().build());
Flowable<Event> preprocessEvents = preprocess(invocationContext, llmRequestRef);

return preprocess(invocationContext, llmRequest)
.flatMapPublisher(
preResult -> {
LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest();
return preprocessEvents.concatWith(
Flowable.defer(
() -> {
LlmRequest llmRequestAfterPreprocess = llmRequestRef.get();
if (invocationContext.endInvocation()) {
return Flowable.fromIterable(preResult.events());
return Flowable.empty();
}

String eventIdForSendData = Event.generateEventId();
Expand Down Expand Up @@ -623,10 +609,8 @@ public void onError(Throwable e) {
}
});

return receiveFlow
.takeWhile(event -> !event.actions().endInvocation().orElse(false))
.startWithIterable(preResult.events());
});
return receiveFlow.takeWhile(event -> !event.actions().endInvocation().orElse(false));
}));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.google.adk.flows.llmflows;

import static com.google.adk.flows.llmflows.Functions.REQUEST_CONFIRMATION_FUNCTION_CALL_NAME;
import static com.google.adk.flows.llmflows.Functions.TOOL_CALL_SECURITY_STATES;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
Expand Down Expand Up @@ -157,20 +156,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
toolsToResumeWithArgs.values(),
ImmutableMap.copyOf(toolsToResumeWithConfirmation))
.map(
assembledEvent -> {
clearToolCallSecurityStates(invocationContext, toolsToResumeWithArgs.keySet());

// Create an updated LlmRequest including the new event's content
ImmutableList.Builder<Content> updatedContentsBuilder =
ImmutableList.<Content>builder().addAll(llmRequest.contents());
assembledEvent.content().ifPresent(updatedContentsBuilder::add);

LlmRequest updatedLlmRequest =
llmRequest.toBuilder().contents(updatedContentsBuilder.build()).build();

return RequestProcessingResult.create(
updatedLlmRequest, ImmutableList.of(assembledEvent));
})
assembledEvent ->
RequestProcessingResult.create(llmRequest, ImmutableList.of(assembledEvent)))
.toSingle()
.onErrorReturn(
e -> {
Expand Down Expand Up @@ -255,36 +242,4 @@ private Optional<Map.Entry<String, ToolConfirmation>> maybeCreateToolConfirmatio

return Optional.empty();
}

private void clearToolCallSecurityStates(
InvocationContext invocationContext, Collection<String> processedFunctionCallIds) {
var state = invocationContext.session().state();
Object statesObj = state.get(TOOL_CALL_SECURITY_STATES);

if (statesObj == null) {
return;
}
if (!(statesObj instanceof Map)) {
logger.warn(
"Session key {} does not contain a Map, cannot clear tool states. Found: {}",
TOOL_CALL_SECURITY_STATES,
statesObj.getClass().getName());
return;
}

try {
@SuppressWarnings("unchecked") // safe after instanceof check
Map<String, String> updatedToolCallStates = new HashMap<>((Map<String, String>) statesObj);

// Remove the entries for the function calls that just got processed
processedFunctionCallIds.forEach(updatedToolCallStates::remove);

state.put(TOOL_CALL_SECURITY_STATES, updatedToolCallStates);
} catch (ClassCastException e) {
logger.warn(
"Session key {} has unexpected map types, cannot clear tool states.",
TOOL_CALL_SECURITY_STATES,
e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public class SingleFlow extends BaseLlmFlow {
protected static final ImmutableList<RequestProcessor> REQUEST_PROCESSORS =
ImmutableList.of(
new Basic(),
new RequestConfirmationLlmRequestProcessor(),
new Instructions(),
new Identity(),
new Contents(),
new Examples(),
new RequestConfirmationLlmRequestProcessor(),
CodeExecution.requestProcessor);

protected static final ImmutableList<ResponseProcessor> RESPONSE_PROCESSORS =
Expand Down
Loading