Skip to content

Commit

Permalink
fix: support pg v10
Browse files Browse the repository at this point in the history
Signed-off-by: closetool <4closetool3@gmail.com>
  • Loading branch information
kilosonc committed May 6, 2021
1 parent 4c48282 commit 7b7b71d
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 90 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
name: Go

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
- push
- pull_request

jobs:

Expand Down
76 changes: 38 additions & 38 deletions adapter.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
package pgadapter

import (
"context"
"fmt"
"strings"

"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v9/orm"
"github.com/go-pg/pg/v9/types"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/mmcloughlin/meow"
)

const DefaultTableName = "casbin_rule"

// CasbinRule represents a rule in Casbin.
type CasbinRule struct {
ID string
PType string
V0 string
V1 string
V2 string
V3 string
V4 string
V5 string
tableName struct{} `pg:"_"`
ID string
PType string
V0 string
V1 string
V2 string
V3 string
V4 string
V5 string
}

type Filter struct {
Expand All @@ -48,7 +51,7 @@ func NewAdapter(arg interface{}) (*Adapter, error) {
return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
}

a := &Adapter{db: db}
a := &Adapter{db: db, tableName: DefaultTableName}

if err := a.createTableifNotExists(); err != nil {
return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
Expand All @@ -60,17 +63,11 @@ func NewAdapter(arg interface{}) (*Adapter, error) {
// NewAdapterByDB creates new Adapter by using existing DB connection
// creates table from CasbinRule struct if it doesn't exist
func NewAdapterByDB(db *pg.DB, opts ...Option) (*Adapter, error) {
a := &Adapter{db: db}
a := &Adapter{db: db, tableName: DefaultTableName}
for _, opt := range opts {
opt(a)
}

if len(a.tableName) > 0 {
a.db.Model((*CasbinRule)(nil)).TableModel().Table().Name = a.tableName
a.db.Model((*CasbinRule)(nil)).TableModel().Table().FullName = (types.Safe)(a.tableName)
a.db.Model((*CasbinRule)(nil)).TableModel().Table().FullNameForSelects = (types.Safe)(a.tableName)
}

if !a.skipTableCreate {
if err := a.createTableifNotExists(); err != nil {
return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
Expand Down Expand Up @@ -130,7 +127,7 @@ func (a *Adapter) Close() error {
}

func (a *Adapter) createTableifNotExists() error {
err := a.db.CreateTable(&CasbinRule{}, &orm.CreateTableOptions{
err := a.db.Model((*CasbinRule)(nil)).Table(a.tableName).CreateTable(&orm.CreateTableOptions{
Temp: false,
IfNotExists: true,
})
Expand Down Expand Up @@ -183,7 +180,7 @@ func (r *CasbinRule) String() string {
func (a *Adapter) LoadPolicy(model model.Model) error {
var lines []*CasbinRule

if err := a.db.Model(&lines).Select(); err != nil {
if err := a.db.Model(&lines).Table(a.tableName).Select(); err != nil {
return err
}

Expand Down Expand Up @@ -238,7 +235,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
defer tx.Close()

_, err = tx.Model((*CasbinRule)(nil)).Where("id IS NOT NULL").Delete()
_, err = tx.Model((*CasbinRule)(nil)).Table(a.tableName).Where("id IS NOT NULL").Delete()
if err != nil {
return err
}
Expand All @@ -259,8 +256,8 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
}

if len(lines) > 0 {
_, err = tx.Model(&lines).
for _, line := range lines {
_, err = tx.Model(line).Table(a.tableName).
OnConflict("DO NOTHING").
Insert()
if err != nil {
Expand All @@ -279,8 +276,9 @@ func (a *Adapter) SavePolicy(model model.Model) error {
// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
err := a.db.RunInTransaction(func(tx *pg.Tx) error {
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := a.db.Model(line).
Table(a.tableName).
OnConflict("DO NOTHING").
Insert()

Expand All @@ -298,8 +296,9 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
lines = append(lines, line)
}

err := a.db.RunInTransaction(func(tx *pg.Tx) error {
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := tx.Model(&lines).
Table(a.tableName).
OnConflict("DO NOTHING").
Insert()
return err
Expand All @@ -311,8 +310,9 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
// RemovePolicy removes a policy rule from the storage.
func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
line := savePolicyLine(ptype, rule)
err := a.db.RunInTransaction(func(tx *pg.Tx) error {
return a.db.Delete(line)
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := a.db.Model(line).Table(a.tableName).WherePK().Delete()
return err
})

return err
Expand All @@ -326,8 +326,8 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err
lines = append(lines, line)
}

err := a.db.RunInTransaction(func(tx *pg.Tx) error {
_, err := tx.Model(&lines).
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := tx.Model(&lines).Table(a.tableName).
Delete()
return err
})
Expand All @@ -337,7 +337,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
query := a.db.Model((*CasbinRule)(nil)).Where("p_type = ?", ptype)
query := a.db.Model((*CasbinRule)(nil)).Table(a.tableName).Where("p_type = ?", ptype)

idx := fieldIndex + len(fieldValues)
if fieldIndex <= 0 && idx > 0 && fieldValues[0-fieldIndex] != "" {
Expand All @@ -359,7 +359,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
query = query.Where("v5 = ?", fieldValues[5-fieldIndex])
}

err := a.db.RunInTransaction(func(tx *pg.Tx) error {
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := query.Delete()
return err
})
Expand Down Expand Up @@ -413,7 +413,7 @@ func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler
if filter.P != nil {
lines := []*CasbinRule{}

query := a.db.Model(&lines).Where("p_type = 'p'")
query := a.db.Model(&lines).Table(a.tableName).Where("p_type = 'p'")
query, err := buildQuery(query, filter.P)
if err != nil {
return err
Expand All @@ -430,7 +430,7 @@ func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler
if filter.G != nil {
lines := []*CasbinRule{}

query := a.db.Model(&lines).Where("p_type = 'g'")
query := a.db.Model(&lines).Table(a.tableName).Where("p_type = 'g'")
query, err := buildQuery(query, filter.G)
if err != nil {
return err
Expand Down Expand Up @@ -484,12 +484,12 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, oldRules, new
newLines = append(newLines, line)
}

err := a.db.RunInTransaction(func(tx *pg.Tx) error {
_, err := tx.Model(&oldLines).Delete()
err := a.db.RunInTransaction(context.Background(), func(tx *pg.Tx) error {
_, err := tx.Model(&oldLines).Table(a.tableName).Delete()
if err != nil {
return err
}
_, err = tx.Model(&newLines).
_, err = tx.Model(&newLines).Table(a.tableName).
OnConflict("DO NOTHING").
Insert()
return err
Expand All @@ -506,7 +506,7 @@ func (a *Adapter) updatePolicies(oldLines, newLines []*CasbinRule) error {

for i, line := range oldLines {
newLines[i].ID = line.ID
_, err = tx.Model(newLines[i]).WherePK().Update()
_, err = tx.Model(newLines[i]).Table(a.tableName).WherePK().Update()
if err != nil {
tx.Rollback()
return err
Expand Down
2 changes: 1 addition & 1 deletion adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
"github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v10"
"github.com/stretchr/testify/suite"
)

Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module github.com/casbin/casbin-pg-adapter

go 1.13
go 1.14

require (
github.com/casbin/casbin/v2 v2.23.4
github.com/go-pg/pg/v9 v9.1.0
github.com/go-pg/pg/v10 v10.9.1
github.com/mmcloughlin/meow v0.0.0-20181112033425-871e50784daf
github.com/stretchr/testify v1.6.1
github.com/stretchr/testify v1.7.0
)
Loading

0 comments on commit 7b7b71d

Please sign in to comment.