1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
import torch from torch import nn from torch.nn import init import torch.nn.functional as F
from ultralytics.nn.modules.conv import Conv
class ChannelAttention(nn.Module): def __init__(self, channels: int) -> None: super().__init__() self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) self.act = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.act(self.fc(self.pool(x)))
class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.act = nn.Sigmoid()
def forward(self, x): return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
class CBAM(nn.Module): def __init__(self, c1, c2, kernel_size=7): super().__init__() self.channel_attention = ChannelAttention(c2) self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x): return self.spatial_attention(self.channel_attention(x))
def channel_shuffle(x, groups=2): B, C, H, W = x.size() out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous() out = out.view(B, C, H, W) return out
class GAM_Attention(nn.Module): def __init__(self, c1, c2, group=True, rate=4): super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential( nn.Linear(c1, int(c1 / rate)), nn.ReLU(inplace=True), nn.Linear(int(c1 / rate), c1) )
self.spatial_attention = nn.Sequential(
nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 / rate)), nn.ReLU(inplace=True), nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2) )
def forward(self, x): b, c, h, w = x.shape x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) x_channel_att = x_att_permute.permute(0, 3, 1, 2) x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid() x_spatial_att = channel_shuffle(x_spatial_att, 4) out = x * x_spatial_att return out
|