From e15bdf86c1146dede143e961db823ac15ff18df2 Mon Sep 17 00:00:00 2001
From: Thomas Ieong
Date: Wed, 13 Dec 2023 16:10:35 +0100
Subject: [PATCH 1/2] Switch to asyncpg
---
gitlab_matrix/bot.py | 12 +-
gitlab_matrix/commands/alias.py | 12 +-
gitlab_matrix/commands/room.py | 2 +-
gitlab_matrix/commands/server.py | 8 +-
gitlab_matrix/commands/webhook.py | 2 +-
gitlab_matrix/db.py | 344 ++++++++++++++----------------
gitlab_matrix/migrations.py | 61 ++++++
gitlab_matrix/util/decorators.py | 5 +-
gitlab_matrix/webhook.py | 23 +-
maubot.yaml | 1 +
10 files changed, 259 insertions(+), 211 deletions(-)
create mode 100644 gitlab_matrix/migrations.py
diff --git a/gitlab_matrix/bot.py b/gitlab_matrix/bot.py
index 88b1b54..4cd75ac 100644
--- a/gitlab_matrix/bot.py
+++ b/gitlab_matrix/bot.py
@@ -17,23 +17,25 @@
from typing import Type
from mautrix.util.config import BaseProxyConfig
+from mautrix.util.async_db import UpgradeTable
from maubot import Plugin
-from .db import Database
+from .db import DBManager
from .util import Config
from .webhook import GitlabWebhook
from .commands import GitlabCommands
+from .migrations import upgrade_table
class GitlabBot(Plugin):
- db: Database
+ db: DBManager
webhook: GitlabWebhook
commands: GitlabCommands
async def start(self) -> None:
self.config.load_and_update()
- self.db = Database(self.database)
+ self.db = DBManager(self.database)
self.webhook = await GitlabWebhook(self).start()
self.commands = GitlabCommands(self)
@@ -46,3 +48,7 @@ async def stop(self) -> None:
@classmethod
def get_config_class(cls) -> Type[BaseProxyConfig]:
return Config
+
+ @classmethod
+ def get_db_upgrade_table(cls) -> UpgradeTable:
+ return upgrade_table
diff --git a/gitlab_matrix/commands/alias.py b/gitlab_matrix/commands/alias.py
index 0377934..83364ac 100644
--- a/gitlab_matrix/commands/alias.py
+++ b/gitlab_matrix/commands/alias.py
@@ -29,18 +29,20 @@ async def alias(self) -> None:
@command.argument("url", "server URL")
@command.argument("alias", "server alias")
async def alias_add(self, evt: MessageEvent, url: str, alias: str) -> None:
- if url not in self.bot.db.get_servers(evt.sender):
+ servers = await self.bot.db.get_servers(evt.sender)
+ if url not in servers:
await evt.reply("You can't add an alias to a GitLab server you are not logged in to.")
return
- if self.bot.db.has_alias(evt.sender, alias):
+ has_alias = await self.bot.db.has_alias(evt.sender, alias)
+ if has_alias:
await evt.reply("Alias already in use.")
return
- self.bot.db.add_alias(evt.sender, url, alias)
+ await self.bot.db.add_alias(evt.sender, url, alias)
await evt.reply(f"Added alias {alias} to server {url}")
@alias.subcommand("list", aliases=("l", "ls"), help="Show your Gitlab server aliases.")
async def alias_list(self, evt: MessageEvent) -> None:
- aliases = self.bot.db.get_aliases(evt.sender)
+ aliases = await self.bot.db.get_aliases(evt.sender)
if not aliases:
await evt.reply("You don't have any aliases.")
return
@@ -52,5 +54,5 @@ async def alias_list(self, evt: MessageEvent) -> None:
help="Remove a alias to a Gitlab server.")
@command.argument("alias", "server alias")
async def alias_rm(self, evt: MessageEvent, alias: str) -> None:
- self.bot.db.rm_alias(evt.sender, alias)
+ await self.bot.db.rm_alias(evt.sender, alias)
await evt.reply(f"Removed alias {alias}.")
diff --git a/gitlab_matrix/commands/room.py b/gitlab_matrix/commands/room.py
index 947e963..5499851 100644
--- a/gitlab_matrix/commands/room.py
+++ b/gitlab_matrix/commands/room.py
@@ -48,5 +48,5 @@ async def default_repo(self, evt: MessageEvent, repo: str, gl: Gl) -> None:
await evt.reply(f"Couldn't find {repo} on {gl.url}")
return
raise
- self.bot.db.set_default_repo(evt.room_id, gl.url, repo)
+ await self.bot.db.set_default_repo(evt.room_id, gl.url, repo)
await evt.reply(f"Changed the default repo to {repo} on {gl.url}")
diff --git a/gitlab_matrix/commands/server.py b/gitlab_matrix/commands/server.py
index 2a996e5..2949cbb 100644
--- a/gitlab_matrix/commands/server.py
+++ b/gitlab_matrix/commands/server.py
@@ -32,12 +32,12 @@ async def server(self) -> None:
@server.subcommand("default", aliases=("d",), help="Change your default GitLab server.")
@command.argument("url", "server URL")
async def server_default(self, evt: MessageEvent, url: str) -> None:
- self.bot.db.change_default(evt.sender, url)
+ await self.bot.db.change_default(evt.sender, url)
await evt.reply(f"Changed the default server to {url}")
@server.subcommand("list", aliases=("ls",), help="Show your GitLab servers.")
async def server_list(self, evt: MessageEvent) -> None:
- servers = self.bot.db.get_servers(evt.sender)
+ servers = await self.bot.db.get_servers(evt.sender)
if not servers:
await evt.reply("You are not logged in to any server.")
return
@@ -60,13 +60,13 @@ async def server_login(self, evt: MessageEvent, url: str, token: str) -> None:
exc_info=True)
await evt.reply(f"GitLab login failed: {e}")
return
- self.bot.db.add_login(evt.sender, url, token)
+ await self.bot.db.add_login(evt.sender, url, token)
await evt.reply(f"Successfully logged into GitLab at {url} as {gl.user.name}")
@server.subcommand("logout", help="Remove the access token from the bot's database.")
@command.argument("url", "server URL")
async def server_logout(self, evt: MessageEvent, url: str) -> None:
- self.bot.db.rm_login(evt.sender, url)
+ await self.bot.db.rm_login(evt.sender, url)
await evt.reply(f"Removed {url} from the database.")
@Command.gitlab.subcommand("ping", aliases=("p",), help="Ping the bot.")
diff --git a/gitlab_matrix/commands/webhook.py b/gitlab_matrix/commands/webhook.py
index a48ab13..81919d3 100644
--- a/gitlab_matrix/commands/webhook.py
+++ b/gitlab_matrix/commands/webhook.py
@@ -34,7 +34,7 @@ async def webhook(self) -> None:
@with_gitlab_session
async def webhook_add(self, evt: MessageEvent, repo: str, gl: Gl) -> None:
token = secrets.token_urlsafe(64)
- self.bot.db.add_webhook_room(token, evt.room_id)
+ await self.bot.db.add_webhook_room(token, evt.room_id)
project = gl.projects.get(repo)
hook = project.hooks.create({
"url": f"{self.bot.webapp_url}/webhooks",
diff --git a/gitlab_matrix/db.py b/gitlab_matrix/db.py
index 550b12f..6ff443c 100644
--- a/gitlab_matrix/db.py
+++ b/gitlab_matrix/db.py
@@ -1,6 +1,7 @@
# gitlab - A GitLab client and webhook receiver for maubot
# Copyright (C) 2019 Lorenz Steinert
# Copyright (C) 2019 Tulir Asokan
+# Copyright (C) 2023 Thomas Ieong
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@@ -17,203 +18,176 @@
from typing import List, NamedTuple, Optional
import logging as log
-from sqlalchemy import Column, String, Text, ForeignKeyConstraint, or_, ForeignKey
-from sqlalchemy.orm import sessionmaker, relationship, Session
-from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound
-from sqlalchemy.engine.base import Engine
-from sqlalchemy.ext.declarative import declarative_base
-
from mautrix.types import UserID, EventID, RoomID
+from mautrix.util.async_db import Database
AuthInfo = NamedTuple('AuthInfo', server=str, api_token=str)
AliasInfo = NamedTuple('AliasInfo', server=str, alias=str)
DefaultRepoInfo = NamedTuple('DefaultRepoInfo', server=str, repo=str)
-Base = declarative_base()
-
-
-class Token(Base):
- __tablename__ = "token"
-
- user_id: UserID = Column(String(255), primary_key=True, nullable=False)
- gitlab_server = Column(Text, primary_key=True, nullable=False)
- api_token = Column(Text, nullable=False)
- aliases = relationship("Alias", back_populates="token",
- cascade="all, delete-orphan")
- default = relationship("Default", back_populates="token",
- cascade="all, delete-orphan",
- primaryjoin="Token.user_id==Default.user_id")
-
-
-class Alias(Base):
- __tablename__ = "alias"
-
- user_id: UserID = Column(String(255), primary_key=True)
- gitlab_server = Column(Text, primary_key=True)
- alias = Column(Text, primary_key=True, nullable=False)
- __table_args__ = (ForeignKeyConstraint((user_id, gitlab_server),
- (Token.user_id, Token.gitlab_server)),)
- token = relationship("Token", back_populates="aliases")
-
-
-class Default(Base):
- __tablename__ = "default"
-
- user_id: UserID = Column(String(255), ForeignKey("token.user_id"), primary_key=True)
- gitlab_server = Column(Text, ForeignKey("token.gitlab_server"))
- token = relationship("Token", back_populates="default",
- primaryjoin="Token.user_id==Default.user_id")
-
-
-class DefaultRepo(Base):
- __tablename__ = "default_repo"
- room_id: RoomID = Column(String(255), primary_key=True)
- server: str = Column(String(255), nullable=False)
- repo: str = Column(String(255), nullable=False)
+class DBManager:
+ db: Database
-class MatrixMessage(Base):
- __tablename__ = "matrix_message"
-
- message_id: str = Column(String(255), primary_key=True)
- room_id: RoomID = Column(String(255), primary_key=True)
- event_id: EventID = Column(String(255), nullable=False)
-
-
-class WebhookToken(Base):
- __tablename__ = "webhook_token"
-
- room_id: RoomID = Column(Text, nullable=False)
- secret: str = Column(Text, primary_key=True)
-
-
-class Database:
- db: Engine
-
- def __init__(self, db: Engine) -> None:
+ def __init__(self, db: Database) -> None:
self.db = db
- Base.metadata.create_all(db)
- self.Session = sessionmaker(bind=self.db)
- def get_event(self, message_id: str, room_id: RoomID) -> Optional[EventID]:
+ async def get_event(self, message_id: str, room_id: RoomID) -> Optional[EventID]:
if not message_id:
return None
- s: Session = self.Session()
- event = s.query(MatrixMessage).get((message_id, room_id))
- return event.event_id if event else None
-
- def put_event(self, message_id: str, room_id: RoomID, event_id: EventID, merge: bool = False
- ) -> None:
- s: Session = self.Session()
- evt = MatrixMessage(message_id=message_id, room_id=room_id, event_id=event_id)
- if merge:
- s.merge(evt)
- else:
- s.add(evt)
- s.commit()
-
- def get_default_repo(self, room_id: RoomID) -> DefaultRepoInfo:
- s: Session = self.Session()
- default = s.query(DefaultRepo).get((room_id,))
- return DefaultRepoInfo(default.server, default.repo) if default else None
-
- def set_default_repo(self, room_id: RoomID, server: str, repo: str) -> None:
- s: Session = self.Session()
- s.merge(DefaultRepo(room_id=room_id, server=server, repo=repo))
- s.commit()
-
- def get_servers(self, mxid: UserID) -> List[str]:
- s = self.Session()
- rows = s.query(Token).filter(Token.user_id == mxid)
- return [row.gitlab_server for row in rows]
-
- def add_login(self, mxid: UserID, url: str, token: str) -> None:
- token_row = Token(user_id=mxid, gitlab_server=url, api_token=token)
- default = Default(user_id=mxid, gitlab_server=url)
- s = self.Session()
- try:
- s.add(token_row)
- s.query(Default).filter(Default.user_id == mxid).one()
- except NoResultFound:
- s.add(default)
- except MultipleResultsFound as e:
+ q = (
+ "SELECT message_id, room_id, event_id FROM matrix_message "
+ "WHERE message_id = $1 AND room_id = $2"
+ )
+ event = await self.db.fetchrow(q, message_id, room_id)
+ return event["event_id"] if event else None
+
+ async def put_event(
+ self,
+ message_id: str,
+ room_id: RoomID,
+ event_id: EventID,
+ ) -> None:
+ q = (
+ "INSERT INTO matrix_message (message_id, room_id, event_id) VALUES ($1, $2, $3) "
+ "ON CONFLICT (message_id, room_id) DO UPDATE "
+ "SET event_id = excluded.event_id"
+ )
+ await self.db.execute(q, message_id, room_id, event_id)
+
+ async def get_default_repo(self, room_id: RoomID) -> DefaultRepoInfo:
+ q = "SELECT room_id, server, repo FROM default_repo WHERE room_id = $1"
+ default = await self.db.fetchrow(q, room_id)
+ return DefaultRepoInfo(default["server"], default["repo"]) if default else None
+
+ async def set_default_repo(self, room_id: RoomID, server: str, repo: str) -> None:
+ q = (
+ "INSERT INTO default_repo (room_id, server, repo) "
+ "VALUES ($1, $2, $3) "
+ "ON CONFLICT (room_id) DO UPDATE "
+ "SET server = excluded.server, repo = excluded.repo"
+ )
+ await self.db.execute(q, room_id, server, repo)
+
+ async def get_servers(self, mxid: UserID) -> List[str]:
+ q = "SELECT user_id, gitlab_server, api_token FROM token WHERE user_id = $1"
+ rows = await self.db.fetch(q, mxid)
+ return [row["gitlab_server"] for row in rows]
+
+ async def add_login(self, mxid: UserID, url: str, token: str) -> None:
+ token_query = (
+ "INSERT INTO token (user_id, gitlab_server, api_token) "
+ "VALUES ($1, $2, $3)"
+ )
+ await self.db.execute(token_query, mxid, url, token)
+
+ default_query = (
+ "SELECT user_id, gitlab_server FROM 'default' "
+ "WHERE user_id = $1"
+ )
+ result = await self.db.fetch(default_query, mxid)
+
+ if len(result) > 1:
log.warning("Multiple default servers found.")
- log.warning(e)
- raise e
- s.commit()
-
- def rm_login(self, mxid: UserID, url: str) -> None:
- s = self.Session()
- token = s.query(Token).get((mxid, url))
- s.delete(token)
- s.commit()
-
- def get_login(self, mxid: UserID, url_alias: str = None) -> AuthInfo:
- s = self.Session()
+
+ if not result:
+ q = (
+ "INSERT INTO 'default' (user_id, gitlab_server) "
+ "VALUES ($1, $2)"
+ )
+ await self.db.execute(q, mxid, url)
+
+ async def rm_login(self, mxid: UserID, url: str) -> None:
+ q = "DELETE FROM token WHERE user_id = $1 AND gitlab_server = $2"
+ await self.db.execute(q, mxid, url)
+
+ async def get_login(self, mxid: UserID, url_alias: str = None) -> AuthInfo:
if url_alias:
- row = (s.query(Token)
- .join(Alias)
- .filter(Token.user_id == mxid,
- or_(Token.gitlab_server == url_alias,
- Alias.alias == url_alias)).one())
+ q = (
+ "SELECT user_id, gitlab_server, api_token FROM token "
+ "JOIN alias ON token.user_id = alias.user_id "
+ "AND token.gitlab_server = alias.gitlab_server "
+ "WHERE token.user_id = $1 AND ( token.gitlab_server = $2 OR alias.alias = $2 )"
+ )
+ row = await self.db.fetchrow(q, mxid, url_alias)
else:
- row = (s.query(Token)
- .join(Default, Default.user_id == Token.user_id)
- .filter(Token.user_id == mxid).first())
- return AuthInfo(server=row.gitlab_server, api_token=row.api_token)
-
- def get_login_by_server(self, mxid: UserID, url: str) -> AuthInfo:
- s = self.Session()
- row = s.query(Token).get((mxid, url))
- return AuthInfo(server=row.gitlab_server, api_token=row.api_token)
-
- def get_login_by_alias(self, mxid: UserID, alias: str) -> AuthInfo:
- s = self.Session()
- row = s.query(Token).join(Alias).filter(Token.user_id == mxid,
- Alias.alias == alias).one()
- return AuthInfo(server=row.gitlab_server, api_token=row.api_token)
-
- def add_alias(self, mxid: UserID, url: str, alias: str) -> None:
- s = self.Session()
- alias = Alias(user_id=mxid, gitlab_server=url, alias=alias)
- s.add(alias)
- s.commit()
-
- def rm_alias(self, mxid: UserID, alias: str) -> None:
- s = self.Session()
- alias = s.query(Alias).filter(Alias.user_id == mxid,
- Alias.alias == alias).one()
- s.delete(alias)
- s.commit()
-
- def has_alias(self, user_id: UserID, alias: str) -> bool:
- s: Session = self.Session()
- return s.query(Alias).filter(Alias.user_id == user_id, Alias.alias == alias).count() > 0
-
- def get_aliases(self, user_id: UserID) -> List[AliasInfo]:
- s = self.Session()
- rows = s.query(Alias).filter(Alias.user_id == user_id)
- return [AliasInfo(row.gitlab_server, row.alias) for row in rows]
-
- def get_aliases_per_server(self, user_id: UserID, url: str) -> List[AliasInfo]:
- s = self.Session()
- rows = s.query(Alias).filter(Alias.user_id == user_id,
- Alias.gitlab_server == url)
- return [AliasInfo(row.gitlab_server, row.alias) for row in rows]
-
- def change_default(self, mxid: UserID, url: str) -> None:
- s = self.Session()
- default = s.query(Default).get((mxid,))
- default.gitlab_server = url
- s.commit()
-
- def get_webhook_room(self, secret: str) -> Optional[RoomID]:
- s = self.Session()
- webhook_token = s.query(WebhookToken).get((secret,))
- return webhook_token.room_id if webhook_token else None
-
- def add_webhook_room(self, secret: str, room_id: RoomID) -> None:
- s = self.Session()
- webhook_token = WebhookToken(secret=secret, room_id=room_id)
- s.add(webhook_token)
- s.commit()
+ q = (
+ "SELECT user_id, gitlab_server, api_token FROM token "
+ "JOIN 'default' ON token.user_id = 'default'.user_id "
+ "AND token.gitlab_server = 'default'.gitlab_server "
+ "WHERE token.user_id = $1"
+ )
+ row = await self.db.fetchrow(q, mxid)
+ return AuthInfo(server=row["gitlab_server"], api_token=row["api_token"])
+
+ async def get_login_by_server(self, mxid: UserID, url: str) -> AuthInfo:
+ q = (
+ "SELECT user_id, gitlab_server, api_token FROM token "
+ "WHERE user_id = $1 AND gitlab_server = $2"
+ )
+ row = await self.db.fetchrow(q, mxid, url)
+ return AuthInfo(server=row["gitlab_server"], api_token=row["api_token"])
+
+ async def get_login_by_alias(self, mxid: UserID, alias: str) -> AuthInfo:
+ q = (
+ "SELECT user_id, gitlab_server, api_token FROM token "
+ "JOIN alias ON "
+ "token.user_id = alias.user_id AND token.gitlab_server = alias.gitlab_server "
+ "WHERE token.user_id = $1 AND alias.alias = $2"
+ )
+ row = await self.db.fetchrow(q, mxid, alias)
+ return AuthInfo(server=row["gitlab_server"], api_token=row["api_token"])
+
+ async def add_alias(self, mxid: UserID, url: str, alias: str) -> None:
+ q = "INSERT INTO alias (user_id, gitlab_server, alias) VALUES ($1, $2, $3)"
+ await self.db.execute(q, mxid, url, alias)
+
+ async def rm_alias(self, mxid: UserID, alias: str) -> None:
+ q = "DELETE FROM alias WHERE user_id = $1 AND alias = $2"
+ await self.db.execute(q, mxid, alias)
+
+ async def has_alias(self, user_id: UserID, alias: str) -> bool:
+ q = (
+ "SELECT user_id, gitlab_server, alias FROM alias "
+ "WHERE user_id = $1 AND alias = $2"
+ )
+ rows = await self.db.fetch(q, user_id, alias)
+ return len(rows) > 0
+
+ async def get_aliases(self, user_id: UserID) -> List[AliasInfo]:
+ q = (
+ "SELECT user_id, gitlab_server, alias FROM alias "
+ "WHERE user_id = $1"
+ )
+ rows = await self.db.fetch(q, user_id)
+ return [AliasInfo(row["gitlab_server"], row["alias"]) for row in rows]
+
+ async def get_aliases_per_server(self, user_id: UserID, url: str) -> List[AliasInfo]:
+ q = (
+ "SELECT user_id, gitlab_server, alias FROM alias "
+ "WHERE user_id = $1 AND gitlab_server = $2"
+ )
+ rows = await self.db.fetch(q, user_id, url)
+ return [AliasInfo(row["gitlab_server"], row["alias"]) for row in rows]
+
+ async def change_default(self, mxid: UserID, url: str) -> None:
+ q = (
+ "SELECT user_id, gitlab_server FROM 'default' "
+ "WHERE user_id = $1"
+ )
+ default = await self.db.fetchrow(q, mxid)
+ if default:
+ q = (
+ "UPDATE 'default' SET gitlab_server = $2 "
+ "WHERE user_id = $1"
+ )
+ await self.db.execute(q, mxid, url)
+
+ async def get_webhook_room(self, secret: str) -> Optional[RoomID]:
+ q = "SELECT room_id, secret FROM webhook_token WHERE secret = $1"
+ webhook_token = await self.db.fetchrow(q, secret)
+ return webhook_token["room_id"] if webhook_token else None
+
+ async def add_webhook_room(self, secret: str, room_id: RoomID) -> None:
+ q = "INSERT INTO webhook_token (room_id, secret) VALUES ($1, $2)"
+ await self.db.execute(q, room_id, secret)
diff --git a/gitlab_matrix/migrations.py b/gitlab_matrix/migrations.py
new file mode 100644
index 0000000..a1895d0
--- /dev/null
+++ b/gitlab_matrix/migrations.py
@@ -0,0 +1,61 @@
+from mautrix.util.async_db import Connection, UpgradeTable
+
+upgrade_table = UpgradeTable()
+
+
+@upgrade_table.register(description="Initial revision")
+async def upgrade_v1(conn: Connection) -> None:
+ await conn.execute(
+ f"""CREATE TABLE IF NOT EXISTS token (
+ user_id VARCHAR(255),
+ gitlab_server TEXT,
+ api_token TEXT NOT NULL,
+
+ PRIMARY KEY (user_id, gitlab_server)
+ )"""
+ )
+ await conn.execute(
+ f"""CREATE TABLE IF NOT EXISTS alias (
+ user_id VARCHAR(255),
+ gitlab_server TEXT,
+ alias TEXT,
+
+ PRIMARY KEY (user_id, gitlab_server, alias),
+ FOREIGN KEY (user_id, gitlab_server) REFERENCES token (user_id, gitlab_server) ON DELETE CASCADE
+ )"""
+ )
+ await conn.execute(
+ """CREATE TABLE IF NOT EXISTS "default" (
+ user_id VARCHAR(255),
+ gitlab_server TEXT NOT NULL,
+
+ PRIMARY KEY (user_id),
+ FOREIGN KEY (user_id, gitlab_server) REFERENCES token (user_id, gitlab_server) ON DELETE CASCADE
+ )"""
+ ) # add gitlab in primary key ?
+ await conn.execute(
+ """CREATE TABLE IF NOT EXISTS default_repo (
+ room_id VARCHAR(255),
+ server VARCHAR(255) NOT NULL,
+ repo VARCHAR(255) NOT NULL,
+
+ PRIMARY KEY (room_id)
+ )"""
+ )
+ await conn.execute(
+ """CREATE TABLE IF NOT EXISTS matrix_message (
+ message_id VARCHAR(255),
+ room_id VARCHAR(255),
+ event_id VARCHAR(255) NOT NULL,
+
+ PRIMARY KEY (message_id, room_id)
+ )"""
+ )
+ await conn.execute(
+ """CREATE TABLE IF NOT EXISTS webhook_token (
+ room_id TEXT NOT NULL,
+ secret TEXT,
+
+ PRIMARY KEY (secret)
+ )"""
+ )
diff --git a/gitlab_matrix/util/decorators.py b/gitlab_matrix/util/decorators.py
index 7a10c99..5b0c964 100644
--- a/gitlab_matrix/util/decorators.py
+++ b/gitlab_matrix/util/decorators.py
@@ -37,10 +37,11 @@ async def wrapper(self: 'Command', evt: MessageEvent, login: AuthInfo, **kwargs
try:
repo: Any = kwargs["repo"]
if isinstance(repo, DefaultRepoInfo):
- if repo.server not in self.bot.db.get_servers(evt.sender):
+ servers = await self.bot.db.get_servers(evt.sender)
+ if repo.server not in servers:
await evt.reply(f"You're not logged into {repo.server}")
return
- login = self.bot.db.get_login(evt.sender, url_alias=repo.server)
+ login = await self.bot.db.get_login(evt.sender, url_alias=repo.server)
kwargs["repo"] = repo.repo
except KeyError:
pass
diff --git a/gitlab_matrix/webhook.py b/gitlab_matrix/webhook.py
index aa69a32..bd98985 100644
--- a/gitlab_matrix/webhook.py
+++ b/gitlab_matrix/webhook.py
@@ -77,7 +77,7 @@ async def post_handler(self, request: Request) -> Response:
"Did you forget the ?room query parameter?\n",
status=400)
else:
- room_id = self.bot.db.get_webhook_room(token)
+ room_id = await self.bot.db.get_webhook_room(token)
if not room_id:
return Response(text="401: Unauthorized\n", status=401)
@@ -177,15 +177,16 @@ def abort() -> None:
}
content["com.beeper.linkpreviews"] = []
- edit_evt = self.bot.db.get_event(subevt.message_id, room_id)
+ edit_evt = await self.bot.db.get_event(subevt.message_id, room_id)
if edit_evt:
content.set_edit(edit_evt)
- event_id = await self.bot.client.send_message(room_id, content)
- if not edit_evt and subevt.message_id:
- self.bot.db.put_event(subevt.message_id, room_id, event_id)
+ event_id = await self.bot.client.send_message(room_id, content, retry_count=5)
+
+ if not edit_evt and subevt.message_id and event_id:
+ await self.bot.db.put_event(subevt.message_id, room_id, event_id)
async def handle_job_event(self, evt: GitlabJobEvent, evt_type: str, room_id: RoomID) -> None:
- push_evt = self.bot.db.get_event(evt.push_id, room_id)
+ push_evt = await self.bot.db.get_event(evt.push_id, room_id)
if not push_evt:
self.bot.log.debug(f"No message found to react to push {evt.push_id}")
return
@@ -201,11 +202,13 @@ async def handle_job_event(self, evt: GitlabJobEvent, evt_type: str, room_id: Ro
**evt.meta,
}
- prev_reaction = self.bot.db.get_event(evt.reaction_id, room_id)
+ prev_reaction = await self.bot.db.get_event(evt.reaction_id, room_id)
if prev_reaction:
- await self.bot.client.redact(room_id, prev_reaction)
- event_id = await self.bot.client.send_message_event(room_id, EventType.REACTION, reaction)
- self.bot.db.put_event(evt.reaction_id, room_id, event_id, merge=prev_reaction is not None)
+ await self.bot.client.redact(room_id, prev_reaction, retry_count=5)
+
+ event_id = await self.bot.client.send_message_event(room_id, EventType.REACTION, reaction, retry_count=5)
+ if event_id:
+ await self.bot.db.put_event(evt.reaction_id, room_id, event_id)
@event.on(EventType.ROOM_MEMBER)
async def member_handler(self, evt: StateEvent) -> None:
diff --git a/maubot.yaml b/maubot.yaml
index a871e45..9d888e4 100644
--- a/maubot.yaml
+++ b/maubot.yaml
@@ -18,4 +18,5 @@ soft_dependencies:
webapp: true
database: true
+database_type: asyncpg
config: true
From 3b5a05ab411bf1316cd9825a7d3c32891a484b4f Mon Sep 17 00:00:00 2001
From: Thomas Ieong
Date: Wed, 13 Dec 2023 16:13:09 +0100
Subject: [PATCH 2/2] Add a feature to notify only when a job fails
---
gitlab_matrix/util/config.py | 9 +++++++++
gitlab_matrix/webhook.py | 5 ++++-
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/gitlab_matrix/util/config.py b/gitlab_matrix/util/config.py
index 72c2ca2..a9f4c42 100644
--- a/gitlab_matrix/util/config.py
+++ b/gitlab_matrix/util/config.py
@@ -15,11 +15,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
import secrets
+import os
+from typing import Any
from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper
class Config(BaseProxyConfig):
+ def __getitem__(self, key: str) -> Any:
+ try:
+ return os.environ[f"MAUBOT_GITLAB_{key.replace('.', '_').upper()}"]
+ except KeyError:
+ return super().__getitem__(key)
+
def do_update(self, helper: ConfigUpdateHelper) -> None:
if not self["secret"] or self["secret"] == "put a random password here":
helper.base["secret"] = secrets.token_urlsafe(32)
@@ -28,3 +36,4 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
helper.copy("base_command")
helper.copy("send_as_notice")
helper.copy("time_format")
+ helper.copy("notify_only_on_failure")
diff --git a/gitlab_matrix/webhook.py b/gitlab_matrix/webhook.py
index bd98985..c908af4 100644
--- a/gitlab_matrix/webhook.py
+++ b/gitlab_matrix/webhook.py
@@ -29,7 +29,7 @@
from mautrix.util.formatter import parse_html
from maubot.handlers import web, event
-from .types import GitlabJobEvent, EventParse, Action, OTHER_ENUMS
+from .types import GitlabJobEvent, EventParse, Action, BuildStatus, OTHER_ENUMS
from .util import TemplateManager, TemplateUtil
if TYPE_CHECKING:
@@ -131,6 +131,9 @@ async def process_hook(self, body: JSON, evt_type: str, room_id: RoomID) -> None
was_manually_handled = True
if isinstance(evt, GitlabJobEvent):
+ if self.bot.config["notify_only_on_failure"] and evt.build_status != BuildStatus.FAILED:
+ return
+
await self.handle_job_event(evt, evt_type, room_id)
else:
was_manually_handled = False