博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow之dropout
阅读量:4937 次
发布时间:2019-06-11

本文共 2762 字,大约阅读时间需要 9 分钟。

训练神经网络模型时,如果训练样本较少,为了防止模型过拟合,可以使用Dropout, Dropout是指在模型训练时随机让网络某些隐含层节点的权重不工作,不工作的那些节点可以暂时认为不是网络结构的一部分,但是它的权重得保留下来(只是暂时不更新而已),因为下次样本输入时它可能又得工作了(有点抽象,具体实现看后面的实验部分)。

#!/usr/bin/env python2# -*- coding: utf-8 -*-"""tensorflow dropoutdropout一般用在全连接的部分,卷积部分不会用到dropout,输出曾也不会使用dropout,适用范围[输入,输出)只用在训练集,不用在测试集"""import tensorflow as tffrom sklearn.datasets import load_digitsfrom sklearn.model_selection import train_test_splitfrom sklearn.preprocessing import LabelBinarizer#加载数据digits = load_digits()X = digits.datay = digits.targety = LabelBinarizer().fit_transform(y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)def add_layer(inputs, in_size, out_size, layer_name, activation_function=None, keep_prob=1.0):    Weights = tf.Variable(tf.random_normal([in_size, out_size]))    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )    Wx_plus_b = tf.matmul(inputs, Weights) + biases        # 这里做 dropout    Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)        if activation_function is None:        outputs = Wx_plus_b    else:        outputs = activation_function(Wx_plus_b, )    tf.summary.histogram(layer_name + '/outputs', outputs)    return outputsxs = tf.placeholder(tf.float32, [None, 64])  # 8x8ys = tf.placeholder(tf.float32, [None, 10])#添加输出层l1 = add_layer(xs, 64 ,50, 'l1', activation_function=tf.nn.tanh)prediction = add_layer(l1, 50, 10, 'l2', activation_function=tf.nn.softmax)cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),                                              reduction_indices=[1]))  # loss                                              #scalar_summary记录存数值,用于画图tf.summary.scalar('loss', cross_entropy)train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#定义每次做连接的神经元个数keep_prob = tf.placeholder(tf.float32)sess = tf.Session()##在TensorFlow中,所有的操作只有当你执行,或者另一个操作依赖于它的输出时才会运行。#我们刚才创建的这些节点(summary nodes)都围绕着你的图像:没有任何操作依赖于它们的结果。#因此,为了生成汇总信息,我们需要运行所有这些节点。这样的手动工作是很乏味的,#因此可以使用tf.merge_all_summaries来将他们合并为一个操作。#http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/summaries_and_tensorboard.htmlmerged = tf.summary.merge_all()#summarytrain_writer = tf.train.summary.FileWriter("logs/train", sess.graph)test_writer = tf.train..summary.FileWriter("logs/test", sess.graph)sess.run(tf.initialize_all_variables())for i in range(500):    sess.run(train_step, feed_dict={xs: X_train, ys: y_train})    if i % 50 == 0:        # record loss        # record loss        train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})        test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})        train_writer.add_summary(train_result, i)        test_writer.add_summary(test_result, i)

 

转载于:https://www.cnblogs.com/xmeo/p/7218797.html

你可能感兴趣的文章
extjs双层表头
查看>>
ajax请求插件vue-resource的学习
查看>>
网络相册产品分析(一):十年需求变迁
查看>>
ssh配置详解及公私钥批量分发
查看>>
JsDoc应用与配置
查看>>
beta冲刺4
查看>>
DM9000网卡驱动分析(转)
查看>>
如何分析解决Android ANR
查看>>
虚拟内存
查看>>
Python入门学习笔记07(time)
查看>>
Java错误和异常解析
查看>>
.net core 的图片处理及二维码的生成及解析
查看>>
ASP.NET Core 启动流程图
查看>>
从PRISM开始学WPF(四)Prism-Module-更新至Prism7.1
查看>>
.net 框架
查看>>
Docker的使用初探(一):常用指令说明
查看>>
.net core实践系列之短信服务-目录
查看>>
WPF中 PropertyPath XAML 语法
查看>>
Wix 安装部署教程(四) 添加安装文件及快捷方式
查看>>
Win10 IoT C#开发 3 - GPIO Pin 控制发光二极管
查看>>