日韩性视频-久久久蜜桃-www中文字幕-在线中文字幕av-亚洲欧美一区二区三区四区-撸久久-香蕉视频一区-久久无码精品丰满人妻-国产高潮av-激情福利社-日韩av网址大全-国产精品久久999-日本五十路在线-性欧美在线-久久99精品波多结衣一区-男女午夜免费视频-黑人极品ⅴideos精品欧美棵-人人妻人人澡人人爽精品欧美一区-日韩一区在线看-欧美a级在线免费观看

歡迎訪問(wèn) 生活随笔!

生活随笔

當(dāng)前位置: 首頁(yè) > 人文社科 > 生活经验 >内容正文

生活经验

TVM,Relay,Pass

發(fā)布時(shí)間:2023/11/28 生活经验 37 豆豆
生活随笔 收集整理的這篇文章主要介紹了 TVM,Relay,Pass 小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.

TVM,Relay,Pass
Relay介紹
主要結(jié)合TVM的文檔(https://tvm.apache.org/docs/dev/relay_intro.html),介紹一下NNVM的第二代Relay。Relay的設(shè)計(jì)目標(biāo)有以下幾點(diǎn):
支持傳統(tǒng)的數(shù)據(jù)流(DataFlow)風(fēng)格編程。支持functional-style scoping,并融合了編程語(yǔ)言領(lǐng)域的一些知識(shí),帶了一些新的特性(支持Let表達(dá)式,支持遞歸等等)支持?jǐn)?shù)據(jù)流風(fēng)格和函數(shù)式風(fēng)格混合編程。
使用Relay建立一個(gè)計(jì)算圖
傳統(tǒng)的深度學(xué)習(xí)框架使用計(jì)算圖作為的中間表示。計(jì)算圖(或數(shù)據(jù)流圖)是代表計(jì)算過(guò)程的有向無(wú)環(huán)圖(DAG)。盡管由于缺少控制流,數(shù)據(jù)流圖在計(jì)算能力方面受到限制,但簡(jiǎn)單性使其易于實(shí)現(xiàn)自動(dòng)微分,并針對(duì)異構(gòu)執(zhí)行環(huán)境進(jìn)行編譯(例如,在專用硬件上執(zhí)行計(jì)算圖的某些部分,即子圖)。

