From ffa5d222fd48730fc6b4eecd759695d3f867ebff Mon Sep 17 00:00:00 2001
From: Luc Blaeser <112870813+luc-blaeser@users.noreply.github.com>
Date: Mon, 15 Jul 2024 18:37:28 +0200
Subject: [PATCH] feat!: (ic-utils) support canister upgrade option
 wasm_memory_persistence (#502)

* Add upgrade options

* Adjust tests

* Adjust option for orthogonal persistence persistence

* changelog and MSRV 1.75.0

* fix lint

* fmt

* ref test cover new types

---------

Co-authored-by: Linwei Shang <linwei.shang@dfinity.org>
---
 CHANGELOG.md                                  |   4 +
 .../agent/http_transport/route_provider.rs    |   2 +-
 ic-identity-hsm/src/hsm.rs                    |   4 +-
 ic-transport-types/src/request_id.rs          | 162 +++++++++---------
 .../management_canister/builders.rs           |  32 +++-
 ref-tests/tests/ic-ref.rs                     |  25 +--
 rust-toolchain.toml                           |   2 +-
 7 files changed, 128 insertions(+), 103 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index efaf090c..f676d5b9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 * Removed the Bitcoin query methods from `ManagementCanister`. Users should use `BitcoinCanister` for that.
 * Added `BitcoinCanister` to `ic-utils`.
+* Upgraded MSRV to 1.75.0.
+* Changed `ic_utils::interfaces::management_canister::builders::InstallMode::Upgrade` variant to be `Option<CanisterUpgradeOptions>`:
+  * `CanisterUpgradeOptions` is a new struct which covers the new upgrade option: `wasm_memory_persistence: Option<WasmMemoryPersistence>`.
+  * `WasmMemoryPersistence` is a new enum which controls Wasm main memory retention on upgrades which has two variants: `Keep` and `Replace`.
 
 ## [0.36.0] - 2024-06-04
 
diff --git a/ic-agent/src/agent/http_transport/route_provider.rs b/ic-agent/src/agent/http_transport/route_provider.rs
index 3200e691..c554ded5 100644
--- a/ic-agent/src/agent/http_transport/route_provider.rs
+++ b/ic-agent/src/agent/http_transport/route_provider.rs
@@ -88,7 +88,7 @@ mod tests {
     fn test_routes_rotation() {
         let provider = RoundRobinRouteProvider::new(vec!["https://url1.com", "https://url2.com"])
             .expect("failed to create a route provider");
-        let url_strings = vec![
+        let url_strings = [
             "https://url1.com/api/v2/",
             "https://url2.com/api/v2/",
             "https://url1.com/api/v2/",
diff --git a/ic-identity-hsm/src/hsm.rs b/ic-identity-hsm/src/hsm.rs
index ab7dfeb4..706de909 100644
--- a/ic-identity-hsm/src/hsm.rs
+++ b/ic-identity-hsm/src/hsm.rs
@@ -283,7 +283,7 @@ fn get_ec_point(
 
     let blocks =
         from_der(der_encoded_ec_point.as_slice()).map_err(HardwareIdentityError::ASN1Decode)?;
-    let block = blocks.get(0).ok_or(HardwareIdentityError::EcPointEmpty)?;
+    let block = blocks.first().ok_or(HardwareIdentityError::EcPointEmpty)?;
     if let OctetString(_size, data) = block {
         Ok(data.clone())
     } else {
@@ -302,7 +302,7 @@ fn get_attribute_length(
     ctx.get_attribute_value(session_handle, object_handle, &mut attributes)?;
 
     let first = attributes
-        .get(0)
+        .first()
         .ok_or(HardwareIdentityError::AttributeNotFound(attribute_type))?;
     Ok(first.ulValueLen as usize)
 }
diff --git a/ic-transport-types/src/request_id.rs b/ic-transport-types/src/request_id.rs
index 0045c631..8a3c5d7a 100644
--- a/ic-transport-types/src/request_id.rs
+++ b/ic-transport-types/src/request_id.rs
@@ -496,6 +496,87 @@ impl SerializeTupleVariant for TupleVariantSerializer {
     }
 }
 
+// can't use serde_bytes on by-value arrays
+// these impls are effectively #[serde(with = "serde_bytes")]
+impl Serialize for RequestId {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        if serializer.is_human_readable() {
+            let mut text = [0u8; 64];
+            hex::encode_to_slice(self.0, &mut text).unwrap();
+            serializer.serialize_str(std::str::from_utf8(&text).unwrap())
+        } else {
+            serializer.serialize_bytes(&self.0)
+        }
+    }
+}
+
+impl<'de> Deserialize<'de> for RequestId {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        if deserializer.is_human_readable() {
+            deserializer.deserialize_str(RequestIdVisitor)
+        } else {
+            deserializer.deserialize_bytes(RequestIdVisitor)
+        }
+    }
+}
+
+struct RequestIdVisitor;
+
+impl<'de> Visitor<'de> for RequestIdVisitor {
+    type Value = RequestId;
+    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        formatter.write_str("a sha256 hash")
+    }
+
+    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
+    where
+        E: de::Error,
+    {
+        Ok(RequestId::new(v.try_into().map_err(|_| {
+            E::custom(format_args!("must be 32 bytes long, was {}", v.len()))
+        })?))
+    }
+
+    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
+    where
+        A: de::SeqAccess<'de>,
+    {
+        let mut arr = Sha256Hash::default();
+        for (i, byte) in arr.iter_mut().enumerate() {
+            *byte = seq.next_element()?.ok_or(A::Error::custom(format_args!(
+                "must be 32 bytes long, was {}",
+                i - 1
+            )))?;
+        }
+        if seq.next_element::<u8>()?.is_some() {
+            Err(A::Error::custom("must be 32 bytes long, was more"))
+        } else {
+            Ok(RequestId(arr))
+        }
+    }
+
+    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
+    where
+        E: de::Error,
+    {
+        if v.len() != 64 {
+            return Err(E::custom(format_args!(
+                "must be 32 bytes long, was {}",
+                v.len() / 2
+            )));
+        }
+        let mut arr = Sha256Hash::default();
+        hex::decode_to_slice(v, &mut arr).map_err(E::custom)?;
+        Ok(RequestId(arr))
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -877,84 +958,3 @@ mod tests {
         );
     }
 }
-
-// can't use serde_bytes on by-value arrays
-// these impls are effectively #[serde(with = "serde_bytes")]
-impl Serialize for RequestId {
-    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
-        if serializer.is_human_readable() {
-            let mut text = [0u8; 64];
-            hex::encode_to_slice(self.0, &mut text).unwrap();
-            serializer.serialize_str(std::str::from_utf8(&text).unwrap())
-        } else {
-            serializer.serialize_bytes(&self.0)
-        }
-    }
-}
-
-impl<'de> Deserialize<'de> for RequestId {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        if deserializer.is_human_readable() {
-            deserializer.deserialize_str(RequestIdVisitor)
-        } else {
-            deserializer.deserialize_bytes(RequestIdVisitor)
-        }
-    }
-}
-
-struct RequestIdVisitor;
-
-impl<'de> Visitor<'de> for RequestIdVisitor {
-    type Value = RequestId;
-    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        formatter.write_str("a sha256 hash")
-    }
-
-    fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
-    where
-        E: de::Error,
-    {
-        Ok(RequestId::new(v.try_into().map_err(|_| {
-            E::custom(format_args!("must be 32 bytes long, was {}", v.len()))
-        })?))
-    }
-
-    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
-    where
-        A: de::SeqAccess<'de>,
-    {
-        let mut arr = Sha256Hash::default();
-        for (i, byte) in arr.iter_mut().enumerate() {
-            *byte = seq.next_element()?.ok_or(A::Error::custom(format_args!(
-                "must be 32 bytes long, was {}",
-                i - 1
-            )))?;
-        }
-        if seq.next_element::<u8>()?.is_some() {
-            Err(A::Error::custom("must be 32 bytes long, was more"))
-        } else {
-            Ok(RequestId(arr))
-        }
-    }
-
-    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
-    where
-        E: de::Error,
-    {
-        if v.len() != 64 {
-            return Err(E::custom(format_args!(
-                "must be 32 bytes long, was {}",
-                v.len() / 2
-            )));
-        }
-        let mut arr = Sha256Hash::default();
-        hex::decode_to_slice(v, &mut arr).map_err(E::custom)?;
-        Ok(RequestId(arr))
-    }
-}
diff --git a/ic-utils/src/interfaces/management_canister/builders.rs b/ic-utils/src/interfaces/management_canister/builders.rs
index 1f5fa376..5cab9802 100644
--- a/ic-utils/src/interfaces/management_canister/builders.rs
+++ b/ic-utils/src/interfaces/management_canister/builders.rs
@@ -473,6 +473,27 @@ impl<'agent, 'canister: 'agent> IntoFuture for CreateCanisterBuilder<'agent, 'ca
     }
 }
 
