Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move message validation logic to msg_server.go #2473

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 234 additions & 2 deletions x/ccv/provider/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

errorsmod "cosmossdk.io/errors"
"cosmossdk.io/math"

cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand All @@ -19,6 +20,28 @@ import (
ccvtypes "github.com/cosmos/interchain-security/v6/x/ccv/types"
)

// validateDeprecatedChainId validates that the chain ID is not provided (deprecated field)
func validateDeprecatedChainId(chainId string) error {
if chainId != "" {
return fmt.Errorf("chain ID is deprecated, use consumer ID instead")
}
return nil
}

// validateProviderAddress validates that the provider address matches the signer
func validateProviderAddress(addr, signer string) error {
if addr == "" {
return fmt.Errorf("empty provider address")
}
if signer == "" {
return fmt.Errorf("empty signer address")
}
if addr != signer {
return fmt.Errorf("provider address %s does not match signer %s", addr, signer)
}
return nil
}

type msgServer struct {
*Keeper
}
Expand Down Expand Up @@ -52,6 +75,26 @@ func (k msgServer) UpdateParams(goCtx context.Context, msg *types.MsgUpdateParam
func (k msgServer) AssignConsumerKey(goCtx context.Context, msg *types.MsgAssignConsumerKey) (*types.MsgAssignConsumerKeyResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ProviderAddr: %s", err.Error())
}

if msg.ConsumerKey == "" {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerKey cannot be empty")
}
if _, _, err := types.ParseConsumerKeyFromJson(msg.ConsumerKey); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgAssignConsumerKey, "ConsumerKey: %s", err.Error())
}

providerValidatorAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -109,6 +152,34 @@ func (k msgServer) ChangeRewardDenoms(goCtx context.Context, msg *types.MsgChang
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

// Validate basic message properties
emptyDenomsToAdd := len(msg.DenomsToAdd) == 0
emptyDenomsToRemove := len(msg.DenomsToRemove) == 0
// Return error if both sets are empty or nil
if emptyDenomsToAdd && emptyDenomsToRemove {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "both DenomsToAdd and DenomsToRemove are empty")
}

denomMap := map[string]struct{}{}
for _, denom := range msg.DenomsToAdd {
// validate the denom
if !sdk.NewCoin(denom, math.NewInt(1)).IsValid() {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "DenomsToAdd: invalid denom(%s)", denom)
}
denomMap[denom] = struct{}{}
}
for _, denom := range msg.DenomsToRemove {
// validate the denom
if !sdk.NewCoin(denom, math.NewInt(1)).IsValid() {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms, "DenomsToRemove: invalid denom(%s)", denom)
}
// denom cannot be in both sets
if _, found := denomMap[denom]; found {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgChangeRewardDenoms,
"denom(%s) cannot be both added and removed", denom)
}
}

eventAttributes := k.Keeper.ChangeRewardDenoms(ctx, msg.DenomsToAdd, msg.DenomsToRemove)

ctx.EventManager().EmitEvent(
Expand All @@ -123,6 +194,16 @@ func (k msgServer) ChangeRewardDenoms(goCtx context.Context, msg *types.MsgChang

func (k msgServer) SubmitConsumerMisbehaviour(goCtx context.Context, msg *types.MsgSubmitConsumerMisbehaviour) (*types.MsgSubmitConsumerMisbehaviourResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerMisbehaviour, "ConsumerId: %s", err.Error())
}

if err := msg.Misbehaviour.ValidateBasic(); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerMisbehaviour, "Misbehaviour: %s", err.Error())
}

if err := k.Keeper.HandleConsumerMisbehaviour(ctx, msg.ConsumerId, *msg.Misbehaviour); err != nil {
return nil, err
}
Expand All @@ -147,6 +228,23 @@ func (k msgServer) SubmitConsumerMisbehaviour(goCtx context.Context, msg *types.
func (k msgServer) SubmitConsumerDoubleVoting(goCtx context.Context, msg *types.MsgSubmitConsumerDoubleVoting) (*types.MsgSubmitConsumerDoubleVotingResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if dve, err := tmtypes.DuplicateVoteEvidenceFromProto(msg.DuplicateVoteEvidence); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "DuplicateVoteEvidence: %s", err.Error())
} else {
if err = dve.ValidateBasic(); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "DuplicateVoteEvidence: %s", err.Error())
}
}

if err := types.ValidateHeaderForConsumerDoubleVoting(msg.InfractionBlockHeader); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "ValidateTendermintHeader: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSubmitConsumerDoubleVoting, "ConsumerId: %s", err.Error())
}

