diff --git a/cmd/haobase/assets/assets.go b/cmd/haobase/assets/assets.go index b7f9285b..aadc0de3 100644 --- a/cmd/haobase/assets/assets.go +++ b/cmd/haobase/assets/assets.go @@ -3,6 +3,7 @@ package assets import ( "time" + "github.com/shopspring/decimal" "github.com/yzimhao/trading_engine/utils" ) @@ -65,3 +66,27 @@ func FindSymbol(user_id string, symbol string) *Assets { db_engine.Table(new(Assets)).Where("user_id=? and symbol=?", user_id, symbol).Get(&row) return &row } + +func BalanceOfTotal(user_id, symbol string) decimal.Decimal { + row := FindSymbol(user_id, symbol) + if row.Id > 0 { + return utils.D(row.Total) + } + return decimal.Zero +} + +func BalanceOfFreeze(user_id, symbol string) decimal.Decimal { + row := FindSymbol(user_id, symbol) + if row.Id > 0 { + return utils.D(row.Freeze) + } + return decimal.Zero +} + +func BalanceOfAvailable(user_id, symbol string) decimal.Decimal { + row := FindSymbol(user_id, symbol) + if row.Id > 0 { + return utils.D(row.Available) + } + return decimal.Zero +} diff --git a/cmd/haobase/www/demo.go b/cmd/haobase/www/demo.go new file mode 100644 index 00000000..e8100a80 --- /dev/null +++ b/cmd/haobase/www/demo.go @@ -0,0 +1,12 @@ +package www + +import ( + "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" + "github.com/yzimhao/trading_engine/utils/app" +) + +func demoBaseData() { + if app.RunMode == app.ModeDemo { + symbols.DemoData() + } +} diff --git a/cmd/haobase/www/middle/base.go b/cmd/haobase/www/middle/base.go index 20333479..82405939 100644 --- a/cmd/haobase/www/middle/base.go +++ b/cmd/haobase/www/middle/base.go @@ -2,6 +2,8 @@ package middle import ( "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" + "github.com/yzimhao/trading_engine/cmd/haobase/assets" "github.com/yzimhao/trading_engine/utils" "github.com/yzimhao/trading_engine/utils/app" ) @@ -10,11 +12,24 @@ func CheckLogin() gin.HandlerFunc { return func(c *gin.Context) { user_id := "" if app.RunMode == app.ModeDemo { - user_id = c.GetHeader("UserId") + //自动为demo用户充值三种货币 + user_id = c.GetHeader("User-Id") + if user_id != "" { + if assets.BalanceOfTotal(user_id, "usd").Equal(decimal.Zero) { + assets.SysRecharge(user_id, "usd", "10000.00", "sys_recharge") + } + if assets.BalanceOfTotal(user_id, "jpy").Equal(decimal.Zero) { + assets.SysRecharge(user_id, "jpy", "10000.00", "sys_recharge") + } + if assets.BalanceOfTotal(user_id, "eur").Equal(decimal.Zero) { + assets.SysRecharge(user_id, "eur", "10000.00", "sys_recharge") + } + } } if user_id == "" { utils.ResponseFailJson(c, "需要登录") + c.Abort() return } diff --git a/cmd/haobase/www/route.go b/cmd/haobase/www/route.go index 07ffd593..ece59f5c 100644 --- a/cmd/haobase/www/route.go +++ b/cmd/haobase/www/route.go @@ -7,6 +7,8 @@ import ( ) func Run() { + demoBaseData() + g := gin.New() router(g) g.Run(viper.GetString("haobase.http.host"))