Hi, Thinking

Keras源码解析(1)-入门

总字数:约12000字,阅读时间:约15分钟

写在前面

刚开始接触机器学习的时候,曾经试过用python写过一个简单的MLP(Multi-Layer Perceptron),其中仅仅只是实现了简单的Gradient Descent算法,激活函数上也只涉及了Sigmod激活函数。在这之后的很多研究中,都是用的现成的算法库在跑,完全接触不到实现方法,编程能力也急剧下降。所以趁着现在有点时间,认真学习一下别人是怎么样实现深度学习方法的。由于我一直使用的是python语言,而且我在入门深度学习的时候也是从Keras这个库开始的。自从被TensorFlow列为官网API之后,Keras使用的人也极具上升。所以决定从Keras源码开始看起,学习一下别人是如何实现深度学习框架的。

由于我是即兴记录的,目前还没有任何的目标。所以可能会出现结构上有些凌乱,等我都掌握了之后,我再做整理吧。如果发现了我在分析的过程中存在错误或者不足,希望大家能给我留言,欢迎指正。

在发表之前写了一些内容,但是自低向上写的。结果发现自低向上分析实在有点困难,要直接把底层结构讲清楚确实很困难。所以决定改一改策略,采用自顶向下的方式写,先分析基础构建,再深入底层分析运作原理。注:在代码讲解过程中,我将一些不必要的空行,包导入等不影响整体介绍的部分删掉了,将非重点讲解的内容修改为pass(并不是这部分没有代码),在部分位置添加了一些标识以定位代码。

我在开始写这个系列的时候,Keras最新版本还是1.2.2,现在已经出了2.X的版本了,不过核心框架没有变动,仅仅只是做了一些接口的改变,所以依旧可以根据1.2.2版本进行分析。Keras 2改动部分可以在这篇博文(Introducing Keras 2)上查看。

文档结构

我们先来看看Keras下的文档结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|-- docs #说明文档
|-- examples #应用示例
|-- test #测试文件
|-- keras #源码核心
| |-- backend #底层backend
| |-- datasets #数据获取源码
| |-- engine #核心工具
| |-- layers #层源码
| |-- legacy #遗留源码
| |-- preprocessing #预处理函数
| |-- utils #实用工具
| |-- wrappers #scikit-learn 封装类
| |-- activations.py #可用的激活函数
| |-- callbacks.py #回调函数
| |-- constraints.py #权重约束项,如非零约束等
| |-- initializations.py #初始化方法
| |-- metrics.py #度量方法
| |-- models.py #包含Model和Sequential模型,以及各种存取方法
| |-- objectives.py #objectives function,也就是loss function
| |-- optimizers.py #优化方法,如SGD,Adam等
| |-- regularizers.py #规则项,如L1,L2等规则项

实际上,对于在深度学习这个模块里面,我们主要关注的是engine文件夹中的内容,这里面包含了关于model、layer、node的实现方式,也就是最底层的内容。而我们在应用当中,主要使用的Keras文件夹下.py的文件,其中常用的Sequential模型在models.py中实现,各种Optimizer类在optimizers.py中实现,各种激活函数在activations.py等等。这些内容我们以后慢慢来看,我们先从底层开始看起,看看Keras的底层是如何实现可扩展结构的。

从一个例子开始

首先,从一个示例程序入手,开始解析Keras。我从examples文件夹中挑选mnist_cnn
.py作为我们上手的第一个程序。选择它的主要原因是因为它整体框架比较简洁,而且包含了Keras常用的构建。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# examples/mnist_cnn.py
'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''
batch_size = 128
nb_classes = 10
nb_epoch = 12
# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
pool_size = (2, 2)
# convolution kernel size
kernel_size = (3, 3)
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
if K.image_dim_ordering() == 'th':
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Sequential()
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
border_mode='valid',
input_shape=input_shape))
model.add(Activation('relu'))
model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=pool_size))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adadelta',
metrics=['accuracy'])
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

首先定义了这个模型的一些超参数,batch size,input size之类的。其次是获取数据。这里我们需要重点提一下了,在Keras的datasets模块中,包含了6个数据集(cifar, cifar10, cifar100, imdb, mnist, reuters),这些数据集不是一开始就存在的,而是在使用的时候才去下载的,不过只需要下载一次就可以了,之后它会在指定的文件夹中寻找。我们看看它是怎么实现的。

