本文最后更新于:14 天前
示例代码:
import torch.nn as nn
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3,
out_channels=96,
kernel_size=2,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=96,
out_channels=256,
kernel_size=2,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=256,
out_channels=384,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=384,
out_channels=384,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
)
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels=384,
out_channels=256,
kernel_size=2,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc1 = nn.Sequential(
nn.Linear(4 * 4 * 256, 4096),
nn.ReLU(),
nn.Dropout(p=0.8)
)
self.fc2 = nn.Sequential(
nn.Linear(4096, 1024),
nn.ReLU(),
nn.Dropout(p=0.8)
)
self.out = nn.Linear(1024, 10)
def forward(self, inputs):
network = self.conv1(inputs)
network = self.conv2(network)
network = self.conv3(network)
network = self.conv4(network)
network = self.conv5(network)
network = network.view(network.size(0), -1)
network = self.fc1(network)
network = self.fc2(network)
out = self.out(network)
return out, network
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!