Skip to content

Commit

Permalink
fix: dynamic threshold was incorrect, midas depth dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
samedii committed Aug 18, 2022
1 parent 4d804b9 commit eec8453
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 259 deletions.
5 changes: 4 additions & 1 deletion perceptor/losses/owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def device(self):
return next(iter(self.model.parameters())).device

def to(self, device):
self.encodings.to(device)
if self.encodings is not None:
self.encodings.to(device)
if self.weights is not None:
self.weights.to(device)
return super().to(device)

def cuda(self):
Expand Down
3 changes: 2 additions & 1 deletion perceptor/models/midas_depth/midas_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def forward(self, images: Tensor.dims("NCHW")) -> Tensor.dims("NCHW"):
images,
out_shape=self.image_size,
)
return -self.model(self.normalization(images)).float()
return -self.model(self.normalization(images)).float()[:, None]


def test_midas_depth():
Expand All @@ -145,6 +145,7 @@ def test_midas_depth():

with torch.enable_grad():
depths = model(images)
assert len(depths.shape) == 4
depths.mean().backward()

assert images.grad is not None
Expand Down
28 changes: 21 additions & 7 deletions perceptor/models/velocity_diffusion/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,17 @@ def dynamic_threshold(self, quantile=0.95) -> "Predictions":
Thresholding heuristic from imagen paper
"""
dynamic_threshold = torch.quantile(
self.denoised_xs.flatten(start_dim=1).mul(2).sub(1).abs(), quantile, dim=1
self.denoised_xs.flatten(start_dim=1).abs(), quantile, dim=1
).clamp(min=1.0)
denoised_xs = clamp_with_grad(
self.denoised_xs,
-dynamic_threshold,
dynamic_threshold,
denoised_xs = (
clamp_with_grad(
self.denoised_xs,
-dynamic_threshold,
dynamic_threshold,
)
# / dynamic_threshold
# imagen's dynamic thresholding divides by threshold but this makes the images gray
)
# note: imagen dynamic thresholding is dividing by dynamic_threshold
# but this makes everything gray
return self.forced_denoised(diffusion_space.decode(denoised_xs))

def static_threshold(self):
Expand All @@ -211,6 +213,18 @@ def forced_denoised(self, denoised_images) -> "Predictions":
- self.from_sigmas * denoised_xs
)

def forced_predicted_noise(self, predicted_noise) -> "Predictions":
if (self.from_alphas >= 1e-3).all():
denoised_xs = (
self.from_diffused_xs - predicted_noise * self.from_sigmas
) / self.from_alphas
else:
denoised_xs = self.denoised_xs
return self.replace(
velocities=self.from_alphas * predicted_noise
- self.from_sigmas * denoised_xs
)

def wasserstein_distance(self):
sorted_noise = self.predicted_noise.flatten(start_dim=1).sort(dim=1)[0]
n = sorted_noise.shape[1]
Expand Down
4 changes: 3 additions & 1 deletion perceptor/models/velocity_diffusion/velocity_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def shape(self):
return self.model.shape

@staticmethod
def schedule_ts(n_steps=500, from_sigma=1, to_sigma=1e-2, rho=0.7):
def schedule_ts(n_steps=500, from_ts=1, to_ts=1e-2, rho=0.7):
from_sigma = Model.sigmas(from_ts).squeeze().item()
to_sigma = Model.sigmas(to_ts).squeeze().item()
ramp = torch.linspace(0, 1, n_steps + 1)
min_inv_rho = to_sigma ** (1 / rho)
max_inv_rho = from_sigma ** (1 / rho)
Expand Down
5 changes: 3 additions & 2 deletions perceptor/transforms/clamp_with_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def backward(ctx, grad_in):
)


clamp_with_grad = ClampWithGradFunction.apply
def clamp_with_grad(tensor, min=0.0, max=1.0):
return ClampWithGradFunction.apply(tensor, min, max)


class ClampWithGrad(TransformInterface):
Expand All @@ -33,7 +34,7 @@ def __init__(self, min=0, max=1):
self.max = max

def encode(self, tensor):
return clamp_with_grad(tensor)
return clamp_with_grad(tensor, self.min, self.max)

def decode(self, tensor):
return tensor
4 changes: 2 additions & 2 deletions perceptor/utils/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class GradientCheckpoint(FunctionalBase):
def __init__(self, tensor):
super().__init__(original=tensor, detached=tensor.detach().requires_grad_())

def continue_backward(self):
def continue_backward(self, retain_graph=False):
if self.grad is None:
raise ValueError("Gradient is not defined")
return self.original.backward(self.detached.grad)
return self.original.backward(self.detached.grad, retain_graph=retain_graph)

@property
def grad(self):
Expand Down
Loading

0 comments on commit eec8453

Please sign in to comment.