本文最后更新于:14 天前
示例代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/9/14 16:02
# @Author : Seven
# @Site :
# @File : DenseNet.py
# @Software: PyCharm
import math
import torch
import torch.nn as nn
class Bn_act_conv_drop(nn.Module):
def __init__(self, inputs, outs, kernel_size, padding):
super(Bn_act_conv_drop, self).__init__()
self.bn = nn.Sequential(
nn.BatchNorm2d(inputs),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=inputs,
out_channels=outs,
kernel_size=kernel_size,
padding=padding,
stride=1),
nn.ReLU(),
nn.Dropout()
)
def forward(self, inputs):
network = self.bn(inputs)
network = self.conv(network)
return network
class Transition(nn.Module):
def __init__(self, inputs, outs):
super(Transition, self).__init__()
self.conv = Bn_act_conv_drop(inputs, outs, kernel_size=1, padding=0)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, inputs):
network = self.conv(inputs)
network = self.avgpool(network)
return network
class Block(nn.Module):
def __init__(self, inputs, growth):
super(Block, self).__init__()
self.conv1 = Bn_act_conv_drop(inputs, 4*growth, kernel_size=1, padding=0)
self.conv2 = Bn_act_conv_drop(4*growth, growth, kernel_size=3, padding=1)
def forward(self, inputs):
network = self.conv1(inputs)
network = self.conv2(network)
out = torch.cat([network, inputs], 1)
return out
class DenseNet(nn.Module):
def __init__(self, blocks, growth):
super(DenseNet, self).__init__()
num_planes = 2*growth
inputs = 3
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=inputs,
out_channels=num_planes,
kernel_size=3,
# stride=2,
padding=1),
nn.ReLU(),
# nn.MaxPool2d(kernel_size=2, stride=2)
)
self.block1 = self._block(blocks[0], num_planes, growth)
num_planes += blocks[0] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran1 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes
self.block2 = self._block(blocks[1], num_planes, growth)
num_planes += blocks[1] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran2 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes
self.block3 = self._block(blocks[2], num_planes, growth)
num_planes += blocks[2] * growth
out_planes = int(math.floor(num_planes * 0.5))
self.tran3 = Transition(inputs=num_planes, outs=out_planes)
num_planes = out_planes
self.block4 = self._block(blocks[3], num_planes, growth)
num_planes += blocks[3] * growth
self.bn = nn.Sequential(
nn.BatchNorm2d(num_planes),
nn.ReLU()
)
self.avgpool = nn.AvgPool2d(kernel_size=4)
self.linear = nn.Linear(num_planes, 10)
def forward(self, inputs):
network = self.conv(inputs)
network = self.block1(network)
network = self.tran1(network)
network = self.block2(network)
network = self.tran2(network)
network = self.block3(network)
network = self.tran3(network)
network = self.block4(network)
network = self.bn(network)
network = self.avgpool(network)
network = network.view(network.size(0), -1)
out = self.linear(network)
return out, network
@staticmethod
def _block(layers, inputs, growth):
block_layer = []
for layer in range(layers):
network = Block(inputs, growth)
block_layer.append(network)
inputs += growth
block_layer = nn.Sequential(*block_layer)
return block_layer
def DenseNet121():
return DenseNet(blocks=[6, 12, 24, 16], growth=32)
def DenseNet169():
return DenseNet(blocks=[6, 12, 32, 32], growth=32)
def DenseNet201():
return DenseNet(blocks=[6, 12, 48, 32], growth=32)
def DenseNet161():
return DenseNet(blocks=[6, 12, 36, 24], growth=48)
def DenseNet_cifar():
return DenseNet(blocks=[6, 12, 24, 16], growth=12)
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!