Skip to content

Commit

Permalink
feat: specify dbname in NewAdapter() (#41)
Browse files Browse the repository at this point in the history
* feat: specify dbname in NewAdapter()

* Update adapter.go

Co-authored-by: Yang Luo <hsluoyz@qq.com>
  • Loading branch information
JalinWang and hsluoyz authored Jul 13, 2022
1 parent fc569f3 commit f5f993b
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

const DefaultTableName = "casbin_rule"
const DefaultDatabaseName = "casbin"

// CasbinRule represents a rule in Casbin.
type CasbinRule struct {
Expand Down Expand Up @@ -43,10 +44,20 @@ type Adapter struct {
type Option func(a *Adapter)

// NewAdapter is the constructor for Adapter.
// arg should be a PostgreS URL string or of type *pg.Options
// The adapter will create a DB named "casbin" if it doesn't exist
func NewAdapter(arg interface{}) (*Adapter, error) {
db, err := createCasbinDatabase(arg)
// param:arg should be a PostgreS URL string or of type *pg.Options
// param:dbname is the name of the database to use and can is optional.
// If no dbname is provided, the default database name is "casbin" which will be created automatically.
// If arg is *pg.Options, the arg.Database field is omitted and will be modified according to dbname
func NewAdapter(arg interface{}, dbname ...string) (*Adapter, error) {
var db *pg.DB
var err error

if len(dbname) > 0 {
db, err = createCasbinDatabase(arg, dbname[0])
} else {
db, err = createCasbinDatabase(arg, DefaultDatabaseName)
}

if err != nil {
return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
}
Expand Down Expand Up @@ -91,7 +102,7 @@ func SkipTableCreate() Option {
}
}

func createCasbinDatabase(arg interface{}) (*pg.DB, error) {
func createCasbinDatabase(arg interface{}, dbname string) (*pg.DB, error) {
var opts *pg.Options
var err error
if connURL, ok := arg.(string); ok {
Expand All @@ -109,10 +120,13 @@ func createCasbinDatabase(arg interface{}) (*pg.DB, error) {
db := pg.Connect(opts)
defer db.Close()

_, err = db.Exec("CREATE DATABASE casbin")
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbname))
if err != nil && !strings.Contains(err.Error(), "42P04") {
return nil, err
}
db.Close()

opts.Database = "casbin"
opts.Database = dbname
db = pg.Connect(opts)

return db, nil
Expand Down

0 comments on commit f5f993b

Please sign in to comment.