Skip to content

Commit

Permalink
ci: print decoded token ids in the log (#4646)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Sep 18, 2024
1 parent 556deba commit 79cc9c3
Showing 1 changed file with 52 additions and 35 deletions.
87 changes: 52 additions & 35 deletions integration_tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import re
import secrets
import sys
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -35,6 +36,7 @@
from urllib.request import urlopen

import httpx
import jwt
import pytest
from httpx import Headers, HTTPStatusError
from openinference.semconv.resource import ResourceAttributes
Expand Down Expand Up @@ -421,25 +423,19 @@ def __new__(cls) -> Self:
raise NotImplementedError("This class is intended as a singleton to be used directly.")

@classmethod
def stash(cls, port: int, cookies: str) -> None:
for cookie in cookies.split(","):
if cookie.startswith(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) or cookie.startswith(
PHOENIX_REFRESH_TOKEN_COOKIE_NAME
):
token = _get_token_from_cookie(cookie)
with cls._lock:
cls._set.add((port, token))
def stash(cls, port: int, headers: Headers) -> None:
tokens = _extract_tokens(headers, "set-cookie").values()
for token in tokens:
with cls._lock:
cls._set.add((port, token))

@classmethod
def intersect(cls, port: int, cookies: str) -> bool:
for cookie in cookies.split(","):
if cookie.startswith(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) or cookie.startswith(
PHOENIX_REFRESH_TOKEN_COOKIE_NAME
):
token = _get_token_from_cookie(cookie)
with cls._lock:
if (port, token) in cls._set:
return True
def intersect(cls, port: int, headers: Headers) -> bool:
tokens = _extract_tokens(headers).values()
for token in tokens:
with cls._lock:
if (port, token) in cls._set:
return True
return False


Expand All @@ -464,21 +460,12 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
if "auth/login" in path:
sequester_tokens = DEFAULT_ADMIN_EMAIL in request.content.decode()
elif "auth/refresh" in path:
if cookies := headers.get("cookie"):
sequester_tokens = _DefaultAdminTokens.intersect(port, cookies)
elif (
"auth/logout" in path
and (cookies := headers.get("cookie"))
and _DefaultAdminTokens.intersect(port, cookies)
):
sequester_tokens = _DefaultAdminTokens.intersect(port, headers)
elif "auth/logout" in path and _DefaultAdminTokens.intersect(port, headers):
raise self.exc_cls(self.message)
response = self._transport.handle_request(request)
if (
sequester_tokens
and response.status_code // 100 == 2
and (cookies := response.headers.get("set-cookie"))
):
_DefaultAdminTokens.stash(port, cookies)
if sequester_tokens and response.status_code // 100 == 2:
_DefaultAdminTokens.stash(port, response.headers)
return response


Expand All @@ -498,7 +485,7 @@ def _get_token_from_cookie(cookie: str) -> str:
return cookie.split(";", 1)[0].split("=", 1)[1]


_TEST_NAME: ContextVar[str] = ContextVar("test_name")
_TEST_NAME: ContextVar[str] = ContextVar("test_name", default="")
_HTTPX_OP_IDX: ContextVar[int] = ContextVar("httpx_operation_index", default=0)


Expand All @@ -509,16 +496,21 @@ def __init__(self, transport: httpx.BaseTransport) -> None:
def handle_request(self, request: httpx.Request) -> httpx.Response:
info = BytesIO()
info.write(f"{'-'*50}\n".encode())
op_idx = _HTTPX_OP_IDX.get()
_HTTPX_OP_IDX.set(op_idx + 1)
info.write(f"({op_idx})".encode())
info.write(f"{_TEST_NAME.get()}\n".encode())
if test_name := _TEST_NAME.get():
op_idx = _HTTPX_OP_IDX.get()
_HTTPX_OP_IDX.set(op_idx + 1)
info.write(f"({op_idx})".encode())
info.write(f"{test_name}\n".encode())
response = self._transport.handle_request(request)
info.write(f"{response.status_code} {request.method} {request.url}\n".encode())
if token_ids := _decode_token_ids(request.headers):
info.write(f"{' '.join(token_ids)}\n".encode())
info.write(f"{request.headers}\n".encode())
info.write(request.read())
info.write(b"\n")
info.write(f"{response.headers}\n".encode())
if returned_token_ids := _decode_token_ids(response.headers, "set-cookie"):
info.write(f"{' '.join(returned_token_ids)}\n".encode())
return _LogResponse(
info=info,
status_code=response.status_code,
Expand Down Expand Up @@ -895,3 +887,28 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: ...
_EXPECTATION_401 = pytest.raises(HTTPStatusError, match="401 Unauthorized")
_EXPECTATION_403 = pytest.raises(HTTPStatusError, match="403 Forbidden")
_EXPECTATION_404 = pytest.raises(HTTPStatusError, match="404 Not Found")


def _extract_tokens(
headers: Headers,
key: Literal["cookie", "set-cookie"] = "cookie",
) -> Dict[str, str]:
if not (cookies := headers.get(key)):
return {}
parts = re.split(r"[ ,;=]", cookies)
return {
k: v
for k, v in zip(parts[:-1], parts[1:])
if v.strip('"')
and k in (PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
}


def _decode_token_ids(
headers: Headers,
key: Literal["cookie", "set-cookie"] = "cookie",
) -> List[str]:
return [
jwt.decode(v, options={"verify_signature": False})["jti"]
for v in _extract_tokens(headers, key).values()
]

0 comments on commit 79cc9c3

Please sign in to comment.