使用Relay構(gòu)建一個(gè)簡(jiǎn)單的計(jì)算圖示例代碼,對(duì)應(yīng)的文本形式和AST抽象語(yǔ)法樹(shù),可以使用Relay來(lái)構(gòu)建一個(gè)計(jì)算(DataFlow)圖。具體來(lái)說(shuō),上面的代碼顯示了如何構(gòu)造一個(gè)簡(jiǎn)單的兩個(gè)節(jié)點(diǎn)的計(jì)算圖,可以發(fā)現(xiàn)這個(gè)示例的代碼和現(xiàn)有的Garph IR如NNVMv1沒(méi)有太大區(qū)別,唯一的區(qū)別是在術(shù)語(yǔ)方面:
現(xiàn)有框架通常使用圖和子圖Relay使用函數(shù),例如 – fn(%x),表示圖每個(gè)數(shù)據(jù)流節(jié)點(diǎn),都是Relay中的一個(gè)CallNode。通過(guò)Relay的Python DSL,可以快速構(gòu)建計(jì)算圖。上面的代碼需要注意,這里顯示構(gòu)造了一個(gè)Add節(jié)點(diǎn),兩個(gè)輸入都指向%1。當(dāng)一個(gè)深度學(xué)習(xí)框架。對(duì)上面的計(jì)算圖進(jìn)行推理時(shí),將會(huì)按照拓?fù)湫蜻M(jìn)行計(jì)算,并且%1只會(huì)被計(jì)算一次。雖然這個(gè)事實(shí)對(duì)于深度學(xué)習(xí)框架的開(kāi)發(fā)者,一件很自然的事情,但這或許會(huì)使得只關(guān)心算法的研究員困惑。如果實(shí)現(xiàn)一個(gè)簡(jiǎn)單的vistor打印結(jié)果,將結(jié)果視為嵌套的Call表達(dá)式,將是log(%x) + log(%x)。
當(dāng)DAG中存在共享節(jié)點(diǎn)時(shí),這種歧義是由程序語(yǔ)義的解釋不同引起的。在正常的函數(shù)式編程IR中,嵌套表達(dá)式被視為表達(dá)式樹(shù),沒(méi)有考慮%1,實(shí)際上在%2中被重用了2次的事實(shí)。
Relay IR注意到了這個(gè)區(qū)別。其實(shí)深度學(xué)習(xí)框架用戶,經(jīng)常使用這種方式構(gòu)建計(jì)算圖,其中經(jīng)常發(fā)生DAG節(jié)點(diǎn)重用。然后以文本格式打印Relay程序時(shí),每行打印一個(gè)CallNode,并為每個(gè)CallNode分配一個(gè)臨時(shí)ID(%1, %2),以便可以在程序的后續(xù)部分中引用每個(gè)公共節(jié)點(diǎn)。
Module:支持多個(gè)函數(shù)(Graphs)
上面介紹了如何構(gòu)建一個(gè)數(shù)據(jù)流圖為一個(gè)函數(shù)。然后一個(gè)很自然的問(wèn)題是可以做到構(gòu)建多個(gè)函數(shù)并相互調(diào)用嗎?Relay允許將多個(gè)函數(shù)組合在一個(gè)Module中,下面的代碼展示了一個(gè)函數(shù)調(diào)用另外一個(gè)函數(shù)的例子。
def @muladd(%x, %y, %z) { %1 = mul(%x, %y) %2 = add(%1, %z) %2}def @myfunc(%x) { %1 = @muladd(%x, 1, 2) %2 = @muladd(%1, 2, 3) %2}
Module可以被看作Map<GlobalVar, Function>,GlobalVar僅僅是一個(gè)表示函數(shù)名的ID,上面的程序中GlobalVar是@muladd和@myfunc。當(dāng)一個(gè)CallNode調(diào)用另外一個(gè)函數(shù)時(shí),相應(yīng)的GlobalVar被存在CallNode的OP中。包含了一個(gè)間接的等級(jí)關(guān)系—需要使用相應(yīng)的GlobalVar,從Module中查找調(diào)用函數(shù)的主體。也可以直接將引用的函數(shù)存儲(chǔ)為CallNode中的OP。為什么需要引入GlobalVar呢?主要原因是為了解耦定義和聲明,并支持了函數(shù)的遞歸和延遲聲明。
def @myfunc(%x) { %1 = equal(%x, 1)if (%1) { %x } else { %2 = sub(%x, 1) %3 = @myfunc(%2) %4 = add(%3, %3) %4 }}在上面的例子中,@myfunc遞歸調(diào)用。使用GlobalVar @myfunc表示函數(shù),避免了數(shù)據(jù)結(jié)構(gòu)中的循環(huán)依賴性。至此,已經(jīng)介紹完了Relay中的基本概念。相比NNVM,Relay在如下方面進(jìn)行了改進(jìn):
有文本形式中間表示,便于開(kāi)發(fā)和 debug支持子圖函數(shù)、聯(lián)合模塊,便于聯(lián)合優(yōu)化前端用戶友好,便于調(diào)優(yōu)0x2.3 Let Binding and Scopes
至此,已經(jīng)介紹了如何用深度學(xué)習(xí)框架中的舊方法,構(gòu)建計(jì)算圖。這一節(jié)將討論一個(gè)Relay的一個(gè)新的構(gòu)造-let bindings。
Let binding被每一種高級(jí)的編程語(yǔ)言應(yīng)用。在Relay中,一個(gè)擁有三個(gè)字段Let(var, value, body)的數(shù)據(jù)結(jié)構(gòu)。計(jì)算一個(gè)Let表達(dá)式時(shí),首先計(jì)算value部分,然后將其綁定到var,最后在body表達(dá)式中返回計(jì)算結(jié)果。
可以使用一系列的Let綁定,構(gòu)造一個(gè)邏輯上等效于數(shù)據(jù)流程序的程序,下面的代碼示例顯示了這個(gè)用法:

