总字数:约12000字,阅读时间:约15分钟
写在前面
刚开始接触机器学习的时候,曾经试过用python写过一个简单的MLP(Multi-Layer Perceptron),其中仅仅只是实现了简单的Gradient Descent算法,激活函数上也只涉及了Sigmod激活函数。在这之后的很多研究中,都是用的现成的算法库在跑,完全接触不到实现方法,编程能力也急剧下降。所以趁着现在有点时间,认真学习一下别人是怎么样实现深度学习方法的。由于我一直使用的是python语言,而且我在入门深度学习的时候也是从Keras这个库开始的。自从被TensorFlow列为官网API之后,Keras使用的人也极具上升。所以决定从Keras源码开始看起,学习一下别人是如何实现深度学习框架的。
由于我是即兴记录的,目前还没有任何的目标。所以可能会出现结构上有些凌乱,等我都掌握了之后,我再做整理吧。如果发现了我在分析的过程中存在错误或者不足,希望大家能给我留言,欢迎指正。
在发表之前写了一些内容,但是自低向上写的。结果发现自低向上分析实在有点困难,要直接把底层结构讲清楚确实很困难。所以决定改一改策略,采用自顶向下的方式写,先分析基础构建,再深入底层分析运作原理。注:在代码讲解过程中,我将一些不必要的空行,包导入等不影响整体介绍的部分删掉了,将非重点讲解的内容修改为pass(并不是这部分没有代码),在部分位置添加了一些标识以定位代码。
我在开始写这个系列的时候,Keras最新版本还是1.2.2,现在已经出了2.X的版本了,不过核心框架没有变动,仅仅只是做了一些接口的改变,所以依旧可以根据1.2.2版本进行分析。Keras 2改动部分可以在这篇博文(Introducing Keras 2)上查看。
文档结构
我们先来看看Keras下的文档结构。
|
|
实际上,对于在深度学习这个模块里面,我们主要关注的是engine文件夹中的内容,这里面包含了关于model、layer、node的实现方式,也就是最底层的内容。而我们在应用当中,主要使用的Keras文件夹下.py的文件,其中常用的Sequential模型在models.py中实现,各种Optimizer类在optimizers.py中实现,各种激活函数在activations.py等等。这些内容我们以后慢慢来看,我们先从底层开始看起,看看Keras的底层是如何实现可扩展结构的。
从一个例子开始
首先,从一个示例程序入手,开始解析Keras。我从examples文件夹中挑选mnist_cnn
.py作为我们上手的第一个程序。选择它的主要原因是因为它整体框架比较简洁,而且包含了Keras常用的构建。
|
|
首先定义了这个模型的一些超参数,batch size,input size之类的。其次是获取数据。这里我们需要重点提一下了,在Keras的datasets模块中,包含了6个数据集(cifar, cifar10, cifar100, imdb, mnist, reuters),这些数据集不是一开始就存在的,而是在使用的时候才去下载的,不过只需要下载一次就可以了,之后它会在指定的文件夹中寻找。我们看看它是怎么实现的。
数据集下载
|
|
load_data函数接受一个参数path,表示存储的地址。关于这个参数,我觉得这是Keras设计上的一个失误。无论是文件名还是注释都表示这是一个地址,但实际的默认给出的参数却是文件名,然后又在注释中又添加了一段(relative to ~/.keras/datasets),表示实际的默认存储位置为用户目录下的一个隐藏文件中,确实让人很不解。继续往下看,从get_file的URL地址我们可以看出,Keras将数据文件寄托在了Amazon AWS上面。对于一些小一点的数据文件,直接用其自动下载功能不会存在太大的问题。但是对于一些比较大的文件,直接下载很容易断线,而默认是不支持断点续传的。这个时候我们就可以在源码里面找到它的下载URL,然后使用其他的下载工具进行下载,然后拷贝到默认的文件夹(~/.keras/datasets)中即可。
之后是检测是否为压缩包,若是则读取压缩文件,否则直接读取(可以看出这个应该是早期设计的函数,解压部分现在已经被集成到了
keras/utils/data_utils.py
def get_file(fname, origin, untar=False,
md5_hash=None, cache_subdir=’datasets’):
“””Downloads a file from a URL if it not already in the cache.
Passing the MD5 hash will verify the file after download
as well as if it is already present in the cache.
# Arguments
fname: name of the file
origin: original URL of the file
untar: boolean, whether the file should be decompressed
md5_hash: MD5 hash of the file for verification
cache_subdir: directory being used as the cache
# Returns
Path to the downloaded file
"""
######################## (1) start ########################
datadir_base = os.path.expanduser(os.path.join('~', '.keras'))
if not os.access(datadir_base, os.W_OK):
datadir_base = os.path.join('/tmp', '.keras')
datadir = os.path.join(datadir_base, cache_subdir)
if not os.path.exists(datadir):
os.makedirs(datadir)
######################### (1) end #########################
if untar:
untar_fpath = os.path.join(datadir, fname)
fpath = untar_fpath + '.tar.gz'
else:
fpath = os.path.join(datadir, fname)
######################## (2) start ########################
download = False
if os.path.exists(fpath):
# File found; verify integrity if a hash was provided.
if md5_hash is not None:
if not validate_file(fpath, md5_hash):
print('A local file was found, but it seems to be '
'incomplete or outdated.')
download = True
else:
download = True
if download:
print('Downloading data from', origin)
progbar = None
def dl_progress(count, block_size, total_size, progbar=None):
if progbar is None:
progbar = Progbar(total_size)
else:
progbar.update(count * block_size)
error_msg = 'URL fetch failure on {}: {} -- {}'
try:
try:
urlretrieve(origin, fpath,
functools.partial(dl_progress, progbar=progbar))
except URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(fpath):
os.remove(fpath)
raise
progbar = None
######################### (2) end #########################
if untar:
if not os.path.exists(untar_fpath):
print('Untaring file...')
tfile = tarfile.open(fpath, 'r:gz')
try:
tfile.extractall(path=datadir)
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(untar_fpath):
if os.path.isfile(untar_fpath):
os.remove(untar_fpath)
else:
shutil.rmtree(untar_fpath)
raise
tfile.close()
return untar_fpath
return fpath
|
|
keras/utils/data_utils.py
if sys.version_info[0] == 2:
def urlretrieve(url, filename, reporthook=None, data=None):
pass
else:
from six.moves.urllib.request import urlretrieve
keras/utils/np_utils.py
def to_categorical(y, nb_classes=None):
“””Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
# Arguments
y: class vector to be converted into a matrix
(integers from 0 to nb_classes).
nb_classes: total number of classes.
# Returns
A binary matrix representation of the input.
"""
y = np.array(y, dtype='int').ravel()
if not nb_classes:
nb_classes = np.max(y) + 1
n = y.shape[0]
categorical = np.zeros((n, nb_classes))
categorical[np.arange(n), y] = 1
return categorical
|
|
def probas_to_classes(y_pred):
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
return categorical_probas_to_classes(y_pred)
return np.array([1 if p > 0.5 else 0 for p in y_pred])
def categorical_probas_to_classes(p):
return np.argmax(p, axis=1)
```
实际上,categorical_probas_to_classes
函数才是to_categorical
函数的逆向操作,不过作者考虑到两类识别的兼容性问题,在其上又加了一个维度判断,所以我们直接用probas_to_classes
函数就好了。
小结
本节主要讲了一下主体框架,然后以一个例子入手,分析了数据下载模块和两个实用的类别转换小工具。下一个章节将重点分析网络模型建立。