Skip to content

Commit

Permalink
implement FilteredAdapter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pckhoi committed Dec 27, 2019
1 parent d413a3f commit e91692c
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 32 deletions.
131 changes: 111 additions & 20 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/casbin/casbin/v2/persist"
"github.com/go-pg/pg/v9"
"github.com/go-pg/pg/v9/orm"
"golang.org/x/crypto/sha3"
"github.com/mmcloughlin/meow"
)

const (
Expand All @@ -27,9 +27,15 @@ type CasbinRule struct {
V5 string
}

type Filter struct {
P []string
G []string
}

// Adapter represents the github.com/go-pg/pg adapter for policy storage.
type Adapter struct {
db *pg.DB
db *pg.DB
filtered bool
}

// NewAdapter is the constructor for Adapter.
Expand Down Expand Up @@ -98,37 +104,37 @@ func (a *Adapter) createTable() error {
return nil
}

func loadPolicyLine(line *CasbinRule, model model.Model) {
func (r *CasbinRule) String() string {
const prefixLine = ", "
var sb strings.Builder

sb.WriteString(line.PType)
if len(line.V0) > 0 {
sb.WriteString(r.PType)
if len(r.V0) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V0)
sb.WriteString(r.V0)
}
if len(line.V1) > 0 {
if len(r.V1) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V1)
sb.WriteString(r.V1)
}
if len(line.V2) > 0 {
if len(r.V2) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V2)
sb.WriteString(r.V2)
}
if len(line.V3) > 0 {
if len(r.V3) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V3)
sb.WriteString(r.V3)
}
if len(line.V4) > 0 {
if len(r.V4) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V4)
sb.WriteString(r.V4)
}
if len(line.V5) > 0 {
if len(r.V5) > 0 {
sb.WriteString(prefixLine)
sb.WriteString(line.V5)
sb.WriteString(r.V5)
}

persist.LoadPolicyLine(sb.String(), model)
return sb.String()
}

// LoadPolicy loads policy from database.
Expand All @@ -140,16 +146,17 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
}

for _, line := range lines {
loadPolicyLine(line, model)
persist.LoadPolicyLine(line.String(), model)
}

a.filtered = false

return nil
}

func policyID(ptype string, rule []string) string {
data := strings.Join(append([]string{ptype}, rule...), ",")
sum := make([]byte, 64)
sha3.ShakeSum128(sum, []byte(data))
sum := meow.Checksum(0, []byte(data))
return fmt.Sprintf("%x", sum)
}

Expand Down Expand Up @@ -248,3 +255,87 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
_, err := query.Delete()
return err
}

func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
if filter == nil {
return a.LoadPolicy(model)
}

filterValue, ok := filter.(*Filter)
if !ok {
return fmt.Errorf("invalid filter type")
}
err := a.loadFilteredPolicy(model, filterValue, persist.LoadPolicyLine)
if err != nil {
return err
}
a.filtered = true
return nil
}

func buildQuery(query *orm.Query, values []string) (*orm.Query, error) {
for ind, v := range values {
if v == "" {
continue
}
switch ind {
case 0:
query = query.Where("v0 = ?", v)
case 1:
query = query.Where("v1 = ?", v)
case 2:
query = query.Where("v2 = ?", v)
case 3:
query = query.Where("v3 = ?", v)
case 4:
query = query.Where("v4 = ?", v)
case 5:
query = query.Where("v5 = ?", v)
default:
return nil, fmt.Errorf("filter has more values than expected, should not exceed 6 values")
}
}
return query, nil
}

func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler func(string, model.Model)) error {
if filter.P != nil {
lines := []*CasbinRule{}

query := a.db.Model(&lines).Where("p_type = 'p'")
query, err := buildQuery(query, filter.P)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}

for _, line := range lines {
handler(line.String(), model)
}
}
if filter.G != nil {
lines := []*CasbinRule{}

query := a.db.Model(&lines).Where("p_type = 'g'")
query, err := buildQuery(query, filter.G)
if err != nil {
return err
}
err = query.Select()
if err != nil {
return err
}

for _, line := range lines {
handler(line.String(), model)
}
}
return nil
}

