Skip to content

vllm.tool_parsers.utils

UnexpectedAstError

Bases: Exception

Raised when the AST structure does not match the expected pythonic tool call format.

Source code in vllm/tool_parsers/utils.py
class UnexpectedAstError(Exception):
    """Raised when the AST structure does not match the expected
    pythonic tool call format."""

    pass

compute_tool_delta

compute_tool_delta(
    previously_sent_args: str,
    new_call: ToolCall,
    index: int,
    withheld_suffix: str,
) -> DeltaToolCall | None

Compute the incremental delta between previously streamed arguments and the current tool call state.

Returns:

Type Description
DeltaToolCall | None

A DeltaToolCall with only the new argument characters, or None

DeltaToolCall | None

if there is no difference from what was previously sent.

Source code in vllm/tool_parsers/utils.py
def compute_tool_delta(
    previously_sent_args: str,
    new_call: ToolCall,
    index: int,
    withheld_suffix: str,
) -> DeltaToolCall | None:
    """Compute the incremental delta between previously streamed arguments
    and the current tool call state.

    Returns:
        A DeltaToolCall with only the new argument characters, or None
        if there is no difference from what was previously sent.
    """
    new_call_args = new_call.function.arguments
    if withheld_suffix:
        if not new_call_args.endswith(withheld_suffix):
            msg = (
                f"Tool call arguments '{new_call_args}' do not end with "
                f"expected withheld suffix '{withheld_suffix}'"
            )
            logger.error(msg)
            raise ValueError(msg)
        new_call_args = new_call_args[: -len(withheld_suffix)]
    if not previously_sent_args:
        return DeltaToolCall(
            id=new_call.id,
            type="function",
            index=index,
            function=DeltaFunctionCall(
                name=new_call.function.name,
                arguments=new_call_args,
            ),
        )

    arg_diff = new_call_args[len(previously_sent_args) :]
    return (
        DeltaToolCall(
            id=None,
            index=index,
            function=DeltaFunctionCall(arguments=arg_diff),
        )
        if arg_diff
        else None
    )

extract_intermediate_diff

extract_intermediate_diff(curr: str, old: str) -> str

Given two strings, extract the difference in the middle between two strings that are known to have a common prefix and/or suffix.

This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-braces are not returned prematurely. The order of arguments IS important - the new version of the partially-parsed JSON must be the first argument, and the secnod argument must be from the previous generation.

What it returns, is tokens that should be streamed to the client.

e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') -> 'ple'

Source code in vllm/tool_parsers/utils.py
def extract_intermediate_diff(curr: str, old: str) -> str:
    """
    Given two strings, extract the difference in the middle between two strings
    that are known to have a common prefix and/or suffix.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely. The order of arguments IS
    important - the new version of the partially-parsed JSON must be the first
    argument, and the secnod argument must be from the previous generation.

    What it returns, is tokens that should be streamed to the client.

    e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
        -> 'ple'

    """
    suffix = find_common_suffix(curr, old)

    old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
    prefix = find_common_prefix(curr, old)
    diff = curr
    if len(suffix):
        diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]

    if len(prefix):
        # replace the prefix only once in case it's mirrored
        diff = diff.replace(prefix, "", 1)

    return diff

find_common_prefix

find_common_prefix(s1: str, s2: str) -> str

Finds a common prefix that is shared between two strings, if there is one. Order of arguments is NOT important.

This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-braces are not returned prematurely.

e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap'

Source code in vllm/tool_parsers/utils.py
def find_common_prefix(s1: str, s2: str) -> str:
    """
    Finds a common prefix that is shared between two strings, if there is one.
    Order of arguments is NOT important.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely.

    e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
    '{"fruit": "ap'
    """
    prefix = ""
    min_length = min(len(s1), len(s2))
    for i in range(0, min_length):
        if s1[i] == s2[i]:
            prefix += s1[i]
        else:
            break
    return prefix

find_common_suffix

find_common_suffix(s1: str, s2: str) -> str

Finds a common suffix shared between two strings, if there is one. Order of arguments is NOT important. Stops when the suffix ends OR it hits an alphanumeric character

e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'

Source code in vllm/tool_parsers/utils.py
def find_common_suffix(s1: str, s2: str) -> str:
    """
    Finds a common suffix shared between two strings, if there is one. Order of
    arguments is NOT important.
    Stops when the suffix ends OR it hits an alphanumeric character

    e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
    """
    suffix = ""
    min_length = min(len(s1), len(s2))
    for i in range(1, min_length + 1):
        if s1[-i] == s2[-i] and not s1[-i].isalnum():
            suffix = s1[-i] + suffix
        else:
            break
    return suffix

get_parameter_value

get_parameter_value(val: expr) -> Any

Extract a Python literal value from an AST expression node.

Handles constants, dicts, lists, and JSON-style name literals (null, true, false) that some models produce instead of Python literals (None, True, False).

Raises:

Type Description
UnexpectedAstError

If the AST node is not a supported literal type.

