笔记笔记
  • Home
  • AI&ML
  • Example
  • Zoo
  • 关于
⌘ K
简单的使用
自动微分
神经网络
图像分类
最后更新时间:
Copyright © 2023-2024 | Powered by dumi | GuoDapeng | 冀ICP备20004032号-1 | 冀公网安备 冀公网安备 13024002000293号

TABLE OF CONTENTS

‌
‌
‌
‌

图像分类 - 回到 PyTorch 笔记

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# numpy 数据转 Tensor
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5,),
(0.5, 0.5, 0.5,),
),
]
)
# 训练数据
train_set = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform,
)
train_loader = DataLoader(
train_set,
batch_size=4,
shuffle=True,
num_workers=2,
)
# 测试数据
test_set = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform,
)
test_loader = DataLoader(
test_set,
batch_size=4,
shuffle=False,
num_workers=2,
)
classes = ('飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车')
def im_show(img):
img = img / 2 + 0.5
np_img = img.numpy()
plt.imshow(np.transpose(np_img, (1, 2, 0)))
plt.show()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 卷积操作;提取特征
self.pool = nn.MaxPool2d(2, 2) # 卷积操作;它的作用是提取输入数据中的最大值
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 线性变换
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
if __name__ == '__main__': # 这里不这样来一下就会报错
# 下面这段主要是查看一下图片被标记的种类
dataiter = iter(train_loader)
images, labels = dataiter.__next__()
im_show(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
net = Net() # 创建神经网络
criterion = nn.CrossEntropyLoss() # 指定损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 创建优化器
for epoch in range(3):
running_loss = .0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad() # 重置所有模型参数的梯度
outputs = net(inputs)
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新规则
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
dataiter = iter(test_loader)
images, labels = dataiter.__next__()
im_show(torchvision.utils.make_grid(images))
print('预期: ', ' '.join('%6s' % classes[labels[j]] for j in range(4)))
outputs = net(images) # 进行预测
_, predicted = torch.max(outputs, 1)
print('预测: ', ' '.join('%6s' % classes[predicted[j]] for j in range(4)))
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() # 这里的 sum() 是 torch 提供的,是 tensor 的函数
print('网络对10000张测试图像的准确性: %d %%' % (100 * correct / total))
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('准确率 %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

输出:

船 汽车 马 船
[1, 2000] loss: 2.156
[1, 4000] loss: 1.835
[1, 6000] loss: 1.689
[1, 8000] loss: 1.577
[1, 10000] loss: 1.512
[1, 12000] loss: 1.467
[2, 2000] loss: 1.385
[2, 4000] loss: 1.354
[2, 6000] loss: 1.341
[2, 8000] loss: 1.306
[2, 10000] loss: 1.309
[2, 12000] loss: 1.272
[3, 2000] loss: 1.219
[3, 4000] loss: 1.193
[3, 6000] loss: 1.205
[3, 8000] loss: 1.183
[3, 10000] loss: 1.176
[3, 12000] loss: 1.154
预期: 猫 船 船 飞机
预测: 猫 船 飞机 飞机
网络对10000张测试图像的准确性: 58 %
准确率 飞机 : 66 %
准确率 汽车 : 68 %
准确率 鸟 : 59 %
准确率 猫 : 43 %
准确率 鹿 : 50 %
准确率 狗 : 51 %
准确率 青蛙 : 69 %
准确率 马 : 58 %
准确率 船 : 63 %
准确率 卡车 : 48 %
进程已结束,退出代码为 0