Let表達(dá)式構(gòu)造和數(shù)據(jù)流程序等價(jià)的,計(jì)算圖嵌套的Let Binding,稱作A-normal形式,作為函數(shù)式編程語(yǔ)言中的常用IR。通過(guò)上面的圖,可以發(fā)現(xiàn)雖然這兩個(gè)程序的語(yǔ)義完全等價(jià),文本表示也一樣(除了A-norm形式有l(wèi)et的前綴),但AST抽象語(yǔ)法樹(shù)卻不一樣。
由于程序的優(yōu)化,使用了這些AST數(shù)據(jù)結(jié)構(gòu)進(jìn)行了變換,這兩種不同的結(jié)構(gòu),影響到最終編譯器生成的代碼。比如,想要檢測(cè)add(log(x), y)這個(gè)模式。在數(shù)據(jù)流程序中,可以首先進(jìn)入add節(jié)點(diǎn),然后直接檢查第一個(gè)參數(shù)是不是log。在A-form的程序中,不能直接檢查任何東西,因?yàn)閍dd節(jié)點(diǎn)的輸入是%v1-需要維護(hù)一個(gè)映射表,將變量和綁定的值進(jìn)行映射,然后查表才知道%v1代表的是log。
為什么可能需要Let Binding
Let Binding的一種關(guān)鍵用法,可以指定計(jì)算的scope。看一下下面這個(gè)沒(méi)有使用Let Binding的例子:

