Skip to content

Commit

Permalink
Merge pull request #27 from UoMResearchIT/4-write-tests-for-single_em…
Browse files Browse the repository at this point in the history
…itter

Add a test to check the load methods give the same emitter results
  • Loading branch information
andrewgait authored Feb 4, 2025
2 parents c769b89 + 3b6eceb commit a2b9e14
Showing 1 changed file with 124 additions and 35 deletions.
159 changes: 124 additions & 35 deletions src/emitters/tests/test_photon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
import pandas as pd
import drjit as dr
import mitsuba as mi

Expand All @@ -20,8 +21,18 @@
]

# Set up some basic photon data
emitter_data = [2, -8.6815998e-02, 1.0280000e+03, 1.0275000e+02, -1.3769200e-01, 1.0289886e+03, 1.0284847e+02,
-8.6815998e-02, 1.0280000e+03, 1.0275000e+02, -3.8943999e-02, 1.0289916e+03, 1.0286631e+02]
photon_detected = pd.DataFrame([[1,-0.086816,102.75,1028.0,0.047872,0.11631,0.99156,412.48667183165185],
[1,-0.086816,102.75,1028.0,0.029383,0.14085,0.98779,412.48667183165185],
[1,-0.086816,102.75,1028.0,-0.040829,0.069474,0.9925,412.48667183165185]])
x_position, y_position, z_position = photon_detected.values[:, 1:4].T
x_momentum, y_momentum, z_momentum = photon_detected.values[:, 4:7].T
# calculate the target coordinates of the photons
x_target = x_position + x_momentum
y_target = y_position + y_momentum
z_target = z_position + z_momentum

emitter_data = np.column_stack((x_position, z_position, y_position, x_target, z_target, y_target)).flatten()
emitter_data = np.insert(emitter_data, 0, len(x_position))
photon_data = np.zeros((1, 1, len(emitter_data)), dtype=np.float32)
photon_data[0, 0, :] = emitter_data

Expand All @@ -47,20 +58,24 @@ def create_emitter_and_spectrum(lookat, s_key='d65'):
@pytest.mark.parametrize("spectrum_key", spectrum_dicts.keys())
@pytest.mark.parametrize("it_pos", [[2.0, 0.5, 0.0], [1.0, 0.5, -5.0]])
@pytest.mark.parametrize("wavelength_sample", [0.7])
# @pytest.mark.parametrize("cutoff_angle", [20, 80])
@pytest.mark.parametrize("lookat", lookat_transforms)
def test_sample_direction(variant_scalar_spectral, spectrum_key, it_pos, wavelength_sample, lookat):
""" Check the correctness of the sample_direction() method """

# Test a fixed cutoff angle?
cutoff_angle = 20
cutoff_angle_rad = cutoff_angle / 180 * dr.pi
beam_width_rad = cutoff_angle_rad * 0.75
inv_transition_width = 1 / (cutoff_angle_rad - beam_width_rad)
# Create an emitter and spectrum
emitter, spectrum = create_emitter_and_spectrum(lookat, spectrum_key)
eval_t = 0.3
# TODO: work out how to test the transforms used in the photon emitter here
trafo = mi.Transform4f(emitter.world_transform())

# Get the transforms for this photon data
origins = [mi.Point3f(mi.Float(x_position[c]),
mi.Float(z_position[c]),
mi.Float(y_position[c])) for c in range(len(x_position))]
targets = [mi.Point3f(mi.Float(x_target[c]),
mi.Float(z_target[c]),
mi.Float(y_target[c])) for c in range(len(x_target))]
up = mi.Point3f(0, 0, 1)
camera_coords = [mi.Transform4f().look_at(origin=origin, target=target, up=up) for origin, target in zip(origins, targets)]
m_transforms = [camera_coord.matrix for camera_coord in camera_coords]