+#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash, CandidType, Copy)]
+/// Wasm main memory retention on upgrades.
+/// Currently used to specify the persistence of Wasm main memory.
+pub enum WasmMemoryPersistence {
+    /// Retain the main memory across upgrades.
+    /// Used for enhanced orthogonal persistence, as implemented in Motoko
+    Keep,
+    /// Reinitialize the main memory on upgrade.
+    /// Default behavior without enhanced orthogonal persistence.
+    Replace,
+}
+
+#[derive(Debug, Copy, Clone, CandidType, Deserialize, Eq, PartialEq)]
+/// Upgrade options.
+pub struct CanisterUpgradeOptions {
+    /// Skip pre-upgrade hook. Only for exceptional cases, see the IC documentation. Not useful for Motoko.
+    pub skip_pre_upgrade: Option<bool>,
+    /// Support for enhanced orthogonal persistence: Retain the main memory on upgrade.
+    pub wasm_memory_persistence: Option<WasmMemoryPersistence>,
+}
+
 /// The install mode of the canister to install. If a canister is already installed,
 /// using [InstallMode::Install] will be an error. [InstallMode::Reinstall] overwrites
 /// the module, and [InstallMode::Upgrade] performs an Upgrade step.
@@ -484,12 +505,9 @@ pub enum InstallMode {
     /// Overwrite the canister with this module.
     #[serde(rename = "reinstall")]
     Reinstall,
-    /// Upgrade the canister with this module.
+    /// Upgrade the canister with this module and some options.
     #[serde(rename = "upgrade")]
-    Upgrade {
-        /// If true, skip a canister's `#[pre_upgrade]` function.
-        skip_pre_upgrade: Option<bool>,
-    },
+    Upgrade(Option<CanisterUpgradeOptions>),
 }
 
 /// A prepared call to `install_code`.
