-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChangeDetectionTask #2422
base: main
Are you sure you want to change the base?
Add ChangeDetectionTask #2422
Conversation
I wonder if we should limit the scope to change between two timesteps and binary change - then we can use binary metrics and provide a template for the plot methods. I say this because this is the most common change detection task by a mile. Might also simplify the augmentations approach? Treating as a video sequence seems overkill. |
I'm personally okay with this, although @hfangcat has a recent work using multiple pre-event images that would be nice to support someday (could be a subclass if necessary).
Again, this would probably be fine as a starting point, although I would someday like to make all trainers support binary/multiclass/multilabel, e.g., #2219.
Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers: #2184
Agreed.
I actually like the video augmentations, but let me loop in the Kornia folks to get their opinion: @edgarriba @johnnv1
Correct, see #2382 for the big picture (I think I also sent you a recording of my presented plan). |
Can you try
@ashnair1 would this work directly with |
I will go ahead and make changes for this to be for binary change and two timesteps, sounds like a good starting point.
I tried this and it didn't get rid of the other dimension. I also looked into
I was going to add plotting in the trainer, but would you rather not then? What would this look like in the dataset? |
Perhaps there should even be a base class ChangeDetection and subclasses for BinaryChangeDetection etc? |
That's exactly what I'm trying to undo in #2219. |
We can copy-n-paste the
See |
I've updated this to now support only binary change with two timesteps. I still haven't been able to figure out how to make |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you resolve the merge conflicts so we can run the tests?
load_state_dict_from_url: None, | ||
) -> WeightsEnum: | ||
path = tmp_path / f'{weights}.pth' | ||
# multiply in_chans by 2 since images are concatenated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How hard would it be to do late fusion, so pass each image through the encoder separately, then concatenate them, then pass them through the decoder? This would make it easier to use pre-trained models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's definitely possible, although I think we would need a custom Unet implementation in torchgeo/models to do this. It would simplify using the pretrained weights but is late fusion a common enough approach that many people would find this useful?
monkeypatch.setattr(weights, 'url', str(path)) | ||
return weights | ||
|
||
@pytest.mark.parametrize('model', [6], indirect=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me what [6]
means here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Number of input channels (2 3-channel images stacked)
torchgeo/datamodules/oscd.py
Outdated
K.Normalize(mean=self.mean, std=self.std), | ||
_RandomNCrop(self.patch_size, batch_size), | ||
), | ||
data_keys=['image', 'mask'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be data_keys=None
so that Kornia automatically detects the key name from the dict keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I updated this when resolving the merge conflict.
@@ -240,7 +242,7 @@ def _load_target(self, path: Path) -> Tensor: | |||
array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) | |||
tensor = torch.from_numpy(array) | |||
tensor = torch.clamp(tensor, min=0, max=1) | |||
tensor = tensor.to(torch.long) | |||
tensor = tensor.to(torch.float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would the target be a float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss function BCEWithLogitsLoss
expects the target to be a float.
torchgeo/losses/__init__.py
Outdated
|
||
__all__ = ('QRLoss', 'RQLoss') | ||
__all__ = ('QRLoss', 'RQLoss', 'BinaryFocalJaccardLoss', 'BinaryXEntJaccardLoss') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never seen these losses before, are they standard in the literature? If not, I don't think we should add them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if they are standard. I haven't seen them before, but they were in the code from the previous PR that I started from. I can take them out.
for param in self.model.decoder.parameters(): | ||
param.requires_grad = False | ||
|
||
def _shared_step(self, batch: Any, batch_idx: int, stage: str) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add this.
@@ -102,7 +102,13 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]: | |||
batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') | |||
|
|||
# Torchmetrics does not support masks with a channel dimension | |||
if 'mask' in batch and batch['mask'].shape[1] == 1: | |||
# Kornia adds a temporal dimension to mask when passed through VideoSequential. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should no longer be using our AugmentationSequential
wrapper, let's use the Kornia version upstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I've now switched it to K.AugmentationSequential
.
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
I'm going to need some help figuring out how to get Also, disregard my earlier comments about Kornia |
This PR is to add a change detection trainer as mentioned in #2382.
Key points/items to discuss:
AugmentationSequential
doesn’t work, but can be combined withVideoSequential
to support the temporal dimension (see Kornia docs). I overrodeself.aug
in theOSCDDataModule
to do this but not sure if this should be incorporated into theBaseDataModule
instead.VideoSequential
adds a temporal dimension to the mask. Not sure if there is a way to avoid this, or if this is desirable, but I added an if statement to theAugmentationSequential
wrapper to check for and remove this added dimension._RandomNCrop
augmentation, but this does not work for time series data. I'm not sure how to modify_RandomNCrop
to fix this and would appreciate some help/guidance.cc @robmarkcole