ChatGPT

LangChainのソースコードから実装を見てみる(ChatModelのinvoke編)

生成AIのアプリケーション開発をするライブラリであるLangChainについて、いくつかのバージョンアップなどもあり昔の機能が非推奨になったり、様々な更新などがあったため、公式のドキュメントを漁っていてlangchainでモデル呼び出す際の入力について気になったことがあったので詳しく調べてみました。

モデルの呼び出し方

ドキュメントによると次の呼び出し方を実施しています。

import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain_openai import ChatOpenAI

model = ChatOpenAI(model="gpt-4o-mini")

model.invoke("Hello, world!")

invoke関数がモデルの呼び出しを行うもので、引数としては文字列が入力されています。

これは入力した文字列がgpt-4o-miniに送られているのだろうと想定されますので、理解は難しくないです。

しかし、チュートリアルでは以下のような呼び出しになっています。

from langchain_core.messages import HumanMessage, SystemMessage

messages = [
    SystemMessage("Translate the following from English into Italian"),
    HumanMessage("hi!"),
]

model.invoke(messages)

これは2つのメッセージクラスのリストを入力しています。これも会話をリスト形式で保持しているのだろうとは想像できますが、次の形はどうでしょうか。

model.invoke("Hello")
model.invoke([{"role": "user", "content": "Hello"}])
model.invoke([HumanMessage("Hello")])

上記チュートリアルにも記載されていますが、上記はいずれも同じ処理が行われます。

入力方法が多岐に渡っているので、実際どんな入力がサポートされているのか確認してみましょう。

ドキュメントを読むとLanguageModelInputが想定されているようです。

ライブラリを見ていくと次のような形で定義されています。

LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]

つまりPromptValue, str, Sequence[MessageLikeRepresentation]のいずれかが入力されることを期待しているということです。

先ほどの例がそれぞれどれに当てはまるのかというと

  • model.invoke("Hello"):str
  • model.invoke([{"role": "user", "content": "Hello"}]):Sequence[MessageLikeRepresentation]
  • model.invoke([HumanMessage("Hello")]):Sequence[MessageLikeRepresentation]

上記のようになります。

PromptValueは?というとプロンプトテンプレートを使った場合はこの入力に当てはまります。

model.invoke(StringPromptValue(text="Hello"))

実験

以下のコードは実行できるでしょうか?

model.invoke(("user","Hello"))

これは動作しましたが、上記の指定の中にはない形です。

実際どのように動いているのか気になったのでソースコードを確認してみました。

まず、invokeの実装から見てみます。今回気になった部分はinputの部分なのでそこにフォーカスしてみます。

self._convert_input(input)

入力はconvert_inputによって別の値に変換されています。そっちも見てみましょう

確認したかった部分はここですね。

def _convert_input(self, input: LanguageModelInput) -> PromptValue:
    if isinstance(input, PromptValue):
        return input
    elif isinstance(input, str):
        return StringPromptValue(text=input)
    elif isinstance(input, Sequence):
        return ChatPromptValue(messages=convert_to_messages(input))
    else:
        msg = (
            f"Invalid input type {type(input)}. "
            "Must be a PromptValue, str, or list of BaseMessages."
        )
        raise ValueError(msg)

この関数ではLanguageModelInputPromptValueに変換しています。

  • PromptValueはそのまま
  • 文字列はStringPromptValueへ
  • Sequenceは入力をconvert_to_messagesで変換しChatPromptValue

上の2つはすんなり理解できます。最後の関数も見てみましょう。

def convert_to_messages(
    messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
) -> list[BaseMessage]:
    """Convert a sequence of messages to a list of messages.

    Args:
        messages: Sequence of messages to convert.

    Returns:
        list of messages (BaseMessages).
    """
    # Import here to avoid circular imports
    from langchain_core.prompt_values import PromptValue

    if isinstance(messages, PromptValue):
        return messages.to_messages()
    return [_convert_to_message(m) for m in messages]

この関数が呼び出される場合はPromptValueのインスタンスではなくSequenceのインスタンスのはずなので、そこを見てみます。下記の通りリストやタプルの各要素を変換しています。

