博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
巧妙使用 TensorFlow 之 TensorLayer
阅读量:6429 次
发布时间:2019-06-23

本文共 3605 字,大约阅读时间需要 12 分钟。

TensorFlow 1.0 版本刚刚正式发布,各路媒体大量刷屏,可见其影响力。再加上 TensorFlow-Fold 的出现,让其建立 dynamic computation graphs 非常方便,具备了 dynet 和 pytorch 类似的功能。同学们可能发现 1.0 版本改动非常大,很多 API 被重新命名、使用方法也变了很多,导致旧版本的 TF 代码无法在 1.0 版本上直接运行。另外,TF1.0 版本中终于加入了tf.layers,虽然功能有限,但比没有要好多了。

今天介绍一个工具 TensorLayer(以下简称 TL),虽然它实现了各种各样的层,也提供类似 Keras 的 fit(), test(), predict() 等方法,但我称其为工具而不是库,因为它大量的功能以函数形式提供,通过巧妙使用提供的函数,可以非常简洁高效地实现复杂的应用。比如 TL 提供了大量的数据增强和预处理函数,自己可以根据应用而个性化地组合这些函数(比如 image segmentation 时 X 和 Y 要对应处理)。TL 的 API 设计要求尽可能地输入 TF 本身的 API,这样的好处是可以和 TF 方便地交互使用,其  和  是很好的例子(当然它也提供  )。因此在我看来,使用 TL 的设计是为了巧妙地使用 TF,它提供的代码例子都跟随一种编写风格,方便社区统一风格以分享阅读代码。

这里根据最近使用 TL 的经验,我总结了一些使用小技巧,若写得不客观请见谅,当作是自己的笔记吧。

第一次在知乎写文章,写得不好看看就好

我现在在美帝读博,非常喜欢深度学习,欢迎交流 :D

*** 更多小技巧将陆续在()补充。

1. 安装

* 为了方便阅读和拓展 TL 代码,建议把整个项目下载下来(在terminal中输入 git clone ),然后把 tensorlayer 的文件夹放到你的项目中。

* 由于近期 TL 发展很快,若想用 pip 安装,建议安装 master 版本。

* 对于研究 NLP 的同学,可能需要安装  以使用文本分析API,这些功能封装在 tl.nlp 中。(若不使用则不需要安装)

2. TF与TL相互转换

* TF 转为 TL : 通过  把 tensor 输入到 tl.layers

* TL 转为 TF : 通过  获取 tensor

* 其它途径 [], 多输入 []

3. Training/Testing 切换

* 通过  来 disable/enable  (这只当使用  时才可以,可参考  和 

* 更好的方法是把 noise 层的 "is_fix" 设为 "True",然后对 Training 和 Testing 分别建立不同的graph,这需要用到 parameter sharing。除了控制training/testing,这个方法可以让建立 graph 时使用不一样的参数,如 batch_size,,  等。 例子如下:

def mlp(x, is_train=True, reuse=False):    with tf.variable_scope("MLP", reuse=reuse):      tl.layers.set_name_reuse(reuse)      net = InputLayer(x, name='in')      net = DropoutLayer(net, 0.8, True, is_train, 'drop1')      net = DenseLayer(net, 800, tf.nn.relu, 'dense1')      net = DropoutLayer(net, 0.8, True, is_train, 'drop2')      net = DenseLayer(net, 800, tf.nn.relu, 'dense2')      net = DropoutLayer(net, 0.8, True, is_train, 'drop3')      net = DenseLayer(net, 10, tf.identity, 'out')      logits = net.outputs      net.outputs = tf.nn.sigmoid(net.outputs)      return net, logitsx = tf.placeholder(tf.float32, shape=[None, 784], name='x')y_ = tf.placeholder(tf.int64, shape=[None, ], name='y_')net_train, logits = mlp(x, is_train=True, reuse=False)net_test, _ = mlp(x, is_train=False, reuse=True)cost = tl.cost.cross_entropy(logits, y_, name='cost')

4. 获取variables

TL非常特殊的一点是:需要给每层输入一个唯一的名字,除非 reuse 该层。我刚开始用时,完全不明白这样设计的道理,后来发现这样的好处是杜绝了错误重用和方便参数管理。

* 使用  获取参数列表,尽量少用 ,下面的例子获取了上面例子中全部函数,因为上面例子中的函数都置于 “MLP” 之下:

train_vars = tl.layers.get_variables_with_name('MLP', True, True)train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_vars)