沒(méi)有使用Let Binding編程的一個(gè)例子,當(dāng)嘗試在該在哪里計(jì)算%1節(jié)點(diǎn)時(shí),問(wèn)題就來(lái)了。特別的是,雖然文本格式似乎建議,應(yīng)該在if的scope之外,計(jì)算節(jié)點(diǎn)%1,但AST卻不建議這樣做。實(shí)際上數(shù)據(jù)流圖,永遠(yuǎn)不會(huì)定義計(jì)算scope,這在語(yǔ)義上產(chǎn)生了一些歧義。
當(dāng)有閉包時(shí),這種歧義更加有趣,考慮下面的程序,該程序返回一個(gè)閉包。不知道在哪里計(jì)算%1,可以在閉包的內(nèi)部和外部。
fn (%x) { %1 = log(%x) %2 = fn(%y) { add(%y, %1) } %2}Let Binding解決了這些問(wèn)題,因?yàn)橹档挠?jì)算發(fā)生在let節(jié)點(diǎn)上。在這兩個(gè)程序中,如果將%1 = log(%x)改成let %v1 = log(%x),將計(jì)算位置明確指定為if scope和閉包之外。Let Binding為計(jì)算端提供了更精確的范圍,在生成后端代碼時(shí)會(huì)很有用(因?yàn)檫@種范圍在IR中)。
另一方面,沒(méi)有指定計(jì)算scope的數(shù)據(jù)流形式,也有其自身的優(yōu)勢(shì),不需要擔(dān)心在生成代碼時(shí),將let放到哪里。數(shù)據(jù)流格式還為后面決定將計(jì)算節(jié)點(diǎn)放到哪里的Passes,提供了更大的自由度。因此,在優(yōu)化的初始階段,如果發(fā)現(xiàn)數(shù)據(jù)流形式,還是挺方便的,那么,使用數(shù)據(jù)流圖的編碼方法,可能不是一個(gè)壞主意。目前在Relay中也實(shí)現(xiàn)了很多針對(duì)數(shù)據(jù)流圖的優(yōu)化方式。
但是,當(dāng)將IR lower到實(shí)際的運(yùn)行時(shí)程序時(shí),需要精確的計(jì)算scope。特別是當(dāng)使用子函數(shù)和閉包時(shí),要明確指定計(jì)算scope,應(yīng)在哪里發(fā)生。在后期執(zhí)行特定的優(yōu)化中,可以使用Let Binding來(lái)解決此問(wèn)題。
對(duì)IR轉(zhuǎn)換的影響
希望到目前為止,已經(jīng)熟悉兩種表示形式。大多數(shù)函數(shù)式編程語(yǔ)言都以A-normal形式進(jìn)行分析,分析人員無(wú)需注意表達(dá)式是DAG。
Relay選擇同時(shí)支持?jǐn)?shù)據(jù)流形式和Let Binding。TVM相信讓框架開(kāi)發(fā)者選擇熟悉的表達(dá)形式很重要。但是這確實(shí)對(duì)寫(xiě)通用的Passes產(chǎn)生了一些影響。這里還沒(méi)介紹Passes,對(duì)Passes理解不深,沒(méi)有使用過(guò)Let表達(dá)式來(lái)構(gòu)建網(wǎng)絡(luò),就不繼續(xù)介紹具體有哪些影響了。
詳細(xì)內(nèi)容可以參考:https://tvm.apache.org/docs/dev/relay_intro.html#let-binding-and-scopes
基于Relay構(gòu)建一個(gè)自定義的神經(jīng)網(wǎng)絡(luò)示例
基于Relay的接口定義一個(gè)Conv+BN+ReLU的小網(wǎng)絡(luò),展示一下Relay接口應(yīng)該如何使用,這里TVM版本是0.8.0.dev,代碼如下:
#coding=utf-8import tvmfrom tvm import relayimport numpy as npfrom tvm.contrib import graph_executor# 構(gòu)造BNdefbatch_norm(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs): name = kwargs.get(“name”) kwargs.pop(“name”)ifnot gamma: gamma = relay.var(name + “_gamma”)ifnot beta: beta = relay.var(name + “_beta”)ifnot moving_mean: moving_mean = relay.var(name + “_moving_mean”)ifnot moving_var: moving_var = relay.var(name + “_moving_var”)return relay.nn.batch_norm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs)[0]# 構(gòu)造卷積defconv2d(data, weight=None, **kwargs): name = kwargs.get(“name”) kwargs.pop(“name”)ifnot weight: weight = relay.var(name + “_weight”)return relay.nn.conv2d(data, weight, **kwargs)# 構(gòu)造卷積+BN+ReLU的simpleNetdefsimplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), epsilon=1e-5): conv = conv2d( data=data, channels=channels, kernel_size=kernel_size, strides=strides, padding=padding, data_layout=‘NCHW’, name=name+’_conv’) bn = batch_norm(data=conv, epsilon=epsilon, name=name + ‘_bn’) act = relay.nn.relu(data=bn)return actdata_shape = (1, 3, 224, 224)kernel_shape = (32, 3, 3, 3)dtype = "float32"data = relay.var(“data”, shape=data_shape, dtype=dtype)act = simplenet(data, “graph”, 32, strides=(2, 2))func = relay.Function(relay.analysis.free_vars(act), act)print(func)np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))params = {“graph_conv_weight”: tvm.nd.array(np.random.uniform(-1, 1, (32, 3, 3, 3)).astype(dtype)),“graph_bn_gamma”: tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),“graph_bn_beta”: tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),“graph_bn_moving_mean”: tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),“graph_bn_moving_var”: tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),}with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, “l(fā)lvm”, params=params)dev = tvm.cpu(0)dtype = "float32"m = graph_executor.GraphModule(lib"default")# set inputsm.set_input(“data”, tvm.nd.array(np_data.astype(dtype)))# executem.run()# get outputstvm_output = m.get_output(0)
就是一個(gè)很常規(guī)的過(guò)程,創(chuàng)建Relay Function,然后將所有的OP的權(quán)重信息用params這個(gè)字典存起來(lái),注意這里的權(quán)重信息是隨機(jī)初始化的。在編譯Relay IR之前可以先看一下優(yōu)化前的IR長(zhǎng)什么樣:
fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var) { %0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]); %1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var); %2 = %1.0; nn.relu(%2)}符合第二節(jié)介紹的規(guī)則,Relay IR時(shí)一個(gè)函數(shù)。
初識(shí)Pass
上面構(gòu)造simplenet的代碼中,relay.build外部包了一層tvm.transform.PassContext,如下:
with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, “l(fā)lvm”, params=params)實(shí)際上tvm.transform.PassContext這個(gè)接口就定義了Pass,如文檔所示:

