mxnet 概述

1 安装

  • 编译器: g++>4.8 或者 clang
  • 依赖:BLAS 库比如 libblas, openblas 对于不同的场景,我们会需要依赖不同的库。在这里,我们暂时不使用 GPU,所以不安装 CUDA。

git clone --recursive https://github.com/dmlc/mxnet

1.1 目录结构

|- mxnet
|- make
|- Makefile
|- cmake
|- CMakeLists.txt
|- docker
|- dmlc-core
|- ps-lite
|- mshadow
|- include
|- src
|- scala-package
|- R-package
|- python
|- ...

mxnet 依赖于 dmlc-core,ps-lite 和 mshadow 三个项目。在我看来,mxnet 实际上可以分为两部分。一部分我称之为 mxnet core,另一部分我称之为 mxnet api。在 core 中, include 文件夹中定义了一套 c api 供其他语言比如 python,scala 调用。mxnet core 并没有实现完整的神经网络训练逻辑,它定义了神经网络如何做前向后向传播,但实际训练时的迭代次数, KV Store 的起停等逻辑则是包含在 mxnet api 中的,所以 python,scala 等接口都有一套自己的实现逻辑。

1.2 编译

mxnet 现在有两套编译系统,一套直接基于 make,另一套基于 cmake。推荐使用 make,因为功能更全。现在的 mxnet 的 cmake 脚本不支持编译 scala。

可以通过编辑 make/config.mk 文件来配置编译选项。对于我们而言,我们暂时不使用 GPU。同时我们需要与 Spark 结合,所以需要分布式的 KV Store。 在 make/config.ml 下,修改配置如下:

USE_DISK_KVSTORE = 1

因为分布式 KV Store 依赖于 protobuf 和 zmq,我们需要安装对应的依赖库。

开始编译

cd mxnet
make
make scalapkg # 如果你需要 scala 包
make scalatest # 运行 scala 的测试用例

若编译成功,你可以在 lib 目录下找到 libmxnet.so 文件。

2 参数服务器的优势

现在 Spark 基本是大数据处理的事实标准,Spark MLlib 也实现了许多机器学习算法,但 Spark 其实仍是基于 Map/Reduce 计算模型的,而这一模型与机器学习算法的需求 并不十分契合。在机器学习中,一个十分重要的步骤是计算参数的最优解,一般使用梯度下降方法: \[ w = w - \lambda\Delta w \]

在 Spark 中,每次迭代时,我们每个 partition 可以计算梯度,然后在 driver 端更新 weights。那么 driver 端必须等待所有 executor 完成梯度计算。一旦某个 executor 出现网络延时等问题, 整个计算过程将受到影响。而参数服务器的目的既是消除这一影响,单个节点计算的延迟并不会影响整体的计算。使同步执行过程变成异步执行过程。比较 mxnet 和 sparkMLlib 中多层神经网络的训练时间,我们可以看到性能的差距。

perf.png

2.1 实现方式

在参数服务器中有三种角色:

  1. worker: 计算梯度
  2. server: 从 worker 获取梯度信息,更新参数
  3. scheduler: 负责调度,worker 和 server 需 scheduler 注册信息

arch.png

工作流程:

  1. worker,server 向 scheduler 注册,获得相关信息
  2. worker 从 server 端 pull 参数 w
  3. worker 基于参数 w 和数据计算梯度,然后 push 梯度到 server
  4. server 更新参数 w
  5. 反复执行 2-4 这一过程

3 计算模型

主要参考 mxnet 的两篇文章:

http://mxnet.readthedocs.io/en/latest/system/program_model.html

http://mxnet.readthedocs.io/en/latest/system/note_memory.html

对于用户而言,mxnet 提够了一套接口来定义神经网络。