# Create a surface iteration
it = mi.SurfaceInteraction3f()
Expand All @@ -72,38 +87,35 @@ def test_sample_direction(variant_scalar_spectral, spectrum_key, it_pos, wavelen
wav, spec = spectrum.sample_spectrum(it, mi.sample_shifted(wavelength_sample))
it.wavelengths = wav

# Mimic what the matrix gather does in the C++ code (not available in python)
sample = [0,0]
index = dr.arange(dr.llvm.UInt32, dr.width(sample) % dr.width(m_transforms))
m_trans_at_index = [[m_transforms[0][n][index[0]]] * 4 for n in range(len(m_transforms[0]))]
m_matrix = mi.Matrix4f(m_trans_at_index)
m_transform = mi.Transform4f(m_matrix)
m_vector = m_transform.translation()

# Direction from the position to the point emitter
d = mi.Vector3f(-it.p + lookat.translation())
d = mi.Vector3f(-it.p + m_vector)
dist = dr.norm(d)
d /= dist

# Calculate angle between lookat direction and ray direction
angle = dr.acos((trafo.inverse() @ (-d))[2])
angle = dr.select(dr.abs(angle - beam_width_rad)
< 1e-3, beam_width_rad, angle)
angle = dr.select(dr.abs(angle - cutoff_angle_rad)
< 1e-3, cutoff_angle_rad, angle)

# Sample a direction from the emitter
ds, res = emitter.sample_direction(it, [0, 0])
ds, res = emitter.sample_direction(it, sample)

# Evaluate the spectrum
spec = spectrum.eval(it)
spec = dr.select(angle <= beam_width_rad, spec, spec *
((cutoff_angle_rad - angle) * inv_transition_width))
spec = dr.select(angle <= cutoff_angle_rad, spec, 0)

assert ds.time == it.time
assert ds.pdf == 1.0
assert ds.delta
# assert dr.allclose(ds.d, d)
# assert dr.allclose(res, spec / (dist**2))
assert dr.allclose(ds.d, d)
assert dr.allclose(res, spec / (dist**2))


@pytest.mark.parametrize("spectrum_key", spectrum_dicts.keys())
@pytest.mark.parametrize("wavelength_sample", [0.7])
@pytest.mark.parametrize("pos_sample", [[0.4, 0.5], [0.1, 0.4]])
# @pytest.mark.parametrize("cutoff_angle", [20, 80])
@pytest.mark.parametrize("lookat", lookat_transforms)
def test_sample_ray(variants_vec_spectral, spectrum_key, wavelength_sample, pos_sample, lookat):
# Check the correctness of the sample_ray() method
Expand All @@ -114,11 +126,19 @@ def test_sample_ray(variants_vec_spectral, spectrum_key, wavelength_sample, pos_
cos_cutoff_angle_rad = dr.cos(cutoff_angle_rad)
beam_width_rad = cutoff_angle_rad * 0.75
inv_transition_width = 1 / (cutoff_angle_rad - beam_width_rad)
emitter, spectrum = create_emitter_and_spectrum(
lookat, spectrum_key)
emitter, spectrum = create_emitter_and_spectrum(lookat, spectrum_key)
eval_t = 0.3
# TODO: work out how to test the transforms used in the photon emitter here
trafo = mi.Transform4f(emitter.world_transform())

# Get the transforms for this photon data
origins = [mi.Point3f(mi.Float(x_position[c]),
mi.Float(z_position[c]),
mi.Float(y_position[c])) for c in range(len(x_position))]
targets = [mi.Point3f(mi.Float(x_target[c]),
mi.Float(z_target[c]),
mi.Float(y_target[c])) for c in range(len(x_target))]
up = mi.Point3f(0, 0, 1)
camera_coords = [mi.Transform4f().look_at(origin=origin, target=target, up=up) for origin, target in zip(origins, targets)]
m_transforms = [camera_coord.matrix for camera_coord in camera_coords]

# Sample a local direction and calculate local angle
dir_sample = pos_sample # not being used anyway
Expand All @@ -130,6 +150,10 @@ def test_sample_ray(variants_vec_spectral, spectrum_key, wavelength_sample, pos_
angle = dr.select(dr.abs(angle - cutoff_angle_rad)
< 1e-3, cutoff_angle_rad, angle)

index = dr.arange(dr.llvm.UInt32, dr.width(wavelength_sample))
new_dir_4f = m_transforms[index[0]] * mi.Vector4f(0., 0., 1., 0.)
new_dir = mi.Vector3f([new_dir_4f[n] for n in range(3)])

# Sample a ray (position, direction, wavelengths) from the emitter
ray, res = emitter.sample_ray(
eval_t, wavelength_sample, pos_sample, dir_sample)
Expand All @@ -143,20 +167,16 @@ def test_sample_ray(variants_vec_spectral, spectrum_key, wavelength_sample, pos_
((cutoff_angle_rad - angle) * inv_transition_width))
spec = dr.select(angle <= cutoff_angle_rad, spec, 0)

# assert dr.allclose(
# res, spec / mi.warp.square_to_uniform_cone_pdf(trafo.inverse() @ ray.d, cos_cutoff_angle_rad))
pdf_dir = 445029
assert dr.allclose(res, spec / pdf_dir)
assert dr.allclose(ray.time, eval_t)
assert dr.all(local_dir.z >= cos_cutoff_angle_rad)
assert dr.allclose(ray.wavelengths, wav)
# TODO: work out how to test the transforms used in the photon emitter here
# assert dr.allclose(ray.d, trafo @ local_dir)
# assert dr.allclose(ray.o, lookat.translation())
assert dr.allclose(ray.d, new_dir)
assert dr.allclose(ray.o, mi.Transform4f(m_transforms[0]).translation())


@pytest.mark.parametrize("spectrum_key", spectrum_dicts.keys())
# @pytest.mark.parametrize("cutoff_angle", [20, 60])
@pytest.mark.parametrize("lookat", lookat_transforms)
def test_eval(variants_vec_spectral, spectrum_key, lookat):
# Check the correctness of the eval() method
Expand All @@ -168,3 +188,72 @@ def test_eval(variants_vec_spectral, spectrum_key, lookat):
it = dr.zeros(mi.SurfaceInteraction3f, 3)
it.wi = [0, 1, 0]
assert dr.allclose(emitter.eval(it), 0.)


def test_load_methods():
# Check that the same emitter is created for each type of loading method
photon_list = mi.VolumeGrid(photon_data)
intensity = 1000.0
volume_grid_emitter = mi.load_dict({
'type' : 'photon',
'photon_list' : photon_list,
# 'cutoff_angle' : cutoff_angle,
'intensity' : intensity
})

# Create a binary file from the photon data
import struct

with open("photon_geometry.bin", "wb") as f:
f.write(struct.pack("<Q", len(x_position))) # Use 'Q' format for 64-bit unsigned integer (size_t)
for x1, y1, z1, x2, y2, z2 in zip(x_position, y_position, z_position, x_target, y_target, z_target):
f.write(struct.pack("<f", x1))
f.write(struct.pack("<f", z1))
f.write(struct.pack("<f", y1))
f.write(struct.pack("<f", x2))
f.write(struct.pack("<f", z2))
f.write(struct.pack("<f", y2))

binary_file_emitter = mi.load_dict({
'type' : 'photon',
'filename' : 'photon_geometry.bin',
'intensity' : intensity
})

# We now have two emitters so compare the two using the same sampling parameters
cutoff_angle = 20
cutoff_angle_rad = cutoff_angle / 180 * dr.pi
beam_width_rad = cutoff_angle_rad * 0.75
eval_t = 0.3
# TODO: work out how to test the transforms used in the photon emitter here
vol_trafo = mi.Transform4f(volume_grid_emitter.world_transform())
bin_trafo = mi.Transform4f(binary_file_emitter.world_transform())

# Sample a local direction and calculate local angle
pos_sample = [0.4, 0.5]
dir_sample = pos_sample
local_dir = mi.ScalarVector3f(0., 0., 1.)
angle = dr.acos(local_dir[2])
angle = dr.select(dr.abs(angle - beam_width_rad)
< 1e-3, beam_width_rad, angle)
angle = dr.select(dr.abs(angle - cutoff_angle_rad)
< 1e-3, cutoff_angle_rad, angle)

wavelength_sample = 0.7

# Sample a ray (position, direction, wavelengths) from the emitters
vol_ray, vol_res = volume_grid_emitter.sample_ray(
eval_t, wavelength_sample, pos_sample, dir_sample)

bin_ray, bin_res = binary_file_emitter.sample_ray(
eval_t, wavelength_sample, pos_sample, dir_sample)

assert(vol_ray.o[0] == bin_ray.o[0])
assert(vol_ray.d[0] == bin_ray.d[0])
assert(vol_ray.wavelengths[0] == bin_ray.wavelengths[0])
assert(vol_res[0] == bin_res[0])

# Delete the photon_geometry binary file
import os
os.remove('photon_geometry.bin')

0 comments on commit a2b9e14

Please sign in to comment.