From 94ce95d7b814b188d6a2adc2aa3f521ed3a20643 Mon Sep 17 00:00:00 2001 From: krohmerNV <42233792+krohmerNV@users.noreply.github.com> Date: Thu, 12 Dec 2024 19:24:40 +0100 Subject: [PATCH] Add support for custom MDL nodes (#2125) The MDL generator is limited in a way that it only supported the mappings defined in the libraries/***lib/genmdl files. With this pull request we now fully support nodes. This comes in two flavors: 1. External .mdl file references: We allow to specify an MDL module located in an MDL search path using the `file` attribute. In combination with the `function` attribute a particular exported function can be selected. There is one aspect that needs to be considered when using this approach: - Input and output names defined in MaterialX could conflict with reserved names in MDL. Therefore, we add a `mxp_` prefix in case of such a conflict. The prefix is not applied in general because we want to be able to reference MDL modules without modifying them. 2. Inline MDL source code: By specifying the `sourcecode` attribute, MDL code can be added directly. The generator will inline this code into the generated MDL code, based on the interface of the corresponding node definition. Here, more aspects need to be considered: - all input and input names defined in MaterialX are prefixed with `mxp_` to keep things simple for authors that otherwise would need knowledge about the keywords in MDL. - because the source code parsed from the MaterialX document contains no line breaks, we need to forbid comments using `//`. Instead the `/* comment */` syntax can be used. - to keep these inline source code nodes portable, it is not allowed to import additional MDL modules --- source/MaterialXGenMdl/MdlShaderGenerator.cpp | 58 +++- source/MaterialXGenMdl/MdlSyntax.cpp | 89 +++++- source/MaterialXGenMdl/MdlSyntax.h | 11 + .../Nodes/ClosureCompoundNodeMdl.cpp | 4 +- .../MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp | 8 +- .../MaterialXGenMdl/Nodes/CustomNodeMdl.cpp | 278 ++++++++++++++++++ source/MaterialXGenMdl/Nodes/CustomNodeMdl.h | 56 ++++ source/MaterialXGenMdl/Nodes/ImageNodeMdl.h | 2 +- .../MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp | 5 +- .../Nodes/SourceCodeNodeMdl.cpp | 62 ++-- .../MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h | 1 + .../MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp | 5 +- .../Nodes/SourceCodeNode.cpp | 24 +- .../MaterialXGenShader/Nodes/SourceCodeNode.h | 3 + .../MaterialXTest/MaterialXGenMdl/GenMdl.cpp | 2 +- 15 files changed, 526 insertions(+), 82 deletions(-) create mode 100644 source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp create mode 100644 source/MaterialXGenMdl/Nodes/CustomNodeMdl.h diff --git a/source/MaterialXGenMdl/MdlShaderGenerator.cpp b/source/MaterialXGenMdl/MdlShaderGenerator.cpp index 71db0d7cf4..1bc65d5cc4 100644 --- a/source/MaterialXGenMdl/MdlShaderGenerator.cpp +++ b/source/MaterialXGenMdl/MdlShaderGenerator.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -157,6 +158,24 @@ ShaderPtr MdlShaderGenerator::generate(const string& name, ElementPtr element, G emitLineEnd(stage, true); } + // Emit custom node imports for nodes in the graph + for (ShaderNode* node : graph.getNodes()) + { + const ShaderNodeImpl& impl = node->getImplementation(); + const CustomCodeNodeMdl* customNode = dynamic_cast(&impl); + if (customNode) + { + const string& importName = customNode->getQualifiedModuleName(); + if (!importName.empty()) + { + emitString("import ", stage); + emitString(importName, stage); + emitString("::*", stage); + emitLineEnd(stage, true); + } + } + } + // Add global constants and type definitions emitTypeDefinitions(context, stage); @@ -353,14 +372,31 @@ ShaderNodeImplPtr MdlShaderGenerator::getImplementation(const NodeDef& nodedef, impl = _implFactory.create(name); if (!impl) { - // Fall back to source code implementation. - if (outputType.isClosure()) + // When `file` and `function` are provided we consider this node a user node + const string file = implElement->getTypedAttribute("file"); + const string function = implElement->getTypedAttribute("function"); + // Or, if `sourcecode` is provided we consider this node a user node with inline implementation + // inline implementations are not supposed to have replacement markers + const string sourcecode = implElement->getTypedAttribute("sourcecode"); + if ((!file.empty() && !function.empty()) || (!sourcecode.empty() && sourcecode.find("{{") == string::npos)) + { + impl = CustomCodeNodeMdl::create(); + } + else if (file.empty() && sourcecode.empty()) { - impl = ClosureSourceCodeNodeMdl::create(); + throw ExceptionShaderGenError("No valid MDL implementation found for '" + name + "'"); } else { - impl = SourceCodeNodeMdl::create(); + // Fall back to source code implementation. + if (outputType.isClosure()) + { + impl = ClosureSourceCodeNodeMdl::create(); + } + else + { + impl = SourceCodeNodeMdl::create(); + } } } } @@ -386,6 +422,7 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex return ShaderGenerator::getUpstreamResult(input, context); } + const MdlSyntax& mdlSyntax = static_cast(getSyntax()); string variable; const ShaderNode* upstreamNode = upstreamOutput->getNode(); if (!upstreamNode->isAGraph() && upstreamNode->numOutputs() > 1) @@ -397,7 +434,18 @@ string MdlShaderGenerator::getUpstreamResult(const ShaderInput* input, GenContex } else { - variable = upstreamNode->getName() + "_result.mxp_" + upstreamOutput->getName(); + const string& fieldName = upstreamOutput->getName(); + const CustomCodeNodeMdl* upstreamCustomNodeMdl = dynamic_cast(&upstreamNode->getImplementation()); + if (upstreamCustomNodeMdl) + { + // Prefix the port name depending on the CustomCodeNode + variable = upstreamNode->getName() + "_result." + upstreamCustomNodeMdl->modifyPortName(fieldName, mdlSyntax); + } + else + { + // Existing implementations and none user defined structs will keep the prefix always to not break existing content + variable = upstreamNode->getName() + "_result." + mdlSyntax.modifyPortName(upstreamOutput->getName()); + } } } else diff --git a/source/MaterialXGenMdl/MdlSyntax.cpp b/source/MaterialXGenMdl/MdlSyntax.cpp index 0244a8352f..c0859f5ae2 100644 --- a/source/MaterialXGenMdl/MdlSyntax.cpp +++ b/source/MaterialXGenMdl/MdlSyntax.cpp @@ -29,6 +29,8 @@ TYPEDESC_REGISTER_TYPE(MDL_SCATTER_MODE, "scatter_mode") namespace { +const string MARKER_MDL_VERSION_SUFFIX = "MDL_VERSION_SUFFIX"; + class MdlFilenameTypeSyntax : public ScalarTypeSyntax { public: @@ -195,6 +197,8 @@ const StringVec MdlSyntax::FILTERTYPE_MEMBERS = { "box", "gaussian" }; const StringVec MdlSyntax::DISTRIBUTIONTYPE_MEMBERS = { "ggx" }; const StringVec MdlSyntax::SCATTER_MODE_MEMBERS = { "R", "T", "RT" }; +const string MdlSyntax::PORT_NAME_PREFIX = "mxp_"; + // // MdlSyntax methods // @@ -202,22 +206,40 @@ const StringVec MdlSyntax::SCATTER_MODE_MEMBERS = { "R", "T", "RT" }; MdlSyntax::MdlSyntax() { // Add in all reserved words and keywords in MDL + // Formatted as in the MDL Specification 1.9.2 for easy comparing registerReservedWords( - { // Reserved words - "annotation", "bool", "bool2", "bool3", "bool4", "break", "bsdf", "bsdf_measurement", "case", "cast", "color", "const", - "continue", "default", "do", "double", "double2", "double2x2", "double2x3", "double3", "double3x2", "double3x3", "double3x4", - "double4", "double4x3", "double4x4", "double4x2", "double2x4", "edf", "else", "enum", "export", "false", "float", "float2", - "float2x2", "float2x3", "float3", "float3x2", "float3x3", "float3x4", "float4", "float4x3", "float4x4", "float4x2", "float2x4", - "for", "hair_bsdf", "if", "import", "in", "int", "int2", "int3", "int4", "intensity_mode", "intensity_power", "intensity_radiant_exitance", - "let", "light_profile", "material", "material_emission", "material_geometry", "material_surface", "material_volume", "mdl", "module", - "package", "return", "string", "struct", "switch", "texture_2d", "texture_3d", "texture_cube", "texture_ptex", "true", "typedef", "uniform", - "using", "varying", "vdf", "while", - // Reserved for future use - "auto", "catch", "char", "class", "const_cast", "delete", "dynamic_cast", "explicit", "extern", "external", "foreach", "friend", "goto", - "graph", "half", "half2", "half2x2", "half2x3", "half3", "half3x2", "half3x3", "half3x4", "half4", "half4x3", "half4x4", "half4x2", "half2x4", - "inline", "inout", "lambda", "long", "mutable", "namespace", "native", "new", "operator", "out", "phenomenon", "private", "protected", "public", - "reinterpret_cast", "sampler", "shader", "short", "signed", "sizeof", "static", "static_cast", "technique", "template", "this", "throw", "try", - "typeid", "typename", "union", "unsigned", "virtual", "void", "volatile", "wchar_t" }); + { // Reserved words + "annotation", "double2", "float", "in", "operator", + "auto", "double2x2", "float2", "int", "package", + "bool", "double2x3", "float2x2", "int2", "return", + "bool2", "double3", "float2x3", "int3", "string", + "bool3", "double3x2", "float3", "int4", "struct", + "bool4", "double3x3", "float3x2", "intensity_mode", "struct_category", + "break", "double3x4", "float3x3", "intensity_power", "switch", + "bsdf", "double4", "float3x4", "intensity_radiant_exitance", "texture_2d", + "bsdf_measurement", "double4x3", "float4", "let", "texture_3d", + "case", "double4x4", "float4x3", "light_profile", "texture_cube", + "cast", "double4x2", "float4x4", "material", "texture_ptex", + "color", "double2x4", "float4x2", "material_emission", "true", + "const", "edf", "float2x4", "material_geometry", "typedef", + "continue", "else", "for", "material_surface", "uniform", + "declarative", "enum", "hair_bsdf", "material_volume", "using", + "default", "export", "if", "mdl", "varying", + "do", "false", "import", "module", "vdf", + "double", "while", + + // Reserved for future use + "catch", "friend", "half3x4", "mutable", "sampler", "throw", + "char", "goto", "half4", "namespace", "shader", "try", + "class", "graph", "half4x3", "native", "short", "typeid", + "const_cast", "half", "half4x4", "new", "signed", "typename", + "delete", "half2", "half4x2", "out", "sizeof", "union", + "dynamic_cast", "half2x2", "half2x4", "phenomenon", "static", "unsigned", + "explicit", "half2x3", "inline", "private", "static_cast", "virtual", + "extern", "half3", "inout", "protected", "technique", "void", + "external", "half3x2", "lambda", "public", "template", "volatile", + "foreach", "half3x3", "long", "reinterpret_cast", "this", "wchar_t", + }); // Register restricted tokens in MDL StringMap tokens; @@ -533,4 +555,41 @@ void MdlSyntax::makeValidName(string& name) const } } +string MdlSyntax::modifyPortName(const string& word) const +{ + return PORT_NAME_PREFIX + word; +} + +string MdlSyntax::replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) const +{ + // An inline function call + // Replace tokens of the format "{{}}" + static const string prefix("{{"); + static const string postfix("}}"); + + size_t pos = 0; + size_t i = soureCode.find_first_of(prefix); + StringVec code; + while (i != string::npos) + { + code.push_back(soureCode.substr(pos, i - pos)); + size_t j = soureCode.find_first_of(postfix, i + 2); + if (j == string::npos) + { + throw ExceptionShaderGenError("Malformed inline expression in implementation for node " + nodeName); + } + const string marker = soureCode.substr(i + 2, j - i - 2); + code.push_back(lambda(marker)); + pos = j + 2; + i = soureCode.find_first_of(prefix, pos); + } + code.push_back(soureCode.substr(pos)); + return joinStrings(code, EMPTY_STRING); +} + +const string MdlSyntax::getMdlVersionSuffixMarker() const +{ + return MARKER_MDL_VERSION_SUFFIX; +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/MdlSyntax.h b/source/MaterialXGenMdl/MdlSyntax.h index e282cefeca..48fd6214c0 100644 --- a/source/MaterialXGenMdl/MdlSyntax.h +++ b/source/MaterialXGenMdl/MdlSyntax.h @@ -53,6 +53,7 @@ class MX_GENMDL_API MdlSyntax : public Syntax static const StringVec FILTERTYPE_MEMBERS; static const StringVec DISTRIBUTIONTYPE_MEMBERS; static const StringVec SCATTER_MODE_MEMBERS; + static const string PORT_NAME_PREFIX; // Applied to input and output names to avoid collisions with reserved words in MDL /// Get an type description for an enumeration based on member value TypeDesc getEnumeratedType(const string& value) const; @@ -63,6 +64,16 @@ class MX_GENMDL_API MdlSyntax : public Syntax /// Modify the given name string to remove any invalid characters or tokens. void makeValidName(string& name) const override; + + /// To avoid collisions with reserved names in MDL, input and output names are prefixed. + string modifyPortName(const string& word) const; + + /// Replaces all markers in a source code string indicated by {{...}}. + /// The replacement is defined by a callback function. + string replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) const; + + /// Get the MDL language versing marker: {{MDL_VERSION_SUFFIX}}. + const string getMdlVersionSuffixMarker() const; }; namespace Type diff --git a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp index 0a2004cd71..59b488c875 100644 --- a/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/ClosureCompoundNodeMdl.cpp @@ -4,6 +4,7 @@ // #include +#include #include #include @@ -29,6 +30,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit functions for all child nodes shadergen.emitFunctionDefinitions(*_rootGraph, context, stage); @@ -146,7 +148,7 @@ void ClosureCompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenC for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { const string result = shadergen.getUpstreamResult(output, context); - shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage); + shadergen.emitLine(resultVariableName + mdlSyntax.modifyPortName(output->getName()) + " = " + result, stage); } shadergen.emitLine("return " + resultVariableName, stage); } diff --git a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp index 80c449e507..2dab721e24 100644 --- a/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/CompoundNodeMdl.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -46,6 +47,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); const bool isMaterialExpr = (_rootGraph->hasClassification(ShaderNode::Classification::CLOSURE) || _rootGraph->hasClassification(ShaderNode::Classification::SHADER)); @@ -83,7 +85,7 @@ void CompoundNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { const string result = shadergen.getUpstreamResult(output, context); - shadergen.emitLine(resultVariableName + ".mxp_" + output->getName() + " = " + result, stage); + shadergen.emitLine(resultVariableName + "." + syntax.modifyPortName(output->getName()) + " = " + result, stage); } shadergen.emitLine("return " + resultVariableName, stage); } @@ -180,7 +182,7 @@ void CompoundNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& conte void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& context, ShaderStage& stage) const { const ShaderGenerator& shadergen = context.getShaderGenerator(); - const Syntax& syntax = shadergen.getSyntax(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); if (!_returnStruct.empty()) { @@ -208,7 +210,7 @@ void CompoundNodeMdl::emitFunctionSignature(const ShaderNode&, GenContext& conte shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); for (const ShaderGraphOutputSocket* output : _rootGraph->getOutputSockets()) { - shadergen.emitLine(syntax.getTypeName(output->getType()) + " mxp_" + output->getName(), stage); + shadergen.emitLine(syntax.getTypeName(output->getType()) + " " + syntax.modifyPortName(output->getName()), stage); } shadergen.emitScopeEnd(stage, true); shadergen.emitLineBreak(stage); diff --git a/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp new file mode 100644 index 0000000000..7215b90256 --- /dev/null +++ b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.cpp @@ -0,0 +1,278 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include +#include +#include +#include +#include + +MATERIALX_NAMESPACE_BEGIN + +ShaderNodeImplPtr CustomCodeNodeMdl::create() +{ + return std::make_shared(); +} + +const string& CustomCodeNodeMdl::getQualifiedModuleName() const +{ + return _qualifiedModuleName; +} + +string CustomCodeNodeMdl::modifyPortName(const string& name, const MdlSyntax& syntax) const +{ + if (_useExternalSourceCode) + { + const StringSet& reservedWords = syntax.getReservedWords(); + if (reservedWords.find(name) == reservedWords.end()) + { + // Use existing MDL parameter names if they don't collide with a reserved word. + // This allows us to reference MDL existing functions without changing the MDL source code. + return name; + } + } + return syntax.modifyPortName(name); +} + +void CustomCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& context) +{ + SourceCodeNodeMdl::initialize(element, context); + if (_inlined) + { + _useExternalSourceCode = false; + initializeForInlineSourceCode(element, context); + } + else + { + _useExternalSourceCode = true; + initializeForExternalSourceCode(element, context); + } +} + +void CustomCodeNodeMdl::initializeForInlineSourceCode(const InterfaceElement& element, GenContext& context) +{ + const Implementation& impl = static_cast(element); + // Store the inline source because the `_functionSource` is used for the function call template string + // that matched the regular MaterialX to MDL function mapping. + _inlineSourceCode = impl.getAttribute("sourcecode"); + if (_inlineSourceCode.empty()) + { + throw ExceptionShaderGenError("No source code was specified for the implementation '" + impl.getName() + "'"); + } + if (_inlineSourceCode.find("//") != string::npos) + { + throw ExceptionShaderGenError("Source code contains unsupported comments '//', please use '/* comment */' instead in '" + impl.getName() + "'"); + } + + NodeDefPtr nodeDef = impl.getNodeDef(); + _inlineFunctionName = nodeDef->getName(); + _hash = std::hash{}(_inlineFunctionName); // make sure we emit the function definition only once + + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + // Construct the function call template string + initializeFunctionCallTemplateString(syntax, *nodeDef); + // Collect information about output names and defaults + initializeOutputDefaults(syntax, *nodeDef); +} + +void CustomCodeNodeMdl::initializeForExternalSourceCode(const InterfaceElement& element, GenContext& context) +{ + // Format the function source in a way that the ShaderCodeNodeMdl (the base class of the current one) can deal with it + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + const string uniformPrefix = syntax.getUniformQualifier() + " "; + + // Map `file` to a qualified MDL module name + const Implementation& impl = static_cast(element); + string moduleName = impl.getAttribute("file"); + if (moduleName.empty()) + { + throw ExceptionShaderGenError("No source file was specified for the implementation '" + impl.getName() + "'"); + } + if (_functionName.empty()) + { + throw ExceptionShaderGenError("No function name was specified for the implementation '" + impl.getName() + "'"); + } + + string mdlModuleName = replaceSubstrings(moduleName, { { "/", "::" } }); + if (!stringStartsWith(mdlModuleName, "::")) + { + mdlModuleName = "::" + mdlModuleName; + } + + if (!stringEndsWith(mdlModuleName, ".mdl")) + { + throw ExceptionShaderGenError("Referenced source file is not an MDL module: '" + moduleName + + "' used by implementation '" + impl.getName() + "'"); + } + else + { + mdlModuleName = mdlModuleName.substr(0, mdlModuleName.size() - 4); + } + const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); + _qualifiedModuleName = syntax.replaceSourceCodeMarkers(element.getName(), mdlModuleName, + [&versionSuffix, &syntax](const string& marker) + { + return marker == syntax.getMdlVersionSuffixMarker() ? versionSuffix : marker; + }); + + NodeDefPtr nodeDef = impl.getNodeDef(); + // Construct the function call template string + initializeFunctionCallTemplateString(syntax, *nodeDef); + // Collect information about output names and defaults + initializeOutputDefaults(syntax, *nodeDef); +} + +void CustomCodeNodeMdl::initializeFunctionCallTemplateString(const MdlSyntax& syntax, const NodeDef& nodeDef) +{ + // Construct the fully qualified function name for external functions + if (_useExternalSourceCode) + { + _functionSource = _qualifiedModuleName.substr(2) + "::" + _functionName + "("; + } + // or simple name for local functions + else + { + _functionSource = _inlineFunctionName + "("; + } + + // Function parameters + string delim = EMPTY_STRING; + for (const InputPtr& input : nodeDef.getInputs()) + { + string inputName = modifyPortName(input->getName(), syntax); + _functionSource += delim + inputName + ": {{" + input->getName() + "}}"; + if (delim == EMPTY_STRING) + delim = Syntax::COMMA + " "; + } + _functionSource += ")"; + _inlined = true; +} + +void CustomCodeNodeMdl::initializeOutputDefaults(const MdlSyntax&, const NodeDef& nodeDef) +{ + for (const OutputPtr& output : nodeDef.getOutputs()) + { + _outputDefaults.push_back(output->getValue()); + } +} + +void CustomCodeNodeMdl::emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const +{ + // No source code printing for externally defined functions + if (_useExternalSourceCode) + { + return; + } + + const ShaderGenerator& shadergen = context.getShaderGenerator(); + const MdlSyntax& syntax = static_cast(shadergen.getSyntax()); + shadergen.emitComment("generated code for implementation: '" + node.getImplementation().getName() + "'", stage); + + // Function return type + struct Field + { + string name; + string type_name; + string default_value; + }; + vector outputs; + size_t i = 0; + for (const ShaderOutput* output : node.getOutputs()) + { + string name = modifyPortName(output->getName(), syntax); + TypeDesc type = output->getType(); + const ValuePtr defaultValue = _outputDefaults[i]; + outputs.push_back({ + name, + syntax.getTypeName(type), + defaultValue ? syntax.getValue(type, *defaultValue.get()) : syntax.getDefaultValue(type) + }); + ++i; + } + + size_t numOutputs = node.getOutputs().size(); + string returnTypeName; + if (numOutputs == 1) + { + returnTypeName = outputs.back().type_name; + } + else + { + returnTypeName = _inlineFunctionName + "_return_type"; + shadergen.emitLine("struct " + returnTypeName, stage, false); + shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); + for (const auto& field : outputs) + { + // ignore the default values here, they have to be initialized in the body + shadergen.emitLine(field.type_name + " " + field.name, stage); + } + shadergen.emitScopeEnd(stage, Syntax::CURLY_BRACKETS); + shadergen.emitLineEnd(stage, false); + } + // Signature + shadergen.emitString(returnTypeName + " " + _inlineFunctionName, stage); + { + // Function parameters + shadergen.emitScopeBegin(stage, Syntax::PARENTHESES); + size_t paramCount = node.getInputs().size(); + const string uniformPrefix = syntax.getUniformQualifier() + " "; + for (const ShaderInput* input : node.getInputs()) + { + const string& qualifier = input->isUniform() || input->getType() == Type::FILENAME ? uniformPrefix : EMPTY_STRING; + const string& type = syntax.getTypeName(input->getType()); + const string name = modifyPortName(input->getName(), syntax); + const string& delim = --paramCount == 0 ? EMPTY_STRING : Syntax::COMMA; + shadergen.emitString(" " + qualifier + type + " " + name + delim + Syntax::NEWLINE, stage); + } + shadergen.emitScopeEnd(stage, false, true); + } + { + // Function body + shadergen.emitScopeBegin(stage, Syntax::CURLY_BRACKETS); + + // Out variable initialization + shadergen.emitComment("initialize outputs:", stage); + for (const auto& field : outputs) + { + shadergen.emitLine(field.type_name + " " + field.name + " = " + field.default_value, stage); + } + + // User defined code + shadergen.emitComment("inlined shader source code:", stage); + shadergen.emitLine(_inlineSourceCode, stage, false); + + // Output packing + shadergen.emitComment("pack (in case of multiple outputs) and return outputs:", stage); + if (numOutputs == 1) + { + shadergen.emitLine("return " + outputs.back().name, stage, true); + } + else + { + // Return a constructor call of the return struct type + shadergen.emitString(" return " + returnTypeName + "(", stage); + string delim = EMPTY_STRING; + for (const auto& field : outputs) + { + shadergen.emitString(delim + field.name, stage); + if (delim == EMPTY_STRING) + delim = Syntax::COMMA + " "; + } + shadergen.emitString(")", stage); + shadergen.emitLineEnd(stage, true); + } + shadergen.emitScopeEnd(stage, false, true); + } + shadergen.emitLine("", stage, false); // empty line for spacing +} + +MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h new file mode 100644 index 0000000000..daf4c873a5 --- /dev/null +++ b/source/MaterialXGenMdl/Nodes/CustomNodeMdl.h @@ -0,0 +1,56 @@ +// +// Copyright Contributors to the MaterialX Project +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef MATERIALX_CUSTOMNODEMDL_H +#define MATERIALX_CUSTOMNODEMDL_H + +#include + +MATERIALX_NAMESPACE_BEGIN + +class MdlSyntax; +class NodeDef; + +/// Node to handle user defined implementations in external MDL files or using the inline `sourcecode` attribute. +class MX_GENMDL_API CustomCodeNodeMdl : public SourceCodeNodeMdl +{ + public: + static ShaderNodeImplPtr create(); + void initialize(const InterfaceElement& element, GenContext& context) override; + void emitFunctionDefinition(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; + + /// Get the MDL qualified name of the externally references user module. + /// It's used for import statements and functions calls in the generated target code. + const string& getQualifiedModuleName() const; + + /// To avoid collisions with reserved names in MDL, input and output names are prefixed. + /// In the `sourcecode` case all inputs and outputs are prefixed so authors don't need knowledge about reserved words in MDL. + /// In the `file` and `function` case, only reserved names are prefixed to support existing MDL implementations without changes. + string modifyPortName(const string& name, const MdlSyntax& syntax) const; + + protected: + /// Initialize function for nodes that use the inline `sourcecode` attribute. + void initializeForInlineSourceCode(const InterfaceElement& element, GenContext& context); + + /// Initialize function for nodes that use the `file` and `function` attribute. + void initializeForExternalSourceCode(const InterfaceElement& element, GenContext& context); + + /// Computes the function call string with replacement markers use by base class. + void initializeFunctionCallTemplateString(const MdlSyntax& syntax, const NodeDef& node); + + /// Keep track of the default values needed for the inline `sourcecode` case. + void initializeOutputDefaults(const MdlSyntax& syntax, const NodeDef& node); + + std::vector _outputDefaults; ///< store default values of the node definition + + bool _useExternalSourceCode; // Indicates that `file` and `function` are used by this node implementation + string _inlineFunctionName; // Name of the functionDefinition to emit + string _inlineSourceCode; // The actual inline source code + string _qualifiedModuleName; // MDL qualified name derived from the `file` attribute +}; + +MATERIALX_NAMESPACE_END + +#endif diff --git a/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h b/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h index fe88b2ce09..8363c6f4fa 100644 --- a/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/ImageNodeMdl.h @@ -18,7 +18,7 @@ class MX_GENMDL_API ImageNodeMdl : public SourceCodeNodeMdl using BASE = SourceCodeNodeMdl; public: - static const string FLIP_V; ///< the empty string "" + static const string FLIP_V; ///< name of the additional parameter "flip_v" static ShaderNodeImplPtr create(); diff --git a/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp index d13aa2871f..6891fc1778 100644 --- a/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/MaterialNodeMdl.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -32,6 +33,7 @@ void MaterialNodeMdl::emitFunctionCall(const ShaderNode& _node, GenContext& cont const ShaderGenerator& shadergen = context.getShaderGenerator(); const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit the function call for upstream surface shader. const ShaderNode* surfaceshaderNode = surfaceshaderInput->getConnection()->getNode(); @@ -50,8 +52,7 @@ void MaterialNodeMdl::emitFunctionCall(const ShaderNode& _node, GenContext& cont for (ShaderInput* input : node.getInputs()) { shadergen.emitString(delim, stage); - shadergen.emitString("mxp_", stage); - shadergen.emitString(input->getName(), stage); + shadergen.emitString(mdlSyntax.modifyPortName(input->getName()), stage); shadergen.emitString(": ", stage); shadergen.emitInput(input, context, stage); delim = ", "; diff --git a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp index f237a0db54..5547ae309d 100644 --- a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.cpp @@ -13,51 +13,28 @@ #include #include +#include + #include MATERIALX_NAMESPACE_BEGIN -namespace // anonymous -{ -const string MARKER_MDL_VERSION_SUFFIX = "MDL_VERSION_SUFFIX"; - -StringVec replaceSourceCodeMarkers(const string& nodeName, const string& soureCode, std::function lambda) +ShaderNodeImplPtr SourceCodeNodeMdl::create() { - // An inline function call - // Replace tokens of the format "{{}}" - static const string prefix("{{"); - static const string postfix("}}"); - - size_t pos = 0; - size_t i = soureCode.find_first_of(prefix); - StringVec code; - while (i != string::npos) - { - code.push_back(soureCode.substr(pos, i - pos)); - size_t j = soureCode.find_first_of(postfix, i + 2); - if (j == string::npos) - { - throw ExceptionShaderGenError("Malformed inline expression in implementation for node " + nodeName); - } - const string marker = soureCode.substr(i + 2, j - i - 2); - code.push_back(lambda(marker)); - pos = j + 2; - i = soureCode.find_first_of(prefix, pos); - } - code.push_back(soureCode.substr(pos)); - return code; + return std::make_shared(); } -} // anonymous namespace - -ShaderNodeImplPtr SourceCodeNodeMdl::create() +void SourceCodeNodeMdl::resolveSourceCode(const InterfaceElement& /*element*/, GenContext& /*context*/) { - return std::make_shared(); + // Initialize without fetching the source code from file. + // The resolution of MDL modules is done by the MDL compiler when loading the generated source code. + // All references MDL modules must be accessible via MDL search paths set up by the consuming application. } void SourceCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& context) { SourceCodeNode::initialize(element, context); + const MdlSyntax& syntax = static_cast(context.getShaderGenerator().getSyntax()); const Implementation& impl = static_cast(element); NodeDefPtr nodeDef = impl.getNodeDef(); @@ -77,11 +54,10 @@ void SourceCodeNodeMdl::initialize(const InterfaceElement& element, GenContext& const ShaderGenerator& shadergen = context.getShaderGenerator(); const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); - StringVec code = replaceSourceCodeMarkers(getName(), functionName, [&versionSuffix](const string& marker) + functionName = syntax.replaceSourceCodeMarkers(getName(), functionName, [&versionSuffix, syntax](const string& marker) { - return marker == MARKER_MDL_VERSION_SUFFIX ? versionSuffix : EMPTY_STRING; + return marker == syntax.getMdlVersionSuffixMarker() ? versionSuffix : EMPTY_STRING; }); - functionName = std::accumulate(code.begin(), code.end(), EMPTY_STRING); _returnStruct = functionName + "__result"; } else @@ -103,12 +79,13 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con const MdlShaderGenerator& shadergenMdl = static_cast(shadergen); if (_inlined) { + const MdlSyntax& syntax = static_cast(shadergenMdl.getSyntax()); const string versionSuffix = shadergenMdl.getMdlVersionFilenameSuffix(context); - StringVec code = replaceSourceCodeMarkers(node.getName(), _functionSource, - [&shadergenMdl, &context, &node, &versionSuffix](const string& marker) + string code = syntax.replaceSourceCodeMarkers(node.getName(), _functionSource, + [&shadergenMdl, &context, &node, &versionSuffix, syntax](const string& marker) { // Special handling for the version suffix of MDL source code modules. - if (marker == MARKER_MDL_VERSION_SUFFIX) + if (marker == syntax.getMdlVersionSuffixMarker()) { return versionSuffix; } @@ -131,7 +108,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con // Emit the struct multioutput. const string resultVariableName = node.getName() + "_result"; shadergen.emitLineBegin(stage); - shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage); + shadergen.emitString("auto " + resultVariableName + " = ", stage); } else { @@ -141,10 +118,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con shadergen.emitString(" = ", stage); } - for (const string& c : code) - { - shadergen.emitString(c, stage); - } + shadergen.emitString(code, stage); shadergen.emitLineEnd(stage); } else @@ -156,7 +130,7 @@ void SourceCodeNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& con // Emit the struct multioutput. const string resultVariableName = node.getName() + "_result"; shadergen.emitLineBegin(stage); - shadergen.emitString(_returnStruct + " " + resultVariableName + " = ", stage); + shadergen.emitString("auto " + resultVariableName + " = ", stage); } else { diff --git a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h index 23c47e6249..7f595704cc 100644 --- a/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h +++ b/source/MaterialXGenMdl/Nodes/SourceCodeNodeMdl.h @@ -25,6 +25,7 @@ class MX_GENMDL_API SourceCodeNodeMdl : public SourceCodeNode void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; protected: + void resolveSourceCode(const InterfaceElement& element, GenContext& context) override; string _returnStruct; }; diff --git a/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp b/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp index dd0444d129..8904d835f4 100644 --- a/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp +++ b/source/MaterialXGenMdl/Nodes/SurfaceNodeMdl.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -55,6 +56,7 @@ void SurfaceNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& contex DEFINE_SHADER_STAGE(stage, Stage::PIXEL) { const MdlShaderGenerator& shadergen = static_cast(context.getShaderGenerator()); + const MdlSyntax& mdlSyntax = static_cast(shadergen.getSyntax()); // Emit calls for the closure dependencies upstream from this node. shadergen.emitDependentFunctionCalls(node, context, stage, ShaderNode::Classification::CLOSURE); @@ -84,8 +86,7 @@ void SurfaceNodeMdl::emitFunctionCall(const ShaderNode& node, GenContext& contex for (ShaderInput* input : node.getInputs()) { shadergen.emitString(delim, stage); - shadergen.emitString("mxp_", stage); - shadergen.emitString(input->getName(), stage); + shadergen.emitString(mdlSyntax.modifyPortName(input->getName()), stage); shadergen.emitString(": ", stage); shadergen.emitInput(input, context, stage); delim = ", "; diff --git a/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp b/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp index 936edf9b46..32fc345354 100644 --- a/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp +++ b/source/MaterialXGenShader/Nodes/SourceCodeNode.cpp @@ -25,6 +25,20 @@ ShaderNodeImplPtr SourceCodeNode::create() return std::make_shared(); } +void SourceCodeNode::resolveSourceCode(const InterfaceElement& element, GenContext& context) +{ + const Implementation& impl = static_cast(element); + + FilePath localPath = FilePath(impl.getActiveSourceUri()).getParentPath(); + _sourceFilename = context.resolveSourceFile(impl.getAttribute("file"), localPath); + _functionSource = readFile(_sourceFilename); + if (_functionSource.empty()) + { + throw ExceptionShaderGenError("Failed to get source code from file '" + _sourceFilename.asString() + + "' used by implementation '" + impl.getName() + "'"); + } +} + void SourceCodeNode::initialize(const InterfaceElement& element, GenContext& context) { ShaderNodeImpl::initialize(element, context); @@ -40,19 +54,13 @@ void SourceCodeNode::initialize(const InterfaceElement& element, GenContext& con _functionSource = impl.getAttribute("sourcecode"); if (_functionSource.empty()) { - FilePath localPath = FilePath(impl.getActiveSourceUri()).getParentPath(); - _sourceFilename = context.resolveSourceFile(impl.getAttribute("file"), localPath); - _functionSource = readFile(_sourceFilename); - if (_functionSource.empty()) - { - throw ExceptionShaderGenError("Failed to get source code from file '" + _sourceFilename.asString() + - "' used by implementation '" + impl.getName() + "'"); - } + resolveSourceCode(element, context); } // Find the function name to use // If no function is given the source will be inlined. _functionName = impl.getAttribute("function"); + _inlined = _functionName.empty(); if (!_inlined) { diff --git a/source/MaterialXGenShader/Nodes/SourceCodeNode.h b/source/MaterialXGenShader/Nodes/SourceCodeNode.h index a208185a8d..6169dc0f61 100644 --- a/source/MaterialXGenShader/Nodes/SourceCodeNode.h +++ b/source/MaterialXGenShader/Nodes/SourceCodeNode.h @@ -26,6 +26,9 @@ class MX_GENSHADER_API SourceCodeNode : public ShaderNodeImpl void emitFunctionCall(const ShaderNode& node, GenContext& context, ShaderStage& stage) const override; protected: + /// Resolve the source file and read the source code during the initialization of the node. + virtual void resolveSourceCode(const InterfaceElement& element, GenContext& context); + bool _inlined; string _functionName; string _functionSource; diff --git a/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp b/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp index 4bec0016de..610b237163 100644 --- a/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp +++ b/source/MaterialXTest/MaterialXGenMdl/GenMdl.cpp @@ -270,7 +270,7 @@ void MdlShaderGeneratorTester::compileSource(const std::vector& so CHECK(returnValue == 0); } - if (!renderExec.empty()) // render if renderer is availabe + if (!renderExec.empty()) // render if renderer is available { std::string renderCommand = renderExec;