使用分布式训练

简介

TensorFlow只是library,分布式TensorFlow应用需要我们在多个节点启动Python脚本组成分布式计算集群。

Xiaomi Cloud-ML支持标准的分布式TensorFlow应用,用户只需编写对应的Python脚本即可提交运行,用法与单机版类似。

代码规范

由于分布式TensorFlow应用需要启动多节点,每个节点需要知道自己的角色,一般都是通过命令行参数传入,而用户自定义的命令行参数名和个数可能不同。Cloud-ML要求用户通过DISTRIBUTED_CONFIG或TF_CONFIG(Cloud-ML原先只支持tensorflow分布式时,使用TF_CONFIG这个环境变量传递分布式参数,当前仍保留,后期会统一为DISTRIBUTED_CONFIG)这个环境变量传入集群和节点的信息。

如1个master、1个ps、1个worker的情况,传入的参数如下:

DISTRIBUTED_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}}'
TF_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}}'

然后用户Python代码中可以直接读取环境变量,获取cluster spec和type、index信息。

if os.environ.get('DISTRIBUTED_CONFIG', ""):
  env = json.loads(os.environ.get('DISTRIBUTED_CONFIG', '{}'))
  task_data = env.get('task', None)
  cluster_spec = env["cluster"]
  task_type = task_data["type"]
  task_index = task_data["index"]

代码实例

我们也实现了标准的分布式TensorFlow应用,代码地址 https://github.com/XiaoMi/cloud-ml-sdk/blob/master/cloud_ml_samples/tensorflow/linear_regression/trainer/task.py

本地运行

本地启动分布式TensorFlow应用,以samples代码为例,可以先打开3个终端,然后分别运行下面的命令。

CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}}' python -m trainer.task 

CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "worker"}}' python -m trainer.task

使用Xiaomi Cloud-ML

如果使用Xiaomi Cloud-ML,只需要把Python代码打包,然后运行时传入 -D,之后根据提示输入task type的名称、数量、资源等信息:

cloudml jobs submit -n distributed -m trainer.task -u fds://cloud-ml/linear/trainer-1.0.tar.gz -c 0.3 -M 300M -D

分布式训练任务提交后,可以通过命令行看到多个任务的启动,查看具体某个worker日志发现分布式训练任务也正常完成。

cloudml jobs list

cloudml jobs logs distributed-worker-0

cloudml jobs logs distributed-ps-0

参数介绍

  • -D 表示使用分布式,按照提示输入分布式相关的信息即可,支持通用分布式

旧版本分布式参数:

  • -p 表示集群的ps的个数,暂时只支持TensorFlow深度学习框架。
  • -w 表示集群的worker个数,暂时只支持TensorFlow深度学习框架。

results matching ""

    No results matching ""