代码拉取完成,页面将自动刷新
import os
from datetime import datetime
import argparse
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import socket
os.environ['MASTER_ADDR'] = socket.gethostbyname('node-0')
os.environ['MASTER_PORT'] = "12345"
os.environ['PAI_TASK_INDEX'] = socket.gethostname().split('-')[-1]
def main():
print('run main')
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
args = parser.parse_args()
args.world_size = args.gpus * args.nodes
print('world_size:',args.world_size)
print('master:', os.environ['MASTER_ADDR'], 'port:', os.environ['MASTER_PORT'])
# Data loading code
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train)
mp.spawn(train, nprocs=args.gpus, args=(args, trainset))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def train(gpu, args, trainset):
print("start train")
rank = int(os.environ['PAI_TASK_INDEX']) * args.gpus + gpu
dist.init_process_group(backend=args.dist_backend, init_method='env://', world_size=args.world_size, rank=rank)
torch.manual_seed(0)
model=Net()
torch.cuda.set_device(gpu)
model.cuda(gpu)
batch_size = 100
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
# Wrap the model
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# Data loading code
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainsampler = torch.utils.data.distributed.DistributedSampler(
trainset,
num_replicas=args.world_size,
rank=rank,
shuffle=True,
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=False, num_workers=2, sampler=trainsampler)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, sampler=trainsampler)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
start = datetime.now()
total_step = len(trainloader)
for epoch in range(args.epochs):
for i, (images, labels) in enumerate(trainloader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
#if (i + 1) % 100 == 0 and gpu == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, args.epochs, i + 1, total_step,
loss.item()))
if gpu == 0:
print("Training complete in: " + str(datetime.now() - start))
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。