@@ -514,9 +532,7 @@ impl FromStr for InstallMode {
         match s {
             "install" => Ok(InstallMode::Install),
             "reinstall" => Ok(InstallMode::Reinstall),
-            "upgrade" => Ok(InstallMode::Upgrade {
-                skip_pre_upgrade: Some(false),
-            }),
+            "upgrade" => Ok(InstallMode::Upgrade(None)),
             &_ => Err(format!("Invalid install mode: {}", s)),
         }
     }
diff --git a/ref-tests/tests/ic-ref.rs b/ref-tests/tests/ic-ref.rs
index 2bf4271e..fb4e9f32 100644
--- a/ref-tests/tests/ic-ref.rs
+++ b/ref-tests/tests/ic-ref.rs
@@ -33,7 +33,9 @@ mod management_canister {
         call::AsyncCall,
         interfaces::{
             management_canister::{
-                builders::{CanisterSettings, InstallMode},
+                builders::{
+                    CanisterSettings, CanisterUpgradeOptions, InstallMode, WasmMemoryPersistence,
+                },
                 CanisterStatus, StatusCallResult,
             },
             wallet::CreateResult,
@@ -161,18 +163,20 @@ mod management_canister {
 
             // Upgrade should succeed.
             ic00.install_code(&canister_id, &canister_wasm)
-                .with_mode(InstallMode::Upgrade {
-                    skip_pre_upgrade: None,
-                })
+                .with_mode(InstallMode::Upgrade(Some(CanisterUpgradeOptions {
+                    skip_pre_upgrade: Some(true),
+                    wasm_memory_persistence: None,
+                })))
                 .call_and_wait()
                 .await?;
 
             // Upgrade with another agent should fail.
             let result = other_ic00
                 .install_code(&canister_id, &canister_wasm)
-                .with_mode(InstallMode::Upgrade {
+                .with_mode(InstallMode::Upgrade(Some(CanisterUpgradeOptions {
                     skip_pre_upgrade: None,
-                })
+                    wasm_memory_persistence: Some(WasmMemoryPersistence::Keep),
+                })))
                 .call_and_wait()
                 .await;
             assert!(matches!(result, Err(AgentError::UncertifiedReject(..))));
@@ -302,7 +306,7 @@ mod management_canister {
                 .iter()
                 .cloned()
                 .collect::<HashSet<_>>();
-            let expected = vec![agent_principal, other_agent_principal]
+            let expected = [agent_principal, other_agent_principal]
                 .iter()
                 .cloned()
                 .collect::<HashSet<_>>();
@@ -320,7 +324,7 @@ mod management_canister {
                 .iter()
                 .cloned()
                 .collect::<HashSet<_>>();
-            let expected = vec![agent_principal, other_agent_principal]
+            let expected = [agent_principal, other_agent_principal]
                 .iter()
                 .cloned()
                 .collect::<HashSet<_>>();
@@ -485,9 +489,10 @@ mod management_canister {
 
             // Upgrade should succeed
             ic00.install_code(&canister_id, &canister_wasm)
-                .with_mode(InstallMode::Upgrade {
+                .with_mode(InstallMode::Upgrade(Some(CanisterUpgradeOptions {
                     skip_pre_upgrade: None,
-                })
+                    wasm_memory_persistence: Some(WasmMemoryPersistence::Replace),
+                })))
                 .call_and_wait()
                 .await?;
 
diff --git a/rust-toolchain.toml b/rust-toolchain.toml
index cca9dd55..18b402cd 100644
--- a/rust-toolchain.toml
+++ b/rust-toolchain.toml
@@ -2,6 +2,6 @@
 # MSRV
 # Avoid updating this field unless we use new Rust features
 # Sync rust-version in workspace Cargo.toml
-channel = "1.70.0"
+channel = "1.75.0"
 components = ["rustfmt", "clippy"]
 targets = ["wasm32-unknown-unknown"]