evidence, err := tmtypes.DuplicateVoteEvidenceFromProto(msg.DuplicateVoteEvidence)
if err != nil {
return nil, err
Expand Down Expand Up @@ -198,6 +296,25 @@ func (k msgServer) SubmitConsumerDoubleVoting(goCtx context.Context, msg *types.
func (k msgServer) OptIn(goCtx context.Context, msg *types.MsgOptIn) (*types.MsgOptInResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ProviderAddr: %s", err.Error())
}

if msg.ConsumerKey != "" {
if _, _, err := types.ParseConsumerKeyFromJson(msg.ConsumerKey); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptIn, "ConsumerKey: %s", err.Error())
}
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -250,6 +367,19 @@ func (k msgServer) OptIn(goCtx context.Context, msg *types.MsgOptIn) (*types.Msg
func (k msgServer) OptOut(goCtx context.Context, msg *types.MsgOptOut) (*types.MsgOptOutResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgOptOut, "ProviderAddr: %s", err.Error())
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -300,13 +430,30 @@ func (k msgServer) OptOut(goCtx context.Context, msg *types.MsgOptOut) (*types.M
func (k msgServer) SetConsumerCommissionRate(goCtx context.Context, msg *types.MsgSetConsumerCommissionRate) (*types.MsgSetConsumerCommissionRateResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

providerValidatorAddr, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
// Validate basic message properties
if err := validateDeprecatedChainId(msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ChainId: %s", err.Error())
}

if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ConsumerId: %s", err.Error())
}

if err := validateProviderAddress(msg.ProviderAddr, msg.Signer); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "ProviderAddr: %s", err.Error())
}

if !msg.Commission.IsPositive() || msg.Commission.GT(math.LegacyOneDec()) {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgSetConsumerCommissionRate, "commission must be between 0 and 1")
}

valAddress, err := sdk.ValAddressFromBech32(msg.ProviderAddr)
if err != nil {
return nil, err
}

// validator must already be registered
validator, err := k.stakingKeeper.GetValidator(ctx, providerValidatorAddr)
validator, err := k.stakingKeeper.GetValidator(ctx, valAddress)
if err != nil {
return nil, stakingtypes.ErrNoValidatorFound
}
Expand Down Expand Up @@ -350,6 +497,36 @@ func (k msgServer) SetConsumerCommissionRate(goCtx context.Context, msg *types.M
// CreateConsumer creates a consumer chain
func (k msgServer) CreateConsumer(goCtx context.Context, msg *types.MsgCreateConsumer) (*types.MsgCreateConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := types.ValidateChainId("ChainId", msg.ChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "ChainId: %s", err.Error())
}

if err := types.ValidateConsumerMetadata(msg.Metadata); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "Metadata: %s", err.Error())
}

if err := types.ValidateInitializationParameters(*msg.InitializationParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "InitializationParameters: %s", err.Error())
}

if err := types.ValidatePowerShapingParameters(*msg.PowerShapingParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "PowerShapingParameters: %s", err.Error())
}

if err := types.ValidateAllowlistedRewardDenoms(*msg.AllowlistedRewardDenoms); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "AllowlistedRewardDenoms: %s", err.Error())
}

if err := types.ValidateInfractionParameters(*msg.InfractionParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgCreateConsumer, "InfractionParameters: %s", err.Error())
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgCreateConsumerResponse{}

// initialize an empty slice to store event attributes
Expand Down Expand Up @@ -470,6 +647,52 @@ func (k msgServer) CreateConsumer(goCtx context.Context, msg *types.MsgCreateCon
// UpdateConsumer updates the metadata, power-shaping or initialization parameters of a consumer chain
func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateConsumer) (*types.MsgUpdateConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "ConsumerId: %s", err.Error())
}

if msg.Metadata != nil {
if err := types.ValidateConsumerMetadata(*msg.Metadata); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "Metadata: %s", err.Error())
}
}

if msg.InitializationParameters != nil {
if err := types.ValidateInitializationParameters(*msg.InitializationParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "InitializationParameters: %s", err.Error())
}
}

if msg.PowerShapingParameters != nil {
if err := types.ValidatePowerShapingParameters(*msg.PowerShapingParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "PowerShapingParameters: %s", err.Error())
}
}

if msg.AllowlistedRewardDenoms != nil {
if err := types.ValidateAllowlistedRewardDenoms(*msg.AllowlistedRewardDenoms); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "AllowlistedRewardDenoms: %s", err.Error())
}
}

if msg.InfractionParameters != nil {
if err := types.ValidateInfractionParameters(*msg.InfractionParameters); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "InfractionParameters: %s", err.Error())
}
}

if msg.NewChainId != "" {
if err := types.ValidateChainId("NewChainId", msg.NewChainId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgUpdateConsumer, "NewChainId: %s", err.Error())
}
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgUpdateConsumerResponse{}

// initialize an empty slice to store event attributes
Expand Down Expand Up @@ -704,6 +927,15 @@ func (k msgServer) UpdateConsumer(goCtx context.Context, msg *types.MsgUpdateCon
func (k msgServer) RemoveConsumer(goCtx context.Context, msg *types.MsgRemoveConsumer) (*types.MsgRemoveConsumerResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

// Validate basic message properties
if err := ccvtypes.ValidateConsumerId(msg.ConsumerId); err != nil {
return nil, errorsmod.Wrapf(types.ErrInvalidMsgRemoveConsumer, "ConsumerId: %s", err.Error())
}

if k.GetAuthority() != msg.Authority {
return nil, errorsmod.Wrapf(types.ErrUnauthorized, "expected %s, got %s", k.GetAuthority(), msg.Authority)
}

resp := types.MsgRemoveConsumerResponse{}

consumerId := msg.ConsumerId
Expand Down
Loading