Skip to content

Commit

Permalink
make events expire, rewrite sending logic
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
  • Loading branch information
BeryJu committed Feb 2, 2025
1 parent b075a1c commit 8b34541
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 96 deletions.
70 changes: 34 additions & 36 deletions authentik/enterprise/providers/ssf/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Generated by Django 5.0.11 on 2025-01-21 01:46
# Generated by Django 5.0.11 on 2025-02-02 19:39

import authentik.lib.utils.time
import django.contrib.postgres.fields
import django.db.models.deletion
import uuid
from django.conf import settings
from django.db import migrations, models


Expand All @@ -15,7 +15,6 @@ class Migration(migrations.Migration):
("authentik_core", "0042_authenticatedsession_authentik_c_expires_08251d_idx_and_more"),
("authentik_crypto", "0004_alter_certificatekeypair_name"),
("authentik_providers_oauth2", "0027_accesstoken_authentik_p_expires_9f24a5_idx_and_more"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]

operations = [
Expand All @@ -33,6 +32,13 @@ class Migration(migrations.Migration):
to="authentik_core.provider",
),
),
(
"event_retention",
models.TextField(
default="days=30",
validators=[authentik.lib.utils.time.timedelta_string_validator],
),
),
(
"oidc_auth_providers",
models.ManyToManyField(
Expand Down Expand Up @@ -60,8 +66,9 @@ class Migration(migrations.Migration):
),
],
options={
"verbose_name": "SSF Provider",
"verbose_name_plural": "SSF Providers",
"verbose_name": "Shared Signals Framework Provider",
"verbose_name_plural": "Shared Signals Framework Providers",
"permissions": [("add_stream", "Add stream to SSF provider")],
},
bases=("authentik_core.provider",),
),
Expand Down Expand Up @@ -128,24 +135,35 @@ class Migration(migrations.Migration):
to="authentik_providers_ssf.ssfprovider",
),
),
(
"user_subjects",
models.ManyToManyField(
related_name="UserStreamSubject", to=settings.AUTH_USER_MODEL
),
),
],
options={
"verbose_name": "SSF Stream",
"verbose_name_plural": "SSF Streams",
"default_permissions": ["change", "delete", "view"],
},
),
migrations.CreateModel(
name="StreamEvent",
fields=[
("expires", models.DateTimeField(default=None, null=True)),
("expiring", models.BooleanField(default=True)),
(
"uuid",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("status", models.TextField(choices=[("pending", "Pending"), ("sent", "Sent")])),
(
"status",
models.TextField(
choices=[
("pending_new", "Pending New"),
("pending_failed", "Pending Failed"),
("sent", "Sent"),
],
default="pending_new",
),
),
(
"type",
models.TextField(
Expand Down Expand Up @@ -174,29 +192,9 @@ class Migration(migrations.Migration):
),
),
],
),
migrations.CreateModel(
name="UserStreamSubject",
fields=[
(
"id",
models.AutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
),
),
(
"stream",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
to="authentik_providers_ssf.stream",
),
),
(
"user",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
),
],
options={
"verbose_name": "SSF Stream Event",
"verbose_name_plural": "SSF Stream Events",
},
),
]

This file was deleted.

39 changes: 25 additions & 14 deletions authentik/enterprise/providers/ssf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.templatetags.static import static
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from jwt import encode

from authentik.core.models import BackchannelProvider, Token, User
from authentik.core.models import BackchannelProvider, ExpiringModel, Token
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider


Expand All @@ -34,7 +36,8 @@ class DeliveryMethods(models.TextChoices):
class SSFEventStatus(models.TextChoices):
"""SSF Event status"""

PENDING = "pending"
PENDING_NEW = "pending_new"
PENDING_FAILED = "pending_failed"
SENT = "sent"


Expand All @@ -56,6 +59,11 @@ class SSFProvider(BackchannelProvider):

token = models.ForeignKey(Token, on_delete=models.CASCADE, null=True, default=None)

event_retention = models.TextField(
default="days=30",
validators=[timedelta_string_validator],
)

@cached_property
def jwt_key(self) -> tuple[str | PrivateKeyTypes, str]:
"""Get either the configured certificate or the client secret"""
Expand Down Expand Up @@ -111,8 +119,6 @@ class Stream(models.Model):
format = models.TextField()
aud = ArrayField(models.TextField(), default=list)

user_subjects = models.ManyToManyField(User, "UserStreamSubject")

iss = models.TextField()

class Meta:
Expand All @@ -125,10 +131,13 @@ def __str__(self) -> str:

