-
-
Notifications
You must be signed in to change notification settings - Fork 72
/
io.go
53 lines (47 loc) · 1.46 KB
/
io.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
package onnx
import (
"fmt"
"gorgonia.org/tensor"
)
// SetInput assign a tensor to the i-th input of the graph
func (m *Model) SetInput(i int, t tensor.Tensor) error {
if i >= len(m.Input) {
return fmt.Errorf("error, trying to set input #%v, but model has only #%v input", i, len(m.Input))
}
// Get the corresponding node
n := m.backend.Node(int64(i))
if n == nil {
return fmt.Errorf("cannot set input for node %v, node is nil", i)
}
if _, ok := n.(DataCarrier); !ok {
return fmt.Errorf("cannot set input because node is not a DataCarrier")
}
return n.(DataCarrier).SetTensor(t)
}
// GetOutputTensors of the graph
func (m *Model) GetOutputTensors() ([]tensor.Tensor, error) {
output := make([]tensor.Tensor, len(m.Output))
for i := range m.Output {
n := m.backend.Node(int64(m.Output[i]))
if n == nil {
return nil, fmt.Errorf("cannot get output for node %v, node is nil", i)
}
if _, ok := n.(DataCarrier); !ok {
return nil, fmt.Errorf("cannot set output because node is not a DataCarrier")
}
output[i] = n.(DataCarrier).GetTensor()
}
return output, nil
}
// GetInpuTensors from the graph. This function is useful to get informations if the tensor is a placeholder
// and does not contain any data yet.
func (m *Model) GetInputTensors() []tensor.Tensor {
output := make([]tensor.Tensor, len(m.Input))
for i := range m.Input {
n := m.backend.Node(int64(m.Input[i]))
if n != nil {
output[i] = n.(DataCarrier).GetTensor()
}
}
return output
}