# 输入
torch.jit.trace(
func, # (callable or torch.nn.Module) – function 或者 torch.nn.Module
example_inputs, # (tuple or torch.Tensor) – tracing时作为例子的输入
...
)
# 返回
# 如果输入的是个function,返回的是ScriptFunciton
# 如果输入的是nn.Module或者nn.Module的forward函数,返回的是ScirptModule
# 返回的结果会经过jit编译优化
tracing的graph会忽略控制流,如果不想忽略需要使用torch.jit.script(不是使用trace op执行而是解析Python AST的方式来生成graph)。
import torch
def func(x, h):
new_h = -(x + h)
return new_h
def func2(x):
new_x = -x
return new_x
x, h, m = torch.full((3, 4), 1), torch.full((3, 4), 2), torch.full((3, 4), -3)
print("x:", x)
print("h:", h)
print("m:", m)
traced_func = torch.jit.trace(func, (x, h))
print("traced_func:", traced_func)
print("traced_func.graph:", traced_func.graph)
traced_func2 = torch.jit.trace(func2, (m))
print("traced_func2:", traced_func)
print("traced_func2.graph:", traced_func.graph)
# graph + normal
test_traced_with_normal = traced_func(x, h) + m
print("test_traced_with_normal:", test_traced_with_normal)
# graph + normal + graph
test_traced_with_normal_with_traced = traced_func(x, h) + m + (2*traced_func2(m))
print("test_traced_with_normal_with_traced:", test_traced_with_normal_with_traced)
输出
x: tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
h: tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
m: tensor([[-3., -3., -3., -3.],
[-3., -3., -3., -3.],
[-3., -3., -3., -3.]])
traced_func: <torch.jit.ScriptFunction object at 0x7fc81430c7d0>
traced_func.graph: graph(%x : Float(3:4, 4:1),
%h : Float(3:4, 4:1)):
%2 : int = prim::Constant[value=1]() # test_jit.py:4:0, scalar是aten:add的默认参数scalar的值
%3 : Float(3:4, 4:1) = aten::add(%x, %h, %2) # test_jit.py:4:0 # add
%4 : Float(3:4, 4:1) = aten::neg(%3) # test_jit.py:4:0 # neg
return (%4)
traced_func2: <torch.jit.ScriptFunction object at 0x7fc81430c7d0>
traced_func2.graph: graph(%x : Float(3:4, 4:1),
%h : Float(3:4, 4:1)):
%2 : int = prim::Constant[value=1]() # test_jit.py:4:0
%3 : Float(3:4, 4:1) = aten::add(%x, %h, %2) # test_jit.py:4:0
%4 : Float(3:4, 4:1) = aten::neg(%3) # test_jit.py:4:0
return (%4)
test_traced_with_normal: tensor([[-6., -6., -6., -6.],
[-6., -6., -6., -6.],
[-6., -6., -6., -6.]]) # graph + normal
test_traced_with_normal_with_traced: tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]) # graph + normal + graph
# python -> c++, pybind11
def trace(...):
traced = torch._C._create_function_from_trace(name, func, example_inputs, ...)
_create_function_from_trace调用
std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, strict, force_outplace));
createGraphByTracing调用
tracer::trace(trace_inputs,
[&func](Stack inputs) -> Stack { // 这里传入的是一个lambda
py::tuple py_inputs(num_func_inputs);
for (size_t i = 0; i < num_func_inputs; ++i) {
py_inputs[i] = py::cast(inputs[i]);
}
auto out = func(*py_inputs); // 这里借助pybind11执行python函数调用
return {toTypeInferredIValue(out)};
});
tracer::trace的实现
auto state = std::make_shared<TracingState>();
// 调用了上面的lambda function
auto out_stack = traced_fn(inputs);
构图的实现主要体现在这里,trace状态下调用了特殊的用于trace的op实现,trace op是模板生成的用于构图的op:
// relu op在trace下的实现
Tensor relu(const Tensor & self) {
#if !defined(PYTORCH_DISABLE_TRACING)
torch::jit::Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
if (jit::tracer::isTracing()) {
tracer_state = jit::tracer::getTracingState();
at::Symbol op_name;
op_name = jit::Symbol::fromQualString("aten::relu");
node = tracer_state->graph->create(op_name, /*num_outputs=*/0); // 创建一个node
jit::tracer::recordSourceLocation(node);
jit::tracer::addInputs(node, "self", self); // 增加输入
tracer_state->graph->insertNode(node); // node加入graph
jit::tracer::setTracingState(nullptr);
}
#endif
static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::relu", "");
auto result =c10::Dispatcher::singleton().redispatch<Tensor, const Tensor &>(op, c10::DispatchKey::Tracer, self);
#if !defined(PYTORCH_DISABLE_TRACING)
if (tracer_state) {
jit::tracer::setTracingState(std::move(tracer_state));
jit::tracer::addOutput(node, result); // 增加输出
}
#endif
return result;
}
到此为止创建了图,实现了trace生成逻辑图的功能。