1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| def train(model, device, train_loader, optimizer): model.train()
for batch_id, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) criterion = nn.CrossEntropyLoss() loss = criterion(output, target) loss.backward() optimizer.step()
def test(model, device, test_loader, classes): model.eval() test_loss = 0 correct = 0 example_images = []
with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) criterion = nn.CrossEntropyLoss() test_loss += criterion(output, target).item() pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() example_images.append(wandb.Image( data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))
wandb.log({ "Examples": example_images, "Test Accuracy": 100. * correct / len(test_loader.dataset), "Test Loss": test_loss })
wandb.watch_called = False
def main(): use_cuda = config.use_cuda and torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
set_seed(config.seed) torch.backends.cudnn.deterministic = True
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])
train_loader = DataLoader(datasets.CIFAR10( root='dataset', train=True, download=True, transform=transform ), batch_size=config.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(datasets.CIFAR10( root='dataset', train=False, download=True, transform=transform ), batch_size=config.batch_size, shuffle=False, **kwargs)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = resnet18(pretrained=True).to(device) optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
wandb.watch(model, log="all") for epoch in range(1, config.epochs + 1): train(model, device, train_loader, optimizer) test(model, device, test_loader, classes)
torch.save(model.state_dict(), 'model.pth') wandb.save('model.pth')
if __name__ == '__main__': main()
|