V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
getui
V2EX  ›  推广

TensorFlow 分布式实践

  •  
  •   getui · 2019-01-29 13:22:15 +08:00 · 916 次点击
    这是一个创建于 1886 天前的主题,其中的信息可能已经有所发展或是发生改变。

    大数据时代,基于单机的建模很难满足企业不断增长的数据量级的需求,开发者需要使用分布式的开发方式,在集群上进行建模。而单机和分布式的开发代码有一定的区别,本文就将为开发者们介绍,基于 TensorFlow 进行分布式开发的两种方式,帮助开发者在实践的过程中,更好地选择模块的开发方向。

    基于 TensorFlow 原生的分布式开发

    分布式开发会涉及到更新梯度的方式,有同步和异步的两个方案,同步更新的方式在模型的表现上能更快地进行收敛,而异步更新时,迭代的速度则会更加快。两种更新方式的图示如下:

    同步更新流程

    (图片来源:TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems )

    异步更新流程

    (图片来源:TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems )

    TensorFlow 是基于 ps、work 两种服务器进行分布式的开发。ps 服务器可以只用于参数的汇总更新,让各个 work 进行梯度的计算。

    基于 TensorFlow 原生的分布式开发的具体流程如下:

    首先指定 ps 服务器启动参数 – job_name=ps:

    python distribute.py --ps_hosts=192.168.100.42:2222 --worker_hosts=192.168.100.42:2224,192.168.100.253:2225 --job_name=ps --task_index=0
    

    接着指定 work 服务器参数(启动两个 work 节点) – job_name=work2:

    python distribute.py --ps_hosts=192.168.100.42:2222 --worker_hosts=192.168.100.42:2224,192.168.100.253:2225 --job_name=worker --task_index=0
    python distribute.py --ps_hosts=192.168.100.42:2222 --worker_hosts=192.168.100.42:2224,192.168.100.253:2225 --job_name=worker --task_index=1
    

    之后,上述指定的参数 worker_hosts ps_hosts job_name task_index 都需要在 py 文件中接受使用:

    tf.app.flags.DEFINE_string("worker_hosts", "默认值", "描述说明")
    

    接收参数后,需要分别注册 ps、work,使他们各司其职:

    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
    server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
    
    issync = FLAGS.issync
    if FLAGS.job_name == "ps":
       server.join()
    elif FLAGS.job_name == "worker":
       with tf.device(tf.train.replica_device_setter(
                       worker_device="/job:worker/task:%d" % FLAGS.task_index,
                       cluster=cluster)):
    

    继而更新梯度。

    ( 1 )同步更新梯度:

    rep_op = tf.train.SyncReplicasOptimizer(optimizer,
                                                   replicas_to_aggregate=len(worker_hosts),
                                                   replica_id=FLAGS.task_index,
                                                   total_num_replicas=len(worker_hosts),
                                                   use_locking=True)
    train_op = rep_op.apply_gradients(grads_and_vars,global_step=global_step)
    init_token_op = rep_op.get_init_tokens_op()
    chief_queue_runner = rep_op.get_chief_queue_runner()
    

    ( 2 )异步更新梯度:

    train_op = optimizer.apply_gradients(grads_and_vars,global_step=global_step)
    

    最后,使用 tf.train.Supervisor 进行真的迭代

    另外,开发者还要注意,如果是同步更新梯度,则还需要加入如下代码:

    sv.start_queue_runners(sess, [chief_queue_runner])
    sess.run(init_token_op)
    

    需要注意的是,上述异步的方式需要自行指定集群 IP 和端口,不过,开发者们也可以借助 TensorFlowOnSpark,使用 Yarn 进行管理。

    基于 TensorFlowOnSpark 的分布式开发

    作为个推面向开发者服务的移动 APP 数据统计分析产品,个数所具有的用户行为预测功能模块,便是基于 TensorFlowOnSpark 这种分布式来实现的。基于 TensorFlowOnSpark 的分布式开发使其可以在屏蔽了端口和机器 IP 的情况下,也能够做到较好的资源申请和分配。而在多个千万级应用同时建模的情况下,集群也有良好的表现,在 sparkUI 中也能看到相对应的资源和进程的情况。最关键的是,TensorFlowOnSpark 可以在单机过度到分布式的情况下,使代码方便修改,且容易部署。

    基于 TensorFlowOnSpark 的分布式开发的具体流程如下:

    首先,需要使用 spark-submit 来提交任务,同时指定 spark 需要运行的参数(– num-executors 6 等)、模型代码、模型超参等,同样需要接受外部参数:

    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--tracks", help="数据集路径")  
    args = parser.parse_args()
    

    之后,准备好参数和训练数据(DataFrame),调用模型的 API 进行启动。

    其中,soft_dist.map_fun 是要调起的方法,后面均是模型训练的参数。

    estimator = TFEstimator(soft_dist.map_fun, args) \
         .setInputMapping({'tracks': 'tracks', 'label': 'label'}) \
         .setModelDir(args.model) \
         .setExportDir(args.serving) \
         .setClusterSize(args.cluster_size) \
         .setNumPS(num_ps) \
         .setEpochs(args.epochs) \
         .setBatchSize(args.batch_size) \
         .setSteps(args.max_steps)
       model = estimator.fit(df)
    

    接下来是 soft_dist 定义一个 map_fun(args, ctx)的方法:

    def map_fun(args, ctx):
    ...
    worker_num = ctx.worker_num  # worker 数量
    job_name = ctx.job_name  # job 名
    task_index = ctx.task_index  # 任务索引
    if job_name == "ps":  # ps 节点(主节点)
      time.sleep((worker_num + 1) * 5)
      cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)
      num_workers = len(cluster.as_dict()['worker'])
      if job_name == "ps":
           server.join()
      elif job_name == "worker":
           with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index, cluster=cluster)):
    

    之后,可以使用 tf.train.MonitoredTrainingSession 高级 API,进行模型训练和预测。

    总结

    基于 TensorFlow 的分布式开发大致就是本文中介绍的两种情况,第二种方式可以用于实际的生产环境,稳定性会更高。

    在运行结束的时候,开发者们也可通过设置邮件的通知,及时地了解到模型运行的情况。

    同时,如果开发者使用 SessionRunHook 来保存最后输出的模型,也需要了解到,框架代码中的一个 BUG,即它只能在规定的时间内保存,超出规定时间,即使运行没有结束,程序也会被强制结束。如果开发者使用的版本是未修复 BUG 的版本,则要自行处理,放宽运行时间。

    1 条回复    2019-01-29 15:20:17 +08:00
    ddzzhen
        1
    ddzzhen  
       2019-01-29 15:20:17 +08:00
    支持分享
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   我们的愿景   ·   实用小工具   ·   4389 人在线   最高记录 6543   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 27ms · UTC 10:07 · PVG 18:07 · LAX 03:07 · JFK 06:07
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.