backend 以一个 Python 函数形式存在,函数的接口约定为 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable
.
其中 gm
的类型为 torch.fx.GraphModule
,用户可以使用 torch.fx 提供的工具对 gm
中的计算逻辑做自定义,然后返回新的计算函数,如此实现自定义的后端。
import torch
# 自定义的后端
def my_custom_backend(gm, example_inputs):
return gm.forward
# 一个示例函数
def f(...):
...
# compile 作为函数
f_opt = torch.compile(f, backend=my_custom_backend)
# compile 作为装饰器
@torch.compile(backend=my_custom_backend)
def g(...):
...
register_backend
函数用于注册一个 backend.
from torch._dynamo.optimizations import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
注册完成后,可以使用注册的名字来使用 backend:
torch.compile(model, backend="my_compiler")
options 参数可以给编译器后端传递自定义参数。但是当前(torch 2.0.1)仅限于 torch 内置的 inductor 后端可以使用这个 options 参数。