From 56c4e20d1ffdf2376ee8b2a0fefb6355902e8ebd Mon Sep 17 00:00:00 2001
From: Ray Luo <rayluo@microsoft.com>
Date: Fri, 5 Apr 2024 16:54:42 -0700
Subject: [PATCH] Switch to MSAL 1.27+'s TokenCache._find()

---
 docker_run.sh                  |  4 +++-
 msal_extensions/__init__.py    |  2 +-
 msal_extensions/token_cache.py |  4 ++--
 setup.py                       |  2 +-
 tests/test_agnostic_backend.py | 36 +++++++++++++++-------------------
 tests/test_macos_backend.py    |  1 +
 tests/test_windows_backend.py  |  1 +
 7 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/docker_run.sh b/docker_run.sh
index 9836192..04914d7 100755
--- a/docker_run.sh
+++ b/docker_run.sh
@@ -6,9 +6,11 @@ docker build -t $IMAGE_NAME - < Dockerfile
 echo "==== Integration Test for Persistence on Linux (libsecret) ===="
 echo "After seeing the bash prompt, run the following to test encryption on Linux:"
 echo "    pip install -e ."
-echo "    pytest"
+echo "    pytest -s tests/chosen_test_file.py"
+echo "Note that you probably need to set up ENV VAR for the test cases to run"
 docker run --rm -it \
     --privileged \
+    --env-file .env \
     -w /home -v $PWD:/home \
     $IMAGE_NAME \
     $1
diff --git a/msal_extensions/__init__.py b/msal_extensions/__init__.py
index 31b07c1..f0ee7ce 100644
--- a/msal_extensions/__init__.py
+++ b/msal_extensions/__init__.py
@@ -1,5 +1,5 @@
 """Provides auxiliary functionality to the `msal` package."""
