From a8f66414724dc0d59f9bb73d83259fa5fac48d91 Mon Sep 17 00:00:00 2001 From: ajermaky Date: Tue, 30 Apr 2024 16:55:57 -0700 Subject: [PATCH] lint --- lib/go/edgecontext/oauth_client.go | 4 +- lib/go/edgecontext/service.go | 9 + lib/py/reddit_edgecontext/__init__.py | 9 +- lib/py/reddit_edgecontext/thrift/ttypes.py | 187 +++------------------ lib/py/tests/edge_context_tests.py | 10 +- 5 files changed, 40 insertions(+), 179 deletions(-) diff --git a/lib/go/edgecontext/oauth_client.go b/lib/go/edgecontext/oauth_client.go index 7b66651..ac0dbd6 100644 --- a/lib/go/edgecontext/oauth_client.go +++ b/lib/go/edgecontext/oauth_client.go @@ -23,11 +23,11 @@ func (o OAuthClient) ID() string { // // For example, use: // -// if client.IsType("third_party") +// if client.IsType("third_party") // // Instead of: // -// if !client.IsType("first_party") +// if !client.IsType("first_party") func (o OAuthClient) IsType(types ...string) bool { clientType := AuthenticationToken(o).OAuthClientType for _, t := range types { diff --git a/lib/go/edgecontext/service.go b/lib/go/edgecontext/service.go index f358995..d839525 100644 --- a/lib/go/edgecontext/service.go +++ b/lib/go/edgecontext/service.go @@ -27,6 +27,10 @@ func (s Service) Name() (name string, ok bool) { return } +// OnBehalfOfID returns the ID of the user on whose behalf the service is acting. +// +// If it's not coming from an authenticated service, +// ("", false) will be returned. func (s Service) OnBehalfOfID() (id string, ok bool) { if s.isService() { token := AuthenticationToken(s) @@ -41,6 +45,10 @@ func (s Service) OnBehalfOfID() (id string, ok bool) { return } +// OnBehalfOfRoles returns the roles of the user on whose behalf the service is acting. +// +// If it's not coming from an authenticated service, +// (nil, false) will be returned. func (s Service) OnBehalfOfRoles() (roles []string, ok bool) { if s.isService() { token := AuthenticationToken(s) @@ -52,6 +60,7 @@ func (s Service) OnBehalfOfRoles() (roles []string, ok bool) { return } +// IsElevatedAccess returns whether the service requested elevated access. func (s Service) IsElevatedAccess() bool { if s.isService() { return AuthenticationToken(s).ServiceRequestedElevatedAccess diff --git a/lib/py/reddit_edgecontext/__init__.py b/lib/py/reddit_edgecontext/__init__.py index acd1531..23b4742 100644 --- a/lib/py/reddit_edgecontext/__init__.py +++ b/lib/py/reddit_edgecontext/__init__.py @@ -17,9 +17,6 @@ from baseplate.lib.edgecontext import EdgeContextFactory as BaseEdgeContextFactory from baseplate.lib.secrets import SecretsStore from jwt.algorithms import get_default_algorithms -from thrift import TSerialization -from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory - from reddit_edgecontext.thrift.ttypes import Device as TDevice from reddit_edgecontext.thrift.ttypes import Geolocation as TGeolocation from reddit_edgecontext.thrift.ttypes import Locale as TLocale @@ -28,6 +25,8 @@ from reddit_edgecontext.thrift.ttypes import Request as TRequest from reddit_edgecontext.thrift.ttypes import RequestId as TRequestId from reddit_edgecontext.thrift.ttypes import Session as TSession +from thrift import TSerialization +from thrift.protocol.TBinaryProtocol import TBinaryProtocolAcceleratedFactory logger = logging.getLogger(__name__) @@ -537,9 +536,7 @@ def request_id(self) -> RequestId: @cached_property def locale(self) -> Locale: """:py:class:`~reddit_edgecontext.Locale` object for the current context.""" - return Locale( - locale_code=self._t_request.locale.locale_code, - ) + return Locale(locale_code=self._t_request.locale.locale_code) @cached_property def _t_request(self) -> TRequest: diff --git a/lib/py/reddit_edgecontext/thrift/ttypes.py b/lib/py/reddit_edgecontext/thrift/ttypes.py index 8538076..aaa3f05 100644 --- a/lib/py/reddit_edgecontext/thrift/ttypes.py +++ b/lib/py/reddit_edgecontext/thrift/ttypes.py @@ -37,16 +37,9 @@ class Loid(object): """ - __slots__ = ( - "id", - "created_ms", - ) + __slots__ = ("id", "created_ms") - def __init__( - self, - id=None, - created_ms=None, - ): + def __init__(self, id=None, created_ms=None): self.id = id self.created_ms = created_ms @@ -137,10 +130,7 @@ class Session(object): __slots__ = ("id",) - def __init__( - self, - id=None, - ): + def __init__(self, id=None): self.id = id def read(self, iprot): @@ -221,10 +211,7 @@ class Device(object): __slots__ = ("id",) - def __init__( - self, - id=None, - ): + def __init__(self, id=None): self.id = id def read(self, iprot): @@ -306,10 +293,7 @@ class OriginService(object): __slots__ = ("name",) - def __init__( - self, - name=None, - ): + def __init__(self, name=None): self.name = name def read(self, iprot): @@ -389,10 +373,7 @@ class Geolocation(object): __slots__ = ("country_code",) - def __init__( - self, - country_code=None, - ): + def __init__(self, country_code=None): self.country_code = country_code def read(self, iprot): @@ -473,10 +454,7 @@ class RequestId(object): __slots__ = ("readable_id",) - def __init__( - self, - readable_id=None, - ): + def __init__(self, readable_id=None): self.readable_id = readable_id def read(self, iprot): @@ -560,10 +538,7 @@ class Locale(object): __slots__ = ("locale_code",) - def __init__( - self, - locale_code=None, - ): + def __init__(self, locale_code=None): self.locale_code = locale_code def read(self, iprot): @@ -818,146 +793,32 @@ def __ne__(self, other): all_structs.append(Loid) Loid.thrift_spec = ( None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 - ( - 2, - TType.I64, - "created_ms", - None, - None, - ), # 2 + (1, TType.STRING, "id", "UTF8", None), # 1 + (2, TType.I64, "created_ms", None, None), # 2 ) all_structs.append(Session) -Session.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 -) +Session.thrift_spec = (None, (1, TType.STRING, "id", "UTF8", None)) # 0 # 1 all_structs.append(Device) -Device.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 -) +Device.thrift_spec = (None, (1, TType.STRING, "id", "UTF8", None)) # 0 # 1 all_structs.append(OriginService) -OriginService.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "name", - "UTF8", - None, - ), # 1 -) +OriginService.thrift_spec = (None, (1, TType.STRING, "name", "UTF8", None)) # 0 # 1 all_structs.append(Geolocation) -Geolocation.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "country_code", - "UTF8", - None, - ), # 1 -) +Geolocation.thrift_spec = (None, (1, TType.STRING, "country_code", "UTF8", None)) # 0 # 1 all_structs.append(RequestId) -RequestId.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "readable_id", - "UTF8", - None, - ), # 1 -) +RequestId.thrift_spec = (None, (1, TType.STRING, "readable_id", "UTF8", None)) # 0 # 1 all_structs.append(Locale) -Locale.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "locale_code", - "UTF8", - None, - ), # 1 -) +Locale.thrift_spec = (None, (1, TType.STRING, "locale_code", "UTF8", None)) # 0 # 1 all_structs.append(Request) Request.thrift_spec = ( None, # 0 - ( - 1, - TType.STRUCT, - "loid", - [Loid, None], - None, - ), # 1 - ( - 2, - TType.STRUCT, - "session", - [Session, None], - None, - ), # 2 - ( - 3, - TType.STRING, - "authentication_token", - "UTF8", - None, - ), # 3 - ( - 4, - TType.STRUCT, - "device", - [Device, None], - None, - ), # 4 - ( - 5, - TType.STRUCT, - "origin_service", - [OriginService, None], - None, - ), # 5 - ( - 6, - TType.STRUCT, - "geolocation", - [Geolocation, None], - None, - ), # 6 - ( - 7, - TType.STRUCT, - "request_id", - [RequestId, None], - None, - ), # 7 - ( - 8, - TType.STRUCT, - "locale", - [Locale, None], - None, - ), # 8 + (1, TType.STRUCT, "loid", [Loid, None], None), # 1 + (2, TType.STRUCT, "session", [Session, None], None), # 2 + (3, TType.STRING, "authentication_token", "UTF8", None), # 3 + (4, TType.STRUCT, "device", [Device, None], None), # 4 + (5, TType.STRUCT, "origin_service", [OriginService, None], None), # 5 + (6, TType.STRUCT, "geolocation", [Geolocation, None], None), # 6 + (7, TType.STRUCT, "request_id", [RequestId, None], None), # 7 + (8, TType.STRUCT, "locale", [Locale, None], None), # 8 ) fix_spec(all_structs) del all_structs diff --git a/lib/py/tests/edge_context_tests.py b/lib/py/tests/edge_context_tests.py index 3fe909c..6f4a84c 100644 --- a/lib/py/tests/edge_context_tests.py +++ b/lib/py/tests/edge_context_tests.py @@ -2,7 +2,6 @@ import unittest from baseplate.testing.lib.secrets import FakeSecretsStore - from reddit_edgecontext import EdgeContextFactory from reddit_edgecontext import InvalidAuthenticationToken from reddit_edgecontext import NoAuthenticationError @@ -135,10 +134,7 @@ def test_validated_service_authentication_token(self): payload = { "sub": "service/test-service", "exp": 1574458470, - "obo": { - "aid": "t2_deadbeef", - "roles": ["admin"], - }, + "obo": {"aid": "t2_deadbeef", "roles": ["admin"]}, "sea": True, } @@ -208,9 +204,7 @@ def test_create_empty_context(self): self.assertEqual( request_context._header, # loid - b"\x0c\x00" # STRUCT - b"\x01" # tag number - b"\x00" # END STRUCT + b"\x0c\x00" b"\x01" b"\x00" # STRUCT # tag number # END STRUCT # session b"\x0c\x00\x02\x00" # device