1 Star 0 Fork 0

wandongdong/tflite2opencl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
create_pb_file.py 3.16 KB
一键复制 编辑 原始数据 按行查看 历史
wandongdong 提交于 2021-01-18 01:10 . add gather tflite op
from tensorflow.quantization import fake_quant_with_min_max_vars
from tensorflow.python.framework import graph_util
import math
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.keras.layers import dense_attention
# input_shape = [(1, 112, 112, 24), (1, 112, 112, 96), (1, 56, 56, 144)]
# filter_shape = [(3, 3, 24, 1), (3, 3, 96, 1), (3, 3, 24, 1)]
# stride_shape = [(1, 1, 1, 1), (2, 2, 2, 2), (1, 1, 1, 1)]
# input = tf.placeholder(name="input", dtype=tf.float32, shape=input_shape[0])
# output = tf.pad(output, paddings=((0, 0), (1, 1), (1, 1), (0, 0)))
# output = tf.nn.conv2d(output,
# tf.get_variable("w0", dtype=tf.float32, shape=(1, 1, 3, 32)),
# strides=[1, 1, 1, 1],
# padding='SAME',
# # padding=[[0, 0], [0, 0], [0, 0], [0, 0]],
# use_cudnn_on_gpu=False,
# data_format="NHWC",
# dilations=[1, 1, 1, 1])
# output = tf.nn.depthwise_conv2d_native(input,
# tf.get_variable("w1", dtype=tf.float32, shape=filter_shape[0]),
# strides=stride_shape[0],
# padding='SAME',
# data_format="NHWC",
# dilations=[1, 1, 1, 1])
# output = tf.nn.softmax(output)
# output = tf.identity(output, name="output")
# input = tf.constant(1., shape=[1, 8, 8, 1])
# # 卷积核的大小为3×3×1,个数为1
# w = tf.constant(1., shape=[3, 3, 1, 1])
# # 卷积:输出2×2的单通道图像
# result = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding='VALID')
# # 转置卷积:输出4×4的单通道图像
# output = tf.nn.conv2d_transpose(result, w, output_shape=[1, 8, 8, 1], strides=[
# 1, 1, 1, 1], padding='VALID', name="output")
# input = tf.placeholder(name="input", dtype=tf.int8, shape=[1, 7, 7, 144])
# output = tf.reshape(input, [144, 1, 7, 7])
input = tf.placeholder(name="input", dtype=tf.int8, shape=[1, 7, 7, 16])
output = tf.gather(input, [0, 4, 8, 12], axis=3)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# export pb: TensorFlow保存模型为PB文件 https://zhuanlan.zhihu.com/p/32887066
print(sess.graph_def.node)
constant_graph = graph_util.convert_variables_to_constants(
sess, sess.graph_def, [output.op.name])
file_name = output.op.name.lower() + '_' + str(output.dtype).split('\'')[-2]
with tf.gfile.FastGFile(file_name + '.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# export float32 tflite
converter = tf.lite.TFLiteConverter.from_session(sess, [input], [output])
open(file_name + ".tflite", "wb").write(converter.convert())
# # export quantized tflite
# converter = tf.lite.TFLiteConverter.from_session(sess, [input], [output])
# converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
# converter.quantized_input_stats = {"input": (128, 128)}
# converter.default_ranges_stats = (0, 1)
# open("uint8.tflite", "wb").write(converter.convert())
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ddwsky/tflite2opencl.git
git@gitee.com:ddwsky/tflite2opencl.git
ddwsky
tflite2opencl
tflite2opencl
master

搜索帮助