From 56c866ff405eb7083188cced8a82bc7f8b5408f9 Mon Sep 17 00:00:00 2001 From: Yevhen Radionov Date: Thu, 18 Jun 2020 18:31:49 +0300 Subject: [PATCH] Use explicit transactions (#11) * Use explicit transactions for DB manipulations * Preallocate memory for strings.Builder * Use tx.Close instead of tx.Rollback --- adapter.go | 48 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/adapter.go b/adapter.go index 98cd247..bb59d08 100644 --- a/adapter.go +++ b/adapter.go @@ -133,6 +133,12 @@ func (r *CasbinRule) String() string { const prefixLine = ", " var sb strings.Builder + sb.Grow( + len(r.PType) + + len(r.V0) + len(r.V1) + len(r.V2) + + len(r.V3) + len(r.V4) + len(r.V5), + ) + sb.WriteString(r.PType) if len(r.V0) > 0 { sb.WriteString(prefixLine) @@ -215,7 +221,13 @@ func savePolicyLine(ptype string, rule []string) *CasbinRule { // SavePolicy saves policy to database. func (a *Adapter) SavePolicy(model model.Model) error { - _, err := a.db.Model((*CasbinRule)(nil)).Where("id IS NOT NULL").Delete() + tx, err := a.db.Begin() + if err != nil { + return fmt.Errorf("start DB transaction: %v", err) + } + defer tx.Close() + + _, err = tx.Model((*CasbinRule)(nil)).Where("id IS NOT NULL").Delete() if err != nil { return err } @@ -237,20 +249,33 @@ func (a *Adapter) SavePolicy(model model.Model) error { } if len(lines) > 0 { - _, err = a.db.Model(&lines). + _, err = tx.Model(&lines). OnConflict("DO NOTHING"). Insert() - return err + if err != nil { + return err + } + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("commit DB transaction: %v", err) } + return nil } // AddPolicy adds a policy rule to the storage. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { line := savePolicyLine(ptype, rule) - _, err := a.db.Model(line). - OnConflict("DO NOTHING"). - Insert() + err := a.db.RunInTransaction(func(tx *pg.Tx) error { + _, err := a.db.Model(line). + OnConflict("DO NOTHING"). + Insert() + + return err + }) + return err } @@ -275,7 +300,10 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error // RemovePolicy removes a policy rule from the storage. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { line := savePolicyLine(ptype, rule) - err := a.db.Delete(line) + err := a.db.RunInTransaction(func(tx *pg.Tx) error { + return a.db.Delete(line) + }) + return err } @@ -320,7 +348,11 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, query = query.Where("v5 = ?", fieldValues[5-fieldIndex]) } - _, err := query.Delete() + err := a.db.RunInTransaction(func(tx *pg.Tx) error { + _, err := query.Delete() + return err + }) + return err }