val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation(name = "relu1")(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation(name = "relu2")(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected(name = "fc3")(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.SoftmaxOutput(name = "softmax")(Map("data" -> fc3))

如上一段 Scala 代码便定义了一个多层神经网络。而在实际执行时, Symbol 会调用 toStaticGraph 方法转成 StaticGraphStaticGraph 会计算图中节点的依赖并生成拓扑结构。我们知道训练神经网络有两个步骤,前向传播和后向传播。现在有两种不同的后向传播计算方法, 一种是与前向传播共用一个图,而另一种则是显式生成后向传播图节点。

back_graph.png

有些深度学习库选择共用一个图,比如 caffe,torch。而另一些则选择显式后向传播节点,比如 Theano。mxnet 同样选择显式后向传播。这样可以为优化提供方便。

4 实例

我们先以一个实例来看看 mxnet 是如何运行的。鉴于 Spark 基本是当前大数据处理的事实标准,我们直接尝试将 mxnet 与 Spark 结合, 从而更接近生产环境的工作流。mxnet 源码中已经有一个与 Spark 结合的实例,我们直接拿来分析。

class ClassificationExample
object ClassificationExample {

def main(args: Array[String]): Unit = {
try {
// 初始化 SparkContext
val conf = new SparkConf().setAppName("MXNet")
val sc = new SparkContext(conf)

// 构建网络
val network = if (cmdLine.model == "mlp") getMlp else getLenet
val dimension = if (cmdLine.model == "mlp") Shape(784) else Shape(1, 28, 28)
val devs =
if (cmdLine.gpus != null) cmdLine.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
else if (cmdLine.cpus != null) cmdLine.cpus.split(',').map(id => Context.cpu(id.trim.toInt))
else Array(Context.cpu(0))

// 配置训练属性
val mxnet = new MXNet()
.setBatchSize(128)
.setLabelName("softmax_label")
.setContext(devs)
.setDimension(dimension)
.setNetwork(network)
.setNumEpoch(cmdLine.numEpoch)
.setNumServer(cmdLine.numServer)
.setNumWorker(cmdLine.numWorker)
.setExecutorJars(cmdLine.jars)
.setJava(cmdLine.java)

val trainData = parseRawData(sc, cmdLine.input)
val start = System.currentTimeMillis

// 开始训练
val model = mxnet.fit(trainData)
val timeCost = System.currentTimeMillis - start
logger.info("Training cost {} milli seconds", timeCost)
model.save(sc, cmdLine.output + "/model")

logger.info("Now do validation")
val valData = parseRawData(sc, cmdLine.inputVal)

// 广播模型用于预测
val brModel = sc.broadcast(model)
val res = valData.mapPartitions { data =>
// get real labels
import org.apache.spark.mllib.linalg.Vector
val points = ArrayBuffer.empty[Vector]
val y = ArrayBuffer.empty[Float]
while (data.hasNext) {
val evalData = data.next()
y += evalData.label.toFloat
points += evalData.features
}

// get predicted labels
val probArrays = brModel.value.predict(points.toIterator)
require(probArrays.length == 1)
val prob = probArrays(0)
val py = NDArray.argmaxChannel(prob.get)
require(y.length == py.size, s"${y.length} mismatch ${py.size}")

// I'm too lazy to calculate the accuracy
val res = Iterator((y.toArray zip py.toArray).map {
case (y1, py1) => y1 + "," + py1 }.mkString("\n"))

py.dispose()
prob.get.dispose()
res
}
res.saveAsTextFile(cmdLine.output + "/data")

sc.stop()
} catch {
case e: Throwable =>
logger.error(e.getMessage, e)
sys.exit(-1)
}
}

def getMlp: Symbol = {
val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")(Map("data" -> data, "num_hidden" -> 128))
val act1 = Symbol.Activation(name = "relu1")(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")(Map("data" -> act1, "num_hidden" -> 64))
val act2 = Symbol.Activation(name = "relu2")(Map("data" -> fc2, "act_type" -> "relu"))
val fc3 = Symbol.FullyConnected(name = "fc3")(Map("data" -> act2, "num_hidden" -> 10))
val mlp = Symbol.SoftmaxOutput(name = "softmax")(Map("data" -> fc3))
mlp
}
}

为了与 Spark 沟通,毫无疑问首先是初始化 SparkContext 。然后我们需要定义神经网络, getMlp 方法通过 Symbol 定义了一个多层神经网络。然后新建 MXNet 类,定义训练属性。 可以看到,接下来最关键的一步是 mxnet.fit(trainData) 。此方法接受一个 RDD,并获得最终模型。

mxnet.fit 方法中,主要有以下几步操作:

  1. 新建一个 ParameterServer scheduler。这里存在一个问题,一旦 scheduler 挂了,整个参数服务器将不能运作,需要 HA 改进
  2. 通过 Spark 每个 partition 新建一个 ParameterServer Server
  3. 对于数据集,每个 partition 新建一个 ParameterServer worker
  4. 每个 partition 新建一个 FeedForword 网络,对应每个 worker,调用 FeedForword.fit 进行训练。
def fit(data: RDD[LabeledPoint]): MXNetModel = {
val sc = data.context
// distribute native jars
params.jars.foreach(jar => sc.addFile(jar))

val trainData = {
if (params.numWorker > data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.repartition(params.numWorker)
} else if (params.numWorker < data.partitions.length) {
logger.info("repartitioning training set to {} partitions", params.numWorker)
data.coalesce(params.numWorker)
} else {
data
}
}

val schedulerIP = utils.Network.ipAddress
val schedulerPort = utils.Network.availablePort
// TODO: check ip & port available
logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort)
val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler",
rootUri = schedulerIP, rootPort = schedulerPort,
numServer = params.numServer, numWorker = params.numWorker, java = params.javabin)
require(scheduler.startProcess(), "Failed to start ps scheduler process")

sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { p =>
logger.info("Starting server ...")
val server = new ParameterServer(params.runtimeClasspath,
role = "server",
rootUri = schedulerIP, rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker,
java = params.javabin)
require(server.startProcess(), "Failed to start ps server process")
}

val job = trainData.mapPartitions { partition =>
val dataIter = new LabeledPointIter(
partition, params.dimension,
params.batchSize,
dataName = params.dataName,
labelName = params.labelName)

// TODO: more nature way to get the # of examples?
var numExamples = 0
while (dataIter.hasNext) {
val dataBatch = dataIter.next()
numExamples += dataBatch.label.head.shape(0)
}
logger.debug("Number of samples: {}", numExamples)
dataIter.reset()

logger.info("Launching worker ...")
logger.info("Batch {}", params.batchSize)
KVStoreServer.init(ParameterServer.buildEnv(role = "worker",
rootUri = schedulerIP, rootPort = schedulerPort,
numServer = params.numServer,
numWorker = params.numWorker))
val kv = KVStore.create("dist_async")

val optimizer: Optimizer = new SGD(learningRate = 0.01f,
momentum = 0.9f, wd = 0.00001f)

logger.debug("Define model")
val model = new FeedForward(ctx = params.context,
symbol = params.getNetwork,
numEpoch = params.numEpoch,
optimizer = optimizer,
initializer = new Xavier(factorType = "in", magnitude = 2.34f),
argParams = null,
auxParams = null,
beginEpoch = 0,
epochSize = numExamples / params.batchSize / kv.numWorkers)
logger.info("Start training ...")
model.fit(trainData = dataIter,
evalData = null,
evalMetric = new Accuracy(),
kvStore = kv)

logger.info("Training finished, waiting for other workers ...")
dataIter.dispose()
kv.barrier()
kv.dispose()
Iterator(new MXNetModel(
model, params.dimension, params.batchSize,
dataName = params.dataName, labelName = params.labelName))
}.cache()

// force job to run
job.foreachPartition(() => _)
// simply the first model
val mxModel = job.first()

logger.info("Waiting for scheduler ...")
scheduler.waitFor()
mxModel
}
// FeedForword.fit
private def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric = new Accuracy(),
kvStore: Option[KVStore], updateOnKVStore: Boolean,
epochEndCallback: EpochEndCallback = null,
batchEndCallback: BatchEndCallback = null, logger: Logger = FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
val (argNames, paramNames, auxNames) =
initParams(trainData.provideData ++ trainData.provideLabel)

// init optimizer
val batchSizeMultiplier = kvStore.map { kv =>
if (kv.`type` == "dist_sync") {
kv.numWorkers
} else {
1
}
}
val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
this.optimizer.setArgNames(argNames)
this.optimizer.setRescaleGrad(1f / batchSize)

logger.debug("Start training on multi-device")
Model.trainMultiDevice(
symbol, ctx, argNames, paramNames, auxNames,
_argParams, _auxParams,
this.beginEpoch, this.numEpoch,
this.epochSize, this.optimizer,
kvStore, updateOnKVStore,
trainData = trainData, evalData = Option(evalData),
evalMetric = evalMetric,
epochEndCallback = Option(epochEndCallback),
batchEndCallback = Option(batchEndCallback),
logger = logger, workLoadList = workLoadList,
monitor = monitor)

可以看到,在 FeedForword.fit 中,基本上是直接调用了 Model.trainMultiDevice 方法。而此方法则实现了神经网络的前向后向传播和 KV store 的更新。 主要步骤:

  1. 取 batch
  2. 在此 batch 上做 forward 和 backward 传播
  3. 从 kv store 更新参数
private[mxnet] def trainMultiDevice(symbol: Symbol, ctx: Array[Context],
argNames: Seq[String], paramNames: Seq[String],
auxNames: Seq[String], argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
beginEpoch: Int, endEpoch: Int, epochSize: Int,
optimizer: Optimizer,
kvStore: Option[KVStore], updateOnKVStore: Boolean,
trainData: DataIter = null,
evalData: Option[DataIter] = None,
evalMetric: EvalMetric,
epochEndCallback: Option[EpochEndCallback] = None,
batchEndCallback: Option[BatchEndCallback] = None,
logger: Logger = logger,
workLoadList: Seq[Float] = Nil,
monitor: Option[Monitor] = None): Unit = {
val executorManager = new DataParallelExecutorManager(
symbol = symbol,
ctx = ctx,
trainData = trainData,
paramNames = paramNames,
argNames = argNames,
auxNames = auxNames,
workLoadList = workLoadList,
logger = logger)

monitor.foreach(executorManager.installMonitor)
executorManager.setParams(argParams, auxParams)

// updater for updateOnKVStore = false
val updaterLocal = Optimizer.getUpdater(optimizer)

kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
argParams, executorManager._paramNames, updateOnKVStore))
if (updateOnKVStore) {
kvStore.foreach(_.setOptimizer(optimizer))
}

// Now start training
for (epoch <- beginEpoch until endEpoch) {
// Training phase
val tic = System.currentTimeMillis
evalMetric.reset()
var nBatch = 0
var epochDone = false
// Iterate over training data.
trainData.reset()
while (!epochDone) {
var doReset = true
while (doReset && trainData.hasNext) {
val dataBatch = trainData.next()
executorManager.loadDataBatch(dataBatch)
monitor.foreach(_.tic())
executorManager.forward(isTrain = true)
executorManager.backward()
if (updateOnKVStore) {
updateParamsOnKVStore(executorManager.paramArrays,
executorManager.gradArrays,
kvStore)
} else {
updateParams(executorManager.paramArrays,
executorManager.gradArrays,
updaterLocal, ctx.length,
kvStore)
}
monitor.foreach(_.tocPrint())
// evaluate at end, so out_cpu_array can lazy copy
evalMetric.update(dataBatch.label, executorManager.cpuOutputArrays)

nBatch += 1
batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))

// this epoch is done possibly earlier
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
}
dataBatch.dispose()
}
if (doReset) {
trainData.reset()
}

