博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras Data augmentation(数据扩充)
阅读量:6439 次
发布时间:2019-06-23

本文共 3971 字,大约阅读时间需要 13 分钟。

       在深度学习中,我们经常需要用到一些技巧(比如将图片进行旋转,翻转等)来进行data augmentation, 来减少过拟合。 在本文中,我们将主要介绍如何用深度学习框架keras来自动的进行data augmentation。

keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,    samplewise_center=False,    featurewise_std_normalization=False,    samplewise_std_normalization=False,    zca_whitening=False,    zca_epsilon=1e-6,    rotation_range=0.,    width_shift_range=0.,    height_shift_range=0.,    shear_range=0.,    zoom_range=0.,    channel_shift_range=0.,    fill_mode='nearest',    cval=0.,    horizontal_flip=False,    vertical_flip=False,    rescale=None,    preprocessing_function=None,    data_format=K.image_data_format())
  • 生成批次的带实时数据增益的张量图像数据。数据将按批次无限循环。
  • 参数
    • featurewise_center: 布尔值。将输入数据的均值设置为 0,逐特征进行。
    • samplewise_center: 布尔值。将每个样本的均值设置为 0。
    • featurewise_std_normalization: 布尔值。将输入除以数据标准差,逐特征进行。
    • samplewise_std_normalization: 布尔值。将每个输入除以其标准差。
    • zca_epsilon: ZCA 白化的 epsilon 值,默认为 1e-6。
    • zca_whitening: 布尔值。应用 ZCA 白化。
    • rotation_range: 整数。随机旋转的度数范围。
    • width_shift_range: 浮点数(总宽度的比例)。随机水平移动的范围。
    • height_shift_range: 浮点数(总高度的比例)。随机垂直移动的范围。
    • shear_range: 浮点数。剪切强度(以弧度逆时针方向剪切角度)。
    • zoom_range: 浮点数 或 [lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]
    • channel_shift_range: 浮点数。随机通道转换的范围。
    • fill_mode: {"constant", "nearest", "reflect" or "wrap"} 之一。输入边界以外的点根据给定的模式填充:
      • "constant": kkkkkkkk|abcd|kkkkkkkk (cval=k)
      • "nearest": aaaaaaaa|abcd|dddddddd
      • "reflect": abcddcba|abcd|dcbaabcd
      • "wrap": abcdabcd|abcd|abcdabcd
    • cval: 浮点数或整数。用于边界之外的点的值,当 fill_mode = "constant" 时。
    • horizontal_flip: 布尔值。随机水平翻转。
    • vertical_flip: 布尔值。随机垂直翻转。
    • rescale: 重缩放因子。默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之前)。
    • preprocessing_function: 应用于每个输入的函数。这个函数会在任何其他改变之前运行。这个函数需要一个参数:一张图像(秩为 3 的 Numpy 张量),并且应该输出一个同尺寸的 Numpy 张量。
    • data_format: {"channels_first", "channels_last"} 之一。"channels_last" 模式表示输入尺寸应该为 (samples, height, width, channels),"channels_first" 模式表示输入尺寸应该为 (samples, channels, height, width)。默认为 在 Keras 配置文件 ~/.keras/keras.json 中的 image_data_format 值。如果你从未设置它,那它就是 "channels_last"。
       
  • 方法:
  • fit(x): 根据一组样本数据,计算与数据相关转换有关的内部数据统计信息。当且仅当 featurewise_center 或 featurewise_std_normalization 或 zca_whitening 时才需要。
  • flow(x, y): 传入 Numpy 数据和标签数组,生成批次的 增益的/标准化的 数据。在生成的批次数据上无限制地无限次循环。
  • flow_from_directory(directory): 以目录路径为参数,生成批次的 增益的/标准化的 数据。在生成的批次数据上无限制地无限次循环。
from keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_imgdatagen=ImageDataGenerator(    rotation_range=40,    width_shift_range=0.2,    height_shift_range=0.2,    shear_range=0.2,    zoom_range=0.2,    horizontal_flip=True,    fill_mode='nearest')img=load_img("test.jpg")x=img_to_array(img) # 把PIL图像格式转换成numpy格式x=x.reshape((1,)+x.shape)i=0for batch in datagen.flow(x,batch_size=2,save_to_dir="datagen",save_prefix="cat",save_format="jpeg"):    i+=1    if i>10:        break

其他注意api:

compile

compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

用于配置训练模型。

fit

fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

以固定数量的轮次(数据集上的迭代)训练模型。

fit_generator

fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

使用 Python 生成器逐批生成的数据,按批次训练模型。

evaluate

evaluate(self, x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None)

在测试模式下返回模型的误差值和评估标准值。

evaluate_generator

evaluate_generator(self, generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False)

在数据生成器上评估模型。

predict

predict(self, x, batch_size=None, verbose=0, steps=None)

为输入样本生成输出预测。

predict_generator

predict_generator(self, generator, steps=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

为来自数据生成器的输入样本生成预测。

转载地址:http://khzwo.baihongyu.com/

你可能感兴趣的文章
AsyncTask的小分析
查看>>
使用Redis实现关注关系
查看>>
Go抓取网页数据并存入MySQL和返回json数据<三>
查看>>
MySQL复制介绍及搭建
查看>>
Java在线调试工具
查看>>
[译]CSS-理解百分比的background-position
查看>>
虚拟机安装CentOS
查看>>
Idea里面老版本MapReduce设置FileInputFormat参数格式变化
查看>>
在 win10 环境下,设置自己写的 程序 开机自动 启动的方法
查看>>
Unity3d游戏开发之-单例设计模式-多线程一
查看>>
通过jquery定位元素
查看>>
Tooltip表单验证的注册表单
查看>>
UWP开发中两种网络图片缓存方法
查看>>
超8千Star,火遍Github的Python反直觉案例集!
查看>>
【msdn wpf forum翻译】如何在wpf程序(程序激活时)中捕获所有的键盘输入,而不管哪个元素获得焦点?...
查看>>
全球首家!阿里云获GNTC2018 网络创新大奖 成唯一获奖云服务商
查看>>
Python简单HttpServer
查看>>
Java LinkedList工作原理及实现
查看>>
负载均衡SLB的基本使用
查看>>
Centos 7 x86 安装JDK
查看>>