Skip to content

Commit

Permalink
Support for [TAG tag] on AI.SCRIPTSET、AI.MODELSET (#20)
Browse files Browse the repository at this point in the history
* Support for [TAG tag] on AI.SCRIPTSET、AI.MODELSET
  • Loading branch information
dengliming authored May 17, 2020
1 parent 81ac224 commit dd710d8
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 28 deletions.
6 changes: 3 additions & 3 deletions redisai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func getConnectionDetails() (host string, password string) {
}

func createPool() *redis.Pool {
host,_ := getConnectionDetails()
host, _ := getConnectionDetails()
cpool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Expand Down Expand Up @@ -57,12 +57,12 @@ func getTLSdetails() (tlsready bool, tls_cert string, tls_key string, tls_cacert
}

func createTestClient() *Client {
host,_ := getConnectionDetails()
host, _ := getConnectionDetails()
return Connect(host, nil)
}

func TestConnect(t *testing.T) {
host,_ := getConnectionDetails()
host, _ := getConnectionDetails()
cpool1 := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Expand Down
19 changes: 16 additions & 3 deletions redisai/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (c *Client) TensorGetBlob(name string) (dt string, shape []int64, data []by

// ModelSet sets a RedisAI model from a blob
func (c *Client) ModelSet(keyName, backend, device string, data []byte, inputs, outputs []string) (err error) {
args := modelSetFlatArgs(keyName, backend, device, inputs, outputs, data)
args := modelSetFlatArgs(keyName, backend, device, "", inputs, outputs, data)
_, err = c.DoOrSend("AI.MODELSET", args, nil)
return
}
Expand All @@ -95,15 +95,21 @@ func (c *Client) ModelSetFromModel(keyName string, model ModelInterface) (err er
return
}

// ModelGet gets a RedisAI model from the RedisAI server
// The reply will an array, containing at
// - position 0 the backend used by the model as a String
// - position 1 the device used to execute the model as a String
// - position 2 the model's tag as a String
// - position 3 a blob containing the serialized model (when called with the BLOB argument) as a String
func (c *Client) ModelGet(keyName string) (data []interface{}, err error) {
var reply interface{}
data = make([]interface{}, 3)
data = make([]interface{}, 4)
args := modelGetFlatArgs(keyName)
reply, err = c.DoOrSend("AI.MODELGET", args, nil)
if err != nil || reply == nil {
return
}
err, data[0], data[1], data[2] = modelGetParseReply(reply)
err, data[0], data[1], data[2], data[3] = modelGetParseReply(reply)
return
}

Expand Down Expand Up @@ -138,6 +144,13 @@ func (c *Client) ScriptSet(name string, device string, script_source string) (er
return
}

// ScriptSetWithTag sets a RedisAI script from a blob with tag
func (c *Client) ScriptSetWithTag(name string, device string, script_source string, tag string) (err error) {
args := redis.Args{}.Add(name, device, "TAG", tag, "SOURCE", script_source)
_, err = c.DoOrSend("AI.SCRIPTSET", args, nil)
return
}

func (c *Client) ScriptGet(name string) (data map[string]string, err error) {
args := redis.Args{}.Add(name, "META", "SOURCE")
respInitial, err := c.DoOrSend("AI.SCRIPTGET", args, nil)
Expand Down
40 changes: 32 additions & 8 deletions redisai/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,12 @@ func TestCommand_ModelGet(t *testing.T) {
args args
wantBackend string
wantDevice string
wantTag string
wantData []byte
wantErr bool
}{
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, data, true},
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, data, false},
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, "", data, true},
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, "", data, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -501,8 +502,13 @@ func TestCommand_ModelGet(t *testing.T) {
}
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[2], tt.wantData) {
t.Errorf("ModelGetToModel() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[2], tt.wantData, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantData))
if !reflect.DeepEqual(gotData[2], tt.wantTag) {
t.Errorf("ModelGetToModel() gotTag = %v, want %v. gotTag Type %v, want Type %v.", gotData[2], tt.wantTag, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantTag))
}
}
if !tt.wantErr {
if !reflect.DeepEqual(gotData[3], tt.wantData) {
t.Errorf("ModelGetToModel() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[3], tt.wantData, reflect.TypeOf(gotData[3]), reflect.TypeOf(tt.wantData))
}
}

Expand Down Expand Up @@ -621,8 +627,12 @@ func TestCommand_FullFromModelFlow(t *testing.T) {
assert.Nil(t, err)
model1.SetInputs([]string{"transaction", "reference"})
model1.SetOutputs([]string{"output"})
model1.SetTag("financialTag")
err = client.ModelSetFromModel("financialNet1", model1)
assert.Nil(t, err)
model2 := implementations.NewEmptyModel()
err = client.ModelGetToModel("financialNet1", model2)
assert.Equal(t, model1.Tag(), model2.Tag())
}

func TestCommand_ScriptDel(t *testing.T) {
Expand Down Expand Up @@ -684,6 +694,14 @@ func TestCommand_ScriptGet(t *testing.T) {
return
}

keyScript2 := "test:ScriptGet:2"
keyScriptTag := "keyScriptTag"
err = simpleClient.ScriptSetWithTag(keyScript2, DeviceCPU, scriptBin, keyScriptTag)
if err != nil {
t.Errorf("Error preparing for ScriptGet(), while issuing ScriptSet. error = %v", err)
return
}

type args struct {
name string
}
Expand All @@ -692,11 +710,13 @@ func TestCommand_ScriptGet(t *testing.T) {
args args
wantDeviceType string
wantData string
wantTag string
wantErr bool
}{
{keyScript, args{keyScript}, DeviceCPU, "", false},
{keyScriptPipelined, args{keyScript}, DeviceCPU, "", false},
{keyScriptEmpty, args{keyScriptEmpty}, DeviceCPU, "", true},
{keyScript, args{keyScript}, DeviceCPU, "", "", false},
{keyScriptPipelined, args{keyScript}, DeviceCPU, "", "", false},
{keyScriptEmpty, args{keyScriptEmpty}, DeviceCPU, "", "", true},
{keyScriptTag, args{keyScript2}, DeviceCPU, "", keyScriptTag, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -706,13 +726,17 @@ func TestCommand_ScriptGet(t *testing.T) {
t.Errorf("ScriptGet() error = %v, wantErr %v", err, tt.wantErr)
return
}

if tt.wantErr == false {
if !reflect.DeepEqual(gotData["device"], tt.wantDeviceType) {
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["device"], tt.wantDeviceType)
}
if !reflect.DeepEqual(gotData["source"], tt.wantData) {
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["source"], tt.wantData)
}
if !reflect.DeepEqual(gotData["tag"], tt.wantTag) {
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["tag"], tt.wantTag)
}
}

})
Expand Down Expand Up @@ -1014,4 +1038,4 @@ func TestClient_ModelRun(t *testing.T) {
}
})
}
}
}
1 change: 0 additions & 1 deletion redisai/example_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ func ExampleConnect() {
// Output: [1.1 2.2 3.3 4.4]
}


