代码拉取完成,页面将自动刷新
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.keras.layers import dense_attention
# input_shape = [(1, 112, 112, 24), (1, 112, 112, 96), (1, 56, 56, 144)]
# filter_shape = [(3, 3, 24, 1), (3, 3, 96, 1), (3, 3, 24, 1)]
# stride_shape = [(1, 1, 1, 1), (2, 2, 2, 2), (1, 1, 1, 1)]
# input = tf.placeholder(name="input", dtype=tf.float32, shape=input_shape[0])
# output = tf.pad(output, paddings=((0, 0), (1, 1), (1, 1), (0, 0)))
# output = tf.nn.conv2d(output,
# tf.get_variable("w0", dtype=tf.float32, shape=(1, 1, 3, 32)),
# strides=[1, 1, 1, 1],
# padding='SAME',
# # padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
# use_cudnn_on_gpu=False,
# data_format="NHWC",
# dilations=[1, 1, 1, 1])
# output = tf.nn.depthwise_conv2d_native(input,
# tf.get_variable("w1", dtype=tf.float32, shape=filter_shape[0]),
# strides=stride_shape[0],
# padding='SAME',
# data_format="NHWC",
# dilations=[1, 1, 1, 1])
# output = tf.nn.softmax(output)
# output = tf.identity(output, name="output")
# input = tf.constant(1., shape=[1, 8, 8, 1])
# # 卷积核的大小为3×3×1,个数为1
# w = tf.constant(1., shape=[3, 3, 1, 1])
# # 卷积:输出2×2的单通道图像
# result = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding='VALID')
# # 转置卷积:输出4×4的单通道图像
# output = tf.nn.conv2d_transpose(result, w, output_shape=[1, 8, 8, 1], strides=[
# 1, 1, 1, 1], padding='VALID', name="output")
# input = tf.placeholder(name="input", dtype=tf.int8, shape=[1, 7, 7, 144])
# output = tf.reshape(input, [144, 1, 7, 7])
input = tf.placeholder(name="input", dtype=tf.int8, shape=[1, 7, 7, 16])
output = tf.gather(input, [0, 4, 8, 12], axis=3)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# export pb: TensorFlow保存模型为PB文件 https://zhuanlan.zhihu.com/p/32887066
print(sess.graph_def.node)
constant_graph = graph_util.convert_variables_to_constants(
sess, sess.graph_def, [output.op.name])
file_name = output.op.name.lower() + '_' + str(output.dtype).split('\'')[-2]
with tf.gfile.FastGFile(file_name + '.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# export float32 tflite
converter = tf.lite.TFLiteConverter.from_session(sess, [input], [output])
open(file_name + ".tflite", "wb").write(converter.convert())
# # export quantized tflite
# converter = tf.lite.TFLiteConverter.from_session(sess, [input], [output])
# converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
# converter.quantized_input_stats = {"input": (128, 128)}
# converter.default_ranges_stats = (0, 1)
# open("uint8.tflite", "wb").write(converter.convert())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。