Skip to content

Commit

Permalink
Add mypy (#48)
Browse files Browse the repository at this point in the history
* make sure poetry.lock is up to date

* add mypy and a types lib for cachetools to dev dependencies

* make sure fn is_secret accounts for all cases, refactor to use try/except instead of type-checking

* declare get_secret and is_secret as staticmethods, because they are

* remove type hints on "params" aka "kwargs", unnecessary and results in undesired type-handling behavior

* refactor site_name in handlers/base.py to correctly use parent class methods while maintaining the same functionality

* change from Dict[str, str] to just dict for write_metric's tags param; the values we use are not limited to strings, and if they are not strings they are cast to strings. Can change it to a union of str | int | None if desired

* fix typing for LATENCY_BUFFERS

* simplify push_latency by changing LATENCY_BUFFERS to a defaultdict of LatencyBuffer

* add some type hints to handlers/recommendation.py

* update get_arguments override to be its own function with a different name; overriding a base/parent classes function structure is bad practice and can lead to undesired behavior

* fix validate_filters type annotation for output, remove unused dict output and return None for the case where no error messages are returned

* unnecessary unpacking of a dictionary argument and unnecessary use of kwargs, both in fetch_results; fixed

* fix typing for propagation of `site` through a few functions; there is no case where `site` is None, and should be treated as a str not an Optional[str]. Fixed

* remove code that isn't used from mappings/model.py

* add a mypy.ini file for explicit mypy config

* add more helpful comments to mypy.ini

* add a github action to run on PRs to the main branch that runs mypy if any python file is changed

* update gitignore

* update gitignore

* update readme with information about mypy and how to run it locally

* add whitespace to end of mypy.yml

* make sure poetry.lock and requirements are up to date
  • Loading branch information
raaidarshad authored Jun 30, 2022
1 parent 9c25ec4 commit 2dbc7f5
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 57 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: run mypy
on:
pull_request:
branches:
- main
paths:
- '**.py'

jobs:
typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
# change this python version if the python version is changed in pyproject.toml and the Dockerfile
python-version: '3.9'
- name: Set up Poetry
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: 1.1.13
- name: Install dependencies
run: poetry config virtualenvs.create false && poetry install --no-interaction
- name: run mypy
run: mypy
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
.envrc
.venv
__pycache__
*.egg-info
*.egg-info
.mypy_cache
.pytest_cache
.idea
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ versions. We also use [pre-commit](https://pre-commit.com/) with hooks for [isor
[black](https://github.com/psf/black), and [flake8](https://flake8.pycqa.org/en/latest/) for consistent code style and
readability. Note that this means code that doesn't meet the rules will fail to commit until it is fixed.

We also use [mypy](https://mypy.readthedocs.io/en/stable/index.html) for static type checking. This can be run manually,
and the CI runs it on PRs.

### Setup

1. [Install Poetry](https://python-poetry.org/docs/#installation).
Expand All @@ -170,6 +173,11 @@ This is done with Poetry via the `poetry.lock` file. As for the containerized co

To manually run isort, black, and flake8 all in one go, simply run `pre-commit run --all-files`.

### Run Static Type Checking

To manually run mypy, simply run `mypy` from the root directory of the project. It will use the default configuration
specified in the mypy.ini file.

### Update Dependencies

To update dependencies in your local environment, make changes to the `pyproject.toml` file then run `poetry update`.
Expand Down
4 changes: 2 additions & 2 deletions db/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
MAX_PAGE_SIZE = config.get("MAX_PAGE_SIZE")


def create_resource(mapping_class: BaseMapping, **params: dict) -> int:
def create_resource(mapping_class: BaseMapping, **params) -> int:
resource = mapping_class(**params)
resource.save()
return resource.id
Expand All @@ -23,7 +23,7 @@ def get_resource(mapping_class: BaseMapping, _id: int) -> dict:
return instance.to_dict()


def update_resources(mapping_class: BaseMapping, conditions: Expression, **params: dict) -> None:
def update_resources(mapping_class: BaseMapping, conditions: Expression, **params) -> None:
params["updated_at"] = tzaware_now()
q = mapping_class.update(**params).where(conditions)
q.execute()
Expand Down
15 changes: 1 addition & 14 deletions db/mappings/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import enum
import logging

from peewee import TextField

from db.helpers import update_resources
from db.mappings.base import BaseMapping, db_proxy
from db.mappings.base import BaseMapping


class Type(enum.Enum):
Expand Down Expand Up @@ -33,14 +31,3 @@ class Meta:
type = TextField(null=False)
status = TextField(null=False, default=Status.PENDING.value)
site = TextField(null=False, default="")

# If an exception occurs, the current transaction/savepoint will be rolled back.
# Otherwise the statements will be committed at the end.
@db_proxy.atomic()
def set_current(model_id: int, model_type: Type, model_site: Site) -> None:
current_model_query = (
(Model.type == model_type) & (Model.status == Status.CURRENT.value) & (Model.site == model_site)
)
update_resources(Model, current_model_query, status=Status.STALE.value)
update_resources(Model, Model.id == model_id, status=Status.CURRENT.value)
logging.info(f"Successfully updated model id {model_id} as current '{model_type}' model for '{model_site}'")
19 changes: 8 additions & 11 deletions handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import json
import logging
import time
from collections import defaultdict
from decimal import Decimal
from typing import Dict, List, Tuple

import tornado.web

Expand All @@ -12,8 +12,6 @@
from lib.metrics import Unit, write_metric

DEFAULT_PAGE_SIZE = config.get("DEFAULT_PAGE_SIZE")
# buffer of latency values for each handler/site combination
LATENCY_BUFFERS: Dict[Tuple[str, str], List[float]] = {}


def unix_time_ms(datetime_instance):
Expand Down Expand Up @@ -41,6 +39,10 @@ def flush(self):
return _buffer


# buffer of latency values for each handler/site combination
LATENCY_BUFFERS: dict[tuple[str, str], LatencyBuffer] = defaultdict(LatencyBuffer)


def admin_only(f):
def decorated(self, *args, **kwargs):
admin_token = self.request.headers.get("Authorization")
Expand Down Expand Up @@ -75,8 +77,7 @@ def handler_name(self) -> str:

@property
def site_name(self) -> str:
params = self.get_arguments()
return params.get("site", "n/a")
return self.get_argument("site", "n/a")

def prepare(self):
self.start_time = time.time()
Expand All @@ -92,11 +93,7 @@ def write_error_metric(self, latency: float):

def push_latency(self, latency, handler_name: str, site_name: str) -> None:
key = (handler_name, site_name)
if LATENCY_BUFFERS.get(key):
LATENCY_BUFFERS[key].push(latency)
else:
LATENCY_BUFFERS[key] = LatencyBuffer()
LATENCY_BUFFERS[key].push(latency)
LATENCY_BUFFERS[key].push(latency)

def on_finish(self):
if self.handler_name == "Health":
Expand Down Expand Up @@ -137,7 +134,7 @@ class APIHandler(BaseHandler):
def __init__(self, *args, **kwargs):
super(APIHandler, self).__init__(*args, **kwargs)

def get_arguments(self):
def get_arguments_as_dict(self) -> dict:
arguments = {k: self.get_argument(k) for k in self.request.arguments}
arguments["size"] = arguments.get("size", DEFAULT_PAGE_SIZE)
return arguments
Expand Down
2 changes: 1 addition & 1 deletion handlers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def apply_conditions(self, query, **filters):

@retry_rollback
async def get(self):
filters = self.get_arguments()
filters = self.get_arguments_as_dict()
query = self.mapping.select()
query = self.apply_conditions(query, **filters)
query = self.apply_sort(query, **filters)
Expand Down
24 changes: 11 additions & 13 deletions handlers/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DEFAULT_SITE = config.get("DEFAULT_SITE")
STALE_AFTER_MIN = 15
# each result takes roughly 50,000 bytes; 2048 cached results ~= 100 MBs
TTL_CACHE = TTLCache(maxsize=2048, ttl=STALE_AFTER_MIN * 60)
TTL_CACHE: TTLCache = TTLCache(maxsize=2048, ttl=STALE_AFTER_MIN * 60)
# counter of default recs served for site
DEFAULT_REC_COUNTER: Dict[str, int] = {}
# counter of db hits by site
Expand All @@ -28,7 +28,7 @@
TOTAL_HANDLED: Dict[str, int] = {}


def incr_metric_total(counter: Dict[str, int], site: str) -> None:
def incr_metric_total(counter: dict[str, int], site: str) -> None:
"""
increment running metric totals to be flushed on an interval
"""
Expand All @@ -46,8 +46,8 @@ def instance_unaware_key(instance, *args, **kwargs):

class DefaultRecs:
DEFAULT_TYPE = Type.POPULARITY.value
_recs = {}
_last_updated = {}
_recs: dict[str, list[dict]] = {}
_last_updated: dict[str, datetime] = {}

@classmethod
@retry_rollback
Expand Down Expand Up @@ -124,9 +124,7 @@ def apply_conditions(self, query, **filters):

return query

def validate_filters(self, **filters) -> Dict[str, str]:
error_msgs = {}

def validate_filters(self, **filters) -> Optional[str]:
if "exclude" in filters:
for exclude in filters["exclude"].split(","):
try:
Expand All @@ -146,13 +144,13 @@ def validate_filters(self, **filters) -> Dict[str, str]:
except (ValueError, AssertionError):
return f"Invalid input for 'size' (int), must be below {MAX_PAGE_SIZE}: {filters['size']}"

return error_msgs
return None

@cached(cache=TTL_CACHE, key=instance_unaware_key)
def fetch_cached_results(
self,
site: str,
source_entity_id: Optional[str] = None,
site: Optional[str] = None,
model_type: Optional[str] = None,
model_id: Optional[str] = None,
exclude: Optional[str] = None,
Expand All @@ -168,10 +166,10 @@ def fetch_cached_results(
incr_metric_total(DB_HIT_COUNTER, site)
return [x.to_dict() for x in query]

def fetch_results(self, **filters: Dict[str, str]) -> List[Rec]:
def fetch_results(self, filters: dict[str, str]) -> List[Rec]:
results = self.fetch_cached_results(
site=filters["site"],
source_entity_id=filters.get("source_entity_id"),
site=filters.get("site"),
model_type=filters.get("model_type"),
model_id=filters.get("model_id"),
exclude=filters.get("exclude"),
Expand All @@ -183,14 +181,14 @@ def fetch_results(self, **filters: Dict[str, str]) -> List[Rec]:

@retry_rollback
async def get(self):
filters = self.get_arguments()
filters = self.get_arguments_as_dict()
filters["site"] = filters.get("site", DEFAULT_SITE)
validation_errors = self.validate_filters(**filters)
if validation_errors:
raise tornado.web.HTTPError(status_code=400, log_message=validation_errors)

res = {
"results": self.fetch_results(**filters)
"results": self.fetch_results(filters)
or DefaultRecs.get_recs(filters["site"], filters.get("source_entity_id"), int(filters["size"])),
}
incr_metric_total(TOTAL_HANDLED, filters["site"])
Expand Down
13 changes: 7 additions & 6 deletions lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class Config:
def __init__(self):
self._config = self.load_env()

def get_secret(self, secret_key: str) -> Any:
@staticmethod
def get_secret(secret_key: str) -> Any:
res = CLIENT.get_parameter(Name=secret_key, WithDecryption=True)
val = res["Parameter"]["Value"]
try:
Expand All @@ -28,13 +29,13 @@ def get_secret(self, secret_key: str) -> Any:
pass
return val

def is_secret(self, val: str) -> bool:
if not isinstance(val, str):
@staticmethod
def is_secret(val: str) -> bool:
try:
return val.startswith("/prod") or val.startswith("/dev")
except AttributeError:
return False

if val.startswith("/prod") or val.startswith("/dev"):
return True

def get(self, var_name: str) -> Any:
try:
val = self._config[var_name]
Expand Down
2 changes: 1 addition & 1 deletion lib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def write_metric(
name: str,
value: float,
unit: str = Unit.COUNT,
tags: Dict[str, str] = None,
tags: dict = None,
) -> None:
if STAGE == "local":
logging.info(f"Skipping metric write for name:{name} | value:{value} | tags:{tags}")
Expand Down
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[mypy]
# documentation reference: https://mypy.readthedocs.io/en/stable/config_file.html
# start with this set to true, set to false and address 3rd party packages in future
ignore_missing_imports = True
# will also want to eventually check for missing type annotations
files = app.py
Loading

0 comments on commit 2dbc7f5

Please sign in to comment.