PyTorch设置随机数

通过设置随机种子,使训练结果可复现,有效对比模型测试结果;

目录

设置方式

参考资料


设置方式:

训练过程中的随机数包含:初始权重、随机数据增强、数据读取顺序等,将这些随机数固定,按理可保证训练结果一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 设置随机种子
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 设置Dataloader的种子
def set_random_Dataloader(worker_id, rank, seed):
worker_seed = rank + seed
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)

set_seed(30)
set_random_Dataloader()

# 预处理、训练模型

参考资料:


版本历史

  • 2024-01-19: 初版文档