Skip to content

Commit

Permalink
config
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Dec 8, 2024
1 parent e9b847b commit 264b70e
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions pgxtester.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main

import (
"context"
"os"
"sync"
"testing"
"time"
Expand All @@ -11,31 +12,51 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)

const timeout = 2 * time.Second

type DBTX interface {
Exec(context.Context, string, ...any) (pgconn.CommandTag, error)
Query(context.Context, string, ...any) (pgx.Rows, error)
QueryRow(context.Context, string, ...any) pgx.Row
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
}

func Connect(t *testing.T, url string) DBTX {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := pgx.Connect(ctx, url)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
type Config struct {
// Connection URL to be used if Conn is not provided.
URL string
// The connection to wrap around. If nil, a new connection will be established using URL.
Conn *pgx.Conn
// The timeout for establishing connection, creating trnsaction, and rolling it back.
//
// Default is 2 seconds.
Timeout time.Duration
// The options for creating a transaction.
TxOptions pgx.TxOptions
}

func Connect(t *testing.T, c Config) DBTX {
if c.Timeout == 0 {
c.Timeout = 2 * time.Second
}
if c.Conn == nil {
if c.URL == "" {
c.URL = os.Getenv("POSTGRES_URL")
}
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
conn, err := pgx.Connect(ctx, c.URL)
if err != nil {
t.Fatalf("failed to connect to PostgreSQL: %v", err)
}
c.Conn = conn
}

ctx, cancel = context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
tx, err := conn.BeginTx(ctx, pgx.TxOptions{})
tx, err := c.Conn.BeginTx(ctx, c.TxOptions)
if err != nil {
t.Fatalf("failed to begin tansaction: %v", err)
}
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout)
defer cancel()
err := tx.Rollback(ctx)
if err != nil {
Expand Down

0 comments on commit 264b70e

Please sign in to comment.