torch.jit.trace 的实现

torch.jit.trace 函数

# 输入
torch.jit.trace(
    func,  # (callable or torch.nn.Module) – function 或者 torch.nn.Module
    example_inputs, # (tuple or torch.Tensor) – tracing时作为例子的输入
    ...
)

# 返回
# 如果输入的是个function,返回的是ScriptFunciton
# 如果输入的是nn.Module或者nn.Module的forward函数,返回的是ScirptModule
# 返回的结果会经过jit编译优化

tracing的graph会忽略控制流,如果不想忽略需要使用torch.jit.script(不是使用trace op执行而是解析Python AST的方式来生成graph)。

使用例子

import torch

def func(x, h):
    new_h = -(x + h)
    return new_h

def func2(x):
    new_x = -x
    return new_x

x, h, m = torch.full((3, 4), 1), torch.full((3, 4), 2), torch.full((3, 4), -3)
print("x:", x)
print("h:", h)
print("m:", m)

traced_func = torch.jit.trace(func, (x, h))
print("traced_func:", traced_func)
print("traced_func.graph:", traced_func.graph)

traced_func2 = torch.jit.trace(func2, (m))
print("traced_func2:", traced_func)
print("traced_func2.graph:", traced_func.graph)

# graph + normal
test_traced_with_normal = traced_func(x, h) + m
print("test_traced_with_normal:", test_traced_with_normal)

# graph + normal + graph
test_traced_with_normal_with_traced = traced_func(x, h) + m + (2*traced_func2(m))
print("test_traced_with_normal_with_traced:", test_traced_with_normal_with_traced)

输出

x: tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
h: tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])
m: tensor([[-3., -3., -3., -3.],
        [-3., -3., -3., -3.],
        [-3., -3., -3., -3.]])

traced_func: <torch.jit.ScriptFunction object at 0x7fc81430c7d0>
traced_func.graph: graph(%x : Float(3:4, 4:1),
      %h : Float(3:4, 4:1)):
  %2 : int = prim::Constant[value=1]() # test_jit.py:4:0, scalar是aten:add的默认参数scalar的值
  %3 : Float(3:4, 4:1) = aten::add(%x, %h, %2) # test_jit.py:4:0  # add
  %4 : Float(3:4, 4:1) = aten::neg(%3) # test_jit.py:4:0  # neg
  return (%4)

traced_func2: <torch.jit.ScriptFunction object at 0x7fc81430c7d0>
traced_func2.graph: graph(%x : Float(3:4, 4:1),
      %h : Float(3:4, 4:1)):
  %2 : int = prim::Constant[value=1]() # test_jit.py:4:0
  %3 : Float(3:4, 4:1) = aten::add(%x, %h, %2) # test_jit.py:4:0
  %4 : Float(3:4, 4:1) = aten::neg(%3) # test_jit.py:4:0
  return (%4)

test_traced_with_normal: tensor([[-6., -6., -6., -6.],
        [-6., -6., -6., -6.],
        [-6., -6., -6., -6.]])  # graph + normal
test_traced_with_normal_with_traced: tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])  # graph + normal + graph

实现

# python -> c++, pybind11
def trace(...):
    traced = torch._C._create_function_from_trace(name, func, example_inputs, ...)

_create_function_from_trace调用

std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
    func, typed_inputs, var_lookup_fn, strict, force_outplace));

createGraphByTracing调用

tracer::trace(trace_inputs,
    [&func](Stack inputs) -> Stack { // 这里传入的是一个lambda
    py::tuple py_inputs(num_func_inputs);
    for (size_t i = 0; i < num_func_inputs; ++i) {
         py_inputs[i] = py::cast(inputs[i]);
    }
    auto out = func(*py_inputs);  // 这里借助pybind11执行python函数调用
    return {toTypeInferredIValue(out)};
    });

tracer::trace的实现

auto state = std::make_shared<TracingState>();
// 调用了上面的lambda function
auto out_stack = traced_fn(inputs);

构图的实现主要体现在这里,trace状态下调用了特殊的用于trace的op实现,trace op是模板生成的用于构图的op:

// relu op在trace下的实现
Tensor relu(const Tensor & self) {
  #if !defined(PYTORCH_DISABLE_TRACING)
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::relu");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);  // 创建一个node
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);  // 增加输入
    tracer_state->graph->insertNode(node);  // node加入graph

    jit::tracer::setTracingState(nullptr);
  }
  #endif
  static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::relu", "");
  auto result =c10::Dispatcher::singleton().redispatch<Tensor, const Tensor &>(op, c10::DispatchKey::Tracer, self);
  #if !defined(PYTORCH_DISABLE_TRACING)
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result); // 增加输出
  }
  #endif
  return result;
}

到此为止创建了图,实现了trace生成逻辑图的功能。