本文最后更新于:14 天前

示例代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/9/13 19:45
# @Author  : Seven
# @Site    : 
# @File    : train.py
# @Software: PyCharm

import torch
import os
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


DOWNLOAD_MNIST = False

if not(os.path.exists('./data/')) or not os.listdir('./data/'):  # 判断数据是否存在
    DOWNLOAD_MNIST = True

trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=DOWNLOAD_MNIST,
                                        transform=transform_train)

trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=128,
                                          shuffle=True,
                                          num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=DOWNLOAD_MNIST,
                                       transform=transform_test)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=0)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!

Pytorch实现ResNextNet 上一篇
Pytorch实现VGGNet 下一篇