使用分布式训练
简介
TensorFlow只是library,分布式TensorFlow应用需要我们在多个节点启动Python脚本组成分布式计算集群。
Xiaomi Cloud-ML支持标准的分布式TensorFlow应用,用户只需编写对应的Python脚本即可提交运行,用法与单机版类似。
代码规范
由于分布式TensorFlow应用需要启动多节点,每个节点需要知道自己的角色,一般都是通过命令行参数传入,而用户自定义的命令行参数名和个数可能不同。Cloud-ML要求用户通过DISTRIBUTED_CONFIG或TF_CONFIG(Cloud-ML原先只支持tensorflow分布式时,使用TF_CONFIG这个环境变量传递分布式参数,当前仍保留,后期会统一为DISTRIBUTED_CONFIG)这个环境变量传入集群和节点的信息。
如1个master、1个ps、1个worker的情况,传入的参数如下:
DISTRIBUTED_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}, "environment": "cloud"}'
TF_CONFIG='{"cluster": {"master": ["127.0.0.1:3000"], "ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}, "environment": "cloud"}'
注: 其中 environment
赋值为 cloud 表明为云上分布式训练,tensorflow 框架会根据这个变量来判断 is_chief
。
然后用户Python代码中可以直接读取环境变量,获取cluster spec和type、index信息。
if os.environ.get('DISTRIBUTED_CONFIG', ""):
env = json.loads(os.environ.get('DISTRIBUTED_CONFIG', '{}'))
task_data = env.get('task', None)
cluster_spec = env["cluster"]
task_type = task_data["type"]
task_index = task_data["index"]
代码实例
我们也实现了标准的分布式TensorFlow应用,代码地址 https://github.com/XiaoMi/cloud-ml-sdk/blob/master/cloud_ml_samples/tensorflow/linear_regression/trainer/task.py 。
本地运行
本地启动分布式TensorFlow应用,以samples代码为例,可以先打开3个终端,然后分别运行下面的命令。
CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "ps"}}' python -m trainer.task
CUDA_VISIBLE_DEVICES='' DISTRIBUTED_CONFIG='{"cluster": {"ps": ["127.0.0.1:3001"], "worker": ["127.0.0.1:3002"]}, "task": {"index": 0, "type": "worker"}}' python -m trainer.task
使用Xiaomi Cloud-ML
如果使用Xiaomi Cloud-ML,只需要把Python代码打包,然后运行时传入 -D
,之后根据提示输入task type的名称、数量、资源及定制参数等信息:
cloudml jobs submit -n distributed -m trainer.task -u fds://cloud-ml/linear/trainer-1.0.tar.gz -D
分布式训练任务提交后,可以通过命令行看到多个任务的启动,查看具体某个worker日志发现分布式训练任务也正常完成。
cloudml jobs list
cloudml jobs logs distributed-worker-0
cloudml jobs logs distributed-ps-0
参数介绍
-D
表示使用分布式,按照提示输入分布式相关的信息即可,支持通用分布式
旧版本分布式参数:
-p
表示集群的ps的个数,暂时只支持TensorFlow深度学习框架。-w
表示集群的worker个数,暂时只支持TensorFlow深度学习框架。