Skip to content

Commit

Permalink
daily updates
Browse files Browse the repository at this point in the history
  • Loading branch information
OmarFaig committed Jan 16, 2024
1 parent c5a453a commit 7789ede
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 16 deletions.
Binary file added 1912.12033.pdf
Binary file not shown.
12 changes: 6 additions & 6 deletions Codes/PointNet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _sample(self,xyz,npoints):
sampled_xyz : sampled points [B,C(3),npoints]
'''
print("xyz.shape ",xyz.shape)
print("npoints: - ",npoints)
print("sample xyz.shape ",xyz.shape)
print("sample npoints: - ",npoints)

if npoints is None:
return None
Expand Down Expand Up @@ -149,14 +149,14 @@ def _group(self,xyz,points,sampled_xyz):

if points is not None:
grouped_points = index_points(points.transpose(-2,-1),group_idx).movedim(-1,-3)
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-3)
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=1)
else:
grouped_points = grouped_xyz
else:#group all
grouped_xyz = xyz.unsqueeze(-2)
if points is not None:
grouped_points =points.unsqueeze(-2)
grouped_points = torch.cat([grouped_points,grouped_xyz],dim=-3)
grouped_points =points.unsqueeze(2)
grouped_points = torch.cat([grouped_points,grouped_xyz],dim=1)
else:
grouped_points = grouped_xyz
return grouped_points
Expand All @@ -174,7 +174,7 @@ def forward(self, xyz, points=None, sampled_xyz=None, npoints=None):
new_xyz: Tensor, (B, 3, npoint)
new_points: Tensor, (B, mlp[-1], npoint)
"""
print(f'npoints: {npoints}, self.nsample: {self.nsample}')
print(f'forward - npoints: {npoints}, self.nsample: {self.nsample}')

assert (npoints is None) or (self.npoints is None)
if npoints is None:
Expand Down
Binary file modified Codes/__pycache__/PointNet2.cpython-310.pyc
Binary file not shown.
16 changes: 9 additions & 7 deletions Codes/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from torch import einsum
import warnings
import logging
from pytorch3d.ops import knn_points
from pytorch3d.ops import knn_points,sample_farthest_points

from Codes.utils.seedformer_utils import fps_subsample
from PointNet2 import MLP_Res,MLP_CONV, PointNet_SA_Layer,vTransformer,grouping_operation
logger = logging.getLogger(__name__)

Expand All @@ -13,12 +15,13 @@ class FeatureExtractor(nn.Module):
def __init__(self, out_dim=1024, n_knn=20):
"""Encoder that encodes information of partial point cloud
"""
#in_channel values are not correct needs to be investigated
super(FeatureExtractor, self).__init__()
self.sa_module_1 = PointNet_SA_Layer(npoints=512,nsample=16,in_channel=3,mlp_channels=[64,128] )
self.sa_module_1 = PointNet_SA_Layer(npoints=512,nsample=16,in_channel=6,mlp_channels=[64,128] )
self.transformer_1 = vTransformer(128, dim=64, n_knn=n_knn)
self.sa_module_2 = PointNet_SA_Layer(npoints=128,nsample=16,in_channel=128,mlp_channels=[128,256])
self.sa_module_2 = PointNet_SA_Layer(npoints=128,nsample=16,in_channel=131,mlp_channels=[128,256])
self.transformer_2 = vTransformer(256, dim=64, n_knn=n_knn)
self.sa_module_3 = PointNet_SA_Layer(npoints=None,nsample=None,in_channel=256,mlp_channels=[512,out_dim])
self.sa_module_3 = PointNet_SA_Layer(npoints=None,nsample=None,in_channel=259,mlp_channels=[512,out_dim])

def forward(self, partial_cloud):
"""
Expand All @@ -39,8 +42,6 @@ def forward(self, partial_cloud):

return l3_points, l2_xyz, l2_points



class SeedGenerator(nn.Module):
def __init__(self, feat_dim=512, seed_dim=128, n_knn=20, factor=2, attn_channel=True):
super(SeedGenerator, self).__init__()
Expand Down Expand Up @@ -307,7 +308,8 @@ def forward_decoder(self, feat, partial_cloud, patch_xyz, patch_feat):
pred_pcds.append(seed)

# Upsample layers
pcd = fps_subsample(torch.cat([seed, partial_cloud], 1), self.num_p0) # (B, num_p0, 3)
#pcd = fps_subsample(torch.cat([seed, partial_cloud], 1), self.num_p0) # (B, num_p0, 3)
pcd,_=sample_farthest_points(torch.cat([seed, partial_cloud], 1), K=self.num_p0,random_start_point=True)
K_prev = None
pcd = pcd.permute(0, 2, 1).contiguous() # (B, 3, num_p0)
seed = seed.permute(0, 2, 1).contiguous() # (B, 3, 256)
Expand Down
Binary file modified Codes/utils/__pycache__/seedformer_utils.cpython-310.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions Codes/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import torch
from tqdm import tqdm
7 changes: 4 additions & 3 deletions Codes/utils/seedformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def forward(self, xyz, points, idx=None):
return new_xyz, new_points


def fps_subsample(pcd, n_points=2048):#Not so correct !!!
def fps_subsample(pcd, n_points=512):
"""
Args
pcd: (b, 16384, 3)
Expand All @@ -689,13 +689,14 @@ def fps_subsample(pcd, n_points=2048):#Not so correct !!!
raise ValueError(
'FPS subsampling receives a larger n_points: {:d} > {:d}'.format(
n_points, pcd.shape[1]))
#indices = sample_farthest_points(pcd, K=n_points)

new_pcd = gather_operation(
pcd.permute(0, 2, 1).contiguous(),
sample_farthest_points(pcd, n_points))
sample_farthest_points(pcd, K=n_points))
new_pcd = new_pcd.permute(0, 2, 1).contiguous()
return new_pcd


def get_nearest_index(target, source, k=1, return_dis=False):
"""
Args:
Expand Down

0 comments on commit 7789ede

Please sign in to comment.