如何插入8bit量化节点(tensorflow)
目錄
tf流圖graph基礎知識
默認圖
創建顯式圖
創建多個圖
調用tf偽量化接口
插入kernel、層間量化節點
tf流圖graph基礎知識
默認圖
import tensorflow as tf import numpy as npa = tf.constant(123) print(a.graph) print(tf.get_default_graph())輸出:
<tensorflow.python.framework.ops.Graph object at 0x7f32df766668>
<tensorflow.python.framework.ops.Graph object at 0x7f32df766668>
當tensorflow庫被加載時,即使用戶沒有顯示地創建一個圖,他也會自動創建一個圖對象,并將其作為默認的額數據流圖
創建顯式圖
import tensorflow as tf import numpy as npg= tf.Graph() #創建了一個圖對象g with g.as_default(): #返回一個上下文管理器,使得當前圖對象稱為當前默認圖對象a = tf.constant(123)print(a.graph) #a.graph:獲得a所在的圖print(tf.get_default_graph()) #get_default_graph:獲取當前圖對象的句柄輸出:
<tensorflow.python.framework.ops.Graph object at 0x7f71a7a2b710>
<tensorflow.python.framework.ops.Graph object at 0x7f71a7a2b710>
創建多個圖
?
import tensorflow as tf import numpy as npg1 = tf.Graph() g2 = tf.Graph()with g1.as_default():a = tf.constant(123)print(a.graph)print(tf.get_default_graph())with g2.as_default():b = tf.multiply(2, 3)print(b.graph)print(tf.get_default_graph())?輸出:
<tensorflow.python.framework.ops.Graph object at 0x7f570354fcc0>
<tensorflow.python.framework.ops.Graph object at 0x7f570354fcc0>
<tensorflow.python.framework.ops.Graph object at 0x7f5684813860>
<tensorflow.python.framework.ops.Graph object at 0x7f5684813860>
?
調用tf偽量化接口
class Quantizationint8(object):#初始化量化參數#n為量化bit數8或16,d是實際小數點位數def __init__(self, n, d):d = float(d)self._quant_min = -(2 ** (n - 1) - 1) * (2 ** d)self._quant_max = (2 ** (n - 1) - 1) * (2 ** d)self._num_bits = nself._narrow = True#利用tf接口函數計算偽量化值,并返回def __call__(self, inputs):return array_ops.fake_quant_with_min_max_vars(inputs, self._quant_min, self._quant_max,num_bits=self._num_bits, narrow_range=self._narrow)tf接口釋義
可參考:https://blog.csdn.net/weixin_36670529/article/details/100560469:?
?
插入kernel、層間量化節點
#g為靜態圖句柄,producer為全精度浮點數,quant_tensor為上面偽量化值節點的輸出 def insert_slim_quant_op(g, producer, name, quant_tensor):Tar_op= []for (op_name, op) in g._nodes_by_name.items():if producer in op.inputs._inputs and name in op_name:Tar_op.append(op)assert len(Tar_op) != 0, "do not have the Tar_op of node {}\n".format(producer.name)with variable_scope.variable_scope(name + '/SlimQuant'):graph_editor.reroute_ts([quant_tensor], [producer], can_modify=Tar_op)?
總結
以上是生活随笔為你收集整理的如何插入8bit量化节点(tensorflow)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: smooth l1(huber)+bin
- 下一篇: Check failed: error