func (a *Adapter) IsFiltered() bool {
return a.filtered
}
69 changes: 57 additions & 12 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ type AdapterTestSuite struct {
a *Adapter
}

func (s *AdapterTestSuite) testGetPolicy(res [][]string) {
func (s *AdapterTestSuite) assertPolicy(expected, res [][]string) {
s.T().Helper()
myRes := s.e.GetPolicy()
s.Assert().True(util.Array2DEquals(res, myRes), "Policy Got: %v, supposed to be %v", myRes, res)
s.Assert().True(util.Array2DEquals(expected, res), "Policy Got: %v, supposed to be %v", res, expected)
}

func (s *AdapterTestSuite) dropCasbinDB() {
Expand Down Expand Up @@ -56,7 +55,8 @@ func (s *AdapterTestSuite) TearDownTest() {
}

func (s *AdapterTestSuite) TestSaveLoad() {
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.Assert().False(s.e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}

func (s *AdapterTestSuite) TestAutoSave() {
Expand All @@ -72,7 +72,7 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy()
s.Require().NoError(err)
// This is still the original policy.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())

// Now we enable the AutoSave.
s.e.EnableAutoSave(true)
Expand All @@ -85,14 +85,14 @@ func (s *AdapterTestSuite) TestAutoSave() {
err = s.e.LoadPolicy()
s.Require().NoError(err)
// The policy has a new rule: {"alice", "data1", "write"}.
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())

// Aditional AddPolicy have no effect
_, err = s.e.AddPolicy("alice", "data1", "write")
s.Require().NoError(err)
err = s.e.LoadPolicy()
s.Require().NoError(err)
s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}, s.e.GetPolicy())
s.Require().NoError(err)
}

Expand All @@ -107,31 +107,76 @@ func (s *AdapterTestSuite) TestConstructorOptions() {
s.e, err = casbin.NewEnforcer("examples/rbac_model.conf", a)
s.Require().NoError(err)

s.testGetPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}

func (s *AdapterTestSuite) TestRemovePolicy() {
_, err := s.e.RemovePolicy("alice", "data1", "read")
s.Require().NoError(err)

s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())

err = s.e.LoadPolicy()
s.Require().NoError(err)

s.testGetPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
s.assertPolicy([][]string{{"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}

func (s *AdapterTestSuite) TestRemoveFilteredPolicy() {
_, err := s.e.RemoveFilteredPolicy(0, "", "data2")
s.Require().NoError(err)

s.testGetPolicy([][]string{{"alice", "data1", "read"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())

err = s.e.LoadPolicy()
s.Require().NoError(err)

s.testGetPolicy([][]string{{"alice", "data1", "read"}})
s.assertPolicy([][]string{{"alice", "data1", "read"}}, s.e.GetPolicy())
}

func (s *AdapterTestSuite) TestLoadFilteredPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)

err = e.LoadFilteredPolicy(&Filter{
P: []string{"", "", "read"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"data2_admin", "data2", "read"}}, e.GetPolicy())
}

func (s *AdapterTestSuite) TestLoadFilteredGroupingPolicy() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)

err = e.LoadFilteredPolicy(&Filter{
G: []string{"bob"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{}, e.GetGroupingPolicy())

e, err = casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)

err = e.LoadFilteredPolicy(&Filter{
G: []string{"alice"},
})
s.Require().NoError(err)
s.Assert().True(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data2_admin"}}, e.GetGroupingPolicy())
}

func (s *AdapterTestSuite) TestLoadFilteredPolicyNilFilter() {
e, err := casbin.NewEnforcer("examples/rbac_model.conf", s.a)
s.Require().NoError(err)

err = e.LoadFilteredPolicy(nil)
s.Require().NoError(err)

s.Assert().False(e.IsFiltered())
s.assertPolicy([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}, s.e.GetPolicy())
}

func TestAdapterTestSuite(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/casbin/casbin v1.9.1
github.com/casbin/casbin/v2 v2.1.2
github.com/go-pg/pg/v9 v9.1.0
github.com/mmcloughlin/meow v0.0.0-20181112033425-871e50784daf
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413
)

0 comments on commit e91692c

Please sign in to comment.