Skip to content

Commit

Permalink
celery -> arq
Browse files Browse the repository at this point in the history
  • Loading branch information
alexcannan committed Aug 11, 2023
1 parent 0f3a2b4 commit 59bc439
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 54 deletions.
16 changes: 12 additions & 4 deletions articlesa/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""One-stop shop for environment configuration."""

import os
from urllib.parse import urlparse

class CeleryConfig:
""" Configuration for Celery. """
broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379/0")
result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")

class RedisConfig:
url: str = os.getenv("REDIS_URL", "redis://localhost:6379")

@property
def host(self) -> str:
return urlparse(self.url).hostname

@property
def port(self) -> int:
return urlparse(self.url).port


class ServeConfig:
Expand Down
1 change: 1 addition & 0 deletions articlesa/crawl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ def get_articles(self) -> list[str]:
self.driver.get(self.url)
articles = self.driver.find_elements(By.CSS_SELECTOR, "a.story")
article_urls = [article.get_attribute("href") for article in articles]
logger.info(f"found {len(article_urls)} articles from mastodon.social")
return article_urls
1 change: 1 addition & 0 deletions articlesa/neo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def put_article(self,
parsedAtUtc=parsed_article.parsedAtUtc,
authors=parsed_article.authors,
publisherNetLoc=parsed_article.publisherNetLoc,
parent_url=parent_url,
)

async def get_article(self, url: str) -> ParsedArticle:
Expand Down
19 changes: 12 additions & 7 deletions articlesa/serve/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from datetime import datetime
import json
from typing import AsyncGenerator, Optional, cast
from arq import ArqRedis
from arq.jobs import JobStatus

from celery.app.task import Task
from fastapi import APIRouter, Request
Expand All @@ -25,8 +27,7 @@
PlaceholderArticle,
ParseFailure,
)
from articlesa.worker.parse import parse_article
parse_article = cast(Task, parse_article)
from articlesa.worker import create_pool


router = APIRouter()
Expand All @@ -49,13 +50,14 @@ def build_event(data: Optional[dict], id: str, event: StreamEvent) -> SSE:


async def retrieve_article(url: str,
arqpool: ArqRedis,
neodriver: Neo4JArticleDriver,
parent_url: Optional[str] = None,
) -> dict:
"""
Retrieve article from db or through celery; intended to be wrapped in asyncio.Task.
Retrieve article from db or through arq; intended to be wrapped in asyncio.Task.
Tries neo.Neo4jArticleDriver.get_article first, then falls back to celery.
Tries neo.Neo4jArticleDriver.get_article first, then falls back to arq enqueueing.
If a parent_url is passed, neo4j will create a relationship between the parent
and the child article.
Expand All @@ -64,8 +66,10 @@ async def retrieve_article(url: str,
parsed_article = await neodriver.get_article(url)
return parsed_article.dict()
except ArticleNotFound:
task = parse_article.delay(url)
article_dict = task.get()
job = await arqpool.enqueue_job("parse_article", url)
while await job.status() != JobStatus.complete:
await asyncio.sleep(0.1)
article_dict = await job.result()
try:
await neodriver.put_article(ParsedArticle(**article_dict), parent_url=parent_url)
except Exception as e:
Expand All @@ -83,6 +87,7 @@ async def _article_stream(
max_depth: maximum depth to parse to
"""
tasks = set()
arqpool = await create_pool()

async def _begin_processing_task(
url: str, depth: int, parent: Optional[str]
Expand All @@ -91,7 +96,7 @@ async def _begin_processing_task(
placeholder_node = PlaceholderArticle(
urlhash=url_to_hash(url), depth=depth, parent=parent
)
task = asyncio.create_task(retrieve_article(url, neodriver, parent_url=parent))
task = asyncio.create_task(retrieve_article(url, arqpool, neodriver, parent_url=parent))
task.set_name(f"{depth}/{url}")
tasks.add(task)
yield build_event(
Expand Down
38 changes: 38 additions & 0 deletions articlesa/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio

from aiohttp import ClientSession
from arq import create_pool
from arq.connections import RedisSettings

from articlesa.config import RedisConfig
from articlesa.logger import logger
from articlesa.worker.parse import parse_article


async def make_pool():
"""Create a redis pool for the worker, used to enqueue jobs."""
return await create_pool(
RedisSettings(
host=RedisConfig.host,
port=RedisConfig.port,
)
)


async def startup(ctx):
"""Startup function for arq worker, creates aiohttp session."""
logger.info("starting up")
ctx['session'] = await ClientSession().__aenter__()


async def shutdown(ctx):
"""Shutdown function for arq worker, closes aiohttp session."""
logger.info("shutting down")
await ctx['session'].__aexit__(None, None, None)


class WorkerSettings:
"""https://arq-docs.helpmanual.io/#arq.worker.Worker"""
functions = [parse_article]
on_startup = startup
on_shutdown = shutdown
14 changes: 14 additions & 0 deletions articlesa/worker/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio

from arq import create_pool
from arq.connections import RedisSettings


async def main():
redis = await create_pool(RedisSettings())
for url in ('https://facebook.com', 'https://microsoft.com', 'https://github.com'):
await redis.enqueue_job('parse_article', url)


if __name__ == '__main__':
asyncio.run(main())
24 changes: 0 additions & 24 deletions articlesa/worker/app.py

This file was deleted.

22 changes: 4 additions & 18 deletions articlesa/worker/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from articlesa.logger import logger
from articlesa.types import ParsedArticle, relative_to_absolute_url, HostBlacklist
from articlesa.worker.app import app as celeryapp


session: Optional[ClientSession] = None
Expand All @@ -29,14 +28,6 @@
}


async def get_session_() -> ClientSession:
"""Create a global aiohttp session if one doesn't exist."""
global session
if not session:
session = await ClientSession().__aenter__()
return session


class MissingArticleText(Exception):
"""Raised when an article has no text."""
pass
Expand All @@ -63,10 +54,9 @@ async def download_article(url: str, session: ClientSession) -> str:
return await response.text()


@celeryapp.task(name="parse_article")
async def parse_article(url: str) -> dict:
async def parse_article(ctx, url: str) -> dict:
"""Given a url, parse the article and return a dict like ParsedArticle."""
session = await get_session_()
session: ClientSession = ctx["session"]

# Check for redirects
final_url = await check_redirect(url, session)
Expand Down Expand Up @@ -124,10 +114,6 @@ async def parse_article(url: str) -> dict:
links=article.links,
published=article.publish_date,
parsedAtUtc=datetime.utcnow(),
).dict()

return parsed_article

)

if __name__ == "__main__":
print("run me with:\ncelery -A articlesa.worker.parse worker -l info") # noqa: T201
return parsed_article.dict()
3 changes: 2 additions & 1 deletion blacklist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ google.com
twitter.com
t.co
linkedin.com
amazon.com
amazon.com
archive.org

0 comments on commit 59bc439

Please sign in to comment.