问答:如何解决TensorFlow保存的pb模型载入时报错?
By 青衣极客 Blue Geek In 2020-07-02
使用tensorflow训练模型并部署的朋友很可能会有这样一个需求:将checkpoint(ckpt)模型转换成protobuf(pb)格式,然后在部署时直接载入pb格式的模型。这样做有两方面的好处:(1) 网络结构保密,因为发布的pb模型可以选择为二进制的格式,不容易被窥探;(2) 多平台调用快速且方便,OpenCV、Python以及C++中调用pb模型会非常容易,而且由于使用二进制存储,文件比较小,载入速度也快。
但是,如果你的网络结构中使用了Batch Normalization,可能在载入时会遇到这样的一个报错:
Invalid argument: Input 0 of node * was passed float from * incompatible with expected float_ref.
这个报错显示的是“程序需要float类型的指针,但是传进去的却是 float 类型的变量”,当然,这里的报错也可能是其他类型,比如 int32,这些都属于同一类错误,即:传递不可修改的变量给可修改的逻辑。
为什么会出现这样的错误呢?经过一番搜索之后,终于找到了原因。
我们知道Batch Normalization的结构中有需要学习的参数,这些参数使用了 “AssignAdd” 和 “AssignSub” 这一类的操作来对变量进行更新。 原本tensorflow在导出网络参数时,应该将这种操作转换为 “Add” 和“Sub”,因为在应用模型的时候,这些参数并不需要改变。但是,tensorflow的接口并没有考虑到这一点,所以我们在导出pb模型的时候,仍然使用了 “Assign*” 的算子。 于是,在载入时就报错了,也就是上面展示的那种错误。
那么,该怎么来解决这个问题呢?
遇到这种情况,除了在导出时手动编写代码把不合适的算子名称修复一下,或许也没有其他更好的办法。所以,这里给出修复的代码:
# for fixing the bug of batch norm
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
如果你想确认一下修复是否真的生效,可以导出pb模型时选择 “text” 格式,然后打开这个文本文件看看。
这是目前能找到的一个比较简单的解决方案,如果还有其他方案,可以在后台留言讨论。
问题的解决得益于github上的一个issue,感兴趣的朋友可以直接访问:链接

COMMENT
博客评论区功能由Github Issue提供,提交Issue时请以本文标题为话题。
"BG112-问答:如何解决TensorFlow保存的pb模型载入时报错?"