トップページ -> データベース関連技術 -> PyTorch の応用例 -> 単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)
[サイトマップへ], [サイト内検索へ]

単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)

ユースケース: PyTorch の動作確認を行いたい.CNN (PyTorch を使用)について練習したい

次のWebページに記載のソースコード(単純な CNN を用いた画像分類)を実行してみる

参考 Web ページ: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py

先人に感謝

PyTorch の Web ページ: http://pytorch.org

GitHub の PyTorch の Webページ: https://github.com/pytorch/pytorch

前準備

PyTorch のインストール


単純な CNN を用いた画像分類(PyTorch のサンプルプログラムを使用)

Python プログラムを動かす.

※ Python プログラムを動かすために, Windows では,「python」コマンドを使う. Ubuntu では「python3」コマンドを使う.

開発環境や Python コンソール(Jupyter Qt ConsolespyderPyCharmPyScripter など)も便利である.

  1. インポート

    import torch
    import torchvision
    import torchvision.transforms as transforms
    

    Ubuntu での実行結果例

    [image]
  2. CIFAR 10 のダウンロード

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    

    Ubuntu での実行結果例

    [image]
  3. CNN の定義

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    

    Ubuntu での実行結果例

    [image]
  4. GPU デバイスがあれば,それを使いたい

    net = Net()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    

    Ubuntu での実行結果例

    [image]
  5. 損失関数と最適化の定義

    import torch.optim as optim
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    

    Ubuntu での実行結果例

    [image]
  6. 訓練(学習)

    inputs, labels = data[0].to(device), data[1].to(device)」は,GPUでもプログラムが動くようにするための処理

    for epoch in range(2):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data[0].to(device), data[1].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    
    print('Finished Training')
    

    Ubuntu での実行結果例

    [image]
  7. テストデータを用いたテスト

    dataiter = iter(testloader)
    d = dataiter.next()
    images, labels = d[0].to(device), d[1].to(device)
    
    # print images
    imshow(torchvision.utils.make_grid(images))
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    

    Ubuntu での実行結果例

    [image]