import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class ChannelSELayer3D(nn.Module):
"""
3D extension of Squeeze-and-Excitation (SE) block described in:
*Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
*Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
Args:
num_channels (int): No of input channels
reduction_ratio (int): By how much should the num_channels should be reduced
"""
super(ChannelSELayer3D, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
num_channels_reduced = num_channels // reduction_ratio
self.reduction_ratio = reduction_ratio
self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
[docs]
def forward(self, x):
batch_size, num_channels, D, H, W = x.size()
# Average along each channel
squeeze_tensor = self.avg_pool(x)
# channel excitation
fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1))
return output_tensor
[docs]
class SpatialSELayer3D(nn.Module):
"""
3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
"""
def __init__(self, num_channels):
"""
Args:
num_channels (int): No of input channels
"""
super(SpatialSELayer3D, self).__init__()
self.conv = nn.Conv3d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
[docs]
def forward(self, x, weights=None):
"""
Args:
weights (torch.Tensor): weights for few shot learning
x: X, shape = (batch_size, num_channels, D, H, W)
Returns:
(torch.Tensor): output_tensor
"""
# channel squeeze
batch_size, channel, D, H, W = x.size()
if weights:
weights = weights.view(1, channel, 1, 1)
out = F.conv2d(x, weights)
else:
out = self.conv(x)
squeeze_tensor = self.sigmoid(out)
# spatial excitation
output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W))
return output_tensor
[docs]
class ChannelSpatialSELayer3D(nn.Module):
"""
3D extension of concurrent spatial and channel squeeze & excitation:
*Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
"""
def __init__(self, num_channels, reduction_ratio=2):
"""
Args:
num_channels (int): No of input channels
reduction_ratio (int): By how much should the num_channels should be reduced
"""
super(ChannelSpatialSELayer3D, self).__init__()
self.cSE = ChannelSELayer3D(num_channels, reduction_ratio)
self.sSE = SpatialSELayer3D(num_channels)
[docs]
def forward(self, input_tensor):
output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
return output_tensor
[docs]
class ConvolutionBlock(nn.Module):
def __init__(self, strides,channels):
super().__init__()
self.strides = strides
self.channel1, self.channel2, self.channel3, self.channel4 = channels
self.in_planes = self.channel2
self.conv3d_r = nn.Conv3d(self.channel1,
self.channel4,
kernel_size=1,
stride = self.strides)
self.batch_normalization_r = nn.BatchNorm3d(self.in_planes)
self.conv3d1 = nn.Conv3d(self.channel1,
self.channel2,
kernel_size=1,
stride= self.strides,
bias=False)
self.batch_normalization1 = nn.BatchNorm3d(self.in_planes)
self.activation1 = nn.ReLU()
self.conv3d2 = nn.Conv3d(self.channel2,
self.channel3,
kernel_size=3,
padding='same')
self.batch_normalization2 = nn.BatchNorm3d(self.in_planes)
self.activation2 = nn.ReLU()
self.conv3d3 = nn.Conv3d(self.channel3,
self.channel4,
kernel_size=1)
self.batch_normalization3 = nn.BatchNorm3d(self.in_planes)
self.activation3 = nn.ReLU()
[docs]
def forward(self,x):
r = self.conv3d_r(x)
r = self.batch_normalization_r(r)
x = self.conv3d1(x)
x = self.batch_normalization1(x)
x = self.activation1(x)
x = self.conv3d2(x)
x = self.batch_normalization2(x)
x = self.activation2(x)
x = self.conv3d3(x)
x = self.batch_normalization3(x)
x = x+r
x = self.activation3(x)
return x
[docs]
class IdentityBlock(nn.Module):
def __init__(self, channels,layer=None):
super().__init__()
self.channel1, self.channel2, self.channel3, self.channel4 = channels
self.in_planes = self.channel2
self.conv3d1 = nn.Conv3d(self.channel1,
self.channel2,
kernel_size=1,
bias=False)
if layer==None:
self.batch_normalization1 = nn.BatchNorm3d(self.in_planes)
self.activation1 = nn.ReLU()
self.conv3d2 = nn.Conv3d(self.channel2,
self.channel3,
kernel_size=3,
padding='same')
if layer==None:
self.batch_normalization2 = nn.BatchNorm3d(self.in_planes)
self.activation2 = nn.ReLU()
self.conv3d3 = nn.Conv3d(self.channel3,
self.channel4,
kernel_size=1)
if layer==None:
self.batch_normalization3 = nn.BatchNorm3d(self.in_planes)
self.activation3 = nn.ReLU()
[docs]
def forward(self,x):
r = x.clone()
x = self.conv3d1(x)
x = self.batch_normalization1(x)
x = self.activation1(x)
x = self.conv3d2(x)
x = self.batch_normalization2(x)
x = self.activation2(x)
x = self.conv3d3(x)
x = self.batch_normalization3(x)
x = x+r
x = self.activation3(x)
return x
[docs]
class UpSamplingBlock(nn.Module):
def __init__(self,channels,stride,size,padding='same',layer=None):
super().__init__()
self.scale = size
self.strides = stride
self.channel1, self.channel2, self.channel3, self.channel4 = channels
self.in_planes = self.channel2
self.conv3d_r = nn.Conv3d(self.channel1,
self.channel4,
kernel_size=1,
stride = self.strides,
padding = 'same')
if layer ==None:
self.batch_normalization_r = nn.BatchNorm3d(self.in_planes)
self.conv3d1 = nn.Conv3d(self.channel1,
self.channel2,
kernel_size=1,
stride = self.strides,
bias=False)
if layer==None:
self.batch_normalization1 = nn.BatchNorm3d(self.in_planes)
self.activation1 = nn.ReLU()
self.conv3d2 = nn.Conv3d(self.channel2,
self.channel3,
kernel_size=3,
padding='same')
if layer==None:
self.batch_normalization2 = nn.BatchNorm3d(self.in_planes)
self.activation2 = nn.ReLU()
self.conv3d3 = nn.Conv3d(self.channel3,
self.channel4,
kernel_size=1)
if layer==None:
self.batch_normalization3 = nn.BatchNorm3d(self.in_planes)
self.activation3 = nn.ReLU()
[docs]
def forward(self,x):
r = x.clone()
x = F.interpolate(x,scale_factor=self.scale,mode = 'trilinear')
x = self.conv3d1(x)
x = self.batch_normalization1(x)
x = self.activation1(x)
x = self.conv3d2(x)
x = self.batch_normalization2(x)
x = self.activation2(x)
x = self.conv3d3(x)
x = self.batch_normalization3(x)
r = F.interpolate(r,scale_factor=self.scale, mode = 'trilinear')
r = self.conv3d_r(r)
r = self.batch_normalization_r(r)
x = x+r
x = self.activation3(x)
return x
[docs]
class SEResNet(nn.Module):
def __init__(self):
super().__init__()
input_channels = 18
f = input_channels
## downsampling
self.Convblock1 = ConvolutionBlock(channels = [input_channels, f,f,f],strides = 1)
self.identblock1 = IdentityBlock(channels = [f,f,f,f])
self.skipidentityblock1 = IdentityBlock(channels = [f,f,f,f])
self.Convblock2 = ConvolutionBlock(channels = [f,f*2,f*2,f*2],strides = 2)
self.identblock2 = IdentityBlock(channels = [f*2,f*2,f*2,f*2])
self.skipidentityblock2 = IdentityBlock(channels = [f*2,f*2,f*2,f*2])
self.Convblock3 = ConvolutionBlock(channels = [f*2, f*4,f*4,f*4],strides = 2)
self.identblock3 = IdentityBlock(channels = [f*4,f*4,f*4,f*4])
self.skipidentityblock3 = IdentityBlock(channels = [f*4,f*4,f*4,f*4])
self.Convblock4 = ConvolutionBlock(channels = [f*4, f*8,f*8,f*8],strides = 3)
self.identblock4 = IdentityBlock(channels = [f*8,f*8,f*8,f*8])
self.skipidentityblock4 = IdentityBlock(channels = [f*8,f*8,f*8,f*8])
self.Convblock5 = ConvolutionBlock(channels = [f*8, f*16,f*16,f*16],strides = 3)
self.identblock5 = IdentityBlock(channels = [f*16,f*16,f*16,f*16])
## upsampling
self.upblock6 = UpSamplingBlock(channels = [f*16,f*16,f*16,f*16],size=3,stride=1)
self.identblock6 = IdentityBlock(channels=[f*16,f*16,f*16,f*16])
self.upblock7 = UpSamplingBlock(channels = [f*16+f*8,f*8,f*8,f*8],size=3,stride=1)
self.identblock7 = IdentityBlock(channels=[f*8,f*8,f*8,f*8])
self.upblock8 = UpSamplingBlock(channels = [f*8+f*4,f*4,f*4,f*4],size=2,stride=1)
self.identblock8 = IdentityBlock(channels=[f*4,f*4,f*4,f*4])
self.upblock9 = UpSamplingBlock(channels = [f*4+f*2,f*2,f*2,f*2],size=2,stride=1)
self.identblock9 = IdentityBlock(channels=[f*2,f*2,f*2,f*2])
self.finalconv10 = nn.Conv3d(f*2+f,1,kernel_size=1)
self.finalact10 = nn.Sigmoid()
[docs]
def forward(self,x):
skip = []
x = self.Convblock1(x)
x = self.identblock1(x)
skip.append(self.skipidentityblock1(x))
x = self.Convblock2(x)
x = self.identblock2(x)
skip.append(self.skipidentityblock2(x))
x = self.Convblock3(x)
x = self.identblock3(x)
skip.append(self.skipidentityblock3(x))
x = self.Convblock4(x)
x = self.identblock4(x)
skip.append(self.skipidentityblock4(x))
x = self.Convblock5(x)
x = self.identblock5(x)
x = self.upblock6(x)
x = self.identblock6(x)
x = torch.cat([x,skip[-1]],dim=1)
x = self.upblock7(x)
x = self.identblock7(x)
x = torch.cat([x,skip[-2]],dim=1)
x = self.upblock8(x)
x = self.identblock8(x)
x = torch.cat([x,skip[-3]],dim=1)
x = self.upblock9(x)
x = self.identblock9(x)
x = torch.cat([x,skip[-4]],dim=1)
x = self.finalconv10(x)
x = self.finalact10(x)
return x