TVM - Tensor Expression

本文以向量加法为例,记录TVM最最基本的Tensor Expression的使用,以及简单的编译运行流程。

下面的代码为简单的向量加法,参考自Tensor Expression官方教程,在TVM v0.6下执行(注意与v0.7dev的模块有区别)。

import tvm
import numpy as np

# Tensor Expression
# args: (shape, label)
A = tvm.placeholder((10,), name='A')
B = tvm.placeholder((10,), name='B')
# args: (shape, function, label)
# function represented in lambda expression (element-wise)
#     lambda axis1, axis2, ... : f(axis1, axis2, ...)
C = tvm.compute((10,), lambda i: A[i] + B[i], name="C")

# generate schedule
s = tvm.create_schedule(C.op)
# print low level codes
print(tvm.lower(s,[A,B,C],simple_mode=True))

其中placeholder代表特定维度的张量,最后生成的代码会要求用户输入两个tensor,如果是C++代码,则要求用户输入两个float*。注意,你会发现这个过程实际上是没有计算发生的,而只是定义了计算如何进行

输出的low-level代码如下所示,还是相当好理解的,即i从0到10循环,循环内每次计算C[i]的值。

produce C {
  for (i, 0, 10) {
    C[i] = (A[i] + B[i])
  }
}

一些常用的循环优化API可以在这里找到。这里使用循环分割split作为尝试。

split(parent[, factor, nparts])
Split the stage either by factor providing outer scope, or both. Return outer, inner vaiable of iteration.
bx, tx = s[C].split(C.op.axis[0],factor=2)
print(tvm.lower(s,[A,B,C],simple_mode=True))

由于对schedule的操作是原地变换,因此可以直接输出lower后的代码,发现确实已经改变了,原来的循环体变成5*2的循环。

produce C {
  for (i.outer, 0, 5) {
    for (i.inner, 0, 2) {
      C[((i.outer*2) + i.inner)] = (A[((i.outer*2) + i.inner)] + B[((i.outer*2) + i.inner)])
    }
  }
}

当然这一个schedule变换并没有带来任何好处,只是为了说明Tensor Expression应该怎么用。

之后就可以调用build生成目标代码了,可以设置targettarget_host

tgt = "c" # "llvm", "cuda"
fadd = tvm.build(s,[A,B,C],target=tgt,name="myadd")

然后可以创造运行时环境,进行运行测试。

n = 10
ctx = tvm.context(tgt,0)
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n,dtype=C.dtype), ctx)
fadd(a,b,c) # run

# test
tvm.testing.assert_allclose(c.asnumpy(),a.asnumpy() + b.asnumpy())
print(fadd.get_source())

生成的C代码如下

