diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java index 3983b08a5..ab10a6521 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/MessageConverter.java @@ -22,6 +22,7 @@ import com.google.adk.models.LlmResponse; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import java.net.URI; import java.util.ArrayList; @@ -32,6 +33,8 @@ import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.EmptyUsage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.ChatOptions; @@ -318,11 +321,27 @@ public LlmResponse toLlmResponse(ChatResponse chatResponse, boolean isStreaming) boolean isPartial = isStreaming && isPartialResponse(assistantMessage); boolean isTurnComplete = !isStreaming || isTurnCompleteResponse(chatResponse); - return LlmResponse.builder() - .content(content) - .partial(isPartial) - .turnComplete(isTurnComplete) - .build(); + LlmResponse.Builder responseBuilder = + LlmResponse.builder().content(content).partial(isPartial).turnComplete(isTurnComplete); + + if (chatResponse.getMetadata() != null + && chatResponse.getMetadata().getUsage() != null + && !(chatResponse.getMetadata().getUsage() instanceof EmptyUsage)) { + Usage springUsage = chatResponse.getMetadata().getUsage(); + + GenerateContentResponseUsageMetadata adkUsage = + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(nullSafeInt(springUsage.getPromptTokens())) + .candidatesTokenCount(nullSafeInt(springUsage.getCompletionTokens())) + .totalTokenCount(nullSafeInt(springUsage.getTotalTokens())) + .build(); + responseBuilder.usageMetadata(adkUsage); + } + return responseBuilder.build(); + } + + private int nullSafeInt(Integer value) { + return value != null ? value.intValue() : 0; } /** Determines if an assistant message represents a partial response in streaming. */ diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java index b861a71f2..455190109 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/MessageConverterTest.java @@ -33,6 +33,8 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -237,6 +239,76 @@ void testToLlmResponseFromChatResponseWithToolCalls() { assertThat(functionCallPart.functionCall().get().id()).contains("call_123"); } + @Test + void testUsageMetadataShouldBeEmptyWhenSpringAiMetadataIsNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("intermediate chunk"); + Generation generation = new Generation(assistantMessage); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), null); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, true); + + assertThat(llmResponse.usageMetadata().isEmpty()); + } + + @Test + void testUsageMetadataShouldBeEmptyWhenSpringAiUsageIsNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("intermediate chunk"); + Generation generation = new Generation(assistantMessage); + + ChatResponseMetadata metadata = ChatResponseMetadata.builder().id("resp-no-usage").build(); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), metadata); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, true); + + assertThat(llmResponse.usageMetadata().isEmpty()); + } + + @Test + void testUsageMetadataShouldDefaultToZeroWhenSpringAiTokensAreNull() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("final chunk"); + Generation generation = new Generation(assistantMessage); + + // Anonymous implementation to simulate incomplete provider data where some token counts are + // null + DefaultUsage incompleteUsage = new DefaultUsage(null, null, 42); + ChatResponseMetadata metadata = + ChatResponseMetadata.builder().id("resp-partial-tokens").usage(incompleteUsage).build(); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), metadata); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, false); + + assertThat(llmResponse.usageMetadata().isPresent()); + assertThat(llmResponse.usageMetadata().get().promptTokenCount().orElse(-1)).isEqualTo(0); + assertThat(llmResponse.usageMetadata().get().candidatesTokenCount().orElse(-1)).isEqualTo(0); + assertThat(llmResponse.usageMetadata().get().totalTokenCount().orElse(-1)).isEqualTo(42); + } + + @Test + void testUsageMetadataShouldMapCorrectlyWhenAllFieldsArePresent() { + MessageConverter converter = new MessageConverter(new ObjectMapper()); + AssistantMessage assistantMessage = new AssistantMessage("final chunk"); + Generation generation = new Generation(assistantMessage); + + DefaultUsage completeUsage = new DefaultUsage(15, 25, 40); + ChatResponseMetadata metadata = + ChatResponseMetadata.builder().id("resp-happy-path").usage(completeUsage).build(); + + ChatResponse chatResponse = new ChatResponse(List.of(generation), metadata); + + LlmResponse llmResponse = converter.toLlmResponse(chatResponse, false); + + assertThat(llmResponse.usageMetadata().isPresent()); + assertThat(llmResponse.usageMetadata().get().promptTokenCount().orElse(-1)).isEqualTo(15); + assertThat(llmResponse.usageMetadata().get().candidatesTokenCount().orElse(-1)).isEqualTo(25); + assertThat(llmResponse.usageMetadata().get().totalTokenCount().orElse(-1)).isEqualTo(40); + } + @Test void testToolCallIdPreservedInConversion() { // Create AssistantMessage with tool call including ID