数据集下载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# keras/datasets/mnist.py
def load_data(path='mnist.pkl.gz'):
"""Loads the MNIST dataset.
# Arguments
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.pkl.gz')
if path.endswith('.gz'):
f = gzip.open(path, 'rb')
else:
f = open(path, 'rb')
if sys.version_info < (3,):
data = cPickle.load(f)
else:
data = cPickle.load(f, encoding='bytes')
f.close()
return data # (x_train, y_train), (x_test, y_test)

load_data函数接受一个参数path,表示存储的地址。关于这个参数,我觉得这是Keras设计上的一个失误。无论是文件名还是注释都表示这是一个地址,但实际的默认给出的参数却是文件名,然后又在注释中又添加了一段(relative to ~/.keras/datasets),表示实际的默认存储位置为用户目录下的一个隐藏文件中,确实让人很不解。继续往下看,从get_file的URL地址我们可以看出,Keras将数据文件寄托在了Amazon AWS上面。对于一些小一点的数据文件,直接用其自动下载功能不会存在太大的问题。但是对于一些比较大的文件,直接下载很容易断线,而默认是不支持断点续传的。这个时候我们就可以在源码里面找到它的下载URL,然后使用其他的下载工具进行下载,然后拷贝到默认的文件夹(~/.keras/datasets)中即可。

之后是检测是否为压缩包,若是则读取压缩文件,否则直接读取(可以看出这个应该是早期设计的函数,解压部分现在已经被集成到了

1
2
这里我们再深入一点,看看```get_file```函数的运作方式。这个函数很重要,Keras中涉及到下载文件功能的部分均是由这个函数实现的。它保存在utils\data_utils.py文件中。

keras/utils/data_utils.py

def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir=’datasets’):
“””Downloads a file from a URL if it not already in the cache.
Passing the MD5 hash will verify the file after download
as well as if it is already present in the cache.

# Arguments
    fname: name of the file
    origin: original URL of the file
    untar: boolean, whether the file should be decompressed
    md5_hash: MD5 hash of the file for verification
    cache_subdir: directory being used as the cache
# Returns
    Path to the downloaded file
"""
######################## (1) start ########################
datadir_base = os.path.expanduser(os.path.join('~', '.keras'))
if not os.access(datadir_base, os.W_OK):
    datadir_base = os.path.join('/tmp', '.keras')
datadir = os.path.join(datadir_base, cache_subdir)
if not os.path.exists(datadir):
    os.makedirs(datadir)
######################### (1) end #########################
if untar:
    untar_fpath = os.path.join(datadir, fname)
    fpath = untar_fpath + '.tar.gz'
else:
    fpath = os.path.join(datadir, fname)
######################## (2) start ########################
download = False
if os.path.exists(fpath):
    # File found; verify integrity if a hash was provided.
    if md5_hash is not None:
        if not validate_file(fpath, md5_hash):
            print('A local file was found, but it seems to be '
                  'incomplete or outdated.')
            download = True
else:
    download = True
if download:
    print('Downloading data from', origin)
    progbar = None

    def dl_progress(count, block_size, total_size, progbar=None):
        if progbar is None:
            progbar = Progbar(total_size)
        else:
            progbar.update(count * block_size)

    error_msg = 'URL fetch failure on {}: {} -- {}'
    try:
        try:
            urlretrieve(origin, fpath,
                        functools.partial(dl_progress, progbar=progbar))
        except URLError as e:
            raise Exception(error_msg.format(origin, e.errno, e.reason))
        except HTTPError as e:
            raise Exception(error_msg.format(origin, e.code, e.msg))
    except (Exception, KeyboardInterrupt) as e:
        if os.path.exists(fpath):
            os.remove(fpath)
        raise
    progbar = None
######################### (2) end #########################
if untar:
    if not os.path.exists(untar_fpath):
        print('Untaring file...')
        tfile = tarfile.open(fpath, 'r:gz')
        try:
            tfile.extractall(path=datadir)
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(untar_fpath):
                if os.path.isfile(untar_fpath):
                    os.remove(untar_fpath)
                else:
                    shutil.rmtree(untar_fpath)
            raise
        tfile.close()
    return untar_fpath