* 这个方法常常用于选择哪些参数需要被更新,比如训练 GAN 时,可以分别获取 G 和 D 的参数列表,放到对应的 optimizer 中。

* 其它方法 [], [], []

5. 使用预训练的 CNN 及建立 Resnet

很多应用中需要用到预训练好的CNN模型,比如 image captioning, VQA, 以及在小数据集中 fine-tune 做分类器等等。

* 预训练的 CNN

。TL 网站上提供了 VGG16, VGG19, Inception 等例子,请见 

。此外通过  可以使用  中全部预训练好的模型!

* Resnet

。Implement by "for" loop []

。 Other methods []

6. 数据增强

* 使用 TF 提供的 TFRecord,参考 ; 这里介绍一个很好的工具: 

* TL提供了  来使用 python-threading,并提供了大量图像增强的函数: ,请参考 

7. 句子ID化(Sentences tokenization)

NLP中,词语需要转换为ID来处理,TL 的  提供了大量的方法,但我觉得下面的几个han s基本够用了。

* 使用  把句子分隔,对于中文推荐使用 

* 然后使用  来建立词汇表并保存成为 txt 文件,该函数还会返回一个 实例

* 最后建议从  保存的 txt 文件中实例化一个 ,以方便词语和数字ID之间的转换

* 更多文本处理函数请见  和 

8. Dynamic RNN 与 sequence length

* 使用  来帮助  自动计算每个句子的 sequence length

* 对一个 batch 的数据做 zero padding:

b_sentence_ids = tl.prepro.pad_sequences(b_sentence_ids, padding='post')

* 其它方法 []

9. 常见bug

* Matplotlib issue arise when importing TensorLayer [] [] (这个问题往往在远程连接 ubuntu 时出现)

10. 其它小技巧

* TL默认模式下,在执行每一个 layer 时会把相关信息显示到terminal中。但当你在建立非常深的网络时,这些信息没有太大帮助。因此可以通过  来禁止print输出:

print("You can see me")with tl.ops.suppress_stdout():    print("You can't see me") # 在这里建立模型print("You can see me")

Useful links

* TL official sites: [], [], []

转自:https://zhuanlan.zhihu.com/p/25296966

转载地址:http://utiga.baihongyu.com/

你可能感兴趣的文章
浅析HTML5的10大优势
查看>>
实例讲解基于 React+Redux 的前端开发流程
查看>>
[转]Vim配置与高级技巧
查看>>
查找SQL数据表或视图中的字段属性信息
查看>>
如何优化UPS的工作模式为数据中心节省运营成本
查看>>
使用python来访问Hadoop HDFS存储实现文件的操作
查看>>
靠能力赚大钱,是最最可笑的谎言
查看>>
WORD设置节起始页码后出现诡异隐藏页/跳页、节首页页面边框丢失
查看>>
团队文化之表扬和批评
查看>>
国家能源局:《电力企业网络与信息安全专项监管报告》
查看>>
Gartner:2012年SIEM(安全信息与事件管理)市场分析报告
查看>>
社交大革命,不可遏止的互联网春天
查看>>
也谈nginx的安全限制
查看>>
mongodb数据库问题三则
查看>>
【翻译】了解ASP.NET MVC的HTML助手
查看>>
老男孩:Linux运维岗位强于开发岗位的6点优势
查看>>
烂泥:mysql5.5多实例部署
查看>>
了解BYOD---工作方式的新时尚
查看>>
fastdfs binlog同步BUG
查看>>
基于DRBD构建高可用主从MySQL服务器
查看>>