Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from anthropic import types as anthropic_types
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override

from ..utils._google_client_headers import get_tracking_headers
Expand Down Expand Up @@ -359,6 +360,11 @@ class AnthropicLlm(BaseLlm):
model: str = "claude-sonnet-4-20250514"
max_tokens: int = 8192

client: Optional[Union[AsyncAnthropic, AsyncAnthropicVertex]] = Field(
default=None, exclude=True
)
"""An optional pre-configured Anthropic client."""

@classmethod
@override
def supported_models(cls) -> list[str]:
Expand Down Expand Up @@ -495,6 +501,8 @@ async def _generate_content_streaming(

@cached_property
def _anthropic_client(self) -> AsyncAnthropic:
if self.client:
return self.client
return AsyncAnthropic()


Expand Down
7 changes: 6 additions & 1 deletion src/google/adk/models/apigee_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
custom_headers: dict[str, str] | None = None,
retry_options: Optional[types.HttpRetryOptions] = None,
api_type: ApiType | str = ApiType.UNKNOWN,
client: Optional['Client'] = None,
):
"""Initializes the Apigee LLM backend.

Expand Down Expand Up @@ -121,9 +122,10 @@ def __init__(
authorization headers in Vertex AI and Gemini API calls.
retry_options: Allow google-genai to retry failed responses.
api_type: The type of API to use. One of `ApiType` or string.
client: An optional pre-configured google-genai Client.
""" # fmt: skip

super().__init__(model=model, retry_options=retry_options)
super().__init__(model=model, retry_options=retry_options, client=client)
# Validate the model string. Create a helper method to validate the model
# string.
if not _validate_model_string(model):
Expand Down Expand Up @@ -220,6 +222,9 @@ def api_client(self) -> Client:
Returns:
The api client.
"""
if self.client:
return self.client

from google.genai import Client

kwargs_for_http_options = {}
Expand Down
14 changes: 14 additions & 0 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from google.genai import types
from google.genai.errors import ClientError
from pydantic import Field
from typing_extensions import override

from ..utils._google_client_headers import get_tracking_headers
Expand Down Expand Up @@ -91,6 +92,13 @@ class Gemini(BaseLlm):

model: str = 'gemini-2.5-flash'

client: Optional[Any] = Field(default=None, exclude=True)
"""An optional pre-configured google-genai Client.

When provided, this client will be used for all API calls instead of
constructing a new one from environment variables or other attributes.
"""

base_url: Optional[str] = None
"""The base URL for the AI platform service endpoint."""

Expand Down Expand Up @@ -302,6 +310,9 @@ def api_client(self) -> Client:
Returns:
The api client.
"""
if self.client:
return self.client

from google.genai import Client

return Client(
Expand Down Expand Up @@ -334,6 +345,9 @@ def _live_api_version(self) -> str:

@cached_property
def _live_api_client(self) -> Client:
if self.client:
return self.client

from google.genai import Client

return Client(
Expand Down
110 changes: 110 additions & 0 deletions tests/unittests/models/test_custom_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from unittest import mock
from google.genai import Client
from anthropic import AsyncAnthropic
from google.adk.models.google_llm import Gemini
from google.adk.models.anthropic_llm import AnthropicLlm
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from google.genai.types import Content, Part

def test_gemini_custom_client():
"""Verify that Gemini uses the provided custom client."""
mock_client = mock.MagicMock(spec=Client)
gemini = Gemini(model="gemini-1.5-flash", client=mock_client)

assert gemini.api_client is mock_client
# Verify it persists (cached_property)
assert gemini.api_client is mock_client

def test_anthropic_custom_client():
"""Verify that AnthropicLlm uses the provided custom client."""
mock_client = mock.MagicMock(spec=AsyncAnthropic)
anthropic_llm = AnthropicLlm(model="claude-3-5-sonnet-20241022", client=mock_client)

assert anthropic_llm._anthropic_client is mock_client

@pytest.mark.asyncio
async def test_gemini_uses_custom_client_in_call():
"""Verify that Gemini calls use the provided custom client's methods."""
mock_client = mock.MagicMock(spec=Client)
# Mock the nested aio.models.generate_content
mock_aio_models = mock_client.aio.models

gemini = Gemini(model="gemini-1.5-flash", client=mock_client)

request = LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])]
)

# Mock the response
mock_response = types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(role="model", parts=[Part.from_text(text="Hello")]),
finish_reason=types.FinishReason.STOP
)
]
)

async def mock_coro(*args, **kwargs):
return mock_response

mock_aio_models.generate_content.return_value = mock_coro()

# We use stream=False to simplify the mock
responses = [r async for r in gemini.generate_content_async(request, stream=False)]

assert len(responses) == 1
assert responses[0].content.parts[0].text == "Hello"
mock_aio_models.generate_content.assert_called()

@pytest.mark.asyncio
async def test_anthropic_uses_custom_client_in_call():
"""Verify that AnthropicLlm calls use the provided custom client's methods."""
mock_client = mock.MagicMock(spec=AsyncAnthropic)
mock_messages = mock_client.messages

anthropic_llm = AnthropicLlm(model="claude-3-5-sonnet-20241022", client=mock_client)

request = LlmRequest(
model="claude-3-5-sonnet-20241022",
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])]
)

from anthropic import types as anthropic_types
mock_response = anthropic_types.Message(
id="msg_test",
content=[anthropic_types.TextBlock(text="Hello", type="text")],
model="claude-3-5-sonnet-20241022",
role="assistant",
stop_reason="end_turn",
type="message",
usage=anthropic_types.Usage(input_tokens=1, output_tokens=1)
)

async def mock_coro(*args, **kwargs):
return mock_response

mock_messages.create.return_value = mock_coro()

responses = [r async for r in anthropic_llm.generate_content_async(request, stream=False)]

assert len(responses) == 1
assert responses[0].content.parts[0].text == "Hello"
mock_messages.create.assert_called()