"""Chat Model Components Derived from ChatModel/NVIDIA"""
from __future__ import annotations
import base64
import io
import logging
import os
import sys
import urllib.parse
import warnings
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
)
import requests
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
BaseMessage,
ChatMessage,
ChatMessageChunk,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
from langchain_nvidia_ai_endpoints._statics import MODEL_SPECS
_CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
_DictOrPydanticClass = Union[Dict[str, Any], Type[BaseModel]]
_DictOrPydantic = Union[Dict, BaseModel]
try:
import PIL.Image
has_pillow = True
except ImportError:
has_pillow = False
logger = logging.getLogger(__name__)
def _is_url(s: str) -> bool:
try:
result = urllib.parse.urlparse(s)
return all([result.scheme, result.netloc])
except Exception as e:
logger.debug(f"Unable to parse URL: {e}")
return False
def _resize_image(img_data: bytes, max_dim: int = 1024) -> str:
if not has_pillow:
print( # noqa: T201
"Pillow is required to resize images down to reasonable scale."
" Please install it using `pip install pillow`."
" For now, not resizing; may cause NVIDIA API to fail."
)
return base64.b64encode(img_data).decode("utf-8")
image = PIL.Image.open(io.BytesIO(img_data))
max_dim_size = max(image.size)
aspect_ratio = max_dim / max_dim_size
new_h = int(image.size[1] * aspect_ratio)
new_w = int(image.size[0] * aspect_ratio)
resized_image = image.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
output_buffer = io.BytesIO()
resized_image.save(output_buffer, format="JPEG")
output_buffer.seek(0)
resized_b64_string = base64.b64encode(output_buffer.read()).decode("utf-8")
return resized_b64_string
def _url_to_b64_string(image_source: str) -> str:
b64_template = "data:image/png;base64,{b64_string}"
try:
if _is_url(image_source):
response = requests.get(image_source)
response.raise_for_status()
encoded = base64.b64encode(response.content).decode("utf-8")
if sys.getsizeof(encoded) > 200000:
## (VK) Temporary fix. NVIDIA API has a limit of 250KB for the input.
encoded = _resize_image(response.content)
return b64_template.format(b64_string=encoded)
elif image_source.startswith("data:image"):
return image_source
elif os.path.exists(image_source):
with open(image_source, "rb") as f:
encoded = base64.b64encode(f.read()).decode("utf-8")
return b64_template.format(b64_string=encoded)
else:
raise ValueError(
"The provided string is not a valid URL, base64, or file path."
)
except Exception as e:
raise ValueError(f"Unable to process the provided image source: {e}")
[docs]class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, BaseChatModel):
"""NVIDIA chat model.
Example:
.. code-block:: python
from langchain_nvidia_ai_endpoints import ChatNVIDIA
model = ChatNVIDIA(model="llama2_13b")
response = model.invoke("Hello")
"""
_default_model: str = "ai-mixtral-8x7b-instruct"
infer_endpoint: str = Field("{base_url}/chat/completions")
model: str = Field(_default_model, description="Name of the model to invoke")
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
max_tokens: Optional[int] = Field(description="Maximum # of tokens to generate")
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
seed: Optional[int] = Field(description="The seed for deterministic results")
bad: Optional[Sequence[str]] = Field(description="Bad words to avoid (cased)")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")
streaming: bool = Field(True)
@validator("model")
def aifm_deprecated(cls, value: str) -> str:
"""All AI Foundataion Models are deprecate, use API Catalog models instead."""
for model in [value, f"playground_{value}"]:
if model in MODEL_SPECS and MODEL_SPECS[model].get("api_type") == "aifm":
alternative = MODEL_SPECS[model].get(
"alternative", ChatNVIDIA._default_model
)
warnings.warn(
f"{value} is deprecated. Try {alternative} instead.",
DeprecationWarning,
)
return value
@property
def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
return "chat-nvidia-ai-playground"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
responses = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
self._set_callback_out(responses, run_manager)
message = ChatMessage(**self.custom_postprocess(responses))
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation], llm_output=responses)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
return await run_in_executor(
None,
self._generate,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _call(
self,
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> dict:
"""Invoke on a single list of chat messages."""
inputs = self.custom_preprocess(messages)
responses = self.get_generation(inputs=inputs, stop=stop, **kwargs)
return responses
def _get_filled_chunk(self, **kwargs: Any) -> ChatGenerationChunk:
"""Fill the generation chunk."""
return ChatGenerationChunk(message=ChatMessageChunk(**kwargs))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Allows streaming to model!"""
inputs = self.custom_preprocess(messages)
for response in self.get_stream(inputs=inputs, stop=stop, **kwargs):
self._set_callback_out(response, run_manager)
chunk = self._get_filled_chunk(**self.custom_postprocess(response))
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
inputs = self.custom_preprocess(messages)
async for response in self.get_astream(inputs=inputs, stop=stop, **kwargs):
self._set_callback_out(response, run_manager)
chunk = self._get_filled_chunk(**self.custom_postprocess(response))
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def _set_callback_out(
self,
result: dict,
run_manager: Optional[_CallbackManager],
) -> None:
result.update({"model_name": self.model})
if run_manager:
for cb in run_manager.handlers:
if hasattr(cb, "llm_output"):
cb.llm_output = result
[docs] def custom_preprocess(
self, msg_list: Sequence[BaseMessage]
) -> List[Dict[str, str]]:
return [self.preprocess_msg(m) for m in msg_list]
def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
if isinstance(content, str):
return content
string_array: list = []
for part in content:
if isinstance(part, str):
string_array.append(part)
elif isinstance(part, Mapping):
# OpenAI Format
if "type" in part:
if part["type"] == "text":
string_array.append(str(part["text"]))
elif part["type"] == "image_url":
img_url = part["image_url"]
if isinstance(img_url, dict):
if "url" not in img_url:
raise ValueError(
f"Unrecognized message image format: {img_url}"
)
img_url = img_url["url"]
b64_string = _url_to_b64_string(img_url)
string_array.append(f'<img src="{b64_string}" />')
else:
raise ValueError(
f"Unrecognized message part type: {part['type']}"
)
else:
raise ValueError(f"Unrecognized message part format: {part}")
return "".join(string_array)
[docs] def preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]:
if isinstance(msg, BaseMessage):
role_convert = {"ai": "assistant", "human": "user"}
if isinstance(msg, ChatMessage):
role = msg.role
else:
role = msg.type
role = role_convert.get(role, role)
content = self._process_content(msg.content)
return {"role": role, "content": content}
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")
[docs] def custom_postprocess(self, msg: dict) -> dict:
kw_left = msg.copy()
out_dict = {
"role": kw_left.pop("role", "assistant") or "assistant",
"name": kw_left.pop("name", None),
"id": kw_left.pop("id", None),
"content": kw_left.pop("content", "") or "",
"additional_kwargs": {},
"response_metadata": {},
}
for k in list(kw_left.keys()):
if "tool" in k:
out_dict["additional_kwargs"][k] = kw_left.pop(k)
out_dict["response_metadata"] = kw_left
return out_dict
######################################################################################
## Core client-side interfaces
[docs] def get_generation(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> dict:
"""Call to client generate method with call scope"""
stop = kwargs["stop"] = kwargs.get("stop") or self.stop
payload = self.get_payload(inputs=inputs, stream=False, **kwargs)
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
return out
[docs] def get_stream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> Iterator:
"""Call to client stream method with call scope"""
stop = kwargs["stop"] = kwargs.get("stop") or self.stop
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_stream(self.model, stop=stop, payload=payload)
[docs] def get_astream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
stop = kwargs["stop"] = kwargs.get("stop") or self.stop
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_astream(self.model, stop=stop, payload=payload)
[docs] def get_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Generates payload for the _NVIDIAClient API to send to service."""
attr_kwargs = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"seed": self.seed,
"bad": self.bad,
"stop": self.stop,
"labels": self.labels,
}
if model_name := self.get_binding_model():
attr_kwargs["model"] = model_name
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
new_kwargs = {**attr_kwargs, **kwargs}
return self.prep_payload(inputs=inputs, **new_kwargs)
[docs] def prep_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Prepares a message or list of messages for the payload"""
messages = [self.prep_msg(m) for m in inputs]
if kwargs.get("labels"):
# (WFH) Labels are currently (?) always passed as an assistant
# suffix message, but this API seems less stable.
messages += [{"labels": kwargs.pop("labels"), "role": "assistant"}]
if kwargs.get("stop") is None:
kwargs.pop("stop")
return {"messages": messages, **kwargs}
[docs] def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
[docs] def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
function_call: Optional[str] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError(
"Not implemented, awaiting server-side function-recieving API"
" Consider following open-source LLM agent spec techniques:"
" https://huggingface.co/blog/open-source-llms-as-agents"
)
[docs] def with_structured_output(
self,
schema: _DictOrPydanticClass,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
return_type: Literal["parsed", "all"] = "parsed",
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
raise NotImplementedError(
"Not implemented, awaiting server-side function-recieving API"
" Consider following open-source LLM agent spec techniques:"
" https://huggingface.co/blog/open-source-llms-as-agents"
)