Skip to content

Commit

Permalink
CallSupport: isolate per variant
Browse files Browse the repository at this point in the history
  • Loading branch information
merlinND authored and wjakob committed Dec 13, 2024
1 parent b6cb13d commit bc0d501
Show file tree
Hide file tree
Showing 17 changed files with 67 additions and 33 deletions.
20 changes: 20 additions & 0 deletions include/mitsuba/core/class.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,25 @@ template <typename T, std::enable_if_t<is_constructible_v<T, Stream*>, int> = 0>
Class::UnserializeFunctor get_unserialize_functor() { return [](Stream* s) -> Object* { return new T(s); }; }
template <typename T, std::enable_if_t<!is_constructible_v<T, Stream*>, int> = 0>
Class::UnserializeFunctor get_unserialize_functor() { return {}; }

NAMESPACE_END(detail)

#define MI_REGISTRY_PUT(name, ptr) \
if constexpr (dr::is_jit_v<Float>) { \
jit_registry_put(::mitsuba::detail::get_variant<Float, Spectrum>(), \
"mitsuba::" name, ptr); \
}

#define MI_CALL_TEMPLATE_BEGIN(Name) \
DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Name)

#define MI_CALL_TEMPLATE_END(Name) \
public: \
static constexpr const char *variant_() { \
return ::mitsuba::detail::get_variant<Ts...>(); \
} \
static_assert(is_detected_v<detail::has_variant_override, CallSupport_>); \
DRJIT_CALL_END(mitsuba::Name)


NAMESPACE_END(mitsuba)
3 changes: 3 additions & 0 deletions include/mitsuba/core/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ NAMESPACE_END(mitsuba)
namespace drjit {
template <typename Array>
struct call_support<mitsuba::Object, Array> {
// This is for pointers to general `Object` instances, we don't have access
// to specific `Float` and `Spectrum` types here.
static constexpr const char *Variant = "";
static constexpr const char *Domain = "mitsuba::Object";
};
}
4 changes: 2 additions & 2 deletions include/mitsuba/render/bsdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::BSDF)
MI_CALL_TEMPLATE_BEGIN(BSDF)
DRJIT_CALL_METHOD(sample)
DRJIT_CALL_METHOD(eval)
DRJIT_CALL_METHOD(eval_null_transmission)
Expand All @@ -670,7 +670,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::BSDF)
auto needs_differentials() const {
return has_flag(flags(), mitsuba::BSDFFlags::NeedsDifferentials);
}
DRJIT_CALL_END(mitsuba::BSDF)
MI_CALL_TEMPLATE_END(BSDF)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Emitter)
MI_CALL_TEMPLATE_BEGIN(Emitter)
DRJIT_CALL_METHOD(sample_ray)
DRJIT_CALL_METHOD(sample_direction)
DRJIT_CALL_METHOD(pdf_direction)
Expand All @@ -117,7 +117,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Emitter)
DRJIT_CALL_GETTER(shape)
DRJIT_CALL_GETTER(medium)
DRJIT_CALL_GETTER(sampling_weight)
DRJIT_CALL_END(mitsuba::Emitter)
MI_CALL_TEMPLATE_END(Emitter)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/medium.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for packets of Medium pointers
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Medium)
MI_CALL_TEMPLATE_BEGIN(Medium)
DRJIT_CALL_GETTER(phase_function)
DRJIT_CALL_GETTER(use_emitter_sampling)
DRJIT_CALL_GETTER(is_homogeneous)
Expand All @@ -131,7 +131,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Medium)
DRJIT_CALL_METHOD(sample_interaction)
DRJIT_CALL_METHOD(transmittance_eval_pdf)
DRJIT_CALL_METHOD(get_scattering_coefficients)
DRJIT_CALL_END(mitsuba::Medium)
MI_CALL_TEMPLATE_END(Medium)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/phase.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::PhaseFunction)
MI_CALL_TEMPLATE_BEGIN(PhaseFunction)
DRJIT_CALL_METHOD(sample)
DRJIT_CALL_METHOD(eval_pdf)
DRJIT_CALL_METHOD(projected_area)
DRJIT_CALL_METHOD(max_projected_area)
DRJIT_CALL_GETTER(flags)
DRJIT_CALL_GETTER(component_count)
DRJIT_CALL_END(mitsuba::PhaseFunction)
MI_CALL_TEMPLATE_END(PhaseFunction)

//! @}
// -----------------------------------------------------------------------
4 changes: 2 additions & 2 deletions include/mitsuba/render/sensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Sensor)
MI_CALL_TEMPLATE_BEGIN(Sensor)
DRJIT_CALL_METHOD(sample_ray)
DRJIT_CALL_METHOD(sample_ray_differential)
DRJIT_CALL_METHOD(sample_direction)
Expand All @@ -326,4 +326,4 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Sensor)
DRJIT_CALL_GETTER(flags)
DRJIT_CALL_GETTER(shape)
DRJIT_CALL_GETTER(medium)
DRJIT_CALL_END(mitsuba::Sensor)
MI_CALL_TEMPLATE_END(Sensor)
4 changes: 2 additions & 2 deletions include/mitsuba/render/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ NAMESPACE_END(mitsuba)
//! @{ \name Dr.Jit support for vectorized function calls
// -----------------------------------------------------------------------

DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Shape)
MI_CALL_TEMPLATE_BEGIN(Shape)
DRJIT_CALL_METHOD(compute_surface_interaction)
DRJIT_CALL_METHOD(has_attribute)
DRJIT_CALL_METHOD(eval_attribute)
Expand Down Expand Up @@ -1112,7 +1112,7 @@ DRJIT_CALL_TEMPLATE_BEGIN(mitsuba::Shape)
auto is_mesh() const { return shape_type() == (uint32_t) mitsuba::ShapeType::Mesh; }
auto is_medium_transition() const { return interior_medium() != nullptr ||
exterior_medium() != nullptr; }
DRJIT_CALL_END(mitsuba::Shape)
MI_CALL_TEMPLATE_END(Shape)

//! @}
// -----------------------------------------------------------------------
20 changes: 20 additions & 0 deletions src/core/tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,23 @@ def test09_import_torch_order(order):
})

print(bsdf)


def test10_variant_switching_vcall():
import mitsuba as mi

# Switch variants between two scenes and try to render the second scene.
# This will only work if the two scenes do not share the same JIT vectorized
# call registry
if ('llvm_ad_spectral' not in mi.variants() or
'llvm_ad_rgb' not in mi.variants()):
pytest.skip(f"Missing variants to properly run the test.")

mi.set_variant('llvm_ad_spectral')
scene1 = mi.load_dict(mi.cornell_box())

mi.set_variant('llvm_ad_rgb')
scene2 = mi.load_dict(mi.cornell_box())

mi.render(scene2, spp=1)

3 changes: 1 addition & 2 deletions src/render/bsdf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ NAMESPACE_BEGIN(mitsuba)

MI_VARIANT BSDF<Float, Spectrum>::BSDF(const Properties &props)
: m_flags(+BSDFFlags::Empty), m_id(props.id()) {
if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::BSDF", this);
MI_REGISTRY_PUT("BSDF", this);
}

MI_VARIANT BSDF<Float, Spectrum>::~BSDF() {
Expand Down
5 changes: 2 additions & 3 deletions src/render/emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ MI_VARIANT Emitter<Float, Spectrum>::Emitter(const Properties &props)
: Base(props) {
m_sampling_weight = props.get<ScalarFloat>("sampling_weight", 1.0f);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Emitter", this);
MI_REGISTRY_PUT("Emitter", this);
}

MI_VARIANT Emitter<Float, Spectrum>::~Emitter() {
MI_VARIANT Emitter<Float, Spectrum>::~Emitter() {
if constexpr (dr::is_jit_v<Float>)
jit_registry_remove(this);
}
Expand Down
10 changes: 4 additions & 6 deletions src/render/medium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

NAMESPACE_BEGIN(mitsuba)

MI_VARIANT Medium<Float, Spectrum>::Medium() :
m_is_homogeneous(false),
MI_VARIANT Medium<Float, Spectrum>::Medium() :
m_is_homogeneous(false),
m_has_spectral_extinction(true) {

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Medium", this);
MI_REGISTRY_PUT("Medium", this);
}

MI_VARIANT Medium<Float, Spectrum>::Medium(const Properties &props) : m_id(props.id()) {
Expand All @@ -33,8 +32,7 @@ MI_VARIANT Medium<Float, Spectrum>::Medium(const Properties &props) : m_id(props

m_sample_emitters = props.get<bool>("sample_emitters", true);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Medium", this);
MI_REGISTRY_PUT("Medium", this);
}

MI_VARIANT Medium<Float, Spectrum>::~Medium() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/phase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ NAMESPACE_BEGIN(mitsuba)
MI_VARIANT
PhaseFunction<Float, Spectrum>::PhaseFunction(const Properties &props)
: m_flags(+PhaseFunctionFlags::Empty), m_id(props.id()) {
if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::PhaseFunction", this);
MI_REGISTRY_PUT("PhaseFunction", this);
}

MI_VARIANT PhaseFunction<Float, Spectrum>::~PhaseFunction() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/sensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ MI_VARIANT Sensor<Float, Spectrum>::Sensor(const Properties &props) : Base(props
}
}

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Sensor", this);
MI_REGISTRY_PUT("Sensor", this);
}

MI_VARIANT Sensor<Float, Spectrum>::~Sensor() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ MI_VARIANT Shape<Float, Spectrum>::Shape(const Properties &props) : m_id(props.i

m_silhouette_sampling_weight = props.get<ScalarFloat>("silhouette_sampling_weight", 1.0f);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Shape", this);
MI_REGISTRY_PUT("Shape", this);
}

MI_VARIANT Shape<Float, Spectrum>::~Shape() {
Expand Down
3 changes: 1 addition & 2 deletions src/render/shapegroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ MI_VARIANT ShapeGroup<Float, Spectrum>::ShapeGroup(const Properties &props) {
}
#endif

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::ShapeGroup", this);
MI_REGISTRY_PUT("ShapeGroup", this);
}

MI_VARIANT ShapeGroup<Float, Spectrum>::~ShapeGroup() {
Expand Down
3 changes: 1 addition & 2 deletions src/shapes/merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class MergeShape final : public Shape<Float, Spectrum> {
Log(Info, "Collapsed %zu into %zu meshes. (took %s, %zu objects ignored)",
visited, tbl.size(), util::time_string((float) timer.value()), ignored);

if constexpr (dr::is_jit_v<Float>)
jit_registry_put(dr::backend_v<Float>, "mitsuba::Shape", this);
MI_REGISTRY_PUT("Shape", this);
}

std::vector<ref<Object>> expand() const override {
Expand Down

0 comments on commit bc0d501

Please sign in to comment.