PackedFunc
TVM 即可以在 C++ 端定义函数然后从 Python 端调用,反之亦可以。这一切都与 tvm/include/tvm/runtime/packed_func.h
中的 PackedFunc 类有关。
Constructor
我们先来看他的构造函数,用于将一个能转换成形如 std::function<void(TVMArgs, TVMRetValue*)>
的 TCallable
类型转包装成 PackedFunc,返回一个 PackedFuncSubObj
对象的智能指针。
并且重载了 ()
运算符,可以像调用函数一样调用 PackedFunc 对象。具体解释见下面注释。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| template <typename TCallable, typename = std::enable_if_t< std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value && !std::is_base_of<TCallable, PackedFunc>::value>> explicit PackedFunc(TCallable data) { using ObjType = PackedFuncSubObj<TCallable>; data_ = make_object<ObjType>(std::forward<TCallable>(data)); }
template <typename T, typename... Args> inline ObjectPtr<T> make_object(Args&&... args) { return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
template <typename... Args> inline TVMRetValue operator()(Args&&... args) const;
TVM_ALWAYS_INLINE void CallPacked(TVMArgs args, TVMRetValue* rv) const; bool operator==(std::nullptr_t null) const { return data_ == nullptr; } bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj); }
|
()
操作符重载的实现如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| template <typename... Args> inline TVMRetValue PackedFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...); TVMRetValue rv; (static_cast<PackedFuncObj*>(data_.get()))->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; }
|
TVMArgs & TVMRetValue
下面我们来看 TVMArgs 和 TVMRetValue 的定义。
TVMArgs
构造函数接受三个参数,分别代表参数值,参数值对应的数据类型代码,参数个数。
TVMValue
是一个 union,使用 DLDataType
来描述 TVMArgs 中数据的类型 (类型代码,位数,向量长度),使用 DLDevice
来描述该数据被存在哪种硬件上。他俩都位于第三方 dlpack 包中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| class TVMArgs { public: const TVMValue* values; const int* type_codes; int num_args;
TVMArgs(const TVMValue* values, const int* type_codes, int num_args) : values(values), type_codes(type_codes), num_args(num_args) {} inline int size() const; inline TVMArgValue operator[](int i) const;
template <typename T> inline T At(int i) const; };
typedef union { int64_t v_int64; double v_float64; void* v_handle; const char* v_str; DLDataType v_type; DLDevice v_device; } TVMValue;
|
TVMRetValue
继承自 TVMPODValue_CRTP_<TVMRetValue>
,基类为 TVMPODValue_
,从名字看出他与 C++ 中的 Plain Old Data 行为类似。
▶
Curious Recurring Template Pattern (CRTP)
Curiously Recurring Template Pattern (CRTP,奇异递归模板模式) 是一种 C++ 模板元编程技巧,它允许在基类中使用派生类的类型信息,从而实现静态多态性,避免运行时类型检查的开销。
CRTP 的核心在于,派生类将自身作为模板参数传递给基类。 这使得基类可以在编译时访问派生类的类型信息,从而在基类中实现一些依赖于派生类类型的方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| template <typename Derived> class Base { public: void print() { static_cast<Derived*>(this)->print_impl(); } };
class Derived : public Base<Derived> { public: void print_impl() { std::cout << "Hello from Derived!" << std::endl; } };
int main() { Derived d; d.print(); return 0; }
|
- Base
是一个模板类,它接受一个类型参数
Derived`. Derived
继承自 Base<Derived>
,将自身作为模板参数传递给 Base。Base
类中的 print()
方法使用 static_cast
将 this 指针转换为 Derived
类型指针,然后调用 print_impl()
方法。 print_impl()
方法是在 Derived
类中定义的。
- 重载了各种类型转换运算符 (例如
operator void*()
, operator DLTensor*()
, operator NDArray()
…),允许以类型安全的方式访问 TVMValue
中存储的值并转换为对应的类型。
TryAsBool()
, TryAsInt()
, TryAsFloat()
等辅助函数提供了一种更安全的方式来尝试将 TVMValue
转换为布尔值、整数或浮点数,并返回 std::optional
来指示转换是否成功。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| class TVMPODValue_ { public:
protected: friend class TVMArgsSetter; friend class TVMRetValue; friend class TVMMovableArgValue_; TVMPODValue_() : type_code_(kTVMNullptr) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {}
TVMValue value_; int type_code_; };
|
C++ Function Register
C++ 端的函数注册宏定义是 TVM_REGISTER_GLOBAL,将生成的唯一变量名赋值为 ::tvm::runtime::Registry::Register(OpName)
的返回值。
1 2
| #define TVM_REGISTER_GLOBAL(OpName) \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName)
|
Register 函数会返回一个 Registry 类,提供多种方法来设置函数体。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| class Registry { public: Registry& set_body(PackedFunc f); template <typename FLambda> Registry& set_body_typed(FLambda f); template <typename T, typename R, typename... Args> Registry& set_body_method(R (T::*f)(Args...)); static Registry& Register(const std::string& name); static const PackedFunc* Get(const std::string& name); static std::vector ListNames(); protected: std::string name_; PackedFunc func_; friend struct Manager; };
|
Manager 的定义如下。
fmap
用于存储已注册的函数
mutex
是一个互斥锁,用于保护 fmap 的线程安全。
Global()
静态方法,返回一个全局的静态 Manager 对象 (只在第一次调用时被初始化).
1 2 3 4 5 6 7 8 9 10 11 12 13
| struct Registry::Manager {
std::unordered_map<String, Registry*> fmap; std::mutex mutex;
Manager() {}
static Manager* Global() { static Manager* inst = new Manager(); return inst; } };
|
Registry::Register
函数定义如下。使用 Manager 类来管理已注册的函数,并使用互斥锁来保证线程安全。 它允许注册新的函数,并检查函数名是否冲突,以及是否允许覆盖已存在的函数。 返回 Registry 对象的引用方便后续设置函数体等操作。
1 2 3 4 5 6 7 8 9 10 11 12
| Registry& Registry::Register(const String& name, bool can_override) { Manager m = Manager::Global(); std::lock_guardstd::mutex lock(m->mutex); if (m->fmap.count(name)) { ICHECK(can_override) << "Global PackedFunc " << name << " is already registered"; }
Registry* r = new Registry(); r->name_ = name; m->fmap[name] = r; return *r; }
|
▶
An Register Example in C++
1 2 3 4 5 6
| struct Example { int doThing(int x); }
TVM_REGISTER_GLOBAL("Example_doThing") .set_body_method(&Example::doThing);
|
Python Call C++ Function
python 端调用 C++ 函数最终都会进入到 _init_api_prefix 函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def _init_api_prefix(module_name, prefix): module = sys.modules[module_name]
for name in list_global_func_names(): if not name.startswith(prefix): continue
fname = name[len(prefix) + 1 :] target_module = module
if fname.find(".") != -1: continue f = get_global_func(name) ff = _get_api(f) ff.__name__ = fname ff.__doc__ = "TVM PackedFunc %s. " % fname setattr(target_module, ff.__name__, ff)
|
get_global_func
会调用 _get_global_func 返回一个 python 端的 PackedFunc
对象。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| def _get_global_func(name, allow_missing=False): handle = PackedFuncHandle() check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))
if handle.value: return _make_packed_func(handle, False)
if allow_missing: return None
raise ValueError("Cannot find global function %s" % name)
def _make_packed_func(handle, is_global): """Make a packed function class""" obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC) obj.is_global = is_global obj.handle = handle return obj
|
上面代码中 _CLASS_PACKED_FUNC 是一个全局变量, 在 /python/tvm/runtime/packed_func.py
中会调用 _set_class_packed_func 将其设置为 python 端的 PackedFunc
类。
python 端的 PackedFunc
类继承自 PackedFuncBase 类,它本质上就是在 python 端定义了一个和 C++ 端具有相同形式的类,将传入的参数转换成 C++ 端的 TVMArgs, 再通过 ctype 加载的 lib 调用 C++ 端的 PackedFunc 对象,并将返回值转换成 python 端的 TVMRetValue.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| class PackedFuncBase(object): """Function base."""
__slots__ = ["handle", "is_global"]
def __init__(self, handle, is_global): self.handle = handle self.is_global = is_global
def __del__(self): if not self.is_global and _LIB is not None: if _LIB.TVMFuncFree(self.handle) != 0: raise_last_ffi_error()
def __call__(self, *args): temp_args = [] values, tcodes, num_args = _make_tvm_args(args, temp_args) ret_val = TVMValue() ret_tcode = ctypes.c_int() if ( _LIB.TVMFuncCall( self.handle, values, tcodes, ctypes.c_int(num_args), ctypes.byref(ret_val), ctypes.byref(ret_tcode), ) != 0 ): raise_last_ffi_error() _ = temp_args _ = args return RETURN_SWITCH[ret_tcode.value](ret_val)
|
C++ Call Python Function
register_func 实现注册 python 端的函数。如果 f 已经提供,则直接注册并返回已注册的函数。如果 f 未提供,则返回一个 register 函数以延迟注册 (通常以 decorator 形式存在).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def register_func(func_name, f=None, override=False): if callable(func_name): f = func_name func_name = f.__name__
if not isinstance(func_name, str): raise ValueError("expect string function name")
ioverride = ctypes.c_int(override)
def register(myf): """internal register function""" if not isinstance(myf, PackedFuncBase): myf = convert_to_tvm_func(myf) check_call(_LIB.TVMFuncRegisterGlobal(c_str(func_name), myf.handle, ioverride)) return myf
if f: return register(f) return register
|
内部的 register 函数通过 convert_to_tvm_func 将一个 Python 函数转换为 TVM 函数 (tvm.nd.Function),与 C++ 代码进行交互。
内部定义的 cfun 是一个 C++ 风格的回调函数,它将接收的 C++ 参数转换为 Python 参数,调用 pyfunc,并将结果返回给 C++ 端。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| def convert_to_tvm_func(pyfunc): """Convert a python function to TVM function
Parameters ---------- pyfunc : python function The python function to be converted.
Returns ------- tvmfunc: tvm.nd.Function The converted tvm function. """ local_pyfunc = pyfunc
def cfun(args, type_codes, num_args, ret, _): """ctypes function""" num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args pyargs = (C_TO_PY_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)) try: rv = local_pyfunc(*pyargs) except Exception as err: msg = traceback.format_exc() msg = py2cerror(msg) _LIB.TVMAPISetLastPythonError(ctypes.py_object(err))
return -1
if rv is not None: if isinstance(rv, tuple): raise ValueError("PackedFunction can only support one return value") temp_args = [] values, tcodes, _ = _make_tvm_args((rv,), temp_args) if not isinstance(ret, TVMRetValueHandle): ret = TVMRetValueHandle(ret) if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0: raise_last_ffi_error() _ = temp_args _ = rv return 0
handle = PackedFuncHandle() f = TVMPackedCFunc(cfun) pyobj = ctypes.py_object(f) ctypes.pythonapi.Py_IncRef(pyobj) if _LIB.TVMFuncCreateFromCFunc(f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0: raise_last_ffi_error() return _make_packed_func(handle, False)
|
注册完成之后该函数已经存在于 Manager 中,可以在 C++ 端调用如下调用
1 2 3 4 5 6 7 8 9 10
| tvm::runtime::PackedFunc f = tvm::runtime::Registry::Get("my_python_function");
const PackedFunc* Registry::Get(const String& name) { Manager* m = Manager::Global(); std::lock_guard<std::mutex> lock(m->mutex); auto it = m->fmap.find(name); if (it == m->fmap.end()) return nullptr; return &(it->second->func_); }
|
▶
An Register Example in Python
1 2 3 4 5
| def callback(msg): print(msg)
tvm.register_func("my_call_back", callback)
|
Summary
TVM 的互调机制可以简述为:在 C++ 和 Python 两边使用了一个统一的函数原型 void(TVMArgs args, TVMRetValue *rv)
,这就是PackedFunc
类的机制,实现主要是重载了函数调用运算符 ()
,真正的函数体是通过 set_body 去设置的。
相互调用其实是每次去全局注册函数表中寻找相应的函数名,然后做两种语言之间PackedFunc对象的转换,再去执行。