-__version__ = "1.1.0"
+__version__ = "1.2.0a1"  # Note: During/after release, copy this number to Dockerfile
 
 from .persistence import (
     FilePersistence,
diff --git a/msal_extensions/token_cache.py b/msal_extensions/token_cache.py
index 119c9c2..99fd4b2 100644
--- a/msal_extensions/token_cache.py
+++ b/msal_extensions/token_cache.py
@@ -69,7 +69,7 @@ def modify(self, credential_type, old_entry, new_key_value_pairs=None):
             self._persistence.save(self.serialize())
             self._last_sync = time.time()
 
-    def find(self, credential_type, **kwargs):  # pylint: disable=arguments-differ
+    def _find(self, credential_type, **kwargs):  # pylint: disable=arguments-differ
         # Use optimistic locking rather than CrossPlatLock(self._lock_location)
         retry = 3
         for attempt in range(1, retry + 1):
@@ -83,6 +83,6 @@ def find(self, credential_type, **kwargs):  # pylint: disable=arguments-differ
                 else:
                     raise  # End of retry. Re-raise the exception as-is.
             else:  # If reload encountered no error, the data is considered intact
-                return super(PersistedTokenCache, self).find(credential_type, **kwargs)
+                return super(PersistedTokenCache, self)._find(credential_type, **kwargs)
         return []  # Not really reachable here. Just to keep pylint happy.
 
diff --git a/setup.py b/setup.py
index d49bb23..04a21f2 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@
     package_data={'': ['LICENSE']},
     python_requires=">=3.7",
     install_requires=[
-        'msal>=0.4.1,<2.0.0',
+        'msal>=1.27,<2.0.0',
         'portalocker<3,>=1.4',
 
         ## We choose to NOT define a hard dependency on this.
diff --git a/tests/test_agnostic_backend.py b/tests/test_agnostic_backend.py
index 2d8454f..9f6eca9 100644
--- a/tests/test_agnostic_backend.py
+++ b/tests/test_agnostic_backend.py
@@ -15,32 +15,28 @@ def temp_location():
     yield os.path.join(test_folder, 'token_cache.bin')
     shutil.rmtree(test_folder, ignore_errors=True)
 
-
-def _test_token_cache_roundtrip(cache):
+def _test_token_cache_roundtrip(persistence):
     client_id = os.getenv('AZURE_CLIENT_ID')
     client_secret = os.getenv('AZURE_CLIENT_SECRET')
     if not (client_id and client_secret):
         pytest.skip('no credentials present to test TokenCache round-trip with.')
 
-    app = msal.ConfidentialClientApplication(
-        client_id=client_id,
-        client_credential=client_secret,
-        token_cache=cache)
     desired_scopes = ['https://graph.microsoft.com/.default']
-    token1 = app.acquire_token_for_client(scopes=desired_scopes)
-    os.utime(  # Mock having another process update the cache
-        cache._persistence.get_location(), None)
-    token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
-    assert token1['access_token'] == token2['access_token']
-
-def test_file_token_cache_roundtrip(temp_location):
-    _test_token_cache_roundtrip(PersistedTokenCache(FilePersistence(temp_location)))
-
-def test_current_platform_cache_roundtrip_with_persistence_builder(temp_location):
-    _test_token_cache_roundtrip(PersistedTokenCache(build_encrypted_persistence(temp_location)))
-
-def test_persisted_token_cache(temp_location):
-    _test_token_cache_roundtrip(PersistedTokenCache(FilePersistence(temp_location)))
+    apps = [  # Multiple apps sharing same persistence
+        msal.ConfidentialClientApplication(
+        client_id, client_credential=client_secret,
+        token_cache=PersistedTokenCache(persistence)) for i in range(2)]
+    token1 = apps[0].acquire_token_for_client(scopes=desired_scopes)
+    assert token1["token_source"] == "identity_provider", "Initial token should come from IdP"
+    token2 = apps[1].acquire_token_for_client(scopes=desired_scopes)  # Hit token cache in MSAL 1.23+
+    assert token2["token_source"] == "cache", "App2 should hit cache written by app1"
+    assert token1['access_token'] == token2['access_token'], "Cache should hit"
+
+def test_token_cache_roundtrip_with_persistence_biulder(temp_location):
+    _test_token_cache_roundtrip(build_encrypted_persistence(temp_location))
+
+def test_token_cache_roundtrip_with_file_persistence(temp_location):
+    _test_token_cache_roundtrip(FilePersistence(temp_location))
 
 def test_file_not_found_error_is_not_raised():
     persistence = FilePersistence('non_existing_file')
diff --git a/tests/test_macos_backend.py b/tests/test_macos_backend.py
index dfc7ca2..210f5e9 100644
--- a/tests/test_macos_backend.py
+++ b/tests/test_macos_backend.py
@@ -39,6 +39,7 @@ def test_osx_token_cache_roundtrip():
             token_cache=subject)
         desired_scopes = ['https://graph.microsoft.com/.default']
         token1 = app.acquire_token_for_client(scopes=desired_scopes)
+        # TODO: Modify this to same approach in test_agnostic_backend.py
         os.utime(cache_file, None)  # Mock having another process update the cache.
         token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
         assert token1['access_token'] == token2['access_token']
diff --git a/tests/test_windows_backend.py b/tests/test_windows_backend.py
index 6de0094..ab3197a 100644
--- a/tests/test_windows_backend.py
+++ b/tests/test_windows_backend.py
@@ -93,6 +93,7 @@ def test_windows_token_cache_roundtrip():
             token_cache=subject)
         desired_scopes = ['https://graph.microsoft.com/.default']
         token1 = app.acquire_token_for_client(scopes=desired_scopes)
+        # TODO: Modify this to same approach in test_agnostic_backend.py
         os.utime(cache_file, None)  # Mock having another process update the cache.
         token2 = app.acquire_token_silent(scopes=desired_scopes, account=None)
         assert token1['access_token'] == token2['access_token']