diff --git a/README.md b/README.md index 22bc939..19a5344 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,12 @@ Toolbox: Change detection model: -- [x] [Changer (arXiv'2022)](configs/changer) -- [x] [FC-Siam-Diff (ICIP'2018)](configs/fcsn) +- [x] [FC-EF (ICIP'2018)](configs/fcsn) +- [x] [FC-Siam-diff (ICIP'2018)](configs/fcsn) +- [x] [FC-Siam-conc (ICIP'2018)](configs/fcsn) - [x] [SNUNet (GRSL'2021)](configs/snunet) - [x] [BiT (TGRS'2021)](configs/bit) +- [x] [Changer (arXiv'2022)](configs/changer) - [ ] ... The code of some models are borrowed directly from their official repositories. diff --git a/configs/_base_/models/fc_ef.py b/configs/_base_/models/fc_ef.py new file mode 100644 index 0000000..ea80c9f --- /dev/null +++ b/configs/_base_/models/fc_ef.py @@ -0,0 +1,23 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +base_channels = 16 +model = dict( + type='DIEncoderDecoder', + pretrained=None, + backbone=dict( + type='FC_EF', + in_channels=6, + base_channel=base_channels), + decode_head=dict( + type='FCNHead', + in_channels=base_channels, + channels=base_channels, + in_index=-1, + num_convs=0, + concat_input=False, + num_classes=2, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/configs/_base_/models/fc_siam_conc.py b/configs/_base_/models/fc_siam_conc.py new file mode 100644 index 0000000..1c8681d --- /dev/null +++ b/configs/_base_/models/fc_siam_conc.py @@ -0,0 +1,23 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +base_channels = 16 +model = dict( + type='DIEncoderDecoder', + pretrained=None, + backbone=dict( + type='FC_Siam_conc', + in_channels=3, + base_channel=base_channels), + decode_head=dict( + type='FCNHead', + in_channels=base_channels, + channels=base_channels, + in_index=-1, + num_convs=0, + concat_input=False, + num_classes=2, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/configs/_base_/models/fc_siam_diff.py b/configs/_base_/models/fc_siam_diff.py index 43252b3..66c5d1f 100644 --- a/configs/_base_/models/fc_siam_diff.py +++ b/configs/_base_/models/fc_siam_diff.py @@ -5,7 +5,7 @@ type='DIEncoderDecoder', pretrained=None, backbone=dict( - type='SiamUnet_diff', + type='FC_Siam_diff', in_channels=3, base_channel=base_channels), decode_head=dict( diff --git a/opencd/models/backbones/__init__.py b/opencd/models/backbones/__init__.py index 55b692f..6da9ca7 100644 --- a/opencd/models/backbones/__init__.py +++ b/opencd/models/backbones/__init__.py @@ -1,6 +1,7 @@ from .interaction_resnet import IA_ResNetV1c from .interaction_resnest import IA_ResNeSt -from .siamunet_diff import SiamUnet_diff +from .fcsn import FC_EF, FC_Siam_diff, FC_Siam_conc from .snunet import SNUNet_ECAM -__all__ = ['IA_ResNetV1c', 'IA_ResNeSt', 'SiamUnet_diff', 'SNUNet_ECAM'] \ No newline at end of file +__all__ = ['IA_ResNetV1c', 'IA_ResNeSt', 'FC_EF', 'FC_Siam_diff', + 'FC_Siam_conc', 'SNUNet_ECAM'] \ No newline at end of file diff --git a/opencd/models/backbones/fcsn.py b/opencd/models/backbones/fcsn.py new file mode 100644 index 0000000..28f45af --- /dev/null +++ b/opencd/models/backbones/fcsn.py @@ -0,0 +1,489 @@ +""" +Daudt, R. C., Le Saux, B., & Boulch, A. +"Fully convolutional siamese networks for change detection". +In 2018 25th IEEE International Conference on Image Processing (ICIP) +(pp. 4063-4067). IEEE. + +Some code in this file is borrowed from: +https://github.com/rcdaudt/fully_convolutional_change_detection +https://github.com/Bobholamovic/CDLab +https://github.com/likyoo/Siam-NestedUNet +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.padding import ReplicationPad2d + +from mmseg.models.builder import BACKBONES + + +@BACKBONES.register_module() +class FC_EF(nn.Module): + """FC_EF segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_EF, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + x = torch.cat((x1, x2), 1) + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) + x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43, kernel_size=2, stride=2) + + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), x43), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), x33), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), x22), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), x12), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) + + +@BACKBONES.register_module() +class FC_Siam_diff(nn.Module): + """FC_Siam_diff segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_Siam_diff, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) + x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) + + #################################################### + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) + x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) + + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) + + +@BACKBONES.register_module() +class FC_Siam_conc(nn.Module): + """FC_Siam_conc segmentation network.""" + + def __init__(self, in_channels, base_channel=16): + super(FC_Siam_conc, self).__init__() + + filters = [base_channel, base_channel * 2, base_channel * 4, + base_channel * 8, base_channel * 16] + + self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(filters[0]) + self.do11 = nn.Dropout2d(p=0.2) + self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(filters[0]) + self.do12 = nn.Dropout2d(p=0.2) + + self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(filters[1]) + self.do21 = nn.Dropout2d(p=0.2) + self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(filters[1]) + self.do22 = nn.Dropout2d(p=0.2) + + self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(filters[2]) + self.do31 = nn.Dropout2d(p=0.2) + self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(filters[2]) + self.do32 = nn.Dropout2d(p=0.2) + self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(filters[2]) + self.do33 = nn.Dropout2d(p=0.2) + + self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(filters[3]) + self.do41 = nn.Dropout2d(p=0.2) + self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(filters[3]) + self.do42 = nn.Dropout2d(p=0.2) + self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(filters[3]) + self.do43 = nn.Dropout2d(p=0.2) + + self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv43d = nn.ConvTranspose2d(filters[3]+filters[4], filters[3], kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(filters[3]) + self.do43d = nn.Dropout2d(p=0.2) + self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(filters[3]) + self.do42d = nn.Dropout2d(p=0.2) + self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(filters[2]) + self.do41d = nn.Dropout2d(p=0.2) + + self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv33d = nn.ConvTranspose2d(filters[2]+filters[3], filters[2], kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(filters[2]) + self.do33d = nn.Dropout2d(p=0.2) + self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(filters[2]) + self.do32d = nn.Dropout2d(p=0.2) + self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(filters[1]) + self.do31d = nn.Dropout2d(p=0.2) + + self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv22d = nn.ConvTranspose2d(filters[1]+filters[2], filters[1], kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(filters[1]) + self.do22d = nn.Dropout2d(p=0.2) + self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(filters[0]) + self.do21d = nn.Dropout2d(p=0.2) + + self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) + + self.conv12d = nn.ConvTranspose2d(filters[0]+filters[1], filters[0], kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(filters[0]) + self.do12d = nn.Dropout2d(p=0.2) + self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) + + def forward(self, x1, x2): + """Forward method.""" + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) + x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) + + #################################################### + # Stage 1 + x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) + x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) + x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) + + # Stage 2 + x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) + x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) + x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) + + # Stage 3 + x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) + x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) + x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) + x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) + + # Stage 4 + x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) + x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) + x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) + x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) + + #################################################### + # Stage 4d + x4d = self.upconv4(x4p) + pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) + x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) + x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) + x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) + x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) + + # Stage 3d + x3d = self.upconv3(x41d) + pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) + x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) + x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) + x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) + x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) + + # Stage 2d + x2d = self.upconv2(x31d) + pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) + x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) + x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) + x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) + + # Stage 1d + x1d = self.upconv1(x21d) + pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) + x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) + x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) + x11d = self.conv11d(x12d) + + return (x11d,) \ No newline at end of file diff --git a/opencd/models/backbones/siamunet_diff.py b/opencd/models/backbones/siamunet_diff.py deleted file mode 100644 index 94f50bb..0000000 --- a/opencd/models/backbones/siamunet_diff.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -Daudt, R. C., Le Saux, B., & Boulch, A. -"Fully convolutional siamese networks for change detection". -In 2018 25th IEEE International Conference on Image Processing (ICIP) -(pp. 4063-4067). IEEE. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.modules.padding import ReplicationPad2d - -from mmseg.models.builder import BACKBONES - - -@BACKBONES.register_module() -class SiamUnet_diff(nn.Module): - """SiamUnet_diff segmentation network.""" - - def __init__(self, in_channels, base_channel=16): - super(SiamUnet_diff, self).__init__() - - n1 = base_channel - filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] - self.input_nbr = in_channels - - self.conv11 = nn.Conv2d(in_channels, filters[0], kernel_size=3, padding=1) - self.bn11 = nn.BatchNorm2d(filters[0]) - self.do11 = nn.Dropout2d(p=0.2) - self.conv12 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) - self.bn12 = nn.BatchNorm2d(filters[0]) - self.do12 = nn.Dropout2d(p=0.2) - - self.conv21 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1) - self.bn21 = nn.BatchNorm2d(filters[1]) - self.do21 = nn.Dropout2d(p=0.2) - self.conv22 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) - self.bn22 = nn.BatchNorm2d(filters[1]) - self.do22 = nn.Dropout2d(p=0.2) - - self.conv31 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1) - self.bn31 = nn.BatchNorm2d(filters[2]) - self.do31 = nn.Dropout2d(p=0.2) - self.conv32 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) - self.bn32 = nn.BatchNorm2d(filters[2]) - self.do32 = nn.Dropout2d(p=0.2) - self.conv33 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1) - self.bn33 = nn.BatchNorm2d(filters[2]) - self.do33 = nn.Dropout2d(p=0.2) - - self.conv41 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1) - self.bn41 = nn.BatchNorm2d(filters[3]) - self.do41 = nn.Dropout2d(p=0.2) - self.conv42 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) - self.bn42 = nn.BatchNorm2d(filters[3]) - self.do42 = nn.Dropout2d(p=0.2) - self.conv43 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1) - self.bn43 = nn.BatchNorm2d(filters[3]) - self.do43 = nn.Dropout2d(p=0.2) - - self.upconv4 = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1, stride=2, output_padding=1) - - self.conv43d = nn.ConvTranspose2d(filters[4], filters[3], kernel_size=3, padding=1) - self.bn43d = nn.BatchNorm2d(filters[3]) - self.do43d = nn.Dropout2d(p=0.2) - self.conv42d = nn.ConvTranspose2d(filters[3], filters[3], kernel_size=3, padding=1) - self.bn42d = nn.BatchNorm2d(filters[3]) - self.do42d = nn.Dropout2d(p=0.2) - self.conv41d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) - self.bn41d = nn.BatchNorm2d(filters[2]) - self.do41d = nn.Dropout2d(p=0.2) - - self.upconv3 = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1, stride=2, output_padding=1) - - self.conv33d = nn.ConvTranspose2d(filters[3], filters[2], kernel_size=3, padding=1) - self.bn33d = nn.BatchNorm2d(filters[2]) - self.do33d = nn.Dropout2d(p=0.2) - self.conv32d = nn.ConvTranspose2d(filters[2], filters[2], kernel_size=3, padding=1) - self.bn32d = nn.BatchNorm2d(filters[2]) - self.do32d = nn.Dropout2d(p=0.2) - self.conv31d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) - self.bn31d = nn.BatchNorm2d(filters[1]) - self.do31d = nn.Dropout2d(p=0.2) - - self.upconv2 = nn.ConvTranspose2d(filters[1], filters[1], kernel_size=3, padding=1, stride=2, output_padding=1) - - self.conv22d = nn.ConvTranspose2d(filters[2], filters[1], kernel_size=3, padding=1) - self.bn22d = nn.BatchNorm2d(filters[1]) - self.do22d = nn.Dropout2d(p=0.2) - self.conv21d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) - self.bn21d = nn.BatchNorm2d(filters[0]) - self.do21d = nn.Dropout2d(p=0.2) - - self.upconv1 = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1, stride=2, output_padding=1) - - self.conv12d = nn.ConvTranspose2d(filters[1], filters[0], kernel_size=3, padding=1) - self.bn12d = nn.BatchNorm2d(filters[0]) - self.do12d = nn.Dropout2d(p=0.2) - self.conv11d = nn.ConvTranspose2d(filters[0], filters[0], kernel_size=3, padding=1) - - def forward(self, x1, x2): - """Forward method.""" - # Stage 1 - x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) - x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) - x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) - - # Stage 2 - x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) - x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) - x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) - - # Stage 3 - x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) - x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) - x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) - x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) - - # Stage 4 - x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) - x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) - x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) - x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) - - #################################################### - # Stage 1 - x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) - x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) - x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) - - # Stage 2 - x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) - x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) - x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) - - # Stage 3 - x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) - x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) - x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) - x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) - - # Stage 4 - x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) - x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) - x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) - x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) - - # Stage 4d - x4d = self.upconv4(x4p) - pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) - x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) - x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) - x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) - x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) - - # Stage 3d - x3d = self.upconv3(x41d) - pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) - x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) - x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) - x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) - x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) - - # Stage 2d - x2d = self.upconv2(x31d) - pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) - x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) - x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) - x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) - - # Stage 1d - x1d = self.upconv1(x21d) - pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) - x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) - x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) - x11d = self.conv11d(x12d) - - return (x11d,) \ No newline at end of file