参考

Custom backends 简介

backend 的后端接口约定

backend 以一个 Python 函数形式存在,函数的接口约定为 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable.

其中 gm 的类型为 torch.fx.GraphModule ,用户可以使用 torch.fx 提供的工具对 gm 中的计算逻辑做自定义,然后返回新的计算函数,如此实现自定义的后端。

backend 的使用

import torch

# 自定义的后端
def my_custom_backend(gm, example_inputs):
    return gm.forward

# 一个示例函数
def f(...):
    ...

# compile 作为函数
f_opt = torch.compile(f, backend=my_custom_backend)

# compile 作为装饰器
@torch.compile(backend=my_custom_backend)
def g(...):
    ...

backend 的注册

register_backend 函数用于注册一个 backend.

from torch._dynamo.optimizations import register_backend

@register_backend
def my_compiler(gm, example_inputs):
    ...

注册完成后,可以使用注册的名字来使用 backend:

torch.compile(model, backend="my_compiler")

torch.compile 的使用细节

给编译后端传递参数

options 参数可以给编译器后端传递自定义参数。但是当前(torch 2.0.1)仅限于 torch 内置的 inductor 后端可以使用这个 options 参数。