def prepare_event_payload(self, type: EventTypes, event_data: dict, **kwargs) -> dict:
jti = uuid4()
_now = now()
return {
"uuid": jti,
"stream_id": str(self.pk),
"type": type,
"expiring": True,
"expires": _now + timedelta_from_string(self.provider.event_retention),
"payload": {
"jti": jti.hex,
"aud": self.aud,
Expand All @@ -147,24 +156,26 @@ def encode(self, data: dict) -> str:
return encode(data, key, algorithm=alg, headers=headers)


class UserStreamSubject(models.Model):
stream = models.ForeignKey(Stream, on_delete=models.CASCADE)
user = models.ForeignKey(User, on_delete=models.CASCADE)

def __str__(self) -> str:
return f"Stream subject {self.stream_id} to {self.user_id}"


class StreamEvent(models.Model):
class StreamEvent(ExpiringModel):
"""Single stream event to be sent"""

uuid = models.UUIDField(default=uuid4, primary_key=True, editable=False)

stream = models.ForeignKey(Stream, on_delete=models.CASCADE)
status = models.TextField(choices=SSFEventStatus.choices)
status = models.TextField(choices=SSFEventStatus.choices, default=SSFEventStatus.PENDING_NEW)

type = models.TextField(choices=EventTypes.choices)
payload = models.JSONField(default=dict)

def expire_action(self, *args, **kwargs):
"""Only allow automatic cleanup of successfully sent event"""
if self.status != SSFEventStatus.SENT:
return
return super().expire_action(*args, **kwargs)

def __str__(self):
return f"Stream event {self.type}"

class Meta:
verbose_name = _("SSF Stream Event")
verbose_name_plural = _("SSF Stream Events")
47 changes: 34 additions & 13 deletions authentik/enterprise/providers/ssf/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from celery import group
from requests.exceptions import RequestException
from structlog.stdlib import get_logger

from authentik.enterprise.providers.ssf.models import (
DeliveryMethods,
Expand All @@ -8,10 +9,13 @@
Stream,
StreamEvent,
)
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP

session = get_http_session()
LOGGER = get_logger()


def send_ssf_event(
Expand All @@ -36,13 +40,12 @@ def _send_ssf_event(event_data: list[tuple[str, dict]]):
tasks = []
for stream, data in event_data:
event = StreamEvent.objects.create(**data)
tasks.append(send_single_ssf_event.si(stream, str(event.uuid)))
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
main_task = group(*tasks)
main_task()


@CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True)
def send_single_ssf_event(self, stream_id: str, evt_id: str):
def send_single_ssf_event(stream_id: str, evt_id: str):
stream = Stream.objects.filter(pk=stream_id).first()
if not stream:
return
Expand All @@ -52,15 +55,33 @@ def send_single_ssf_event(self, stream_id: str, evt_id: str):
if event.status == SSFEventStatus.SENT:
return
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
ssf_push_request(event)
event.status = SSFEventStatus.SENT
event.save()
return [ssf_push_event.si(str(event.pk))]
return []


def ssf_push_request(event: StreamEvent):
response = session.post(
event.stream.endpoint_url,
data=event.stream.encode(event.payload),
headers={"Content-Type": "application/secevent+jwt", "Accept": "application/json"},
)
response.raise_for_status()
@CELERY_APP.task(bind=True, base=SystemTask)
def ssf_push_event(self: SystemTask, event_id: str):
self.save_on_success = False
event = StreamEvent.objects.filter(pk=event_id).first()
if not event:
return
self.set_uid(event)
if event.status == SSFEventStatus.SENT:
self.set_status(TaskStatus.SUCCESSFUL)
return
try:
response = session.post(
event.stream.endpoint_url,
data=event.stream.encode(event.payload),
headers={"Content-Type": "application/secevent+jwt", "Accept": "application/json"},
)
response.raise_for_status()
event.status = SSFEventStatus.SENT
event.save()
self.set_status(TaskStatus.SUCCESSFUL)
return
except RequestException as exc:
LOGGER.warning("Failed to send SSF event", exc=exc)
self.set_error(exc)
event.status = SSFEventStatus.PENDING_FAILED
event.save()
20 changes: 16 additions & 4 deletions authentik/enterprise/providers/ssf/tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from authentik.core.models import Application
from authentik.core.tests.utils import create_test_cert
from authentik.enterprise.providers.ssf.models import SSFProvider, Stream
from authentik.enterprise.providers.ssf.models import (
SSFEventStatus,
SSFProvider,
Stream,
StreamEvent,
)
from authentik.lib.generators import generate_id


Expand All @@ -25,9 +30,7 @@ def test_stream_add(self):
),
data={
"iss": "https://screw-fotos-bracelets-longitude.trycloudflare.com/.well-known/ssf-configuration/abm-ssf/5",
"aud": [
"https://app.authentik.company"
],
"aud": ["https://app.authentik.company"],
"delivery": {
"method": "https://schemas.openid.net/secevent/risc/delivery-method/push",
"endpoint_url": "https://app.authentik.company",
Expand All @@ -41,6 +44,15 @@ def test_stream_add(self):
HTTP_AUTHORIZATION=f"Bearer {self.provider.token.key}",
)
self.assertEqual(res.status_code, 201)
stream = Stream.objects.filter(provider=self.provider).first()
self.assertIsNotNone(stream)
event = StreamEvent.objects.filter(stream=stream).first()
self.assertIsNotNone(event)
self.assertEqual(event.status, SSFEventStatus.PENDING_FAILED)
self.assertEqual(
event.payload["events"],
{"https://schemas.openid.net/secevent/ssf/event-type/verification": {"state": None}},
)

def test_stream_delete(self):
"""delete stream"""
Expand Down

0 comments on commit 8b34541

Please sign in to comment.