Skip to content

Commit

Permalink
Did multiple changes in this commit
Browse files Browse the repository at this point in the history
* Restructed Christoffel to use dictonary
* Wrote test for Christoffel
* created setup.py
* Restructed all imports to refect that we have a library now
* moved tests in seperate folder
  • Loading branch information
ThomasHelfer committed Dec 22, 2023
1 parent a3411d0 commit a9739de
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 31 deletions.
16 changes: 16 additions & 0 deletions GeneralRelativity/DimensionDefinitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
def FOR1():
return range(3)


def FOR2():
return ((i, j) for i in FOR1() for j in FOR1())


def FOR3():
return ((i, j, k) for i in FOR1() for j in FOR1() for k in FOR1())


def FOR4():
return (
(i, j, k, l) for i in FOR1() for j in FOR1() for k in FOR1() for l in FOR1()
)
61 changes: 36 additions & 25 deletions GeneralRelativity/TensorAlgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tqdm.auto import tqdm, trange
import glob
import torch
from GeneralRelativity.DimensionDefinitions import FOR1, FOR2, FOR3, FOR4


def compute_christoffel(d1_metric: torch.tensor, h_UU: torch.tensor) -> torch.tensor:
Expand All @@ -17,34 +18,25 @@ def compute_christoffel(d1_metric: torch.tensor, h_UU: torch.tensor) -> torch.te
Tuple of np.ndarray: Two arrays representing the Christoffel symbols LLL and ULL, each of shape [batch, x, y, z, i, j, k].
"""

# Initialize the output arrays
shape = d1_metric.shape[:-1] + (
d1_metric.shape[-2],
) # shape is [batch, x, y, z, i, j, k]
LLL = torch.zeros(shape, dtype=d1_metric.dtype)
ULL = torch.zeros(shape, dtype=d1_metric.dtype)
chris = {
"LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
}

# Compute Christoffel symbols of the first kind (LLL)
# out.LLL[i][j][k] = 0.5 * (d1_metric[j][i][k] + d1_metric[k][i][j] -
# d1_metric[j][k][i]);
for i in range(3):
for j in range(3):
for k in range(3):
LLL[..., i, j, k] = 0.5 * (
+d1_metric[..., j, i, k]
+ d1_metric[..., k, i, j]
- d1_metric[..., j, k, i]
)
for i, j, k in FOR3():
chris["LLL"][..., i, j, k] = 0.5 * (
+d1_metric[..., j, i, k] + d1_metric[..., k, i, j] - d1_metric[..., j, k, i]
)

# Compute Christoffel symbols of the second kind
# FOR1(l) { out.ULL[i][j][k] += h_UU[i][l] * out.LLL[l][j][k]; }
for i in range(3):
for j in range(3):
for k in range(3):
for l in range(3):
ULL[..., i, j, k] += h_UU[..., i, l] * LLL[..., l, j, k]
for i, j, k, l in FOR4():
chris["ULL"][..., i, j, k] += h_UU[..., i, l] * chris["LLL"][..., l, j, k]

return LLL, ULL
return chris


def compute_christoffel_fast(
Expand All @@ -65,14 +57,16 @@ def compute_christoffel_fast(

# Initialize the output tensors
# batch, x, y, z, i, j, dx = d1_metric.shape
LLL = torch.zeros_like(d1_metric)
ULL = torch.zeros_like(LLL)
chris = {
"LLL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
"ULL": torch.zeros(d1_metric.shape, dtype=d1_metric.dtype),
}

# Compute Christoffel symbols of the first kind (LLL)
# Adjusting indices and dimensions for proper computation
# out.LLL[i][j][k] = 0.5 * (d1_metric[j][i][k] + d1_metric[k][i][j] - d1_metric[j][k][i]);
test = (d1_metric).clone()
LLL = 0.5 * (
chris["LLL"] = 0.5 * (
test.permute(0, 1, 2, 3, 4, 5, 6)
+ test.permute(0, 1, 2, 3, 4, 6, 5)
- d1_metric.permute(0, 1, 2, 3, 6, 5, 4)
Expand All @@ -83,6 +77,23 @@ def compute_christoffel_fast(
# Note: 'ijklmn->ijklm' aligns the dimensions correctly
# Compute Christoffel symbols of the second kind
# FOR1(l) { out.ULL[i][j][k] += h_UU[i][l] * out.LLL[l][j][k]; }
ULL = torch.einsum("bxzyil,bxzyijk->bxzyijk", h_UU, LLL)
chris["ULL"] = torch.einsum("bxzyil,bxzyijk->bxzyijk", h_UU, chris["LLL"])

return LLL, ULL
return chris


def compute_trace(tensor_LL, inverse_metric):
"""
Computes the trace of a 2-Tensor with lower indices given an inverse metric.
Args:
tensor_LL (torch.Tensor): The 2-Tensor with lower indices.
inverse_metric (torch.Tensor): The inverse metric tensor.
Returns:
float: The trace of the tensor.
"""
trace = 0.0
for i, j in FOR2():
trace += inverse_metric[..., i, j] * tensor_LL[..., i, j]
return trace
30 changes: 30 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from setuptools import setup
from setuptools import find_packages


def readme():
with open("README.md") as f:
return f.read()


setup(
name="GeneralRelativity",
version="0.1",
description="HBA with GPR and PCA",
long_description=readme(),
classifiers=[
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.8",
"Topic :: Machine learning :: Physics :: Simulation :: General Relativity",
],
keywords="Machine learning, Physics, Simulation, General Relativity",
author="ThomasHelfer",
author_email="thomashelfer@live.de",
license="MIT",
packages=find_packages(exclude=["tests"]),
install_requires=["torch", "black", "pre-commit", "pytest", "numpy"],
python_requires=">=3.5 ",
include_package_data=True,
zip_safe=False,
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch
from FourthOrderDerivatives import diff1, diff2
from Utils import get_box_format, TensorDict, cut_ghosts, keys, keys_all
from GeneralRelativity.FourthOrderDerivatives import diff1, diff2
from GeneralRelativity.Utils import (
get_box_format,
TensorDict,
cut_ghosts,
keys,
keys_all,
)
import os
import sys

Expand All @@ -15,8 +21,7 @@ def test_compare_diff_with_reference():
"""
# Define the path to the test data files for variable X
filenamesX = os.path.dirname(__file__) + "/TestData/Xdata_level0_step*"
print(filenamesX)
print(filenamesX)

