Skip to content

Commit

Permalink
Add support for custom MDL nodes (#2125)
Browse files Browse the repository at this point in the history
The MDL generator is limited in a way that it only supported the mappings defined in the libraries/***lib/genmdl files.
With this pull request we now fully support <implementation target="genmdl"> nodes.
This comes in two flavors:

1. External .mdl file references:
We allow to specify an MDL module located in an MDL search path using the `file` attribute. In combination with the `function` attribute a particular exported function can be selected. There is one aspect that needs to be considered when using this approach:
- Input and output names defined in MaterialX could conflict with reserved names in MDL. Therefore, we add a `mxp_` prefix in case of such a conflict. The prefix is not applied in general because we want to be able to reference MDL modules without modifying them.

2. Inline MDL source code: 
By specifying the `sourcecode` attribute, MDL code can be added directly. The generator will inline this code into the generated MDL code, based on the interface of the corresponding node definition. Here, more aspects need to be considered:
- all input and input names defined in MaterialX are prefixed with `mxp_` to keep things simple for authors that otherwise would need knowledge about the keywords in MDL.
- because the source code parsed from the MaterialX document contains no line breaks, we need to forbid comments using `//`. Instead the `/* comment */` syntax can be used.
- to keep these inline source code nodes portable, it is not allowed to import additional MDL modules
  • Loading branch information
krohmerNV authored Dec 12, 2024
1 parent c46e8a3 commit 94ce95d
Show file tree
Hide file tree
Showing 15 changed files with 526 additions and 82 deletions.
58 changes: 53 additions & 5 deletions source/MaterialXGenMdl/MdlShaderGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <MaterialXGenMdl/Nodes/ClosureLayerNodeMdl.h>
#include <MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h>
#include <MaterialXGenMdl/Nodes/ClosureSourceCodeNodeMdl.h>
#include <MaterialXGenMdl/Nodes/CustomNodeMdl.h>
#include <MaterialXGenMdl/Nodes/ImageNodeMdl.h>

#include <MaterialXGenShader/GenContext.h>
Expand Down Expand Up @@ -157,6 +158,24 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G
emitLineEnd(stage, true);
}

// Emit custom node imports for nodes in the graph
for (ShaderNode* node : graph.getNodes())
{
const ShaderNodeImpl& impl = node->getImplementation();
const CustomCodeNodeMdl* customNode = dynamic_cast<const CustomCodeNodeMdl*>(&impl);
if (customNode)
{
const string& importName = customNode->getQualifiedModuleName();
if (!importName.empty())
{
emitString("import ", stage);
emitString(importName, stage);
emitString("::*", stage);
emitLineEnd(stage, true);
}
}
}

// Add global constants and type definitions
emitTypeDefinitions(context, stage);

Expand Down Expand Up @@ -353,14 +372,31 @@ ShaderNodeImplPtr MdlShaderGenerator::getImplementation(const NodeDef& nodedef,
impl = _implFactory.create(name);
if (!impl)
{
// Fall back to source code implementation.
if (outputType.isClosure())
// When `file` and `function` are provided we consider this node a user node
const string file = implElement->getTypedAttribute<string>("file");
const string function = implElement->getTypedAttribute<string>("function");
// Or, if `sourcecode` is provided we consider this node a user node with inline implementation
// inline implementations are not supposed to have replacement markers
const string sourcecode = implElement->getTypedAttribute<string>("sourcecode");
if ((!file.empty() && !function.empty()) || (!sourcecode.empty() && sourcecode.find("{{") == string::npos))
{
impl = CustomCodeNodeMdl::create();
}
else if (file.empty() && sourcecode.empty())
{
impl = ClosureSourceCodeNodeMdl::create();
throw ExceptionShaderGenError("No valid MDL implementation found for '" + name + "'");
}
else
{
impl = SourceCodeNodeMdl::create();
// Fall back to source code implementation.
if (outputType.isClosure())
{
impl = ClosureSourceCodeNodeMdl::create();
}
else
{
impl = SourceCodeNodeMdl::create();
}
}
}
}
Expand All @@ -386,6 +422,7 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex
return ShaderGenerator::getUpstreamResult(input, context);
}

const MdlSyntax& mdlSyntax = static_cast<const MdlSyntax&>(getSyntax());
string variable;
const ShaderNode* upstreamNode = upstreamOutput->getNode();
if (!upstreamNode->isAGraph() && upstreamNode->numOutputs() > 1)
Expand All @@ -397,7 +434,18 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex
}
else
{
variable = upstreamNode->getName() + "_result.mxp_" + upstreamOutput->getName();
const string& fieldName = upstreamOutput->getName();
const CustomCodeNodeMdl* upstreamCustomNodeMdl = dynamic_cast<const CustomCodeNodeMdl*>(&upstreamNode->getImplementation());
if (upstreamCustomNodeMdl)
{
// Prefix the port name depending on the CustomCodeNode
variable = upstreamNode->getName() + "_result." + upstreamCustomNodeMdl->modifyPortName(fieldName, mdlSyntax);
}
else
{
// Existing implementations and none user defined structs will keep the prefix always to not break existing content
variable = upstreamNode->getName() + "_result." + mdlSyntax.modifyPortName(upstreamOutput->getName());
}
}
}
else
Expand Down
89 changes: 74 additions & 15 deletions source/MaterialXGenMdl/MdlSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ TYPEDESC_REGISTER_TYPE(MDL_SCATTER_MODE, "scatter_mode")
namespace
{

const string MARKER_MDL_VERSION_SUFFIX = "MDL_VERSION_SUFFIX";

class MdlFilenameTypeSyntax : public ScalarTypeSyntax
{
public:
Expand Down Expand Up @@ -195,29 +197,49 @@ const StringVec MdlSyntax::FILTERTYPE_MEMBERS = { "box", "gaussian" };
const StringVec MdlSyntax::DISTRIBUTIONTYPE_MEMBERS = { "ggx" };
const StringVec MdlSyntax::SCATTER_MODE_MEMBERS = { "R", "T", "RT" };

const string MdlSyntax::PORT_NAME_PREFIX = "mxp_";

//
// MdlSyntax methods
//

MdlSyntax::MdlSyntax()
{
// Add in all reserved words and keywords in MDL
// Formatted as in the MDL Specification 1.9.2 for easy comparing
registerReservedWords(
{ // Reserved words
"annotation", "bool", "bool2", "bool3", "bool4", "break", "bsdf", "bsdf_measurement", "case", "cast", "color", "const",
"continue", "default", "do", "double", "double2", "double2x2", "double2x3", "double3", "double3x2", "double3x3", "double3x4",
"double4", "double4x3", "double4x4", "double4x2", "double2x4", "edf", "else", "enum", "export", "false", "float", "float2",
"float2x2", "float2x3", "float3", "float3x2", "float3x3", "float3x4", "float4", "float4x3", "float4x4", "float4x2", "float2x4",
"for", "hair_bsdf", "if", "import", "in", "int", "int2", "int3", "int4", "intensity_mode", "intensity_power", "intensity_radiant_exitance",
"let", "light_profile", "material", "material_emission", "material_geometry", "material_surface", "material_volume", "mdl", "module",
"package", "return", "string", "struct", "switch", "texture_2d", "texture_3d", "texture_cube", "texture_ptex", "true", "typedef", "uniform",
"using", "varying", "vdf", "while",
// Reserved for future use
"auto", "catch", "char", "class", "const_cast", "delete", "dynamic_cast", "explicit", "extern", "external", "foreach", "friend", "goto",
"graph", "half", "half2", "half2x2", "half2x3", "half3", "half3x2", "half3x3", "half3x4", "half4", "half4x3", "half4x4", "half4x2", "half2x4",
"inline", "inout", "lambda", "long", "mutable", "namespace", "native", "new", "operator", "out", "phenomenon", "private", "protected", "public",
"reinterpret_cast", "sampler", "shader", "short", "signed", "sizeof", "static", "static_cast", "technique", "template", "this", "throw", "try",
"typeid", "typename", "union", "unsigned", "virtual", "void", "volatile", "wchar_t" });
{ // Reserved words
"annotation", "double2", "float", "in", "operator",
"auto", "double2x2", "float2", "int", "package",
"bool", "double2x3", "float2x2", "int2", "return",
"bool2", "double3", "float2x3", "int3", "string",
"bool3", "double3x2", "float3", "int4", "struct",
"bool4", "double3x3", "float3x2", "intensity_mode", "struct_category",
"break", "double3x4", "float3x3", "intensity_power", "switch",
"bsdf", "double4", "float3x4", "intensity_radiant_exitance", "texture_2d",
"bsdf_measurement", "double4x3", "float4", "let", "texture_3d",
"case", "double4x4", "float4x3", "light_profile", "texture_cube",
"cast", "double4x2", "float4x4", "material", "texture_ptex",
"color", "double2x4", "float4x2", "material_emission", "true",
"const", "edf", "float2x4", "material_geometry", "typedef",
"continue", "else", "for", "material_surface", "uniform",
"declarative", "enum", "hair_bsdf", "material_volume", "using",
"default", "export", "if", "mdl", "varying",
"do", "false", "import", "module", "vdf",
"double", "while",

// Reserved for future use
"catch", "friend", "half3x4", "mutable", "sampler", "throw",
"char", "goto", "half4", "namespace", "shader", "try",
"class", "graph", "half4x3", "native", "short", "typeid",
"const_cast", "half", "half4x4", "new", "signed", "typename",
"delete", "half2", "half4x2", "out", "sizeof", "union",
"dynamic_cast", "half2x2", "half2x4", "phenomenon", "static", "unsigned",
"explicit", "half2x3", "inline", "private", "static_cast", "virtual",
"extern", "half3", "inout", "protected", "technique", "void",
"external", "half3x2", "lambda", "public", "template", "volatile",
"foreach", "half3x3", "long", "reinterpret_cast", "this", "wchar_t",
});

// Register restricted tokens in MDL
StringMap tokens;
Expand Down Expand Up @@ -533,4 +555,41 @@ void MdlSyntax::makeValidName(string& name) const
}
}

string MdlSyntax::modifyPortName(const string& word) const
{
return PORT_NAME_PREFIX + word;
}

string MdlSyntax::replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function<string(const string&)> lambda) const
{
// An inline function call
// Replace tokens of the format "{{<var>}}"
static const string prefix("{{");
static const string postfix("}}");

size_t pos = 0;
size_t i = soureCode.find_first_of(prefix);
StringVec code;
while (i != string::npos)
{
code.push_back(soureCode.substr(pos, i - pos));
size_t j = soureCode.find_first_of(postfix, i + 2);
if (j == string::npos)
{
throw ExceptionShaderGenError("Malformed inline expression in implementation for node " + nodeName);
}
const string marker = soureCode.substr(i + 2, j - i - 2);
code.push_back(lambda(marker));
pos = j + 2;
i = soureCode.find_first_of(prefix, pos);
}
code.push_back(soureCode.substr(pos));
return joinStrings(code, EMPTY_STRING);
}

const string MdlSyntax::getMdlVersionSuffixMarker() const
{
return MARKER_MDL_VERSION_SUFFIX;
}

MATERIALX_NAMESPACE_END
11 changes: 11 additions & 0 deletions source/MaterialXGenMdl/MdlSyntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MX_GENMDL_API MdlSyntax : public Syntax
static const StringVec FILTERTYPE_MEMBERS;
static const StringVec DISTRIBUTIONTYPE_MEMBERS;
static const StringVec SCATTER_MODE_MEMBERS;
static const string PORT_NAME_PREFIX; // Applied to input and output names to avoid collisions with reserved words in MDL

/// Get an type description for an enumeration based on member value
TypeDesc getEnumeratedType(const string& value) const;
Expand All @@ -63,6 +64,16 @@ class MX_GENMDL_API MdlSyntax : public Syntax

/// Modify the given name string to remove any invalid characters or tokens.
void makeValidName(string& name) const override;

/// To avoid collisions with reserved names in MDL, input and output names are prefixed.
string modifyPortName(const string& word) const;

/// Replaces all markers in a source code string indicated by {{...}}.
/// The replacement is defined by a callback function.
string replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function<string(const string&)> lambda) const;

/// Get the MDL language versing marker: {{MDL_VERSION_SUFFIX}}.
const string getMdlVersionSuffixMarker() const;
};

namespace Type
Expand Down
4 changes: 3 additions & 1 deletion source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//

#include <MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.h>
#include <MaterialXGenMdl/MdlSyntax.h>

#include <MaterialXGenShader/HwShaderGenerator.h>
#include <MaterialXGenShader/ShaderGenerator.h>
Expand All @@ -29,6 +30,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC
DEFINE_SHADER_STAGE(stage, Stage::PIXEL)
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const MdlSyntax& mdlSyntax = static_cast<const MdlSyntax&>(shadergen.getSyntax());

// Emit functions for all child nodes
shadergen.emitFunctionDefinitions(*_rootGraph, context, stage);
Expand Down Expand Up @@ -146,7 +148,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
const string result = shadergen.getUpstreamResult(output, context);
shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage);
shadergen.emitLine(resultVariableName + mdlSyntax.modifyPortName(output->getName()) + " = " + result, stage);
}
shadergen.emitLine("return " + resultVariableName, stage);
}
Expand Down
8 changes: 5 additions & 3 deletions source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <MaterialXGenMdl/Nodes/CompoundNodeMdl.h>
#include <MaterialXGenMdl/MdlShaderGenerator.h>
#include <MaterialXGenMdl/MdlSyntax.h>

#include <MaterialXGenShader/HwShaderGenerator.h>
#include <MaterialXGenShader/ShaderGenerator.h>
Expand Down Expand Up @@ -46,6 +47,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext&
DEFINE_SHADER_STAGE(stage, Stage::PIXEL)
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const MdlSyntax& syntax = static_cast<const MdlSyntax&>(shadergen.getSyntax());

const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) ||
_rootGraph->hasClassification(ShaderNode::Classification::SHADER));
Expand Down Expand Up @@ -83,7 +85,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext&
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
const string result = shadergen.getUpstreamResult(output, context);
shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage);
shadergen.emitLine(resultVariableName + "." + syntax.modifyPortName(output->getName()) + " = " + result, stage);
}
shadergen.emitLine("return " + resultVariableName, stage);
}
Expand Down Expand Up @@ -180,7 +182,7 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte
void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& context, ShaderStage& stage) const
{
const ShaderGenerator& shadergen = context.getShaderGenerator();
const Syntax& syntax = shadergen.getSyntax();
const MdlSyntax& syntax = static_cast<const MdlSyntax&>(shadergen.getSyntax());

if (!_returnStruct.empty())
{
Expand Down Expand Up @@ -208,7 +210,7 @@ void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& conte
shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS);
for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets())
{
shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage);
shadergen.emitLine(syntax.getTypeName(output->getType()) + " " + syntax.modifyPortName(output->getName()), stage);
}
shadergen.emitScopeEnd(stage, true);
shadergen.emitLineBreak(stage);
Expand Down
Loading

0 comments on commit 94ce95d

Please sign in to comment.