Skip to content

Commit

Permalink
Add support for adding proof fields via !addproof command
Browse files Browse the repository at this point in the history
  • Loading branch information
leighmacdonald committed Mar 24, 2024
1 parent 24c67ca commit 0614890
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 24 deletions.
63 changes: 49 additions & 14 deletions tf2bdd/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,17 @@ func addEntry(ctx context.Context, database *sql.DB, sid steamid.SteamID, msg []
Time: time.Now().Unix(),
},
SteamID: sid,
Proof: []string{},
}

if err := AddPlayer(ctx, database, player, author); err != nil {
if err.Error() == "UNIQUE constraint failed: player.steamid" {
if errors.Is(err, ErrDuplicate) {
return "", fmt.Errorf("duplicate steam id: %s", sid.String())
}

slog.Error("Failed to add player", slog.String("error", err.Error()))

return "", fmt.Errorf("oops")
return "", err
}

return fmt.Sprintf("Added new entry successfully: %s", sid.String()), nil
Expand All @@ -152,12 +153,30 @@ func addEntry(ctx context.Context, database *sql.DB, sid steamid.SteamID, msg []
func checkEntry(ctx context.Context, database *sql.DB, sid steamid.SteamID) (string, error) {
player, errPlayer := getPlayer(ctx, database, sid)
if errPlayer != nil {
return "", fmt.Errorf("steam id does not exist in database: %d", sid.Int64())
if errors.Is(errPlayer, ErrNotFound) {
return "", fmt.Errorf("steam id does not exist in database: %d", sid.Int64())
}

return "", errPlayer
}

var builder strings.Builder
builder.WriteString(fmt.Sprintf("\n:skull_crossbones: **%s is a confirmed baddie** :skull_crossbones:\n", player.LastSeen.PlayerName))
builder.WriteString(fmt.Sprintf("**Attributes:** %s\n", strings.Join(player.Attributes, ", ")))
for idx, proof := range player.Proof {
if strings.HasPrefix(proof, "http") {
builder.WriteString(fmt.Sprintf("**Proof #%d:** <%s>\n", idx, proof))
} else {
builder.WriteString(fmt.Sprintf("**Proof #%d:** %s\n", idx, proof))
}
}
builder.WriteString(fmt.Sprintf("**Added on:** %s\n", player.CreatedOn.String()))
if player.Author > 0 {
builder.WriteString(fmt.Sprintf("**Author:** <@%d>\n", player.Author))
}
builder.WriteString(fmt.Sprintf("**Profile:** <https://steamcommunity.com/profiles/%s>", sid.String()))

return fmt.Sprintf(":skull_crossbones: %s is a confirmed baddie :skull_crossbones: "+
"https://steamcommunity.com/profiles/%d \nAttributes: %s\nAuthor: <@%d>\nCreated: %s",
player.LastSeen.PlayerName, sid.Int64(), strings.Join(player.Attributes, ","), player.Author, player.CreatedOn.String()), nil
return builder.String(), nil
}

func getSteamid(sid steamid.SteamID) string {
Expand Down Expand Up @@ -279,16 +298,19 @@ func messageCreate(ctx context.Context, database *sql.DB, config Config) func(*d
}
msg := strings.Split(strings.ToLower(message.Content), " ")
minArgs := map[string]int{
"!del": 2,
"!check": 2,
"!add": 2,
"!steamid": 2,
"!import": 1,
"!count": 1,
"!del": 2,
"!check": 2,
"!add": 2,
"!steamid": 2,
"!import": 1,
"!count": 1,
"!addproof": 3,
}

argCount, found := minArgs[msg[0]]
if !found {
sendMsg(session, message, fmt.Sprintf("unknown command: %s", msg[0]))

return
}

Expand Down Expand Up @@ -342,7 +364,7 @@ func messageCreate(ctx context.Context, database *sql.DB, config Config) func(*d
case "!check":
response, cmdErr = checkEntry(ctx, database, sid)
case "!addproof":
response, cmdErr = addProof(ctx, database, sid, msg[1:])
response, cmdErr = addProof(ctx, database, sid, msg[2:])
case "!add":
author, errAuthor := strconv.ParseInt(message.Author.ID, 10, 64)
if errAuthor != nil {
Expand All @@ -369,11 +391,24 @@ func messageCreate(ctx context.Context, database *sql.DB, config Config) func(*d
}
}

func addProof(ctx context.Context, database *sql.DB, sid steamid.SteamID, msg []string) (string, error) {
func addProof(ctx context.Context, database *sql.DB, sid steamid.SteamID, proofs []string) (string, error) {
player, errPlayer := getPlayer(ctx, database, sid)
if errPlayer != nil {
return "", errPlayer
}

for _, proof := range proofs {
if slices.Contains(player.Proof, proof) {
return "", errors.New("duplicate proof provided")
}
player.Proof = append(player.Proof, proof)
}

if errUpdate := updatePlayer(ctx, database, player); errUpdate != nil {
return "", errors.Join(errUpdate, errors.New("could not update player entry"))
}

return fmt.Sprintf("Added %d proof entries", len(proofs)), nil
}

func sendMsg(s *discordgo.Session, m *discordgo.MessageCreate, msg string) {
Expand Down
90 changes: 81 additions & 9 deletions tf2bdd/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package tf2bdd
import (
"context"
"database/sql"
"database/sql/driver"
"embed"
"errors"
"fmt"
"log/slog"
"strings"
"time"
Expand All @@ -13,6 +15,7 @@ import (
"github.com/golang-migrate/migrate/v4/database/sqlite"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/leighmacdonald/steamid/v4/steamid"
"github.com/ncruces/go-sqlite3"
)

//go:embed migrations/*.sql
Expand All @@ -25,8 +28,30 @@ var (
ErrStoreDriver = errors.New("failed to create db driver")
ErrCreateMigration = errors.New("failed to create migrator")
ErrPerformMigration = errors.New("failed to migrate database")
ErrDuplicate = errors.New("duplicate entry")
ErrNotFound = errors.New("entry not found")
)

func dbErr(err error) error {
if err == nil {
return nil
}
if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound
}
var sqliteErr *sqlite3.Error
if errors.As(err, &sqliteErr) {
switch {
case errors.Is(sqliteErr.Code(), sqlite3.CONSTRAINT):
return ErrDuplicate
default:
return fmt.Errorf("unhandled sqlite error: %w", err)
}
}

return err
}

func OpenDB(dbPath string) (*sql.DB, error) {
database, errOpen := sql.Open("sqlite3", dbPath)
if errOpen != nil {
Expand Down Expand Up @@ -75,8 +100,50 @@ func migrateDB(database *sql.DB) error {
return nil
}

const proofSep = "^^"

type Proof []string

func (p *Proof) Scan(value interface{}) error {
strVal, ok := value.(string)
if !ok {
return errors.New("invalid type")
}
if strVal == "" {
*p = []string{}

return nil
}

*p = strings.Split(strVal, proofSep)

return nil
}

func (p Proof) Value() (driver.Value, error) {
return strings.Join(p, proofSep), nil
}

func updatePlayer(ctx context.Context, database *sql.DB, player Player) error {
const query = `
UPDATE player
SET attributes = ?,
last_seen = ?,
last_name = ?,
author = ?,
proof = ?
WHERE steamid = ?`

if _, errExec := database.ExecContext(ctx, query, strings.Join(player.Attributes, ","), player.LastSeen.Time, player.LastSeen.PlayerName,
player.Author, player.Proof, player.SteamID.Int64()); errExec != nil {
return errExec
}

return nil
}

func getPlayer(ctx context.Context, database *sql.DB, steamID steamid.SteamID) (Player, error) {
const query = `SELECT steamid, attributes, last_seen, last_name, author, created_on FROM player WHERE steamid = ?`
const query = `SELECT steamid, attributes, last_seen, last_name, author, created_on, proof FROM player WHERE steamid = ?`

var (
player Player
Expand All @@ -85,12 +152,13 @@ func getPlayer(ctx context.Context, database *sql.DB, steamID steamid.SteamID) (
lastSeen int64
lastName string
createdOn int64
proof Proof
)

if errScan := database.
QueryRowContext(ctx, query, steamID.Int64()).
Scan(&sid, &attrs, &lastSeen, &lastName, &player.Author, &createdOn); errScan != nil {
return Player{}, errScan
Scan(&sid, &attrs, &lastSeen, &lastName, &player.Author, &createdOn, &proof); errScan != nil {
return Player{}, dbErr(errScan)
}

player.CreatedOn = time.Unix(createdOn, 0)
Expand All @@ -100,12 +168,13 @@ func getPlayer(ctx context.Context, database *sql.DB, steamID steamid.SteamID) (
PlayerName: lastName,
Time: lastSeen,
}
player.Proof = proof

return player, nil
}

func getPlayers(ctx context.Context, db *sql.DB) ([]Player, error) {
const query = `SELECT steamid, attributes, last_seen, last_name, author, created_on FROM player`
const query = `SELECT steamid, attributes, last_seen, last_name, author, created_on, proof FROM player`

rows, err := db.QueryContext(ctx, query)
if err != nil {
Expand All @@ -128,9 +197,10 @@ func getPlayers(ctx context.Context, db *sql.DB) ([]Player, error) {
lastSeen int64
lastName string
createdOn int64
proof Proof
)

if errScan := rows.Scan(&sid, &attrs, &lastSeen, &lastName, &player.Author, &createdOn); errScan != nil {
if errScan := rows.Scan(&sid, &attrs, &lastSeen, &lastName, &player.Author, &createdOn, &proof); errScan != nil {
return nil, errors.Join(errScan, errors.New("error scanning player row"))
}

Expand All @@ -141,6 +211,7 @@ func getPlayers(ctx context.Context, db *sql.DB) ([]Player, error) {
PlayerName: lastName,
Time: lastSeen,
}
player.Proof = proof

players = append(players, player)
}
Expand All @@ -154,17 +225,18 @@ func getPlayers(ctx context.Context, db *sql.DB) ([]Player, error) {

func AddPlayer(ctx context.Context, db *sql.DB, player Player, author int64) error {
const query = `
INSERT INTO player (steamid, attributes, last_seen, last_name, author, created_on)
VALUES(?, ?, ?, ?, ?, ?)`
INSERT INTO player (steamid, attributes, last_seen, last_name, author, created_on, proof)
VALUES(?, ?, ?, ?, ?, ?, ?)`

if _, err := db.ExecContext(ctx, query,
player.SteamID.Int64(),
strings.ToLower(strings.Join(player.Attributes, ",")),
player.LastSeen.Time,
player.LastSeen.PlayerName,
author,
int(time.Now().Unix())); err != nil {
return err
int(time.Now().Unix()),
player.Proof); err != nil {
return dbErr(err)
}

return nil
Expand Down
2 changes: 1 addition & 1 deletion tf2bdd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Player struct {
LastSeen LastSeen `json:"last_seen,omitempty"`
Author int64 `json:"-"`
CreatedOn time.Time `json:"-"`
Proof []string `json:"proof"`
Proof Proof `json:"proof"`
}

func handleGetSteamIDs(database *sql.DB, config Config) http.HandlerFunc {
Expand Down

0 comments on commit 0614890

Please sign in to comment.