keras图像预处理

[TOC]

class: ImageDataGenerator

生成批次的带实时数据增益的张量图像数据。数据将按批次无限循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=0.,
width_shift_range=0.,
height_shift_range=0.,
shear_range=0.,
zoom_range=0.,
channel_shift_range=0.,
fill_mode='nearest',
cval=0.,
horizontal_flip=False,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=K.image_data_format())

常用参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
rotation_range: 整数。随机旋转的度数范围。
width_shift_range: 浮点数(总宽度的比例)。随机水平移动的范围。
height_shift_range: 浮点数(总高度的比例)。随机垂直移动的范围。
shear_range: 浮点数。剪切强度(以弧度逆时针方向剪切角度)。
zoom_range: 浮点数 或 [lower, upper]。随机缩放范围。如果是浮点数,[lower, upper] = [1-zoom_range, 1+zoom_range]。
channel_shift_range: 浮点数。随机通道转换的范围。
fill_mode: {"constant", "nearest", "reflect" or "wrap"} 之一。输入边界以外的点根据给定的模式填充:
"constant": kkkkkkkk|abcd|kkkkkkkk (cval=k)
"nearest": aaaaaaaa|abcd|dddddddd
"reflect": abcddcba|abcd|dcbaabcd
"wrap": abcdabcd|abcd|abcdabcd
cval: 浮点数或整数。当 fill_mode = "constant" 时,用于边界之外的点的值。
horizontal_flip: 布尔值。随机水平翻转。
vertical_flip: 布尔值。随机垂直翻转。
rescale: 重缩放因子。默认为 None。如果是 None 或 0,不进行缩放,否则将数据乘以所提供的值(在应用任何其他转换之前)。

类的方法

类的方法我就用了两个:flowflow_from_directory

flow():

传入 Numpy 数据和标签数组,生成批次的 增益的/标准化的 数据。在生成的批次数据上无限制地无限次循环。

  • 参数
1
2
3
4
5
6
7
8
x: 数据。秩应该为 4。在灰度数据的情况下,通道轴的值应该为 1,在 RGB 数据的情况下,它应该为 3。
y: 标签。
batch_size: 整数(默认 32)。
shuffle: 布尔值(默认 True)。
seed: 整数(默认 None)。
save_to_dir: None 或 字符串(默认 None)。这使你可以最佳地指定正在生成的增强图片要保存的目录(用于可视化你在做什么)。
save_prefix: 字符串(默认 '')。 保存图片的文件名前缀(仅当 save_to_dir 设置时可用)。
save_format: "png", "jpeg" 之一(仅当 save_to_dir 设置时可用)。默认:"png"。
  • yields: 元组 (x, y),其中 x 是图像数据的 Numpy 数组,y 是相应标签的 Numpy 数组。生成器将无限循环。

flow_from_directory():

以目录路径为参数,生成批次的 增益的/标准化的 数据。在生成的批次数据上无限制地无限次循环。

  • 参数
1
2
3
4
5
6
7
8
9
10
11
12
directory: 目标目录的路径。每个类应该包含至少一个子目录。任何在子目录下的图像,都将被包含在生成器中。
target_size: 整数元组 (height, width),默认:(256, 256)。所有的图像将被调整到的尺寸。
color_mode: "grayscale", "rbg" 之一。默认:"rgb"。图像是否被转换成1或3个颜色通道。
classes: 可选的类的子目录列表(例如 ['dogs', 'cats'])。默认:None。如果未提供,类的列表将自动从“目录”下的子目录名称/结构中推断出来,其中每个子目录都将被作为不同的类(类名将按字典序映射到标签的索引)。包含从类名到类索引的映射的字典可以通过class_indices属性获得。
class_mode: "categorical", "binary", "sparse", "input" 或 None 之一。默认:"categorical"。决定返回的标签数组的类型:"categorical" 将是 2D one-hot 编码标签,"binary" 将是 1D 二进制标签,"sparse" 将是 1D 整数标签,"input" 将是与输入图像相同的图像(主要用于与自动编码器一起工作)。如果为 None,不返回标签(生成器将只产生批量的图像数据,对于 model.predict_generator(), model.evaluate_generator() 等很有用)。请注意,如果 class_mode 为 None,那么数据仍然需要驻留在 directory 的子目录中才能正常工作。
batch_size: 一批数据的大小(默认 32)。
shuffle: 是否混洗数据(默认 True)。
seed: 可选随机种子,用于混洗和转换。
save_to_dir: None 或 字符串(默认 None)。这使你可以最佳地指定正在生成的增强图片要保存的目录(用于可视化你在做什么)。
save_prefix: 字符串。 保存图片的文件名前缀(仅当 save_to_dir 设置时可用)。
save_format: "png", "jpeg" 之一(仅当 save_to_dir 设置时可用)。默认:"png"。
follow_links: 是否跟踪类子目录下的符号链接(默认 False)。
  • yields: 元组 (x, y),其中 x 是图像数据的 Numpy 数组,y 是相应标签的 Numpy 数组。生成器将无限循环。

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
'''
任务一:使用keras数据生成器,对猫狗数据进行数值处理和空间处理(旋转,平移等),
对处理后的数据做三项检查,即数据类型,大小,数值范围,最后随机可视化8个生成器处理后的数据
'''

from __future__ import print_function
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# 从目录中读取图像
datagen = ImageDataGenerator(
rotation_range=90,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.5,
rescale=1./255,
# horizontal_flip=True,
# vertical_flip=False
)
gene = datagen.flow_from_directory(
'test',
# target_size=(1280, 720),
batch_size=8,
# save_to_dir = 'x_test',
# # save_prefix = 'cats_and_dogs',
# save_format = 'png'
)
data = next(gene)
print(data[0][0].shape)
for i in data[0] :
plt.imshow(i)
plt.show()

# 直接使用cifar10数据
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
cifar10_datagen = ImageDataGenerator(
rescale=1./255,
)
result = cifar10_datagen.flow(
x_train,
y_train,
batch_size=8,
)
data = next(result)
for i in data[0] :
print("data type:",type(i))
print("data shape:",i.shape)
print("max:",i.max())
print("min:",i.min())
plt.imshow(i)
plt.show()

参考