Skip to content

Commit

Permalink
Merge pull request #1086 from uc-cdis/fix/google_update
Browse files Browse the repository at this point in the history
feat(google_update): updated Flask and deps, also don't fail early with Google errors PXP-10900
  • Loading branch information
Avantol13 authored May 4, 2023
2 parents 5b8f223 + f5f5b25 commit 0304957
Show file tree
Hide file tree
Showing 36 changed files with 845 additions and 869 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ addons:
postgresql: "9.6"

install:
- pip install --upgrade pip
- pip install poetry
- poetry install -vv --no-interaction
- poetry install --all-extras -vv --no-interaction
- poetry show -vv
- psql -c 'SELECT version();' -U postgres
- psql -U postgres -c "create database fence_test_tmp"
Expand Down
45 changes: 30 additions & 15 deletions fence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from urllib.parse import urljoin
import flask
from flask_cors import CORS
from flask_sqlalchemy_session import flask_scoped_session, current_session
from sqlalchemy.orm import scoped_session
from flask import _app_ctx_stack, current_app
from werkzeug.local import LocalProxy

from authutils.oauth2.client import OAuthClient
from cdislogging import get_logger
Expand Down Expand Up @@ -112,7 +114,24 @@ def app_sessions(app):
SQLAlchemyDriver.setup_db = lambda _: None
app.db = SQLAlchemyDriver(config["DB"])

session = flask_scoped_session(app.db.Session, app) # noqa
# app.db.Session is from SQLAlchemyDriver and uses
# SQLAlchemy's sessionmaker. Using scoped_session here ensures
# a thread-local db session is created. Effectively the below is expanded to:
# app.scoped_session = scoped_session(
# sessionmaker(
# bind=sqlalchemy.create_engine(config["DB"]),
# expire_on_commit=False,
# )
# )
#
# From the sqlalchemy docs: "The scoped_session object by default uses
# [Python's threading.local() construct] as
# storage, so that a single Session is maintained for all who call upon the
# scoped_session registry, but only within the scope of a single thread.
# Callers who call upon the registry in a different thread get a Session
# instance that is local to that other thread."
app.scoped_session = scoped_session(app.db.Session)

app.session_interface = UserSessionInterface()


Expand Down Expand Up @@ -509,28 +528,15 @@ def _setup_prometheus(app):
multiprocess,
make_wsgi_app,
)
from prometheus_flask_exporter import Counter
from prometheus_flask_exporter.multiprocess import (
UWsgiPrometheusMetrics,
)

app.prometheus_registry = CollectorRegistry()
multiprocess.MultiProcessCollector(app.prometheus_registry)

UWsgiPrometheusMetrics(app)

# Add prometheus wsgi middleware to route /metrics requests
app.wsgi_app = DispatcherMiddleware(
app.wsgi_app, {"/metrics": make_wsgi_app(registry=app.prometheus_registry)}
)

# set up counters
app.prometheus_counters["pre_signed_url_req"] = Counter(
"pre_signed_url_req",
"tracking presigned url requests",
["requested_protocol"],
)


@app.errorhandler(Exception)
def handle_error(error):
Expand Down Expand Up @@ -569,3 +575,12 @@ def check_csrf():
logger.debug("HTTP REFERER " + str(referer))
except Exception as e:
raise UserError("CSRF verification failed: {}. Request aborted".format(e))


@app.teardown_appcontext
def remove_scoped_session(*args, **kwargs):
if hasattr(app, "scoped_session"):
try:
app.scoped_session.remove()
except Exception as exc:
logger.warning(f"could not remove app.scoped_session. Error: {exc}")
34 changes: 21 additions & 13 deletions fence/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import flask
from flask_sqlalchemy_session import current_session
from flask import current_app
from datetime import datetime
from functools import wraps
import urllib.request, urllib.parse, urllib.error
Expand Down Expand Up @@ -95,7 +95,7 @@ def set_flask_session_values(user):
flask.g.scopes = ["_all"]
flask.g.token = None

user = query_for_user(session=current_session, username=username)
user = query_for_user(session=current_app.scoped_session(), username=username)
if user:
_update_users_email(user, email)
_update_users_id_from_idp(user, id_from_idp)
Expand All @@ -120,16 +120,17 @@ def set_flask_session_values(user):

# setup idp connection for new user (or existing user w/o it setup)
idp = (
current_session.query(IdentityProvider)
current_app.scoped_session()
.query(IdentityProvider)
.filter(IdentityProvider.name == provider)
.first()
)
if not idp:
idp = IdentityProvider(name=provider)

user.identity_provider = idp
current_session.add(user)
current_session.commit()
current_app.scoped_session().add(user)
current_app.scoped_session().commit()

set_flask_session_values(user)

Expand Down Expand Up @@ -238,7 +239,9 @@ def has_oauth(scope=None):
raise Unauthorized("failed to validate token: {}".format(e))
if "sub" in access_token_claims:
user_id = access_token_claims["sub"]
user = current_session.query(User).filter_by(id=int(user_id)).first()
user = (
current_app.scoped_session().query(User).filter_by(id=int(user_id)).first()
)
if not user:
raise Unauthorized("no user found with id: {}".format(user_id))
# set some application context for current user
Expand All @@ -250,7 +253,12 @@ def has_oauth(scope=None):


def get_user_from_claims(claims):
return current_session.query(User).filter(User.id == claims["sub"]).first()
return (
current_app.scoped_session()
.query(User)
.filter(User.id == claims["sub"])
.first()
)


def admin_required(f):
Expand Down Expand Up @@ -284,8 +292,8 @@ def _update_users_email(user, email):
)
user.email = email

current_session.add(user)
current_session.commit()
current_app.scoped_session().add(user)
current_app.scoped_session().commit()


def _update_users_id_from_idp(user, id_from_idp):
Expand All @@ -298,8 +306,8 @@ def _update_users_id_from_idp(user, id_from_idp):
)
user.id_from_idp = id_from_idp

current_session.add(user)
current_session.commit()
current_app.scoped_session().add(user)
current_app.scoped_session().commit()


def _update_users_last_auth(user):
Expand All @@ -309,5 +317,5 @@ def _update_users_last_auth(user):
logger.info(f"Updating username {user.username}'s _last_auth.")
user._last_auth = datetime.now()

current_session.add(user)
current_session.commit()
current_app.scoped_session().add(user)
current_app.scoped_session().commit()
Loading

0 comments on commit 0304957

Please sign in to comment.