# Number of variables in the data
num_varsX = 100

Expand Down
73 changes: 73 additions & 0 deletions tests/test_TensorAlgebra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from GeneralRelativity.FourthOrderDerivatives import diff1, diff2
from GeneralRelativity.Utils import (
get_box_format,
TensorDict,
cut_ghosts,
keys,
keys_all,
)
from GeneralRelativity.TensorAlgebra import (
compute_christoffel,
compute_christoffel_fast,
)
import os
import sys


def test_chris():
"""
Test function to validate the computation of Christoffel symbols.
This function reads tensor data from files, computes Christoffel symbols using two different
implementations (compute_christoffel and compute_christoffel_fast), and then compares the results
to ensure they are consistent with each other. It also checks the symmetry property of the Christoffel
symbols. Assertions are used to ensure that the differences are within a specified tolerance.
"""
# Define the path to the test data files for variable X
filenamesX = os.path.dirname(__file__) + "/TestData/Xdata_level0_step*"

# Number of variables in the data
num_varsX = 100

# Read the data in a box format
dataX = get_box_format(filenamesX, num_varsX)

# Tolerance for comparison
tol = 1e-10

# Compute the differential value
oneoverdx = 64.0 / 4.0

# Prepare the data and compute derivatives using TensorDict
vars = TensorDict(cut_ghosts(dataX), keys_all)
d1 = TensorDict(diff1(dataX, oneoverdx), keys_all)
h_UU = torch.inverse(vars["h"])
chris = compute_christoffel(d1["h"], h_UU)
chris_2nd_implementation = compute_christoffel_fast(d1["h"], h_UU)

# Compare two versions of Christoffel symbols
assert torch.mean(torch.abs(chris["LLL"] - chris_2nd_implementation["LLL"])) < tol
assert torch.mean(torch.abs(chris["ULL"] - chris_2nd_implementation["ULL"])) < tol

# Check symmetry of Christoffel symbols
for i in range(3):
for j in range(i, 3):
assert (
torch.mean(
torch.abs(chris["ULL"][..., i, j])
- torch.abs(chris["ULL"][..., j, i])
)
== 0
)
assert (
torch.mean(
torch.abs(chris_2nd_implementation["ULL"][..., i, j])
- torch.abs(chris_2nd_implementation["ULL"][..., j, i])
)
== 0
)


if __name__ == "__main__":
test_chris()
10 changes: 8 additions & 2 deletions GeneralRelativity/test_Utils.py → tests/test_Utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch
from FourthOrderDerivatives import diff1, diff2
from Utils import get_box_format, TensorDict, cut_ghosts, keys, keys_all
from GeneralRelativity.FourthOrderDerivatives import diff1, diff2
from GeneralRelativity.Utils import (
get_box_format,
TensorDict,
cut_ghosts,
keys,
keys_all,
)
import os
import sys

Expand Down

0 comments on commit a9739de

Please sign in to comment.