def _convert_to_message(message: MessageLikeRepresentation) -> BaseMessage:
    """Instantiate a message from a variety of message formats.

    The message format can be one of the following:

    - BaseMessagePromptTemplate
    - BaseMessage
    - 2-tuple of (role string, template); e.g., ("human", "{user_input}")
    - dict: a message dict with role and content keys
    - string: shorthand for ("human", template); e.g., "{user_input}"

    Args:
        message: a representation of a message in one of the supported formats.

    Returns:
        an instance of a message or a message template.

    Raises:
        NotImplementedError: if the message type is not supported.
        ValueError: if the message dict does not contain the required keys.
    """
    if isinstance(message, BaseMessage):
        _message = message
    elif isinstance(message, str):
        _message = _create_message_from_message_type("human", message)
    elif isinstance(message, Sequence) and len(message) == 2:
        # mypy doesn't realise this can't be a string given the previous branch
        message_type_str, template = message  # type: ignore[misc]
        _message = _create_message_from_message_type(message_type_str, template)
    elif isinstance(message, dict):
        msg_kwargs = message.copy()
        try:
            try:
                msg_type = msg_kwargs.pop("role")
            except KeyError:
                msg_type = msg_kwargs.pop("type")
            # None msg content is not allowed
            msg_content = msg_kwargs.pop("content") or ""
        except KeyError as e:
            msg = f"Message dict must contain 'role' and 'content' keys, got {message}"
            msg = create_message(
                message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE
            )
            raise ValueError(msg) from e
        _message = _create_message_from_message_type(
            msg_type, msg_content, **msg_kwargs
        )
    else:
        msg = f"Unsupported message type: {type(message)}"
        msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE)
        raise NotImplementedError(msg)

    return _message

ドキュメントにも記載されている通り、フォーマットに従った内容であれば下記の通り変換がされるようです。

  • メッセージクラスはそのまま
  • 文字列はHumanMessageへ変換
  • Sequenceかつ要素数が2なら、1つ目の要素をロール、2つ目の要素を値にしたメッセージクラスへ変換
  • 辞書型ならrole要素かtype要素をロール、content要素を値にしたメッセージクラスへ

ここで先ほどの入力がどうなっていたか見てみましょう。

model.invoke(("user","Hello"))

先ほどの入力ではSequenceであるタプルの中に文字列が含まれていたので、_convert_to_messageではHumanMessageへの変換がされていることになります。

つまり、ここで処理されているのは下記と同一だったわけです。

model.invoke([HumanMessage("user"),HumanMessage("Hello")])

ソースコードの確認は深い理解のためには重要ですね。

おまけ

今回の実験のためにいろいろ試していたのですが、その際、出力が微妙に異なるパターンがtemperatureが0に指定されているのにも拘らず発生してしまいました。誤差が大きいときでは100トークン前後の誤差が生じていました。

その原因としてはGPUでの計算のリアルタイム最適化の結果、演算順序が変動することで浮動小数点での誤差が発生してしまうためのようです。

これは生成したトークンから次のトークンを生成していく連鎖によって動作するLLMにとっては最終的に大きな誤差が生まれ得ることを意味しています。

これはGPUによる演算を最適化している都合上起こり得るものなので、回避は難しいです。

参考:The New World of LLM Functions: Integrating LLM Technology into the Wolfram Language (2025/1/30閲覧)

ピックアップ記事

  1. 最速で理解したい人のためのIT用語集

関連記事

  1. Python

    pythonを使ったダミーデータ生成

    最初になにか発見したことを総合研究所で発表したり、デモ資料を作ったり…

  2. Python

    わかりやすいPyTorch入門⑤(CNNとデータの拡張)

    CNNとデータの拡張データの拡張とは今回は前回学んだCNNの練習に…

  3. ChatGPT

    生成AIって何?今までのAIと何が違う?

    はじめにAIの世界は年々目覚しい発展を遂げています。最近では…

  4. Google Cloud Platform

    Vertex AI Embeddings for Text によるテキストエンベディングをやってみた…

    こんにちは、石原と申します。自然言語処理(NLP)は近年のA…

  5. ChatGPT

    Open Interpreter+VScode+Dockerで生成AIによるコード開発環境構築(Wi…

    はじめにこんにちは、エクスチュアの石原です。皆さん、…

  6. ChatGPT

    LangChainって何?: 次世代AIアプリケーション構築 その3

    こんにちは、エクスチュアの石原です。こちらは第3回の記事にな…

最近の記事

  1. LangChainのソースコードから実装を見てみる(Chat…
  2. Tableau×Teams連携
  3. AIを使ったマーケティングゲームを作ってみた
  4. Snowflakeや最新データ基盤が広義のマーケティングにも…
  5. 回帰分析はかく語りき Part3 ロジスティック回帰
  1. IT用語集

    メモリ(Memory)、仮想メモリ(Virtual Memory)って何?
  2. Adobe Audience Manager

    Adobe Audience Manager: REST APIを使ってTrai…
  3. ブログ

    インドネシアのデジタルマーケティング
  4. IT用語集

    アーカイブ(Archive)って何?
  5. Data Clean Room

    忘年会シーズンに「DCRごっこ」のご提案
PAGE TOP