Tensorflow源码解析3 -- TensorFlow核心对象 - Graph
1 Graph概述
計(jì)算圖Graph是TensorFlow的核心對(duì)象,TensorFlow的運(yùn)行流程基本都是圍繞它進(jìn)行的。包括圖的構(gòu)建、傳遞、剪枝、按worker分裂、按設(shè)備二次分裂、執(zhí)行、注銷等。因此理解計(jì)算圖Graph對(duì)掌握TensorFlow運(yùn)行尤為關(guān)鍵。
2 默認(rèn)Graph
默認(rèn)圖替換
之前講解Session的時(shí)候就說(shuō)過(guò),一個(gè)Session只能run一個(gè)Graph,但一個(gè)Graph可以運(yùn)行在多個(gè)Session中。常見(jiàn)情況是,session會(huì)運(yùn)行全局唯一的隱式的默認(rèn)的Graph,operation也是注冊(cè)到這個(gè)Graph中。
也可以顯示創(chuàng)建Graph,并調(diào)用as_default()使他替換默認(rèn)Graph。在該上下文管理器中創(chuàng)建的op都會(huì)注冊(cè)到這個(gè)graph中。退出上下文管理器后,則恢復(fù)原來(lái)的默認(rèn)graph。一般情況下,我們不用顯式創(chuàng)建Graph,使用系統(tǒng)創(chuàng)建的那個(gè)默認(rèn)Graph即可。
print tf.get_default_graph()with tf.Graph().as_default() as g:print tf.get_default_graph() is gprint tf.get_default_graph()print tf.get_default_graph()輸出如下
<tensorflow.python.framework.ops.Graph object at 0x106329fd0> True <tensorflow.python.framework.ops.Graph object at 0x18205cc0d0> <tensorflow.python.framework.ops.Graph object at 0x10d025fd0>由此可見(jiàn),在上下文管理器中,當(dāng)前線程的默認(rèn)圖被替換了,而退出上下文管理后,則恢復(fù)為了原來(lái)的默認(rèn)圖。
默認(rèn)圖管理
默認(rèn)graph和默認(rèn)session一樣,也是線程作用域的。當(dāng)前線程中,永遠(yuǎn)都有且僅有一個(gè)graph為默認(rèn)圖。TensorFlow同樣通過(guò)棧來(lái)管理線程的默認(rèn)graph。
@tf_export("Graph") class Graph(object):# 替換線程默認(rèn)圖def as_default(self):return _default_graph_stack.get_controller(self)# 棧式管理,push pop@tf_contextlib.contextmanagerdef get_controller(self, default):try:context.context_stack.push(default.building_function, default.as_default)finally:context.context_stack.pop()替換默認(rèn)圖采用了堆棧的管理方式,通過(guò)push pop操作進(jìn)行管理。獲取默認(rèn)圖的操作如下,通過(guò)默認(rèn)graph棧_default_graph_stack來(lái)獲取。
@tf_export("get_default_graph") def get_default_graph():return _default_graph_stack.get_default()下面來(lái)看_default_graph_stack的創(chuàng)建
_default_graph_stack = _DefaultGraphStack() class _DefaultGraphStack(_DefaultStack): def __init__(self):# 調(diào)用父類來(lái)創(chuàng)建super(_DefaultGraphStack, self).__init__()self._global_default_graph = Noneclass _DefaultStack(threading.local):def __init__(self):super(_DefaultStack, self).__init__()self._enforce_nesting = True# 和默認(rèn)session棧一樣,本質(zhì)上也是一個(gè)listself.stack = []_default_graph_stack的創(chuàng)建如上所示,最終和默認(rèn)session棧一樣,本質(zhì)上也是一個(gè)list。
3 前端Graph數(shù)據(jù)結(jié)構(gòu)
Graph數(shù)據(jù)結(jié)構(gòu)
理解一個(gè)對(duì)象,先從它的數(shù)據(jù)結(jié)構(gòu)開(kāi)始。我們先來(lái)看Python前端中,Graph的數(shù)據(jù)結(jié)構(gòu)。Graph主要的成員變量是Operation和Tensor。Operation是Graph的節(jié)點(diǎn),它代表了運(yùn)算算子。Tensor是Graph的邊,它代表了運(yùn)算數(shù)據(jù)。
@tf_export("Graph") class Graph(object):def __init__(self):# 加線程鎖,使得注冊(cè)op時(shí),不會(huì)有其他線程注冊(cè)op到graph中,從而保證共享graph是線程安全的self._lock = threading.Lock()# op相關(guān)數(shù)據(jù)。# 為graph的每個(gè)op分配一個(gè)id,通過(guò)id可以快速索引到相關(guān)op。故創(chuàng)建了_nodes_by_id字典self._nodes_by_id = dict() # GUARDED_BY(self._lock)self._next_id_counter = 0 # GUARDED_BY(self._lock)# 同時(shí)也可以通過(guò)name來(lái)快速索引op,故創(chuàng)建了_nodes_by_name字典self._nodes_by_name = dict() # GUARDED_BY(self._lock)self._version = 0 # GUARDED_BY(self._lock)# tensor相關(guān)數(shù)據(jù)。# 處理tensor的placeholderself._handle_feeders = {}# 處理tensor的read操作self._handle_readers = {}# 處理tensor的move操作self._handle_movers = {}# 處理tensor的delete操作self._handle_deleters = {}下面看graph如何添加op的,以及保證線程安全的。
def _add_op(self, op):# graph被設(shè)置為final后,就是只讀的了,不能添加op了。self._check_not_finalized()# 保證共享graph的線程安全with self._lock:# 將op以id和name分別構(gòu)建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便后續(xù)快速索引self._nodes_by_id[op._id] = opself._nodes_by_name[op.name] = opself._version = max(self._version, op._id)GraphKeys 圖分組
每個(gè)Operation節(jié)點(diǎn)都有一個(gè)特定的標(biāo)簽,從而實(shí)現(xiàn)節(jié)點(diǎn)的分類。相同標(biāo)簽的節(jié)點(diǎn)歸為一類,放到同一個(gè)Collection中。標(biāo)簽是一個(gè)唯一的GraphKey,GraphKey被定義在類GraphKeys中,如下
@tf_export("GraphKeys") class GraphKeys(object):GLOBAL_VARIABLES = "variables"QUEUE_RUNNERS = "queue_runners"SAVERS = "savers"WEIGHTS = "weights"BIASES = "biases"ACTIVATIONS = "activations"UPDATE_OPS = "update_ops"LOSSES = "losses"TRAIN_OP = "train_op"# 省略其他name_scope 節(jié)點(diǎn)命名空間
使用name_scope對(duì)graph中的節(jié)點(diǎn)進(jìn)行層次化管理,上下層之間通過(guò)斜杠分隔。
# graph節(jié)點(diǎn)命名空間 g = tf.get_default_graph() with g.name_scope("scope1"):c = tf.constant("hello, world", name="c")print c.op.namewith g.name_scope("scope2"):c = tf.constant("hello, world", name="c")print c.op.name輸出如下
scope1/c scope1/scope2/c # 內(nèi)層的scope會(huì)繼承外層的,類似于棧,形成層次化管理4 后端Graph數(shù)據(jù)結(jié)構(gòu)
Graph
先來(lái)看graph.h文件中的Graph類的定義,只看關(guān)鍵代碼
class Graph {private:// 所有已知的op計(jì)算函數(shù)的注冊(cè)表FunctionLibraryDefinition ops_;// GraphDef版本號(hào)const std::unique_ptr<VersionDef> versions_;// 節(jié)點(diǎn)node列表,通過(guò)id來(lái)訪問(wèn)std::vector<Node*> nodes_;// node個(gè)數(shù)int64 num_nodes_ = 0;// 邊edge列表,通過(guò)id來(lái)訪問(wèn)std::vector<Edge*> edges_;// graph中非空edge的數(shù)目int num_edges_ = 0;// 已分配了內(nèi)存,但還沒(méi)使用的node和edgestd::vector<Node*> free_nodes_;std::vector<Edge*> free_edges_;}后端中的Graph主要成員也是節(jié)點(diǎn)node和邊edge。節(jié)點(diǎn)node為計(jì)算算子Operation,邊為算子所需要的數(shù)據(jù),或者代表節(jié)點(diǎn)間的依賴關(guān)系。這一點(diǎn)和Python中的定義相似。邊Edge的持有它的源節(jié)點(diǎn)和目標(biāo)節(jié)點(diǎn)的指針,從而將兩個(gè)節(jié)點(diǎn)連接起來(lái)。下面看Edge類的定義。
Edge
class Edge {private:Edge() {}friend class EdgeSetTest;friend class Graph;// 源節(jié)點(diǎn), 邊的數(shù)據(jù)就來(lái)源于源節(jié)點(diǎn)的計(jì)算。源節(jié)點(diǎn)是邊的生產(chǎn)者Node* src_;// 目標(biāo)節(jié)點(diǎn),邊的數(shù)據(jù)提供給目標(biāo)節(jié)點(diǎn)進(jìn)行計(jì)算。目標(biāo)節(jié)點(diǎn)是邊的消費(fèi)者Node* dst_;// 邊id,也就是邊的標(biāo)識(shí)符int id_;// 表示當(dāng)前邊為源節(jié)點(diǎn)的第src_output_條邊。源節(jié)點(diǎn)可能會(huì)有多條輸出邊int src_output_;// 表示當(dāng)前邊為目標(biāo)節(jié)點(diǎn)的第dst_input_條邊。目標(biāo)節(jié)點(diǎn)可能會(huì)有多條輸入邊。int dst_input_; };Edge既可以承載tensor數(shù)據(jù),提供給節(jié)點(diǎn)Operation進(jìn)行運(yùn)算,也可以用來(lái)表示節(jié)點(diǎn)之間有依賴關(guān)系。對(duì)于表示節(jié)點(diǎn)依賴的邊,其src_output_, dst_input_均為-1,此時(shí)邊不承載任何數(shù)據(jù)。
下面來(lái)看Node類的定義。
Node
class Node {public:// NodeDef,節(jié)點(diǎn)算子Operation的信息,比如op分配到哪個(gè)設(shè)備上了,op的名字等,運(yùn)行時(shí)有可能變化。const NodeDef& def() const;// OpDef, 節(jié)點(diǎn)算子Operation的元數(shù)據(jù),不會(huì)變的。比如Operation的入?yún)⒘斜?#xff0c;出參列表等const OpDef& op_def() const;private:// 輸入邊,傳遞數(shù)據(jù)給節(jié)點(diǎn)。可能有多條EdgeSet in_edges_;// 輸出邊,節(jié)點(diǎn)計(jì)算后得到的數(shù)據(jù)。可能有多條EdgeSet out_edges_; }節(jié)點(diǎn)Node中包含的主要數(shù)據(jù)有輸入邊和輸出邊的集合,從而能夠由Node找到跟他關(guān)聯(lián)的所有邊。Node中還包含NodeDef和OpDef兩個(gè)成員。NodeDef表示節(jié)點(diǎn)算子的信息,運(yùn)行時(shí)可能會(huì)變,創(chuàng)建Node時(shí)會(huì)new一個(gè)NodeDef對(duì)象。OpDef表示節(jié)點(diǎn)算子的元信息,運(yùn)行時(shí)不會(huì)變,創(chuàng)建Node時(shí)不需要new OpDef,只需要從OpDef倉(cāng)庫(kù)中取出即可。因?yàn)樵畔⑹谴_定的,比如Operation的入?yún)€(gè)數(shù)等。
由Node和Edge,即可以組成圖Graph,通過(guò)任何節(jié)點(diǎn)和任何邊,都可以遍歷完整圖。Graph執(zhí)行計(jì)算時(shí),按照拓?fù)浣Y(jié)構(gòu),依次執(zhí)行每個(gè)Node的op計(jì)算,最終即可得到輸出結(jié)果。入度為0的節(jié)點(diǎn),也就是依賴數(shù)據(jù)已經(jīng)準(zhǔn)備好的節(jié)點(diǎn),可以并發(fā)執(zhí)行,從而提高運(yùn)行效率。
系統(tǒng)中存在默認(rèn)的Graph,初始化Graph時(shí),會(huì)添加一個(gè)Source節(jié)點(diǎn)和Sink節(jié)點(diǎn)。Source表示Graph的起始節(jié)點(diǎn),Sink為終止節(jié)點(diǎn)。Source的id為0,Sink的id為1,其他節(jié)點(diǎn)id均大于1.
5 Graph運(yùn)行時(shí)生命周期
Graph是TensorFlow的核心對(duì)象,TensorFlow的運(yùn)行均是圍繞Graph進(jìn)行的。運(yùn)行時(shí)Graph大致經(jīng)過(guò)了以下階段
這些階段根據(jù)TensorFlow運(yùn)行時(shí)的不同,會(huì)進(jìn)行不同的處理。運(yùn)行時(shí)有兩種,本地運(yùn)行時(shí)和分布式運(yùn)行時(shí)。故Graph生命周期到后面分析本地運(yùn)行時(shí)和分布式運(yùn)行時(shí)的時(shí)候,再詳細(xì)講解。
本文作者:揚(yáng)易
閱讀原文
本文為云棲社區(qū)原創(chuàng)內(nèi)容,未經(jīng)允許不得轉(zhuǎn)載。
總結(jié)
以上是生活随笔為你收集整理的Tensorflow源码解析3 -- TensorFlow核心对象 - Graph的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 带上一份技能地图
- 下一篇: 归纳DOM事件中各种阻止方法