トップページ人工知能,実世界DBPyTorch の応用例単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)

単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)

ユースケース: PyTorch の動作確認を行いたい.CNN (PyTorch を使用)について練習したい

次のWebページに記載のソースコード(単純な CNN を用いた画像分類)を実行してみる

参考 Web ページ: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py

先人に感謝

PyTorch の Web ページ: http://pytorch.org

GitHub の PyTorch の Webページ: https://github.com/pytorch/pytorch

前準備

PyTorch のインストール

単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)

Python プログラムを動かす.

Python プログラムを動かすために, Windows では「python」, Ubuntu では「python3」などのコマンドを使う.

あるいは, 開発環境や Python コンソール(Jupyter Qt ConsoleSpyderPyCharmPyScripter など)の利用も便利である.

あるいは,オンラインで動くGoogle Colaboratory のノートブックの利用も,場合によっては便利である.

  1. インポート

    import torch
    import torchvision
    import torchvision.transforms as transforms
    

    Ubuntu での実行結果例

    [image]
  2. CIFAR 10 のダウンロード

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    

    Ubuntu での実行結果例

    [image]
  3. CNN の定義

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    

    Ubuntu での実行結果例

    [image]
  4. GPU デバイスがあれば,それを使いたい

    net = Net()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    

    Ubuntu での実行結果例

    [image]
  5. 損失関数と最適化の定義

    import torch.optim as optim
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    

    Ubuntu での実行結果例

    [image]
  6. 学習

    inputs, labels = data[0].to(device), data[1].to(device)」は,GPUでもプログラムが動くようにするための処理

    for epoch in range(2):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    
    print('Finished Training')
    

    Ubuntu での実行結果例

    [image]
  7. テストデータを用いたテスト

    dataiter = iter(testloader)
    d = dataiter.next()
    images, labels = d[0].to(device), d[1].to(device)
    
    # print images
    imshow(torchvision.utils.make_grid(images))
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    

    Ubuntu での実行結果例

    [image]