Skip to content

Commit

Permalink
feat: refactor StakingVault initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
loga4 committed Nov 29, 2024
1 parent 41ed2c7 commit 9b59268
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 52 deletions.
35 changes: 18 additions & 17 deletions contracts/0.8.25/vaults/StakingVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {Versioned} from "../utils/Versioned.sol";

// TODO: extract interface and implement it

contract StakingVault is IStakingVault, IBeaconProxy, VaultBeaconChainDepositor, OwnableUpgradeable, Versioned {
contract StakingVault is IStakingVault, IBeaconProxy, VaultBeaconChainDepositor, OwnableUpgradeable {
/// @custom:storage-location erc7201:StakingVault.Vault
struct VaultStorage {
IStakingVault.Report report;
Expand All @@ -25,8 +25,7 @@ contract StakingVault is IStakingVault, IBeaconProxy, VaultBeaconChainDepositor,
int128 inOutDelta;
}

uint256 private constant _version = 1;
address private immutable _SELF;
uint64 private constant _version = 1;
VaultHub public immutable VAULT_HUB;

/// keccak256(abi.encode(uint256(keccak256("StakingVault.Vault")) - 1)) & ~bytes32(uint256(0xff));
Expand All @@ -39,32 +38,34 @@ contract StakingVault is IStakingVault, IBeaconProxy, VaultBeaconChainDepositor,
) VaultBeaconChainDepositor(_beaconChainDepositContract) {
if (_vaultHub == address(0)) revert ZeroArgument("_vaultHub");

_SELF = address(this);
VAULT_HUB = VaultHub(_vaultHub);

_disableInitializers();
}

modifier onlyBeacon() {
if (msg.sender != getBeacon()) revert UnauthorizedSender(msg.sender);
_;
}

/// @notice Initialize the contract storage explicitly.
/// The initialize function selector is not changed. For upgrades use `_params` variable
///
/// @param _owner owner address that can TBD
/// @param _owner vaultStaffRoom address
/// @param _params the calldata for initialize contract after upgrades
// solhint-disable-next-line no-unused-vars
function initialize(address _owner, bytes calldata _params) external {
if (_owner == address(0)) revert ZeroArgument("_owner");

if (address(this) == _SELF) {
revert NonProxyCallsForbidden();
}

_initializeContractVersionTo(1);

_transferOwnership(_owner);
function initialize(address _owner, bytes calldata _params) external onlyBeacon initializer {
__Ownable_init(_owner);
}

function version() public pure virtual returns(uint256) {
function version() public pure virtual returns(uint64) {
return _version;
}

function getInitializedVersion() public view returns (uint64) {
return _getInitializedVersion();
}

function getBeacon() public view returns (address) {
return ERC1967Utils.getBeacon();
}
Expand Down Expand Up @@ -228,5 +229,5 @@ contract StakingVault is IStakingVault, IBeaconProxy, VaultBeaconChainDepositor,
error NotHealthy();
error NotAuthorized(string operation, address sender);
error LockedCannotBeDecreased(uint256 locked);
error NonProxyCallsForbidden();
error UnauthorizedSender(address sender);
}
2 changes: 1 addition & 1 deletion contracts/0.8.25/vaults/interfaces/IBeaconProxy.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ pragma solidity 0.8.25;

interface IBeaconProxy {
function getBeacon() external view returns (address);
function version() external pure returns(uint256);
function version() external pure returns(uint64);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ import {IReportReceiver} from "contracts/0.8.25/vaults/interfaces/IReportReceive
import {IStakingVault} from "contracts/0.8.25/vaults/interfaces/IStakingVault.sol";
import {IBeaconProxy} from "contracts/0.8.25/vaults/interfaces/IBeaconProxy.sol";
import {VaultBeaconChainDepositor} from "contracts/0.8.25/vaults/VaultBeaconChainDepositor.sol";
import {Versioned} from "contracts/0.8.25/utils/Versioned.sol";

contract StakingVault__HarnessForTestUpgrade is IBeaconProxy, VaultBeaconChainDepositor, OwnableUpgradeable, Versioned {
contract StakingVault__HarnessForTestUpgrade is IBeaconProxy, VaultBeaconChainDepositor, OwnableUpgradeable {
/// @custom:storage-location erc7201:StakingVault.Vault
struct VaultStorage {
uint128 reportValuation;
Expand All @@ -25,7 +24,7 @@ contract StakingVault__HarnessForTestUpgrade is IBeaconProxy, VaultBeaconChainDe
int256 inOutDelta;
}

uint256 private constant _version = 2;
uint64 private constant _version = 2;
VaultHub public immutable vaultHub;

/// keccak256(abi.encode(uint256(keccak256("StakingVault.Vault")) - 1)) & ~bytes32(uint256(0xff));
Expand All @@ -41,25 +40,33 @@ contract StakingVault__HarnessForTestUpgrade is IBeaconProxy, VaultBeaconChainDe
vaultHub = VaultHub(_vaultHub);
}

modifier onlyBeacon() {
if (msg.sender != getBeacon()) revert UnauthorizedSender(msg.sender);
_;
}

/// @notice Initialize the contract storage explicitly.
/// @param _owner owner address that can TBD
/// @param _params the calldata for initialize contract after upgrades
function initialize(address _owner, bytes calldata _params) external {
if (_owner == address(0)) revert ZeroArgument("_owner");
if (getBeacon() == address(0)) revert NonProxyCall();
function initialize(address _owner, bytes calldata _params) external onlyBeacon reinitializer(_version) {
__StakingVault_init_v2();
__Ownable_init(_owner);
}

_initializeContractVersionTo(2);
function finalizeUpgrade_v2() public reinitializer(_version) {
__StakingVault_init_v2();
}

_transferOwnership(_owner);
event InitializedV2();
function __StakingVault_init_v2() internal {
emit InitializedV2();
}

function finalizeUpgrade_v2() external {
if (getContractVersion() == _version) {
revert AlreadyInitialized();
}
function getInitializedVersion() public view returns (uint64) {
return _getInitializedVersion();
}

function version() external pure virtual returns(uint256) {
function version() external pure virtual returns(uint64) {
return _version;
}

Expand All @@ -82,6 +89,5 @@ contract StakingVault__HarnessForTestUpgrade is IBeaconProxy, VaultBeaconChainDe
}

error ZeroArgument(string name);
error NonProxyCall();
error AlreadyInitialized();
error UnauthorizedSender(address sender);
}
14 changes: 4 additions & 10 deletions test/0.8.25/vaults/vault.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,17 @@ describe("StakingVault.sol", async () => {
});

describe("initialize", () => {
it("reverts if `_owner` is zero address", async () => {
await expect(stakingVault.initialize(ZeroAddress, "0x"))
.to.be.revertedWithCustomError(stakingVault, "ZeroArgument")
.withArgs("_owner");
});

it("reverts if call from non proxy", async () => {
it("reverts on impl initialization", async () => {
await expect(stakingVault.initialize(await owner.getAddress(), "0x")).to.be.revertedWithCustomError(
stakingVault,
"NonProxyCallsForbidden",
vaultProxy,
"UnauthorizedSender",
);
});

it("reverts if already initialized", async () => {
await expect(vaultProxy.initialize(await owner.getAddress(), "0x")).to.be.revertedWithCustomError(
vaultProxy,
"NonZeroContractVersionOnInit",
"UnauthorizedSender",
);
});
});
Expand Down
46 changes: 38 additions & 8 deletions test/0.8.25/vaults/vaultFactory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ describe("VaultFactory.sol", () => {
await accounting.connect(admin).grantRole(await accounting.VAULT_REGISTRY_ROLE(), admin);

//the initialize() function cannot be called on a contract
await expect(implOld.initialize(stranger, "0x")).to.revertedWithCustomError(implOld, "NonProxyCallsForbidden");
await expect(implOld.initialize(stranger, "0x")).to.revertedWithCustomError(implOld, "UnauthorizedSender");
});

beforeEach(async () => (originalState = await Snapshot.take()));
Expand Down Expand Up @@ -135,7 +135,12 @@ describe("VaultFactory.sol", () => {
expect(await vault.getBeacon()).to.eq(await vaultFactory.getAddress());
});

it("works with non-empty `params`", async () => {});
it("check `version()`", async () => {
const { vault } = await createVaultProxy(vaultFactory, vaultOwner1);
expect(await vault.version()).to.eq(1);
});

it.skip("works with non-empty `params`", async () => {});
});

context("connect", () => {
Expand Down Expand Up @@ -247,13 +252,38 @@ describe("VaultFactory.sol", () => {
),
).to.revertedWithCustomError(accounting, "ImplNotAllowed");

const version1After = await vault1.version();
const version2After = await vault2.version();
const version3After = await vault3.version();
const vault1WithNewImpl = await ethers.getContractAt("StakingVault__HarnessForTestUpgrade", vault1, deployer);
const vault2WithNewImpl = await ethers.getContractAt("StakingVault__HarnessForTestUpgrade", vault2, deployer);
const vault3WithNewImpl = await ethers.getContractAt("StakingVault__HarnessForTestUpgrade", vault3, deployer);

//finalize first vault
await vault1WithNewImpl.finalizeUpgrade_v2();

const version1After = await vault1WithNewImpl.version();
const version2After = await vault2WithNewImpl.version();
const version3After = await vault3WithNewImpl.version();

const version1AfterV2 = await vault1WithNewImpl.getInitializedVersion();
const version2AfterV2 = await vault2WithNewImpl.getInitializedVersion();
const version3AfterV2 = await vault3WithNewImpl.getInitializedVersion();

expect(version1Before).to.eq(1);
expect(version1AfterV2).to.eq(2);

expect(version2Before).to.eq(1);
expect(version2AfterV2).to.eq(1);

expect(version3After).to.eq(2);

const v1 = { version: version1After, getInitializedVersion: version1AfterV2 };
const v2 = { version: version2After, getInitializedVersion: version2AfterV2 };
const v3 = { version: version3After, getInitializedVersion: version3AfterV2 };

console.table([v1, v2, v3]);

expect(version1Before).not.to.eq(version1After);
expect(version2Before).not.to.eq(version2After);
expect(2).to.eq(version3After);
// await vault1.initialize(stranger, "0x")
// await vault2.initialize(stranger, "0x")
// await vault3.initialize(stranger, "0x")
});
});
});
2 changes: 1 addition & 1 deletion test/0.8.25/vaults/vaultStaffRoom.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ describe("VaultStaffRoom.sol", () => {
await accounting.connect(admin).grantRole(await accounting.VAULT_MASTER_ROLE(), admin);

//the initialize() function cannot be called on a contract
await expect(implOld.initialize(stranger, "0x")).to.revertedWithCustomError(implOld, "NonProxyCallsForbidden");
await expect(implOld.initialize(stranger, "0x")).to.revertedWithCustomError(implOld, "UnauthorizedSender");
});

beforeEach(async () => (originalState = await Snapshot.take()));
Expand Down

0 comments on commit 9b59268

Please sign in to comment.