//Example of how to establish an connection with a shared pool to the RedisAI Server
func ExampleConnect_pool() {

Expand Down
5 changes: 2 additions & 3 deletions redisai/example_commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func ExampleClient_ModelGet() {
device := reply[1]
// print the error (should be <nil>)
fmt.Println(err)
fmt.Println(backend,device)
fmt.Println(backend, device)

// Output:
// <nil>
Expand Down Expand Up @@ -185,7 +185,6 @@ func ExampleClient_ModelRun() {
// <nil>
}


func ExampleClient_Info() {
// Create a client.
client := redisai.Connect("redis://localhost:6379", nil)
Expand Down Expand Up @@ -217,4 +216,4 @@ func ExampleClient_Info() {
// <nil>
// <nil>
// Total runs: 1
}
}
9 changes: 9 additions & 0 deletions redisai/implementations/AIModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type AIModel struct {
blob []byte
inputs []string
outputs []string
tag string
}

func (m *AIModel) Outputs() []string {
Expand Down Expand Up @@ -50,6 +51,14 @@ func (m *AIModel) SetBackend(backend string) {
m.backend = backend
}

func (m *AIModel) Tag() string {
return m.tag
}

func (m *AIModel) SetTag(tag string) {
m.tag = tag
}

func NewModel(backend string, device string) *AIModel {
return &AIModel{backend: backend, device: device}
}
Expand Down
39 changes: 35 additions & 4 deletions redisai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ type ModelInterface interface {
SetDevice(device string)
Backend() string
SetBackend(backend string)
Tag() string
SetTag(tag string)
}

func modelSetFlatArgs(keyName, backend, device string, inputs, outputs []string, blob []byte) redis.Args {
func modelSetFlatArgs(keyName, backend, device, tag string, inputs, outputs []string, blob []byte) redis.Args {
args := redis.Args{}.Add(keyName, backend, device)
if len(tag) > 0 {
args = args.Add("TAG", tag)
}
if len(inputs) > 0 {
args = args.Add("INPUTS").AddFlat(inputs)
}
Expand All @@ -33,7 +38,26 @@ func modelSetFlatArgs(keyName, backend, device string, inputs, outputs []string,
}

func modelSetInterfaceArgs(keyName string, modelInterface ModelInterface) redis.Args {
return modelSetFlatArgs(keyName, modelInterface.Backend(), modelInterface.Device(), modelInterface.Inputs(), modelInterface.Outputs(), modelInterface.Blob())
args := redis.Args{keyName}
if len(modelInterface.Backend()) > 0 {
args = args.Add(modelInterface.Backend())
}
if len(modelInterface.Device()) > 0 {
args = args.Add(modelInterface.Device())
}
if len(modelInterface.Tag()) > 0 {
args = args.Add("TAG", modelInterface.Tag())
}
if len(modelInterface.Inputs()) > 0 {
args = args.Add("INPUTS").AddFlat(modelInterface.Inputs())
}
if len(modelInterface.Outputs()) > 0 {
args = args.Add("OUTPUTS").AddFlat(modelInterface.Outputs())
}
if modelInterface.Blob() != nil {
args = args.Add("BLOB", modelInterface.Blob())
}
return args
}

func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string) redis.Args {
Expand All @@ -51,18 +75,20 @@ func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string)
func modelGetParseToInterface(reply interface{}, model ModelInterface) (err error) {
var backend string
var device string
var tag string
var blob []byte
err, backend, device, blob = modelGetParseReply(reply)
err, backend, device, tag, blob = modelGetParseReply(reply)
if err != nil {
return err
}
model.SetBackend(backend)
model.SetDevice(device)
model.SetTag(tag)
model.SetBlob(blob)
return
}

func modelGetParseReply(reply interface{}) (err error, backend string, device string, blob []byte) {
func modelGetParseReply(reply interface{}) (err error, backend string, device string, tag string, blob []byte) {
var replySlice []interface{}
var key string
replySlice, err = redis.Values(reply, err)
Expand Down Expand Up @@ -90,6 +116,11 @@ func modelGetParseReply(reply interface{}) (err error, backend string, device st
if err != nil {
return
}
case "tag":
tag, err = redis.String(replySlice[pos+1], err)
if err != nil {
return
}
}
}
return
Expand Down
16 changes: 10 additions & 6 deletions redisai/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@ func Test_modelGetParseReply(t *testing.T) {
args args
wantBackend string
wantDevice string
wantTag string
wantBlob []byte
wantErr bool
}{
{"empty", args{}, "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
{"empty", args{}, "", "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", "", nil, true},
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotErr, gotBackend, gotDevice, gotBlob := modelGetParseReply(tt.args.reply)
gotErr, gotBackend, gotDevice, gotTag, gotBlob := modelGetParseReply(tt.args.reply)
if gotErr != nil && !tt.wantErr {
t.Errorf("modelGetParseReply() gotErr = %v, want %v", gotErr, tt.wantErr)
}
Expand All @@ -35,6 +36,9 @@ func Test_modelGetParseReply(t *testing.T) {
if gotDevice != tt.wantDevice {
t.Errorf("modelGetParseReply() gotDevice = %v, want %v", gotDevice, tt.wantDevice)
}
if gotTag != tt.wantTag {
t.Errorf("modelGetParseReply() gotTag = %v, want %v", gotTag, tt.wantTag)
}
if !reflect.DeepEqual(gotBlob, tt.wantBlob) {
t.Errorf("modelGetParseReply() gotBlob = %v, want %v", gotBlob, tt.wantBlob)
}
Expand Down

0 comments on commit dd710d8

Please sign in to comment.