Pytorch-MNIST手写识别

mtain 2024年03月08日 30次浏览

1. 官方数据集

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder

# 定义神经网络模型
class Net(nn.Module):
    """定义神经网络模型类"""
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)  # 第一个全连接层,输入维度为 28*28,输出维度为 500
        self.fc2 = nn.Linear(500, 10)  # 第二个全连接层,输入维度为 500,输出维度为 10

    def forward(self, x):
        """前向传播方法"""
        x = x.view(-1, 28*28)  # 将输入张量进行展平操作,维度变为 (batch_size, 28*28)
        x = torch.relu(self.fc1(x))  # 使用 ReLU 激活函数进行非线性变换
        x = self.fc2(x)  # 经过第二个全连接层,输出结果
        return x

# 加载 MNIST 数据集
transform=transforms.Compose([
   transforms.ToTensor(),  # 将图像数据转换为张量
   transforms.Normalize((0.1307,), (0.3081,))  # 归一化处理
])

# 自动下载MNIST数据集到datasets目录
train_dataset = datasets.MNIST('datasets', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('datasets', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)  # 每批加载 64 个样本,打乱顺序
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=True)  # 每批加载 1000 个样本,打乱顺序

# 创建模型实例
model = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # 随机梯度下降优化器

# 训练循环
def train(epoch):
    """训练循环"""
    model.train()
    for batch_idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()  # 清零梯度
        output = model(images)  # 前向传播
        loss = criterion(output, labels)  # 计算损失
        
        loss.backward()  # 反向传播,计算梯度
        optimizer.step()  # 更新参数
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(images), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# 测试循环
def test():
    """测试循环"""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)  # 前向传播
            test_loss += criterion(output, target).item()  # 计算测试集损失
            pred = output.data.max(1, keepdim=True)[1]  # 获取预测结果
            correct += pred.eq(target.data.view_as(pred)).sum()  # 计算预测正确的样本数量
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

# 运行训练和测试
for epoch in range(1, 11):
    train(epoch)  # 训练模型
    test()  # 在测试集上评估模型

# 训练结果保存到模型文件
torch.save(model,'mnist.pt')

2. 自定义数据集


3. 使用训练模型预测结果

# 加载训练结果模型,预测识别结果
import torch
from PIL import Image
from torchvision import transforms
import torch.nn as nn


# 定义神经网络模型
class Net(nn.Module):
    """定义神经网络模型类"""
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)  # 第一个全连接层,输入维度为 28*28,输出维度为 500
        self.fc2 = nn.Linear(500, 10)  # 第二个全连接层,输入维度为 500,输出维度为 10

    def forward(self, x):
        """前向传播方法"""
        x = x.view(-1, 28*28)  # 将输入张量进行展平操作,维度变为 (batch_size, 28*28)
        x = torch.relu(self.fc1(x))  # 使用 ReLU 激活函数进行非线性变换
        x = self.fc2(x)  # 经过第二个全连接层,输出结果
        return x

device = torch.device('cpu')

transform=transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
def prediect(img_path):
    net=Net()
    net=torch.load('mnist.pt')
    net.eval()
    net=net.to(device)
    torch.no_grad()
    img=Image.open(img_path)    
    img=transform(img).unsqueeze(0)
    img_ = img.to(device)
    outputs = net(img_)
    _, predicted = torch.max(outputs, 1)
    print(predicted)
    lable_idx = int(predicted[0])
    print('this picture maybe :', labels[lable_idx])

# 自己用画图工具写的数字,黑底白字,否则需要对图像进行转化
if __name__ == '__main__':
    prediect('./datasets/MNIST/test.png')

参考文章

一小时实践入门 PyTorch:https://zhuanlan.zhihu.com/p/644120304