From 93c9c449043bbb4617781fb33252ac9555e7005b Mon Sep 17 00:00:00 2001 From: Rudrakh Panigrahi Date: Wed, 29 Jan 2025 00:27:06 +0530 Subject: [PATCH] feat: implement Lua EnvoyExtensionPolicy Signed-off-by: Rudrakh Panigrahi --- go.mod | 1 + go.sum | 2 + internal/gatewayapi/envoyextensionpolicy.go | 81 +++ .../gatewayapi/luavalidator/lua_validator.go | 59 ++ .../luavalidator/lua_validator_test.go | 59 ++ internal/gatewayapi/luavalidator/mocks.lua | 577 ++++++++++++++++++ internal/ir/xds.go | 12 + internal/ir/zz_generated.deepcopy.go | 27 + internal/xds/translator/lua.go | 119 ++++ 9 files changed, 937 insertions(+) create mode 100644 internal/gatewayapi/luavalidator/lua_validator.go create mode 100644 internal/gatewayapi/luavalidator/lua_validator_test.go create mode 100644 internal/gatewayapi/luavalidator/mocks.lua create mode 100644 internal/xds/translator/lua.go diff --git a/go.mod b/go.mod index 1993b7d139a..9df5746fc7c 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/telepresenceio/watchable v0.0.0-20220726211108-9bb86f92afa7 github.com/tetratelabs/func-e v1.1.5-0.20240822223546-c85a098d5bf0 github.com/tsaarni/certyaml v0.10.0 + github.com/yuin/gopher-lua v1.1.1 go.opentelemetry.io/otel v1.34.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.34.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.34.0 diff --git a/go.sum b/go.sum index 35e6872946c..54cb952f496 100644 --- a/go.sum +++ b/go.sum @@ -865,6 +865,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= diff --git a/internal/gatewayapi/envoyextensionpolicy.go b/internal/gatewayapi/envoyextensionpolicy.go index 145662e5480..e07cf956578 100644 --- a/internal/gatewayapi/envoyextensionpolicy.go +++ b/internal/gatewayapi/envoyextensionpolicy.go @@ -8,6 +8,8 @@ package gatewayapi import ( "errors" "fmt" + "github.com/envoyproxy/gateway/internal/gatewayapi/luavalidator" + gwapiv1 "sigs.k8s.io/gateway-api/apis/v1" "sort" "strconv" "strings" @@ -352,6 +354,7 @@ func (t *Translator) translateEnvoyExtensionPolicyForGateway( var ( extProcs []ir.ExtProc wasms []ir.Wasm + luas []ir.Lua err, errs error ) @@ -363,6 +366,10 @@ func (t *Translator) translateEnvoyExtensionPolicyForGateway( err = perr.WithMessage(err, "Wasm") errs = errors.Join(errs, err) } + if luas, err = t.buildLuas(policy, resources); err != nil { + err = perr.WithMessage(err, "Lua") + errs = errors.Join(errs, err) + } irKey := t.getIRKey(gateway.Gateway) // Should exist since we've validated this @@ -395,6 +402,7 @@ func (t *Translator) translateEnvoyExtensionPolicyForGateway( r.EnvoyExtensions = &ir.EnvoyExtensionFeatures{ ExtProcs: extProcs, Wasms: wasms, + Luas: luas, } } } @@ -402,6 +410,72 @@ func (t *Translator) translateEnvoyExtensionPolicyForGateway( return errs } +func (t *Translator) buildLuas(policy *egv1a1.EnvoyExtensionPolicy, resources *resource.Resources) ([]ir.Lua, error) { + var luaIRList []ir.Lua + + if policy == nil { + return nil, nil + } + + for idx, ep := range policy.Spec.Lua { + name := irConfigNameForLua(policy, idx) + luaIR, err := t.buildLua(name, policy, ep, resources) + if err != nil { + return nil, err + } + luaIRList = append(luaIRList, *luaIR) + } + return luaIRList, nil +} + +func (t *Translator) buildLua( + name string, + policy *egv1a1.EnvoyExtensionPolicy, + lua egv1a1.Lua, + resources *resource.Resources, +) (*ir.Lua, error) { + var luaBody *string + var err error + if lua.Type == egv1a1.LuaValueTypeValueRef { + luaBody, err = getLuaBodyFromLocalObjectReference(lua.ValueRef, resources, policy.Namespace) + } else { + luaBody = lua.Inline + } + if err != nil { + return nil, err + } + if err = luavalidator.NewLuaValidator(*luaBody).Validate(); err != nil { + return nil, fmt.Errorf("validation failed for lua body in policy with name %v: %v", name, err) + } + return &ir.Lua{ + Name: name, + Body: luaBody, + }, nil +} + +// getLuaBodyFromLocalObjectReference assumes the local object reference points to a Kubernetes ConfigMap +func getLuaBodyFromLocalObjectReference(valueRef *gwapiv1.LocalObjectReference, resources *resource.Resources, policyNs string) (*string, error) { + cm := resources.GetConfigMap(policyNs, string(valueRef.Name)) + if cm != nil { + b, dataOk := cm.Data["lua"] + switch { + case dataOk: + return &b, nil + case len(cm.Data) > 0: // Fallback to the first key if lua is not found + for _, value := range cm.Data { + b = value + break + } + return &b, nil + default: + return nil, fmt.Errorf("can't find the key lua in the referenced configmap %s", valueRef.Name) + } + + } else { + return nil, fmt.Errorf("can't find the referenced configmap %s", valueRef.Name) + } +} + func (t *Translator) buildExtProcs(policy *egv1a1.EnvoyExtensionPolicy, resources *resource.Resources, envoyProxy *egv1a1.EnvoyProxy) ([]ir.ExtProc, error) { var extProcIRList []ir.ExtProc @@ -522,6 +596,13 @@ func irConfigNameForExtProc(policy *egv1a1.EnvoyExtensionPolicy, index int) stri strconv.Itoa(index)) } +func irConfigNameForLua(policy *egv1a1.EnvoyExtensionPolicy, index int) string { + return fmt.Sprintf( + "%s/lua/%s", + irConfigName(policy), + strconv.Itoa(index)) +} + func (t *Translator) buildWasms( policy *egv1a1.EnvoyExtensionPolicy, resources *resource.Resources, diff --git a/internal/gatewayapi/luavalidator/lua_validator.go b/internal/gatewayapi/luavalidator/lua_validator.go new file mode 100644 index 00000000000..79820d1d696 --- /dev/null +++ b/internal/gatewayapi/luavalidator/lua_validator.go @@ -0,0 +1,59 @@ +// Copyright Envoy Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package luavalidator + +import ( + _ "embed" + "fmt" + lua "github.com/yuin/gopher-lua" + "strings" +) + +// mockData contains mocks of Envoy supported APIs for Lua filters. +// Refer: https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/lua_filter#stream-handle-api +// +//go:embed mocks.lua +var mockData []byte + +// LuaValidator validates user provided Lua for compatibility with Envoy supported Lua HTTP filter +type LuaValidator struct { + body string +} + +// NewLuaValidator returns a LuaValidator for user provided Lua body +func NewLuaValidator(body string) *LuaValidator { + return &LuaValidator{ + body: body, + } +} + +// Validate runs all validations for the LuaValidator +func (l *LuaValidator) Validate() error { + if !strings.Contains(l.body, "envoy_on_request") && !strings.Contains(l.body, "envoy_on_response") { + return fmt.Errorf("expected one of envoy_on_request() or envoy_on_response() to be defined") + } + if strings.Contains(l.body, "envoy_on_request") { + if err := l.validate(string(mockData) + "\n" + l.body + "\nenvoy_on_request(StreamHandle)"); err != nil { + return fmt.Errorf("failed to mock run envoy_on_request: %v", err) + } + } + if strings.Contains(l.body, "envoy_on_response") { + if err := l.validate(string(mockData) + "\n" + l.body + "\nenvoy_on_response(StreamHandle)"); err != nil { + return fmt.Errorf("failed to mock run envoy_on_response: %v", err) + } + } + return nil +} + +// validate interprets and runs the provided Lua body in runtime +func (l *LuaValidator) validate(body string) error { + L := lua.NewState() + defer L.Close() + if err := L.DoString(body); err != nil { + return err + } + return nil +} diff --git a/internal/gatewayapi/luavalidator/lua_validator_test.go b/internal/gatewayapi/luavalidator/lua_validator_test.go new file mode 100644 index 00000000000..d2b0b5e52f8 --- /dev/null +++ b/internal/gatewayapi/luavalidator/lua_validator_test.go @@ -0,0 +1,59 @@ +// Copyright Envoy Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package luavalidator + +import ( + "strings" + "testing" +) + +func Test_Validate(t *testing.T) { + type args struct { + name string + body string + expectedErrSubstring string + } + tests := []args{ + { + name: "empty body", + body: "", + expectedErrSubstring: "expected one of envoy_on_request() or envoy_on_response() to be defined", + }, + { + name: "logInfo: envoy_on_response", + body: `function envoy_on_response(response_handle) + response_handle:logInfo("Goodbye.") + end`, + expectedErrSubstring: "", + }, + { + name: "logInfo: envoy_on_request", + body: `function envoy_on_request(request_handle) + request_handle:logInfo("Goodbye.") + end`, + expectedErrSubstring: "", + }, + { + name: "unknown api call", + body: `function envoy_on_request(request_handle) + request_handle:unknownApi() + end`, + expectedErrSubstring: "attempt to call a non-function object", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := NewLuaValidator(tt.body) + if err := l.Validate(); err != nil && tt.expectedErrSubstring == "" { + t.Errorf("Unexpected error: %v", err) + } else if err != nil && !strings.Contains(err.Error(), tt.expectedErrSubstring) { + t.Errorf("Expected substring in error: %v, got error: %v", tt.expectedErrSubstring, err) + } else if err == nil && tt.expectedErrSubstring != "" { + t.Errorf("Expected error with substring: %v", tt.expectedErrSubstring) + } + }) + } +} diff --git a/internal/gatewayapi/luavalidator/mocks.lua b/internal/gatewayapi/luavalidator/mocks.lua new file mode 100644 index 00000000000..2bce8266895 --- /dev/null +++ b/internal/gatewayapi/luavalidator/mocks.lua @@ -0,0 +1,577 @@ +-- ParsedName object. +local ParsedName = {} + +-- Logging methods. +function ParsedName:logTrace(message) + print("[TRACE]: " .. message) +end + +function ParsedName:logDebug(message) + print("[DEBUG]: " .. message) +end + +function ParsedName:logInfo(message) + print("[INFO]: " .. message) +end + +function ParsedName:logWarn(message) + print("[WARN]: " .. message) +end + +function ParsedName:logErr(message) + print("[ERROR]: " .. message) +end + +function ParsedName:logCritical(message) + print("[CRITICAL]: " .. message) +end + +-- Method to return the common name. +function ParsedName:commonName() + return "common_name" +end + +-- Method to return the organization name. +function ParsedName:organizationName() + return "organization_name" +end + +local SSLConnectionObject = {} + +-- Logging methods +function SSLConnectionObject:logTrace(message) + print("[TRACE]: " .. message) +end + +function SSLConnectionObject:logDebug(message) + print("[DEBUG]: " .. message) +end + +function SSLConnectionObject:logInfo(message) + print("[INFO]: " .. message) +end + +function SSLConnectionObject:logWarn(message) + print("[WARN]: " .. message) +end + +function SSLConnectionObject:logErr(message) + print("[ERROR]: " .. message) +end + +function SSLConnectionObject:logCritical(message) + print("[CRITICAL]: " .. message) +end + +-- SSL-related methods +function SSLConnectionObject:peerCertificatePresented() + return true -- Simulate that the peer certificate is always presented +end + +function SSLConnectionObject:peerCertificateValidated() + return false -- Simulate no validation for TLS session resumption +end + +function SSLConnectionObject:uriSanLocalCertificate() + local t = {} + for i = 1, 1000 do + t[i] = "san" + end + return t +end + +function SSLConnectionObject:sha256PeerCertificateDigest() + return "abcdef1234567890" -- Simulated SHA256 digest +end + +function SSLConnectionObject:serialNumberPeerCertificate() + return "123456789" -- Simulated serial number +end + +function SSLConnectionObject:issuerPeerCertificate() + return "CN=Mock Issuer, O=Example Org" -- Simulated issuer +end + +function SSLConnectionObject:subjectPeerCertificate() + return "CN=Mock Subject, O=Example Org" -- Simulated subject +end + +function SSLConnectionObject:parsedSubjectPeerCertificate() + return { + commonName = function() return "Mock Subject" end, + organizationName = function() return "Example Org" end, + } +end + +function SSLConnectionObject:uriSanPeerCertificate() + return {"uri1", "uri2"} -- Example SAN URIs for peer certificate +end + +function SSLConnectionObject:subjectLocalCertificate() + return "CN=Local Cert, O=Example Org" -- Simulated local certificate subject +end + +function SSLConnectionObject:urlEncodedPemEncodedPeerCertificate() + return "PEM_ENCODED_PEER_CERT" -- Simulated PEM encoding +end + +function SSLConnectionObject:urlEncodedPemEncodedPeerCertificateChain() + return "PEM_ENCODED_CERT_CHAIN" -- Simulated PEM chain +end + +function SSLConnectionObject:dnsSansPeerCertificate() + local t = {} + for i = 1, 1000 do + t[i] = "example.com" + end + return t +end + +function SSLConnectionObject:dnsSansLocalCertificate() + local t = {} + for i = 1, 1000 do + t[i] = "local.com" + end + return t +end + +function SSLConnectionObject:oidsPeerCertificate() + local t = {} + for i = 1, 1000 do + t[i] = "1.2.840.113549.1.1.1" + end + return t +end + +function SSLConnectionObject:oidsLocalCertificate() + local t = {} + for i = 1, 1000 do + t[i] = "1.2.840.113549.1.1.1" + end + return t +end + +function SSLConnectionObject:validFromPeerCertificate() + return 1672531200 +end + +function SSLConnectionObject:expirationPeerCertificate() + return 2724608000 +end + +function SSLConnectionObject:sessionId() + return "abcdef1234567890abcdef1234567890" +end + +function SSLConnectionObject:ciphersuiteId() + return "0x1301" +end + +function SSLConnectionObject:ciphersuiteString() + return "TLS_AES_128_GCM_SHA256" +end + +function SSLConnectionObject:tlsVersion() + return "TLSv1.3" +end + +-- Connection object +local Connection = {} + +-- Logging methods for Connection +function Connection:logTrace(message) + print("[TRACE]: " .. message) +end + +function Connection:logDebug(message) + print("[DEBUG]: " .. message) +end + +function Connection:logInfo(message) + print("[INFO]: " .. message) +end + +function Connection:logWarn(message) + print("[WARN]: " .. message) +end + +function Connection:logErr(message) + print("[ERROR]: " .. message) +end + +function Connection:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function Connection:ssl() + return SSLConnectionObject +end + +local DynamicMetadata = {} + +-- Logging methods +function DynamicMetadata:logTrace(message) + print("[TRACE]: " .. message) +end + +function DynamicMetadata:logDebug(message) + print("[DEBUG]: " .. message) +end + +function DynamicMetadata:logInfo(message) + print("[INFO]: " .. message) +end + +function DynamicMetadata:logWarn(message) + print("[WARN]: " .. message) +end + +function DynamicMetadata:logErr(message) + print("[ERROR]: " .. message) +end + +function DynamicMetadata:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function DynamicMetadata:get(filterName) + return "get_result" +end + +function DynamicMetadata:set(filterName, key, value) +end + +function DynamicMetadata:__pairs() + return function(tbl, prevKey) + return nil, nil + end, self, nil +end + +-- ConnectionStreamInfo Object +local ConnectionStreamInfo = {} + +-- Logging methods for ConnectionStreamInfo +function ConnectionStreamInfo:logTrace(message) + print("[TRACE]: " .. message) +end + +function ConnectionStreamInfo:logDebug(message) + print("[DEBUG]: " .. message) +end + +function ConnectionStreamInfo:logInfo(message) + print("[INFO]: " .. message) +end + +function ConnectionStreamInfo:logWarn(message) + print("[WARN]: " .. message) +end + +function ConnectionStreamInfo:logErr(message) + print("[ERROR]: " .. message) +end + +function ConnectionStreamInfo:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function ConnectionStreamInfo:dynamicMetadata() + return DynamicMetadata +end + +-- StreamInfo Object +local StreamInfo = {} + +function StreamInfo:logTrace(message) + print("[TRACE]: " .. message) +end + +function StreamInfo:logDebug(message) + print("[DEBUG]: " .. message) +end + +function StreamInfo:logInfo(message) + print("[INFO]: " .. message) +end + +function StreamInfo:logWarn(message) + print("[WARN]: " .. message) +end + +function StreamInfo:logErr(message) + print("[ERROR]: " .. message) +end + +function StreamInfo:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function StreamInfo:protocol() + return "HTTP/2" +end + +function StreamInfo:routeName() + return "example-route" +end + +function StreamInfo:virtualClusterName() + return "example-virtual-cluster" +end + +function StreamInfo:downstreamDirectLocalAddress() + return "127.0.0.1:8080" +end + +function StreamInfo:downstreamLocalAddress() + return "192.168.1.100:8080" +end + +function StreamInfo:downstreamDirectRemoteAddress() + return "192.168.1.50:5050" +end + +function StreamInfo:downstreamRemoteAddress() + return "10.0.0.1:5050" +end + +function StreamInfo:dynamicMetadata() + return DynamicMetadata +end + +function StreamInfo:downstreamSslConnection() + return SSLConnectionObject +end + +function StreamInfo:requestedServerName() + return "example-server" +end + +-- Metadata Object +local Metadata = {} + +-- Logging Methods +function Metadata:logTrace(message) + print("[TRACE]: " .. message) +end + +function Metadata:logDebug(message) + print("[DEBUG]: " .. message) +end + +function Metadata:logInfo(message) + print("[INFO]: " .. message) +end + +function Metadata:logWarn(message) + print("[WARN]: " .. message) +end + +function Metadata:logErr(message) + print("[ERROR]: " .. message) +end + +function Metadata:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function Metadata:get(key) + return "get_example" +end + +function Metadata:__pairs() + return function(tbl, prevKey) + return nil, nil + end, self, nil +end + +-- Buffer Object +local Buffer = {} + +-- Logging Methods +function Buffer:logTrace(message) + print("[TRACE]: " .. message) +end + +function Buffer:logDebug(message) + print("[DEBUG]: " .. message) +end + +function Buffer:logInfo(message) + print("[INFO]: " .. message) +end + +function Buffer:logWarn(message) + print("[WARN]: " .. message) +end + +function Buffer:logErr(message) + print("[ERROR]: " .. message) +end + +function Buffer:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function Buffer:length() + return 10 +end + +function Buffer:getBytes(index, length) + return "example_bytes" +end + +-- Headers Object +local Headers = {} + +-- Logging Methods +function Headers:logTrace(message) + print("[TRACE]: " .. message) +end + +function Headers:logDebug(message) + print("[DEBUG]: " .. message) +end + +function Headers:logInfo(message) + print("[INFO]: " .. message) +end + +function Headers:logWarn(message) + print("[WARN]: " .. message) +end + +function Headers:logErr(message) + print("[ERROR]: " .. message) +end + +function Headers:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function Headers:add(key, value) +end + +function Headers:get(key) + return "example_value" +end + +function Headers:getAtIndex(key, index) + return "example_value" +end + +function Headers:getNumValues(key) + return 1 +end + +function Headers:__pairs() + return function(tbl, prevKey) + return nil, nil + end, self, nil +end + +function Headers:remove(key) +end + +function Headers:replace(key, value) +end + +function Headers:setHttp1ReasonPhrase(reasonPhrase) +end + +function Headers:getHttp1ReasonPhrase() + return "reason" +end + +-- StreamHandle Object +local StreamHandle = {} + +-- Logging Methods (with internal logging mechanism) +function StreamHandle:logTrace(message) + print("[TRACE]: " .. message) +end + +function StreamHandle:logDebug(message) + print("[DEBUG]: " .. message) +end + +function StreamHandle:logInfo(message) + print("[INFO]: " .. message) +end + +function StreamHandle:logWarn(message) + print("[WARN]: " .. message) +end + +function StreamHandle:logErr(message) + print("[ERROR]: " .. message) +end + +function StreamHandle:logCritical(message) + print("[CRITICAL]: " .. message) +end + +function StreamHandle:headers() + return Headers +end + +function StreamHandle:body(always_wrap_body) + return "body" +end + +function StreamHandle:bodyChunks() + return function(tbl, prevKey) + return nil, nil + end, self, nil +end + +function StreamHandle:trailers() + return Headers +end + +function StreamHandle:httpCall(cluster, headers, body, timeout_ms, asynchronous) + self:logInfo("Making HTTP call to cluster: " .. cluster) + return headers, body +end + +function StreamHandle:respond(headers, body) + self:logInfo("Responding with status: " .. headers[":status"]) +end + +function StreamHandle:metadata() + return Metadata +end + +function StreamHandle:streamInfo() + return StreamInfo +end + +function StreamHandle:connection() + return Connection +end + +function StreamHandle:connectionStreamInfo() + return ConnectionStreamInfo +end + +function StreamHandle:setUpstreamOverrideHost(host, strict) +end + +function StreamHandle:importPublicKey(keyder, keyderLength) + return "mocked_public_key" +end + +function StreamHandle:verifySignature(hashFunction, pubkey, signature, signatureLength, data, dataLength) + return true, "" +end + +function StreamHandle:base64Escape(inputString) + return inputString +end + +function StreamHandle:timestamp(format) + return os.time() * 1000 +end + +function StreamHandle:timestampString(resolution) + return os.date("!%Y-%m-%dT%H:%M:%SZ", os.time()) +end \ No newline at end of file diff --git a/internal/ir/xds.go b/internal/ir/xds.go index f428297ace5..ca1f0ebdb66 100644 --- a/internal/ir/xds.go +++ b/internal/ir/xds.go @@ -862,6 +862,8 @@ type EnvoyExtensionFeatures struct { ExtProcs []ExtProc `json:"extProcs,omitempty" yaml:"extProcs,omitempty"` // Wasm extensions Wasms []Wasm `json:"wasms,omitempty" yaml:"wasms,omitempty"` + // Lua extensions + Luas []Lua `json:"luas,omitempty" yaml:"luas,omitempty"` } // UnstructuredRef holds unstructured data for an arbitrary k8s resource introduced by an extension @@ -2792,6 +2794,16 @@ type ExtProc struct { AllowModeOverride bool `json:"allowModeOverride,omitempty" yaml:"allowModeOverride,omitempty"` } +// Lua holds the information associated with Lua extensions +// +k8s:deepcopy-gen=true +type Lua struct { + // Name is a unique name for the LUa configuration. + // The xds translator only generates one Lua filter for each unique name + Name string + // Body is the Lua source body + Body *string +} + // Wasm holds the information associated with the Wasm extensions. // +k8s:deepcopy-gen=true type Wasm struct { diff --git a/internal/ir/zz_generated.deepcopy.go b/internal/ir/zz_generated.deepcopy.go index 3c42375daa5..541f28f750b 100644 --- a/internal/ir/zz_generated.deepcopy.go +++ b/internal/ir/zz_generated.deepcopy.go @@ -888,6 +888,13 @@ func (in *EnvoyExtensionFeatures) DeepCopyInto(out *EnvoyExtensionFeatures) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.Luas != nil { + in, out := &in.Luas, &out.Luas + *out = make([]Lua, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new EnvoyExtensionFeatures. @@ -2132,6 +2139,26 @@ func (in *LocalRateLimit) DeepCopy() *LocalRateLimit { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Lua) DeepCopyInto(out *Lua) { + *out = *in + if in.Body != nil { + in, out := &in.Body, &out.Body + *out = new(string) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Lua. +func (in *Lua) DeepCopy() *Lua { + if in == nil { + return nil + } + out := new(Lua) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Metrics) DeepCopyInto(out *Metrics) { *out = *in diff --git a/internal/xds/translator/lua.go b/internal/xds/translator/lua.go new file mode 100644 index 00000000000..f30a50a26ab --- /dev/null +++ b/internal/xds/translator/lua.go @@ -0,0 +1,119 @@ +// Copyright Envoy Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package translator + +import ( + "errors" + + egv1a1 "github.com/envoyproxy/gateway/api/v1alpha1" + "github.com/envoyproxy/gateway/internal/ir" + "github.com/envoyproxy/gateway/internal/xds/types" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + luafilterv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/lua/v3" + hcmv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + "google.golang.org/protobuf/types/known/anypb" +) + +func init() { + registerHTTPFilter(&lua{}) +} + +type lua struct{} + +var _ httpFilter = &lua{} + +// patchHCM builds and appends the lua Filters to the HTTP Connection Manager +// if applicable, and it does not already exist. +// Note: this method creates a lua filter for each route that contains a lua config. +func (*lua) patchHCM(mgr *hcmv3.HttpConnectionManager, irListener *ir.HTTPListener) error { + var errs error + + if mgr == nil { + return errors.New("hcm is nil") + } + if irListener == nil { + return errors.New("ir listener is nil") + } + + for _, route := range irListener.Routes { + if !routeContainsLua(route) { + continue + } + for _, ep := range route.EnvoyExtensions.Luas { + if hcmContainsFilter(mgr, luaFilterName(ep)) { + continue + } + filter, err := buildHCMLuaPerRouteFilter(ep) + if err != nil { + errs = errors.Join(errs, err) + continue + } + mgr.HttpFilters = append(mgr.HttpFilters, filter) + } + } + return errs +} + +// buildHCMLuaPerRouteFilter returns a lua HTTP filter from the provided IR HTTPRoute. +func buildHCMLuaPerRouteFilter(lua ir.Lua) (*hcmv3.HttpFilter, error) { + var ( + luaProto *luafilterv3.LuaPerRoute + luaAny *anypb.Any + err error + ) + + if luaProto, err = luaPerRouteConfig(lua); err != nil { + return nil, err + } + if err = luaProto.ValidateAll(); err != nil { + return nil, err + } + if luaAny, err = anypb.New(luaProto); err != nil { + return nil, err + } + + return &hcmv3.HttpFilter{ + Name: luaFilterName(lua), + ConfigType: &hcmv3.HttpFilter_TypedConfig{ + TypedConfig: luaAny, + }, + }, nil +} + +func luaFilterName(lua ir.Lua) string { + return perRouteFilterName(egv1a1.EnvoyFilterWasm, lua.Name) +} + +func luaPerRouteConfig(lua ir.Lua) (*luafilterv3.LuaPerRoute, error) { + return &luafilterv3.LuaPerRoute{ + Override: &luafilterv3.LuaPerRoute_SourceCode{ + SourceCode: &corev3.DataSource{ + Specifier: &corev3.DataSource_InlineString{ + InlineString: *lua.Body, + }, + }, + }, + }, nil +} + +// routeContainsLua returns true if Luas exists for the provided route. +func routeContainsLua(irRoute *ir.HTTPRoute) bool { + if irRoute == nil { + return false + } + return irRoute.EnvoyExtensions != nil && len(irRoute.EnvoyExtensions.Luas) > 0 +} + +// patchResources patches the cluster resources for the http lua code source. +func (*lua) patchResources(_ *types.ResourceVersionTable, _ []*ir.HTTPRoute) error { + return nil +} + +// patchRoute patches the provided route with the lua config if applicable. +func (*lua) patchRoute(route *routev3.Route, irRoute *ir.HTTPRoute) error { + return nil +}