|
@@ -238,6 +238,18 @@ if __name__ == "__main__":
|
|
|
# train_data_path: 训练数据路径
|
|
|
|
|
|
try:
|
|
|
+ # 设置环境变量
|
|
|
+ # 单机多卡
|
|
|
+ os.environ['RANK'] = '0' # 第一张卡的 rank
|
|
|
+ os.environ['WORLD_SIZE'] = '1' # 总共有 1 张卡
|
|
|
+ os.environ['MASTER_ADDR'] = 'localhost'
|
|
|
+ os.environ['MASTER_PORT'] = '12345'
|
|
|
+ # 多机多卡
|
|
|
+ # export RANK=0 # 第一台机器的 rank
|
|
|
+ # export WORLD_SIZE=4 # 总共有 4 台机器
|
|
|
+ # export MASTER_ADDR=<主节点 IP>
|
|
|
+ # export MASTER_PORT=12345
|
|
|
+
|
|
|
# 初始化进程组
|
|
|
dist.init_process_group(backend='nccl', init_method='env://')
|
|
|
# 初始化 ModelTrainer
|