本文主要介绍如何将Relay IR的计算图(computational graph)/数据流图(dataflow graph)进行可视化输出。
参照TVM #3259的Pull Request,将下列代码复制到python/tvm/relay/visualize.py
中,注意代码做了一定的适应性修改。
from .expr_functor import ExprFunctor
from . import expr as _expr
import networkx as nx
class VisualizeExpr(ExprFunctor):
def __init__(self):
super().__init__()
self.graph = nx.DiGraph()
self.counter = 0
def viz(self, expr):
assert isinstance(expr, _expr.Function)
for param in expr.params:
self.visit(param)
return self.visit(expr.body)
def visit_constant(self, const): # overload this!
pass
def visit_var(self, var):
name = var.name_hint
self.graph.add_node(name)
self.graph.nodes[name]['style'] = 'filled'
self.graph.nodes[name]['fillcolor'] = 'mistyrose'
return var.name_hint
def visit_tuple_getitem(self, get_item):
tuple = self.visit(get_item.tuple_value)
# self.graph.nodes[tuple]
index = get_item.index
# import pdb; pdb.set_trace()
return tuple
def visit_call(self, call):
parents = []
for arg in call.args:
parents.append(self.visit(arg))
# assert isinstance(call.op, _expr.Op)
name = "{}({})".format(call.op.name, self.counter)
self.counter += 1
self.graph.add_node(name)
self.graph.nodes[name]['style'] = 'filled'
self.graph.nodes[name]['fillcolor'] = 'turquoise'
self.graph.nodes[name]['shape'] = 'diamond'
edges = []
for i, parent in enumerate(parents):
edges.append((parent, name, { 'label': 'arg{}'.format(i) }))
self.graph.add_edges_from(edges)
return name
def visualize(expr,mydir="relay_ir.png"):
viz_expr = VisualizeExpr()
viz_expr.viz(expr)
graph = viz_expr.graph
dotg = nx.nx_pydot.to_pydot(graph)
dotg.write_png(mydir)
注意传入的参数需要时一个ExprFunctor
实例,因此原文给出的测试实例调用relay.testing.renet.getworkload()
得到模型并输出对v0.6版本并不可行。
下面复用上次GCN的例子,来生成计算图。
from tvm.relay.visualize import visualize
func = relay.Function(relay.analysis.free_vars(output), output)
visualize(func)
执行上述代码之前需要先安装pydot和graphviz
pip install pydot
apt-get install graphviz
最后会生成对应的relay_ir.png
图片,如下。
实际上VisualizeExpr
就是一个计算图的遍历器(以visit
对结点进行访问),因此只要重载对应的结点函数,就可以实现对应的功能。