diff --git a/cmd/nvidia-ctk/cdi/transform/root/root.go b/cmd/nvidia-ctk/cdi/transform/root/root.go index 6014cea8f..b89a6c156 100644 --- a/cmd/nvidia-ctk/cdi/transform/root/root.go +++ b/cmd/nvidia-ctk/cdi/transform/root/root.go @@ -18,9 +18,12 @@ package root import ( "fmt" + "io" + "os" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) @@ -32,73 +35,125 @@ type loadSaver interface { type command struct { logger *logrus.Logger +} - handler loadSaver +type transformOptions struct { + input string + output string } -type config struct { +type options struct { + transformOptions from string to string } // NewCommand constructs a generate-cdi command with the specified logger -func NewCommand(logger *logrus.Logger, specHandler loadSaver) *cli.Command { +func NewCommand(logger *logrus.Logger) *cli.Command { c := command{ - logger: logger, - handler: specHandler, + logger: logger, } return c.build() } // build creates the CLI command func (m command) build() *cli.Command { - cfg := config{} + opts := options{} c := cli.Command{ Name: "root", Usage: "Apply a root transform to a CDI specification", Before: func(c *cli.Context) error { - return m.validateFlags(c, &cfg) + return m.validateFlags(c, &opts) }, Action: func(c *cli.Context) error { - return m.run(c, &cfg) + return m.run(c, &opts) }, } c.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "input", + Usage: "Specify the file to read the CDI specification from. If this is '-' the specification is read from STDIN", + Value: "-", + Destination: &opts.input, + }, + &cli.StringFlag{ + Name: "output", + Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT", + Destination: &opts.output, + }, &cli.StringFlag{ Name: "from", Usage: "specify the root to be transformed", - Destination: &cfg.from, + Destination: &opts.from, }, &cli.StringFlag{ Name: "to", Usage: "specify the replacement root. If this is the same as the from root, the transform is a no-op.", Value: "", - Destination: &cfg.to, + Destination: &opts.to, }, } return &c } -func (m command) validateFlags(c *cli.Context, cfg *config) error { +func (m command) validateFlags(c *cli.Context, opts *options) error { return nil } -func (m command) run(c *cli.Context, cfg *config) error { - spec, err := m.handler.Load() +func (m command) run(c *cli.Context, opts *options) error { + spec, err := opts.Load() if err != nil { return fmt.Errorf("failed to load CDI specification: %w", err) } err = transform.NewRootTransformer( - cfg.from, - cfg.to, + opts.from, + opts.to, ).Transform(spec.Raw()) if err != nil { return fmt.Errorf("failed to transform CDI specification: %w", err) } - return m.handler.Save(spec) + return opts.Save(spec) +} + +// Load lodas the input CDI specification +func (o transformOptions) Load() (spec.Interface, error) { + contents, err := o.getContents() + if err != nil { + return nil, fmt.Errorf("failed to read spec contents: %v", err) + } + + raw, err := cdi.ParseSpec(contents) + if err != nil { + return nil, fmt.Errorf("failed to parse CDI spec: %v", err) + } + + return spec.New( + spec.WithRawSpec(raw), + ) +} + +func (o transformOptions) getContents() ([]byte, error) { + if o.input == "-" { + return io.ReadAll(os.Stdin) + } + + return os.ReadFile(o.input) +} + +// Save saves the CDI specification to the output file +func (o transformOptions) Save(s spec.Interface) error { + if o.output == "" { + _, err := s.WriteTo(os.Stdout) + if err != nil { + return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) + } + return nil + } + + return s.Save(o.output) } diff --git a/cmd/nvidia-ctk/cdi/transform/transform.go b/cmd/nvidia-ctk/cdi/transform/transform.go index 22ad2e136..5538166d1 100644 --- a/cmd/nvidia-ctk/cdi/transform/transform.go +++ b/cmd/nvidia-ctk/cdi/transform/transform.go @@ -17,13 +17,7 @@ package transform import ( - "fmt" - "io" - "os" - "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/cdi/transform/root" - "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) @@ -32,11 +26,6 @@ type command struct { logger *logrus.Logger } -type options struct { - input string - output string -} - // NewCommand constructs a command with the specified logger func NewCommand(logger *logrus.Logger) *cli.Command { c := command{ @@ -47,82 +36,16 @@ func NewCommand(logger *logrus.Logger) *cli.Command { // build creates the CLI command func (m command) build() *cli.Command { - opts := options{} - c := cli.Command{ Name: "transform", Usage: "Apply a transform to a CDI specification", - Before: func(c *cli.Context) error { - return m.validateFlags(c, &opts) - }, - Action: func(c *cli.Context) error { - return m.run(c, &opts) - }, } - c.Flags = []cli.Flag{ - &cli.StringFlag{ - Name: "input", - Usage: "Specify the file to read the CDI specification from. If this is '-' the specification is read from STDIN", - Value: "-", - Destination: &opts.input, - }, - &cli.StringFlag{ - Name: "output", - Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT", - Destination: &opts.output, - }, - } + c.Flags = []cli.Flag{} c.Subcommands = []*cli.Command{ - root.NewCommand(m.logger, &opts), + root.NewCommand(m.logger), } return &c } - -func (m command) validateFlags(c *cli.Context, opts *options) error { - return nil -} - -func (m command) run(c *cli.Context, cfg *options) error { - return nil -} - -// Load lodas the input CDI specification -func (o options) Load() (spec.Interface, error) { - contents, err := o.getContents() - if err != nil { - return nil, fmt.Errorf("failed to read spec contents: %v", err) - } - - raw, err := cdi.ParseSpec(contents) - if err != nil { - return nil, fmt.Errorf("failed to parse CDI spec: %v", err) - } - - return spec.New( - spec.WithRawSpec(raw), - ) -} - -func (o options) getContents() ([]byte, error) { - if o.input == "-" { - return io.ReadAll(os.Stdin) - } - - return os.ReadFile(o.input) -} - -// Save saves the CDI specification to the output file -func (o options) Save(s spec.Interface) error { - if o.output == "" { - _, err := s.WriteTo(os.Stdout) - if err != nil { - return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) - } - return nil - } - - return s.Save(o.output) -}