Published in CVPR2019, a significant target detection article was found when reading references
Some blogs are quite detailed. They are listed below
Then I drew a network diagram to deepen my understanding
catalogue
Published in CVPR2019, a significant target detection article was found when reading references
4.holographic attention module
1. overall block diagram
2.RFB module
class RFB(nn.Module): # RFB-like multi-scale module def __init__(self, in_channel, out_channel): super(RFB, self).__init__() self.relu = nn.ReLU(True) self.branch0 = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), ) self.branch1 = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) ) self.branch2 = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) ) self.branch3 = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) ) self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1) self.conv_res = BasicConv2d(in_channel, out_channel, 1) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) x = self.relu(x_cat + self.conv_res(x)) return x
3.aggregation module
class aggregation(nn.Module): # dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on. # used after MSF def __init__(self, channel): super(aggregation, self).__init__() self.relu = nn.ReLU(True) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1) self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1) self.conv4 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1) self.conv5 = nn.Conv2d(3 * channel, 1, 1) def forward(self, x1, x2, x3): x1_1 = x1 x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ * self.conv_upsample3(self.upsample(x2)) * x3 x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) x2_2 = self.conv_concat2(x2_2) x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) x3_2 = self.conv_concat3(x3_2) x = self.conv4(x3_2) x = self.conv5(x) return x
4.holographic attention module
- This module aims to expand the area of the initial saliency prediction map, and then guide the problems such as inaccurate lifting boundary and incomplete results.
- When the accurate saliency map is obtained from the attention branch, the strategy will effectively suppress the interference of features.
- On the contrary, if the interference is classified as a significant region, the strategy results in abnormal segmentation results. Therefore, it is necessary to improve the effectiveness of the initial saliency map. More specifically, the edge information of the saliency target may be filtered out by the initial saliency map because it is difficult to accurately predict. In addition, some objects in complex scenes are difficult to be completely segmented. Therefore, a global attention module is proposed to expand the coverage of the initial saliency map.
HereRepresents a convolution operation with Gaussian kernel k and zero bias, where
Represents a normalization function to make the range of the blurred map [0,1]. The MAX operation represents the maximum value function, which makes
It tends to increase the weight coefficient of the smoothed medium saliency region, which not only retains the value of the original saliency region (the original value of the saliency region is greater than the value of the smoothed fuzzy map), but also improves the attention to the boundary region of the original saliency map and expands the area of saliency perception (in the non saliency region, the smoothed fuzzy map value is greater than the original saliency map value).
Set attention map Perform element wise multiplexing with the third layer convolution feature to obtain the modified feature after attention. Together with the layer 4 and layer 5 features, they are sent to the decoder section to generate a new saliency prediction map. Compared with the initial attention, the proposed overall attention mechanism increases the computational cost, but further highlights the overall saliency goal.
Note: the size and standard deviation of Gaussian kernel k here are initialized to 32 and 4, and will be automatically learned during training
class HA(nn.Module): # holistic attention module def __init__(self): super(HA, self).__init__() gaussian_kernel = np.float32(gkern(31, 4)) gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...] self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel)) def forward(self, attention, x): soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15) soft_attention = min_max_norm(soft_attention) x = torch.mul(x, soft_attention.max(attention)) return x def gkern(kernlen=16, nsig=3): interval = (2*nsig+1.)/kernlen x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) kern1d = np.diff(st.norm.cdf(x)) kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) kernel = kernel_raw/kernel_raw.sum() return kernel def min_max_norm(in_): max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) in_ = in_ - min_ return in_.div(max_-min_+1e-8)