diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml new file mode 100644 index 000000000..e58f9f347 --- /dev/null +++ b/contrib/langchain4j/pom.xml @@ -0,0 +1,384 @@ + + + + + 4.0.0 + + + com.google.adk + google-adk-parent + 0.1.1-SNAPSHOT + + + google-adk-contrib-langchain4j + jar + + Agent Development Kit - Contributions - LangChain4j + https://github.com/google/adk-java + + + The Apache License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0 + + + + scm:git:git@github.com/google:adk-java.git + + scm:git:git@github.com/google:adk-java.git + + git@github.com/google:adk-java.git + + + + Google Inc. + http://www.google.com + + + + Third-party contributions, integrations, and plugins for Agent Development Kit. + + + UTF-8 + 17 + ${java.version} + 0.10.0 + 2.38.0 + 1.33.1 + 2.28.0 + 1.0.0 + 1.11.0 + 4.31.0-RC1 + 5.11.4 + 5.17.0 + 1.6.0 + 2.19.0 + 4.12.0 + 1.0.1 + + + + + dev.langchain4j + langchain4j-bom + ${langchain4j.version} + pom + import + + + org.junit + junit-bom + ${junit.version} + pom + import + + + + + + + dev.langchain4j + langchain4j-core + + + com.google.adk + google-adk + ${project.version} + + + com.google.adk + google-adk-dev + ${project.version} + + + com.google.genai + google-genai + ${google.genai.version} + + + io.modelcontextprotocol.sdk + mcp + ${mcp-schema.version} + + + + + dev.langchain4j + langchain4j-anthropic + test + + + dev.langchain4j + langchain4j-open-ai + test + + + dev.langchain4j + langchain4j-google-ai-gemini + test + + + dev.langchain4j + langchain4j-ollama + test + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + com.google.truth + truth + 1.4.4 + test + + + org.assertj + assertj-core + 3.27.3 + test + + + org.mockito + mockito-core + ${mockito.version} + test + + + + + ossrh + Central Repository OSSRH + https://google.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + ossrh + Central Repository OSSRH for snapshots + https://google.oss.sonatype.org/content/repositories/snapshots + + + + + + com.google.cloud.artifactregistry + artifactregistry-maven-wagon + 2.2.0 + + + + + + maven-clean-plugin + 3.1.0 + + + maven-resources-plugin + 3.0.2 + + + maven-compiler-plugin + 3.13.0 + + ${java.version} + ${java.version} + ${maven.compiler.release} + true + + + com.google.auto.value + auto-value + ${auto-value.version} + + + + + + maven-surefire-plugin + 3.5.3 + + + me.fabriciorby + maven-surefire-junit5-tree-reporter + 0.1.0 + + + + plain + + + **/*Test.java + + + + + maven-jar-plugin + 3.0.2 + + + maven-install-plugin + 2.5.2 + + + maven-deploy-plugin + 3.1.1 + + false + + + + maven-site-plugin + 3.7.1 + + + maven-project-info-reports-plugin + 3.0.0 + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.7 + + + sign-artifacts + verify + + sign + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.6.3 + + all,-missing + true + ${project.build.directory}/javadoc + Agent Development Kit + ${maven.compiler.release} + UTF-8 + + + + attach-javadocs + + jar + + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.7.0 + true + + ossrh + https://google.oss.sonatype.org/ + false + + + + + + + org.jacoco + jacoco-maven-plugin + 0.8.12 + + + + prepare-agent + + + + *MockitoMock* + *$$EnhancerByMockitoWithCGLIB$$* + *$$FastClassByMockitoWithCGLIB$$* + com/sun/tools/attach/* + sun/util/resources/cldr/provider/* + + + + + report + test + + report + + + + HTML + + + + + + + + + + release + + + + org.apache.maven.plugins + maven-source-plugin + + + attach-sources + + jar-no-fork + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + attach-javadocs + + jar + + + + + + + + + diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java new file mode 100644 index 000000000..ba94100fa --- /dev/null +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -0,0 +1,492 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.BaseLlm; +import com.google.adk.models.BaseLlmConnection; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionCallingConfigMode; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import com.google.genai.types.Schema; +import com.google.genai.types.ToolConfig; +import com.google.genai.types.Type; +import dev.langchain4j.Experimental; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.audio.Audio; +import dev.langchain4j.data.image.Image; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.AudioContent; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ImageContent; +import dev.langchain4j.data.message.PdfFileContent; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.message.VideoContent; +import dev.langchain4j.data.pdf.PdfFile; +import dev.langchain4j.data.video.Video; +import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ToolChoice; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +@Experimental +public class LangChain4j extends BaseLlm { + + private static final TypeReference> MAP_TYPE_REFERENCE = new TypeReference<>() { + }; + + private final ChatModel chatModel; + private final StreamingChatModel streamingChatModel; + private final ObjectMapper objectMapper; + + public LangChain4j(ChatModel chatModel) { + super(Objects.requireNonNull(chatModel.defaultRequestParameters().modelName(), + "chat model name cannot be null")); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = null; + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(ChatModel chatModel, String modelName) { + super(Objects.requireNonNull(modelName, + "chat model name cannot be null")); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = null; + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(StreamingChatModel streamingChatModel) { + super(Objects.requireNonNull(streamingChatModel.defaultRequestParameters().modelName(), + "streaming chat model name cannot be null")); + this.chatModel = null; + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { + super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); + this.chatModel = null; + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + + public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { + super(Objects.requireNonNull(modelName, "model name cannot be null")); + this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); + this.objectMapper = new ObjectMapper(); + } + + @Override + public Flowable generateContent(LlmRequest llmRequest, boolean stream) { + if (stream) { + if (this.streamingChatModel == null) { + return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); + } + + ChatRequest chatRequest = toChatRequest(llmRequest); + + return Flowable.create(emitter -> { + streamingChatModel.chat(chatRequest, new StreamingChatResponseHandler() { + @Override + public void onPartialResponse(String s) { + emitter.onNext(LlmResponse.builder() + .content(Content.fromParts(Part.fromText(s))) + .build()); + } + + @Override + public void onCompleteResponse(ChatResponse chatResponse) { + if (chatResponse.aiMessage().hasToolExecutionRequests()) { + AiMessage aiMessage = chatResponse.aiMessage(); + toParts(aiMessage).stream() + .map(Part::functionCall) + .forEach(functionCall -> { + functionCall.ifPresent(function -> { + emitter.onNext(LlmResponse.builder() + .content(Content.fromParts(Part.fromFunctionCall( + function.name().orElse(""), + function.args().orElse(Map.of())))) + .build()); + }); + }); + } + emitter.onComplete(); + } + + @Override + public void onError(Throwable throwable) { + emitter.onError(throwable); + } + }); + }, BackpressureStrategy.BUFFER); + } else { + if (this.chatModel == null) { + return Flowable.error(new IllegalStateException("ChatModel is not configured")); + } + + ChatRequest chatRequest = toChatRequest(llmRequest); + ChatResponse chatResponse = chatModel.chat(chatRequest); + LlmResponse llmResponse = toLlmResponse(chatResponse); + + return Flowable.just(llmResponse); + } + } + + private ChatRequest toChatRequest(LlmRequest llmRequest) { + ChatRequest.Builder requestBuilder = ChatRequest.builder(); + + List toolSpecifications = toToolSpecifications(llmRequest); + requestBuilder.toolSpecifications(toolSpecifications); + + if (llmRequest.config().isPresent()) { + GenerateContentConfig generateContentConfig = llmRequest.config().get(); + + generateContentConfig.temperature().ifPresent(temp -> + requestBuilder.temperature(temp.doubleValue())); + generateContentConfig.topP().ifPresent(topP -> + requestBuilder.topP(topP.doubleValue())); + generateContentConfig.topK().ifPresent(topK -> + requestBuilder.topK(topK.intValue())); + generateContentConfig.maxOutputTokens().ifPresent(requestBuilder::maxOutputTokens); + generateContentConfig.stopSequences().ifPresent(requestBuilder::stopSequences); + generateContentConfig.frequencyPenalty().ifPresent(freqPenalty -> + requestBuilder.frequencyPenalty(freqPenalty.doubleValue())); + generateContentConfig.presencePenalty().ifPresent(presPenalty -> + requestBuilder.presencePenalty(presPenalty.doubleValue())); + + if (generateContentConfig.toolConfig().isPresent()) { + ToolConfig toolConfig = generateContentConfig.toolConfig().get(); + toolConfig.functionCallingConfig().ifPresent(functionCallingConfig -> { + functionCallingConfig.mode().ifPresent(functionMode -> { + if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.AUTO)) { + requestBuilder.toolChoice(ToolChoice.AUTO); + } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.ANY)) { + // TODO check if it's the correct mapping + requestBuilder.toolChoice(ToolChoice.REQUIRED); + functionCallingConfig.allowedFunctionNames().ifPresent(allowedFunctionNames -> { + requestBuilder.toolSpecifications( + toolSpecifications.stream() + .filter(toolSpecification -> + allowedFunctionNames.contains(toolSpecification.name())) + .toList()); + }); + } else if (functionMode.knownEnum().equals(FunctionCallingConfigMode.Known.NONE)) { + requestBuilder.toolSpecifications(List.of()); + } + }); + }); + toolConfig.retrievalConfig().ifPresent(retrievalConfig -> { + // TODO? It exposes Latitude / Longitude, what to do with this? + }); + } + } + + return requestBuilder + .messages(toMessages(llmRequest)) + .build(); + } + + private List toMessages(LlmRequest llmRequest) { + List messages = new ArrayList<>(); + messages.addAll(llmRequest.getSystemInstructions().stream().map(SystemMessage::from).toList()); + messages.addAll(llmRequest.contents().stream().map(this::toChatMessage).toList()); + return messages; + } + + private ChatMessage toChatMessage(Content content) { + String role = content.role().orElseThrow().toLowerCase(); + return switch (role) { + case "user" -> toUserOrToolResultMessage(content); + case "model", "assistant" -> toAiMessage(content); + default -> throw new IllegalStateException("Unexpected role: " + role); + }; + } + + private ChatMessage toUserOrToolResultMessage(Content content) { + ToolExecutionResultMessage toolExecutionResultMessage = null; + ToolExecutionRequest toolExecutionRequest = null; + + List lc4jContents = new ArrayList<>(); + + for (Part part : content.parts().orElse(List.of())) { + if (part.text().isPresent()) { + lc4jContents.add(TextContent.from(part.text().get())); + } else if (part.functionResponse().isPresent()) { + FunctionResponse functionResponse = part.functionResponse().get(); + toolExecutionResultMessage = ToolExecutionResultMessage.from( + functionResponse.id().orElseThrow(), + functionResponse.name().orElseThrow(), + toJson(functionResponse.response().orElseThrow()) + ); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + toolExecutionRequest = ToolExecutionRequest.builder() + .id(functionCall.id().orElseThrow()) + .name(functionCall.name().orElseThrow()) + .arguments(toJson(functionCall.args().orElse(Map.of()))) + .build(); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + + if (blob.mimeType().isEmpty() || blob.data().isEmpty()) { + throw new IllegalArgumentException("Mime type and data required"); + } + + byte[] bytes = blob.data().get(); + String mimeType = blob.mimeType().get(); + + Base64.Encoder encoder = Base64.getEncoder(); + + dev.langchain4j.data.message.Content lc4jContent = null; + + if (mimeType.startsWith("audio/")) { + lc4jContent = AudioContent.from(Audio.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("video/")) { + lc4jContent = VideoContent.from(Video.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("image/")) { + lc4jContent = ImageContent.from(Image.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("application/pdf")) { + lc4jContent = PdfFileContent.from(PdfFile.builder() + .base64Data(encoder.encodeToString(bytes)) + .mimeType(mimeType) + .build()); + } else if (mimeType.startsWith("text/") + || mimeType.equals("application/json") + || mimeType.endsWith("+json") + || mimeType.endsWith("+xml")) { + // TODO are there missing text based mime types? + // TODO should we assume UTF_8? + lc4jContents.add(TextContent.from(new String(bytes, java.nio.charset.StandardCharsets.UTF_8))); + } + + if (lc4jContent != null) { + lc4jContents.add(lc4jContent); + } else { + throw new IllegalArgumentException("Unknown or unhandled mime type: " + mimeType); + } + } else { + throw new IllegalStateException("Text, media or functionCall is expected, but was: " + part); + } + } + + if (toolExecutionResultMessage != null) { + return toolExecutionResultMessage; + } else if (toolExecutionRequest != null){ + return AiMessage.aiMessage(toolExecutionRequest); + } else { + return UserMessage.from(lc4jContents); + } + } + + private AiMessage toAiMessage(Content content) { + List texts = new ArrayList<>(); + List toolExecutionRequests = new ArrayList<>(); + + content.parts().orElse(List.of()).forEach(part -> { + if (part.text().isPresent()) { + texts.add(part.text().get()); + } else if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id(functionCall.id().orElseThrow()) + .name(functionCall.name().orElseThrow()) + .arguments(toJson(functionCall.args().orElseThrow())) + .build(); + toolExecutionRequests.add(toolExecutionRequest); + } else { + throw new IllegalStateException("Either text or functionCall is expected, but was: " + part); + } + }); + + return AiMessage.builder() + .text(String.join("\n", texts)) + .toolExecutionRequests(toolExecutionRequests) + .build(); + } + + private String toJson(Object object) { + try { + return objectMapper.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private List toToolSpecifications(LlmRequest llmRequest) { + List toolSpecifications = new ArrayList<>(); + + llmRequest.tools().values() + .forEach(baseTool -> { + if (baseTool.declaration().isPresent()) { + FunctionDeclaration functionDeclaration = baseTool.declaration().get(); + if (functionDeclaration.parameters().isPresent()) { + Schema schema = functionDeclaration.parameters().get(); + ToolSpecification toolSpecification = ToolSpecification.builder() + .name(baseTool.name()) + .description(baseTool.description()) + .parameters(toParameters(schema)) + .build(); + toolSpecifications.add(toolSpecification); + } else { + // TODO exception or something else? + throw new IllegalStateException("Tool lacking parameters: " + baseTool); + } + } else { + // TODO exception or something else? + throw new IllegalStateException("Tool lacking declaration: " + baseTool); + } + }); + + return toolSpecifications; + } + + private JsonObjectSchema toParameters(Schema schema) { + if (schema.type().isPresent() && schema.type().get().knownEnum().equals(Type.Known.OBJECT)) { + return JsonObjectSchema.builder() + .addProperties(toProperties(schema)) + .required(schema.required().orElse(List.of())) + .build(); + } else { + throw new UnsupportedOperationException("LangChain4jLlm does not support schema of type: " + schema.type()); + } + } + + private Map toProperties(Schema schema) { + Map properties = schema.properties().orElse(Map.of()); + Map result = new HashMap<>(); + properties.forEach((k, v) -> result.put(k, toJsonSchemaElement(v))); + return result; + } + + private JsonSchemaElement toJsonSchemaElement(Schema schema) { + if (schema != null && schema.type().isPresent()) { + Type type = schema.type().get(); + return switch (type.knownEnum()) { + case STRING -> JsonStringSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case NUMBER -> JsonNumberSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case INTEGER -> JsonIntegerSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case BOOLEAN -> JsonBooleanSchema.builder() + .description(schema.description().orElse(null)) + .build(); + case ARRAY -> JsonArraySchema.builder() + .description(schema.description().orElse(null)) + .items(toJsonSchemaElement(schema.items().orElseThrow())) + .build(); + case OBJECT -> toParameters(schema); + case TYPE_UNSPECIFIED -> + throw new UnsupportedFeatureException("LangChain4jLlm does not support schema of type: " + type); + }; + } else { + throw new IllegalArgumentException("Schema type cannot be null or absent"); + } + } + + private LlmResponse toLlmResponse(ChatResponse chatResponse) { + Content content = Content.builder() + .role("model") + .parts(toParts(chatResponse.aiMessage())) + .build(); + + return LlmResponse.builder() + .content(content) + .build(); + } + + private List toParts(AiMessage aiMessage) { + if (aiMessage.hasToolExecutionRequests()) { + List parts = new ArrayList<>(); + aiMessage.toolExecutionRequests().forEach(toolExecutionRequest -> { + FunctionCall functionCall = FunctionCall.builder() + .id(toolExecutionRequest.id() != null ? toolExecutionRequest.id() : UUID.randomUUID().toString()) + .name(toolExecutionRequest.name()) + .args(toArgs(toolExecutionRequest)) + .build(); + Part part = Part.builder() + .functionCall(functionCall) + .build(); + parts.add(part); + }); + return parts; + } else { + Part part = Part.builder() + .text(aiMessage.text()) + .build(); + return List.of(part); + } + } + + private Map toArgs(ToolExecutionRequest toolExecutionRequest) { + try { + return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public BaseLlmConnection connect(LlmRequest llmRequest) { + throw new UnsupportedOperationException("Live connection is not supported for LangChain4j models."); + } +} \ No newline at end of file diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java new file mode 100644 index 000000000..09b240619 --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jIntegrationTest.java @@ -0,0 +1,452 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import static com.google.adk.models.langchain4j.RunLoop.askAgent; +import static com.google.adk.models.langchain4j.RunLoop.askAgentStreaming; +import static org.junit.jupiter.api.Assertions.*; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.LlmAgent; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.tools.AgentTool; +import com.google.adk.tools.FunctionTool; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import dev.langchain4j.model.anthropic.AnthropicChatModel; +import dev.langchain4j.model.anthropic.AnthropicStreamingChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiStreamingChatModel; +import io.reactivex.rxjava3.core.Flowable; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +class LangChain4jIntegrationTest { + + public static final String CLAUDE_3_7_SONNET_20250219 = "claude-3-7-sonnet-20250219"; + public static final String GEMINI_2_0_FLASH = "gemini-2.0-flash"; + public static final String GPT_4_O_MINI = "gpt-4o-mini"; + + @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") + void testSimpleAgent() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LlmAgent agent = LlmAgent.builder() + .name("science-app") + .description("Science teacher agent") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .instruction(""" + You are a helpful science teacher that explains science concepts + to kids and teenagers. + """) + .build(); + + // when + List events = askAgent(agent, "What is a qubit?"); + + // then + assertEquals(1, events.size()); + + Event firstEvent = events.get(0); + assertTrue(firstEvent.content().isPresent()); + + Content content = firstEvent.content().get(); + System.out.println("Answer: " + content.text()); + assertTrue(content.text().contains("quantum")); + } + + @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") + void testSingleAgentWithTools() { + // given + AnthropicChatModel claudeModel = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(claudeModel, CLAUDE_3_7_SONNET_20250219)) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `getWeather` function. + """) + .tools(FunctionTool.create(ToolExample.class, "getWeather")) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + Event eventOne = events.get(0); + Event eventTwo = events.get(1); + Event eventThree = events.get(2); + + // assert the first event is a function call + assertTrue(eventOne.content().isPresent()); + Content contentOne = eventOne.content().get(); + assertTrue(contentOne.parts().isPresent()); + List partsOne = contentOne.parts().get(); + assertEquals(1, partsOne.size()); + Optional functionCall = partsOne.get(0).functionCall(); + assertTrue(functionCall.isPresent()); + assertTrue(functionCall.get().name().isPresent()); + assertEquals("getWeather", functionCall.get().name().get()); + assertTrue(functionCall.get().args().isPresent()); + assertTrue(functionCall.get().args().get().containsKey("city")); + + // assert the second event is a function response + assertTrue(eventTwo.content().isPresent()); + Content contentTwo = eventTwo.content().get(); + assertTrue(contentTwo.parts().isPresent()); + List partsTwo = contentTwo.parts().get(); + assertEquals(1, partsTwo.size()); + Optional functionResponseTwo = partsTwo.get(0).functionResponse(); + assertTrue(functionResponseTwo.isPresent()); + + // assert the third event is the final text response + assertTrue(eventThree.finalResponse()); + assertTrue(eventThree.content().isPresent()); + Content contentThree = eventThree.content().get(); + assertTrue(contentThree.parts().isPresent()); + List partsThree = contentThree.parts().get(); + assertEquals(1, partsThree.size()); + assertTrue(partsThree.get(0).text().isPresent()); + assertTrue(partsThree.get(0).text().get().contains("beautiful")); + } + + @Test + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") + void testAgentTool() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .apiKey(System.getenv("OPENAI_API_KEY")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent weatherAgent = LlmAgent.builder() + .name("weather-agent") + .description("Weather agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to always answer that the weather is sunny and 20°C. + """) + .build(); + + BaseAgent agent = LlmAgent.builder() + .name("friendly-weather-app") + .description("Friend agent that knows about the weather") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly assistant. + + If asked about the weather forecast for a city, + you MUST call the `weather-agent` function. + """) + .tools(AgentTool.create(weatherAgent)) + .build(); + + // when + List events = askAgent(agent, "What's the weather like in Paris?"); + + // then + assertEquals(3, events.size()); + events.forEach(event -> { + assertTrue(event.content().isPresent()); + System.out.printf("%nevent: %s%n", event.stringifyContent()); + }); + + assertEquals(1, events.get(0).functionCalls().size()); + assertEquals("weather-agent", events.get(0).functionCalls().get(0).name().get()); + + assertEquals(1, events.get(1).functionResponses().size()); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().toLowerCase().contains("sunny")); + assertTrue(events.get(1).functionResponses().get(0).response().get().toString().contains("20")); + + { + final var finalEvent = events.get(2); + assertTrue(finalEvent.finalResponse()); + final var text = finalEvent.content().orElseThrow().text(); + assertTrue(text.contains("sunny")); + assertTrue(text.contains("20")); + } + } + + @Test + @EnabledIfEnvironmentVariable(named = "GOOGLE_API_KEY", matches = "\\S+") + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") + void testSubAgent() { + // given + OpenAiChatModel gptModel = OpenAiChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + + LlmAgent greeterAgent = LlmAgent.builder() + .name("greeterAgent") + .description("Friendly agent that greets users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that greets users. + """) + .build(); + + LlmAgent farewellAgent = LlmAgent.builder() + .name("farewellAgent") + .description("Friendly agent that says goodbye to users") + .model(new LangChain4j(gptModel)) + .instruction(""" + You are a friendly that says goodbye to users. + """) + .build(); + + LlmAgent coordinatorAgent = LlmAgent.builder() + .name("coordinator-agent") + .description("Coordinator agent") + .model(GEMINI_2_0_FLASH) + .instruction(""" + Your role is to coordinate 2 agents: + - `greeterAgent`: should reply to messages saying hello, hi, etc. + - `farewellAgent`: should reply to messages saying bye, goodbye, etc. + """) + .subAgents(greeterAgent, farewellAgent) + .build(); + + // when + List hiEvents = askAgent(coordinatorAgent, "Hi"); + List byeEvents = askAgent(coordinatorAgent, "Goodbye"); + + // then + hiEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + byeEvents.forEach(event -> { System.out.println(event.stringifyContent()); }); + + // Assertions for hiEvents + assertEquals(3, hiEvents.size()); + + Event hiEvent1 = hiEvents.get(0); + assertTrue(hiEvent1.content().isPresent()); + assertFalse(hiEvent1.functionCalls().isEmpty()); + assertEquals(1, hiEvent1.functionCalls().size()); + FunctionCall hiFunctionCall = hiEvent1.functionCalls().get(0); + assertTrue(hiFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "greeterAgent")), hiFunctionCall.args()); + + Event hiEvent2 = hiEvents.get(1); + assertTrue(hiEvent2.content().isPresent()); + assertFalse(hiEvent2.functionResponses().isEmpty()); + assertEquals(1, hiEvent2.functionResponses().size()); + FunctionResponse hiFunctionResponse = hiEvent2.functionResponses().get(0); + assertTrue(hiFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), hiFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), hiFunctionResponse.response()); // Empty map for response + + Event hiEvent3 = hiEvents.get(2); + assertTrue(hiEvent3.content().isPresent()); + assertTrue(hiEvent3.content().get().text().toLowerCase().contains("hello")); + assertTrue(hiEvent3.finalResponse()); + + // Assertions for byeEvents + assertEquals(3, byeEvents.size()); + + Event byeEvent1 = byeEvents.get(0); + assertTrue(byeEvent1.content().isPresent()); + assertFalse(byeEvent1.functionCalls().isEmpty()); + assertEquals(1, byeEvent1.functionCalls().size()); + FunctionCall byeFunctionCall = byeEvent1.functionCalls().get(0); + assertTrue(byeFunctionCall.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionCall.name()); + assertEquals(Optional.of(Map.of("agentName", "farewellAgent")), byeFunctionCall.args()); + + Event byeEvent2 = byeEvents.get(1); + assertTrue(byeEvent2.content().isPresent()); + assertFalse(byeEvent2.functionResponses().isEmpty()); + assertEquals(1, byeEvent2.functionResponses().size()); + FunctionResponse byeFunctionResponse = byeEvent2.functionResponses().get(0); + assertTrue(byeFunctionResponse.id().isPresent()); + assertEquals(Optional.of("transferToAgent"), byeFunctionResponse.name()); + assertEquals(Optional.of(Map.of()), byeFunctionResponse.response()); // Empty map for response + + Event byeEvent3 = byeEvents.get(2); + assertTrue(byeEvent3.content().isPresent()); + assertTrue(byeEvent3.content().get().text().toLowerCase().contains("goodbye")); + assertTrue(byeEvent3.finalResponse()); + } + + @Test + @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = "\\S+") + void testSimpleStreamingResponse() { + // given + AnthropicStreamingChatModel claudeStreamingModel = AnthropicStreamingChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(CLAUDE_3_7_SONNET_20250219) + .build(); + + LangChain4j lc4jClaude = new LangChain4j(claudeStreamingModel, CLAUDE_3_7_SONNET_20250219); + + // when + Flowable responses = lc4jClaude.generateContent(LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Why is the sky blue?")))) + .build(), true); + + String fullResponse = String.join("", responses.blockingStream() + .map(llmResponse -> llmResponse.content().get().text()) + .toList()); + + // then + assertTrue(fullResponse.contains("blue")); + assertTrue(fullResponse.contains("Rayleigh")); + assertTrue(fullResponse.contains("scatter")); + } + + @Test + @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = "\\S+") + void testStreamingRunConfig() { + // given + OpenAiStreamingChatModel streamingModel = OpenAiStreamingChatModel.builder() + .baseUrl("http://langchain4j.dev/demo/openai/v1") + .apiKey(Objects.requireNonNullElse(System.getenv("OPENAI_API_KEY"), "demo")) + .modelName(GPT_4_O_MINI) + .build(); + +// AnthropicStreamingChatModel streamingModel = AnthropicStreamingChatModel.builder() +// .apiKey(System.getenv("ANTHROPIC_API_KEY")) +// .modelName(CLAUDE_3_7_SONNET_20250219) +// .build(); + +// GoogleAiGeminiStreamingChatModel streamingModel = GoogleAiGeminiStreamingChatModel.builder() +// .apiKey(System.getenv("GOOGLE_API_KEY")) +// .modelName("gemini-2.0-flash") +// .build(); + + LlmAgent agent = LlmAgent.builder() + .name("streaming-agent") + .description("Friendly science teacher agent") + .instruction(""" + You're a friendly science teacher. + You give concise answers about science topics. + + When someone greets you, respond with "Hello". + If someone asks about the weather, call the `getWeather` function. + """) + .model(new LangChain4j(streamingModel, "GPT_4_O_MINI")) +// .model(new LangChain4j(streamingModel, CLAUDE_3_7_SONNET_20250219)) + .tools(FunctionTool.create(ToolExample.class, "getWeather")) + .build(); + + // when + List eventsHi = askAgentStreaming(agent, "Hi"); + String responseToHi = String.join("", eventsHi.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsQubit = askAgentStreaming(agent, "Tell me about qubits"); + String responseToQubit = String.join("", eventsQubit.stream() + .map(event -> event.content().get().text()) + .toList()); + + List eventsWeather = askAgentStreaming(agent, "What's the weather in Paris?"); + String responseToWeather = String.join("", eventsWeather.stream() + .map(Event::stringifyContent) + .toList()); + + // then + + // Assertions for "Hi" + assertFalse(eventsHi.isEmpty(), "eventsHi should not be empty"); + // Depending on the model and streaming behavior, the number of events can vary. + // If a single "Hello" is expected in one event: + // assertEquals(1, eventsHi.size(), "Expected 1 event for 'Hi'"); + // assertEquals("Hello", responseToHi, "Response to 'Hi' should be 'Hello'"); + // If "Hello" can be streamed in multiple parts: + assertTrue(responseToHi.trim().contains("Hello"), "Response to 'Hi' should be 'Hello'"); + + + // Assertions for "Tell me about qubits" + assertTrue(eventsQubit.size() > 1, "Expected multiple streaming events for 'qubit' question"); + assertTrue(responseToQubit.toLowerCase().contains("qubit"), "Response to 'qubit' should contain 'qubit'"); + assertTrue(responseToQubit.toLowerCase().contains("quantum"), "Response to 'qubit' should contain 'quantum'"); + assertTrue(responseToQubit.toLowerCase().contains("superposition"), "Response to 'qubit' should contain 'superposition'"); + + // Assertions for "What's the weather in Paris?" + assertTrue(eventsWeather.size() > 2, "Expected multiple events for weather question (function call, response, text)"); + + // Check for function call + Optional functionCallEvent = eventsWeather.stream() + .filter(e -> !e.functionCalls().isEmpty()) + .findFirst(); + assertTrue(functionCallEvent.isPresent(), "Should contain a function call event for weather"); + FunctionCall fc = functionCallEvent.get().functionCalls().get(0); + assertEquals(Optional.of("getWeather"), fc.name(), "Function call name should be 'getWeather'"); + assertTrue(fc.args().isPresent() && "Paris".equals(fc.args().get().get("city")), "Function call should be for 'Paris'"); + + // Check for function response + Optional functionResponseEvent = eventsWeather.stream() + .filter(e -> !e.functionResponses().isEmpty()) + .findFirst(); + assertTrue(functionResponseEvent.isPresent(), "Should contain a function response event for weather"); + FunctionResponse fr = functionResponseEvent.get().functionResponses().get(0); + assertEquals(Optional.of("getWeather"), fr.name(), "Function response name should be 'getWeather'"); + assertTrue(fr.response().isPresent()); + Map weatherResponseMap = (Map) fr.response().get(); + assertEquals("Paris", weatherResponseMap.get("city")); + assertTrue(weatherResponseMap.get("forecast").toString().contains("beautiful and sunny")); + + // Check the final aggregated text response + // Consolidate text parts from events that are not function calls or responses + String finalWeatherTextResponse = eventsWeather.stream() + .filter(event -> event.functionCalls().isEmpty() && event.functionResponses().isEmpty() && event.content().isPresent() && event.content().get().text() != null) + .map(event -> event.content().get().text()) + .collect(java.util.stream.Collectors.joining()) + .trim(); + + assertTrue(finalWeatherTextResponse.contains("Paris"), "Final weather response should mention Paris"); + assertTrue(finalWeatherTextResponse.toLowerCase().contains("beautiful and sunny"), "Final weather response should mention 'beautiful and sunny'"); + assertTrue(finalWeatherTextResponse.contains("10"), "Final weather response should mention '10'"); + assertTrue(finalWeatherTextResponse.contains("24"), "Final weather response should mention '24'"); + + // You can also assert on the concatenated `responseToWeather` if it's meant to capture the full interaction text + assertTrue(responseToWeather.contains("Function Call") && responseToWeather.contains("getWeather") && responseToWeather.contains("Paris")); + assertTrue(responseToWeather.contains("Function Response") && responseToWeather.contains("beautiful and sunny weather")); + assertTrue(responseToWeather.contains("sunny")); + assertTrue(responseToWeather.contains("24")); + } +} diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java new file mode 100644 index 000000000..75ec6ac41 --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -0,0 +1,590 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.tools.FunctionTool; +import com.google.genai.types.*; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatModel; +import dev.langchain4j.model.chat.StreamingChatModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import io.reactivex.rxjava3.core.Flowable; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +class LangChain4jTest { + + private static final String MODEL_NAME = "test-model"; + + private ChatModel chatModel; + private StreamingChatModel streamingChatModel; + private LangChain4j langChain4j; + private LangChain4j streamingLangChain4j; + + @BeforeEach + void setUp() { + chatModel = mock(ChatModel.class); + streamingChatModel = mock(StreamingChatModel.class); + + langChain4j = new LangChain4j(chatModel, MODEL_NAME); + streamingLangChain4j = new LangChain4j(streamingChatModel, MODEL_NAME); + } + + @Test + @DisplayName("Should generate content using non-streaming chat model") + void testGenerateContentWithChatModel() { + // Given + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello, how can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + + // Then + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello, how can I help you?"); + + // Verify the request conversion + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.messages()).hasSize(1); + assertThat(capturedRequest.messages().get(0)).isInstanceOf(UserMessage.class); + } + + @Test + @DisplayName("Should handle function calls in LLM responses") + void testGenerateContentWithFunctionCall() { + // Given + // Create a mock FunctionTool + final FunctionTool weatherTool = mock(FunctionTool.class); + when(weatherTool.name()).thenReturn("getWeather"); + when(weatherTool.description()).thenReturn("Get weather for a city"); + + // Create a mock FunctionDeclaration + final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class); + when(weatherTool.declaration()).thenReturn(Optional.of(functionDeclaration)); + + // Create a mock Schema + final Schema schema = mock(Schema.class); + when(functionDeclaration.parameters()).thenReturn(Optional.of(schema)); + + // Create a mock Type + final Type type = mock(Type.class); + when(schema.type()).thenReturn(Optional.of(type)); + when(type.knownEnum()).thenReturn(Type.Known.OBJECT); + + // Create a mock for schema properties and required fields + when(schema.properties()).thenReturn(Optional.of(Map.of("city", schema))); + when(schema.required()).thenReturn(Optional.of(List.of("city"))); + + // Create a real LlmRequest + // We'll use a real LlmRequest but we won't add any tools to it + // This is because we don't know the exact return type of LlmRequest.tools() + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .build(); + + // Mock the AI response with a function call + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + //language=json + .arguments("{\"city\":\"Paris\"}") + .build(); + + final List toolExecutionRequests = List.of(toolExecutionRequest); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(toolExecutionRequests) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final Flowable responseFlowable = langChain4j.generateContent(llmRequest, false); + final LlmResponse response = responseFlowable.blockingFirst(); + + // Then + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); + } + + @Test + @DisplayName("Should handle streaming responses correctly") + void testGenerateContentWithStreamingChatModel() { + // Given + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .build(); + + // Create a list to collect the responses + final List responses = new ArrayList<>(); + + // Set up the mock to capture and store the handler + final StreamingChatResponseHandler[] handlerRef = new StreamingChatResponseHandler[1]; + + doAnswer(invocation -> { + // Store the handler for later use + handlerRef[0] = invocation.getArgument(1); + return null; + }).when(streamingChatModel).chat(any(ChatRequest.class), any(StreamingChatResponseHandler.class)); + + // When + final Flowable responseFlowable = streamingLangChain4j.generateContent(llmRequest, true); + + // Subscribe to the flowable to collect responses + final var disposable = responseFlowable.subscribe(responses::add); + + // Verify the streaming model was called + verify(streamingChatModel).chat(any(ChatRequest.class), any(StreamingChatResponseHandler.class)); + + // Get the captured handler + final StreamingChatResponseHandler handler = handlerRef[0]; + + // Simulate streaming responses + handler.onPartialResponse("Hello"); + handler.onPartialResponse(", how"); + handler.onPartialResponse(" can I help"); + handler.onPartialResponse(" you?"); + + // Simulate a function call in the complete response + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + + // Simulate completion with a function call + handler.onCompleteResponse(chatResponse); + + // Then + assertThat(responses).hasSize(5); // 4 partial responses + 1 function call + + // Verify the partial responses + assertThat(responses.get(0).content().orElseThrow().text()).isEqualTo("Hello"); + assertThat(responses.get(1).content().orElseThrow().text()).isEqualTo(", how"); + assertThat(responses.get(2).content().orElseThrow().text()).isEqualTo(" can I help"); + assertThat(responses.get(3).content().orElseThrow().text()).isEqualTo(" you?"); + + // Verify the function call + assertThat(responses.get(4).content().orElseThrow().parts().orElseThrow()).hasSize(1); + assertThat(responses.get(4).content().orElseThrow().parts().orElseThrow().get(0).functionCall()).isPresent(); + final FunctionCall functionCall = responses.get(4).content().orElseThrow().parts().orElseThrow().get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args().orElseThrow()).containsEntry("city", "Paris"); + + disposable.dispose(); + } + + @Test + @DisplayName("Should pass configuration options to LangChain4j") + void testGenerateContentWithConfigOptions() { + // Given + final GenerateContentConfig config = GenerateContentConfig.builder() + .temperature(0.7f) + .topP(0.9f) + .topK(40f) + .maxOutputTokens(100) + .presencePenalty(0.5f) + .build(); + + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Hello")) + )) + .config(config) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello, how can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final var llmResponse = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Assert the llmResponse + assertThat(llmResponse).isNotNull(); + assertThat(llmResponse.content()).isPresent(); + assertThat(llmResponse.content().get().text()).isEqualTo("Hello, how can I help you?"); + + // Assert the request configuration + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + assertThat(capturedRequest.temperature()).isCloseTo(0.7, offset(0.001)); + assertThat(capturedRequest.topP()).isCloseTo(0.9, offset(0.001)); + assertThat(capturedRequest.topK()).isEqualTo(40); + assertThat(capturedRequest.maxOutputTokens()).isEqualTo(100); + assertThat(capturedRequest.presencePenalty()).isCloseTo(0.5, offset(0.001)); + } + + @Test + @DisplayName("Should throw UnsupportedOperationException when connect is called") + void testConnectThrowsUnsupportedOperationException() { + // Given + final LlmRequest llmRequest = LlmRequest.builder().build(); + + // When/Then + assertThatThrownBy(() -> langChain4j.connect(llmRequest)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("Live connection is not supported for LangChain4j models."); + } + + @Test + @DisplayName("Should handle tool calling in LLM responses") + void testGenerateContentWithToolCalling() { + // Given + // Create a mock ChatResponse with a tool execution request + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // Create a LlmRequest with a user message + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .build(); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected function call + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); + + // Verify the ChatModel was called + verify(chatModel).chat(any(ChatRequest.class)); + } + + + @Test + @DisplayName("Should set ToolChoice to AUTO when FunctionCallingConfig mode is AUTO") + void testGenerateContentWithAutoToolChoice() { + // Given + // Create a FunctionCallingConfig with mode AUTO + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.AUTO); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response + final AiMessage aiMessage = AiMessage.from("It's sunny in Paris"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("It's sunny in Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool choice is AUTO + assertThat(capturedRequest.toolChoice()).isEqualTo(dev.langchain4j.model.chat.request.ToolChoice.AUTO); + } + + @Test + @DisplayName("Should set ToolChoice to REQUIRED when FunctionCallingConfig mode is ANY") + void testGenerateContentWithAnyToolChoice() { + // Given + // Create a FunctionCallingConfig with mode ANY and allowed function names + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.ANY); + when(functionCallingConfig.allowedFunctionNames()).thenReturn(Optional.of(List.of("getWeather"))); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response with a function call + final ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() + .id("123") + .name("getWeather") + .arguments("{\"city\":\"Paris\"}") + .build(); + + final AiMessage aiMessage = AiMessage.builder() + .text("") + .toolExecutionRequests(List.of(toolExecutionRequest)) + .build(); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected function call + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isPresent(); + + final List parts = response.content().get().parts().orElseThrow(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).functionCall()).isPresent(); + + final FunctionCall functionCall = parts.get(0).functionCall().orElseThrow(); + assertThat(functionCall.name()).isEqualTo(Optional.of("getWeather")); + assertThat(functionCall.args()).isPresent(); + assertThat(functionCall.args().get()).containsEntry("city", "Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool choice is REQUIRED (mapped from ANY) + assertThat(capturedRequest.toolChoice()).isEqualTo(dev.langchain4j.model.chat.request.ToolChoice.REQUIRED); + } + + @Test + @DisplayName("Should disable tool calling when FunctionCallingConfig mode is NONE") + void testGenerateContentWithNoneToolChoice() { + // Given + // Create a FunctionCallingConfig with mode NONE + final FunctionCallingConfig functionCallingConfig = mock(FunctionCallingConfig.class); + final FunctionCallingConfigMode functionMode = mock(FunctionCallingConfigMode.class); + + when(functionCallingConfig.mode()).thenReturn(Optional.of(functionMode)); + when(functionMode.knownEnum()).thenReturn(FunctionCallingConfigMode.Known.NONE); + + // Create a ToolConfig with the FunctionCallingConfig + final ToolConfig toolConfig = mock(ToolConfig.class); + when(toolConfig.functionCallingConfig()).thenReturn(Optional.of(functionCallingConfig)); + + // Create a GenerateContentConfig with the ToolConfig + final GenerateContentConfig config = GenerateContentConfig.builder() + .toolConfig(toolConfig) + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("What's the weather in Paris?")) + )) + .config(config) + .build(); + + // Mock the AI response with text (no function call) + final AiMessage aiMessage = AiMessage.from("It's sunny in Paris"); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains text (no function call) + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("It's sunny in Paris"); + + // Verify the request was built correctly with the tool config + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify tool specifications are empty + assertThat(capturedRequest.toolSpecifications()).isEmpty(); + } + + @Test + @DisplayName("Should handle structured responses with JSON schema") + void testGenerateContentWithStructuredResponseJsonSchema() { + // Given + // Create a JSON schema for the structured response + final JsonObjectSchema responseSchema = JsonObjectSchema.builder() + .addProperty("name", JsonStringSchema.builder().build()) + .addProperty("age", JsonStringSchema.builder().build()) + .addProperty("city", JsonStringSchema.builder().build()) + .build(); + + // Create a GenerateContentConfig without responseSchema + final GenerateContentConfig config = GenerateContentConfig.builder() + .build(); + + // Create a LlmRequest with the config + final LlmRequest llmRequest = LlmRequest.builder() + .contents(List.of( + Content.fromParts(Part.fromText("Give me information about John Doe")) + )) + .config(config) + .build(); + + // Mock the AI response with structured JSON data + final String jsonResponse = """ + { + "name": "John Doe", + "age": "30", + "city": "New York" + } + """; + final AiMessage aiMessage = AiMessage.from(jsonResponse); + + final ChatResponse chatResponse = mock(ChatResponse.class); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response contains the expected JSON data + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo(jsonResponse); + + // Verify the request was built correctly + final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ChatRequest.class); + verify(chatModel).chat(requestCaptor.capture()); + final ChatRequest capturedRequest = requestCaptor.getValue(); + + // Verify the request contains the expected messages + assertThat(capturedRequest.messages()).hasSize(1); + assertThat(capturedRequest.messages().get(0)).isInstanceOf(UserMessage.class); + final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); + assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); + } +} diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java new file mode 100644 index 000000000..bfba6d9fb --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/RunLoop.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.events.Event; +import com.google.adk.runner.InMemoryRunner; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.Session; +import com.google.genai.types.Content; +import com.google.genai.types.Part; + +import java.util.ArrayList; +import java.util.List; + +public class RunLoop { + public static List askAgent(BaseAgent agent, Object... messages) { + return runLoop(agent, false, messages); + } + + public static List askAgentStreaming(BaseAgent agent, Object... messages) { + return runLoop(agent, true, messages); + } + + public static List runLoop(BaseAgent agent, boolean streaming, Object... messages) { + ArrayList allEvents = new ArrayList<>(); + + Runner runner = new InMemoryRunner(agent, agent.name()); + Session session = runner.sessionService().createSession(agent.name(), "user132").blockingGet(); + + for (Object message : messages) { + Content messageContent = null; + if (message instanceof String) { + messageContent = Content.fromParts(Part.fromText((String) message)); + } else if (message instanceof Part) { + messageContent = Content.fromParts((Part) message); + } else if (message instanceof Content) { + messageContent = (Content) message; + } + allEvents.addAll( + runner.runAsync(session, messageContent, + RunConfig.builder() + .setStreamingMode(streaming ? RunConfig.StreamingMode.SSE : RunConfig.StreamingMode.NONE) + .build()) + .blockingStream() + .toList() + ); + } + + return allEvents; + } +} diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java new file mode 100644 index 000000000..a6f92b78e --- /dev/null +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/ToolExample.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.models.langchain4j; + +import com.google.adk.tools.Annotations; + +import java.util.Map; + +public class ToolExample { + @Annotations.Schema(description = "Function to get the weather forecast for a given city") + public static Map getWeather( + @Annotations.Schema(name = "city", description = "The city to get the weather forecast for") + String city) { + + return Map.of( + "city", city, + "forecast", "a beautiful and sunny weather", + "temperature", "from 10°C in the morning up to 24°C in the afternoon" + ); + } +} diff --git a/pom.xml b/pom.xml index 92bced1c7..4e2e73733 100644 --- a/pom.xml +++ b/pom.xml @@ -28,6 +28,7 @@ core dev + contrib/langchain4j