您好,欢迎来到欧得旅游网。
搜索
您的当前位置:首页那些年,我们踩过的Batch_normlization的坑

那些年,我们踩过的Batch_normlization的坑

来源:欧得旅游网

先说一下batch_normlization的原理吧,我就不复制粘贴了,大家可以去这个博客看一下,原理讲的挺好的

下面主要是我在实践当中遇到的问题。

我用tensorflow.slim模块进行训练的时候使用了bn层, 当时没有考虑那么多,就直接训练了

def bn_relu(x,flag):
        x = slim.batch_norm(x)
        x = tf.nn.relu(x)
        return x

但是到预估的时候发现,当一张图一张图的去做测试的时候,会发现效果很差,在前辈的指导下,知道问题出现在了BN层,只有batchsize和训练的时候保持一致或者接近时才会出现比较好的效果,这就和BN层的均值和方差有关,查完资料以后修改如下:

 def bn_relu(x,flag):
        x = slim.batch_norm(x,is_training=flag)#flag标志是训练还是测试,True or False,默认为True
        x = tf.nn.relu(x)
        return x

在训练的时候要添加如下语句:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #在训练之前添加这条语句
with tf.control_dependencies(update_ops): 
    train = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

其他的保持不变,然后训练完一个epoch,保存新的模型,用新的模型再去做预测的时候 ,将flag改成False即可。

关于is_training的解释:

is_training: Whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and  `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`.

关于训练时要添加的语句的解释:

注意一但使用batch_norm层,在训练节点定义时需要添加一些语句,slim.batch_norm里有moving_mean和moving_variance两个量,分别表示每个批次的均值和方差。在训练时还好理解,但在测试时,moving_mean和moving_variance的含义变了,在训练时,注意tf本体的BN操作也要这步操作

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 

    with tf.control_dependencies(update_ops): 

        train_step = tf.train.GradientDescentOptimizer(0.01).minimize(total_loss) 

# 注意并tf本体的batch_normal操作也需要这步操作

# 其中,tf.control_dependencies(update_ops)表示with段中的操作是在update_ops操作执行之后 再执行的

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- ovod.cn 版权所有 湘ICP备2023023988号-4

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务