diff --git a/cmd/haobase/assets/assets.go b/cmd/haobase/assets/assets.go index aadc0de3..fdfa66f2 100644 --- a/cmd/haobase/assets/assets.go +++ b/cmd/haobase/assets/assets.go @@ -5,6 +5,7 @@ import ( "github.com/shopspring/decimal" "github.com/yzimhao/trading_engine/utils" + "github.com/yzimhao/trading_engine/utils/app" ) // 用户资产冻结记录 @@ -62,8 +63,12 @@ type assetsFreeze struct { } func FindSymbol(user_id string, symbol string) *Assets { + + db := app.Database().NewSession() + defer db.Close() + var row Assets - db_engine.Table(new(Assets)).Where("user_id=? and symbol=?", user_id, symbol).Get(&row) + db.Table(new(Assets)).Where("user_id=? and symbol=?", user_id, symbol).Get(&row) return &row } diff --git a/cmd/haobase/assets/init.go b/cmd/haobase/assets/init.go index da5f6088..a04fbbb7 100644 --- a/cmd/haobase/assets/init.go +++ b/cmd/haobase/assets/init.go @@ -1,14 +1,9 @@ package assets import ( - "github.com/redis/go-redis/v9" "github.com/shopspring/decimal" "github.com/sirupsen/logrus" - "xorm.io/xorm" -) - -var ( - db_engine *xorm.Engine + "github.com/yzimhao/trading_engine/utils/app" ) const ( @@ -18,8 +13,8 @@ const ( UserSystemFee string = "system_fee" ) -func Init(db *xorm.Engine, rdc *redis.Client) { - db_engine = db +func Init() { + db_engine := app.Database() //同步表结构 err := db_engine.Sync2( @@ -33,6 +28,9 @@ func Init(db *xorm.Engine, rdc *redis.Client) { } func UserAssets(user_id string, symbol []string) []Assets { + db_engine := app.Database().NewSession() + defer db_engine.Close() + rows := []Assets{} q := db_engine.Table(new(Assets)).Where("user_id=?", user_id) if len(symbol) > 0 { @@ -43,10 +41,6 @@ func UserAssets(user_id string, symbol []string) []Assets { return rows } -func DB() *xorm.Engine { - return db_engine -} - func d(s string) decimal.Decimal { ss, _ := decimal.NewFromString(s) return ss diff --git a/cmd/haobase/assets/transfer_assets.go b/cmd/haobase/assets/transfer_assets.go index 7061ea14..d1f45be3 100644 --- a/cmd/haobase/assets/transfer_assets.go +++ b/cmd/haobase/assets/transfer_assets.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/yzimhao/trading_engine/utils/app" "xorm.io/xorm" ) @@ -12,7 +13,7 @@ func Transfer(db *xorm.Session, from, to string, symbol string, amount string, b } func SysRecharge(to string, symbol string, amount string, business_id string) (success bool, err error) { - db := db_engine.NewSession() + db := app.Database().NewSession() defer db.Close() db.Begin() diff --git a/cmd/haobase/base/base.go b/cmd/haobase/base/base.go index 13894154..7e3e7995 100644 --- a/cmd/haobase/base/base.go +++ b/cmd/haobase/base/base.go @@ -1,26 +1,11 @@ package base import ( - "github.com/redis/go-redis/v9" "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" - "xorm.io/xorm" ) -var ( - db *xorm.Engine - rdc *redis.Client -) - -func Init(_db *xorm.Engine, _rdc *redis.Client) { - db = _db - rdc = _rdc - symbols.Init(db, rdc) -} - -func DB() *xorm.Engine { - return db -} +var () -func RDC() *redis.Client { - return rdc +func Init() { + symbols.Init() } diff --git a/cmd/haobase/base/symbols/func.go b/cmd/haobase/base/symbols/func.go index 0434f31f..c0416ce3 100644 --- a/cmd/haobase/base/symbols/func.go +++ b/cmd/haobase/base/symbols/func.go @@ -1,12 +1,20 @@ package symbols +import "github.com/yzimhao/trading_engine/utils/app" + func NewVarieties(symbol string) *Varieties { + db := app.Database().NewSession() + defer db.Close() + var row Varieties db.Where("symbol=?", symbol).Get(&row) return &row } func NewTradingVarieties(symbol string) *TradingVarieties { + db := app.Database().NewSession() + defer db.Close() + var row TradingVarieties db.Where("symbol=?", symbol).Get(&row) @@ -18,6 +26,9 @@ func NewTradingVarieties(symbol string) *TradingVarieties { } func newVarietiesById(id int) *Varieties { + db := app.Database().NewSession() + defer db.Close() + var row Varieties db.Where("id=?", id).Get(&row) return &row diff --git a/cmd/haobase/base/symbols/init.go b/cmd/haobase/base/symbols/init.go index 10241a68..17235f24 100644 --- a/cmd/haobase/base/symbols/init.go +++ b/cmd/haobase/base/symbols/init.go @@ -1,24 +1,19 @@ package symbols import ( - "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" - "xorm.io/xorm" + "github.com/yzimhao/trading_engine/utils/app" ) -var ( - db *xorm.Engine - rdc *redis.Client -) - -func Init(_db *xorm.Engine, _rdc *redis.Client) { - db = _db - rdc = _rdc +var () +func Init() { init_db() } func init_db() { + db := app.Database() + err := db.Sync2( new(Varieties), new(TradingVarieties), @@ -54,6 +49,9 @@ func DemoData() { }, } + db := app.Database().NewSession() + defer db.Close() + _, err := db.Insert(symbols) if err != nil { logrus.Error(err) diff --git a/cmd/haobase/clearing/clearing.go b/cmd/haobase/clearing/clearing.go index 33743482..e04a70d3 100644 --- a/cmd/haobase/clearing/clearing.go +++ b/cmd/haobase/clearing/clearing.go @@ -1,22 +1,22 @@ package clearing import ( - "context" "encoding/json" "fmt" "time" + "github.com/gomodule/redigo/redis" "github.com/sirupsen/logrus" - "github.com/yzimhao/trading_engine/cmd/haobase/base" "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/types" "github.com/yzimhao/trading_engine/utils" + "github.com/yzimhao/trading_engine/utils/app" ) func Run() { //load symbols - db := base.DB().NewSession() + db := app.Database().NewSession() defer db.Close() var rows []symbols.TradingVarieties @@ -37,16 +37,15 @@ func watch_redis_list(symbol string) { logrus.Infof("结算,正在监听%s成交日志...", symbol) for { func() { - cx := context.Background() - rdc := base.RDC() + rdc := app.RedisPool().Get() defer rdc.Close() - if n, _ := rdc.LLen(cx, key).Result(); n == 0 { + if n, _ := redis.Int64(rdc.Do("LLen", key)); n == 0 { time.Sleep(time.Duration(50) * time.Millisecond) return } - raw, _ := rdc.LPop(cx, key).Bytes() + raw, _ := redis.Bytes(rdc.Do("Lpop", key)) var data trading_core.TradeResult err := json.Unmarshal(raw, &data) @@ -64,7 +63,9 @@ func watch_redis_list(symbol string) { } //通知kline系统 - rdc.RPush(cx, quote_key, raw) + if _, err := rdc.Do("RPUSH", quote_key, raw); err != nil { + logrus.Errorf("rpush %s err: %s", quote_key, err.Error()) + } // if !data.Last { // go newClean(data) diff --git a/cmd/haobase/clearing/clearing_flow.go b/cmd/haobase/clearing/clearing_flow.go index 8637fa30..1c8fc1db 100644 --- a/cmd/haobase/clearing/clearing_flow.go +++ b/cmd/haobase/clearing/clearing_flow.go @@ -4,11 +4,11 @@ import ( "fmt" "github.com/yzimhao/trading_engine/cmd/haobase/assets" - "github.com/yzimhao/trading_engine/cmd/haobase/base" "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" "github.com/yzimhao/trading_engine/cmd/haobase/orders" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/utils" + "github.com/yzimhao/trading_engine/utils/app" "xorm.io/xorm" ) @@ -23,7 +23,7 @@ type clean struct { } func newClean(raw trading_core.TradeResult) error { - db := base.DB().NewSession() + db := app.Database().NewSession() defer db.Close() item := clean{ diff --git a/cmd/haobase/clearing/clearing_test.go b/cmd/haobase/clearing/clearing_test.go index c4028620..ffd3e1a9 100644 --- a/cmd/haobase/clearing/clearing_test.go +++ b/cmd/haobase/clearing/clearing_test.go @@ -4,14 +4,13 @@ import ( "testing" "time" - "github.com/redis/go-redis/v9" "github.com/yzimhao/trading_engine/cmd/haobase/assets" "github.com/yzimhao/trading_engine/cmd/haobase/base" "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" "github.com/yzimhao/trading_engine/cmd/haobase/orders" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/utils" - "xorm.io/xorm" + "github.com/yzimhao/trading_engine/utils/app" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -28,14 +27,10 @@ var ( ) func initdb(t *testing.T) { - db, err := xorm.NewEngine("mysql", "root:root@tcp(localhost:3306)/test?charset=utf8&loc=Local") - if err != nil { - t.Logf("mysql err: %s", err) - } + app.DatabaseInit("mysql", "root:root@tcp(localhost:3306)/test?charset=utf8&loc=Local", true) + app.RedisInit("127.0.0.1:6379", "", 0) - rdc := redis.NewClient(&redis.Options{Addr: "127.0.0.1:6379", DB: 0}) - base.Init(db, rdc) - db.ShowSQL(true) + base.Init() cleanAssets(t) cleanOrders(t) @@ -43,7 +38,7 @@ func initdb(t *testing.T) { } func initAssets(t *testing.T) { - assets.Init(base.DB(), base.RDC()) + assets.Init() symbols.DemoData() assets.SysRecharge("user1", "usd", "10000.00", "C001") @@ -53,10 +48,11 @@ func initAssets(t *testing.T) { } func cleanAssets(t *testing.T) { - base.DB().DropIndexes(new(assets.Assets)) - base.DB().DropIndexes("assets_freeze") - base.DB().DropIndexes("assets_log") - err := base.DB().DropTables(new(assets.Assets), "assets_freeze", "assets_log") + db := app.Database() + db.DropIndexes(new(assets.Assets)) + db.DropIndexes("assets_freeze") + db.DropIndexes("assets_log") + err := db.DropTables(new(assets.Assets), "assets_freeze", "assets_log") if err != nil { t.Logf("mysql droptables: %s", err) } @@ -64,13 +60,14 @@ func cleanAssets(t *testing.T) { } func cleanOrders(t *testing.T) { - base.DB().DropIndexes(orders.GetOrderTableName(testSymbol)) - base.DB().DropIndexes(new(orders.UnfinishedOrder)) - base.DB().DropIndexes(orders.GetTradelogTableName(testSymbol)) - - base.DB().DropTables(orders.GetOrderTableName(testSymbol)) - base.DB().DropTables(new(orders.UnfinishedOrder)) - base.DB().DropTables(orders.GetTradelogTableName(testSymbol)) + db := app.Database() + db.DropIndexes(orders.GetOrderTableName(testSymbol)) + db.DropIndexes(new(orders.UnfinishedOrder)) + db.DropIndexes(orders.GetTradelogTableName(testSymbol)) + + db.DropTables(orders.GetOrderTableName(testSymbol)) + db.DropTables(new(orders.UnfinishedOrder)) + db.DropTables(orders.GetTradelogTableName(testSymbol)) } diff --git a/cmd/haobase/main.go b/cmd/haobase/main.go index 97cef067..e44d174f 100644 --- a/cmd/haobase/main.go +++ b/cmd/haobase/main.go @@ -27,11 +27,11 @@ func main() { Before: func(ctx *cli.Context) error { app.ConfigInit(ctx.String("config")) - db := app.DatabaseInit() - rc := app.RedisInit() + app.DatabaseInit(app.Cstring("database.driver"), app.Cstring("database.dsn"), app.Cbool("database.show_sql")) + app.RedisInit(app.Cstring("redis.host"), app.Cstring("redis.password"), app.Cint("redis.db")) - base.Init(db, rc) - assets.Init(db, rc) + base.Init() + assets.Init() return nil }, diff --git a/cmd/haobase/orders/cancel_order.go b/cmd/haobase/orders/cancel_order.go index 3717dafa..79078e1d 100644 --- a/cmd/haobase/orders/cancel_order.go +++ b/cmd/haobase/orders/cancel_order.go @@ -3,10 +3,11 @@ package orders import ( "github.com/sirupsen/logrus" "github.com/yzimhao/trading_engine/cmd/haobase/assets" + "github.com/yzimhao/trading_engine/utils/app" ) func cancel_order(symbol, order_id string) (order *Order, err error) { - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() err = db.Begin() diff --git a/cmd/haobase/orders/func.go b/cmd/haobase/orders/func.go index b8ca44a9..77c9b469 100644 --- a/cmd/haobase/orders/func.go +++ b/cmd/haobase/orders/func.go @@ -1,16 +1,15 @@ package orders import ( - "context" "fmt" "math/rand" "strings" "time" "github.com/sirupsen/logrus" - "github.com/yzimhao/trading_engine/cmd/haobase/base" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/types" + "github.com/yzimhao/trading_engine/utils/app" ) func generate_order_id_by_side(side trading_core.OrderSide) string { @@ -32,9 +31,10 @@ func generate_order_id(prefix string) string { func push_new_order_to_redis(symbol string, data []byte) { topic := types.FormatNewOrder.Format(symbol) logrus.Infof("push %s new: %s", topic, data) - ctx := context.Background() - err := base.RDC().RPush(ctx, topic, data).Err() - if err != nil { + + rdc := app.RedisPool().Get() + defer rdc.Close() + if _, err := rdc.Do("RPUSH", topic, data); err != nil { logrus.Errorf("push %s err: %s", topic, err.Error()) } } diff --git a/cmd/haobase/orders/limit_order.go b/cmd/haobase/orders/limit_order.go index 5414d292..5e033d43 100644 --- a/cmd/haobase/orders/limit_order.go +++ b/cmd/haobase/orders/limit_order.go @@ -7,6 +7,7 @@ import ( "github.com/yzimhao/trading_engine/haotrader" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/utils" + "github.com/yzimhao/trading_engine/utils/app" ) func NewLimitOrder(user_id string, symbol string, side trading_core.OrderSide, price, qty string) (order *Order, err error) { @@ -36,7 +37,7 @@ func limit_order(user_id string, symbol string, side trading_core.OrderSide, pri Status: OrderStatusNew, } - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() //todo事务开启前创建需要的表 diff --git a/cmd/haobase/orders/market_order.go b/cmd/haobase/orders/market_order.go index 21838854..272d38d7 100644 --- a/cmd/haobase/orders/market_order.go +++ b/cmd/haobase/orders/market_order.go @@ -5,6 +5,7 @@ import ( "github.com/yzimhao/trading_engine/cmd/haobase/base/symbols" "github.com/yzimhao/trading_engine/haotrader" "github.com/yzimhao/trading_engine/trading_core" + "github.com/yzimhao/trading_engine/utils/app" ) func NewMarketOrderByQty(user_id string, symbol string, side trading_core.OrderSide, qty string) (*Order, error) { @@ -34,7 +35,7 @@ func market_order_qty(user_id string, symbol string, side trading_core.OrderSide Status: OrderStatusNew, } - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() err = db.Begin() @@ -120,7 +121,7 @@ func market_order_amount(user_id string, symbol string, side trading_core.OrderS Status: OrderStatusNew, } - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() err = db.Begin() diff --git a/cmd/haobase/orders/order.go b/cmd/haobase/orders/order.go index 3738e3a2..25cf92a6 100644 --- a/cmd/haobase/orders/order.go +++ b/cmd/haobase/orders/order.go @@ -4,8 +4,8 @@ import ( "fmt" "time" - "github.com/yzimhao/trading_engine/cmd/haobase/assets" "github.com/yzimhao/trading_engine/trading_core" + "github.com/yzimhao/trading_engine/utils/app" "xorm.io/xorm" ) @@ -42,6 +42,7 @@ type Order struct { func (o *Order) Save(db *xorm.Session) error { //todo 频繁查询表是否存在,后面考虑缓存一下 + exist, err := db.IsTableExist(o.TableName()) if err != nil { return err @@ -80,7 +81,7 @@ func GetOrderTableName(symbol string) string { } func Find(symbol string, order_id string) *Order { - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() var row Order diff --git a/cmd/haobase/orders/unfinished_order.go b/cmd/haobase/orders/unfinished_order.go index c4dd6801..92ebce96 100644 --- a/cmd/haobase/orders/unfinished_order.go +++ b/cmd/haobase/orders/unfinished_order.go @@ -1,7 +1,7 @@ package orders import ( - "github.com/yzimhao/trading_engine/cmd/haobase/assets" + "github.com/yzimhao/trading_engine/utils/app" "xorm.io/xorm" ) @@ -45,7 +45,7 @@ func (u *UnfinishedOrder) Create(db *xorm.Session) error { } func FindUnfinished(symbol string, order_id string) *Order { - db := assets.DB().NewSession() + db := app.Database().NewSession() defer db.Close() var row Order diff --git a/cmd/haoquote/main.go b/cmd/haoquote/main.go index fd9b817e..4f460a08 100644 --- a/cmd/haoquote/main.go +++ b/cmd/haoquote/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "os" "github.com/sevlyar/go-daemon" @@ -74,11 +73,9 @@ func main() { } - db := app.DatabaseInit() - - rc := app.RedisInit() - ctext := context.Background() - haoquote.Run(&ctext, rc, db) + app.DatabaseInit(app.Cstring("database.driver"), app.Cstring("database.dsn"), app.Cbool("database.show_sql")) + app.RedisInit(app.Cstring("redis.host"), app.Cstring("redis.password"), app.Cint("redis.db")) + haoquote.Run() return nil }, } diff --git a/cmd/haotrader/main.go b/cmd/haotrader/main.go index 2fa67f61..b7b7ecc0 100644 --- a/cmd/haotrader/main.go +++ b/cmd/haotrader/main.go @@ -1,10 +1,8 @@ package main import ( - "context" "os" - "github.com/redis/go-redis/v9" "github.com/sevlyar/go-daemon" "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -13,10 +11,6 @@ import ( "github.com/yzimhao/trading_engine/utils/app" ) -var ( - rc *redis.Client -) - func main() { appm := &cli.App{ Name: "haotrader", @@ -54,12 +48,12 @@ func main() { Action: func(ctx *cli.Context) error { app.ConfigInit(ctx.String("config")) app.LogsInit("haotrader.run", false) - rc := app.RedisInit() + // rc := app.RedisInit() if ctx.String("side") == "ask" { - haotrader.InsertAsk(rc, ctx.String("symbol"), ctx.String("type"), ctx.Int("n"), ctx.String("price"), ctx.String("qty")) + // haotrader.InsertAsk(rc, ctx.String("symbol"), ctx.String("type"), ctx.Int("n"), ctx.String("price"), ctx.String("qty")) } else { - haotrader.InsertBid(rc, ctx.String("symbol"), ctx.String("type"), ctx.Int("n"), ctx.String("price"), ctx.String("qty")) + // haotrader.InsertBid(rc, ctx.String("symbol"), ctx.String("type"), ctx.Int("n"), ctx.String("price"), ctx.String("qty")) } return nil }, @@ -73,8 +67,7 @@ func main() { logrus.Infof("当前运行在%s模式下,生产环境时main.mode请务必成prod", viper.GetString("main.mode")) } - rc := app.RedisInit() - ctext := context.Background() + app.RedisInit(app.Cstring("redis.host"), app.Cstring("redis.password"), app.Cint("redis.db")) if ctx.Bool("deamon") { logrus.Info("开始守护进程") @@ -96,7 +89,7 @@ func main() { }(context) } - haotrader.Start(&ctext, rc) + haotrader.Run() return nil }, } diff --git a/go.mod b/go.mod index 940f932d..ec4e1b9d 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/snappy v0.0.4 // indirect + github.com/gomodule/redigo v1.8.9 // indirect github.com/gookit/goutil v0.6.12 // indirect github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/gorilla/websocket v1.5.0 // indirect diff --git a/go.sum b/go.sum index 9b354103..46a03f5d 100644 --- a/go.sum +++ b/go.sum @@ -194,6 +194,8 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8l github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= diff --git a/haoquote/main.go b/haoquote/main.go index f0c14e4f..f32bb5b0 100644 --- a/haoquote/main.go +++ b/haoquote/main.go @@ -1,25 +1,22 @@ package haoquote import ( - "context" "strings" "sync" - "github.com/redis/go-redis/v9" "github.com/spf13/viper" "github.com/yzimhao/trading_engine/haoquote/tradelog" "github.com/yzimhao/trading_engine/haoquote/www" "github.com/yzimhao/trading_engine/utils/filecache" - "xorm.io/xorm" ) -func Run(ctx *context.Context, rc *redis.Client, db *xorm.Engine) { +func Run() { wg := sync.WaitGroup{} wg.Add(1) init_symbols_quote() //http - tradelog.Init(rc, db) - www.Run(rc, db) + tradelog.Init() + www.Run() wg.Wait() } diff --git a/haoquote/tradelog/trade_log.go b/haoquote/tradelog/trade_log.go index b1736a47..8b09913a 100644 --- a/haoquote/tradelog/trade_log.go +++ b/haoquote/tradelog/trade_log.go @@ -1,13 +1,12 @@ package tradelog import ( - "context" "encoding/json" "fmt" "time" + "github.com/gomodule/redigo/redis" "github.com/gookit/goutil/arrutil" - "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/yzimhao/trading_engine/haoquote/period" @@ -15,13 +14,10 @@ import ( "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/types" "github.com/yzimhao/trading_engine/utils" - "xorm.io/xorm" + "github.com/yzimhao/trading_engine/utils/app" ) -var ( - db *xorm.Engine - rdc *redis.Client -) +var () type TradeLog struct { Id int64 `xorm:"autoincr pk" json:"-"` @@ -45,6 +41,8 @@ func (t *TradeLog) CreateTable() error { return fmt.Errorf("symbol is null") } + db := app.Database() + exist, err := db.IsTableExist(t.TableName()) if err != nil { return err @@ -68,14 +66,14 @@ func (t *TradeLog) CreateTable() error { } func (t *TradeLog) Save() error { + db := app.Database().NewSession() + defer db.Close() + _, err := db.Table(t.TableName()).Insert(t) return err } -func Init(rc *redis.Client, d *xorm.Engine) { - db = d - rdc = rc -} +func Init() {} func Monitor(symbol string, price_digit, qty_digit int64) { key := types.FormatQuoteTradeResult.Format(symbol) @@ -85,13 +83,22 @@ func Monitor(symbol string, price_digit, qty_digit int64) { for { func() { - cx := context.Background() - if n, _ := rdc.LLen(cx, key).Result(); n == 0 { + // cx := context.Background() + // if n, _ := rdc.LLen(cx, key).Result(); n == 0 { + // time.Sleep(time.Duration(50) * time.Millisecond) + // return + // } + + rdc := app.RedisPool().Get() + defer rdc.Close() + if n, _ := redis.Int64(rdc.Do("LLen", key)); n == 0 { time.Sleep(time.Duration(50) * time.Millisecond) return } - raw, _ := rdc.LPop(cx, key).Bytes() + // raw, _ := rdc.LPop(cx, key).Bytes() + + raw, _ := redis.Bytes(rdc.Do("LPop", key)) var data trading_core.TradeResult err := json.Unmarshal(raw, &data) @@ -172,6 +179,8 @@ func tradelog_msg(symbol string, data TradeLog, pd, qd int64) { } func save_db(row *period.Period) error { + + db := app.Database() row.CreateTable(db) sess := db.NewSession() diff --git a/haoquote/www/http.go b/haoquote/www/http.go index b82da8e6..c2b21cbd 100644 --- a/haoquote/www/http.go +++ b/haoquote/www/http.go @@ -5,10 +5,8 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" "github.com/spf13/viper" - "xorm.io/xorm" _ "github.com/yzimhao/trading_engine/docs/api" // main 文件中导入 docs 包 "github.com/yzimhao/trading_engine/utils" @@ -19,15 +17,9 @@ import ( "github.com/yzimhao/trading_engine/haoquote/ws" ) -var ( - rdc *redis.Client - db *xorm.Engine -) - -func Run(rc *redis.Client, d *xorm.Engine) { - rdc = rc - db = d +var () +func Run() { ws.NewHub() sub_symbol_depth() http_start(viper.GetString("haoquote.http.host")) @@ -111,6 +103,9 @@ func trans_record(ctx *gin.Context) { Symbol: symbol, } + db := app.Database().NewSession() + defer db.Close() + db.Table(tl.TableName()).OrderBy("trade_at desc, id desc").Limit(limit).Find(&rows) // [ @@ -183,6 +178,9 @@ func kline(ctx *gin.Context) { // ] // ] + db := app.Database().NewSession() + defer db.Close() + var rows []period.Period db.Table(row.TableName()).OrderBy("open_at desc").Limit(limit).Find(&rows) diff --git a/haoquote/www/sub_symbol_depth.go b/haoquote/www/sub_symbol_depth.go index 378d2fb8..09fc24af 100644 --- a/haoquote/www/sub_symbol_depth.go +++ b/haoquote/www/sub_symbol_depth.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/viper" "github.com/yzimhao/trading_engine/haoquote/ws" "github.com/yzimhao/trading_engine/types" + "github.com/yzimhao/trading_engine/utils/app" ) var ( @@ -90,6 +91,9 @@ func sub_depth(symbol string, price_digit, qty_digit int64) { channel := types.FormatBroadcastDepth.Format(symbol) ctx := context.Background() + rdc := app.RedisPool().Get() + defer rdc.Close() + pubsub := rdc.PSubscribe(ctx, channel) defer pubsub.Close() diff --git a/haotrader/data.go b/haotrader/data.go index 5a901f38..384c8df1 100644 --- a/haotrader/data.go +++ b/haotrader/data.go @@ -1,68 +1,71 @@ package haotrader -import ( - "context" - "encoding/json" - "fmt" - "math/rand" - "strconv" - "time" +// import ( +// "context" +// "encoding/json" +// "fmt" +// "math/rand" +// "strconv" +// "time" - "github.com/google/uuid" - "github.com/redis/go-redis/v9" - "github.com/spf13/viper" - "github.com/yzimhao/trading_engine/types" -) +// "github.com/gomodule/redigo/redis" +// "github.com/google/uuid" +// "github.com/spf13/viper" +// "github.com/yzimhao/trading_engine/types" +// ) -func InsertAsk(rdc *redis.Client, symbol string, order_type string, n int, base_price string, qty string) { - if n <= 0 { - n = 1 - } +// func InsertAsk(rdc *redis.Pool, symbol string, order_type string, n int, base_price string, qty string) { +// if n <= 0 { +// n = 1 +// } - digit := viper.GetInt(fmt.Sprintf("symbol.%s.price_digit", symbol)) - pdigit := fmt.Sprintf("%%.%df", digit) +// digit := viper.GetInt(fmt.Sprintf("symbol.%s.price_digit", symbol)) +// pdigit := fmt.Sprintf("%%.%df", digit) - price, _ := strconv.ParseFloat(base_price, 64) - for i := 0; i < n; i++ { +// price, _ := strconv.ParseFloat(base_price, 64) +// for i := 0; i < n; i++ { - id, _ := uuid.NewRandom() - obj := Order{ - OrderId: fmt.Sprintf("a-%s-%d", id.String(), i), - Side: "ask", - OrderType: "limit", - Price: fmt.Sprintf(pdigit, price+rand.Float64()), - Qty: qty, - At: time.Now().Unix(), - } - s, _ := json.Marshal(obj) - ctx1 := context.Background() - key := types.FormatNewOrder.Format(symbol) - rdc.RPush(ctx1, key, s).Err() - } -} +// id, _ := uuid.NewRandom() +// obj := Order{ +// OrderId: fmt.Sprintf("a-%s-%d", id.String(), i), +// Side: "ask", +// OrderType: "limit", +// Price: fmt.Sprintf(pdigit, price+rand.Float64()), +// Qty: qty, +// At: time.Now().Unix(), +// } +// s, _ := json.Marshal(obj) +// // ctx1 := context.Background() +// key := types.FormatNewOrder.Format(symbol) +// // rdc.RPush(ctx1, key, s).Err() -func InsertBid(rdc *redis.Client, symbol string, order_type string, n int, base_price string, qty string) { - if n <= 0 { - n = 1 - } +// rdc.Get().Do("RPUSH", key, s) - digit := viper.GetInt(fmt.Sprintf("symbol.%s.price_digit", symbol)) - pdigit := fmt.Sprintf("%%.%df", digit) +// } +// } - price, _ := strconv.ParseFloat(base_price, 64) - for i := 0; i < n; i++ { - id, _ := uuid.NewRandom() - obj := Order{ - OrderId: fmt.Sprintf("b-%s-%d", id.String(), i), - Side: "bid", - OrderType: "limit", - Price: fmt.Sprintf(pdigit, price+rand.Float64()), - Qty: qty, - At: time.Now().Unix(), - } - s, _ := json.Marshal(obj) - ctx1 := context.Background() - key := types.FormatNewOrder.Format(symbol) - rdc.RPush(ctx1, key, s).Err() - } -} +// func InsertBid(rdc *redis.Pool, symbol string, order_type string, n int, base_price string, qty string) { +// if n <= 0 { +// n = 1 +// } + +// digit := viper.GetInt(fmt.Sprintf("symbol.%s.price_digit", symbol)) +// pdigit := fmt.Sprintf("%%.%df", digit) + +// price, _ := strconv.ParseFloat(base_price, 64) +// for i := 0; i < n; i++ { +// id, _ := uuid.NewRandom() +// obj := Order{ +// OrderId: fmt.Sprintf("b-%s-%d", id.String(), i), +// Side: "bid", +// OrderType: "limit", +// Price: fmt.Sprintf(pdigit, price+rand.Float64()), +// Qty: qty, +// At: time.Now().Unix(), +// } +// s, _ := json.Marshal(obj) +// ctx1 := context.Background() +// key := types.FormatNewOrder.Format(symbol) +// rdc.RPush(ctx1, key, s).Err() +// } +// } diff --git a/haotrader/tengine.go b/haotrader/tengine.go index 58d7b1ad..5c78d7c9 100644 --- a/haotrader/tengine.go +++ b/haotrader/tengine.go @@ -1,16 +1,17 @@ package haotrader import ( - "context" "encoding/json" "strings" "time" "github.com/gin-gonic/gin" + "github.com/gomodule/redigo/redis" "github.com/shopspring/decimal" "github.com/sirupsen/logrus" "github.com/yzimhao/trading_engine/trading_core" "github.com/yzimhao/trading_engine/types" + "github.com/yzimhao/trading_engine/utils/app" ) type tengine struct { @@ -52,24 +53,23 @@ func (t *tengine) broadcast_depth() { select { case <-t.broadcast: go func() { - logrus.Debugf("触发%s广播", depth_channel) - tx := context.Background() data := gin.H{ //todo 限制最大获取的数量 "asks": t.tp.GetAskDepth(100), "bids": t.tp.GetBidDepth(100), } + rdc := app.RedisPool().Get() + defer rdc.Close() + raw, _ := json.Marshal(data) - err := rdc.Publish(tx, depth_channel, raw).Err() - if err != nil { + + if _, err := rdc.Do("Publish", depth_channel, raw); err != nil { logrus.Warnf("广播%s消息失败: %s", depth_channel, err) } }() go func() { - tx := context.Background() - price, at := t.tp.LatestPrice() data := types.ChannelLatestPrice{ T: at, @@ -77,7 +77,11 @@ func (t *tengine) broadcast_depth() { } raw, _ := json.Marshal(data) - rdc.Publish(tx, price_channel, raw).Err() + rdc := app.RedisPool().Get() + defer rdc.Close() + if _, err := rdc.Do("Publish", price_channel, raw); err != nil { + logrus.Warnf("广播%s消息失败: %s", price_channel, err) + } }() default: @@ -170,55 +174,65 @@ func (t *tengine) pull_new_order() { logrus.Infof("正在监听redis队列: %s", key) for { - cx := context.Background() - if n, _ := rdc.LLen(cx, key).Result(); n == 0 || t.tp.IsPausePushNew() { - time.Sleep(time.Duration(50) * time.Millisecond) - continue - } + // cx := context.Background() + // if n, _ := rdc.LLen(cx, key).Result(); n == 0 || t.tp.IsPausePushNew() { + // time.Sleep(time.Duration(50) * time.Millisecond) + // continue + // } + func() { - raw, _ := rdc.LPop(cx, key).Bytes() - go func(raw []byte) { - var data Order - err := json.Unmarshal(raw, &data) - if err != nil { - logrus.Warnf("%s 解析json: %s 错误: %s", key, raw, err) + rdc := app.RedisPool().Get() + defer rdc.Close() + if n, _ := redis.Int64(rdc.Do("LLen", key)); n == 0 || t.tp.IsPausePushNew() { + time.Sleep(time.Duration(50) * time.Millisecond) + return } - if data.OrderId != "" { - logrus.Debugf("%s队列LPop: %s", key, raw) - side := strings.ToLower(data.Side) - order_type := strings.ToLower(data.OrderType) - - if order_type == "limit" { - if side == trading_core.OrderSideSell.String() { - t.tp.PushNewOrder(trading_core.NewAskLimitItem(data.OrderId, d(data.Price), d(data.Qty), data.At)) - } else if side == trading_core.OrderSideBuy.String() { - t.tp.PushNewOrder(trading_core.NewBidLimitItem(data.OrderId, d(data.Price), d(data.Qty), data.At)) - } else { - logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) - } - } else if order_type == "market_qty" { - //按成交量 - if side == trading_core.OrderSideSell.String() { - t.tp.PushNewOrder(trading_core.NewAskMarketQtyItem(data.OrderId, d(data.Qty), data.At)) - } else if side == trading_core.OrderSideBuy.String() { - t.tp.PushNewOrder(trading_core.NewBidMarketQtyItem(data.OrderId, d(data.Qty), d(data.MaxAmount), data.At)) - } else { - logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) - } - } else if order_type == "market_amount" { - //按成交金额 - if side == trading_core.OrderSideSell.String() { - t.tp.PushNewOrder(trading_core.NewAskMarketAmountItem(data.OrderId, d(data.Qty), d(data.MaxQty), data.At)) - } else if side == trading_core.OrderSideBuy.String() { - t.tp.PushNewOrder(trading_core.NewBidMarketAmountItem(data.OrderId, d(data.Amount), data.At)) - } else { - logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) + raw, _ := redis.Bytes(rdc.Do("Lpop", key)) + + go func(raw []byte) { + var data Order + err := json.Unmarshal(raw, &data) + if err != nil { + logrus.Warnf("%s 解析json: %s 错误: %s", key, raw, err) + } + + if data.OrderId != "" { + logrus.Debugf("%s队列LPop: %s", key, raw) + side := strings.ToLower(data.Side) + order_type := strings.ToLower(data.OrderType) + + if order_type == "limit" { + if side == trading_core.OrderSideSell.String() { + t.tp.PushNewOrder(trading_core.NewAskLimitItem(data.OrderId, d(data.Price), d(data.Qty), data.At)) + } else if side == trading_core.OrderSideBuy.String() { + t.tp.PushNewOrder(trading_core.NewBidLimitItem(data.OrderId, d(data.Price), d(data.Qty), data.At)) + } else { + logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) + } + } else if order_type == "market_qty" { + //按成交量 + if side == trading_core.OrderSideSell.String() { + t.tp.PushNewOrder(trading_core.NewAskMarketQtyItem(data.OrderId, d(data.Qty), data.At)) + } else if side == trading_core.OrderSideBuy.String() { + t.tp.PushNewOrder(trading_core.NewBidMarketQtyItem(data.OrderId, d(data.Qty), d(data.MaxAmount), data.At)) + } else { + logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) + } + } else if order_type == "market_amount" { + //按成交金额 + if side == trading_core.OrderSideSell.String() { + t.tp.PushNewOrder(trading_core.NewAskMarketAmountItem(data.OrderId, d(data.Qty), d(data.MaxQty), data.At)) + } else if side == trading_core.OrderSideBuy.String() { + t.tp.PushNewOrder(trading_core.NewBidMarketAmountItem(data.OrderId, d(data.Amount), data.At)) + } else { + logrus.Errorf("新订单参数错误: %s side只能是sell/buy", raw) + } } } - } - }(raw) + }(raw) + }() } } @@ -229,13 +243,15 @@ func (t *tengine) pull_cancel_order() { logrus.Infof("正在监听redis队列: %s", key) for { func() { - cx := context.Background() - if n, _ := rdc.LLen(cx, key).Result(); n == 0 || t.tp.IsPausePushNew() { + rdc := app.RedisPool().Get() + defer rdc.Close() + + if n, _ := redis.Int64(rdc.Do("Llen", key)); n == 0 || t.tp.IsPausePushNew() { time.Sleep(time.Duration(50) * time.Millisecond) return } - raw, _ := rdc.LPop(cx, key).Bytes() + raw, _ := redis.Bytes(rdc.Do("LPOP", key)) // rdc.LPop(cx, key).Bytes() var data cancel_order err := json.Unmarshal(raw, &data) @@ -271,7 +287,6 @@ func (t *tengine) monitor_result() { }() case uniq := <-t.tp.ChCancelResult: go func() { - cx := context.Background() key := types.FormatCancelResult.Format(t.symbol) data := map[string]any{ @@ -279,9 +294,13 @@ func (t *tengine) monitor_result() { "cancel": "success", } + rdc := app.RedisPool().Get() + defer rdc.Close() + raw, _ := json.Marshal(data) - err := rdc.RPush(cx, key, raw).Err() - logrus.Infof("%s队列RPush: %s %s", key, raw, err) + if _, err := rdc.Do("RPUSH", key, raw); err != nil { //rdc.RPush(cx, key, raw).Err() + logrus.Warnf("%s队列RPush: %s %s", key, raw, err) + } }() default: @@ -292,10 +311,13 @@ func (t *tengine) monitor_result() { } func (t *tengine) push_match_result(data []byte) { - cx := context.Background() + rdc := app.RedisPool().Get() + defer rdc.Close() + key := types.FormatTradeResult.Format(t.symbol) - err := rdc.RPush(cx, key, data).Err() - logrus.Infof("往%s队列RPush: %s %s", key, data, err) + if _, err := rdc.Do("RPUSH", key, data); err != nil { + logrus.Warnf("往%s队列RPush: %s %s", key, data, err) + } // if viper.GetBool("haotrader.notify_quote") { // quote_key := types.FormatQuoteTradeResult.Format(t.symbol) // rdc.RPush(cx, quote_key, data) diff --git a/haotrader/trader.go b/haotrader/trader.go index b56ae8e0..cba79dc2 100644 --- a/haotrader/trader.go +++ b/haotrader/trader.go @@ -1,12 +1,10 @@ package haotrader import ( - "context" "strings" "sync" "time" - "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" "github.com/spf13/viper" "github.com/yzimhao/trading_engine/trading_core" @@ -14,24 +12,17 @@ import ( ) var ( - rdc *redis.Client - wg sync.WaitGroup + wg sync.WaitGroup localdb *filecache.Storage teps map[string]*trading_core.TradePair ) -func Start(ctx *context.Context, rc *redis.Client) { - rdc = rc +func Run() { teps = make(map[string]*trading_core.TradePair) localdb = filecache.NewStorage(viper.GetString("haotrader.storage_path"), time.Duration(10)) defer localdb.Close() - _, err := rdc.Ping(context.Background()).Result() - if err != nil { - panic(err) - } - wg = sync.WaitGroup{} //todo wg.done() diff --git a/utils/app/app.go b/utils/app/app.go index 7a60bcaf..cca2ef96 100644 --- a/utils/app/app.go +++ b/utils/app/app.go @@ -6,8 +6,8 @@ import ( "os" "time" + "github.com/gomodule/redigo/redis" "github.com/gookit/goutil/fsutil" - "github.com/redis/go-redis/v9" "github.com/sevlyar/go-daemon" "github.com/sirupsen/logrus" "github.com/spf13/viper" @@ -33,6 +33,10 @@ var ( Commit = "" Build = "" RunMode Mode = ModeProd + + // + redisPool *redis.Pool + database *xorm.Engine ) func ShowVersion() { @@ -55,27 +59,52 @@ func ConfigInit(fp string) { RunMode = Mode(viper.GetString("main.mode")) } -func RedisInit() *redis.Client { - return redis.NewClient(&redis.Options{ - Addr: viper.GetString("redis.host"), - Password: viper.GetString("redis.password"), - DB: viper.GetInt("redis.db"), - - DialTimeout: 10 * time.Second, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - ContextTimeoutEnabled: true, +func Cstring(key string) string { + return viper.GetString(key) +} - MaxRetries: -1, +func Cbool(key string) bool { + return viper.GetBool(key) +} - PoolTimeout: 30 * time.Second, - ConnMaxIdleTime: time.Minute, +func Cint(key string) int { + return viper.GetInt(key) +} - PoolSize: 15, - MinIdleConns: 10, +func RedisInit(addr, password string, db int) { + if redisPool == nil { + // addr := viper.GetString("redis.host") + // password := viper.GetString("redis.password") + // db := viper.GetInt("redis.db") + redisPool = &redis.Pool{ + MaxIdle: 10, //空闲数 + IdleTimeout: 240 * time.Second, + MaxActive: 20, //最大数 + Dial: func() (redis.Conn, error) { + c, err := redis.Dial("tcp", addr) + if err != nil { + return nil, err + } + if password != "" { + if _, err := c.Do("AUTH", password); err != nil { + c.Close() + return nil, err + } + } + + if _, err := c.Do("SELECT", db); err != nil { + return nil, err + } + + return c, err + }, + TestOnBorrow: func(c redis.Conn, t time.Time) error { + _, err := c.Do("PING") + return err + }, + } + } - ConnMaxLifetime: 0, - }) } func LogsInit(fn string, is_daemon bool) { @@ -111,23 +140,36 @@ func LogsInit(fn string, is_daemon bool) { logrus.SetOutput(mw) } -func DatabaseInit() *xorm.Engine { - dsn := viper.GetString("database.dsn") - driver := viper.GetString("database.driver") - conn, err := xorm.NewEngine(driver, dsn) - if err != nil { - logrus.Panic(err) - } +func DatabaseInit(driver, dsn string, show_sql bool) { + if database == nil { + // dsn := viper.GetString("database.dsn") + // driver := viper.GetString("database.driver") + conn, err := xorm.NewEngine(driver, dsn) + if err != nil { + logrus.Panic(err) + } - // tbMapper := core.NewPrefixMapper(core.SnakeMapper{}, "ex_") - // conn.SetTableMapper(tbMapper) - if viper.GetBool("database.show_sql") { - conn.ShowSQL(true) + // tbMapper := core.NewPrefixMapper(core.SnakeMapper{}, "ex_") + // conn.SetTableMapper(tbMapper) + // if viper.GetBool("database.show_sql") { + // conn.ShowSQL(true) + // } + if show_sql { + conn.ShowSQL(true) + } + + conn.DatabaseTZ = time.Local + conn.TZLocation = time.Local + database = conn } +} + +func Database() *xorm.Engine { + return database +} - conn.DatabaseTZ = time.Local - conn.TZLocation = time.Local - return conn +func RedisPool() *redis.Pool { + return redisPool } func Deamon(pid string, logfile string) (*daemon.Context, *os.Process, error) {