for (int32_t i_outer = 0; i_outer < 5; ++i_outer) {
  for (int32_t i_inner = 0; i_inner < 2; ++i_inner) {
    C[((i_outer * 2) + i_inner)] = (A[((i_outer * 2) + i_inner)] + B[((i_outer * 2) + i_inner)]);
  }
}
生成的myadd.c完整代码如下
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
extern void* __tvm_module_ctx = NULL;
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t myadd( void* args,  void* arg_type_ids, int32_t num_args) {
  if (!((num_args == 3))) {
    TVMAPISetLastError("myadd: num_args should be 3");
    return -1;
  }
  void* arg0 = (((TVMValue*)args)[0].v_handle);
  int32_t arg0_code = (( int32_t*)arg_type_ids)[0];
  void* arg1 = (((TVMValue*)args)[1].v_handle);
  int32_t arg1_code = (( int32_t*)arg_type_ids)[1];
  void* arg2 = (((TVMValue*)args)[2].v_handle);
  int32_t arg2_code = (( int32_t*)arg_type_ids)[2];
  float* A = (float*)(((TVMArray*)arg0)[0].data);
  int64_t* arg0_shape = (int64_t*)(((TVMArray*)arg0)[0].shape);
  int64_t* arg0_strides = (int64_t*)(((TVMArray*)arg0)[0].strides);
  int32_t dev_type = (((TVMArray*)arg0)[0].ctx.device_type);
  int32_t dev_id = (((TVMArray*)arg0)[0].ctx.device_id);
  float* B = (float*)(((TVMArray*)arg1)[0].data);
  int64_t* arg1_shape = (int64_t*)(((TVMArray*)arg1)[0].shape);
  int64_t* arg1_strides = (int64_t*)(((TVMArray*)arg1)[0].strides);
  float* C = (float*)(((TVMArray*)arg2)[0].data);
  int64_t* arg2_shape = (int64_t*)(((TVMArray*)arg2)[0].shape);
  int64_t* arg2_strides = (int64_t*)(((TVMArray*)arg2)[0].strides);
  if (!(((((arg0_code == 3) || (arg0_code == 13)) || (arg0_code == 7)) || (arg0_code == 4)))) {
    TVMAPISetLastError("myadd: Expect arg[0] to be pointer");
    return -1;
  }
  if (!(((((arg1_code == 3) || (arg1_code == 13)) || (arg1_code == 7)) || (arg1_code == 4)))) {
    TVMAPISetLastError("myadd: Expect arg[1] to be pointer");
    return -1;
  }
  if (!(((((arg2_code == 3) || (arg2_code == 13)) || (arg2_code == 7)) || (arg2_code == 4)))) {
    TVMAPISetLastError("myadd: Expect arg[2] to be pointer");
    return -1;
  }
  if (!((dev_type == 1))) {
    TVMAPISetLastError("device_type need to be 1");
    return -1;
  }
  if (!((1 == (((TVMArray*)arg0)[0].ndim)))) {
    TVMAPISetLastError("arg0.ndim is expected to equal 1");
    return -1;
  }
  if (!(((((((TVMArray*)arg0)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg0)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg0)[0].dtype.lanes) == (uint16_t)1)))) {
    TVMAPISetLastError("arg0.dtype is expected to be float32");
    return -1;
  }
  if (!((10 == ((int32_t)arg0_shape[0])))) {
    TVMAPISetLastError("Argument arg0.shape[0] has an unsatisfied constraint");
    return -1;
  }
  if (!(arg0_strides == NULL)) {
    if (!((1 == ((int32_t)arg0_strides[0])))) {
      TVMAPISetLastError("arg0.strides: expected to be compact array");
      return -1;
    }
  }
  if (!(((uint64_t)0 == (((TVMArray*)arg0)[0].byte_offset)))) {
    TVMAPISetLastError("Argument arg0.byte_offset has an unsatisfied constraint");
    return -1;
  }
  if (!((1 == (((TVMArray*)arg1)[0].ndim)))) {
    TVMAPISetLastError("arg1.ndim is expected to equal 1");
    return -1;
  }
  if (!(((((((TVMArray*)arg1)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg1)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg1)[0].dtype.lanes) == (uint16_t)1)))) {
    TVMAPISetLastError("arg1.dtype is expected to be float32");
    return -1;
  }
  if (!((10 == ((int32_t)arg1_shape[0])))) {
    TVMAPISetLastError("Argument arg1.shape[0] has an unsatisfied constraint");
    return -1;
  }
  if (!(arg1_strides == NULL)) {
    if (!((1 == ((int32_t)arg1_strides[0])))) {
      TVMAPISetLastError("arg1.strides: expected to be compact array");
      return -1;
    }
  }
  if (!(((uint64_t)0 == (((TVMArray*)arg1)[0].byte_offset)))) {
    TVMAPISetLastError("Argument arg1.byte_offset has an unsatisfied constraint");
    return -1;
  }
  if (!((1 == (((TVMArray*)arg1)[0].ctx.device_type)))) {
    TVMAPISetLastError("Argument arg1.device_type has an unsatisfied constraint");
    return -1;
  }
  if (!((dev_id == (((TVMArray*)arg1)[0].ctx.device_id)))) {
    TVMAPISetLastError("Argument arg1.device_id has an unsatisfied constraint");
    return -1;
  }
  if (!((1 == (((TVMArray*)arg2)[0].ndim)))) {
    TVMAPISetLastError("arg2.ndim is expected to equal 1");
    return -1;
  }
  if (!(((((((TVMArray*)arg2)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg2)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg2)[0].dtype.lanes) == (uint16_t)1)))) {
    TVMAPISetLastError("arg2.dtype is expected to be float32");
    return -1;
  }
  if (!((10 == ((int32_t)arg2_shape[0])))) {
    TVMAPISetLastError("Argument arg2.shape[0] has an unsatisfied constraint");
    return -1;
  }
  if (!(arg2_strides == NULL)) {
    if (!((1 == ((int32_t)arg2_strides[0])))) {
      TVMAPISetLastError("arg2.strides: expected to be compact array");
      return -1;
    }
  }
  if (!(((uint64_t)0 == (((TVMArray*)arg2)[0].byte_offset)))) {
    TVMAPISetLastError("Argument arg2.byte_offset has an unsatisfied constraint");
    return -1;
  }
  if (!((1 == (((TVMArray*)arg2)[0].ctx.device_type)))) {
    TVMAPISetLastError("Argument arg2.device_type has an unsatisfied constraint");
    return -1;
  }
  if (!((dev_id == (((TVMArray*)arg2)[0].ctx.device_id)))) {
    TVMAPISetLastError("Argument arg2.device_id has an unsatisfied constraint");
    return -1;
  }
  for (int32_t i_outer = 0; i_outer < 5; ++i_outer) {
    for (int32_t i_inner = 0; i_inner < 2; ++i_inner) {
      C[((i_outer * 2) + i_inner)] = (A[((i_outer * 2) + i_inner)] + B[((i_outer * 2) + i_inner)]);
    }
  }
  return 0;
}

最后通过fadd.save("myadd.c")保存文件。

References