From f60e972c48ba46990653c9c3f959caadc7d67a52 Mon Sep 17 00:00:00 2001 From: Vidminas Mikucionis <5411598+Vidminas@users.noreply.github.com> Date: Sun, 21 Apr 2024 19:04:04 +0100 Subject: [PATCH] Remove hardcoded callback uris --- src/chat_app/data/client_id.json | 2 -- src/chat_app/main.py | 15 ++------- src/chat_app/solid_oidc_button.py | 54 +++++++++++++++++++++++-------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/chat_app/data/client_id.json b/src/chat_app/data/client_id.json index be16feb..9625df5 100644 --- a/src/chat_app/data/client_id.json +++ b/src/chat_app/data/client_id.json @@ -3,8 +3,6 @@ "client_id": "https://raw.githubusercontent.com/Vidminas/socialgenpod/main/chat_app/data/client_id.json", "client_name": "Social Gen Pod", - "redirect_uris": ["https://socialgenpod.ryey.icu/callback", "https://socialgenpod.azurewebsites.net/callback", "http://localhost:8501/callback"], - "post_logout_redirect_uris": ["https://socialgenpod.ryey.icu", "https://socialgenpod.azurewebsites.net", "http://localhost:8501"], "client_uri": "https://github.com/Vidminas/socialgenpod", "logo_uri" : "https://raw.githubusercontent.com/Vidminas/socialgenpod/main/chat_app/data/turtle.png", "tos_uri" : "https://github.com/Vidminas/socialgenpod/blob/main/README.md", diff --git a/src/chat_app/main.py b/src/chat_app/main.py index f5b24ff..52b1272 100644 --- a/src/chat_app/main.py +++ b/src/chat_app/main.py @@ -11,17 +11,6 @@ from chat_app.apis.openai_api import OpenAIEmbeddingsAPI, OpenAILLMAPI -@st.cache_data() -def get_callback_uri(): - hostname = os.environ.get("WEBSITE_HOSTNAME") - if hostname is not None: - OAUTH_CALLBACK_URI = f"https://{hostname}/callback" - else: - OAUTH_CALLBACK_URI = "http://localhost:8501/callback" - print(f"Auth endpoint set to {OAUTH_CALLBACK_URI}") - return OAUTH_CALLBACK_URI - - def show_login_sidebar(): from chat_app.solid_oidc_button import SolidOidcComponent @@ -54,7 +43,8 @@ def show_login_sidebar(): if solid_server_url not in st.session_state["solid_idps"]: st.session_state["solid_idps"][solid_server_url] = SolidOidcComponent( - solid_server_url + solid_server_url, + ) solid_client = st.session_state["solid_idps"][solid_server_url] @@ -63,7 +53,6 @@ def show_login_sidebar(): result = solid_client.authorize_button( name="Login with Solid", icon="https://raw.githubusercontent.com/CommunitySolidServer/CommunitySolidServer/main/templates/images/solid.svg", - redirect_uri=get_callback_uri(), key="solid", height=670, width=850, diff --git a/src/chat_app/solid_oidc_button.py b/src/chat_app/solid_oidc_button.py index 7696e59..5b6385b 100644 --- a/src/chat_app/solid_oidc_button.py +++ b/src/chat_app/solid_oidc_button.py @@ -1,3 +1,6 @@ +import os +import json +from pathlib import Path import urllib.parse import requests @@ -17,9 +20,26 @@ @st.cache_data(ttl=300) def generate_pkce_pair(client_id): + # client_id is not used but required to cache separate pkce pairs for different clients return create_verifier_challenge() +@st.cache_data() +def get_hostname_uri(): + hostname = os.environ.get("WEBSITE_HOSTNAME") + if hostname is not None: + return f"https://{hostname}" + else: + return "http://localhost:8501" + + +@st.cache_data() +def get_callback_uri(): + OAUTH_CALLBACK_URI = f"{get_hostname_uri()}/callback" + print(f"Auth endpoint set to {OAUTH_CALLBACK_URI}") + return OAUTH_CALLBACK_URI + + class SolidOidcComponent(OAuth2Component): def __init__(self, solid_server_url: str): self.client_id = "https://raw.githubusercontent.com/Vidminas/socialgenpod/main/chat_app/data/client_id.json" @@ -34,13 +54,18 @@ def __init__(self, solid_server_url: str): if "none" not in client.provider_info["token_endpoint_auth_methods_supported"]: # can't use public client, must register with server - res = requests.get(self.client_id) - client_metadata = res.json() + metadata_path = Path(__file__).parent / "data/client_id.json" + with metadata_path.open() as f: + client_metadata = json.load(f) + registration_response = client.client.register( - client.provider_info['registration_endpoint'], - **client_metadata) - self.client_id = registration_response['client_id'] - self.client_secret = registration_response['client_secret'] + client.provider_info["registration_endpoint"], + redirect_uris=[get_callback_uri()], + post_logout_redirect_uris=[get_hostname_uri()], + **client_metadata, + ) + self.client_id = registration_response["client_id"] + self.client_secret = registration_response["client_secret"] super().__init__( client_id=None, @@ -52,18 +77,18 @@ def __init__(self, solid_server_url: str): client=client, ) - def create_login_uri(self, state, redirect_uri, extras_params): + def create_login_uri(self, state, extras_params): code_verifier, code_challenge = generate_pkce_pair(self.client.client_id) authorization_endpoint = self.client.provider_info["authorization_endpoint"] self.client.storage.set(f"{state}_code_verifier", code_verifier) - self.client.storage.set(f"{state}_redirect_url", redirect_uri) + self.client.storage.set(f"{state}_redirect_url", get_callback_uri()) params = { "code_challenge": code_challenge, "code_challenge_method": "S256", "state": state, "response_type": "code", - "redirect_uri": redirect_uri, + "redirect_uri": get_callback_uri(), "client_id": self.client_id, # offline_access: also asks for refresh token "scope": "openid offline_access", @@ -75,7 +100,6 @@ def create_login_uri(self, state, redirect_uri, extras_params): def authorize_button( self, name, - redirect_uri, height=800, width=600, key=None, @@ -84,7 +108,7 @@ def authorize_button( use_container_width=False, ): state = _generate_state(key) - authorize_request = self.create_login_uri(state, redirect_uri, extras_params) + authorize_request = self.create_login_uri(state, extras_params) result = _authorize_button( authorization_url=authorize_request, name=name, @@ -109,11 +133,15 @@ def authorize_button( res = requests.post( token_endpoint, - auth=(self.client_id, self.client_secret) if self.client_secret is not None else None, + auth=( + (self.client_id, self.client_secret) + if self.client_secret is not None + else None + ), data={ "grant_type": "authorization_code", "client_id": self.client_id, - "redirect_uri": redirect_uri, + "redirect_uri": get_callback_uri(), "code": result["code"], "code_verifier": code_verifier, },