return fpath
1
2
3
4
5
6
7
8
这个函数包含五个参数,fname(文件名),origin(源URL),untar(是否解压),md5_hash(md5验证),cache_subdir(存储子文件夹)。这个函数首先获取一个合法的下载地址(1),如果根目录下.keras文件可写,则以这个路径为默认路径,否则默认路径改为/tmp/.keras。这里貌似有个bug,第一部分使用```os.path.expanduser```能兼容Windows、Linux等多种平台,但是对于后面直接使用```os.path.join('/tmp', '.keras')```,貌似在Windows系统下面是不存在这个路径的。
其后检测是否需要解压,这个功能应该是后来加的,在```mnist.load_data```函数中还留有需要外部解压的痕迹。对于需要解压的情况,函数的最后面对其进行了一个解压操作,并返回解压后的地址。
这个函数的核心部分就是下载(2)。首先定义一个布尔变量download(注意:这个变量表示是否需要下载,而非是否已经下载),然后检验是否需要下载,两种情况需要下载:文件不存在或者文件存在但MD5验证不正确。下载的过程就比较简单了,先定义了一个用于显示下载进度条的回调函数```dl_progress```,然后调取```urlretrieve```函数进行下载。我很喜欢这一部分的异常处理,在内部截获两种不同的异常,然后使用通用异常进行显示,并且将用户终止异常添加进去,还处理了下载失败后的损坏文件,避免出错。不过这里还是存在一个问题,这里缺少了MD5验证流程。实际上对于一个严谨的程序,应当将验证流程放置在整个程序的最末尾,以防止中间过程出现的不可预见性错误。所以这里应该适当调整MD5验证的位置,以防止意外情况出现。
最后小小的提一下```urlretrieve```函数,主要是觉得这种处理方式很优雅。

keras/utils/data_utils.py

if sys.version_info[0] == 2:
def urlretrieve(url, filename, reporthook=None, data=None):
pass
else:
from six.moves.urllib.request import urlretrieve

1
2
3
4
5
6
7
8
因为在Python 2.X的版本中,urlretrieve存在一些问题,所以对于这个版本作者自己写了一个获取程序,而对于Python 3.X就直接从外部库导入。在保持接口相同的情况下,就可以兼容两个版本,而且不需要额外的修改。
#### 一些小工具
回到上面的示例程序,获取完了数据之后,我们又可以开心的继续往下走了。下面是检验backend维度顺序,在Theano中默认维度顺序为(depth, row, column),而在TensorFlow中为(row, column, depth),所以这里要加以区分。之后就是做一些简单的数据预处理,这里直接跳过。这一小节重点要讲的是```to_categorical```函数,这个函数在Keras中经常使用到,用于做标签转换的。
在神经网络中,当我们要计算交叉熵的时候,我们需要将输出节点个数定为类别数C。并且将每一个样本label的维度转换为C维,标签值对应位置设为1,其余为0。这个时候就需要用到```to_categorical```函数了。

keras/utils/np_utils.py

def to_categorical(y, nb_classes=None):
“””Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.

# Arguments
    y: class vector to be converted into a matrix
        (integers from 0 to nb_classes).
    nb_classes: total number of classes.
# Returns
    A binary matrix representation of the input.
"""
y = np.array(y, dtype='int').ravel()
if not nb_classes:
    nb_classes = np.max(y) + 1
n = y.shape[0]
categorical = np.zeros((n, nb_classes))
categorical[np.arange(n), y] = 1
return categorical
1
2
3
4
实际上,这个函数的实现特别简单,就是获取类别数C,然后创建一个NxC维的零矩阵,然后将对应值设置为1。这里想提一下的是,它的类别数获取是根据标签值中的最大值加一来获取的,也就是说原始标签一定要从0开始取值。否则会出现未定义的类别标签。
当然啦,这里还有一个反向操作的工具,```probas_to_classes```函数

def probas_to_classes(y_pred):
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
return categorical_probas_to_classes(y_pred)
return np.array([1 if p > 0.5 else 0 for p in y_pred])

def categorical_probas_to_classes(p):
return np.argmax(p, axis=1)
```

实际上,categorical_probas_to_classes函数才是to_categorical函数的逆向操作,不过作者考虑到两类识别的兼容性问题,在其上又加了一个维度判断,所以我们直接用probas_to_classes函数就好了。

小结

本节主要讲了一下主体框架,然后以一个例子入手,分析了数据下载模块和两个实用的类别转换小工具。下一个章节将重点分析网络模型建立。

Kivi.记