tvm.transform.PassContext用來(lái)控制對(duì)relay IR使用哪些Pass進(jìn)行優(yōu)化,Pass是TVM中基于Relay IR進(jìn)行的一系列優(yōu)化,類似于onnx-simplifier里面用到的onnxoptimizer,可以簡(jiǎn)化計(jì)算圖,去除一些冗余的算子,提高模型的推理效率。TVM將所有的pass都抽象到了tvm/include/tvm/ir/transform.h這個(gè)文件中,主要包含PassContext,PassInfo,Pass,以及Sequential。
這里的PassContext是上面Python接口對(duì)應(yīng)的C++實(shí)現(xiàn),包含了Pass執(zhí)行依賴的一些參數(shù),如優(yōu)化level,依賴特定Pass,以及設(shè)置不使用某種指定Pass等。PassInfo是用來(lái)記錄Pass信息的類,包含Pass的opy_level,name,以及當(dāng)前Pass需要哪些前置Pass。而Pass這個(gè)類,就執(zhí)行pass的主體,這是一個(gè)基類,每種Pass具體的C++代碼實(shí)現(xiàn)在tvm/src/relay/transforms中,都會(huì)繼承Pass這個(gè)基類。最后,Sequential是一個(gè)container,裝載所有Pass。
需要說(shuō)明一下,不是所有的Pass都定義在tvm/src/relay/transforms,比如下面的第一個(gè)例子,就在tvm/src/relay/backend/vm文件夾里。接下來(lái)將幾個(gè)Pass的例子,到底對(duì)Relay IR做了什么?
RemoveUnusedFunctions首先來(lái)看一下定義在tvm/src/relay/backend/vm/removed_unused_funcs.cc這里的RemoveUnusedFunctions 這個(gè)pass,核心的代碼實(shí)現(xiàn)如下:
voidVisitExpr_(const FunctionNode* func_node)final{auto func = GetRef(func_node);if (visiting_.find(func) == visiting_.end()) { visiting_.insert(func);for (auto param : func_node->params) { ExprVisitor::VisitExpr(param); } ExprVisitor::VisitExpr(func_node->body); } }IRModule RemoveUnusedFunctions(const IRModule& module, Arrayruntime::String entry_funcs){std::unordered_setstd::string called_funcs{};for (auto entry : entry_funcs) {auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); }auto existing_functions = module->functions;for (auto f : existing_functions) {auto it = called_funcs.find(f.first->name_hint);if (it == called_funcs.end()) {module->Remove(f.first); } }returnmodule;}
這個(gè)pass就是去除Relay IR中的冗余節(jié)點(diǎn),VisitExpr_這個(gè)函數(shù)就是完成了一個(gè)圖的遍歷,然后把沒(méi)有遍歷到的節(jié)點(diǎn)刪掉。刪除發(fā)生在RemoveUnusedFunctions這個(gè)函數(shù)中。
ToBasicBlockNormalForm這個(gè)Pass實(shí)現(xiàn)在tvm/src/relay/transforms/to_basic_block_normal_form.cc,代碼實(shí)現(xiàn)如下:
Expr ToBasicBlockNormalFormAux(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);/* The scope of the whole expr is global. * The scope of any subexpr, is the lowest common ancestor of all incoming edge. * We also record the set of expressions whose scope is lifted. /std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);}IRModule ToBasicBlockNormalForm(const IRModule& mod){ DLOG(INFO) << “ToBBlock:” << std::endl << mod; tvm::Map<GlobalVar, Function> updates;auto funcs = mod->functions;for (constauto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << “Expected no free variables”;if (constauto n = it.second.as()) {if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); updates.Set(it.first, Downcast(ret)); }for (auto pair : updates) { mod->Add(pair.first, pair.second, true); } DLOG(INFO) << “ToBBlock: transformed” << std::endl << mod;return mod;}boolBasicBlockNormalFormCheck(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);for (auto expr : scopes.second) { LOG(FATAL) << "The expression below violates the basic block normal form in that " << “its scope should be lifted:\n” << expr; }return scopes.second.size() == 0;}ToBasicBlockNormalForm
這個(gè)函數(shù)通過(guò)遍歷Relay IR中的function,將每個(gè)function轉(zhuǎn)換為基本塊形式(即ToBasicBlockNormalFormAux這個(gè)函數(shù)),ToBasicBlockNormalFormAux這個(gè)函數(shù)分成以下幾個(gè)部分:
調(diào)用DependencyGraph dg = DependencyGraph::Create(&arena, e)創(chuàng)建一個(gè)DependencyGraph,這個(gè)數(shù)據(jù)結(jié)構(gòu)是一個(gè)表達(dá)式相互依賴的圖結(jié)構(gòu)。通過(guò)std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg)計(jì)算每個(gè)節(jié)點(diǎn)的scope,這個(gè)scope可以簡(jiǎn)單理解為由跳轉(zhuǎn)指令如Ifnode,FunctionNode,LetNode等隔開(kāi)的那些子圖,因?yàn)橐坏┡龅竭@些節(jié)點(diǎn)在上面通過(guò)Relay Function創(chuàng)建DependencyGraph就會(huì)為這種節(jié)點(diǎn)分配一個(gè)new_scope標(biāo)志。然后CalcScope這個(gè)函數(shù)具體做了哪些事情,需要跟進(jìn)去看一下:std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg){ NodeScopeMap expr_scope; ExprSet lifted_exprs;std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;// 首先讓每個(gè)節(jié)點(diǎn)都屬于一個(gè)單獨(dú)的scopefor (auto expr_node : dg.expr_node) { node_to_expr[expr_node.second] = expr_node.first; }bool global_scope_used = false; Scope global_scope = std::make_shared();// 使用LCA算法來(lái)更新每個(gè)節(jié)點(diǎn)的真正scopefor (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it;auto iit = n->parents.head; Scope s;if (iit == nullptr) { ICHECK(!global_scope_used); s = global_scope; global_scope_used = true; } else { s = expr_scope.at(iit->value);constauto original_s = s; iit = iit->next;for (; iit != nullptr; iit = iit->next) { s = LCA(s, expr_scope.at(iit->value)); }if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {// filter out exprs whose scope do not matter Expr expr = node_to_expr[n];if (!expr.as()) { lifted_exprs.insert(expr); } } }if (n->new_scope) {auto child_scope = std::make_shared(s); expr_scope.insert({n, child_scope}); } else { expr_scope.insert({n, s}); } } ICHECK(global_scope_used);returnstd::make_pair(expr_scope, lifted_exprs);}
這個(gè)函數(shù)首先讓每個(gè)節(jié)點(diǎn)都屬于一個(gè)單獨(dú)的scope,然后使用LCA算法來(lái)更新每個(gè)節(jié)點(diǎn)的真正scope。這里簡(jiǎn)單介紹一下LCA算法以及這里具體是如何求取每個(gè)節(jié)點(diǎn)的scope的。
最近公共祖先簡(jiǎn)稱 LCA(Lowest Common Ancestor)。兩個(gè)節(jié)點(diǎn)的最近公共祖先,就是這兩個(gè)點(diǎn)的公共祖先里面,離根最遠(yuǎn)的那個(gè)。為了方便,記某點(diǎn)集 的最近公共祖先為 或 。LCA有以下性質(zhì),引自O(shè)I-wiki:

其實(shí)不看這個(gè)性質(zhì)也沒(méi)關(guān)系,了解LCA,可以求圖中兩個(gè)節(jié)點(diǎn)的最近公共祖先即可。然后CalcScope這個(gè)函數(shù)的具體思路,先將每個(gè)節(jié)點(diǎn)初始化為一個(gè)單獨(dú)的scope,然后按照后DFS序遍歷這些節(jié)點(diǎn),對(duì)于每一個(gè)遍歷到的節(jié)點(diǎn)(這里記作n),看一下它的父親節(jié)點(diǎn)iit是否存在,如果不存在則說(shuō)明當(dāng)前節(jié)點(diǎn)是根節(jié)點(diǎn),scope應(yīng)該為global_scope。如果iit存在,那么遍歷iit的子節(jié)點(diǎn),看一下這些節(jié)點(diǎn)的scope的LCA表達(dá)式,如果這個(gè)通過(guò)LCA求出來(lái)的表達(dá)式和iit節(jié)點(diǎn)的表達(dá)式完全相同,說(shuō)明這個(gè)子圖和當(dāng)前節(jié)點(diǎn)是屬于同一個(gè)scope的,否則就將當(dāng)前節(jié)點(diǎn)插入到lifted_exprs,lifted_exprs是一個(gè)集合用來(lái)保存這個(gè)DependencyGraph里面的那些跳轉(zhuǎn)指令節(jié)點(diǎn),這也是為什么上面再插入節(jié)點(diǎn)到lifted_exprs之前,需要判斷一下這個(gè)節(jié)點(diǎn)的類型是否為OpNode。另外如果當(dāng)前枚舉的節(jié)點(diǎn)有new_scope標(biāo)志,說(shuō)明當(dāng)前節(jié)點(diǎn)屬于一個(gè)新的scope,需要為當(dāng)前節(jié)點(diǎn)分配新的類型為ScopeNode的一個(gè)智能指針。
通過(guò)上面的算法,DependencyGraph中的節(jié)點(diǎn)和scope節(jié)點(diǎn)的關(guān)系就被映射到了一個(gè)map中,并且scope節(jié)點(diǎn)也被建立起了一個(gè)樹(shù)結(jié)構(gòu)。最后調(diào)用這個(gè)Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);來(lái)創(chuàng)建一個(gè)Fill類,這個(gè)類包含了DependencyGraph以及scope相關(guān)的信息,通過(guò)ToBasicBlockNormalForm成員函數(shù)實(shí)現(xiàn)基本塊轉(zhuǎn)換。實(shí)現(xiàn)在tvm/src/relay/transforms/to_a_normal_form.cc這個(gè)文件中,知乎對(duì)這個(gè)Pass也做了解釋,這里引用一下:
它(ToBasicBlockNormalForm)的基本邏輯通過(guò)VisitExpr函數(shù)遍歷dependency節(jié)點(diǎn),將具有相同scope的節(jié)點(diǎn)壓入到同一個(gè)let_list中。Let_list文檔中是這樣解釋的:
/! * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, * and pass them around freely without fear of AST explosion (or effect duplication). * for example, if one write ‘b = a + a; c = b + b; d = c + c’, the AST will contain 8 ‘a(chǎn)’. * if one instead write ‘b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);’, * the AST will contain 2 ‘a(chǎn)’, as b and c are now variables.
Let_list使得抽象語(yǔ)法樹(shù)簡(jiǎn)潔化,不會(huì)因?yàn)樽兞康膹?fù)制導(dǎo)致樹(shù)的爆炸。具有相同的scope的expr被約束到相同的let_list中,用一個(gè)var來(lái)表達(dá),這樣就將表達(dá)式轉(zhuǎn)化為var的形式。一個(gè)var也就對(duì)應(yīng)了一個(gè)基本塊。
EliminateCommonSubexpr最后再看一個(gè)消除公共子表達(dá)式的Pass,所謂公共子表達(dá)式指的就是具有相同的OP類型以及相同的參數(shù),參數(shù)的順序都是完全相同的,這些表達(dá)式就可以合成一個(gè)公共子表達(dá)式。舉個(gè)例子:
a = b + cd = b + c
可以看到這兩個(gè)表達(dá)式時(shí)完全一致的,經(jīng)過(guò)這個(gè)Pass之后,計(jì)算圖就會(huì)消除其中一個(gè)表達(dá)式。代碼實(shí)現(xiàn)在:tvm/src/relay/transforms/eliminate_common_subexpr.cc。這里定義了一個(gè)CommonSubexprEliminator類,這個(gè)類重載了兩個(gè)Rewrite_函數(shù),對(duì)expr進(jìn)行遍歷和重寫(xiě)。代碼實(shí)現(xiàn)如下:
Expr Rewrite_(const CallNode
call, const Expr& post)final{staticauto op_stateful = Op::GetAttrMap(“TOpIsStateful”); Expr new_expr = post;const CallNode* new_call = new_expr.as(); ICHECK(new_call);const OpNode* op = new_call->op.as(); StructuralEqual attrs_equal;if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef(op), false)) {return new_expr; }if (fskip_ != nullptr && fskip_(new_expr)) {return new_expr; }auto it = expr_map_.find(new_call->op);if (it != expr_map_.end()) {for (const Expr& candidate_expr : it->second) {if (const CallNode* candidate = candidate_expr.as()) {bool is_equivalent = true;// attrs匹配if (!attrs_equal(new_call->attrs, candidate->attrs)) {continue; }// args匹配for (size_t i = 0; i < new_call->args.size(); i++) {if (!new_call->args[i].same_as(candidate->args[i]) && !IsEqualScalar(new_call->args[i], candidate->args[i])) { is_equivalent = false;break; } }if (!is_equivalent) continue;return GetRef(candidate); } } } expr_map_[new_call->op].push_back(new_expr);return new_expr; }可以看到大概的思路就是利用expr_map_這個(gè)std::unordered_map<Expr, std::vector, ObjectPtrHash, ObjectPtrEqual> expr_map_;
映射遍歷過(guò)的具有相同op的expr,然后每次碰到相同op的表達(dá)式,都會(huì)對(duì)已經(jīng)記錄的expr進(jìn)行匹配,匹配不僅包含OP的attrs屬性,還包含參數(shù)列表,如果完全一樣,說(shuō)明這兩個(gè)表達(dá)式就是公共表達(dá)式,就不返回新的表達(dá)式。這樣就可以去掉Relay Function中的公共表達(dá)式了。
到這里可能還不是特別清楚最開(kāi)始加載的那個(gè)simplenet的Relay Function,經(jīng)過(guò)一些Pass之后,具體變成什么樣,其實(shí)目前也還沒(méi)搞清楚這個(gè)問(wèn)題,這個(gè)問(wèn)題應(yīng)該就需要留到后面再解答了。
小結(jié)
本文介紹了一下TVM的Relay,介紹了如何基于Relay構(gòu)建一個(gè)Conv+BN+ReLU的小網(wǎng)絡(luò),介紹了一下TVM中的Pass的工作機(jī)制,詳細(xì)的介紹了RemoveUnusedFunctions,ToBasicBlockNormalForm,EliminateCommonSubexpr三種Pass。其中Relay部分的詳細(xì)介紹大部分引用自官方文檔:https://tvm.apache.org/docs/tutorials/get_started/introduction.html。
0x6. 參考資料
https://zhuanlan.zhihu.com/p/358437531https://zhuanlan.zhihu.com/p/91283238https://tvm.apache.org/docs/tutorials/get_started/introduction.html

https://baijiahao.baidu.com/s?id=1700872402469787364&wfr=spider&for=pc

總結(jié)

以上是生活随笔為你收集整理的TVM,Relay,Pass的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。

如果覺(jué)得生活随笔網(wǎng)站內(nèi)容還不錯(cuò),歡迎將生活随笔推薦給好友。