[1] torch fx 官方文档. https://pytorch.org/docs/stable/fx.html
[2] torch.fx: Practical Program Capture and Transformation for Deep Learning in Python. https://arxiv.org/pdf/2112.08429.pdf
[3] torch fx 应用于将 torch 转成 oneflow. https://github.com/siliconflow/onediff/pull/237
[4] 适配PyTorch FX,OneFlow让量化感知训练更简单. https://blog.csdn.net/OneFlow_Official/article/details/129359205
torch fx 是 PyTorch 官方发布的 Python 到 Python 的代码变换工具。如果你想做 torch 代码变换,torch.fx 是首选工具。
torch fx 会将 torch 代码 trace 成 6 种基础的 node 组成的 graph,基于这个 graph 可以方便的做各种变换,变换后的 graph 可以再生成 torch 代码(一个 nn.Module),然后像普通的 nn.Module 一样去执行。
torch 2.0 新发布的 torch.compile(也即 TorchDynamo) 默认将代码转换成了 torch fx 的 GraphModule,进一步加强了 torch fx 的重要性(关联文章:https://strint.github.io/221203-torchdynamo.html)。
关键词:PyTorch,图变换,编译
torch fx 有三块基础功能:
首先定义一个有代表性的 nn.Module,包括了 fx 要处理的6种基础操作:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()