From 4329087515d48c74618abc4eec15b7ff9dccba73 Mon Sep 17 00:00:00 2001 From: closetool Date: Tue, 2 Mar 2021 14:49:41 +0800 Subject: [PATCH] fix: update action change all rows(#23) Signed-off-by: closetool --- adapter.go | 10 +++++----- adapter_test.go | 29 +++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/adapter.go b/adapter.go index 3aac5e5..c62a6a6 100644 --- a/adapter.go +++ b/adapter.go @@ -480,13 +480,13 @@ func (a *Adapter) updatePolicies(oldLines, newLines []*CasbinRule) error { for i, line := range oldLines { newLines[i].ID = line.ID + _, err = tx.Model(newLines[i]).WherePK().Update() + if err != nil { + tx.Rollback() + return err + } } - _, err = tx.Model(&newLines).WherePK().Update() - if err != nil { - tx.Rollback() - return err - } if err = tx.Commit(); err != nil { return err } diff --git a/adapter_test.go b/adapter_test.go index d2d55dd..acb45a8 100644 --- a/adapter_test.go +++ b/adapter_test.go @@ -293,10 +293,35 @@ func (s *AdapterTestSuite) TestUpdatePolicy() { err = s.e.SavePolicy() s.Require().NoError(err) - _, err = s.e.UpdatePolicies([][]string{{"alice", "data1", "read"}}, [][]string{{"bob", "data1", "read"}}) + _, err = s.e.UpdatePolicies([][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}, [][]string{{"bob", "data1", "read"}, {"alice", "data2", "write"}}) s.Require().NoError(err) - s.assertPolicy(s.e.GetPolicy(), [][]string{{"bob", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + err = s.e.LoadPolicy() + s.Require().NoError(err) + + s.assertPolicy(s.e.GetPolicy(), [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data1", "read"}, {"alice", "data2", "write"}}) +} + +func (s *AdapterTestSuite) TestUpdatePolicyWithLoadFilteredPolicy() { + var err error + s.e, err = casbin.NewEnforcer("examples/rbac_model.conf", "examples/rbac_policy.csv") + s.Require().NoError(err) + + s.e.SetAdapter(s.a) + + err = s.e.SavePolicy() + s.Require().NoError(err) + + err = s.e.LoadFilteredPolicy(&Filter{P: []string{"data2_admin"}}) + s.Require().NoError(err) + + _, err = s.e.UpdatePolicies(s.e.GetPolicy(), [][]string{{"bob", "data2", "read"}, {"alice", "data2", "write"}}) + s.Require().NoError(err) + + err = s.e.LoadPolicy() + s.Require().NoError(err) + + s.assertPolicy(s.e.GetPolicy(), [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"bob", "data2", "read"}, {"alice", "data2", "write"}}) } func TestAdapterTestSuite(t *testing.T) {