// this epoch is done
epochDone = (epochSize == -1 || nBatch >= epochSize)
}

val (name, value) = evalMetric.get
logger.info(s"Epoch[$epoch] Train-$name=$value")
val toc = System.currentTimeMillis
logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")

evalData.foreach { evalDataIter =>
evalMetric.reset()
evalDataIter.reset()
// TODO: make DataIter implement Iterator
while (evalDataIter.hasNext) {
val evalBatch = evalDataIter.next()
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
evalMetric.update(evalBatch.label, executorManager.cpuOutputArrays)
evalBatch.dispose()
}

val (name, value) = evalMetric.get
logger.info(s"Epoch[$epoch] Validation-$name=$value")
}

if (epochEndCallback.isDefined || epoch + 1 == endEpoch) {
executorManager.copyTo(argParams, auxParams)
}
epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
}

updaterLocal.dispose()
executorManager.dispose()
}

5 组件

5.1 dmlc-core

5.1.1 parameter.h

与 spark 类似,dmlc core 也有一套定义参数的系统。cpp 没有类似 java 的反射机制, 所以在 dmlc 中用到的方法比较 hack:计算类中属性的 offset。

5.1.2 data.h

5.2 ps-lite

postoffice server, worker, scheduler Control: empty, terminate, addnode, barrier, ack van message 新建 KVWorker 和 KVServer 包含 Customer,初始化时新建一个线程用于接收消息

Customer::Customer(int id, const Customer::RecvHandle& recv_handle)
: id_(id), recv_handle_(recv_handle) {
Postoffice::Get()->AddCustomer(this);
recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this));
}

van 封装通信,现在使用 zmq

5.3 mxnet

Last Updated 2017-08-09 三 16:31.
Render by hexo-renderer-org with Emacs 24.5.1 (Org mode 8.2.10)