Source code in vllm/tool_parsers/utils.py
def get_parameter_value(val: ast.expr) -> Any:
    """Extract a Python literal value from an AST expression node.

    Handles constants, dicts, lists, and JSON-style name literals
    (null, true, false) that some models produce instead of Python
    literals (None, True, False).

    Raises:
        UnexpectedAstError: If the AST node is not a supported literal type.
    """
    if isinstance(val, ast.Constant):
        return val.value
    elif isinstance(val, ast.Dict):
        if not all(isinstance(k, ast.Constant) for k in val.keys):
            logger.warning(
                "Dict argument keys are not all literals: %s",
                ast.dump(val),
            )
            raise UnexpectedAstError("Dict tool call arguments must have literal keys")
        return {
            k.value: get_parameter_value(v)  # type: ignore
            for k, v in zip(val.keys, val.values)
        }
    elif isinstance(val, ast.List):
        return [get_parameter_value(v) for v in val.elts]
    elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
        return _JSON_NAME_LITERALS[val.id]
    else:
        logger.warning(
            "Unsupported AST node type in tool call arguments: %s",
            ast.dump(val),
        )
        raise UnexpectedAstError("Tool call arguments must be literals")

handle_single_tool

handle_single_tool(call: Call) -> ToolCall

Convert a single AST function call node into a ToolCall object.

Raises:

Type Description
UnexpectedAstError

If the call node does not have a simple function name (e.g. it's an attribute access or subscript).

Source code in vllm/tool_parsers/utils.py
def handle_single_tool(call: ast.Call) -> ToolCall:
    """Convert a single AST function call node into a ToolCall object.

    Raises:
        UnexpectedAstError: If the call node does not have a simple
            function name (e.g. it's an attribute access or subscript).
    """
    if not isinstance(call.func, ast.Name):
        logger.warning(
            "Tool call has non-simple function name: %s",
            ast.dump(call.func),
        )
        raise UnexpectedAstError("Invalid tool call name")
    function_name = call.func.id
    arguments = {}
    for keyword in call.keywords:
        arguments[keyword.arg] = get_parameter_value(keyword.value)
    return ToolCall(
        type="function",
        function=FunctionCall(
            name=function_name,
            arguments=json.dumps(arguments, ensure_ascii=False),
        ),
    )

make_valid_python

make_valid_python(text: str) -> tuple[str, str] | None

Attempt to close all open brackets/quotes to make partial Python valid.

Used during streaming to parse incomplete tool call expressions by appending the necessary closing characters.

Returns:

Type Description
tuple[str, str] | None

A tuple of (completed_text, added_suffix) if the text can be

tuple[str, str] | None

made valid, or None if the text is too incomplete to complete

tuple[str, str] | None

meaningfully (e.g. mid-parameter-name or mid-dict-key).

Raises:

Type Description
UnexpectedAstError

If mismatched brackets or parentheses are detected.

Source code in vllm/tool_parsers/utils.py
def make_valid_python(text: str) -> tuple[str, str] | None:
    """Attempt to close all open brackets/quotes to make partial Python valid.

    Used during streaming to parse incomplete tool call expressions by
    appending the necessary closing characters.

    Returns:
        A tuple of (completed_text, added_suffix) if the text can be
        made valid, or None if the text is too incomplete to complete
        meaningfully (e.g. mid-parameter-name or mid-dict-key).

    Raises:
        UnexpectedAstError: If mismatched brackets or parentheses
            are detected.
    """
    bracket_stack: list[str] = []
    for index, char in enumerate(text):
        if char in {"[", "(", "{"}:
            bracket_stack.append(char)
        elif char == "]":
            if not bracket_stack or bracket_stack.pop() != "[":
                raise UnexpectedAstError("Mismatched square brackets")
        elif char == ")":
            if not bracket_stack or bracket_stack.pop() != "(":
                raise UnexpectedAstError("Mismatched parentheses")
        elif char == "}":
            if not bracket_stack or bracket_stack.pop() != "{":
                raise UnexpectedAstError("Mismatched curly braces")
        elif char in {"'", '"'}:
            if bracket_stack and bracket_stack[-1] == char:
                if index > 0 and text[index - 1] == "\\":
                    pass
                else:
                    bracket_stack.pop()
            elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
                pass
            else:
                bracket_stack.append(char)

    text = text.rstrip()
    if text.endswith("=") or text.endswith(":"):
        return None
    if bracket_stack and bracket_stack[-1] == "{":
        trailing_dict_text = text[: text.rfind("{")]
        num_keys = trailing_dict_text.count(":")
        num_values = trailing_dict_text.count(",")
        if num_keys <= num_values:
            return None
    if bracket_stack and bracket_stack[-1] == "(":
        trailing_params_text = text[: text.rfind("(")]
        num_full_param_names = trailing_params_text.count("=")
        num_full_param_values = trailing_params_text.count(",")
        if num_full_param_names <= num_full_param_values:
            return None
    if text.endswith(","):
        text = text[:-1]
    if (
        bracket_stack
        and bracket_stack[-1] == "["
        and not text.endswith("[")
        and not text.endswith(")")
    ):
        return None

    _CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
    added_text = ""
    for char in reversed(bracket_stack):
        added_text += _CLOSING[char]

    return text + added_text, added_text