本文以向量加法为例,记录TVM最最基本的Tensor Expression的使用,以及简单的编译运行流程。
下面的代码为简单的向量加法,参考自Tensor Expression官方教程,在TVM v0.6下执行(注意与v0.7dev的模块有区别)。
import tvm
import numpy as np
# Tensor Expression
# args: (shape, label)
A = tvm.placeholder((10,), name='A')
B = tvm.placeholder((10,), name='B')
# args: (shape, function, label)
# function represented in lambda expression (element-wise)
# lambda axis1, axis2, ... : f(axis1, axis2, ...)
C = tvm.compute((10,), lambda i: A[i] + B[i], name="C")
# generate schedule
s = tvm.create_schedule(C.op)
# print low level codes
print(tvm.lower(s,[A,B,C],simple_mode=True))
其中placeholder
代表特定维度的张量,最后生成的代码会要求用户输入两个tensor
,如果是C++代码,则要求用户输入两个float*
。注意,你会发现这个过程实际上是没有计算发生的,而只是定义了计算如何进行。
输出的low-level代码如下所示,还是相当好理解的,即i
从0到10循环,循环内每次计算C[i]
的值。
produce C {
for (i, 0, 10) {
C[i] = (A[i] + B[i])
}
}
一些常用的循环优化API可以在这里找到。这里使用循环分割split
作为尝试。
split(parent[, factor, nparts])
outer
, inner
vaiable of iteration.bx, tx = s[C].split(C.op.axis[0],factor=2)
print(tvm.lower(s,[A,B,C],simple_mode=True))
由于对schedule的操作是原地变换,因此可以直接输出lower后的代码,发现确实已经改变了,原来的循环体变成5*2的循环。
produce C {
for (i.outer, 0, 5) {
for (i.inner, 0, 2) {
C[((i.outer*2) + i.inner)] = (A[((i.outer*2) + i.inner)] + B[((i.outer*2) + i.inner)])
}
}
}
当然这一个schedule变换并没有带来任何好处,只是为了说明Tensor Expression应该怎么用。
之后就可以调用build
生成目标代码了,可以设置target
和target_host
。
tgt = "c" # "llvm", "cuda"
fadd = tvm.build(s,[A,B,C],target=tgt,name="myadd")
然后可以创造运行时环境,进行运行测试。
n = 10
ctx = tvm.context(tgt,0)
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n,dtype=C.dtype), ctx)
fadd(a,b,c) # run
# test
tvm.testing.assert_allclose(c.asnumpy(),a.asnumpy() + b.asnumpy())
print(fadd.get_source())
生成的C代码如下
for (int32_t i_outer = 0; i_outer < 5; ++i_outer) {
for (int32_t i_inner = 0; i_inner < 2; ++i_inner) {
C[((i_outer * 2) + i_inner)] = (A[((i_outer * 2) + i_inner)] + B[((i_outer * 2) + i_inner)]);
}
}
myadd.c
完整代码如下
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
extern void* __tvm_module_ctx = NULL;
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t myadd( void* args, void* arg_type_ids, int32_t num_args) {
if (!((num_args == 3))) {
TVMAPISetLastError("myadd: num_args should be 3");
return -1;
}
void* arg0 = (((TVMValue*)args)[0].v_handle);
int32_t arg0_code = (( int32_t*)arg_type_ids)[0];
void* arg1 = (((TVMValue*)args)[1].v_handle);
int32_t arg1_code = (( int32_t*)arg_type_ids)[1];
void* arg2 = (((TVMValue*)args)[2].v_handle);
int32_t arg2_code = (( int32_t*)arg_type_ids)[2];
float* A = (float*)(((TVMArray*)arg0)[0].data);
int64_t* arg0_shape = (int64_t*)(((TVMArray*)arg0)[0].shape);
int64_t* arg0_strides = (int64_t*)(((TVMArray*)arg0)[0].strides);
int32_t dev_type = (((TVMArray*)arg0)[0].ctx.device_type);
int32_t dev_id = (((TVMArray*)arg0)[0].ctx.device_id);
float* B = (float*)(((TVMArray*)arg1)[0].data);
int64_t* arg1_shape = (int64_t*)(((TVMArray*)arg1)[0].shape);
int64_t* arg1_strides = (int64_t*)(((TVMArray*)arg1)[0].strides);
float* C = (float*)(((TVMArray*)arg2)[0].data);
int64_t* arg2_shape = (int64_t*)(((TVMArray*)arg2)[0].shape);
int64_t* arg2_strides = (int64_t*)(((TVMArray*)arg2)[0].strides);
if (!(((((arg0_code == 3) || (arg0_code == 13)) || (arg0_code == 7)) || (arg0_code == 4)))) {
TVMAPISetLastError("myadd: Expect arg[0] to be pointer");
return -1;
}
if (!(((((arg1_code == 3) || (arg1_code == 13)) || (arg1_code == 7)) || (arg1_code == 4)))) {
TVMAPISetLastError("myadd: Expect arg[1] to be pointer");
return -1;
}
if (!(((((arg2_code == 3) || (arg2_code == 13)) || (arg2_code == 7)) || (arg2_code == 4)))) {
TVMAPISetLastError("myadd: Expect arg[2] to be pointer");
return -1;
}
if (!((dev_type == 1))) {
TVMAPISetLastError("device_type need to be 1");
return -1;
}
if (!((1 == (((TVMArray*)arg0)[0].ndim)))) {
TVMAPISetLastError("arg0.ndim is expected to equal 1");
return -1;
}
if (!(((((((TVMArray*)arg0)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg0)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg0)[0].dtype.lanes) == (uint16_t)1)))) {
TVMAPISetLastError("arg0.dtype is expected to be float32");
return -1;
}
if (!((10 == ((int32_t)arg0_shape[0])))) {
TVMAPISetLastError("Argument arg0.shape[0] has an unsatisfied constraint");
return -1;
}
if (!(arg0_strides == NULL)) {
if (!((1 == ((int32_t)arg0_strides[0])))) {
TVMAPISetLastError("arg0.strides: expected to be compact array");
return -1;
}
}
if (!(((uint64_t)0 == (((TVMArray*)arg0)[0].byte_offset)))) {
TVMAPISetLastError("Argument arg0.byte_offset has an unsatisfied constraint");
return -1;
}
if (!((1 == (((TVMArray*)arg1)[0].ndim)))) {
TVMAPISetLastError("arg1.ndim is expected to equal 1");
return -1;
}
if (!(((((((TVMArray*)arg1)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg1)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg1)[0].dtype.lanes) == (uint16_t)1)))) {
TVMAPISetLastError("arg1.dtype is expected to be float32");
return -1;
}
if (!((10 == ((int32_t)arg1_shape[0])))) {
TVMAPISetLastError("Argument arg1.shape[0] has an unsatisfied constraint");
return -1;
}
if (!(arg1_strides == NULL)) {
if (!((1 == ((int32_t)arg1_strides[0])))) {
TVMAPISetLastError("arg1.strides: expected to be compact array");
return -1;
}
}
if (!(((uint64_t)0 == (((TVMArray*)arg1)[0].byte_offset)))) {
TVMAPISetLastError("Argument arg1.byte_offset has an unsatisfied constraint");
return -1;
}
if (!((1 == (((TVMArray*)arg1)[0].ctx.device_type)))) {
TVMAPISetLastError("Argument arg1.device_type has an unsatisfied constraint");
return -1;
}
if (!((dev_id == (((TVMArray*)arg1)[0].ctx.device_id)))) {
TVMAPISetLastError("Argument arg1.device_id has an unsatisfied constraint");
return -1;
}
if (!((1 == (((TVMArray*)arg2)[0].ndim)))) {
TVMAPISetLastError("arg2.ndim is expected to equal 1");
return -1;
}
if (!(((((((TVMArray*)arg2)[0].dtype.code) == (uint8_t)2) && ((((TVMArray*)arg2)[0].dtype.bits) == (uint8_t)32)) && ((((TVMArray*)arg2)[0].dtype.lanes) == (uint16_t)1)))) {
TVMAPISetLastError("arg2.dtype is expected to be float32");
return -1;
}
if (!((10 == ((int32_t)arg2_shape[0])))) {
TVMAPISetLastError("Argument arg2.shape[0] has an unsatisfied constraint");
return -1;
}
if (!(arg2_strides == NULL)) {
if (!((1 == ((int32_t)arg2_strides[0])))) {
TVMAPISetLastError("arg2.strides: expected to be compact array");
return -1;
}
}
if (!(((uint64_t)0 == (((TVMArray*)arg2)[0].byte_offset)))) {
TVMAPISetLastError("Argument arg2.byte_offset has an unsatisfied constraint");
return -1;
}
if (!((1 == (((TVMArray*)arg2)[0].ctx.device_type)))) {
TVMAPISetLastError("Argument arg2.device_type has an unsatisfied constraint");
return -1;
}
if (!((dev_id == (((TVMArray*)arg2)[0].ctx.device_id)))) {
TVMAPISetLastError("Argument arg2.device_id has an unsatisfied constraint");
return -1;
}
for (int32_t i_outer = 0; i_outer < 5; ++i_outer) {
for (int32_t i_inner = 0; i_inner < 2; ++i_inner) {
C[((i_outer * 2) + i_inner)] = (A[((i_outer * 2) + i_inner)] + B[((i_outer * 2) + i_inner)]);
}
}
return 0;
}
最后通过fadd.save("myadd.c")
保存文件。