Skip to content

Commit

Permalink
Made sure that dtypes of all tensors are consistent with input tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Dec 28, 2023
1 parent 971b3b4 commit 5360440
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
12 changes: 7 additions & 5 deletions GeneralRelativity/CCZ4Geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def compute_ricci_Z(
h_UU: torch.Tensor,
chris: Dict[str, torch.Tensor],
Z_over_chi: torch.Tensor,
GR_SPACEDIM: int = 4,
GR_SPACEDIM: int = 3,
) -> Dict[str, torch.Tensor]:
"""
Compute the Ricci tensor Z using the provided variables, derivatives, and Christoffel symbols.
Expand All @@ -27,11 +27,13 @@ def compute_ricci_Z(
Returns:
dict: Dictionary containing the components of the Ricci tensor Z.
"""
out = {"LL": torch.zeros_like(vars["h"]), "scalar": 0}
dtype = vars["chi"].dtype

boxtildechi = 0
out = {"LL": torch.zeros_like(vars["h"], dtype=dtype), "scalar": 0.0}

covdtilde2chi = torch.zeros_like(vars["h"])
boxtildechi = 0.0

covdtilde2chi = torch.zeros_like(vars["h"], dtype=dtype)
for k, l in FOR2():
# covdtilde2chi[k][l] = d2.chi[k][l];
covdtilde2chi[..., k, l] = d2["chi"][..., k, l]
Expand Down Expand Up @@ -135,7 +137,7 @@ def compute_ricci(
d2: Dict[str, torch.Tensor],
h_UU: torch.Tensor,
chris: Dict[str, torch.Tensor],
GR_SPACEDIM: int = 4,
GR_SPACEDIM: int = 3,
) -> Dict[str, torch.Tensor]:
"""
Compute the Ricci tensor using the provided variables, derivatives, and Christoffel symbols.
Expand Down
15 changes: 8 additions & 7 deletions GeneralRelativity/Constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def constraint_equations(
d2: Dict[str, torch.Tensor],
h_UU: torch.Tensor,
chris: Dict[str, torch.Tensor],
GR_SPACEDIM: int = 4,
GR_SPACEDIM: int = 3,
cosmological_constant: float = 0,
) -> Dict[str, torch.Tensor]:
"""
Expand All @@ -33,12 +33,13 @@ def constraint_equations(
Dict[str, torch.Tensor]: A dictionary containing the computed constraint equations.
Keys include 'Ham', 'Ham_abs_terms', 'Mom', 'Mom_abs_terms'.
"""
dtype = vars["chi"].dtype

out = {
"Ham": torch.zeros_like(vars["chi"]),
"Ham_abs_terms": torch.zeros_like(vars["chi"]),
"Mom": torch.zeros_like(vars["shift"]),
"Mom_abs_terms": torch.zeros_like(vars["shift"]),
"Ham": torch.zeros_like(vars["chi"], dtype=dtype),
"Ham_abs_terms": torch.zeros_like(vars["chi"], dtype=dtype),
"Mom": torch.zeros_like(vars["shift"], dtype=dtype),
"Mom_abs_terms": torch.zeros_like(vars["shift"], dtype=dtype),
}

# auto ricci = CCZ4Geometry::compute_ricci(vars, d1, d2, h_UU, chris);
Expand Down Expand Up @@ -91,8 +92,8 @@ def constraint_equations(
# Tensor<1, data_t> covd_A_term = 0.0;
# Tensor<1, data_t> d1_chi_term = 0.0;
# const data_t chi_regularised = simd_max(1e-6, vars.chi);
covd_A_term = torch.zeros_like(d1["chi"])
d1_chi_term = torch.zeros_like(d1["chi"])
covd_A_term = torch.zeros_like(d1["chi"], dtype=dtype)
d1_chi_term = torch.zeros_like(d1["chi"], dtype=dtype)
chi_regularised = torch.maximum(torch.tensor(1e-6), vars["chi"])
for i, j, k in FOR3():
# covd_A_term[i] += h_UU[j][k] * covd_A[k][j][i];
Expand Down
17 changes: 9 additions & 8 deletions GeneralRelativity/FourthOrderDerivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ def diff1(tensor: torch.Tensor, one_over_dx: float) -> torch.Tensor:
:param one_over_dx: Inverse of the grid spacing.
:return: Tensor of shape (batchsize, x-4, y-4, z-4, num_variables, 3) containing derivatives.
"""
weight_far = torch.tensor(8.333333333333333e-2)
weight_near = torch.tensor(6.666666666666667e-1)

weight_far = torch.tensor(8.333333333333333e-2, dtype=tensor.dtype)
weight_near = torch.tensor(6.666666666666667e-1, dtype=tensor.dtype)

derivatives = []

Expand Down Expand Up @@ -55,9 +56,9 @@ def mixed_diff2_tensor(
:return: Tensor containing mixed second derivatives.
"""

weight_far_far = 6.94444444444444444444e-3
weight_near_far = 5.55555555555555555556e-2
weight_near_near = 4.44444444444444444444e-1
weight_far_far = torch.tensor(6.94444444444444444444e-3, dtype=tensor.dtype)
weight_near_far = torch.tensor(5.55555555555555555556e-2, dtype=tensor.dtype)
weight_near_near = torch.tensor(4.44444444444444444444e-1, dtype=tensor.dtype)

# Adjust indices for the spatial dimensions (add 1 because first dimension is batch)
dim1 = i + 1
Expand Down Expand Up @@ -151,9 +152,9 @@ def diff2_multidim(tensor: torch.tensor, i: int, one_over_dx2: float) -> torch.t
:param one_over_dx2: Inverse of the square of the grid spacing.
:return: Tensor of the same shape as the input tensor containing the second derivative along the specified dimension.
"""
weight_far = 8.33333333333333333333e-2
weight_near = 1.33333333333333333333e0
weight_local = 2.50000000000000000000e0
weight_far = torch.tensor(8.33333333333333333333e-2, dtype=tensor.dtype)
weight_near = torch.tensor(1.33333333333333333333e0, dtype=tensor.dtype)
weight_local = torch.tensor(2.50000000000000000000e0, dtype=tensor.dtype)

# Determine the spatial dimension to calculate the derivative
dim = i + 1 # Adjusting for batch dimension
Expand Down
4 changes: 2 additions & 2 deletions GeneralRelativity/TensorAlgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def raise_all_vector(
Returns:
torch.Tensor: The resulting tensor with the index raised.
"""
tensor_U = torch.zeros_like(tensor_L)
tensor_U = torch.zeros_like(tensor_L, dtype=tensor_L.dtype)
for i, j in FOR2():
tensor_U[..., i] += inverse_metric[..., i, j] * tensor_L[..., j]
return tensor_U
Expand All @@ -134,7 +134,7 @@ def raise_all_metric(
Returns:
torch.Tensor: The resulting tensor with indices raised (2-Tensor).
"""
tensor_UU = torch.zeros_like(tensor_LL)
tensor_UU = torch.zeros_like(tensor_LL, dtype=tensor_LL.dtype)
for i, j, k, l in FOR4():
tensor_UU[..., i, j] += (
inverse_metric[..., i, k] * inverse_metric[..., j, l] * tensor_LL[..., k, l]
Expand Down
4 changes: 4 additions & 0 deletions GeneralRelativity/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def cut_ghosts(tensor: torch.Tensor) -> torch.Tensor:
"dz_B1",
"dz_B2",
"dz_B3",
"Ham",
"Mom1",
"Mom2",
"Mom3",
]

keys = [
Expand Down

0 comments on commit 5360440

Please sign in to comment.