from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
Generation,
LLMResult,
)
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_google_vertexai._anthropic_utils import _format_messages_anthropic
from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon
[docs]class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM):
"""Large language models served from Vertex AI Model Garden."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
if self.single_example_per_request and len(instances) > 1:
results = []
for instance in instances:
response = self.client.predict(
endpoint=self.endpoint_path, instances=[instance]
)
results.append(self._parse_prediction(response.predictions[0]))
return LLMResult(
generations=[[Generation(text=result)] for result in results]
)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
return self._parse_response(response)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
if self.single_example_per_request and len(instances) > 1:
responses = []
for instance in instances:
responses.append(
self.async_client.predict(
endpoint=self.endpoint_path, instances=[instance]
)
)
responses = await asyncio.gather(*responses)
return LLMResult(
generations=[
[Generation(text=self._parse_prediction(response.predictions[0]))]
for response in responses
]
)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response)
[docs]class ChatAnthropicVertex(_VertexAICommon, BaseChatModel):
async_client: Any = None #: :meta private:
model_name: Optional[str] = Field(default=None, alias="model") # type: ignore[assignment]
"Underlying model name."
max_output_tokens: int = Field(default=1024, alias="max_tokens")
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
from anthropic import ( # type: ignore[import-not-found]
AnthropicVertex,
AsyncAnthropicVertex,
)
values["client"] = AnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
values["async_client"] = AsyncAnthropicVertex(
project_id=values["project"],
region=values["location"],
max_retries=values["max_retries"],
)
return values
@property
def _default_params(self):
return {
"model": self.model_name,
"max_tokens": self.max_output_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
}
def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
system_message, formatted_messages = _format_messages_anthropic(messages)
params = self._default_params
params.update(kwargs)
if kwargs.get("model_name"):
params["model"] = params["model_name"]
if kwargs.get("model"):
params["model"] = kwargs["model"]
params.pop("model_name", None)
params.update(
{
"system": system_message,
"messages": formatted_messages,
"stop_sequences": stop,
}
)
return {k: v for k, v in params.items() if v is not None}
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
data_dict = data.model_dump()
content = data_dict["content"]
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
else:
msg = AIMessage(content=content)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self.client.messages.create(**params)
return self._format_output(data, **kwargs)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self.async_client.messages.create(**params)
return self._format_output(data, **kwargs)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat-vertexai"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self.client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self.async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk