diff --git a/.codecov.yml b/.codecov.yml index 8153244d..ed9d686a 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,7 +1,9 @@ ignore: - "client/app/app.go" - - "cmd/ocfclient/ocfclient.go" + - "cmd/bridge-device/*.go" + - "cmd/ocfclient/*.go" - "**/main.go" + - "**/test/**/*.go" - "**/*.pb.go" - "**/*.pb.gw.go" - "**/*_test.go" diff --git a/.github/workflows/build-publish-cfg.yaml b/.github/workflows/build-publish-cfg.yaml index f6ef7d66..55d3a423 100644 --- a/.github/workflows/build-publish-cfg.yaml +++ b/.github/workflows/build-publish-cfg.yaml @@ -12,10 +12,6 @@ on: description: Name of the container type: string required: true - directory: - description: Directory of service - type: string - required: true file: description: Dockerfile to build type: string @@ -102,7 +98,6 @@ jobs: platforms: linux/amd64,linux/arm64 builder: ${{ steps.buildx.outputs.name }} build-args: | - DIRECTORY=${{ inputs.directory }} NAME=${{ inputs.name }} COMMIT_DATE=${{ steps.build-args.outputs.commit_date }} SHORT_COMMIT=${{ steps.build-args.outputs.short_commit }} @@ -123,7 +118,6 @@ jobs: platforms: linux/amd64,linux/arm64 builder: ${{ steps.buildx.outputs.name }} build-args: | - DIRECTORY=${{ inputs.directory }} NAME=${{ inputs.name }} COMMIT_DATE=${{ steps.build-args.outputs.commit_date }} SHORT_COMMIT=${{ steps.build-args.outputs.short_commit }} diff --git a/.github/workflows/build-publish.yaml b/.github/workflows/build-publish.yaml index fab8eb10..5ea80beb 100644 --- a/.github/workflows/build-publish.yaml +++ b/.github/workflows/build-publish.yaml @@ -27,11 +27,11 @@ jobs: matrix: include: - name: test-cloud-server - directory: test/cloud-server file: test/cloud-server/Dockerfile + - name: bridge-device + file: cmd/bridge-device/Dockerfile uses: ./.github/workflows/build-publish-cfg.yaml with: name: ${{ matrix.name }} - directory: ${{ matrix.directory }} file: ${{ matrix.file }} diff --git a/.github/workflows/test-with-cfg.yml b/.github/workflows/test-with-cfg.yml index 9af269c8..6864075b 100644 --- a/.github/workflows/test-with-cfg.yml +++ b/.github/workflows/test-with-cfg.yml @@ -11,6 +11,10 @@ on: name: type: string required: true + bridgetest-enabled: + type: boolean + required: false + default: false coverage: type: boolean required: false @@ -22,10 +26,10 @@ on: type: string required: false default: "" - unittest-args: - type: string + unittest-enabled: + type: boolean required: false - default: "" + default: false jobs: test: runs-on: ubuntu-latest @@ -51,11 +55,16 @@ jobs: if: ${{ failure() }} run: docker logs -t devsim-net-host && cat .tmp/devsim-net-host/0.log - # Run after integration tests, because they first clean-up the output directory + # Run after integration tests, because they always clean-up the output directory + - name: Run bridge tests + if: ${{ inputs.bridgetest-enabled }} + run: | + make test-bridge + - name: Run unit tests - if: ${{ inputs.coverage }} + if: ${{ inputs.unittest-enabled }} run: | - make unit-test ${{ inputs.unittest-args }} + make unit-test - name: Get output file name if: ${{ inputs.coverage }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a6693156..5c9e1961 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,8 +21,10 @@ jobs: matrix: include: - name: cloud-server-debug + bridgetest-enabled: true coverage: true tag: ghcr.io/iotivity/iotivity-lite/cloud-server-debug:vnext + unittest-enabled: true - name: cloud-server-debug-sha384 test-args: CERT_TOOL_SIGN_ALG=ECDSA-SHA384 CERT_TOOL_ELLIPTIC_CURVE=P384 coverage: true @@ -37,9 +39,11 @@ jobs: uses: ./.github/workflows/test-with-cfg.yml with: name: ${{ matrix.name }} + bridgetest-enabled: ${{ matrix.bridgetest-enabled || false }} coverage: ${{ matrix.coverage }} tag: ${{ matrix.tag }} test-args: ${{ matrix.test-args }} + unittest-enabled: ${{ matrix.unittest-enabled || false }} analysis: name: SonarCloud and codecov analysis diff --git a/.gitignore b/.gitignore index 4bd75fbc..bcada047 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,6 @@ vendor/ .vscode/ .tmp/ debug - +cmd/bridge-device/bridge-device +cmd/ocfclient/ocfclient +test/bridge-device/bridge-device diff --git a/.vscode/settings.json b/.vscode/settings.json index dfa25688..17339c48 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,6 +3,7 @@ "-race", "-v", "-cover", + "-coverpkg=./...", ], "go.testEnvVars": { "ROOT_CA_CRT": "${workspaceFolder}/.tmp/pki_certs/cloudca.pem", @@ -13,9 +14,12 @@ "MFG_KEY": "${workspaceFolder}/.tmp/pki_certs/mfgkey.pem", "IDENTITY_CRT": "${workspaceFolder}/.tmp/pki_certs/identitycrt.pem", "IDENTITY_KEY": "${workspaceFolder}/.tmp/pki_certs/identitykey.pem", + "COAP_CRT": "${workspaceFolder}/.tmp/pki_certs/coapcrt.pem", + "COAP_KEY": "${workspaceFolder}/.tmp/pki_certs/coapkey.pem", + "CLOUD_SID": "adebc667-1f2b-41e3-bf5c-6d6eabc68cc6", }, "files.watcherExclude": { - "**/plgd-dev/device/v2/**": true - }, + "**/plgd-dev/device/v2/**": true + }, "go.testTimeout": "600s", } \ No newline at end of file diff --git a/Makefile b/Makefile index e6af0958..65f9ca67 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,11 @@ SHELL = /bin/bash SERVICE_NAME = cloud-server-test VERSION_TAG = vnext-$(shell git rev-parse --short=7 --verify HEAD) SIMULATOR_NAME_SUFFIX ?= $(shell hostname) +USER_ID := $(shell id -u) +GROUP_ID := $(shell id -g) TMP_PATH = $(shell pwd)/.tmp CERT_PATH = $(TMP_PATH)/pki_certs +CLOUD_SID ?= adebc667-1f2b-41e3-bf5c-6d6eabc68cc6 DEVSIM_NET_HOST_PATH = $(shell pwd)/.tmp/devsim-net-host CERT_TOOL_IMAGE ?= ghcr.io/plgd-dev/hub/cert-tool:vnext # supported values: ECDSA-SHA256, ECDSA-SHA384, ECDSA-SHA512 @@ -11,6 +14,7 @@ CERT_TOOL_SIGN_ALG ?= ECDSA-SHA256 # supported values: P256, P384, P521 CERT_TOOL_ELLIPTIC_CURVE ?= P256 DEVSIM_IMAGE ?= ghcr.io/iotivity/iotivity-lite/cloud-server-discovery-resource-observable-debug:vnext +HUB_TEST_DEVICE_IMAGE = ghcr.io/plgd-dev/hub/test-cloud-server:vnext-pr1202 default: build @@ -28,21 +32,58 @@ build-testcontainer: build: build-testcontainer +ROOT_CA_CRT = $(CERT_PATH)/cloudca.pem +ROOT_CA_KEY = $(CERT_PATH)/cloudcakey.pem +INTERMEDIATE_CA_CRT = $(CERT_PATH)/intermediatecacrt.pem +INTERMEDIATE_CA_KEY = $(CERT_PATH)/intermediatecakey.pem +MFG_CRT = $(CERT_PATH)/mfgcrt.pem +MFG_KEY = $(CERT_PATH)/mfgkey.pem +COAP_CRT = $(CERT_PATH)/coapcrt.pem +COAP_KEY = $(CERT_PATH)/coapkey.pem + certificates: mkdir -p $(CERT_PATH) chmod 0777 $(CERT_PATH) docker pull $(CERT_TOOL_IMAGE) - docker run --rm -v $(CERT_PATH):/out $(CERT_TOOL_IMAGE) --outCert=/out/cloudca.pem --outKey=/out/cloudcakey.pem \ - --cert.subject.cn="ca" --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) \ - --cmd.generateRootCA - docker run --rm -v $(CERT_PATH):/out $(CERT_TOOL_IMAGE) --signerCert=/out/cloudca.pem --signerKey=/out/cloudcakey.pem \ - --outCert=/out/intermediatecacrt.pem --outKey=/out/intermediatecakey.pem --cert.basicConstraints.maxPathLen=0 \ - --cert.subject.cn="intermediateCA" --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) \ - --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) --cmd.generateIntermediateCA - docker run --rm -v $(CERT_PATH):/out $(CERT_TOOL_IMAGE) --signerCert=/out/intermediatecacrt.pem \ - --signerKey=/out/intermediatecakey.pem --outCert=/out/mfgcrt.pem --outKey=/out/mfgkey.pem --cert.san.domain=localhost \ - --cert.san.ip=127.0.0.1 --cert.subject.cn="mfg" --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) \ - --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) --cmd.generateCertificate + + docker run \ + --rm -v $(CERT_PATH):/out \ + --user $(USER_ID):$(GROUP_ID) \ + $(CERT_TOOL_IMAGE) \ + --outCert=/out/cloudca.pem --outKey=/out/cloudcakey.pem \ + --cert.subject.cn="ca" --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) \ + --cmd.generateRootCA + + docker run \ + --rm -v $(CERT_PATH):/out \ + --user $(USER_ID):$(GROUP_ID) \ + $(CERT_TOOL_IMAGE) \ + --signerCert=/out/cloudca.pem --signerKey=/out/cloudcakey.pem \ + --outCert=/out/intermediatecacrt.pem --outKey=/out/intermediatecakey.pem \ + --cert.basicConstraints.maxPathLen=0 --cert.subject.cn="intermediateCA" \ + --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) \ + --cmd.generateIntermediateCA + + docker run \ + --rm -v $(CERT_PATH):/out \ + --user $(USER_ID):$(GROUP_ID) \ + $(CERT_TOOL_IMAGE) \ + --signerCert=/out/intermediatecacrt.pem --signerKey=/out/intermediatecakey.pem \ + --outCert=/out/mfgcrt.pem --outKey=/out/mfgkey.pem --cert.san.domain=localhost \ + --cert.san.ip=127.0.0.1 --cert.subject.cn="mfg" \ + --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) \ + --cmd.generateCertificate + + docker run \ + --rm -v $(CERT_PATH):/out \ + --user $(USER_ID):$(GROUP_ID) \ + ${CERT_TOOL_IMAGE} \ + --signerCert=/out/cloudca.pem --signerKey=/out/cloudcakey.pem \ + --outCert=/out/coapcrt.pem --outKey=/out/coapkey.pem \ + --cert.san.ip=127.0.0.1 --cert.san.domain=localhost \ + --cert.signatureAlgorithm=$(CERT_TOOL_SIGN_ALG) --cert.ellipticCurve=$(CERT_TOOL_ELLIPTIC_CURVE) \ + --cmd.generateCertificate --cert.subject.cn=uuid:$(CLOUD_SID) + sudo chown -R $(shell whoami) $(CERT_PATH) chmod -R 0777 $(CERT_PATH) @@ -64,8 +105,15 @@ env: clean certificates unit-test: certificates mkdir -p $(TMP_PATH) - go test -race -v ./schema/... -covermode=atomic -coverprofile=$(TMP_PATH)/schema.coverage.txt - ROOT_CA_CRT="$(CERT_PATH)/cloudca.pem" ROOT_CA_KEY="$(CERT_PATH)/cloudcakey.pem" go test -race -v ./pkg/... -covermode=atomic -coverprofile=$(TMP_PATH)/pkg.coverage.txt + ROOT_CA_CRT="$(ROOT_CA_CRT)" ROOT_CA_KEY="$(ROOT_CA_KEY)" \ + MFG_CRT="$(MFG_CRT)" MFG_KEY="$(MFG_KEY)" \ + INTERMEDIATE_CA_CRT="$(INTERMEDIATE_CA_CRT)" INTERMEDIATE_CA_KEY=$(INTERMEDIATE_CA_KEY) \ + COAP_CRT="$(COAP_CRT)" COAP_KEY="$(COAP_KEY)" \ + CLOUD_SID=$(CLOUD_SID) \ + go test -race -parallel 1 -v ./bridge/... -coverpkg=./... -covermode=atomic -coverprofile=$(TMP_PATH)/bridge.unit.coverage.txt + go test -race -v ./schema/... -covermode=atomic -coverprofile=$(TMP_PATH)/schema.unit.coverage.txt + ROOT_CA_CRT="$(ROOT_CA_CRT)" ROOT_CA_KEY="$(ROOT_CA_KEY)" \ + go test -race -v ./pkg/... -covermode=atomic -coverprofile=$(TMP_PATH)/pkg.unit.coverage.txt test: env build-testcontainer docker run \ @@ -75,8 +123,63 @@ test: env build-testcontainer -v $(TMP_PATH):/tmp \ $(SERVICE_NAME):$(VERSION_TAG) -test.parallel 1 -test.v -test.coverprofile=/tmp/coverage.txt +test-bridge: + sudo rm -rf $(TMP_PATH)/data || : + mkdir -p $(TMP_PATH)/data + # pull image + docker pull $(HUB_TEST_DEVICE_IMAGE) + # prepare environment + docker run \ + --rm \ + --network=host \ + --name hub-device-tests-environment \ + --env PREPARE_ENV=true \ + --env RUN=false \ + --env COAP_GATEWAY_CLOUD_ID="$(CLOUD_SID)" \ + -v $(TMP_PATH):/tmp \ + -v $(TMP_PATH)/data:/data \ + $(HUB_TEST_DEVICE_IMAGE) + + # start device + rm -rf $(TMP_PATH)/bridge || : + mkdir -p $(TMP_PATH)/bridge + go build -C ./test/bridge-device -cover -o ./bridge-device + pkill -KILL bridge-device || : + CLOUD_SID=$(CLOUD_SID) CA_POOL=$(TMP_PATH)/data/certs/root_ca.crt \ + CERT_FILE=$(TMP_PATH)/data/certs/external/coap-gateway.crt \ + KEY_FILE=$(TMP_PATH)/data/certs/external/coap-gateway.key \ + GOCOVERDIR=$(TMP_PATH)/bridge \ + ./test/bridge-device/bridge-device & + + # run tests + docker run \ + --rm \ + --network=host \ + --name hub-device-tests \ + --env PREPARE_ENV=false \ + --env RUN=true \ + --env COAP_GATEWAY_CLOUD_ID="$(CLOUD_SID)" \ + --env TEST_DEVICE_NAME="bridged-device-0" \ + --env TEST_DEVICE_TYPE="bridged" \ + --env GRPC_GATEWAY_TEST_DISABLED=1 \ + --env IOTIVITY_LITE_TEST_RUN="(TestOffboard|TestOffboardWithoutSignIn|TestOffboardWithRepeat|TestRepublishAfterRefresh)$$" \ + -v $(TMP_PATH):/tmp \ + -v $(TMP_PATH)/data:/data \ + $(HUB_TEST_DEVICE_IMAGE) + + # stop device + pkill -TERM bridge-device || : + while pgrep -x bridge-device > /dev/null; do \ + echo "waiting for bridge-device to exit"; \ + sleep 1; \ + done + go tool covdata textfmt -i=$(TMP_PATH)/bridge -o $(TMP_PATH)/bridge.coverage.txt + clean: - docker rm -f devsim-net-host || true + docker rm -f devsim-net-host || : + docker rm -f hub-device-tests-environment || : + docker rm -f hub-device-tests || : + pkill -KILL bridge-device || : sudo rm -rf .tmp/* .PHONY: build-testcontainer build certificates clean env test unit-test diff --git a/bridge/README.md b/bridge/README.md new file mode 100644 index 00000000..72c1480e --- /dev/null +++ b/bridge/README.md @@ -0,0 +1,5 @@ +# TODO + +- Set RootCAs to the cloud connection +- Unit tests +- Set logger to package and propagate to package diff --git a/bridge/device/cloud/config.go b/bridge/device/cloud/config.go new file mode 100644 index 00000000..146c30d8 --- /dev/null +++ b/bridge/device/cloud/config.go @@ -0,0 +1,35 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "github.com/plgd-dev/device/v2/schema/cloud" +) + +type Configuration struct { + ResourceTypes []string `yaml:"-" json:"rt"` + Interfaces []string `yaml:"-" json:"if"` + Name string `yaml:"-" json:"n"` + AuthorizationProvider string `yaml:"authorizationProvider" json:"apn"` + CloudID string `yaml:"cloudID" json:"sid"` + URL string `yaml:"cloudEndpoint" json:"cis"` + LastErrorCode int `yaml:"-" json:"clec"` + ProvisioningStatus cloud.ProvisioningStatus `yaml:"-" json:"cps"` + AuthorizationCode string `yaml:"-" json:"-"` +} diff --git a/bridge/device/cloud/manager.go b/bridge/device/cloud/manager.go new file mode 100644 index 00000000..d1dbdbde --- /dev/null +++ b/bridge/device/cloud/manager.go @@ -0,0 +1,477 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/bridge/resources/discovery" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/pkg/net/coap" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/device/v2/schema/device" + plgdResources "github.com/plgd-dev/device/v2/schema/resources" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/options" + "github.com/plgd-dev/go-coap/v3/tcp" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +type ( + GetLinksFilteredBy func(endpoints schema.Endpoints, deviceIDfilter uuid.UUID, resourceTypesFitler []string, policyBitMaskFitler schema.BitMask) (links schema.ResourceLinks) + GetCertificates func(deviceID string) []tls.Certificate + RemoveCloudCAs func(cloudID ...string) +) + +type Config struct { + AccessToken string + UserID string + RefreshToken string + ValidUntil time.Time + AuthorizationProvider string + CloudID string + CloudURL string +} + +type CAPoolGetter = interface { + IsValid() bool + GetPool() (*x509.CertPool, error) +} + +type Manager struct { + handler net.RequestHandler + getLinks GetLinksFilteredBy + maxMessageSize uint32 + deviceID uuid.UUID + save func() + caPool CAPoolGetter + getCertificates GetCertificates + removeCloudCAs RemoveCloudCAs + + private struct { + mutex sync.Mutex + cfg Configuration + previousCloudIDs []string + } + + logger log.Logger + creds ocfCloud.CoapSignUpResponse + client *client.Conn + signedIn bool + resourcesPublished bool + done chan struct{} + trigger chan bool +} + +func New(cfg Config, deviceID uuid.UUID, save func(), handler net.RequestHandler, getLinks GetLinksFilteredBy, caPool CAPoolGetter, opts ...Option) (*Manager, error) { + if !caPool.IsValid() { + return nil, fmt.Errorf("invalid ca pool") + } + o := OptionsCfg{ + maxMessageSize: net.DefaultMaxMessageSize, + getCertificates: func(string) []tls.Certificate { + return nil + }, + removeCloudCAs: func(...string) { + // do nothing + }, + logger: log.NewNilLogger(), + } + for _, opt := range opts { + opt(&o) + } + + c := &Manager{ + done: make(chan struct{}), + trigger: make(chan bool, 10), + handler: handler, + getLinks: getLinks, + deviceID: deviceID, + maxMessageSize: o.maxMessageSize, + save: save, + caPool: caPool, + getCertificates: o.getCertificates, + removeCloudCAs: o.removeCloudCAs, + logger: o.logger, + } + c.private.cfg.ProvisioningStatus = cloud.ProvisioningStatus_UNINITIALIZED + c.importConfig(cfg) + return c, nil +} + +func (c *Manager) Get(request *net.Request) (*pool.Message, error) { + cfg := c.getCloudConfiguration() + return resources.CreateResponseContent(request.Context(), cfg, codes.Content) +} + +func (c *Manager) ExportConfig() Config { + configuration := c.getCloudConfiguration() + creds := c.getCreds() + return Config{ + CloudID: configuration.CloudID, + AuthorizationProvider: configuration.AuthorizationProvider, + CloudURL: configuration.URL, + AccessToken: creds.AccessToken, + UserID: creds.UserID, + RefreshToken: creds.RefreshToken, + ValidUntil: creds.ValidUntil, + } +} + +func (c *Manager) importConfig(cfg Config) { + c.setCloudConfiguration(cloud.ConfigurationUpdateRequest{ + AuthorizationProvider: cfg.AuthorizationProvider, + URL: cfg.CloudURL, + CloudID: cfg.CloudID, + }) + c.setCreds(ocfCloud.CoapSignUpResponse{ + AccessToken: cfg.AccessToken, + UserID: cfg.UserID, + RefreshToken: cfg.RefreshToken, + ValidUntil: cfg.ValidUntil, + }) +} + +func (c *Manager) Init() { + if c.private.cfg.URL != "" { + c.triggerRunner(false) + } + go c.run() +} + +func (c *Manager) resetCredentials(ctx context.Context, signOff bool) { + if signOff { + resetCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + if err := c.signOff(resetCtx); err != nil { + c.logger.Debugf("%w", err) + } + } + c.creds = ocfCloud.CoapSignUpResponse{} + c.signedIn = false + c.resourcesPublished = false + if err := c.close(); err != nil { + c.logger.Warnf("cannot close connection: %w", err) + } + c.save() + c.removePreviousCloudIDs() +} + +func (c *Manager) cleanup() { + c.setCloudConfiguration(cloud.ConfigurationUpdateRequest{}) + c.resetCredentials(context.Background(), false) + c.triggerRunner(false) +} + +func (c *Manager) triggerRunner(reset bool) { + select { + case c.trigger <- reset: + default: + } +} + +func validateConfigurationUpdate(cfg cloud.ConfigurationUpdateRequest) error { + if cfg.CloudID == "" { + return fmt.Errorf("cloud ID cannot be empty") + } + if cfg.AuthorizationProvider == "" { + return fmt.Errorf("authorization provider cannot be empty") + } + if cfg.URL == "" { + return fmt.Errorf("URL cannot be empty") + } + return nil +} + +func (c *Manager) Post(request *net.Request) (*pool.Message, error) { + var cfg cloud.ConfigurationUpdateRequest + err := cbor.ReadFrom(request.Body(), &cfg) + if err != nil { + return resources.CreateResponseBadRequest(request.Context(), err) + } + if cfg.URL == "" { + c.setCloudConfiguration(cloud.ConfigurationUpdateRequest{}) + } else { + err = validateConfigurationUpdate(cfg) + if err != nil { + return resources.CreateResponseBadRequest(request.Context(), err) + } + c.setCloudConfiguration(cfg) + } + c.triggerRunner(true) + currentCfg := c.getCloudConfiguration() + return resources.CreateResponseContent(request.Context(), currentCfg, codes.Changed) +} + +func (c *Manager) popPreviousCloudIDs() []string { + c.private.mutex.Lock() + defer c.private.mutex.Unlock() + previousCloudIDs := c.private.previousCloudIDs + c.private.previousCloudIDs = nil + return previousCloudIDs +} + +func (c *Manager) removePreviousCloudIDs() { + c.removeCloudCAs(c.popPreviousCloudIDs()...) +} + +func (c *Manager) setCloudConfiguration(cfg cloud.ConfigurationUpdateRequest) { + c.private.mutex.Lock() + defer c.private.mutex.Unlock() + c.private.previousCloudIDs = append(c.private.previousCloudIDs, c.private.cfg.CloudID) + c.private.cfg.AuthorizationProvider = cfg.AuthorizationProvider + c.private.cfg.CloudID = cfg.CloudID + c.private.cfg.URL = cfg.URL + c.private.cfg.AuthorizationCode = cfg.AuthorizationCode + if cfg.URL == "" { + c.private.cfg.ProvisioningStatus = cloud.ProvisioningStatus_UNINITIALIZED + } else { + c.private.cfg.ProvisioningStatus = cloud.ProvisioningStatus_READY_TO_REGISTER + } +} + +func (c *Manager) setProvisioningStatus(status cloud.ProvisioningStatus) { + c.private.mutex.Lock() + defer c.private.mutex.Unlock() + c.private.cfg.ProvisioningStatus = status +} + +func (c *Manager) getCloudConfiguration() Configuration { + c.private.mutex.Lock() + defer c.private.mutex.Unlock() + return c.private.cfg +} + +func validUntil(expiresIn int64) time.Time { + if expiresIn == -1 { + return time.Time{} + } + return time.Now().Add(time.Duration(expiresIn) * time.Second) +} + +func (c *Manager) setCreds(creds ocfCloud.CoapSignUpResponse) { + c.creds = creds + c.signedIn = false +} + +func (c *Manager) getCreds() ocfCloud.CoapSignUpResponse { + return c.creds +} + +func (c *Manager) isCredsExpiring() bool { + if !c.signedIn || c.creds.ValidUntil.IsZero() { + return false + } + return !time.Now().Before(c.creds.ValidUntil.Add(-time.Second * 10)) +} + +func (c *Manager) serveCOAP(w mux.ResponseWriter, request *mux.Message) { + request.Message.AddQuery("di=" + c.deviceID.String()) + r := net.Request{ + Message: request.Message, + Endpoints: nil, + Conn: w.Conn(), + } + var resp *pool.Message + p, err := r.Path() + if err == nil { + switch p { + case device.ResourceURI: + links := c.getLinks(schema.Endpoints{}, c.deviceID, nil, resources.PublishToCloud) + for _, link := range links { + if link.HasType(device.ResourceType) { + _ = r.SetPath(link.Href) + break + } + } + resp, err = c.handler(&r) + case plgdResources.ResourceURI: + links := c.getLinks(schema.Endpoints{}, c.deviceID, nil, resources.PublishToCloud) + links = patchDeviceLink(links) + links = discovery.PatchLinks(links, c.deviceID.String()) + resp, err = resources.CreateResponseContent(request.Context(), links, codes.Content) + default: + resp, err = c.handler(&r) + } + } else { + resp, err = c.handler(&r) + } + if err != nil { + resp = net.CreateResponseError(request.Context(), err, request.Token()) + } + if resp != nil { + resp.SetToken(w.Message().Token()) + resp.SetMessageID(w.Message().MessageID()) + resp.SetType(w.Message().Type()) + w.SetMessage(resp) + } +} + +func (c *Manager) close() error { + c.signedIn = false + if c.client == nil { + return nil + } + client := c.client + c.client = nil + return client.Close() +} + +func (c *Manager) dial(ctx context.Context) error { + if c.client != nil && c.client.Context().Err() == nil { + return nil + } + _ = c.close() + cfg := c.getCloudConfiguration() + + caPool, err := c.caPool.GetPool() + if err != nil { + return fmt.Errorf("cannot get ca pool: %w", err) + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + Certificates: c.getCertificates(c.deviceID.String()), + VerifyPeerCertificate: coap.NewVerifyPeerCertificate(caPool, func(cert *x509.Certificate) error { + cloudID, errP := uuid.Parse(c.getCloudConfiguration().CloudID) + if errP != nil { + return fmt.Errorf("cannot parse cloudID: %w", errP) + } + return coap.VerifyCloudCertificate(cert, cloudID) + }), + } + + ep := schema.Endpoint{ + URI: cfg.URL, + } + addr, err := ep.GetAddr() + if err != nil { + return fmt.Errorf("cannot get address from %v: %w", ep, err) + } + m := mux.NewRouter() + m.Use(net.CreateLoggingMiddleware(c.logger)) + m.DefaultHandle(mux.HandlerFunc(c.serveCOAP)) + conn, err := tcp.Dial(addr.String(), + options.WithTLS(tlsConfig), + options.WithMux(m), + options.WithContext(ctx), + options.WithMaxMessageSize(c.maxMessageSize), + options.WithBlockwise(false, blockwise.SZX1024, time.Second*4), + options.WithErrors(func(err error) { + c.logger.Errorf("cloud connection error: %w", err) + }), + options.WithKeepAlive(2, time.Second*10, func(conn *client.Conn) { + c.logger.Infof("cloud connection: keepalive timeout") + if errC := conn.Close(); errC != nil { + c.logger.Warnf("cannot close cloud connection: %w", errC) + } + })) + if err != nil { + return fmt.Errorf("cannot dial to %v: %w", addr.String(), err) + } + c.client = conn + return nil +} + +func patchDeviceLink(links schema.ResourceLinks) schema.ResourceLinks { + for idx, link := range links { + if !link.HasType(device.ResourceType) { + continue + } + newLink := link + newLink.Href = device.ResourceURI + links[idx] = newLink + break + } + return links +} + +func (c *Manager) run() { + ctx := context.Background() + t := time.NewTicker(time.Second * 10) + for { + select { + case <-c.done: + return + case wantToReset := <-c.trigger: + if wantToReset { + c.resetCredentials(ctx, true) + } + case <-t.C: + } + if c.getCloudConfiguration().URL != "" { + if err := c.connect(ctx); err != nil { + c.logger.Errorf("cannot connect to cloud: %w", err) + } else { + c.setProvisioningStatus(cloud.ProvisioningStatus_REGISTERED) + } + } + } +} + +func (c *Manager) connect(ctx context.Context) error { + var funcs []func(ctx context.Context) error + if c.isCredsExpiring() { + funcs = append(funcs, c.refreshToken) + } + funcs = append(funcs, []func(ctx context.Context) error{ + c.signUp, + c.signIn, + c.publishResources, + }...) + err := c.dial(ctx) + if err != nil { + return err + } + for _, f := range funcs { + r := func(ctx context.Context) error { + fctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + return f(fctx) + } + err := r(ctx) + if err != nil { + _ = c.close() + return err + } + } + return nil +} + +func (c *Manager) Close() { + c.done <- struct{}{} +} + +func (c *Manager) Unregister() { + c.setCloudConfiguration(cloud.ConfigurationUpdateRequest{}) + c.triggerRunner(true) +} diff --git a/bridge/device/cloud/manager_test.go b/bridge/device/cloud/manager_test.go new file mode 100644 index 00000000..93218363 --- /dev/null +++ b/bridge/device/cloud/manager_test.go @@ -0,0 +1,114 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + cloudSchema "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" + mockCoapGW "github.com/plgd-dev/device/v2/test/coap-gateway" + mockCoapGWService "github.com/plgd-dev/device/v2/test/coap-gateway/service" + "github.com/stretchr/testify/require" +) + +// device is restarted with an imported configuration with valid cloud credentials +func TestProvisioningOnDeviceRestart(t *testing.T) { + ch := mockCoapGW.NewCoapHandlerWithCounter(-1) + makeHandler := func(s *mockCoapGWService.Service, opts ...mockCoapGWService.Option) mockCoapGWService.ServiceHandler { + return ch + } + coapShutdown := mockCoapGW.New(t, makeHandler, func(handler mockCoapGWService.ServiceHandler) { + h := handler.(*mockCoapGW.DefaultHandlerWithCounter) + fmt.Printf("%+v\n", h.CallCounter.Data) + // d1 -> signup + signin + publish + // d2 -> should use the stored credentials to skip signup and only do sign in + publish + require.Equal(t, 1, h.CallCounter.Data[mockCoapGW.SignUpKey]) + require.Equal(t, 2, h.CallCounter.Data[mockCoapGW.SignInKey]) + require.Equal(t, 2, h.CallCounter.Data[mockCoapGW.PublishKey]) + require.Equal(t, 0, h.CallCounter.Data[mockCoapGW.RefreshTokenKey]) + }) + defer coapShutdown() + + s1 := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s1.Shutdown() + }) + deviceID := uuid.New().String() + d1 := bridgeTest.NewBridgedDevice(t, s1, deviceID, true, false) + s1Shutdown := bridgeTest.RunBridgeService(s1) + t.Cleanup(func() { + _ = s1Shutdown() + }) + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + err = c.OnboardDevice(ctx, deviceID, "authorizationProvider", "coaps+tcp://"+mockCoapGW.COAP_GW_HOST, "authorizationCode", test.CloudSID()) + require.NoError(t, err) + + // wait for sign in + require.Equal(t, 1, ch.WaitForSignIn(time.Second*20)) + + // stop service + err = s1Shutdown() + require.NoError(t, err) + + // save the device configuration + cfg := d1.ExportConfig() + + // recreate device using the saved configuration from a signed in device + s2 := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s2.Shutdown() + }) + d2 := bridgeTest.NewBridgedDeviceWithConfig(t, s2, cfg) + s2Shutdown := bridgeTest.RunBridgeService(s2) + defer func() { + errS := s2Shutdown() + require.NoError(t, errS) + }() + require.Equal(t, 2, ch.WaitForSignIn(time.Second*20)) + + // check provisioning status + var cloudCfg cloud.Configuration + err = c.GetResource(ctx, deviceID, cloudSchema.ResourceURI, &cloudCfg) + require.NoError(t, err) + require.Equal(t, cloudCfg.ProvisioningStatus, cloudSchema.ProvisioningStatus_REGISTERED) + + // sign off + cloudManager := d2.GetCloudManager() + if cloudManager != nil { + cloudManager.Unregister() + } + require.Equal(t, 1, ch.WaitForSignOff(time.Second*20)) +} diff --git a/bridge/device/cloud/options.go b/bridge/device/cloud/options.go new file mode 100644 index 00000000..89ea26f2 --- /dev/null +++ b/bridge/device/cloud/options.go @@ -0,0 +1,58 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "github.com/plgd-dev/device/v2/pkg/log" +) + +type OptionsCfg struct { + maxMessageSize uint32 + getCertificates GetCertificates + removeCloudCAs RemoveCloudCAs + logger log.Logger +} + +type Option func(*OptionsCfg) + +func WithMaxMessageSize(maxMessageSize uint32) Option { + return func(o *OptionsCfg) { + if maxMessageSize > 0 { + o.maxMessageSize = maxMessageSize + } + } +} + +func WithGetCertificates(getCertificates GetCertificates) Option { + return func(o *OptionsCfg) { + o.getCertificates = getCertificates + } +} + +func WithRemoveCloudCAs(removeCloudCA RemoveCloudCAs) Option { + return func(o *OptionsCfg) { + o.removeCloudCAs = removeCloudCA + } +} + +func WithLogger(logger log.Logger) Option { + return func(o *OptionsCfg) { + o.logger = logger + } +} diff --git a/bridge/device/cloud/publishResources.go b/bridge/device/cloud/publishResources.go new file mode 100644 index 00000000..f32f61d4 --- /dev/null +++ b/bridge/device/cloud/publishResources.go @@ -0,0 +1,63 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "fmt" + + "github.com/plgd-dev/device/v2/bridge/resources" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var ErrCannotPublishResources = fmt.Errorf("cannot publish resources") + +func errCannotPublishResources(err error) error { + return fmt.Errorf("%w: %w", ErrCannotPublishResources, err) +} + +func (c *Manager) publishResources(ctx context.Context) error { + if c.resourcesPublished { + return nil + } + + links := c.getLinks(schema.Endpoints{}, c.deviceID, nil, resources.PublishToCloud) + links = patchDeviceLink(links) + wkRd := ocfCloud.PublishResourcesRequest{ + DeviceID: c.deviceID.String(), + Links: links, + TimeToLive: 0, + } + req, err := newPostRequest(ctx, c.client, ocfCloud.ResourceDirectory, wkRd) + if err != nil { + return errCannotPublishResources(err) + } + resp, err := c.client.Do(req) + if err != nil { + return errCannotPublishResources(err) + } + if resp.Code() != codes.Changed { + return errCannotPublishResources(fmt.Errorf("unexpected status code %v", resp.Code())) + } + c.resourcesPublished = true + c.logger.Infof("resources published") + return nil +} diff --git a/bridge/device/cloud/refreshToken.go b/bridge/device/cloud/refreshToken.go new file mode 100644 index 00000000..23c9ec44 --- /dev/null +++ b/bridge/device/cloud/refreshToken.go @@ -0,0 +1,76 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var ErrCannotRefreshToken = fmt.Errorf("cannot refresh token") + +func errCannotRefreshToken(err error) error { + return fmt.Errorf("%w: %w", ErrCannotRefreshToken, err) +} + +func (c *Manager) refreshToken(ctx context.Context) error { + creds := c.getCreds() + if creds.RefreshToken == "" { + return nil + } + req, err := newPostRequest(ctx, c.client, ocfCloud.RefreshToken, ocfCloud.CoapRefreshTokenRequest{ + DeviceID: c.deviceID.String(), + UserID: creds.UserID, + RefreshToken: creds.RefreshToken, + }) + if err != nil { + return errCannotRefreshToken(err) + } + c.setProvisioningStatus(cloud.ProvisioningStatus_REGISTERING) + resp, err := c.client.Do(req) + if err != nil { + return errCannotRefreshToken(err) + } + if resp.Code() != codes.Changed { + if resp.Code() == codes.Unauthorized { + c.cleanup() + } + return errCannotRefreshToken(fmt.Errorf("unexpected status code %v", resp.Code())) + } + var refreshResp ocfCloud.CoapRefreshTokenResponse + if err = cbor.ReadFrom(resp.Body(), &refreshResp); err != nil { + return errCannotRefreshToken(err) + } + c.updateCredsByRefreshTokenResponse(refreshResp) + c.logger.Infof("refreshed token") + c.save() + return nil +} + +func (c *Manager) updateCredsByRefreshTokenResponse(resp ocfCloud.CoapRefreshTokenResponse) { + c.creds.AccessToken = resp.AccessToken + c.creds.RefreshToken = resp.RefreshToken + c.creds.ValidUntil = validUntil(resp.ExpiresIn) + c.signedIn = false +} diff --git a/bridge/device/cloud/request.go b/bridge/device/cloud/request.go new file mode 100644 index 00000000..9b0f7df1 --- /dev/null +++ b/bridge/device/cloud/request.go @@ -0,0 +1,58 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "bytes" + "context" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +func newRequestWithToken(ctx context.Context, c *client.Conn, uri string) (*pool.Message, error) { + req := c.AcquireMessage(ctx) + token, err := message.GetToken() + if err != nil { + return nil, err + } + req.SetToken(token) + if err = req.SetPath(uri); err != nil { + return nil, err + } + return req, nil +} + +func newPostRequest(ctx context.Context, c *client.Conn, uri string, data interface{}) (*pool.Message, error) { + req, err := newRequestWithToken(ctx, c, uri) + if err != nil { + return nil, err + } + req.SetCode(codes.POST) + inputCbor, err := cbor.Encode(data) + if err != nil { + return nil, err + } + req.SetContentFormat(message.AppOcfCbor) + req.SetBody(bytes.NewReader(inputCbor)) + return req, nil +} diff --git a/bridge/device/cloud/security.go b/bridge/device/cloud/security.go new file mode 100644 index 00000000..a30a1aff --- /dev/null +++ b/bridge/device/cloud/security.go @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "crypto/x509" + "fmt" +) + +type GetCAPool func() []*x509.Certificate + +type CAPool struct { + getCAPool GetCAPool + useSystemCAPool bool +} + +func MakeCAPool(getCAPool GetCAPool, useSystemCAPool bool) CAPool { + return CAPool{ + getCAPool: getCAPool, + useSystemCAPool: useSystemCAPool, + } +} + +func (c CAPool) IsValid() bool { + return c.useSystemCAPool || (c.getCAPool != nil && c.getCAPool() != nil) +} + +func (c CAPool) GetPool() (*x509.CertPool, error) { + var pool *x509.CertPool + if c.useSystemCAPool { + systemPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("cannot get system pool: %w", err) + } + pool = systemPool + } else { + pool = x509.NewCertPool() + } + for _, ca := range c.getCAPool() { + pool.AddCert(ca) + } + return pool, nil +} diff --git a/bridge/device/cloud/security_test.go b/bridge/device/cloud/security_test.go new file mode 100644 index 00000000..027249a1 --- /dev/null +++ b/bridge/device/cloud/security_test.go @@ -0,0 +1,42 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud_test + +import ( + "crypto/x509" + "testing" + + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/stretchr/testify/require" +) + +func TestGetPool(t *testing.T) { + getCAPool := func() []*x509.Certificate { + return []*x509.Certificate{{}} + } + caPool1 := cloud.MakeCAPool(getCAPool, true) + require.True(t, caPool1.IsValid()) + _, err := caPool1.GetPool() + require.NoError(t, err) + + caPool2 := cloud.MakeCAPool(getCAPool, false) + require.True(t, caPool2.IsValid()) + _, err = caPool2.GetPool() + require.NoError(t, err) +} diff --git a/bridge/device/cloud/signIn.go b/bridge/device/cloud/signIn.go new file mode 100644 index 00000000..1719619b --- /dev/null +++ b/bridge/device/cloud/signIn.go @@ -0,0 +1,94 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var ( + ErrMissingAccessToken = fmt.Errorf("access token missing") + ErrCannotSignIn = fmt.Errorf("cannot sign in") +) + +func MakeSignInRequest(deviceID, userID, accessToken string) (ocfCloud.CoapSignInRequest, error) { + if accessToken == "" { + return ocfCloud.CoapSignInRequest{}, ErrMissingAccessToken + } + return ocfCloud.CoapSignInRequest{ + DeviceID: deviceID, + UserID: userID, + AccessToken: accessToken, + Login: true, + }, nil +} + +func errCannotSignIn(err error) error { + return fmt.Errorf("%w: %w", ErrCannotSignIn, err) +} + +func (c *Manager) signIn(ctx context.Context) error { + if c.client == nil { + return errCannotSignIn(fmt.Errorf("no connection")) + } + if c.signedIn { + return nil + } + creds := c.getCreds() + signInReq, err := MakeSignInRequest(c.deviceID.String(), creds.UserID, creds.AccessToken) + if err != nil { + return errCannotSignIn(err) + } + req, err := newPostRequest(ctx, c.client, ocfCloud.SignIn, signInReq) + if err != nil { + return errCannotSignIn(err) + } + c.setProvisioningStatus(cloud.ProvisioningStatus_REGISTERING) + resp, err := c.client.Do(req) + if err != nil { + return err + } + if resp.Code() != codes.Changed { + if resp.Code() == codes.Unauthorized { + c.cleanup() + } + return errCannotSignIn(fmt.Errorf("unexpected status code %v", resp.Code())) + } + var signInResp ocfCloud.CoapSignInResponse + err = cbor.ReadFrom(resp.Body(), &signInResp) + if err != nil { + return err + } + c.updateCredsBySignInResponse(signInResp) + c.logger.Infof("signed in") + c.save() + return nil +} + +func (c *Manager) updateCredsBySignInResponse(resp ocfCloud.CoapSignInResponse) { + c.creds.ExpiresIn = resp.ExpiresIn + c.creds.ValidUntil = validUntil(resp.ExpiresIn) + c.signedIn = true +} diff --git a/bridge/device/cloud/signIn_test.go b/bridge/device/cloud/signIn_test.go new file mode 100644 index 00000000..558ecce8 --- /dev/null +++ b/bridge/device/cloud/signIn_test.go @@ -0,0 +1,38 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/stretchr/testify/require" +) + +func TestMakeSignInRequest(t *testing.T) { + _, err := cloud.MakeSignInRequest("id", "uid", "") + require.ErrorIs(t, err, cloud.ErrMissingAccessToken) + + req, err := cloud.MakeSignInRequest("id", "uid", "token") + require.NoError(t, err) + require.Equal(t, "id", req.DeviceID) + require.Equal(t, "uid", req.UserID) + require.Equal(t, "token", req.AccessToken) + require.True(t, req.Login) +} diff --git a/bridge/device/cloud/signOff.go b/bridge/device/cloud/signOff.go new file mode 100644 index 00000000..6a581c77 --- /dev/null +++ b/bridge/device/cloud/signOff.go @@ -0,0 +1,75 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "fmt" + + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +const ProvisioningStatusDEREGISTERING cloud.ProvisioningStatus = "deregistering" + +var ErrCannotSignOff = fmt.Errorf("cannot sign off") + +func newSignOffReq(ctx context.Context, c *client.Conn, deviceID, userID string) (*pool.Message, error) { + req, err := newRequestWithToken(ctx, c, ocfCloud.SignUp) + if err != nil { + return nil, err + } + req.SetCode(codes.DELETE) + req.AddQuery("di=" + deviceID) + req.AddQuery("uid=" + userID) + return req, nil +} + +func errCannotSignOff(err error) error { + return fmt.Errorf("%w: %w", ErrCannotSignOff, err) +} + +func (c *Manager) signOff(ctx context.Context) error { + if c.client == nil { + return nil + } + // signIn / refresh token fails + if ctx.Err() != nil { + return errCannotSignOff(ctx.Err()) + } + + req, err := newSignOffReq(ctx, c.client, c.deviceID.String(), c.getCreds().UserID) + if err != nil { + return errCannotSignOff(err) + } + c.setProvisioningStatus(ProvisioningStatusDEREGISTERING) + resp, err := c.client.Do(req) + defer c.setProvisioningStatus(cloud.ProvisioningStatus_UNINITIALIZED) + if err != nil { + return errCannotSignOff(err) + } + if resp.Code() != codes.Deleted { + return errCannotSignOff(fmt.Errorf("unexpected status code %v", resp.Code())) + } + c.logger.Infof("signed off") + return nil +} diff --git a/bridge/device/cloud/signUp.go b/bridge/device/cloud/signUp.go new file mode 100644 index 00000000..88b22aa4 --- /dev/null +++ b/bridge/device/cloud/signUp.go @@ -0,0 +1,90 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "context" + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var ( + ErrMissingAuthorizationCode = fmt.Errorf("authorization code missing") + ErrMissingAuthorizationProvider = fmt.Errorf("authorization provider missing") + ErrCannotSignUp = fmt.Errorf("cannot sign up") +) + +func MakeSignUpRequest(deviceID, code, provider string) (ocfCloud.CoapSignUpRequest, error) { + if code == "" { + return ocfCloud.CoapSignUpRequest{}, ErrMissingAuthorizationCode + } + if provider == "" { + return ocfCloud.CoapSignUpRequest{}, ErrMissingAuthorizationProvider + } + + return ocfCloud.CoapSignUpRequest{ + DeviceID: deviceID, + AuthorizationCode: code, + AuthorizationProvider: provider, + }, nil +} + +func errCannotSignUp(err error) error { + return fmt.Errorf("%w: %w", ErrCannotSignUp, err) +} + +func (c *Manager) signUp(ctx context.Context) error { + creds := c.getCreds() + if creds.AccessToken != "" { + return nil + } + cfg := c.getCloudConfiguration() + signUpRequest, err := MakeSignUpRequest(c.deviceID.String(), cfg.AuthorizationCode, cfg.AuthorizationProvider) + if err != nil { + return errCannotSignUp(err) + } + req, err := newPostRequest(ctx, c.client, ocfCloud.SignUp, signUpRequest) + if err != nil { + return errCannotSignUp(err) + } + c.setProvisioningStatus(cloud.ProvisioningStatus_REGISTERING) + resp, err := c.client.Do(req) + if err != nil { + return errCannotSignUp(err) + } + if resp.Code() != codes.Changed { + return errCannotSignUp(fmt.Errorf("unexpected status code %v", resp.Code())) + } + var signUpResp ocfCloud.CoapSignUpResponse + err = cbor.ReadFrom(resp.Body(), &signUpResp) + if err != nil { + return errCannotSignUp(err) + } + if signUpResp.ExpiresIn != -1 { + c.creds.ValidUntil = validUntil(signUpResp.ExpiresIn) + } + c.setCreds(signUpResp) + c.logger.Infof("signed up") + c.save() + return nil +} diff --git a/bridge/device/cloud/signUp_test.go b/bridge/device/cloud/signUp_test.go new file mode 100644 index 00000000..123d0b64 --- /dev/null +++ b/bridge/device/cloud/signUp_test.go @@ -0,0 +1,40 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/stretchr/testify/require" +) + +func TestMakeSignUpRequest(t *testing.T) { + _, err := cloud.MakeSignUpRequest("id", "", "provider") + require.ErrorIs(t, err, cloud.ErrMissingAuthorizationCode) + + _, err = cloud.MakeSignUpRequest("id", "code", "") + require.ErrorIs(t, err, cloud.ErrMissingAuthorizationProvider) + + req, err := cloud.MakeSignUpRequest("id", "code", "provider") + require.NoError(t, err) + require.Equal(t, "id", req.DeviceID) + require.Equal(t, "code", req.AuthorizationCode) + require.Equal(t, "provider", req.AuthorizationProvider) +} diff --git a/bridge/device/config.go b/bridge/device/config.go new file mode 100644 index 00000000..c8fad9ce --- /dev/null +++ b/bridge/device/config.go @@ -0,0 +1,62 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device + +import ( + "fmt" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/device/credential" +) + +type CloudConfig struct { + Enabled bool + cloud.Config +} + +type CredentialConfig struct { + Enabled bool + credential.Config +} + +type Config struct { + ID uuid.UUID + Name string + ProtocolIndependentID uuid.UUID + ResourceTypes []string + MaxMessageSize uint32 + Cloud CloudConfig + Credential CredentialConfig +} + +func (cfg *Config) Validate() error { + if cfg.ProtocolIndependentID == uuid.Nil { + return fmt.Errorf("protocolIndependentID is required") + } + if cfg.ID == uuid.Nil { + cfg.ID = uuid.New() + } + + if cfg.Name == "" { + cfg.Name = "Unnamed" + } + + return nil +} diff --git a/bridge/device/config_test.go b/bridge/device/config_test.go new file mode 100644 index 00000000..35510bbc --- /dev/null +++ b/bridge/device/config_test.go @@ -0,0 +1,68 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device_test + +import ( + "testing" + + "github.com/google/uuid" + bridgeDevice "github.com/plgd-dev/device/v2/bridge/device" + "github.com/stretchr/testify/require" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg bridgeDevice.Config + wantError bool + }{ + { + name: "Valid Configuration", + cfg: bridgeDevice.Config{ + ProtocolIndependentID: uuid.New(), + ID: uuid.New(), + Name: "ValidName", + }, + }, + { + name: "Valid Configuration with empty ID and Name", + cfg: bridgeDevice.Config{ + ProtocolIndependentID: uuid.New(), + }, + }, + { + name: "Invalid ProtocolIndependentID", + cfg: bridgeDevice.Config{ + ID: uuid.New(), + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if tt.wantError { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/bridge/device/credential/config.go b/bridge/device/credential/config.go new file mode 100644 index 00000000..486b5920 --- /dev/null +++ b/bridge/device/credential/config.go @@ -0,0 +1,23 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential + +import "github.com/plgd-dev/device/v2/schema/credential" + +type Config = credential.CredentialResponse diff --git a/bridge/device/credential/manager.go b/bridge/device/credential/manager.go new file mode 100644 index 00000000..c4439fed --- /dev/null +++ b/bridge/device/credential/manager.go @@ -0,0 +1,178 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential + +import ( + "crypto/x509" + + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" + "github.com/plgd-dev/device/v2/schema/credential" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/pkg/sync" +) + +type Manager struct { + credentials *sync.Map[int, credential.Credential] + save func() +} + +func New(cfg Config, save func()) *Manager { + if save == nil { + save = func() { + // do nothing + } + } + m := Manager{ + save: save, + credentials: sync.NewMap[int, credential.Credential](), + } + m.importConfig(cfg) + return &m +} + +func (m *Manager) importConfig(cfg Config) { + m.AddOrReplaceCredentials(cfg.Credentials...) +} + +func (m *Manager) GetCAPool() []*x509.Certificate { + certs := make([]*x509.Certificate, 0, m.credentials.Length()) + m.credentials.Range(func(key int, value credential.Credential) bool { + if value.Type != credential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE { + return true + } + if value.Usage != credential.CredentialUsage_TRUST_CA && value.Usage != credential.CredentialUsage_MFG_TRUST_CA { + return true + } + cas, err := pkgX509.ParsePemCertificates(value.PublicData.Data()) + if err != nil { + return true + } + certs = append(certs, cas...) + return true + }) + return certs +} + +func (m *Manager) getNextID() int { + var id int + m.credentials.Range(func(key int, value credential.Credential) bool { + if key > id { + id = key + } + return true + }) + return id + 1 +} + +func (m *Manager) add(c credential.Credential) { + for { + _, loaded := m.credentials.LoadOrStore(c.ID, c) + if !loaded { + return + } + c.ID = m.getNextID() + } +} + +func (m *Manager) AddOrReplaceCredential(c credential.Credential) { + if c.Type == credential.CredentialType_EMPTY { + return + } + if c.ID != 0 { + // replace + m.credentials.Store(c.ID, c) + return + } + // add + m.add(c) +} + +func (m *Manager) AddOrReplaceCredentials(ca ...credential.Credential) { + for _, c := range ca { + m.AddOrReplaceCredential(c) + } +} + +func (m *Manager) RemoveCredentials(ids ...int) { + for _, id := range ids { + m.credentials.Delete(id) + } +} + +func (m *Manager) RemoveCredentialsBySubjects(subjects ...string) { + m.credentials.Range(func(key int, value credential.Credential) bool { + for _, subject := range subjects { + if value.Subject == subject { + m.credentials.Delete(key) + } + } + return true + }) +} + +func (m *Manager) ClearCredentials() { + _ = m.credentials.LoadAndDeleteAll() +} + +func (m *Manager) getRep(privateData bool) credential.CredentialResponse { + cas := m.credentials.CopyData() + creds := credential.CredentialResponse{ + Interfaces: []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}, + ResourceTypes: []string{credential.ResourceType}, + Credentials: make([]credential.Credential, 0, len(cas)), + } + for _, cred := range cas { + if !privateData { + // remove private data + cred.PrivateData = nil + } + creds.Credentials = append(creds.Credentials, cred) + } + return creds +} + +func (m *Manager) Get(request *net.Request) (*pool.Message, error) { + creds := m.getRep(false) + return resources.CreateResponseContent(request.Context(), creds, codes.Content) +} + +func (m *Manager) Post(request *net.Request) (*pool.Message, error) { + var cfg credential.CredentialUpdateRequest + err := cbor.ReadFrom(request.Body(), &cfg) + if err != nil { + return resources.CreateResponseBadRequest(request.Context(), err) + } + m.AddOrReplaceCredentials(cfg.Credentials...) + m.save() + creds := m.getRep(false) + return resources.CreateResponseContent(request.Context(), creds, codes.Changed) +} + +func (m *Manager) ExportConfig() Config { + return m.getRep(true) +} + +func (m *Manager) Close() { + m.ClearCredentials() +} diff --git a/bridge/device/credential/manager_internal_test.go b/bridge/device/credential/manager_internal_test.go new file mode 100644 index 00000000..c70dd444 --- /dev/null +++ b/bridge/device/credential/manager_internal_test.go @@ -0,0 +1,313 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential + +import ( + "bytes" + "context" + "testing" + + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/schema/credential" + "github.com/plgd-dev/device/v2/test" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/stretchr/testify/require" +) + +func TestGetCAPool(t *testing.T) { + m := New(Config{}, func() {}) + + // Add some credentials to the manager + cred1 := credential.Credential{ + Type: credential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE, + Usage: credential.CredentialUsage_TRUST_CA, + PublicData: &credential.CredentialPublicData{ + DataInternal: test.GetRootCApem(t), + }, + } + cred2 := credential.Credential{ + Type: credential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE, + Usage: credential.CredentialUsage_MFG_TRUST_CA, + PublicData: &credential.CredentialPublicData{ + DataInternal: test.GetRootCApem(t), + }, + } + cred3 := credential.Credential{ + Type: credential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE, + Usage: credential.CredentialUsage_CERT, + PublicData: &credential.CredentialPublicData{ + DataInternal: test.GetRootCApem(t), + }, + } + m.credentials.Store(1, cred1) + m.credentials.Store(2, cred2) + m.credentials.Store(3, cred3) + + // Call the GetCAPool function + certs := m.GetCAPool() + + // Verify the result + require.Len(t, certs, 2) +} + +func TestAddOrReplaceCredential(t *testing.T) { + m := New(Config{}, func() {}) + + // Test adding a new credential + cred1 := credential.Credential{ + ID: 1, + Subject: "subject", + Type: credential.CredentialType_PIN_OR_PASSWORD, + } + m.AddOrReplaceCredential(cred1) + _, ok := m.credentials.Load(cred1.ID) + require.True(t, ok) + + // Test replacing an existing credential + cred2 := credential.Credential{ + ID: 1, + Subject: "subject1", + Type: credential.CredentialType_ASYMMETRIC_SIGNING, + } + m.AddOrReplaceCredential(cred2) + c, ok := m.credentials.Load(cred2.ID) + require.True(t, ok) + require.Equal(t, cred2, c) + + // Test adding an empty credential + cred3 := credential.Credential{ + ID: 0, + Type: credential.CredentialType_EMPTY, + } + m.AddOrReplaceCredential(cred3) + _, ok = m.credentials.Load(cred3.ID) + require.False(t, ok) +} + +func TestAddOrReplaceCredentials(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + // Create some test credentials + cred1 := credential.Credential{ + Subject: "subject1", + Type: credential.CredentialType_PIN_OR_PASSWORD, + } + cred2 := credential.Credential{ + Subject: "subject2", + Type: credential.CredentialType_PIN_OR_PASSWORD, + } + cred3 := credential.Credential{ + Subject: "subject3", + Type: credential.CredentialType_EMPTY, + } + + // Add the credentials to the manager + m.AddOrReplaceCredentials(cred1, cred2, cred3) + + // Verify that the credentials were added correctly + require.Equal(t, 2, m.credentials.Length()) +} + +func TestRemoveCredentials(t *testing.T) { + m := New(Config{}, func() {}) + + // Add some credentials to the Manager + m.credentials.Store(1, credential.Credential{ + Subject: "subject1", + Type: credential.CredentialType_PIN_OR_PASSWORD, + }) + m.credentials.Store(2, credential.Credential{ + Subject: "subject2", + Type: credential.CredentialType_PIN_OR_PASSWORD, + }) + m.credentials.Store(3, credential.Credential{ + Subject: "subject3", + Type: credential.CredentialType_PIN_OR_PASSWORD, + }) + + // Remove credentials with IDs 1 and 3 + m.RemoveCredentials(1, 3) + + // Check if the credentials were removed + _, ok := m.credentials.Load(1) + require.False(t, ok, "Credential with ID 1 should be removed") + _, ok = m.credentials.Load(3) + require.False(t, ok, "Credential with ID 3 should be removed") + + // Check if the credential with ID 2 still exists + _, ok = m.credentials.Load(2) + require.True(t, ok, "Credential with ID 2 should still exist") +} + +func TestRemoveCredentialsBySubjects(t *testing.T) { + // Create a new instance of the Manager + m := New(Config{}, func() {}) + + // Add some credentials to the Manager + cred1 := credential.Credential{Subject: "subject1", Type: credential.CredentialType_PIN_OR_PASSWORD} + cred2 := credential.Credential{Subject: "subject2", Type: credential.CredentialType_PIN_OR_PASSWORD} + cred3 := credential.Credential{Subject: "subject3", Type: credential.CredentialType_PIN_OR_PASSWORD} + m.credentials.Store(1, cred1) + m.credentials.Store(2, cred2) + m.credentials.Store(3, cred3) + + // Call the RemoveCredentialsBySubjects method with some subjects + m.RemoveCredentialsBySubjects("subject1", "subject3") + + // Check that the credentials with the specified subjects have been removed + _, ok1 := m.credentials.Load(1) + _, ok2 := m.credentials.Load(2) + _, ok3 := m.credentials.Load(3) + require.False(t, ok1, "Credential with subject 'subject1' should have been removed") + require.True(t, ok2, "Credential with subject 'subject2' should not have been removed") + require.False(t, ok3, "Credential with subject 'subject3' should have been removed") +} + +func TestClearCredentials(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + // Add some credentials to the Manager + cred1 := credential.Credential{Subject: "subject1", Type: credential.CredentialType_PIN_OR_PASSWORD} + cred2 := credential.Credential{Subject: "subject2", Type: credential.CredentialType_PIN_OR_PASSWORD} + cred3 := credential.Credential{Subject: "subject3", Type: credential.CredentialType_PIN_OR_PASSWORD} + m.credentials.Store(1, cred1) + m.credentials.Store(2, cred2) + m.credentials.Store(3, cred3) + + // Call the ClearCredentials method + m.ClearCredentials() + + require.Equal(t, 0, m.credentials.Length()) +} + +func TestGetCredential(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + // Add some credentials to the Manager + cred1 := credential.Credential{Subject: "subject1", Type: credential.CredentialType_PIN_OR_PASSWORD, PrivateData: &credential.CredentialPrivateData{DataInternal: []byte("private data")}} + cred2 := credential.Credential{Subject: "subject2", Type: credential.CredentialType_PIN_OR_PASSWORD} + cred3 := credential.Credential{Subject: "subject3", Type: credential.CredentialType_PIN_OR_PASSWORD} + m.credentials.Store(1, cred1) + m.credentials.Store(2, cred2) + m.credentials.Store(3, cred3) + + msg := pool.NewMessage(context.Background()) + msg.SetCode(codes.GET) + err := msg.SetPath(credential.ResourceURI) + require.NoError(t, err) + msg.SetToken([]byte{0x01}) + + // Create a mock request + request := &net.Request{ + Message: msg, + } + + // Call the Get method + response, err := m.Get(request) + + // Check if the response and error are as expected + require.NoError(t, err) + require.NotNil(t, response) + require.Equal(t, codes.Content, response.Code()) + require.NotNil(t, response.Body()) + + var cred credential.CredentialResponse + err = cbor.ReadFrom(response.Body(), &cred) + require.NoError(t, err) + require.Len(t, cred.Credentials, 3) + for _, c := range cred.Credentials { + require.Nil(t, c.PrivateData) + } +} + +func TestPostCredential(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + msg := pool.NewMessage(context.Background()) + msg.SetCode(codes.POST) + err := msg.SetPath(credential.ResourceURI) + require.NoError(t, err) + msg.SetToken([]byte{0x01}) + + updBody := credential.CredentialUpdateRequest{ + Credentials: []credential.Credential{ + { + Subject: "subject1", + Type: credential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE, + Usage: credential.CredentialUsage_TRUST_CA, + PublicData: &credential.CredentialPublicData{ + DataInternal: test.GetRootCApem(t), + Encoding: credential.CredentialPublicDataEncoding_PEM, + }, + }, + }, + } + body, err := cbor.Encode(updBody) + require.NoError(t, err) + msg.SetBody(bytes.NewReader(body)) + + // Create a mock request + req := &net.Request{ + Message: msg, + } + + // Call the Post method + resp, err := m.Post(req) + require.NoError(t, err) + + // Verify the response + require.NotNil(t, resp) + require.Equal(t, codes.Changed, resp.Code()) + + // Verify the credentials were added or replaced + creds := m.getRep(false) + require.Len(t, creds.Credentials, 1) + require.Equal(t, "subject1", creds.Credentials[0].Subject) +} + +func TestExportConfig(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + // Add some credentials to the Manager + cred1 := credential.Credential{Subject: "subject1", Type: credential.CredentialType_PIN_OR_PASSWORD, PrivateData: &credential.CredentialPrivateData{DataInternal: []byte("private data")}} + m.credentials.Store(1, cred1) + + // Call the ExportConfig function + config := m.ExportConfig() + + // Assert that the returned config is not nil + require.Len(t, config.Credentials, 1) + require.Equal(t, "subject1", config.Credentials[0].Subject) + require.NotNil(t, config.Credentials[0].PrivateData) + require.NotNil(t, config.Credentials[0].PrivateData.DataInternal) +} + +func TestClose(t *testing.T) { + m := New(Config{}, func() {}) // Create an instance of the Manager struct + + // Add some credentials to the Manager + cred1 := credential.Credential{Subject: "subject1", Type: credential.CredentialType_PIN_OR_PASSWORD, PrivateData: &credential.CredentialPrivateData{DataInternal: []byte("private data")}} + m.credentials.Store(1, cred1) + + // Call the Close function + m.Close() + require.Equal(t, 0, m.credentials.Length()) +} diff --git a/bridge/device/credential/security.go b/bridge/device/credential/security.go new file mode 100644 index 00000000..bb87840a --- /dev/null +++ b/bridge/device/credential/security.go @@ -0,0 +1,70 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential + +import ( + "crypto/x509" +) + +type GetCAPool func() []*x509.Certificate + +type CAPoolGetter = interface { + IsValid() bool + GetPool() (*x509.CertPool, error) +} + +type CAPool struct { + origCAPool CAPoolGetter + getCAPool GetCAPool +} + +func MakeCAPool(caPool CAPoolGetter, getCAPool GetCAPool) CAPool { + return CAPool{ + getCAPool: getCAPool, + origCAPool: caPool, + } +} + +func (c CAPool) IsValid() bool { + if c.getCAPool != nil { + return true + } + if c.origCAPool == nil { + return false + } + return c.origCAPool.IsValid() +} + +func (c CAPool) GetPool() (*x509.CertPool, error) { + pool := x509.NewCertPool() + if c.origCAPool != nil && c.origCAPool.IsValid() { + p, err := c.origCAPool.GetPool() + if err != nil { + return nil, err + } + pool = p + } + if c.getCAPool == nil { + return pool, nil + } + for _, ca := range c.getCAPool() { + pool.AddCert(ca) + } + return pool, nil +} diff --git a/bridge/device/credential/security_test.go b/bridge/device/credential/security_test.go new file mode 100644 index 00000000..921b2de4 --- /dev/null +++ b/bridge/device/credential/security_test.go @@ -0,0 +1,54 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential_test + +import ( + "crypto/x509" + "testing" + + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/device/credential" + "github.com/stretchr/testify/require" +) + +func TestGetPool(t *testing.T) { + getCAPool := func() []*x509.Certificate { + return []*x509.Certificate{{}} + } + + credPool1 := credential.MakeCAPool(nil, getCAPool) + require.True(t, credPool1.IsValid()) + _, err := credPool1.GetPool() + require.NoError(t, err) + + cloudCAPool1 := cloud.MakeCAPool(getCAPool, true) + require.True(t, cloudCAPool1.IsValid()) + credPool2 := credential.MakeCAPool(cloudCAPool1, nil) + require.True(t, credPool2.IsValid()) + _, err = credPool2.GetPool() + require.NoError(t, err) + + credPool3 := credential.MakeCAPool(cloudCAPool1, getCAPool) + require.True(t, credPool3.IsValid()) + _, err = credPool3.GetPool() + require.NoError(t, err) + + credPool4 := credential.MakeCAPool(nil, nil) + require.False(t, credPool4.IsValid()) +} diff --git a/bridge/device/device.go b/bridge/device/device.go new file mode 100644 index 00000000..ae7addac --- /dev/null +++ b/bridge/device/device.go @@ -0,0 +1,275 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device + +import ( + "bytes" + "context" + "fmt" + "log" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/device/credential" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + cloudResource "github.com/plgd-dev/device/v2/bridge/resources/cloud" + resourcesDevice "github.com/plgd-dev/device/v2/bridge/resources/device" + "github.com/plgd-dev/device/v2/bridge/resources/discovery" + "github.com/plgd-dev/device/v2/bridge/resources/maintenance" + credentialResource "github.com/plgd-dev/device/v2/bridge/resources/secure/credential" + pkgLog "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/schema" + cloudSchema "github.com/plgd-dev/device/v2/schema/cloud" + credentialSchema "github.com/plgd-dev/device/v2/schema/credential" + plgdDevice "github.com/plgd-dev/device/v2/schema/device" + maintenanceSchema "github.com/plgd-dev/device/v2/schema/maintenance" + plgdResources "github.com/plgd-dev/device/v2/schema/resources" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/pkg/sync" +) + +type Resource interface { + Close() + ETag() []byte + GetHref() string + GetResourceTypes() []string + GetResourceInterfaces() []string + HandleRequest(req *net.Request) (*pool.Message, error) + GetPolicyBitMask() schema.BitMask + SetObserveHandler(createSubscription resources.CreateSubscriptionFunc) + UpdateETag() +} + +type Device struct { + cfg Config + resources *sync.Map[string, Resource] + cloudManager *cloud.Manager + credentialManager *credential.Manager + onDeviceUpdated func(d *Device) +} + +func NewLogger(id uuid.UUID, level pkgLog.Level) pkgLog.Logger { + l := pkgLog.NewStdLogger(level) + if id != uuid.Nil { + l.SetPrefix(fmt.Sprintf("[%v] ", id.String())) + l.SetFlags(l.Flags() | log.Lmsgprefix) + } + return l +} + +func (d *Device) GetID() uuid.UUID { + return d.cfg.ID +} + +func (d *Device) GetName() string { + return d.cfg.Name +} + +func (d *Device) GetResourceTypes() []string { + return d.cfg.ResourceTypes +} + +func (d *Device) GetProtocolIndependentID() uuid.UUID { + return d.cfg.ProtocolIndependentID +} + +func (d *Device) ExportConfig() Config { + cfg := d.cfg + if d.cloudManager != nil { + cfg.Cloud.Config = d.cloudManager.ExportConfig() + } else { + cfg.Cloud.Enabled = false + } + if d.credentialManager != nil { + cfg.Credential.Config = d.credentialManager.ExportConfig() + } else { + cfg.Credential.Enabled = false + } + return cfg +} + +func New(cfg Config, opts ...Option) (*Device, error) { + o := OptionsCfg{ + onDeviceUpdated: func(d *Device) { + // do nothing + }, + getAdditionalProperties: func() map[string]interface{} { return nil }, + caPool: cloud.MakeCAPool(nil, false), + logger: NewLogger(cfg.ID, pkgLog.LevelInfo), + } + for _, opt := range opts { + opt(&o) + } + + cfg.ResourceTypes = resources.Unique(append(cfg.ResourceTypes, plgdDevice.ResourceType)) + d := &Device{ + cfg: cfg, + resources: sync.NewMap[string, Resource](), + onDeviceUpdated: o.onDeviceUpdated, + } + + cloudOpts := []cloud.Option{ + cloud.WithMaxMessageSize(cfg.MaxMessageSize), + cloud.WithLogger(o.logger), + } + if cfg.Credential.Enabled { + d.credentialManager = credential.New(cfg.Credential.Config, func() { + d.onDeviceUpdated(d) + }) + d.AddResource(credentialResource.New(credentialSchema.ResourceURI, d.credentialManager)) + o.caPool = credential.MakeCAPool(o.caPool, d.credentialManager.GetCAPool) + cloudOpts = append(cloudOpts, cloud.WithRemoveCloudCAs(d.credentialManager.RemoveCredentialsBySubjects)) + } + if cfg.Cloud.Enabled { + if o.getCertificates != nil { + cloudOpts = append(cloudOpts, cloud.WithGetCertificates(o.getCertificates)) + } + cm, err := cloud.New(cfg.Cloud.Config, d.cfg.ID, func() { + d.onDeviceUpdated(d) + }, d.HandleRequest, d.GetLinksFilteredBy, o.caPool, cloudOpts...) + if err != nil { + return nil, fmt.Errorf("cannot create cloud manager: %w", err) + } + d.cloudManager = cm + d.AddResource(cloudResource.New(cloudSchema.ResourceURI, d.cloudManager)) + } + + d.AddResource(resourcesDevice.New(plgdDevice.ResourceURI, d, o.getAdditionalProperties)) + // oic/res is not discoverable + discoverResource := discovery.New(plgdResources.ResourceURI, d.GetLinks) + discoverResource.PolicyBitMask = schema.Discoverable + d.AddResource(discoverResource) + + d.AddResource(maintenance.New(maintenanceSchema.ResourceURI, func() { + if d.cloudManager != nil { + d.cloudManager.Unregister() + } + })) + + return d, nil +} + +func (d *Device) AddResource(resource Resource) { + d.resources.Store(resource.GetHref(), resource) +} + +func (d *Device) Init() { + if d.cloudManager != nil { + d.cloudManager.Init() + } +} + +func (d *Device) GetCloudManager() *cloud.Manager { + return d.cloudManager +} + +func (d *Device) Range(f func(resourceHref string, resource Resource) bool) { + d.resources.Range(f) +} + +func (d *Device) GetResource(resourceHref string) (Resource, bool) { + return d.resources.Load(resourceHref) +} + +func hasResourceTypes(resourceTypes []string, oneOf []string) bool { + for _, rt := range oneOf { + for _, rrt := range resourceTypes { + if rt == rrt { + return true + } + } + } + return false +} + +func (d *Device) GetLinksFilteredBy(endpoints schema.Endpoints, deviceIDfilter uuid.UUID, resourceTypesFitler []string, policyBitMaskFitler schema.BitMask) (links schema.ResourceLinks) { + di := deviceIDfilter + if di != uuid.Nil && di != d.GetID() { + return nil + } + links = make(schema.ResourceLinks, 0, d.resources.Length()) + d.resources.Range(func(key string, resource Resource) bool { + if len(resourceTypesFitler) > 0 && !hasResourceTypes(resource.GetResourceTypes(), resourceTypesFitler) { + return true + } + if policyBitMaskFitler != 0 && resource.GetPolicyBitMask()&policyBitMaskFitler == 0 { + return true + } + links = append(links, schema.ResourceLink{ + Href: key, + ResourceTypes: resource.GetResourceTypes(), + Interfaces: resource.GetResourceInterfaces(), + Policy: &schema.Policy{ + BitMask: resource.GetPolicyBitMask() & (^resources.PublishToCloud), + }, + DeviceID: d.GetID().String(), + Endpoints: endpoints, + }) + return true + }) + return links +} + +func (d *Device) GetLinks(request *net.Request) (links schema.ResourceLinks) { + return d.GetLinksFilteredBy(request.Endpoints, request.DeviceID(), request.ResourceTypes(), 0) +} + +func (d *Device) LoadAndDeleteResource(resourceHref string) (Resource, bool) { + return d.resources.LoadAndDelete(resourceHref) +} + +func (d *Device) CloseAndDeleteResource(resourceHref string) bool { + r, ok := d.LoadAndDeleteResource(resourceHref) + if ok { + r.Close() + } + return ok +} + +func createResponseNotFound(ctx context.Context, uri string, token message.Token) *pool.Message { + msg := pool.NewMessage(ctx) + msg.SetCode(codes.NotFound) + msg.SetToken(token) + msg.SetBody(bytes.NewReader([]byte(fmt.Sprintf("uri %v not found", uri)))) + return msg +} + +func (d *Device) HandleRequest(req *net.Request) (*pool.Message, error) { + uri := req.URIPath() + res, ok := d.resources.Load(uri) + if !ok { + return createResponseNotFound(req.Context(), uri, req.Token()), nil + } + return res.HandleRequest(req) +} + +func (d *Device) Close() { + if d.cloudManager != nil { + d.cloudManager.Close() + } + if d.credentialManager != nil { + d.credentialManager.Close() + } + for _, resource := range d.resources.LoadAndDeleteAll() { + resource.Close() + } +} diff --git a/bridge/device/device_test.go b/bridge/device/device_test.go new file mode 100644 index 00000000..cb42a20d --- /dev/null +++ b/bridge/device/device_test.go @@ -0,0 +1,165 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device_test + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/resources" + cloudSchema "github.com/plgd-dev/device/v2/schema/cloud" + plgdDevice "github.com/plgd-dev/device/v2/schema/device" + "github.com/plgd-dev/device/v2/schema/interfaces" + maintenanceSchema "github.com/plgd-dev/device/v2/schema/maintenance" + plgdResources "github.com/plgd-dev/device/v2/schema/resources" + "github.com/stretchr/testify/require" +) + +var ( + deviceCfg = device.Config{ + ID: uuid.New(), + Name: "test", + ProtocolIndependentID: uuid.New(), + ResourceTypes: []string{"rt1", "rt2"}, + MaxMessageSize: 1024, + } + + cloudCfg = device.CloudConfig{ + Enabled: true, + Config: cloud.Config{ + AccessToken: "access token", + UserID: "user id", + RefreshToken: "refresh token", + ValidUntil: time.Date(2042, 4, 2, 0, 0, 0, 0, time.UTC), // 4.2.2042 + AuthorizationProvider: "auth provider", + CloudID: "cloud id", + CloudURL: "cloud url", + }, + } +) + +func TestNewDevice(t *testing.T) { + cfg := deviceCfg + dev, err := device.New(cfg) + require.NoError(t, err) + require.Equal(t, cfg.ID, dev.GetID()) + require.Equal(t, cfg.Name, dev.GetName()) + require.Equal(t, cfg.ProtocolIndependentID, dev.GetProtocolIndependentID()) + // oic.wk.d will be added automatically + require.Subset(t, dev.GetResourceTypes(), cfg.ResourceTypes) + require.Contains(t, dev.GetResourceTypes(), "oic.wk.d") + + cfg.ResourceTypes = append(cfg.ResourceTypes, "oic.wk.d") + require.Equal(t, cfg, dev.ExportConfig()) +} + +func TestNewDeviceWithCloud(t *testing.T) { + cfg := deviceCfg + cfg.Cloud = cloudCfg + + dev, err := device.New(cfg, device.WithCAPool(cloud.MakeCAPool(nil, true))) + require.NoError(t, err) + cfg.ResourceTypes = append(cfg.ResourceTypes, "oic.wk.d") + require.Equal(t, cfg, dev.ExportConfig()) +} + +func TestGetResource(t *testing.T) { + dev, err := device.New(deviceCfg) + require.NoError(t, err) + _, ok := dev.GetResource(plgdDevice.ResourceURI) + require.True(t, ok) + + // cloud was not enabled in the cfg + _, ok = dev.GetResource(cloudSchema.ResourceURI) + require.False(t, ok) + + // not existing resource + _, ok = dev.GetResource("/no-resource") + require.False(t, ok) +} + +func TestRangeResources(t *testing.T) { + cfg := deviceCfg + dev, err := device.New(cfg) + require.NoError(t, err) + resourceHrefs := []string{} + dev.Range(func(href string, _ device.Resource) bool { + resourceHrefs = append(resourceHrefs, href) + return true + }) + // default resources: device, discovery and maintenance + require.Len(t, resourceHrefs, 3) + require.Contains(t, resourceHrefs, plgdDevice.ResourceURI) + require.Contains(t, resourceHrefs, plgdResources.ResourceURI) + require.Contains(t, resourceHrefs, maintenanceSchema.ResourceURI) + + cfg.Cloud = cloudCfg + devWithCloud, err := device.New(cfg, device.WithCAPool(cloud.MakeCAPool(nil, true))) + require.NoError(t, err) + resourceHrefs = []string{} + devWithCloud.Range(func(href string, _ device.Resource) bool { + resourceHrefs = append(resourceHrefs, href) + return true + }) + // default resources: device, discovery, maintenance and cloud + require.Len(t, resourceHrefs, 4) + require.Contains(t, resourceHrefs, plgdDevice.ResourceURI) + require.Contains(t, resourceHrefs, plgdResources.ResourceURI) + require.Contains(t, resourceHrefs, maintenanceSchema.ResourceURI) + require.Contains(t, resourceHrefs, cloudSchema.ResourceURI) +} + +func TestLoadAndDeleteResource(t *testing.T) { + dev, err := device.New(deviceCfg) + require.NoError(t, err) + + _, ok := dev.LoadAndDeleteResource("/fail") + require.False(t, ok) + + res := resources.NewResource("/test", nil, nil, []string{"oic.d.virtual", "oic.d.test"}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + dev.AddResource(res) + _, ok = dev.GetResource(res.GetHref()) + require.True(t, ok) + + _, ok = dev.LoadAndDeleteResource(res.GetHref()) + require.True(t, ok) + _, ok = dev.GetResource(res.GetHref()) + require.False(t, ok) +} + +func TestCloseAndDeleteResource(t *testing.T) { + dev, err := device.New(deviceCfg) + require.NoError(t, err) + + ok := dev.CloseAndDeleteResource("/fail") + require.False(t, ok) + + res := resources.NewResource("/test", nil, nil, []string{"oic.d.virtual", "oic.d.test"}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + dev.AddResource(res) + _, ok = dev.GetResource(res.GetHref()) + require.True(t, ok) + + ok = dev.CloseAndDeleteResource(res.GetHref()) + require.True(t, ok) + _, ok = dev.GetResource(res.GetHref()) + require.False(t, ok) +} diff --git a/bridge/device/options.go b/bridge/device/options.go new file mode 100644 index 00000000..6acdae5f --- /dev/null +++ b/bridge/device/options.go @@ -0,0 +1,74 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device + +import ( + "crypto/x509" + + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/resources/device" + "github.com/plgd-dev/device/v2/pkg/log" +) + +type OnDeviceUpdated func(d *Device) + +type CAPoolGetter interface { + IsValid() bool + GetPool() (*x509.CertPool, error) +} + +type OptionsCfg struct { + onDeviceUpdated OnDeviceUpdated + getAdditionalProperties device.GetAdditionalPropertiesForResponseFunc + getCertificates cloud.GetCertificates + caPool CAPoolGetter + logger log.Logger +} + +type Option func(*OptionsCfg) + +func WithOnDeviceUpdated(onDeviceUpdated OnDeviceUpdated) Option { + return func(o *OptionsCfg) { + o.onDeviceUpdated = onDeviceUpdated + } +} + +func WithGetAdditionalPropertiesForResponse(getAdditionalProperties device.GetAdditionalPropertiesForResponseFunc) Option { + return func(o *OptionsCfg) { + o.getAdditionalProperties = getAdditionalProperties + } +} + +func WithGetCertificates(getCertificates cloud.GetCertificates) Option { + return func(o *OptionsCfg) { + o.getCertificates = getCertificates + } +} + +func WithCAPool(caPool CAPoolGetter) Option { + return func(o *OptionsCfg) { + o.caPool = caPool + } +} + +func WithLogger(logger log.Logger) Option { + return func(o *OptionsCfg) { + o.logger = logger + } +} diff --git a/bridge/getDevices_test.go b/bridge/getDevices_test.go new file mode 100644 index 00000000..6412ca8f --- /dev/null +++ b/bridge/getDevices_test.go @@ -0,0 +1,97 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package bridge_test + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + testClient "github.com/plgd-dev/device/v2/test/client" + "github.com/stretchr/testify/require" +) + +func TestGetDevices(t *testing.T) { + s := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s.Shutdown() + }) + deviceID1 := uuid.New().String() + d1 := bridgeTest.NewBridgedDevice(t, s, deviceID1, false, false) + deviceID2 := uuid.New().String() + d2 := bridgeTest.NewBridgedDevice(t, s, deviceID2, false, false) + deviceID3 := uuid.New().String() + d3 := bridgeTest.NewBridgedDevice(t, s, deviceID3, false, false) + defer func() { + s.DeleteAndCloseDevice(d3.GetID()) + s.DeleteAndCloseDevice(d2.GetID()) + s.DeleteAndCloseDevice(d1.GetID()) + }() + + cleanup := bridgeTest.RunBridgeService(s) + defer func() { + errC := cleanup() + require.NoError(t, errC) + }() + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + + c1, err := testClient.NewTestSecureClient() + require.NoError(t, err) + defer func() { + errC := c1.Close(context.Background()) + require.NoError(t, errC) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + defer cancel() + + devices, err := c.GetDevicesDetails(ctx) + require.NoError(t, err) + require.NotEmpty(t, devices[deviceID1]) + require.NotEmpty(t, devices[deviceID2]) + require.NotEmpty(t, devices[deviceID3]) + + eps := devices[deviceID1].Endpoints + require.NotEmpty(t, eps) + addr, err := eps[0].GetAddr() + require.NoError(t, err) + + ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second*1) + defer cancel1() + + ipDevs, err := c1.GetDevicesByIP(ctx1, addr.String()) + require.NoError(t, err) + require.NotEmpty(t, ipDevs) + devs := make(map[string]bool) + for _, d := range ipDevs { + devs[d.Device.DeviceID()] = true + require.NotEmpty(t, d.Links) + } + require.True(t, devs[deviceID1]) + require.True(t, devs[deviceID2]) + require.True(t, devs[deviceID3]) +} diff --git a/bridge/getResource_test.go b/bridge/getResource_test.go new file mode 100644 index 00000000..9140d72f --- /dev/null +++ b/bridge/getResource_test.go @@ -0,0 +1,165 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package bridge_test + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/resources" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + "github.com/plgd-dev/device/v2/client" + "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/net/coap" + "github.com/plgd-dev/device/v2/schema/device" + "github.com/plgd-dev/device/v2/schema/interfaces" + testClient "github.com/plgd-dev/device/v2/test/client" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/stretchr/testify/require" +) + +func TestGetResource(t *testing.T) { + s := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s.Shutdown() + }) + deviceID1 := uuid.New().String() + d1 := bridgeTest.NewBridgedDevice(t, s, deviceID1, false, false) + deviceID2 := uuid.New().String() + d2 := bridgeTest.NewBridgedDevice(t, s, deviceID2, false, false) + defer func() { + s.DeleteAndCloseDevice(d2.GetID()) + s.DeleteAndCloseDevice(d1.GetID()) + }() + + failRes := resources.NewResource("/fail", + nil, + nil, + []string{"oic.d.virtual", "oic.d.test"}, + []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_R}, + ) + d1.AddResource(failRes) + + cleanup := bridgeTest.RunBridgeService(s) + defer func() { + errC := cleanup() + require.NoError(t, errC) + }() + + type args struct { + deviceID string + href string + opts []client.GetOption + } + tests := []struct { + name string + args args + want coap.DetailedResponse[interface{}] + wantErr bool + }{ + { + name: "valid", + args: args{ + deviceID: d1.GetID().String(), + href: device.ResourceURI, + opts: []client.GetOption{ + client.WithDiscoveryConfiguration(core.DefaultDiscoveryConfiguration()), + }, + }, + want: coap.DetailedResponse[interface{}]{ + Code: codes.Content, + Body: map[interface{}]interface{}{ + "di": d1.GetID().String(), + "piid": d1.GetProtocolIndependentID().String(), + "n": d1.GetName(), + }, + }, + }, + { + name: "valid with interface", + args: args{ + deviceID: d2.GetID().String(), + href: device.ResourceURI, + opts: []client.GetOption{ + client.WithInterface(interfaces.OC_IF_BASELINE), + }, + }, + + want: coap.DetailedResponse[interface{}]{ + Code: codes.Content, + Body: map[interface{}]interface{}{ + "di": d2.GetID().String(), + "piid": d2.GetProtocolIndependentID().String(), + "n": d2.GetName(), + "if": []interface{}{interfaces.OC_IF_BASELINE, interfaces.OC_IF_R}, + "rt": []interface{}{"oic.d.virtual", "oic.wk.d"}, + }, + }, + }, + { + name: "invalid href", + args: args{ + deviceID: d1.GetID().String(), + href: "/invalid/href", + }, + wantErr: true, + }, + { + name: "invalid deviceID", + args: args{ + deviceID: "notfound", + href: device.ResourceURI, + }, + wantErr: true, + }, + { + name: "invalid get handler", + args: args{ + deviceID: d1.GetID().String(), + href: failRes.GetHref(), + }, + wantErr: true, + }, + } + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) + defer cancel() + var got coap.DetailedResponse[interface{}] + err := c.GetResource(ctx, tt.args.deviceID, tt.args.href, &got, tt.args.opts...) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + got.ETag = nil + require.Equal(t, tt.want, got) + }) + } +} diff --git a/bridge/net/config.go b/bridge/net/config.go new file mode 100644 index 00000000..8359f419 --- /dev/null +++ b/bridge/net/config.go @@ -0,0 +1,130 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net + +import ( + "fmt" + gonet "net" + "strconv" + "time" +) + +const ( + UDP4 = "udp4" + UDP6 = "udp6" +) + +var ErrInvalidExternalAddress = fmt.Errorf("invalid externalAddress") + +type externalAddressPort struct { + host string + port string + network string +} + +type externalAddressesPort []externalAddressPort + +func (extAddresses externalAddressesPort) filterByNetwork(network string) externalAddressesPort { + var filtered externalAddressesPort + for _, e := range extAddresses { + if e.network == network { + filtered = append(filtered, e) + } + } + return filtered +} + +func (extAddresses externalAddressesPort) filterByPort(port string) externalAddressesPort { + var filtered externalAddressesPort + for _, e := range extAddresses { + if e.port == port { + filtered = append(filtered, e) + } + } + return filtered +} + +type Config struct { + ExternalAddresses []string `yaml:"externalAddresses"` + MaxMessageSize uint32 `yaml:"maxMessageSize"` + DeduplicationLifetime time.Duration `yaml:"deduplicationLifetime"` + externalAddressesPort externalAddressesPort `yaml:"-"` +} + +const DefaultMaxMessageSize = 2 * 1024 * 1024 + +func errInvalidExternalAddress(err error) error { + return fmt.Errorf("%w: %w", ErrInvalidExternalAddress, err) +} + +func validateExternalAddress(addr string) (externalAddressPort, error) { + host, portStr, err := gonet.SplitHostPort(addr) + if err != nil { + return externalAddressPort{}, errInvalidExternalAddress(err) + } + if host == "" { + return externalAddressPort{}, errInvalidExternalAddress(fmt.Errorf("host cannot be empty")) + } + _, err = strconv.ParseUint(portStr, 10, 16) + if err != nil { + return externalAddressPort{}, errInvalidExternalAddress(err) + } + + _, errIpv4 := gonet.ResolveUDPAddr(UDP4, addr) + _, errIpv6 := gonet.ResolveUDPAddr(UDP6, addr) + if errIpv4 != nil && errIpv6 != nil { + err = errIpv4 + if errIpv6 != nil { + err = errIpv6 + } + return externalAddressPort{}, fmt.Errorf("%w: %w", ErrInvalidExternalAddress, err) + } + network := UDP6 + if errIpv4 == nil { + network = UDP4 + } + + return externalAddressPort{ + host: host, + port: portStr, + network: network, + }, nil +} + +func (cfg *Config) Validate() error { + if len(cfg.ExternalAddresses) == 0 { + return fmt.Errorf("%w: cannot be empty", ErrInvalidExternalAddress) + } + if cfg.MaxMessageSize == 0 { + cfg.MaxMessageSize = DefaultMaxMessageSize + } + if cfg.DeduplicationLifetime == 0 { + cfg.DeduplicationLifetime = 8 * time.Second + } + externalAddressesPort := make([]externalAddressPort, 0, len(cfg.ExternalAddresses)) + for i, e := range cfg.ExternalAddresses { + extAddress, err := validateExternalAddress(e) + if err != nil { + return fmt.Errorf("invalid configuration [%v:%v]: %w", i, e, err) + } + externalAddressesPort = append(externalAddressesPort, extAddress) + } + cfg.externalAddressesPort = externalAddressesPort + return nil +} diff --git a/bridge/net/config_internal_test.go b/bridge/net/config_internal_test.go new file mode 100644 index 00000000..50812fec --- /dev/null +++ b/bridge/net/config_internal_test.go @@ -0,0 +1,132 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfigValidate(t *testing.T) { + type data struct { + maxMsgSize uint32 + externalAddressesPort externalAddressesPort + } + tests := []struct { + name string + config *Config + wantErr bool + want data + }{ + { + name: "ValidConfig", + config: &Config{ExternalAddresses: []string{"localhost:12345"}, MaxMessageSize: 1024}, + want: data{ + maxMsgSize: 1024, + externalAddressesPort: externalAddressesPort{{ + host: "localhost", + port: "12345", + network: UDP4, + }}, + }, + }, + { + name: "ValidConfigUDP6", + config: &Config{ExternalAddresses: []string{"[::1]:12345"}, MaxMessageSize: 1024}, + want: data{ + maxMsgSize: 1024, + externalAddressesPort: externalAddressesPort{{ + host: "::1", + port: "12345", + network: UDP6, + }}, + }, + }, + { + name: "ValidConfigUDP4andUDP6", + config: &Config{ExternalAddresses: []string{"localhost:12345", "[::1]:12345"}, MaxMessageSize: 1024}, + want: data{ + maxMsgSize: 1024, + externalAddressesPort: externalAddressesPort{ + { + host: "localhost", + port: "12345", + network: UDP4, + }, + { + host: "::1", + port: "12345", + network: UDP6, + }, + }, + }, + }, + { + name: "ValidConfigWithDefaultMaxMessageSize", + config: &Config{ExternalAddresses: []string{"localhost:12345"}}, + want: data{ + maxMsgSize: DefaultMaxMessageSize, + externalAddressesPort: externalAddressesPort{{ + host: "localhost", + port: "12345", + network: UDP4, + }}, + }, + }, + { + name: "MissingExternalAddress", + config: &Config{}, + wantErr: true, + }, + { + name: "InvalidExternalAddress", + config: &Config{ExternalAddresses: []string{"invalid-address"}}, + wantErr: true, + }, + { + name: "EmptyHostInExternalAddress", + config: &Config{ExternalAddresses: []string{":12345"}}, + wantErr: true, + }, + { + name: "InvalidPortInExternalAddress", + config: &Config{ExternalAddresses: []string{"localhost:invalid"}}, + wantErr: true, + }, + { + name: "PortGreaterThanMaxUint16", + config: &Config{ExternalAddresses: []string{"localhost:65536"}}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want.maxMsgSize, tt.config.MaxMessageSize) + require.Equal(t, tt.want.externalAddressesPort, tt.config.externalAddressesPort) + }) + } +} diff --git a/bridge/net/network.go b/bridge/net/network.go new file mode 100644 index 00000000..91566729 --- /dev/null +++ b/bridge/net/network.go @@ -0,0 +1,455 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net + +import ( + "bytes" + "context" + "fmt" + "io" + gonet "net" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/codec/json" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/message/status" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/options" + coapCache "github.com/plgd-dev/go-coap/v3/pkg/cache" + "github.com/plgd-dev/go-coap/v3/udp" + "github.com/plgd-dev/go-coap/v3/udp/server" +) + +type Net struct { + cfg Config + mux *mux.Router + handler RequestHandler + logger log.Logger + + servers coAPServers + serving atomic.Bool + done chan struct{} + cache *coapCache.Cache[int32, bool] +} + +func newMCastConn(multicastAddr string, logger log.Logger) (*net.UDPConn, error) { + networks := []string{UDP4, UDP6} + var a *gonet.UDPAddr + var err error + var network string + for _, net := range networks { + a, err = gonet.ResolveUDPAddr(net, multicastAddr) + if err == nil { + network = net + break + } + } + if err != nil { + return nil, err + } + + ifaces, err := gonet.Interfaces() + if err != nil { + return nil, err + } + + mcastListener, err := net.NewListenUDP(network, multicastAddr) + if err != nil { + return nil, err + } + + var anySet bool + for i := range ifaces { + iface := ifaces[i] + err = mcastListener.JoinGroup(&iface, a) + if err == nil { + anySet = true + } + if err != nil { + logger.Warnf("cannot JoinGroup(%v, %v): %v", iface, a, err) + } + } + if !anySet { + _ = mcastListener.Close() + return nil, fmt.Errorf("cannot JoinGroup(%v): %w", a, err) + } + + err = mcastListener.SetMulticastLoopback(true) + if err != nil { + _ = mcastListener.Close() + return nil, err + } + + return mcastListener, nil +} + +func newConn(network, port string) (*net.UDPConn, error) { + return net.NewListenUDP(network, ":"+port) +} + +func getLogContent(r *pool.Message) string { + content := "" + if r == nil { + return content + } + body := r.Body() + if body == nil { + return content + } + defer func() { + _, _ = body.Seek(0, io.SeekStart) + }() + contentFormat := message.TextPlain + if m, err := r.Options().ContentFormat(); err == nil { + contentFormat = m + } + + switch contentFormat { + case message.AppCBOR, message.AppOcfCbor: + var v interface{} + if err := cbor.ReadFrom(body, &v); err == nil { + if data, err := json.Encode(v); err == nil { + content = string(data) + } + } + case message.TextPlain: + data, err := io.ReadAll(body) + if err == nil { + content = string(data) + } + } + return content +} + +func logReqResp(logger log.Logger, c mux.Conn, r *mux.Message, resp *pool.Message) { + content := getLogContent(resp) + p, err := r.Path() + if err == nil && p == "/.well-known/core" { + // don't log core discovery + return + } + respStr := "" + if resp != nil { + respStr = resp.String() + } + logger.Debugf("%v, req=%v resp=%v, content=%v", c.RemoteAddr(), r.String(), respStr, content) +} + +func CreateResponseError(ctx context.Context, err error, token message.Token) *pool.Message { + if err == nil { + return nil + } + s, ok := status.FromError(err) + code := codes.BadRequest + if ok { + code = s.Code() + } + msg := pool.NewMessage(ctx) + msg.SetCode(code) + msg.SetToken(token) + // Don't set content format for diagnostic message: https://tools.ietf.org/html/rfc7252#section-5.5.2 + msg.SetBody(bytes.NewReader([]byte(err.Error()))) + return msg +} + +func CreateLoggingMiddleware(logger log.Logger) func(next mux.Handler) mux.Handler { + return func(next mux.Handler) mux.Handler { + return mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + next.ServeCOAP(w, r) + logReqResp(logger, w.Conn(), r, w.Message()) + }) + } +} + +func (n *Net) ServeCOAP(w mux.ResponseWriter, request *mux.Message) { + now := time.Now() + messageID := request.MessageID() + if messageID >= 0 && request.Type() != message.Confirmable { + v, loaded := n.cache.LoadOrStore(messageID, coapCache.NewElement(true, now.Add(n.cfg.DeduplicationLifetime), func(bool) { + // no-op + })) + if loaded && !v.IsExpired(now) { + n.logger.Debugf("duplicate message %v according messageID: %v", request, messageID) + return + } + } + request.Hijack() + go func(w mux.ResponseWriter, request *mux.Message) { + r := Request{ + Message: request.Message, + Endpoints: n.GetEndpoints(request.ControlMessage(), w.Conn().NetConn().LocalAddr().String()), + Conn: w.Conn(), + } + + resp, err := n.handler(&r) + if err != nil { + resp = CreateResponseError(request.Context(), err, request.Token()) + } + if resp != nil { + resp.SetToken(request.Token()) + logReqResp(n.logger, w.Conn(), request, resp) + if err = w.Conn().WriteMessage(resp); err != nil { + n.logger.Errorf("cannot write response: %w", err) + } + } + }(w, request) +} + +type coAPServer struct { + s *server.Server + l *net.UDPConn +} + +type coAPServers []coAPServer + +func (s coAPServers) Stop() { + for _, cs := range s { + cs.s.Stop() + } +} + +func (s coAPServers) Close() error { + var errors *multierror.Error + for _, cs := range s { + err := cs.l.Close() + if err != nil { + errors = multierror.Append(errors, err) + } + } + return errors.ErrorOrNil() +} + +func getPortFromAddress(addr gonet.Addr) (string, error) { + udpAddr, ok := addr.(*gonet.UDPAddr) + if ok { + return fmt.Sprintf("%d", udpAddr.Port), nil + } + addrStr := addr.String() + _, port, err := gonet.SplitHostPort(addrStr) + if err != nil { + return "", err + } + return port, nil +} + +func newServers(cfg *Config, m *mux.Router, logger log.Logger) (coAPServers, bool, bool, error) { + servers := make(coAPServers, 0, len(cfg.externalAddressesPort)) + hasIPv4 := false + hasIPv6 := false + for i, addr := range cfg.externalAddressesPort { + var conn *net.UDPConn + var err error + if addr.network == UDP4 { + hasIPv4 = true + conn, err = newConn(addr.network, addr.port) + } + if addr.network == UDP6 { + hasIPv6 = true + conn, err = newConn(addr.network, addr.port) + } + if err != nil { + _ = servers.Close() + return nil, false, false, err + } + if addr.port == "0" { + port, err := getPortFromAddress(conn.LocalAddr()) + if err != nil { + _ = servers.Close() + return nil, false, false, err + } + cfg.externalAddressesPort[i].port = port + } + + if conn != nil { + servers = append(servers, coAPServer{ + s: udp.NewServer( + options.WithMux(m), + options.WithErrors(func(err error) { logger.Errorf("server: %w", err) }), + options.WithMaxMessageSize(cfg.MaxMessageSize), + ), + l: conn, + }) + } + } + if len(servers) == 0 { + return nil, false, false, fmt.Errorf("cannot create any server") + } + return servers, hasIPv4, hasIPv6, nil +} + +func appendMCastServers(servers coAPServers, mcastAddresses []string, cfg Config, m *mux.Router, logger log.Logger) (coAPServers, error) { + for _, addr := range mcastAddresses { + if addr == "" { + continue + } + conn, err := newMCastConn(addr, logger) + if err != nil { + _ = servers.Close() + return nil, err + } + servers = append(servers, coAPServer{ + s: udp.NewServer(options.WithMux(m), + options.WithMaxMessageSize(cfg.MaxMessageSize), + options.WithErrors(func(err error) { logger.Errorf("mcast server: %w", err) }), + ), + l: conn, + }) + } + return servers, nil +} + +func New(cfg Config, handler RequestHandler, logger log.Logger) (*Net, error) { + err := cfg.Validate() + if err != nil { + return nil, err + } + m := mux.NewRouter() + servers, hasIPv4, hasIPv6, err := newServers(&cfg, m, logger) + if err != nil { + return nil, err + } + if hasIPv4 { + servers, err = appendMCastServers(servers, core.DefaultDiscoveryConfiguration().MulticastAddressUDP4, cfg, m, logger) + if err != nil { + return nil, err + } + } + if hasIPv6 { + servers, err = appendMCastServers(servers, core.DefaultDiscoveryConfiguration().MulticastAddressUDP6, cfg, m, logger) + if err != nil { + return nil, err + } + } + + n := &Net{ + cfg: cfg, + servers: servers, + mux: m, + handler: handler, + logger: logger, + done: make(chan struct{}), + cache: coapCache.NewCache[int32, bool](), + } + m.DefaultHandle(mux.HandlerFunc(n.ServeCOAP)) + go func() { + for { + select { + case <-n.done: + return + case <-time.After(n.cfg.DeduplicationLifetime / 2): + now := time.Now() + n.cache.CheckExpirations(now) + } + } + }() + return n, nil +} + +func (n *Net) getNetwork(cm *net.ControlMessage, localHost, localPort string) string { + if cm != nil { + if cm.Dst.To4() == nil { + return UDP6 + } + return UDP4 + } + p := n.cfg.externalAddressesPort.filterByPort(localPort) + if len(p) == 1 { + return p[0].network + } + ip := gonet.ParseIP(localHost) + if ip.To4() == nil { + return UDP6 + } + return UDP4 +} + +func (n *Net) GetEndpoints(cm *net.ControlMessage, localAddr string) schema.Endpoints { + localHost, localPort, err := gonet.SplitHostPort(localAddr) + if err != nil { + n.logger.Warnf("cannot get local address: %v", err) + return nil + } + network := n.getNetwork(cm, localHost, localPort) + filteredByNetwork := n.cfg.externalAddressesPort.filterByNetwork(network) + filtered := filteredByNetwork.filterByPort(localPort) + if len(filtered) == 0 { + filtered = filteredByNetwork + } + ep := localAddr + if len(filtered) > 0 { + ep = filtered[0].host + if filtered[0].network == UDP6 { + ep = "[" + ep + "]" + } + ep = ep + ":" + filtered[0].port + } + return schema.Endpoints{ + { + URI: fmt.Sprintf("coap://%v", ep), + }, + } +} + +func (n *Net) Serve() error { + if !n.serving.CompareAndSwap(false, true) { + return fmt.Errorf("already serving") + } + defer close(n.done) + var wg sync.WaitGroup + errCh := make(chan error, len(n.servers)) + wg.Add(len(n.servers)) + for _, cs := range n.servers { + go func(cs coAPServer) { + defer wg.Done() + err := cs.s.Serve(cs.l) + errCh <- err + }(cs) + } + wg.Wait() + var errors *multierror.Error + for { + select { + case err := <-errCh: + if err != nil { + errors = multierror.Append(errors, err) + } + default: + return errors.ErrorOrNil() + } + } +} + +func (n *Net) Close() error { + if !n.serving.Load() { + return n.servers.Close() + } + n.servers.Stop() + <-n.done + return nil +} diff --git a/bridge/net/network_internal_test.go b/bridge/net/network_internal_test.go new file mode 100644 index 00000000..898ca0b5 --- /dev/null +++ b/bridge/net/network_internal_test.go @@ -0,0 +1,96 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net + +import ( + "net" + "testing" + + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/stretchr/testify/require" +) + +func TestGetPortFromAddress(t *testing.T) { + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:42") + require.NoError(t, err) + port, err := getPortFromAddress(udpAddr) + require.NoError(t, err) + require.Equal(t, "42", port) + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:42") + require.NoError(t, err) + port, err = getPortFromAddress(tcpAddr) + require.NoError(t, err) + require.Equal(t, "42", port) +} + +func TestGetPortFromAddress_Fail(t *testing.T) { + ipAddr, err := net.ResolveIPAddr("ip", "127.0.0.1") + require.NoError(t, err) + _, err = getPortFromAddress(ipAddr) + require.Error(t, err) +} + +func TestGetNetwork(t *testing.T) { + cfg := Config{ + ExternalAddresses: []string{"127.0.0.1:42", "[::1]:42", "127.0.0.1:13", "[::1]:37"}, + } + err := cfg.Validate() + require.NoError(t, err) + + n := &Net{ + cfg: cfg, + } + + network := n.getNetwork(nil, "127.0.0.1", "42") + require.Equal(t, UDP4, network) + network = n.getNetwork(nil, "[::1]", "42") + require.Equal(t, UDP6, network) + network = n.getNetwork(nil, "127.0.0.1", "13") + require.Equal(t, UDP4, network) + network = n.getNetwork(nil, "[::1]", "37") + require.Equal(t, UDP6, network) +} + +func TestClose(t *testing.T) { + cfg := Config{ + ExternalAddresses: []string{"127.0.0.1:0"}, + } + n, err := New(cfg, nil, log.NewNilLogger()) + require.NoError(t, err) + + require.NotEqual(t, 0, n.cfg.externalAddressesPort[0]) + checkUDPPort := func(opened bool) { + listenAddress, errL := net.ResolveUDPAddr(n.cfg.externalAddressesPort[0].network, ":"+n.cfg.externalAddressesPort[0].port) + require.NoError(t, errL) + conn, errL := net.ListenUDP(n.cfg.externalAddressesPort[0].network, listenAddress) + if opened { + require.Error(t, errL) + return + } + require.NoError(t, errL) + errL = conn.Close() + require.NoError(t, errL) + } + checkUDPPort(true) + + err = n.Close() + require.NoError(t, err) + checkUDPPort(false) +} diff --git a/bridge/net/request.go b/bridge/net/request.go new file mode 100644 index 00000000..6ddfbeb8 --- /dev/null +++ b/bridge/net/request.go @@ -0,0 +1,95 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net + +import ( + "errors" + "strings" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/mux" +) + +type Request struct { + *pool.Message + Conn mux.Conn + Endpoints schema.Endpoints +} + +type RequestHandler func(req *Request) (*pool.Message, error) + +var ErrKeyNotFound = errors.New("key not found") + +func (r *Request) GetValueFromQuery(key string) (string, error) { + q, err := r.Queries() + if err != nil { + return "", err + } + prefix := key + "=" + for _, query := range q { + if strings.HasPrefix(query, prefix) { + return strings.TrimPrefix(query, prefix), nil + } + } + return "", ErrKeyNotFound +} + +func (r *Request) URIPath() string { + p, err := r.Message.Options().Path() + if err != nil { + return "" + } + return p +} + +func (r *Request) Interface() string { + v, err := r.GetValueFromQuery("if") + if err != nil { + return "" + } + return v +} + +func (r *Request) DeviceID() uuid.UUID { + v, err := r.GetValueFromQuery("di") + if err != nil { + return uuid.Nil + } + di, err := uuid.Parse(v) + if err != nil { + return uuid.Nil + } + return di +} + +func (r *Request) ResourceTypes() []string { + q, err := r.Queries() + if err != nil { + return nil + } + resourceTypes := make([]string, 0, len(q)) + for _, query := range q { + if strings.HasPrefix(query, "rt=") { + resourceTypes = append(resourceTypes, strings.TrimPrefix(query, "rt=")) + } + } + return resourceTypes +} diff --git a/bridge/net/request_test.go b/bridge/net/request_test.go new file mode 100644 index 00000000..526c7ae4 --- /dev/null +++ b/bridge/net/request_test.go @@ -0,0 +1,148 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgn.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implien. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package net_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/stretchr/testify/require" +) + +func TestGetValueFromQuery(t *testing.T) { + req := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + req.AddQuery("param1=value1") + req.AddQuery("param2=value2") + reqNoQuery := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + + tests := []struct { + req *net.Request + key string + wantErr bool + want string + }{ + {req, "param1", false, "value1"}, + {req, "param2", false, "value2"}, + {req, "param3", true, ""}, + {reqNoQuery, "param1", true, ""}, + } + + for _, tt := range tests { + v, err := tt.req.GetValueFromQuery(tt.key) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, v) + } +} + +func TestURIPath(t *testing.T) { + req := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + const path = "/path/to/resource" + err := req.SetPath(path) + require.NoError(t, err) + require.Equal(t, path, req.URIPath()) + + reqNoPath := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + require.Equal(t, "", reqNoPath.URIPath()) +} + +func TestInterface(t *testing.T) { + req := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + req.AddQuery("q1=v1") + req.AddQuery("if=interface") + req.AddQuery("q2=v2") + require.Equal(t, "interface", req.Interface()) + + reqNoInterface := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + reqNoInterface.AddQuery("q1=v1") + require.Equal(t, "", reqNoInterface.Interface()) + + reqNoQuery := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + require.Equal(t, "", reqNoQuery.Interface()) +} + +func TestDeviceID(t *testing.T) { + req := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + req.AddQuery("q1=v1") + deviceID := uuid.New() + req.AddQuery("di=" + deviceID.String()) + req.AddQuery("q2=v2") + require.Equal(t, deviceID, req.DeviceID()) + + reqInvalidDeviceID := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + reqInvalidDeviceID.AddQuery("di=invalid") + require.Equal(t, uuid.Nil, reqInvalidDeviceID.DeviceID()) + + reqNoDeviceID := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + reqNoDeviceID.AddQuery("q1=v1") + require.Equal(t, uuid.Nil, reqNoDeviceID.DeviceID()) + + reqNoQuery := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + require.Equal(t, uuid.Nil, reqNoQuery.DeviceID()) +} + +func TestResourceTypes(t *testing.T) { + req := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + req.AddQuery("q1=v1") + req.AddQuery("rt=type1") + req.AddQuery("rt=type2") + req.AddQuery("q2=v2") + require.Equal(t, []string{"type1", "type2"}, req.ResourceTypes()) + + reqNoResourceTypes := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + reqNoResourceTypes.AddQuery("q1=v1") + require.Equal(t, []string{}, reqNoResourceTypes.ResourceTypes()) + + reqNoQuery := &net.Request{ + Message: pool.NewMessage(context.Background()), + } + require.Nil(t, reqNoQuery.ResourceTypes()) +} diff --git a/bridge/observeResource_test.go b/bridge/observeResource_test.go new file mode 100644 index 00000000..9bfea93d --- /dev/null +++ b/bridge/observeResource_test.go @@ -0,0 +1,143 @@ +package bridge_test + +import ( + "bytes" + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/net/coap" + "github.com/plgd-dev/device/v2/schema/interfaces" + testClient "github.com/plgd-dev/device/v2/test/client" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/stretchr/testify/require" +) + +func TestObserveResource(t *testing.T) { + s := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s.Shutdown() + }) + d := bridgeTest.NewBridgedDevice(t, s, uuid.New().String(), false, false) + defer func() { + s.DeleteAndCloseDevice(d.GetID()) + }() + + rds := resourceDataSync{ + resourceData: resourceData{ + Name: "test", + }, + } + + resHandler := func(req *net.Request) (*pool.Message, error) { + resp := pool.NewMessage(req.Context()) + switch req.Code() { + case codes.GET: + resp.SetCode(codes.Content) + case codes.POST: + resp.SetCode(codes.Changed) + default: + return nil, fmt.Errorf("invalid method %v", req.Code()) + } + resp.SetContentFormat(message.AppOcfCbor) + data, err := cbor.Encode(rds.copy()) + if err != nil { + return nil, err + } + resp.SetBody(bytes.NewReader(data)) + return resp, nil + } + + res := resources.NewResource("/test", resHandler, func(req *net.Request) (*pool.Message, error) { + codec := codecOcf.VNDOCFCBORCodec{} + var newData resourceData + err := codec.Decode(req.Message, &newData) + if err != nil { + return nil, err + } + rds.setName(newData.Name) + return resHandler(req) + }, []string{"oic.d.virtual", "oic.d.test"}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + res.SetObserveHandler(func(req *net.Request, handler func(msg *pool.Message, err error)) (cancel func(), err error) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer cancel() + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Millisecond * 100): + resp, err := resHandler(req) + if err != nil { + handler(nil, err) + return + } + handler(resp, nil) + } + } + }() + return cancel, nil + }) + d.AddResource(res) + + cleanup := bridgeTest.RunBridgeService(s) + defer func() { + errC := cleanup() + require.NoError(t, errC) + }() + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*8) + defer cancel() + + h := testClient.MakeMockResourceObservationHandler() + obsID, err := c.ObserveResource(ctx, d.GetID().String(), "/test", h) + require.NoError(t, err) + defer func() { + _, errC := c.StopObservingResource(ctx, obsID) + require.NoError(t, errC) + }() + + n, err := h.WaitForNotification(ctx) + require.NoError(t, err) + var originalResource resourceData + err = n(&originalResource) + require.NoError(t, err) + require.Equal(t, "test", originalResource.Name) + + var got coap.DetailedResponse[interface{}] + err = c.UpdateResource(ctx, d.GetID().String(), "/test", map[string]interface{}{ + "name": "updated", + }, &got) + require.NoError(t, err) + require.Equal(t, codes.Changed, got.Code) + require.Equal(t, "updated", rds.getName()) + + n, err = h.WaitForNotification(ctx) + require.NoError(t, err) + var updatedResource resourceData + err = n(&updatedResource) + require.NoError(t, err) + require.Equal(t, "updated", updatedResource.Name) + + // fail - invalid data + err = c.UpdateResource(ctx, d.GetID().String(), "/test", map[string]interface{}{ + "name": 1, + }, &got) + require.Error(t, err) +} diff --git a/bridge/onboardDevice_test.go b/bridge/onboardDevice_test.go new file mode 100644 index 00000000..41d59890 --- /dev/null +++ b/bridge/onboardDevice_test.go @@ -0,0 +1,157 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package bridge_test + +import ( + "context" + "crypto/x509" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + schemaCredential "github.com/plgd-dev/device/v2/schema/credential" + "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" + "github.com/stretchr/testify/require" +) + +func TestOnboardDevice(t *testing.T) { + s := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s.Shutdown() + }) + deviceID := uuid.New().String() + d := bridgeTest.NewBridgedDevice(t, s, deviceID, true, false) + defer func() { + s.DeleteAndCloseDevice(d.GetID()) + }() + deviceIDwithoutCAPool := uuid.New().String() + deviceWithoutCAPool := bridgeTest.NewBridgedDevice(t, s, deviceIDwithoutCAPool, true, true, device.WithCAPool(cloud.MakeCAPool(func() []*x509.Certificate { + return nil + }, false))) + defer func() { + s.DeleteAndCloseDevice(deviceWithoutCAPool.GetID()) + }() + cleanup := bridgeTest.RunBridgeService(s) + defer func() { + errC := cleanup() + require.NoError(t, errC) + }() + + type args struct { + deviceID string + authorizationProvider string + authorizationCode string + cloudURL string + cloudID string + cloudCA []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "notFound", + args: args{ + deviceID: "notFound", + authorizationProvider: "authorizationProvider", + authorizationCode: "authorizationCode", + cloudURL: "coaps+tcp://test:5684", + cloudID: "cloudID", + }, + wantErr: true, + }, + { + name: "valid", + args: args{ + deviceID: deviceID, + authorizationProvider: "authorizationProvider", + authorizationCode: "authorizationCode", + cloudURL: "coaps+tcp://test:5684", + cloudID: "cloudID", + }, + }, + { + name: "validWithCA", + args: args{ + deviceID: deviceIDwithoutCAPool, + authorizationProvider: "authorizationProvider", + authorizationCode: "authorizationCode", + cloudURL: "coaps+tcp://test:5684", + cloudID: "cloudID", + cloudCA: test.GetRootCApem(t), + }, + }, + } + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) + defer cancel() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ttCtx, ttCancel := context.WithTimeout(ctx, time.Second) + defer ttCancel() + if len(tt.args.cloudCA) > 0 { + err = c.UpdateResource(ctx, tt.args.deviceID, schemaCredential.ResourceURI, schemaCredential.CredentialUpdateRequest{ + Credentials: []schemaCredential.Credential{ + { + Subject: tt.args.cloudID, + Type: schemaCredential.CredentialType_ASYMMETRIC_SIGNING_WITH_CERTIFICATE, + Usage: schemaCredential.CredentialUsage_TRUST_CA, + PublicData: &schemaCredential.CredentialPublicData{ + DataInternal: tt.args.cloudCA, + Encoding: schemaCredential.CredentialPublicDataEncoding_PEM, + }, + }, + }, + }, nil) + require.NoError(t, err) + var res schemaCredential.CredentialResponse + err = c.GetResource(ctx, tt.args.deviceID, schemaCredential.ResourceURI, &res) + require.NoError(t, err) + require.Len(t, res.Credentials, 1) + } + err = c.OnboardDevice(ttCtx, tt.args.deviceID, tt.args.authorizationProvider, tt.args.cloudURL, tt.args.authorizationCode, tt.args.cloudID) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + err = c.OffboardDevice(ttCtx, tt.args.deviceID) + require.NoError(t, err) + if len(tt.args.cloudCA) > 0 { + // remove CA is async + time.Sleep(time.Second) + var res schemaCredential.CredentialResponse + err = c.GetResource(ctx, tt.args.deviceID, schemaCredential.ResourceURI, &res) + require.NoError(t, err) + require.Len(t, res.Credentials, 0) + } + }) + } +} diff --git a/bridge/resources/cbor.go b/bridge/resources/cbor.go new file mode 100644 index 00000000..0c0544c8 --- /dev/null +++ b/bridge/resources/cbor.go @@ -0,0 +1,47 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources + +import "github.com/plgd-dev/device/v2/pkg/codec/cbor" + +func MergeCBORStructs(a ...interface{}) interface{} { + var merged map[interface{}]interface{} + for _, v := range a { + if v == nil { + continue + } + data, err := cbor.Encode(v) + if err != nil { + continue + } + var m map[interface{}]interface{} + err = cbor.Decode(data, &m) + if err != nil { + continue + } + if merged == nil { + merged = m + } else { + for k, v := range m { + merged[k] = v + } + } + } + return merged +} diff --git a/bridge/resources/cbor_test.go b/bridge/resources/cbor_test.go new file mode 100644 index 00000000..31652f78 --- /dev/null +++ b/bridge/resources/cbor_test.go @@ -0,0 +1,113 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/stretchr/testify/require" +) + +func TestMergeCBORStructs(t *testing.T) { + type args struct { + structs []interface{} + } + tests := []struct { + name string + args args + want interface{} + }{ + { + name: "valid", + args: args{ + structs: []interface{}{ + map[interface{}]interface{}{"key1": "value1"}, + map[interface{}]interface{}{"key2": "value2"}, + }, + }, + want: map[interface{}]interface{}{ + "key1": "value1", + "key2": "value2", + }, + }, + { + name: "merging nil structs", + args: args{ + structs: []interface{}{ + nil, + nil, + }, + }, + want: nil, + }, + { + name: "invalid CBOR encoding", + args: args{ + structs: []interface{}{ + map[interface{}]interface{}{"key1": "value1"}, + "invalid CBOR", + }, + }, + want: map[interface{}]interface{}{ + "key1": "value1", + }, + }, + { + name: "merging struct with empty CBOR encoding", + args: args{ + structs: []interface{}{ + map[interface{}]interface{}{"key1": "value1"}, + map[interface{}]interface{}{}, + }, + }, + want: map[interface{}]interface{}{ + "key1": "value1", + }, + }, + { + name: "merging multiple structs", + args: args{ + structs: []interface{}{ + map[interface{}]interface{}{"key1": "value1"}, + map[interface{}]interface{}{"key2": "value2"}, + map[interface{}]interface{}{"key3": "value3"}, + map[interface{}]interface{}{"key4": "value4"}, + }, + }, + want: map[interface{}]interface{}{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + "key4": "value4", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resources.MergeCBORStructs(tt.args.structs...) + if tt.want != nil { + require.Equal(t, tt.want, got) + return + } + require.Nil(t, got) + }) + } +} diff --git a/bridge/resources/cloud/resource.go b/bridge/resources/cloud/resource.go new file mode 100644 index 00000000..f5df1c4e --- /dev/null +++ b/bridge/resources/cloud/resource.go @@ -0,0 +1,44 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + plgdCloud "github.com/plgd-dev/device/v2/schema/cloud" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type Resource struct { + *resources.Resource +} + +type Manager interface { + Get(req *net.Request) (*pool.Message, error) + Post(req *net.Request) (*pool.Message, error) +} + +func New(uri string, m Manager) *Resource { + d := &Resource{} + d.Resource = resources.NewResource(uri, m.Get, m.Post, []string{plgdCloud.ResourceType}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + // don't publish cloud resource to cloud + d.PolicyBitMask &= ^resources.PublishToCloud + return d +} diff --git a/bridge/resources/device/resource.go b/bridge/resources/device/resource.go new file mode 100644 index 00000000..c0fa455d --- /dev/null +++ b/bridge/resources/device/resource.go @@ -0,0 +1,85 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package device + +import ( + "bytes" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/schema/device" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type Device interface { + GetID() uuid.UUID + GetName() string + GetResourceTypes() []string + GetProtocolIndependentID() uuid.UUID +} + +type GetAdditionalPropertiesForResponseFunc func() map[string]interface{} + +type Resource struct { + *resources.Resource + device Device + getAdditionalProperties GetAdditionalPropertiesForResponseFunc +} + +func New(uri string, dev Device, getAdditionalProperties GetAdditionalPropertiesForResponseFunc) *Resource { + if getAdditionalProperties == nil { + getAdditionalProperties = func() map[string]interface{} { return nil } + } + d := &Resource{ + device: dev, + getAdditionalProperties: getAdditionalProperties, + } + d.Resource = resources.NewResource(uri, d.Get, nil, dev.GetResourceTypes(), []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_R}) + return d +} + +func (d *Resource) Get(request *net.Request) (*pool.Message, error) { + additionalProperties := d.getAdditionalProperties() + deviceProperties := device.Device{ + ID: d.device.GetID().String(), + Name: d.device.GetName(), + ProtocolIndependentID: d.device.GetProtocolIndependentID().String(), + // DataModelVersion: "ocf.res.1.3.0", + // SpecificationVersion: "ocf.2.0.5", + } + if request.Interface() == interfaces.OC_IF_BASELINE { + deviceProperties.ResourceTypes = d.Resource.ResourceTypes + deviceProperties.Interfaces = d.Resource.ResourceInterfaces + } + properties := resources.MergeCBORStructs(additionalProperties, deviceProperties) + res := pool.NewMessage(request.Context()) + res.SetCode(codes.Content) + res.SetContentFormat(message.AppOcfCbor) + data, err := cbor.Encode(properties) + if err != nil { + return nil, err + } + res.SetBody(bytes.NewReader(data)) + return res, nil +} diff --git a/bridge/resources/discovery/resource.go b/bridge/resources/discovery/resource.go new file mode 100644 index 00000000..59f269af --- /dev/null +++ b/bridge/resources/discovery/resource.go @@ -0,0 +1,56 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package discovery + +import ( + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/device/v2/schema/interfaces" + plgdResources "github.com/plgd-dev/device/v2/schema/resources" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type GetLinksHandler func(request *net.Request) schema.ResourceLinks + +type Resource struct { + *resources.Resource + getLinks GetLinksHandler +} + +func New(uri string, getLinks GetLinksHandler) *Resource { + d := &Resource{ + getLinks: getLinks, + } + d.Resource = resources.NewResource(uri, d.Get, nil, []string{plgdResources.ResourceType}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_R}) + return d +} + +func PatchLinks(links schema.ResourceLinks, deviceID string) schema.ResourceLinks { + for _, l := range links { + l.Anchor = "ocf://" + deviceID + } + return links +} + +func (d *Resource) Get(request *net.Request) (*pool.Message, error) { + links := PatchLinks(d.getLinks(request), request.DeviceID().String()) + return resources.CreateResponseContent(request.Context(), links, codes.Content) +} diff --git a/bridge/resources/etag.go b/bridge/resources/etag.go new file mode 100644 index 00000000..ffd3c13c --- /dev/null +++ b/bridge/resources/etag.go @@ -0,0 +1,52 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources + +import ( + "crypto/rand" + "encoding/binary" + "sync/atomic" + "time" +) + +var globalETag atomic.Uint64 + +func generateNextETag(currentETag uint64) uint64 { + buf := make([]byte, 4) + _, err := rand.Read(buf) + if err != nil { + return currentETag + uint64(time.Now().UnixNano()%1000) + } + return currentETag + uint64(binary.BigEndian.Uint32(buf)%1000) +} + +func GetETag() uint64 { + for { + now := uint64(time.Now().UnixNano()) + oldEtag := globalETag.Load() + etag := oldEtag + if now > etag { + etag = now + } + newEtag := generateNextETag(etag) + if globalETag.CompareAndSwap(oldEtag, newEtag) { + return newEtag + } + } +} diff --git a/bridge/resources/etag_test.go b/bridge/resources/etag_test.go new file mode 100644 index 00000000..7903aac3 --- /dev/null +++ b/bridge/resources/etag_test.go @@ -0,0 +1,60 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources_test + +import ( + "sync" + "testing" + + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/stretchr/testify/require" +) + +func TestGetETag(t *testing.T) { + var last uint64 + for i := 0; i < 1000; i++ { + etag := resources.GetETag() + require.Greater(t, etag, last) + last = etag + } +} + +func TestGetETagParallel(t *testing.T) { + const numRoutines = 1000 + + etagMap := make(map[uint64]struct{}) + var mutex sync.Mutex + + var wg sync.WaitGroup + wg.Add(numRoutines) + + for i := 0; i < numRoutines; i++ { + go func() { + defer wg.Done() + etag := resources.GetETag() + mutex.Lock() + etagMap[etag] = struct{}{} + mutex.Unlock() + }() + } + + wg.Wait() + + require.Len(t, etagMap, numRoutines) +} diff --git a/bridge/resources/maintenance/resource.go b/bridge/resources/maintenance/resource.go new file mode 100644 index 00000000..c4ded3e2 --- /dev/null +++ b/bridge/resources/maintenance/resource.go @@ -0,0 +1,69 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package maintenance + +import ( + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/device/v2/schema/maintenance" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type OnFactoryReset func() + +type Resource struct { + *resources.Resource + onFactoryReset OnFactoryReset +} + +func (r *Resource) Get(request *net.Request) (*pool.Message, error) { + return resources.CreateResponseContent(request.Context(), maintenance.Maintenance{ + FactoryReset: false, + }, codes.Content) +} + +func (r *Resource) Post(request *net.Request) (*pool.Message, error) { + var upd maintenance.MaintenanceUpdateRequest + err := cbor.ReadFrom(request.Body(), &upd) + if err != nil { + return resources.CreateResponseBadRequest(request.Context(), err) + } + if upd.FactoryReset { + r.onFactoryReset() + } + return resources.CreateResponseContent(request.Context(), maintenance.Maintenance{ + FactoryReset: false, + }, codes.Changed) +} + +func New(uri string, onFactoryReset OnFactoryReset) *Resource { + r := &Resource{ + onFactoryReset: onFactoryReset, + } + r.Resource = resources.NewResource(uri, + r.Get, + r.Post, + []string{maintenance.ResourceType}, + []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}, + ) + return r +} diff --git a/bridge/resources/maintenance/resource_test.go b/bridge/resources/maintenance/resource_test.go new file mode 100644 index 00000000..82d94e66 --- /dev/null +++ b/bridge/resources/maintenance/resource_test.go @@ -0,0 +1,87 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package maintenance_test + +import ( + "bytes" + "context" + "testing" + + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources/maintenance" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + maintenanceSchema "github.com/plgd-dev/device/v2/schema/maintenance" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/stretchr/testify/require" +) + +func TestMaintenanceGet(t *testing.T) { + mnt := maintenance.New(maintenanceSchema.ResourceURI, func() {}) + require.NotNil(t, mnt) + + req := pool.NewMessage(context.Background()) + req.SetContentFormat(message.AppOcfCbor) + resp, err := mnt.Get(&net.Request{ + Message: req, + }) + require.NoError(t, err) + + var mntData maintenanceSchema.Maintenance + err = cbor.ReadFrom(resp.Body(), &mntData) + require.NoError(t, err) + require.False(t, mntData.FactoryReset) +} + +func TestMaintenancePost(t *testing.T) { + invoked := false + mnt := maintenance.New(maintenanceSchema.ResourceURI, func() { + invoked = true + }) + require.NotNil(t, mnt) + + reqInvalid := pool.NewMessage(context.Background()) + reqInvalid.SetContentFormat(message.TextPlain) + reqInvalid.SetBody(bytes.NewReader([]byte(""))) + resp, err := mnt.Post(&net.Request{ + Message: reqInvalid, + }) + require.NoError(t, err) + require.Equal(t, codes.BadRequest, resp.Code()) + require.False(t, invoked) + + d, err := cbor.Encode(maintenanceSchema.MaintenanceUpdateRequest{ + FactoryReset: true, + }) + require.NoError(t, err) + req := pool.NewMessage(context.Background()) + req.SetContentFormat(message.AppOcfCbor) + req.SetBody(bytes.NewReader(d)) + resp, err = mnt.Post(&net.Request{ + Message: req, + }) + require.NoError(t, err) + require.Equal(t, codes.Changed, resp.Code()) + require.True(t, invoked) + var mntData maintenanceSchema.Maintenance + err = cbor.ReadFrom(resp.Body(), &mntData) + require.NoError(t, err) + require.False(t, mntData.FactoryReset) +} diff --git a/bridge/resources/resource.go b/bridge/resources/resource.go new file mode 100644 index 00000000..eb5e4f80 --- /dev/null +++ b/bridge/resources/resource.go @@ -0,0 +1,319 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "hash/crc64" + "io" + "reflect" + + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/pkg/sync" + "go.uber.org/atomic" +) + +type GetHandlerFunc func(req *net.Request) (*pool.Message, error) + +type PostHandlerFunc func(req *net.Request) (*pool.Message, error) + +type CreateSubscriptionFunc func(req *net.Request, handler func(msg *pool.Message, err error)) (cancel func(), err error) + +const PublishToCloud schema.BitMask = 1 << 7 + +type subscription struct { + done <-chan struct{} + cancel func() +} + +type Resource struct { + Href string + ResourceTypes []string + ResourceInterfaces []string + PolicyBitMask schema.BitMask + getHandler GetHandlerFunc + postHandler PostHandlerFunc + createSubscription CreateSubscriptionFunc + closed atomic.Bool + wakeUpSubscription chan bool + createdSubscription *sync.Map[string, *subscription] + etag atomic.Uint64 +} + +func (r *Resource) GetPolicyBitMask() schema.BitMask { + return r.PolicyBitMask +} + +func (r *Resource) GetHref() string { + return r.Href +} + +func (r *Resource) GetResourceTypes() []string { + return r.ResourceTypes +} + +func (r *Resource) GetResourceInterfaces() []string { + return r.ResourceInterfaces +} + +func Unique(strSlice []string) []string { + allKeys := make(map[string]bool) + list := []string{} + for _, item := range strSlice { + if _, value := allKeys[item]; !value { + allKeys[item] = true + list = append(list, item) + } + } + return list +} + +func NewResource(href string, getHandler GetHandlerFunc, postHandler PostHandlerFunc, resourceTypes, resourceInterfaces []string) *Resource { + r := &Resource{ + Href: href, + ResourceTypes: Unique(resourceTypes), + ResourceInterfaces: Unique(resourceInterfaces), + PolicyBitMask: schema.Discoverable | PublishToCloud, + getHandler: getHandler, + postHandler: postHandler, + wakeUpSubscription: make(chan bool, 1), + createdSubscription: sync.NewMap[string, *subscription](), + } + r.etag.Store(GetETag()) + go r.watchSubscriptions() + return r +} + +func CreateResponseMethodNotAllowed(ctx context.Context, token message.Token) *pool.Message { + msg := pool.NewMessage(ctx) + msg.SetCode(codes.MethodNotAllowed) + msg.SetToken(token) + msg.SetContentFormat(message.TextPlain) + msg.SetBody(bytes.NewReader([]byte(fmt.Sprintf("unsupported method %v", codes.MethodNotAllowed)))) + return msg +} + +func CreateResponseContent(ctx context.Context, data interface{}, code codes.Code) (*pool.Message, error) { + if str, ok := data.(string); ok { + res := pool.NewMessage(ctx) + res.SetCode(code) + res.SetContentFormat(message.TextPlain) + res.SetBody(bytes.NewReader([]byte(str))) + return res, nil + } + d, err := cbor.Encode(data) + if err != nil { + return nil, err + } + res := pool.NewMessage(ctx) + res.SetCode(code) + res.SetContentFormat(message.AppOcfCbor) + res.SetBody(bytes.NewReader(d)) + return res, nil +} + +func CreateResponseBadRequest(ctx context.Context, err error) (*pool.Message, error) { + res := pool.NewMessage(ctx) + res.SetCode(codes.BadRequest) + res.SetContentFormat(message.TextPlain) + res.SetBody(bytes.NewReader([]byte(err.Error()))) + return res, nil +} + +func (r *Resource) SetObserveHandler(createSubscription CreateSubscriptionFunc) { + if createSubscription == nil { + r.createSubscription = nil + r.PolicyBitMask &^= schema.Observable + return + } + r.createSubscription = createSubscription + r.PolicyBitMask |= schema.Observable +} + +func (r *Resource) wakeWatchSubscriptions() { + select { + case r.wakeUpSubscription <- true: + default: + } +} + +// Close closes resource and cancel all subscriptions +func (r *Resource) Close() { + if !r.closed.CompareAndSwap(false, true) { + return + } + r.createdSubscription.Range(func(key string, value *subscription) bool { + value.cancel() + return true + }) + r.wakeUpSubscription <- false +} + +func (r *Resource) watchSubscriptions() { + for { + keys := make([]string, 0, r.createdSubscription.Length()) + cases := make([]reflect.SelectCase, 0, r.createdSubscription.Length()+1) + // wake up subscription + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(r.wakeUpSubscription)}) + r.createdSubscription.Range(func(key string, value *subscription) bool { + cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(value.done)}) + keys = append(keys, key) + return true + }) + idx, recv, ok := reflect.Select(cases) + if idx == 0 { + if ok && recv.Bool() { + // wake up subscription - added/ + continue + } + // resource closed + // cancel all subscriptions + r.createdSubscription.Range(func(key string, value *subscription) bool { + value.cancel() + return true + }) + return + } + r.removeSubscription(keys[idx-1]) + } +} + +func (r *Resource) removeSubscription(key string) { + sub, ok := r.createdSubscription.LoadAndDelete(key) + if ok { + sub.cancel() + r.wakeWatchSubscriptions() + } +} + +func (r *Resource) ETag() []byte { + etag := r.etag.Load() + if etag == 0 { + return nil + } + e := make([]byte, 8) + binary.BigEndian.PutUint64(e, etag) + return e +} + +func (r *Resource) UpdateETag() { + r.etag.Store(GetETag()) +} + +func calcCRC64(body io.ReadSeeker) uint64 { + if body == nil { + return 0 + } + h := crc64.New(crc64.MakeTable(crc64.ISO)) + _, _ = io.Copy(h, body) + _, _ = body.Seek(0, io.SeekStart) + return h.Sum64() +} + +func (r *Resource) observerHandler(req *net.Request, createSubscription bool) (*pool.Message, error) { + if !createSubscription { + r.removeSubscription(req.Conn.RemoteAddr().String()) + return r.getHandler(req) + } + req.Message.Hijack() + sequence := atomic.NewUint32(1) + var cancel func() + var err error + var deduplicationNotification atomic.Uint64 + cancel, err = r.createSubscription(req, func(resp *pool.Message, err error) { + if err == nil { + d := calcCRC64(resp.Body()) + if deduplicationNotification.Swap(d) == d { + return + } + resp.SetObserve(sequence.Inc()) + } else { + defer r.removeSubscription(req.Conn.RemoteAddr().String()) + resp, err = CreateResponseBadRequest(req.Conn.Context(), fmt.Errorf("error while observing %s: %w", r.Href, err)) + if err != nil { + return + } + } + resp.SetContext(req.Conn.Context()) + resp.SetToken(req.Token()) + etag := r.ETag() + if etag != nil { + _ = resp.SetETag(etag) + } + + err = req.Conn.WriteMessage(resp) + if err != nil { + r.removeSubscription(req.Conn.RemoteAddr().String()) + return + } + }) + if err != nil { + return CreateResponseBadRequest(req.Context(), err) + } + resp, err := r.getHandler(req) + if err != nil { + cancel() + return nil, err + } + // set deduplicationNotification value to current value of body + d := calcCRC64(resp.Body()) + deduplicationNotification.Store(d) + resp.SetToken(req.Token()) + resp.SetObserve(sequence.Inc()) + oldSub, oldLoaded := r.createdSubscription.Replace(req.Conn.RemoteAddr().String(), &subscription{ + done: req.Context().Done(), + cancel: cancel, + }) + if oldLoaded { + oldSub.cancel() + } + r.wakeWatchSubscriptions() + return resp, nil +} + +func (r *Resource) HandleRequest(req *net.Request) (*pool.Message, error) { + if req.Code() == codes.GET && r.getHandler != nil { + var resp *pool.Message + var err error + if obs, errObs := req.Observe(); errObs == nil && r.createSubscription != nil && r.PolicyBitMask&schema.Observable != 0 { + resp, err = r.observerHandler(req, obs == 0) + } else { + resp, err = r.getHandler(req) + } + if resp != nil && resp.Code() == codes.Content { + etag := r.ETag() + if etag != nil { + _ = resp.SetETag(etag) + } + } + return resp, err + } + if req.Code() == codes.POST && r.postHandler != nil { + return r.postHandler(req) + } + return CreateResponseMethodNotAllowed(req.Context(), req.Token()), nil +} diff --git a/bridge/resources/secure/credential/resource.go b/bridge/resources/secure/credential/resource.go new file mode 100644 index 00000000..bedd5745 --- /dev/null +++ b/bridge/resources/secure/credential/resource.go @@ -0,0 +1,44 @@ +/**************************************************************************** + * + * Copyright (c) 2023 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package credential + +import ( + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/schema/credential" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type Resource struct { + *resources.Resource +} + +type Manager interface { + Get(req *net.Request) (*pool.Message, error) + Post(req *net.Request) (*pool.Message, error) +} + +func New(uri string, m Manager) *Resource { + d := &Resource{} + d.Resource = resources.NewResource(uri, m.Get, m.Post, []string{credential.ResourceType}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + // don't publish credential resource to cloud + d.PolicyBitMask &= ^resources.PublishToCloud + return d +} diff --git a/bridge/resources/uuid.go b/bridge/resources/uuid.go new file mode 100644 index 00000000..91b87753 --- /dev/null +++ b/bridge/resources/uuid.go @@ -0,0 +1,29 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources + +import "github.com/google/uuid" + +func ToUUID(id string) uuid.UUID { + v, err := uuid.Parse(id) + if err != nil { + return uuid.NewSHA1(uuid.NameSpaceURL, []byte(id)) + } + return v +} diff --git a/bridge/resources/uuid_test.go b/bridge/resources/uuid_test.go new file mode 100644 index 00000000..0e9090bd --- /dev/null +++ b/bridge/resources/uuid_test.go @@ -0,0 +1,59 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package resources_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/stretchr/testify/require" +) + +func TestToUUID(t *testing.T) { + type args struct { + id string + } + tests := []struct { + name string + args args + want uuid.UUID + }{ + { + name: "valid", + args: args{ + id: "00000000-0000-0000-0000-000000000001", + }, + want: uuid.MustParse("00000000-0000-0000-0000-000000000001"), + }, + { + name: "invalid", + args: args{ + id: "00000000-0000-0000-0000-0000000000", + }, + want: uuid.NewSHA1(uuid.NameSpaceURL, []byte("00000000-0000-0000-0000-0000000000")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := resources.ToUUID(tt.args.id) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/bridge/service/config.go b/bridge/service/config.go new file mode 100644 index 00000000..6fd312c9 --- /dev/null +++ b/bridge/service/config.go @@ -0,0 +1,62 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + + "github.com/plgd-dev/device/v2/bridge/net" +) + +type CoAPConfig struct { + ID string `yaml:"id"` + net.Config `yaml:",inline"` +} + +func (c *CoAPConfig) Validate() error { + if c.ID == "" { + return fmt.Errorf("id is required") + } + if err := c.Config.Validate(); err != nil { + return err + } + return nil +} + +type APIConfig struct { + CoAP CoAPConfig `yaml:"coap"` +} + +func (c *APIConfig) Validate() error { + if err := c.CoAP.Validate(); err != nil { + return fmt.Errorf("coap.%w", err) + } + return nil +} + +type Config struct { + API APIConfig `yaml:"apis"` +} + +func (c *Config) Validate() error { + if err := c.API.Validate(); err != nil { + return fmt.Errorf("api.%w", err) + } + return nil +} diff --git a/bridge/service/config_test.go b/bridge/service/config_test.go new file mode 100644 index 00000000..da75ac4f --- /dev/null +++ b/bridge/service/config_test.go @@ -0,0 +1,119 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/service" + "github.com/stretchr/testify/require" +) + +func TestCoAPConfigValidate(t *testing.T) { + tests := []struct { + name string + coapConfig service.CoAPConfig + wantErr bool + }{ + { + name: "ValidCoAPConfig", + coapConfig: service.CoAPConfig{ID: "test", Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}, + }, + { + name: "MissingID", + coapConfig: service.CoAPConfig{Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}, + wantErr: true, + }, + { + name: "InvalidConfig", + coapConfig: service.CoAPConfig{ID: "test", Config: net.Config{}}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.coapConfig.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestAPIConfigValidate(t *testing.T) { + tests := []struct { + name string + apiConfig service.APIConfig + wantErr bool + }{ + { + name: "ValidAPIConfig", + apiConfig: service.APIConfig{CoAP: service.CoAPConfig{ID: "test", Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}}, + }, + { + name: "InvalidCoAPConfig", + apiConfig: service.APIConfig{CoAP: service.CoAPConfig{Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.apiConfig.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config service.Config + wantErr bool + }{ + { + name: "ValidConfig", + config: service.Config{API: service.APIConfig{CoAP: service.CoAPConfig{ID: "test", Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}}}, + }, + { + name: "InvalidAPIConfig", + config: service.Config{API: service.APIConfig{CoAP: service.CoAPConfig{Config: net.Config{ExternalAddresses: []string{"localhost:12345"}}}}}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/bridge/service/options.go b/bridge/service/options.go new file mode 100644 index 00000000..4006d403 --- /dev/null +++ b/bridge/service/options.go @@ -0,0 +1,45 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/pkg/log" +) + +type OptionsCfg struct { + onDiscoveryDevices func(req *net.Request) + logger log.Logger +} + +func WithOnDiscoveryDevices(f func(req *net.Request)) Option { + return func(o *OptionsCfg) { + if f != nil { + o.onDiscoveryDevices = f + } + } +} + +func WithLogger(logger log.Logger) Option { + return func(o *OptionsCfg) { + o.logger = logger + } +} + +type Option func(*OptionsCfg) diff --git a/bridge/service/service.go b/bridge/service/service.go new file mode 100644 index 00000000..f5eb2a2a --- /dev/null +++ b/bridge/service/service.go @@ -0,0 +1,253 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/bridge/resources/discovery" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/schema" + plgdResources "github.com/plgd-dev/device/v2/schema/resources" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" +) + +type Device interface { + Init() // start all goroutines for device + Close() // stop all goroutines for device + + ExportConfig() device.Config // export device config + + GetID() uuid.UUID + GetLinks(request *net.Request) (links schema.ResourceLinks) + GetLinksFilteredBy(endpoints schema.Endpoints, deviceIDfilter uuid.UUID, resourceTypesFilter []string, policyBitMaskFitler schema.BitMask) (links schema.ResourceLinks) + GetName() string + GetProtocolIndependentID() uuid.UUID + GetResourceTypes() []string + + HandleRequest(req *net.Request) (*pool.Message, error) + + Range(f func(key string, resource device.Resource) bool) + AddResource(resource device.Resource) + LoadAndDeleteResource(resourceHref string) (device.Resource, bool) + CloseAndDeleteResource(resourceHref string) bool + GetResource(resourceHref string) (device.Resource, bool) + + GetCloudManager() *cloud.Manager +} + +type Service struct { + cfg Config + net *net.Net + devices *coapSync.Map[uuid.UUID, Device] + done chan struct{} + onDiscoveryDevices func(req *net.Request) +} + +func (c *Service) LoadDevice(di uuid.UUID) (Device, error) { + d, ok := c.devices.Load(di) + if !ok { + return d, fmt.Errorf("invalid queries: device with di %v not found", di) + } + return d, nil +} + +func (c *Service) handleDiscoverAllLinks(req *net.Request) (*pool.Message, error) { + if req.Message.Type() != message.Acknowledgement && req.Message.Type() != message.Reset { + // discovery is only allowed for CON, NON, UNSET messages + c.onDiscoveryDevices(req) + } + res := discovery.New(plgdResources.ResourceURI, func(request *net.Request) schema.ResourceLinks { + links := make(schema.ResourceLinks, 0, c.devices.Length()+1) + for _, d := range c.devices.CopyData() { + dlinks := d.GetLinks(req) + if len(dlinks) > 0 { + links = append(links, dlinks...) + } + } + return links + }) + defer res.Close() + return res.Get(req) +} + +func (c *Service) DefaultRequestHandler(req *net.Request) (*pool.Message, error) { + uriPath := req.URIPath() + if uriPath == "" { + return nil, nil //nolint:nilnil + } + if req.Code() == codes.GET && uriPath == "/.well-known/core" { + // ignore well-known/core + return nil, nil //nolint:nilnil + } + deviceID := req.DeviceID() + if req.Code() == codes.GET && uriPath == plgdResources.ResourceURI && deviceID == uuid.Nil { + return c.handleDiscoverAllLinks(req) + } + if deviceID == uuid.Nil { + return nil, fmt.Errorf("invalid queries: di query is not set") + } + d, err := c.LoadDevice(deviceID) // check if device exists et + if err != nil { + if req.ControlMessage().Dst.IsMulticast() { + return nil, nil //nolint:nilnil + } + return nil, err + } + return d.HandleRequest(req) +} + +func New(cfg Config, opts ...Option) (*Service, error) { + err := cfg.Validate() + if err != nil { + return nil, err + } + + o := OptionsCfg{ + onDiscoveryDevices: func(req *net.Request) { + // nothing to do + }, + logger: log.NewNilLogger(), + } + for _, opt := range opts { + opt(&o) + } + + c := Service{ + cfg: cfg, + devices: coapSync.NewMap[uuid.UUID, Device](), + onDiscoveryDevices: o.onDiscoveryDevices, + done: make(chan struct{}), + } + n, err := net.New(cfg.API.CoAP.Config, c.DefaultRequestHandler, o.logger) + if err != nil { + return nil, err + } + c.net = n + + return &c, nil +} + +func (c *Service) Serve() error { + defer close(c.done) + return c.net.Serve() +} + +func (c *Service) Shutdown() error { + err := c.net.Close() + if err != nil { + return err + } + <-c.done + devices := c.devices.LoadAndDeleteAll() + for _, d := range devices { + d.Close() + } + return nil +} + +type NewDeviceFunc func(id uuid.UUID, piid uuid.UUID) (Device, error) + +func (c *Service) CreateDevice(id uuid.UUID, newDevice NewDeviceFunc) (Device, error) { + var d Device + var err error + _, oldLoaded := c.devices.ReplaceWithFunc(id, func(oldValue Device, oldLoaded bool) (newValue Device, doDelete bool) { + if oldLoaded { + return oldValue, false + } + d, err = newDevice(id, resources.ToUUID(c.cfg.API.CoAP.ID)) + if err != nil { + if oldLoaded { + return oldValue, false + } + return nil, true + } + return d, false + }) + if err != nil { + return nil, err + } + if oldLoaded { + return nil, fmt.Errorf("device with id %v already exists", id) + } + return d, nil +} + +func (c *Service) GetOrCreateDevice(id uuid.UUID, newDevice NewDeviceFunc) (d Device, loaded bool, err error) { + oldDevice, oldLoaded := c.devices.ReplaceWithFunc(id, func(oldValue Device, oldLoaded bool) (newValue Device, doDelete bool) { + if oldLoaded { + return oldValue, false + } + d, err = newDevice(id, resources.ToUUID(c.cfg.API.CoAP.ID)) + if err != nil { + if oldLoaded { + return oldValue, false + } + return nil, true + } + return d, false + }) + if err != nil { + return nil, false, err + } + if oldLoaded { + return oldDevice, true, nil + } + return d, false, nil +} + +func (c *Service) GetDevice(id uuid.UUID) (Device, bool) { + return c.devices.Load(id) +} + +func (c *Service) CopyDevices() map[uuid.UUID]Device { + return c.devices.CopyData() +} + +func (c *Service) Range(f func(key uuid.UUID, value Device) bool) { + c.devices.Range(f) +} + +func (c *Service) RangeWithLock(f func(key uuid.UUID, value Device) bool) { + c.devices.Range2(f) +} + +func (c *Service) Length() int { + return c.devices.Length() +} + +func (c *Service) GetAndDeleteDevice(id uuid.UUID) (Device, bool) { + return c.devices.LoadAndDelete(id) +} + +func (c *Service) DeleteAndCloseDevice(id uuid.UUID) bool { + d, ok := c.devices.LoadAndDelete(id) + if ok { + d.Close() + } + return ok +} diff --git a/bridge/test/test.go b/bridge/test/test.go new file mode 100644 index 00000000..8d57f901 --- /dev/null +++ b/bridge/test/test.go @@ -0,0 +1,111 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package test + +import ( + "crypto/tls" + "crypto/x509" + "testing" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/service" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/test" + "github.com/stretchr/testify/require" +) + +const ( + BRIDGE_SERVICE_PIID = "f47ac10b-58cc-4372-a567-0e02b2c3d479" + BRIDGE_DEVICE_HOST = "127.0.0.1:0" // 0 means random port + BRIDGE_DEVICE_HOST_IPv6 = "[::1]:0" // 0 means random port +) + +func MakeConfig(t *testing.T) service.Config { + var cfg service.Config + cfg.API.CoAP.ID = BRIDGE_SERVICE_PIID + cfg.API.CoAP.Config.ExternalAddresses = []string{BRIDGE_DEVICE_HOST, BRIDGE_DEVICE_HOST_IPv6} + require.NoError(t, cfg.API.Validate()) + return cfg +} + +func NewBridgeService(t *testing.T, opts ...service.Option) *service.Service { + opts = append([]service.Option{service.WithLogger(log.NewStdLogger(log.LevelDebug))}, opts...) + s, err := service.New(MakeConfig(t), opts...) + require.NoError(t, err) + return s +} + +func RunBridgeService(s *service.Service) func() error { + cleanup := func() error { + return s.Shutdown() + } + go func() { + _ = s.Serve() + }() + return cleanup +} + +func NewBridgedDeviceWithConfig(t *testing.T, s *service.Service, cfg device.Config, opts ...device.Option) service.Device { + newDevice := func(di uuid.UUID, piid uuid.UUID) (service.Device, error) { + cfg.ID = di + cfg.ProtocolIndependentID = piid + require.NoError(t, cfg.Validate()) + caPool := cloud.MakeCAPool(func() []*x509.Certificate { + return test.GetRootCA(t) + }, false) + deviceOpts := []device.Option{ + device.WithCAPool(caPool), + device.WithGetCertificates(func(deviceID string) []tls.Certificate { + return []tls.Certificate{test.GetMfgCertificate(t)} + }), + device.WithLogger(device.NewLogger(di, log.LevelDebug)), + } + // allow to override default options + deviceOpts = append(deviceOpts, opts...) + return device.New(cfg, deviceOpts...) + } + d, err := s.CreateDevice(cfg.ID, newDevice) + require.NoError(t, err) + d.Init() + return d +} + +func makeDeviceConfig(id uuid.UUID, cloudEnabled bool, credentialEnabled bool) device.Config { + cfg := device.Config{ + ID: id, + Name: "bridged-device", + ResourceTypes: []string{"oic.d.virtual"}, + MaxMessageSize: 1024 * 256, + } + cfg.Cloud.Enabled = cloudEnabled + if cloudEnabled { + cfg.Cloud.CloudID = test.CloudSID() + } + cfg.Credential.Enabled = credentialEnabled + return cfg +} + +func NewBridgedDevice(t *testing.T, s *service.Service, id string, cloudEnabled bool, credentialEnabled bool, opts ...device.Option) service.Device { + u, err := uuid.Parse(id) + require.NoError(t, err) + cfg := makeDeviceConfig(u, cloudEnabled, credentialEnabled) + return NewBridgedDeviceWithConfig(t, s, cfg, opts...) +} diff --git a/bridge/updateResource_test.go b/bridge/updateResource_test.go new file mode 100644 index 00000000..d0f92947 --- /dev/null +++ b/bridge/updateResource_test.go @@ -0,0 +1,128 @@ +package bridge_test + +import ( + "bytes" + "context" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + bridgeTest "github.com/plgd-dev/device/v2/bridge/test" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/net/coap" + "github.com/plgd-dev/device/v2/schema/interfaces" + testClient "github.com/plgd-dev/device/v2/test/client" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/stretchr/testify/require" +) + +type resourceData struct { + Name string `json:"name,omitempty"` +} + +type resourceDataSync struct { + resourceData + lock sync.Mutex +} + +func (r *resourceDataSync) setName(name string) { + r.lock.Lock() + defer r.lock.Unlock() + r.Name = name +} + +func (r *resourceDataSync) getName() string { + r.lock.Lock() + defer r.lock.Unlock() + return r.Name +} + +func (r *resourceDataSync) copy() resourceData { + r.lock.Lock() + defer r.lock.Unlock() + return resourceData{ + Name: r.Name, + } +} + +func TestUpdateResource(t *testing.T) { + s := bridgeTest.NewBridgeService(t) + t.Cleanup(func() { + _ = s.Shutdown() + }) + d := bridgeTest.NewBridgedDevice(t, s, uuid.New().String(), false, false) + defer func() { + s.DeleteAndCloseDevice(d.GetID()) + }() + + rds := resourceDataSync{ + resourceData: resourceData{ + Name: "test", + }, + } + + resHandler := func(req *net.Request) (*pool.Message, error) { + resp := pool.NewMessage(req.Context()) + data, err := cbor.Encode(rds.copy()) + if err != nil { + return nil, err + } + resp.SetCode(codes.Changed) + resp.SetContentFormat(message.AppOcfCbor) + resp.SetBody(bytes.NewReader(data)) + return resp, nil + } + + res := resources.NewResource("/test", resHandler, func(req *net.Request) (*pool.Message, error) { + codec := codecOcf.VNDOCFCBORCodec{} + var newData resourceData + err := codec.Decode(req.Message, &newData) + if err != nil { + return nil, err + } + rds.setName(newData.Name) + return resHandler(req) + }, []string{"oic.d.virtual", "oic.d.test"}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + d.AddResource(res) + + cleanup := bridgeTest.RunBridgeService(s) + defer func() { + errC := cleanup() + require.NoError(t, errC) + }() + + c, err := testClient.NewTestSecureClientWithBridgeSupport() + require.NoError(t, err) + defer func() { + errC := c.Close(context.Background()) + require.NoError(t, errC) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*8) + defer cancel() + var got coap.DetailedResponse[interface{}] + err = c.UpdateResource(ctx, d.GetID().String(), "/test", map[string]interface{}{ + "name": "updated", + }, &got) + require.NoError(t, err) + require.Equal(t, codes.Changed, got.Code) + require.Equal(t, "updated", rds.getName()) + + // fail - invalid data + err = c.UpdateResource(ctx, d.GetID().String(), "/test", map[string]interface{}{ + "name": 1, + }, &got) + require.Error(t, err) + + // fail - invalid href + err = c.UpdateResource(ctx, d.GetID().String(), "/invalid", map[string]interface{}{ + "name": "updated", + }, &got) + require.Error(t, err) +} diff --git a/client/app/app.go b/client/app/app.go index afc71fea..f5e32537 100644 --- a/client/app/app.go +++ b/client/app/app.go @@ -21,7 +21,7 @@ import ( "crypto/x509" "fmt" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) type AppConfig struct { @@ -46,13 +46,13 @@ func NewApp(cfg *AppConfig) (*App, error) { var manufacturerCert tls.Certificate var err error if len(cfg.RootCA) != 0 { - rootCA, err = security.ParseX509FromPEM([]byte(cfg.RootCA)) + rootCA, err = pkgX509.ParsePemCertificates([]byte(cfg.RootCA)) if err != nil { return nil, fmt.Errorf("invalid Root CA: %w", err) } } if cfg.Manufacturer != nil && len(cfg.Manufacturer.CA) != 0 { - manufacturerCA, err = security.ParseX509FromPEM([]byte(cfg.Manufacturer.CA)) + manufacturerCA, err = pkgX509.ParsePemCertificates([]byte(cfg.Manufacturer.CA)) if err != nil { return nil, fmt.Errorf("invalid Manufacturer's CA: %w", err) } diff --git a/client/client.go b/client/client.go index df3a2621..c89c7b15 100644 --- a/client/client.go +++ b/client/client.go @@ -27,6 +27,7 @@ import ( "github.com/pion/dtls/v2" "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/client/core/otm" + "github.com/plgd-dev/device/v2/pkg/log" "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/go-coap/v3/net/blockwise" "github.com/plgd-dev/go-coap/v3/options" @@ -58,6 +59,8 @@ type Config struct { DisablePeerTCPSignalMessageCSMs bool DefaultTransferDurationSeconds uint64 // 0 means 15 seconds + UseDeviceIDInQuery bool // if true, deviceID is used also in query. Set this option if you use bridged devices. + // specify one of: DeviceOwnershipSDK *DeviceOwnershipSDKConfig `yaml:",omitempty"` DeviceOwnershipBackend *DeviceOwnershipBackendConfig `yaml:",omitempty"` @@ -79,7 +82,7 @@ func NewClientFromConfig(cfg *Config, app ApplicationCallback, logger core.Logge udpDialOpts := make([]udp.Option, 0, 5) if logger == nil { - logger = core.NewNilLogger() + logger = log.NewNilLogger() } errFn := func(err error) { @@ -152,6 +155,7 @@ func NewClientFromConfig(cfg *Config, app ApplicationCallback, logger core.Logge FailureThreshold: cfg.ObserverFailureThreshold, }), WithCacheExpiration(cacheExpiration), + WithUseDeviceIDInQuery(cfg.UseDeviceIDInQuery), } deviceOwner, err := NewDeviceOwnerFromConfig(cfg, dialTLS, dialDTLS, app) @@ -175,6 +179,8 @@ type ClientConfig struct { CacheExpiration time.Duration // Observer is a configuration of the devices observation. Observer ObserverConfig + // UseDeviceIDInQuery if true, deviceID is used also in query. Set this option if you use bridged devices. + UseDeviceIDInQuery bool } type ClientOptionFunc func(ClientConfig) ClientConfig @@ -254,6 +260,14 @@ func WithDialUDP(dial core.DialUDP) ClientOptionFunc { } } +// WithUseDeviceIDInQuery sets the observer config. +func WithUseDeviceIDInQuery(useDeviceIDInQuery bool) ClientOptionFunc { + return func(cfg ClientConfig) ClientConfig { + cfg.UseDeviceIDInQuery = useDeviceIDInQuery + return cfg + } +} + // NewClient constructs a new local client. func NewClient( app ApplicationCallback, @@ -283,7 +297,7 @@ func NewClient( } if coreCfg.Logger == nil { - coreCfg.Logger = core.NewNilLogger() + coreCfg.Logger = log.NewNilLogger() } tls := core.TLSConfig{ GetCertificate: deviceOwner.GetIdentityCertificate, @@ -310,6 +324,7 @@ func NewClient( subscriptions: make(map[string]subscription), observerConfig: clientCfg.Observer, logger: coreCfg.Logger, + useDeviceIDInQuery: clientCfg.UseDeviceIDInQuery, } return &client, nil } @@ -341,6 +356,8 @@ type Client struct { disableUDPEndpoints bool logger core.Logger + + useDeviceIDInQuery bool } func (c *Client) popSubscriptions() map[string]subscription { diff --git a/client/client_test.go b/client/client_test.go index 515e4d0b..9f7a4dd7 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -18,9 +18,6 @@ package client_test import ( "context" - "crypto/tls" - "crypto/x509" - "encoding/pem" "fmt" "log" "testing" @@ -30,17 +27,15 @@ import ( "github.com/plgd-dev/device/v2/client/core" codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" "github.com/plgd-dev/device/v2/pkg/net/coap" - "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/stretchr/testify/require" ) -const TestTimeout = time.Second * 8 - var ( ETagSupported = false ETagBatchSupported = false @@ -87,14 +82,14 @@ func init() { } deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() panicIfErr(err) defer func() { errC := c.Close(context.Background()) panicIfErr(errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) @@ -133,101 +128,9 @@ func init() { } } -type testSetupSecureClient struct { - ca []*x509.Certificate - mfgCA []*x509.Certificate - mfgCert tls.Certificate -} - -func (c *testSetupSecureClient) GetManufacturerCertificate() (tls.Certificate, error) { - if c.mfgCert.PrivateKey == nil { - return c.mfgCert, fmt.Errorf("private key not set") - } - return c.mfgCert, nil -} - -func (c *testSetupSecureClient) GetManufacturerCertificateAuthorities() ([]*x509.Certificate, error) { - if len(c.mfgCA) == 0 { - return nil, fmt.Errorf("certificate authority not set") - } - return c.mfgCA, nil -} - -func (c *testSetupSecureClient) GetRootCertificateAuthorities() ([]*x509.Certificate, error) { - if len(c.ca) == 0 { - return nil, fmt.Errorf("certificate authorities not set") - } - return c.ca, nil -} - -func NewTestSecureClient() (*client.Client, error) { - return newTestSecureClient(test.IdentityIntermediateCA, test.IdentityIntermediateCAKey) -} - -func NewTestSecureClientWithGeneratedCertificate() (*client.Client, error) { - var cfgCA generateCertificate.Configuration - cfgCA.Subject.CommonName = "anotherClient" - cfgCA.ValidFrom = "now" - cfgCA.ValidFor = time.Hour - - priv, err := cfgCA.GenerateKey() - if err != nil { - return nil, fmt.Errorf("cannot generate private key: %w", err) - } - cert, err := generateCertificate.GenerateRootCA(cfgCA, priv) - if err != nil { - return nil, fmt.Errorf("cannot generate root ca: %w", err) - } - derKey, err := x509.MarshalECPrivateKey(priv) - if err != nil { - return nil, fmt.Errorf("cannot marhsal private key: %w", err) - } - key := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: derKey}) - return newTestSecureClient(cert, key) -} - -func newTestSecureClient(signerCert, signerKey []byte) (*client.Client, error) { - cfg := client.Config{ - DeviceOwnershipSDK: &client.DeviceOwnershipSDKConfig{ - ID: CertIdentity, - Cert: string(signerCert), - CertKey: string(signerKey), - CreateSignerFunc: test.NewIdentityCertificateSigner, - }, - } - mfgTrustedCABlock, _ := pem.Decode(test.RootCACrt) - if mfgTrustedCABlock == nil { - return nil, fmt.Errorf("mfgTrustedCABlock is empty") - } - mfgCA, err := x509.ParseCertificates(mfgTrustedCABlock.Bytes) - if err != nil { - return nil, err - } - mfgCert, err := tls.X509KeyPair(test.MfgCert, test.MfgKey) - if err != nil { - return nil, fmt.Errorf("cannot X509KeyPair: %w", err) - } - - client, err := client.NewClientFromConfig(&cfg, &testSetupSecureClient{ - mfgCA: mfgCA, - mfgCert: mfgCert, - }, core.NewNilLogger(), - ) - if err != nil { - return nil, err - } - err = client.Initialization(context.Background()) - if err != nil { - return nil, err - } - return client, nil -} - func disown(t *testing.T, c *client.Client, deviceID string) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() err := c.DisownDevice(ctx, deviceID) require.NoError(t, err) } - -var CertIdentity = "00000000-0000-0000-0000-000000000001" diff --git a/client/core/client.go b/client/core/client.go index ff5615c5..e4d6a863 100644 --- a/client/core/client.go +++ b/client/core/client.go @@ -25,6 +25,7 @@ import ( "github.com/pion/dtls/v2" "github.com/pion/logging" pkgError "github.com/plgd-dev/device/v2/pkg/error" + "github.com/plgd-dev/device/v2/pkg/log" "github.com/plgd-dev/device/v2/pkg/net/coap" coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/tcp" @@ -172,7 +173,7 @@ func NewClient(opts ...OptionFunc) *Client { cfg = o(cfg) } if cfg.Logger == nil { - cfg.Logger = NewNilLogger() + cfg.Logger = log.NewNilLogger() } cfg.TLSConfig = checkTLSConfig(cfg.TLSConfig) diff --git a/client/core/client_test.go b/client/core/client_test.go index 9dd0c087..fa5d6185 100644 --- a/client/core/client_test.go +++ b/client/core/client_test.go @@ -30,11 +30,11 @@ import ( "github.com/plgd-dev/device/v2/client/core/otm/manufacturer" pkgError "github.com/plgd-dev/device/v2/pkg/error" "github.com/plgd-dev/device/v2/pkg/net/coap" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/test" "github.com/plgd-dev/go-coap/v3/tcp" "github.com/plgd-dev/go-coap/v3/udp" - "github.com/plgd-dev/kit/v2/security" "github.com/stretchr/testify/require" ) @@ -64,12 +64,12 @@ func NewTestSecureClientWithCert(cert tls.Certificate, disableDTLS, disableTCPTL return nil, err } - mfgCa, err := security.ParseX509FromPEM(test.RootCACrt) + mfgCa, err := pkgX509.ParsePemCertificates(test.RootCACrt) if err != nil { return nil, err } - identityIntermediateCA, err := security.ParseX509FromPEM(test.IdentityIntermediateCA) + identityIntermediateCA, err := pkgX509.ParsePemCertificates(test.IdentityIntermediateCA) if err != nil { return nil, err } diff --git a/client/core/disownDevice.go b/client/core/disownDevice.go index 4180db0e..b2dc7745 100644 --- a/client/core/disownDevice.go +++ b/client/core/disownDevice.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/google/uuid" + "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/pstat" ) @@ -37,6 +38,7 @@ func connectionWasClosed(ctx context.Context, err error) bool { func (d *Device) Disown( ctx context.Context, links schema.ResourceLinks, + options ...coap.OptionFunc, ) error { cannotDisownErr := func(err error) error { return fmt.Errorf("cannot disown: %w", err) @@ -71,7 +73,7 @@ func (d *Device) Disown( } link.Endpoints = link.Endpoints.FilterSecureEndpoints() - err = d.UpdateResource(ctx, link, setResetProvisionState, nil) + err = d.UpdateResource(ctx, link, setResetProvisionState, nil, options...) if err != nil { if connectionWasClosed(ctx, err) { // connection was closed by disown so we don't report error just log it. diff --git a/client/core/getDevice.go b/client/core/getDevice.go index dc7b1c7a..076a3d75 100644 --- a/client/core/getDevice.go +++ b/client/core/getDevice.go @@ -19,9 +19,11 @@ package core import ( "context" "fmt" + "net" "strings" "sync" + "github.com/hashicorp/go-multierror" "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/device" @@ -30,12 +32,58 @@ import ( // GetDeviceByIP gets the device directly via IP address and multicast listen port 5683. func (c *Client) GetDeviceByIP(ctx context.Context, ip string) (*Device, error) { + devices, err := c.GetDevicesByIP(ctx, ip) + if err != nil { + return nil, err + } + if len(devices) == 0 { + return nil, MakeNotFound(fmt.Errorf("no response from the device with ip %s", ip)) + } + for _, d := range devices { + err = d.Close(ctx) + if err != nil { + c.logger.Debugf("get device by ip error: %s", err.Error()) + } + } + return devices[0], nil +} + +func getAddress(ip string) (addr string, isIpv4 bool, _ error) { + if strings.Contains(ip, ".") { + host, port, err := net.SplitHostPort(ip) + if err != nil { + host, port, err = net.SplitHostPort(ip + ":5683") + } + if err != nil { + return "", false, err + } + return host + ":" + port, true, nil + } + if !strings.Contains(ip, "[") { + ip = "[" + ip + "]:5683" + } + host, port, err := net.SplitHostPort(ip) + if err != nil { + host, port, err = net.SplitHostPort(ip + ":5683") + } + if err != nil { + return "", false, err + } + return "[" + host + "]:" + port, false, nil +} + +// GetDevicesByIP gets the devices directly via IP address and multicast listen port 5683. +func (c *Client) GetDevicesByIP(ctx context.Context, ip string) ([]*Device, error) { var discoveryConfiguration DiscoveryConfiguration - if strings.Contains(ip, ":") && !strings.Contains(ip, "[") { - ip = "[" + ip + "]" - discoveryConfiguration.MulticastAddressUDP6 = []string{ip + ":5683"} + + addr, isIPv4, err := getAddress(ip) + if err != nil { + return nil, MakeInvalidArgument(fmt.Errorf("could not get the device via ip %s: %w", ip, err)) + } + if isIPv4 { + discoveryConfiguration.MulticastAddressUDP4 = []string{addr} } else { - discoveryConfiguration.MulticastAddressUDP4 = []string{ip + ":5683"} + discoveryConfiguration.MulticastAddressUDP6 = []string{addr} } findCtx, cancel := context.WithCancel(ctx) @@ -53,18 +101,20 @@ func (c *Client) GetDeviceByIP(ctx context.Context, ip string) (*Device, error) } }() - h := newDeviceHandler(c.getDeviceConfiguration(), ANY_DEVICE, cancel) + h := newDevicesHandler(c.getDeviceConfiguration(), ANY_DEVICE, cancel) // we want to just get "oic.wk.d" resource, because links will be get via unicast to /oic/res err = DiscoverDevices(findCtx, multicastConn, h, coap.WithResourceType(device.ResourceType)) if err != nil { - return nil, MakeDataLoss(fmt.Errorf("could not get the device from ip %s: %w", ip, err)) + return nil, MakeDataLoss(fmt.Errorf("could not get the devices from ip %s: %w", ip, err)) } - d := h.Device() - if d == nil { - return nil, MakeNotFound(fmt.Errorf("no response from the device with ip %s", ip)) + devices := h.Devices() + if len(devices) == 0 { + return nil, MakeNotFound(fmt.Errorf("no response from the devices with ip %s", ip)) + } + for _, d := range devices { + d.setFoundByIP(addr) } - d.setFoundByIP(ip) - return d, nil + return devices, nil } // GetDeviceByMulticast performs a multicast and returns a device object if the device responds. @@ -84,14 +134,14 @@ func (c *Client) GetDeviceByMulticast(ctx context.Context, deviceID string, disc } }() - h := newDeviceHandler(c.getDeviceConfiguration(), deviceID, cancel) + h := newDevicesHandler(c.getDeviceConfiguration(), deviceID, cancel) // we want to just get "oic.wk.d" resource, because links will be get via unicast to /oic/res err = DiscoverDevices(findCtx, multicastConn, h, coap.WithResourceType(device.ResourceType), coap.WithDeviceID(deviceID)) if err != nil { return nil, MakeDataLoss(fmt.Errorf("could not get the device %s: %w", deviceID, err)) } - d := h.Device() - if d == nil { + d := h.Devices() + if len(d) == 0 { err = h.Err() if err != nil { return nil, MakeInternal(fmt.Errorf("no response from the device %s: %w", deviceID, err)) @@ -99,12 +149,12 @@ func (c *Client) GetDeviceByMulticast(ctx context.Context, deviceID string, disc return nil, MakeInternal(fmt.Errorf("no response from the device %s", deviceID)) } - return d, nil + return d[0], nil } const ANY_DEVICE = "anydevice" -func newDeviceHandler( +func newDevicesHandler( deviceCfg DeviceConfiguration, deviceID string, cancel context.CancelFunc, @@ -121,15 +171,17 @@ type deviceHandler struct { deviceID string cancel context.CancelFunc - lock sync.Mutex - device *Device - err error + lock sync.Mutex + devices []*Device + err error } -func (h *deviceHandler) Device() *Device { +func (h *deviceHandler) Devices() []*Device { h.lock.Lock() defer h.lock.Unlock() - return h.device + devices := h.devices + h.devices = nil + return devices } func (h *deviceHandler) Handle(_ context.Context, conn *client.Conn, links schema.ResourceLinks) { @@ -138,29 +190,34 @@ func (h *deviceHandler) Handle(_ context.Context, conn *client.Conn, links schem } h.lock.Lock() defer h.lock.Unlock() - - link, err := GetResourceLink(links, device.ResourceURI) - if err != nil { - h.err = err + links = links.GetResourceLinks(device.ResourceType) + if len(links) == 0 { + h.err = MakeUnavailable(fmt.Errorf("cannot get %v resourceType for device: not found", device.ResourceType)) return } - deviceID := link.GetDeviceID() - if deviceID == "" { - h.err = MakeInternal(fmt.Errorf("cannot determine deviceID")) + if len(h.devices) > 0 { return } - - if h.device != nil || (deviceID != h.deviceID && h.deviceID != ANY_DEVICE) { - return + var errs *multierror.Error + for _, link := range links { + deviceID := link.GetDeviceID() + if deviceID == "" { + errs = multierror.Append(errs, MakeUnavailable(fmt.Errorf("cannot determine deviceID"))) + continue + } + if deviceID != h.deviceID && h.deviceID != ANY_DEVICE { + continue + } + if len(link.ResourceTypes) == 0 { + errs = multierror.Append(errs, MakeUnavailable(fmt.Errorf("cannot get resource types for %v: is empty", deviceID))) + continue + } + h.devices = append(h.devices, NewDevice(h.deviceCfg, deviceID, link.ResourceTypes, link.GetEndpoints)) } - if len(link.ResourceTypes) == 0 { - h.err = MakeDataLoss(fmt.Errorf("cannot get resource types for %v: is empty", deviceID)) - return + h.err = errs.ErrorOrNil() + if len(h.devices) > 0 { + h.cancel() } - d := NewDevice(h.deviceCfg, deviceID, link.ResourceTypes, link.GetEndpoints) - - h.device = d - h.cancel() } func (h *deviceHandler) Error(err error) { diff --git a/client/core/getDevices.go b/client/core/getDevices.go index 257302ea..e4ef7c4e 100644 --- a/client/core/getDevices.go +++ b/client/core/getDevices.go @@ -24,6 +24,7 @@ import ( "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/device" + "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/go-coap/v3/udp/client" ) @@ -81,26 +82,35 @@ func (h *discoveryHandler) Handle(ctx context.Context, conn *client.Conn, links if errC := conn.Close(); errC != nil { h.handler.Error(fmt.Errorf("discovery handler cannot close connection: %w", errC)) } - link, err := GetResourceLink(links, device.ResourceURI) - if err != nil { - h.handler.Error(err) - return - } - deviceID := link.GetDeviceID() - if deviceID == "" { - h.handler.Error(fmt.Errorf("cannot determine deviceID")) - return - } - if len(link.ResourceTypes) == 0 { - h.handler.Error(fmt.Errorf("cannot get resource types for %v: is empty", deviceID)) - return + deviceLinks := make(map[string]schema.ResourceLinks) + for _, link := range links { + deviceID := link.GetDeviceID() + if deviceID == "" { + h.handler.Error(fmt.Errorf("cannot determine deviceID")) + continue + } + deviceLinks[deviceID] = append(deviceLinks[deviceID], link) } - _, loaded := h.filterDiscoveredDevices.LoadOrStore(deviceID, true) - if loaded { - return + for deviceID, links := range deviceLinks { + link, err := GetResourceLink(links, device.ResourceURI) + if err != nil { + link, err = GetResourceLink(links, resources.ResourceURI) + } + if err != nil { + h.handler.Error(err) + continue + } + if len(link.ResourceTypes) == 0 { + h.handler.Error(fmt.Errorf("cannot get resource types for %v: is empty", deviceID)) + continue + } + _, loaded := h.filterDiscoveredDevices.LoadOrStore(deviceID, true) + if loaded { + continue + } + d := NewDevice(h.deviceCfg, deviceID, link.ResourceTypes, link.GetEndpoints) + h.handler.Handle(ctx, d) } - d := NewDevice(h.deviceCfg, deviceID, link.ResourceTypes, link.GetEndpoints) - h.handler.Handle(ctx, d) } func (h *discoveryHandler) Error(err error) { diff --git a/client/core/getOwnership.go b/client/core/getOwnership.go index 637e3171..10883c07 100644 --- a/client/core/getOwnership.go +++ b/client/core/getOwnership.go @@ -27,7 +27,7 @@ import ( ) // GetOwnership gets device's ownership resource. -func (d *Device) GetOwnership(ctx context.Context, links schema.ResourceLinks) (doxm.Doxm, error) { +func (d *Device) GetOwnership(ctx context.Context, links schema.ResourceLinks, options ...coap.OptionFunc) (doxm.Doxm, error) { ownLink, ok := links.GetResourceLink(doxm.ResourceURI) if !ok { return doxm.Doxm{}, fmt.Errorf("cannot find %v in links: %+v", doxm.ResourceURI, links) @@ -37,8 +37,11 @@ func (d *Device) GetOwnership(ctx context.Context, links schema.ResourceLinks) ( if len(getOwnlink.Endpoints) == 0 { getOwnlink.Endpoints = ownLink.GetSecureEndpoints() } + opts := make([]coap.OptionFunc, 0, 1+len(options)) + opts = append(opts, coap.WithInterface(interfaces.OC_IF_BASELINE)) + opts = append(opts, options...) var ownership doxm.Doxm - err := d.GetResource(ctx, getOwnlink, &ownership, coap.WithInterface(interfaces.OC_IF_BASELINE)) + err := d.GetResource(ctx, getOwnlink, &ownership, opts...) return ownership, err } diff --git a/client/core/logger.go b/client/core/logger.go index 893f0b53..6692e939 100644 --- a/client/core/logger.go +++ b/client/core/logger.go @@ -16,53 +16,8 @@ package core -type Logger interface { - Debug(string) - Info(string) - Warn(string) - Error(string) - Debugf(template string, args ...interface{}) - Infof(template string, args ...interface{}) - Warnf(template string, args ...interface{}) - Errorf(template string, args ...interface{}) -} +import "github.com/plgd-dev/device/v2/pkg/log" -type NilLogger struct{} - -var nilLogger = &NilLogger{} - -func NewNilLogger() *NilLogger { - return nilLogger -} - -func (*NilLogger) Debug(string) { - // no-op -} - -func (*NilLogger) Info(string) { - // no-op -} - -func (*NilLogger) Warn(string) { - // no-op -} - -func (*NilLogger) Error(string) { - // no-op -} - -func (*NilLogger) Debugf(string, ...interface{}) { - // no-op -} - -func (*NilLogger) Infof(string, ...interface{}) { - // no-op -} - -func (*NilLogger) Warnf(string, ...interface{}) { - // no-op -} - -func (*NilLogger) Errorf(string, ...interface{}) { - // no-op -} +type ( + Logger = log.Logger +) diff --git a/client/core/maintenance.go b/client/core/maintenance.go index 0eb47767..f4d8b875 100644 --- a/client/core/maintenance.go +++ b/client/core/maintenance.go @@ -21,6 +21,7 @@ import ( "fmt" "net/http" + "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/maintenance" ) @@ -28,19 +29,21 @@ import ( func (d *Device) Reboot( ctx context.Context, links schema.ResourceLinks, + options ...coap.OptionFunc, ) error { return d.updateMaintenanceResource(ctx, links, maintenance.MaintenanceUpdateRequest{ Reboot: true, - }) + }, options...) } func (d *Device) FactoryReset( ctx context.Context, links schema.ResourceLinks, + options ...coap.OptionFunc, ) error { err := d.updateMaintenanceResource(ctx, links, maintenance.MaintenanceUpdateRequest{ FactoryReset: true, - }) + }, options...) if connectionWasClosed(ctx, err) { // connection was closed by disown so we don't report error just log it. d.cfg.Logger.Debug(err.Error()) @@ -53,13 +56,14 @@ func (d *Device) updateMaintenanceResource( ctx context.Context, links schema.ResourceLinks, req maintenance.MaintenanceUpdateRequest, + options ...coap.OptionFunc, ) (ret error) { links = links.GetResourceLinks(maintenance.ResourceType) if len(links) == 0 { return MakeUnavailable(fmt.Errorf("cannot find '%v' in %+v", maintenance.ResourceType, links)) } var resp maintenance.Maintenance - err := d.UpdateResource(ctx, links[0], req, &resp) + err := d.UpdateResource(ctx, links[0], req, &resp, options...) if err != nil { return err } diff --git a/client/core/observeResource.go b/client/core/observeResource.go index feced148..5638637a 100644 --- a/client/core/observeResource.go +++ b/client/core/observeResource.go @@ -26,6 +26,7 @@ import ( "github.com/plgd-dev/device/v2/pkg/codec/ocf" "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message" "go.uber.org/atomic" ) @@ -104,6 +105,7 @@ func (d *Device) closeObservations(ctx context.Context) error { type observation struct { id string + options []func(message.Options) message.Options handler *observationHandler client *coap.ClientCloseHandler @@ -139,7 +141,11 @@ func (o *observation) stop(ctx context.Context, byClose bool) error { // stop was called by user so we don't want to call OnClose handler o.handler.disableHandlers() } - err := obs.Cancel(ctx) + opts := make(message.Options, 0, 4) + for _, o := range o.options { + opts = o(opts) + } + err := obs.Cancel(ctx, opts...) if err != nil { return MakeCanceled(fmt.Errorf("cannot cancel observation %s: %w", o.id, err)) } @@ -183,6 +189,7 @@ func (d *Device) observeResource( id: id.String(), handler: &h, client: client, + options: options, } onCloseID := client.RegisterCloseHandler(func(error) { o.handler.OnClose() diff --git a/client/core/otm/client.go b/client/core/otm/client.go index 63fc287e..03c2dfae 100644 --- a/client/core/otm/client.go +++ b/client/core/otm/client.go @@ -22,11 +22,11 @@ import ( "fmt" "github.com/plgd-dev/device/v2/pkg/net/coap" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/plgd-dev/device/v2/schema/credential" "github.com/plgd-dev/device/v2/schema/csr" "github.com/plgd-dev/device/v2/schema/doxm" kitNet "github.com/plgd-dev/kit/v2/net" - kitSecurity "github.com/plgd-dev/kit/v2/security" ) // SignFunc handles a certifice signing request (csr), the csr and returned certificate chain are encoded in PEM format @@ -127,7 +127,7 @@ func ProvisionOwnerCredentials(ctx context.Context, tlsClient *coap.ClientCloseH return fmt.Errorf("cannot sign csr for setup device owner credentials: %w", err) } - certsFromChain, err := kitSecurity.ParseX509FromPEM(signedCsr) + certsFromChain, err := pkgX509.ParsePemCertificates(signedCsr) if err != nil { return fmt.Errorf("failed to parse chain of X509 certs: %w", err) } diff --git a/client/core/provisionDevice.go b/client/core/provisionDevice.go index ba4e01e3..1427ede6 100644 --- a/client/core/provisionDevice.go +++ b/client/core/provisionDevice.go @@ -22,6 +22,7 @@ import ( "encoding/pem" "fmt" + "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/acl" "github.com/plgd-dev/device/v2/schema/cloud" @@ -30,10 +31,11 @@ import ( "github.com/plgd-dev/kit/v2/strings" ) -func (d *Device) Provision(ctx context.Context, links schema.ResourceLinks) (*ProvisioningClient, error) { +func (d *Device) Provision(ctx context.Context, links schema.ResourceLinks, options ...coap.OptionFunc) (*ProvisioningClient, error) { p := ProvisioningClient{ - Device: d, - links: links, + Device: d, + links: links, + options: options, } err := p.start(ctx) if err != nil { @@ -44,7 +46,8 @@ func (d *Device) Provision(ctx context.Context, links schema.ResourceLinks) (*Pr type ProvisioningClient struct { *Device - links schema.ResourceLinks + links schema.ResourceLinks + options []coap.OptionFunc } func (c *ProvisioningClient) start(ctx context.Context) error { @@ -60,7 +63,7 @@ func (c *ProvisioningClient) start(ctx context.Context) error { } link.Endpoints = link.GetSecureEndpoints() - err = c.UpdateResource(ctx, link, provisioningState, nil) + err = c.UpdateResource(ctx, link, provisioningState, nil, c.options...) if err != nil { return fmt.Errorf(errMsg, fmt.Errorf("cannot update to provisioning state %+v: %w", link, err)) } @@ -80,7 +83,7 @@ func (c *ProvisioningClient) Close(ctx context.Context) error { } link.Endpoints = link.GetSecureEndpoints() - err = c.UpdateResource(ctx, link, normalOperationState, nil) + err = c.UpdateResource(ctx, link, normalOperationState, nil, c.options...) if err != nil { return fmt.Errorf(errMsg, err) } @@ -94,7 +97,7 @@ func (c *ProvisioningClient) AddCredentials(ctx context.Context, cr credential.C return fmt.Errorf(errMsg, err) } link.Endpoints = link.GetSecureEndpoints() - err = c.UpdateResource(ctx, link, cr, nil) + err = c.UpdateResource(ctx, link, cr, nil, c.options...) if err != nil { return fmt.Errorf(errMsg, err) } @@ -139,7 +142,7 @@ func (c *ProvisioningClient) SetCloudResource(ctx context.Context, r cloud.Confi return fmt.Errorf("could not resolve cloud resource link of device %s", c.DeviceID()) } link.Endpoints = link.GetSecureEndpoints() - err := c.UpdateResource(ctx, link, r, nil) + err := c.UpdateResource(ctx, link, r, nil, c.options...) if err != nil { return fmt.Errorf("could not set cloud resource of device %s: %w", c.DeviceID(), err) } @@ -169,7 +172,7 @@ func (c *ProvisioningClient) SetAccessControl( return fmt.Errorf(errMsg, err) } link.Endpoints = link.GetSecureEndpoints() - err = c.UpdateResource(ctx, link, setACL, nil) + err = c.UpdateResource(ctx, link, setACL, nil, c.options...) if err != nil { return fmt.Errorf(errMsg, err) } diff --git a/client/createResource.go b/client/createResource.go index f882fdf6..3b0b9175 100644 --- a/client/createResource.go +++ b/client/createResource.go @@ -55,5 +55,9 @@ func (c *Client) CreateResource( return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + return d.UpdateResourceWithCodec(ctx, link, cfg.codec, request, response, cfg.opts...) } diff --git a/client/createResource_test.go b/client/createResource_test.go index 8f11ae73..e3c29ae6 100644 --- a/client/createResource_test.go +++ b/client/createResource_test.go @@ -25,6 +25,7 @@ import ( "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/device/v2/test/resource/types" "github.com/stretchr/testify/require" ) @@ -78,13 +79,13 @@ func TestClientCreateResource(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID, client.WithOTMs([]client.OTMType{client.OTMType_JustWorks})) require.NoError(t, err) diff --git a/client/deleteDevices_test.go b/client/deleteDevices_internal_test.go similarity index 99% rename from client/deleteDevices_test.go rename to client/deleteDevices_internal_test.go index 57594fbc..feb1b1a1 100644 --- a/client/deleteDevices_test.go +++ b/client/deleteDevices_internal_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/log" "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/device" @@ -182,8 +183,7 @@ func NewTestSecureClient() (*Client, error) { client, err := NewClientFromConfig(&cfg, &testSetupSecureClient{ mfgCA: mfgCA, mfgCert: mfgCert, - }, core.NewNilLogger(), - ) + }, log.NewStdLogger(log.LevelDebug)) if err != nil { return nil, err } diff --git a/client/deleteResource.go b/client/deleteResource.go index 737509f3..a8dd2233 100644 --- a/client/deleteResource.go +++ b/client/deleteResource.go @@ -21,6 +21,7 @@ import ( "github.com/plgd-dev/device/v2/client/core" codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/net/coap" ) // DeleteResource deletes the resource from the device. @@ -50,5 +51,9 @@ func (c *Client) DeleteResource( return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + return d.DeleteResourceWithCodec(ctx, link, cfg.codec, response, cfg.opts...) } diff --git a/client/deleteResource_test.go b/client/deleteResource_test.go index ffdabc04..0c25a95d 100644 --- a/client/deleteResource_test.go +++ b/client/deleteResource_test.go @@ -28,6 +28,7 @@ import ( "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/device/v2/test/resource/types" "github.com/stretchr/testify/require" ) @@ -88,13 +89,13 @@ func TestClientDeleteResource(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID, client.WithOTMs([]client.OTMType{client.OTMType_JustWorks})) require.NoError(t, err) @@ -117,13 +118,13 @@ func TestClientDeleteResource(t *testing.T) { func TestClientBatchDeleteResources(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout*50) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID, client.WithOTMs([]client.OTMType{client.OTMType_JustWorks})) require.NoError(t, err) diff --git a/client/deviceCache.go b/client/deviceCache.go index 03c468c7..f872951b 100644 --- a/client/deviceCache.go +++ b/client/deviceCache.go @@ -86,8 +86,8 @@ func (c *DeviceCache) GetDevice(deviceID string) (*core.Device, bool) { return d.Data(), true } -func (c *DeviceCache) GetDeviceByFoundIP(ip string) *core.Device { - var d *core.Device +func (c *DeviceCache) GetDeviceByFoundIP(ip string) []*core.Device { + var d []*core.Device now := time.Now() c.devicesCache.Range(func(deviceID string, item *cache.Element[*core.Device]) bool { if item.IsExpired(now) { @@ -95,7 +95,7 @@ func (c *DeviceCache) GetDeviceByFoundIP(ip string) *core.Device { } dev := item.Data() if dev.FoundByIP() == ip { - d = dev + d = append(d, dev) return false } return true diff --git a/client/deviceOwnershipBackend_test.go b/client/deviceOwnershipBackend_test.go index 8bea464e..715087b0 100644 --- a/client/deviceOwnershipBackend_test.go +++ b/client/deviceOwnershipBackend_test.go @@ -10,8 +10,9 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/plgd-dev/device/v2/client" - "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/log" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -40,11 +41,9 @@ func TestBackendOwnershipClient(t *testing.T) { mfgCert, err := tls.X509KeyPair(test.MfgCert, test.MfgKey) require.NoError(t, err) - client, err := client.NewClientFromConfig(&cfg, &testSetupSecureClient{ - mfgCA: mfgCA, - mfgCert: mfgCert, - }, core.NewNilLogger(), - ) + client, err := client.NewClientFromConfig(&cfg, + testClient.NewTestSetupSecureClient(nil, mfgCA, mfgCert), + log.NewStdLogger(log.LevelDebug)) require.NoError(t, err) ctxWithToken := test.CtxWithToken(context.Background(), jwtWithSubUserID) diff --git a/client/deviceOwnershipSDK.go b/client/deviceOwnershipSDK.go index 7f49eb8c..9adcca25 100644 --- a/client/deviceOwnershipSDK.go +++ b/client/deviceOwnershipSDK.go @@ -30,7 +30,7 @@ import ( "github.com/plgd-dev/device/v2/client/core/otm" justworks "github.com/plgd-dev/device/v2/client/core/otm/just-works" "github.com/plgd-dev/device/v2/client/core/otm/manufacturer" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) type Signer = interface { @@ -91,7 +91,7 @@ func newDeviceOwnershipSDK(app ApplicationCallback, sdkDeviceID string, dialTLS return nil, fmt.Errorf("invalid validFrom(%v) for device ownership SDK: %w", validFrom, err) } - signerCAs, err := security.ParseX509Certificates(signerCert) + signerCAs, err := pkgX509.ParseCertificates(signerCert) if err != nil { return nil, fmt.Errorf("could not parse signer certificates") } diff --git a/client/disownDevice.go b/client/disownDevice.go index e35871dc..d0c33b19 100644 --- a/client/disownDevice.go +++ b/client/disownDevice.go @@ -20,6 +20,7 @@ import ( "context" "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/net/coap" ) func (c *Client) removeTemporaryDeviceFromCache(ctx context.Context, d *core.Device) { @@ -43,8 +44,12 @@ func (c *Client) DisownDevice(ctx context.Context, deviceID string, opts ...Comm ok := d.IsSecured() if !ok { - return d.FactoryReset(ctx, links) + return d.FactoryReset(ctx, links, cfg.opts...) } - return d.Disown(ctx, links) + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + + return d.Disown(ctx, links, cfg.opts...) } diff --git a/client/getDevice.go b/client/getDevice.go index 981a77b6..63f153c0 100644 --- a/client/getDevice.go +++ b/client/getDevice.go @@ -23,18 +23,24 @@ import ( "fmt" "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/device" + "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/plgd-dev/go-coap/v3/message/status" ) func getLinksDevice(ctx context.Context, dev *core.Device, disableUDPEndpoints bool) (schema.ResourceLinks, error) { endpoints := dev.GetEndpoints() - links, err := dev.GetResourceLinks(ctx, endpoints) + links, err := dev.GetResourceLinks(ctx, endpoints, coap.WithDeviceID(dev.DeviceID())) if err != nil { return nil, err } + links = links.FilterByDeviceID(dev.DeviceID()) + if len(links) == 0 { + return nil, fmt.Errorf("cannot get resource links for device %v: not found", dev.DeviceID()) + } return patchResourceLinksEndpoints(links, disableUDPEndpoints), nil } @@ -80,41 +86,65 @@ func (c *Client) GetDeviceByMulticast(ctx context.Context, deviceID string, opts return retDev, links, nil } -func (c *Client) getDeviceByIP(ctx context.Context, ip string, expectedDeviceID string) (*core.Device, schema.ResourceLinks, error) { - dev, err := c.getDeviceByIPWithUpdateCache(ctx, ip, expectedDeviceID) +type DeviceWithLinks struct { + Device *core.Device + Links schema.ResourceLinks +} + +func (c *Client) getDevicesByIP(ctx context.Context, ip string, expectedDeviceID string) ([]DeviceWithLinks, error) { + devs, err := c.getDeviceByIPWithUpdateCache(ctx, ip, expectedDeviceID) if err != nil { - return nil, nil, err + return nil, err } - links, err := getLinksDevice(ctx, dev, c.disableUDPEndpoints) - if err != nil { - return nil, nil, err + devsLinks := make([]DeviceWithLinks, 0, len(devs)) + for _, dev := range devs { + if expectedDeviceID != "" && dev.DeviceID() != expectedDeviceID { + continue + } + links, err := getLinksDevice(ctx, dev, c.disableUDPEndpoints) + if err != nil { + c.deleteDeviceNotFoundByIP(ctx, dev) + continue + } + devsLinks = append(devsLinks, DeviceWithLinks{ + Device: dev, + Links: links, + }) } - return dev, links, nil + return devsLinks, nil } -func (c *Client) getDeviceByIPWithUpdateCache(ctx context.Context, ip string, expectedDeviceID string) (*core.Device, error) { - newDev, err := c.client.GetDeviceByIP(ctx, ip) +func (c *Client) getDeviceByIPWithUpdateCache(ctx context.Context, ip string, expectedDeviceID string) ([]*core.Device, error) { + newDevices, err := c.client.GetDevicesByIP(ctx, ip) if err != nil { return nil, err } - var oldDev *core.Device + var devs []*core.Device + oldDevs := make(map[string]*core.Device) if expectedDeviceID != "" { - oldDev, _ = c.deviceCache.GetDevice(expectedDeviceID) + d, ok := c.deviceCache.GetDevice(expectedDeviceID) + if ok { + oldDevs[expectedDeviceID] = d + } } else { - oldDev = c.deviceCache.GetDeviceByFoundIP(ip) + for _, d := range c.deviceCache.GetDeviceByFoundIP(ip) { + oldDevs[d.DeviceID()] = d + } + } + for _, newDev := range newDevices { + delete(oldDevs, newDev.DeviceID()) + dev, _ := c.deviceCache.UpdateOrStoreDevice(newDev) + devs = append(devs, dev) } - if oldDev != nil && oldDev.DeviceID() != newDev.DeviceID() { + for _, oldDev := range oldDevs { tmp, ok := c.deviceCache.LoadAndDeleteDevice(oldDev.DeviceID()) - if ok && tmp == oldDev { - oldDev.UpdateBy(newDev) - if errC := newDev.Close(ctx); errC != nil { + if ok { + if errC := tmp.Close(ctx); errC != nil { c.logger.Debugf("get device by ip error: %w", errC) } - newDev = oldDev } } - dev, _ := c.deviceCache.UpdateOrStoreDevice(newDev) - return dev, nil + return devs, nil } func (c *Client) checkAndUpdateCacheByLinks(ctx context.Context, dev *core.Device, links schema.ResourceLinks) (*core.Device, schema.ResourceLinks, error) { @@ -150,28 +180,40 @@ func (c *Client) GetDevice(ctx context.Context, deviceID string, opts ...GetDevi if err == nil { return c.checkAndUpdateCacheByLinks(ctx, dev, links) } - var newDev *core.Device if dev.FoundByIP() != "" { - newDev, links, err = c.getDeviceByIP(ctx, dev.FoundByIP(), deviceID) + devLinks, err := c.getDevicesByIP(ctx, dev.FoundByIP(), deviceID) if err != nil { return nil, nil, err } - if newDev.DeviceID() != deviceID { + if len(devLinks) == 0 { return nil, nil, fmt.Errorf("cannot get device %v: not found", deviceID) } - return dev, links, nil + return devLinks[0].Device, devLinks[0].Links, nil } c.deleteDeviceNotFoundByIP(ctx, dev) return c.GetDevice(ctx, deviceID, opts...) } +// GetDevicesByIP gets devices by IP and store it to cache without expiration. +// To delete device, call DeleteDevices with the deviceID. +func (c *Client) GetDevicesByIP( + ctx context.Context, + ip string, +) ([]DeviceWithLinks, error) { + return c.getDevicesByIP(ctx, ip, "") +} + // GetDeviceByIP gets device by IP and store it to cache without expiration. // To delete device, call DeleteDevices with the deviceID. func (c *Client) GetDeviceByIP( ctx context.Context, ip string, ) (*core.Device, schema.ResourceLinks, error) { - return c.getDeviceByIP(ctx, ip, "") + devices, err := c.GetDevicesByIP(ctx, ip) + if err != nil { + return nil, nil, err + } + return devices[0].Device, devices[0].Links, nil } func isDeviceOwnedByOther(err error) bool { @@ -182,14 +224,14 @@ func isDeviceOwnedByOther(err error) bool { return errors.As(err, &unknownAuth) } -func (c *Client) getDeviceDetails(ctx context.Context, dev *core.Device, links schema.ResourceLinks, getDetails GetDetailsFunc) (DeviceDetails, error) { - devDetails, err := getDeviceDetails(ctx, dev, links, getDetails) +func (c *Client) getDeviceDetails(ctx context.Context, dev *core.Device, links schema.ResourceLinks, getDetails GetDetailsFunc, optsArgs []func(message.Options) message.Options) (DeviceDetails, error) { + devDetails, err := getDeviceDetails(ctx, dev, links, getDetails, optsArgs) if err != nil { return DeviceDetails{}, err } var o ownership if devDetails.IsSecured { - d, ownErr := dev.GetOwnership(ctx, links) + d, ownErr := dev.GetOwnership(ctx, links, optsArgs...) if ownErr != nil { if isDeviceOwnedByOther(ownErr) { o.status = OwnershipStatus_OwnedByOther @@ -220,7 +262,7 @@ func (c *Client) GetDeviceDetailsByMulticast(ctx context.Context, deviceID strin if err != nil { return DeviceDetails{}, err } - return c.getDeviceDetails(ctx, dev, links, cfg.getDetails) + return c.getDeviceDetails(ctx, dev, links, cfg.getDetails, cfg.opts) } func (c *Client) GetAllDeviceIDsFoundByIP() map[string]string { @@ -240,5 +282,5 @@ func (c *Client) GetDeviceDetailsByIP(ctx context.Context, ip string, opts ...Ge if err != nil { return DeviceDetails{}, err } - return c.getDeviceDetails(ctx, dev, links, cfg.getDetails) + return c.getDeviceDetails(ctx, dev, links, cfg.getDetails, cfg.opts) } diff --git a/client/getDevice_test.go b/client/getDevice_test.go index c10a7a72..f231c14c 100644 --- a/client/getDevice_test.go +++ b/client/getDevice_test.go @@ -28,6 +28,7 @@ import ( "github.com/plgd-dev/device/v2/schema/doxm" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" testTypes "github.com/plgd-dev/device/v2/test/resource/types" "github.com/stretchr/testify/require" ) @@ -120,10 +121,10 @@ func TestClientGetDevice(t *testing.T) { }, } - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) @@ -170,14 +171,14 @@ func TestClientGetDeviceByIP(t *testing.T) { args: args{ ip: ip4, }, - want: NewTestSecureDeviceSimulator(deviceIDip4, test.DevsimName, ip4), + want: NewTestSecureDeviceSimulator(deviceIDip4, test.DevsimName, ip4+":5683"), }, { name: "ip6", args: args{ ip: ip6, }, - want: NewTestSecureDeviceSimulator(deviceIDip6, test.DevsimName, "["+ip6+"]"), + want: NewTestSecureDeviceSimulator(deviceIDip6, test.DevsimName, "["+ip6+"]:5683"), }, { name: "not-found", @@ -188,10 +189,10 @@ func TestClientGetDeviceByIP(t *testing.T) { }, } - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) @@ -234,7 +235,7 @@ func TestClientCheckForDuplicityDeviceInCache(t *testing.T) { ip := test.MustFindDeviceIP(test.DevsimName, test.IP4) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(ctx) @@ -250,14 +251,20 @@ func TestClientCheckForDuplicityDeviceInCache(t *testing.T) { err = c.DisownDevice(ctx, dev.DeviceID()) require.NoError(t, err) time.Sleep(time.Second * 4) - // deviceID was changed after disowning - the call fails, but internally the cache is updated so dev.DeviceID() returns new deviceID - _, _, err = c.GetDevice(ctx, dev.DeviceID()) - require.Error(t, err) - _, _, err = c.GetDevice(ctx, dev.DeviceID()) + deviceNotExist := func(di string) { + ctx1, cancel1 := context.WithTimeout(ctx, time.Second) + defer cancel1() + _, _, err = c.GetDevice(ctx1, di) + require.Error(t, err) + } + // deviceID was changed after disowning - the call fails, because device not exist anymore + deviceNotExist(dev.DeviceID()) + + dev, _, err = c.GetDeviceByIP(ctx, ip) require.NoError(t, err) // change deviceID by another client - c1, err := NewTestSecureClient() + c1, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c1.Close(ctx) @@ -270,8 +277,11 @@ func TestClientCheckForDuplicityDeviceInCache(t *testing.T) { time.Sleep(time.Second * 4) // try get old device again - _, _, err = c.GetDevice(ctx, dev.DeviceID()) - require.Error(t, err) + deviceNotExist(dev.DeviceID()) + + dev, _, err = c.GetDeviceByIP(ctx, ip) + require.NoError(t, err) + // dev has updated deviceID by previous call so we can get the device _, _, err = c.GetDevice(ctx, dev.DeviceID()) require.NoError(t, err) @@ -296,17 +306,17 @@ func TestClientGetDeviceByIPOwnedByOther(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) ip4 := test.MustFindDeviceIP(test.DevsimName, test.IP4) - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errClose := c.Close(context.Background()) require.NoError(t, errClose) }() - c1, err := NewTestSecureClientWithGeneratedCertificate() + c1, err := testClient.NewTestSecureClientWithGeneratedCertificate() require.NoError(t, err) defer func() { errClose := c1.Close(context.Background()) diff --git a/client/getDevices.go b/client/getDevices.go index 5b1f32df..557b7889 100644 --- a/client/getDevices.go +++ b/client/getDevices.go @@ -29,16 +29,21 @@ import ( "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/schema/doxm" "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message" kitStrings "github.com/plgd-dev/kit/v2/strings" ) -func getDetails(ctx context.Context, d *core.Device, links schema.ResourceLinks) (interface{}, error) { +func getDetails(ctx context.Context, d *core.Device, links schema.ResourceLinks, optsArgs ...func(message.Options) message.Options) (interface{}, error) { link := links.GetResourceLinks(device.ResourceType) if len(link) == 0 { return nil, fmt.Errorf("cannot find device resource at links %+v", links) } var dev device.Device - err := d.GetResource(ctx, link[0], &dev, coap.WithInterface(interfaces.OC_IF_BASELINE)) + opts := []func(message.Options) message.Options{ + coap.WithInterface(interfaces.OC_IF_BASELINE), + } + opts = append(opts, optsArgs...) + err := d.GetResource(ctx, link[0], &dev, opts...) if err != nil { return nil, err } @@ -50,7 +55,7 @@ type ownership struct { status OwnershipStatus } -func (c *Client) getDevicesAppendDeviceByIP(ctx context.Context, ip string, resourceTypes []string, getDetails func(ctx context.Context, d *core.Device, links schema.ResourceLinks) (interface{}, error), appendDevice func(d DeviceDetails)) { +func (c *Client) getDevicesAppendDeviceByIP(ctx context.Context, ip string, resourceTypes []string, getDetails func(ctx context.Context, d *core.Device, links schema.ResourceLinks, optsArgs ...func(message.Options) message.Options) (interface{}, error), appendDevice func(d DeviceDetails)) { d, err := c.GetDeviceDetailsByIP(ctx, ip, WithGetDetails(getDetails)) if err != nil { return @@ -80,6 +85,7 @@ func (c *Client) GetDevicesDetails( cfg := getDevicesOptions{ getDetails: getDetails, discoveryConfiguration: core.DefaultDiscoveryConfiguration(), + useDeviceID: c.useDeviceIDInQuery, } for _, o := range opts { cfg = o.applyOnGetDevices(cfg) @@ -92,11 +98,11 @@ func (c *Client) GetDevicesDetails( resOwnerships[deviceID] = d } - getDetails := func(ctx context.Context, d *core.Device, links schema.ResourceLinks) (interface{}, error) { + getDetails := func(ctx context.Context, d *core.Device, links schema.ResourceLinks, optsArgs ...func(message.Options) message.Options) (interface{}, error) { links = patchResourceLinksEndpoints(links, c.disableUDPEndpoints) - details, err := cfg.getDetails(ctx, d, links) + details, err := cfg.getDetails(ctx, d, links, optsArgs...) if err == nil && d.IsSecured() { - doxm, ownErr := d.GetOwnership(ctx, links) + doxm, ownErr := d.GetOwnership(ctx, links, optsArgs...) if ownErr == nil { ownerships(d.DeviceID(), ownership{ doxm: &doxm, @@ -128,7 +134,7 @@ func (c *Client) GetDevicesDetails( }(ip) } - handler := newDiscoveryHandler(cfg.resourceTypes, c.logger, devices, getDetails, c.deviceCache, c.disableUDPEndpoints) + handler := newDiscoveryHandler(cfg.resourceTypes, c.logger, devices, getDetails, c.deviceCache, c.disableUDPEndpoints, cfg.useDeviceID) if err := c.client.GetDevicesByMulticast(ctx, cfg.discoveryConfiguration, handler); err != nil { return nil, err } @@ -193,8 +199,9 @@ func newDiscoveryHandler( getDetails GetDetailsFunc, deviceCache *DeviceCache, disableUDPEndpoints bool, + useDeviceID bool, ) *discoveryHandler { - return &discoveryHandler{typeFilter: typeFilter, logger: logger, devices: devices, getDetails: getDetails, deviceCache: deviceCache, disableUDPEndpoints: disableUDPEndpoints} + return &discoveryHandler{typeFilter: typeFilter, logger: logger, devices: devices, getDetails: getDetails, deviceCache: deviceCache, disableUDPEndpoints: disableUDPEndpoints, useDeviceID: useDeviceID} } type detailsWasSet struct { @@ -209,13 +216,14 @@ type discoveryHandler struct { getDetails GetDetailsFunc deviceCache *DeviceCache disableUDPEndpoints bool + useDeviceID bool getDetailsWasCalled sync.Map } func (h *discoveryHandler) Error(err error) { h.logger.Debug(err.Error()) } -func getDeviceDetails(ctx context.Context, dev *core.Device, links schema.ResourceLinks, getDetails GetDetailsFunc) (out DeviceDetails, _ error) { +func getDeviceDetails(ctx context.Context, dev *core.Device, links schema.ResourceLinks, getDetails GetDetailsFunc, opts []func(message.Options) message.Options) (out DeviceDetails, _ error) { link, ok := links.GetResourceLink(device.ResourceURI) var eps []schema.Endpoint if ok { @@ -225,7 +233,7 @@ func getDeviceDetails(ctx context.Context, dev *core.Device, links schema.Resour isSecured := dev.IsSecured() var details interface{} if getDetails != nil { - d, err := getDetails(ctx, dev, links) + d, err := getDetails(ctx, dev, links, opts...) if err != nil { return DeviceDetails{}, err } @@ -279,7 +287,12 @@ func (h *discoveryHandler) getDeviceDetails(ctx context.Context, d *core.Device, if m.wasSet { getDetails = nil } - devDetails, err := getDeviceDetails(ctx, d, links, getDetails) + opts := make([]func(message.Options) message.Options, 0, 1) + if h.useDeviceID { + opts = append(opts, coap.WithDeviceID(d.DeviceID())) + } + + devDetails, err := getDeviceDetails(ctx, d, links, getDetails, opts) if err == nil { m.wasSet = true } diff --git a/client/getDevices_test.go b/client/getDevices_test.go index eadb7642..6790c246 100644 --- a/client/getDevices_test.go +++ b/client/getDevices_test.go @@ -26,6 +26,7 @@ import ( "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,7 +34,7 @@ import ( func TestDeviceDiscovery(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) secureDeviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) @@ -83,7 +84,7 @@ func TestDeviceDiscovery(t *testing.T) { func TestDeviceDiscoveryWithFilter(t *testing.T) { secureDeviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) @@ -105,7 +106,7 @@ func TestDeviceDiscoveryWithFilter(t *testing.T) { func TestDevicesWithFoundByIP(t *testing.T) { ip4 := test.MustFindDeviceIP(test.DevsimName, test.IP4) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) diff --git a/client/getResource.go b/client/getResource.go index f641b03e..cdafaf6b 100644 --- a/client/getResource.go +++ b/client/getResource.go @@ -21,6 +21,7 @@ import ( "github.com/plgd-dev/device/v2/client/core" codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/net/coap" ) // GetResource returns the device resource. @@ -49,5 +50,9 @@ func (c *Client) GetResource( return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + return d.GetResourceWithCodec(ctx, link, cfg.codec, response, cfg.opts...) } diff --git a/client/getResource_test.go b/client/getResource_test.go index 48fd0c92..9f2e6b3f 100644 --- a/client/getResource_test.go +++ b/client/getResource_test.go @@ -37,6 +37,7 @@ import ( "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/device/v2/schema/softwareupdate" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/stretchr/testify/require" ) @@ -103,13 +104,13 @@ func TestClientGetResource(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) @@ -131,13 +132,13 @@ func TestClientGetResource(t *testing.T) { func TestClientGetDiscoveryResourceWithResourceTypeFilter(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout*8) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout*8) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) @@ -169,13 +170,13 @@ func updateConfigurationResource(ctx context.Context, c *client.Client, deviceID func TestClientGetDiscoveryResourceWithBatchInterface(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) @@ -247,13 +248,13 @@ func TestClientGetDiscoveryResourceWithBatchInterfaceCreateAndDeleteResource(t * } deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) @@ -298,13 +299,13 @@ func TestClientGetDiscoveryResourceWithBatchInterfaceCreateAndDeleteResource(t * func TestClientGetConResourceByETag(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) diff --git a/client/initialization.go b/client/initialization.go index 8043d0fd..bc745210 100644 --- a/client/initialization.go +++ b/client/initialization.go @@ -27,7 +27,7 @@ import ( "fmt" "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" - kitSecurity "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) func generateSDKCertificate(ctx context.Context, csr []byte, sign SignFunc, priv *ecdsa.PrivateKey) (tls.Certificate, []*x509.Certificate, error) { @@ -46,7 +46,7 @@ func generateSDKCertificate(ctx context.Context, csr []byte, sign SignFunc, priv return tls.Certificate{}, nil, fmt.Errorf("cannot create tls certificate: %w", err) } - certsFromChain, err := kitSecurity.ParseX509FromPEM(cert) + certsFromChain, err := pkgX509.ParsePemCertificates(cert) if err != nil { return tls.Certificate{}, nil, fmt.Errorf("cannot parse cert chain: %w", err) } diff --git a/client/maitenance.go b/client/maitenance.go index 8d79287a..4cec10ae 100644 --- a/client/maitenance.go +++ b/client/maitenance.go @@ -18,6 +18,8 @@ package client import ( "context" + + "github.com/plgd-dev/device/v2/pkg/net/coap" ) // FactoryReset factory resets the device. @@ -27,8 +29,12 @@ func (c *Client) FactoryReset(ctx context.Context, deviceID string, opts ...Comm if err != nil { return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + defer c.removeTemporaryDeviceFromCache(ctx, d) - return d.FactoryReset(ctx, links) + return d.FactoryReset(ctx, links, cfg.opts...) } // Reboot reboots the device. @@ -38,6 +44,10 @@ func (c *Client) Reboot(ctx context.Context, deviceID string, opts ...CommonComm if err != nil { return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + defer c.removeTemporaryDeviceFromCache(ctx, d) - return d.Reboot(ctx, links) + return d.Reboot(ctx, links, cfg.opts...) } diff --git a/client/maitenance_test.go b/client/maitenance_test.go index d5c72eef..8be22975 100644 --- a/client/maitenance_test.go +++ b/client/maitenance_test.go @@ -23,6 +23,7 @@ import ( "github.com/plgd-dev/device/v2/client" "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -54,13 +55,13 @@ func TestClientFactoryReset(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() _, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) diff --git a/client/observeDeviceResources_test.go b/client/observeDeviceResources_test.go index cfd29aa8..2ded68ff 100644 --- a/client/observeDeviceResources_test.go +++ b/client/observeDeviceResources_test.go @@ -18,13 +18,14 @@ package client_test import ( "context" - "fmt" + "errors" "testing" "github.com/plgd-dev/device/v2/client" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -41,7 +42,7 @@ func isDeviceResourcesObservable(ctx context.Context, t *testing.T, c *client.Cl } func runObserveDeviceResourcesTest(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeMockDeviceResourcesObservationHandler() + h := testClient.MakeMockDeviceResourcesObservationHandler() if !isDeviceResourcesObservable(ctx, t, c, deviceID) { t.Skip("resource is not observable") return @@ -50,7 +51,7 @@ func runObserveDeviceResourcesTest(ctx context.Context, t *testing.T, c *client. ID, err := c.ObserveDeviceResources(ctx, deviceID, h) require.NoError(t, err) - e, err := h.waitForNotification(ctx) + e, err := h.WaitForNotification(ctx) require.NoError(t, err) test.CheckResourceLinks(t, test.DefaultDevsimResourceLinks(), e) @@ -58,58 +59,21 @@ func runObserveDeviceResourcesTest(ctx context.Context, t *testing.T, c *client. err = c.CreateResource(ctx, deviceID, test.TestResourceSwitchesHref, test.MakeSwitchResourceDefaultData(), nil) require.NoError(t, err) - e, err = h.waitForNotification(ctx) + e, err = h.WaitForNotification(ctx) require.NoError(t, err) test.CheckResourceLinks(t, append(test.DefaultDevsimResourceLinks(), test.DefaultSwitchResourceLink("1")), e) err = c.DeleteResource(ctx, deviceID, test.TestResourceSwitchesInstanceHref("1"), nil) require.NoError(t, err) - e, err = h.waitForNotification(ctx) + e, err = h.WaitForNotification(ctx) require.NoError(t, err) test.CheckResourceLinks(t, test.DefaultDevsimResourceLinks(), e) ok, err := c.StopObservingDeviceResources(ctx, ID) require.NoError(t, err) require.True(t, ok) - select { - case <-h.res: - require.NoError(t, fmt.Errorf("unexpected event")) - default: - } -} - -func makeMockDeviceResourcesObservationHandler() *mockDeviceResourcesObservationHandler { - return &mockDeviceResourcesObservationHandler{ - res: make(chan schema.ResourceLinks, 100), - close: make(chan struct{}), - } -} - -type mockDeviceResourcesObservationHandler struct { - res chan schema.ResourceLinks - close chan struct{} -} -func (h *mockDeviceResourcesObservationHandler) Handle(_ context.Context, body schema.ResourceLinks) { - h.res <- body -} - -func (h *mockDeviceResourcesObservationHandler) Error(err error) { - fmt.Println(err) -} - -func (h *mockDeviceResourcesObservationHandler) OnClose() { - close(h.close) -} - -func (h *mockDeviceResourcesObservationHandler) waitForNotification(ctx context.Context) (schema.ResourceLinks, error) { - select { - case e := <-h.res: - return e, nil - case <-ctx.Done(): - return nil, ctx.Err() - case <-h.close: - return nil, fmt.Errorf("unexpected close") - } + err = h.WaitForClose(ctx) + require.True(t, err == nil || errors.Is(err, context.DeadlineExceeded)) } diff --git a/client/observeDevices_test.go b/client/observeDevices_test.go index 9562ae6f..33bcf0dc 100644 --- a/client/observeDevices_test.go +++ b/client/observeDevices_test.go @@ -25,6 +25,7 @@ import ( "github.com/plgd-dev/device/v2/client" "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -47,14 +48,14 @@ LOOP: func TestObserveDevicesAddedByIP(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) ip := test.MustFindDeviceIP(test.DevsimName, test.IP4) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout*2) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout*2) defer cancel() h := makeTestDevicesObservationHandler() @@ -102,14 +103,14 @@ LOOP: func TestObserveDevices(t *testing.T) { deviceID := test.MustFindDeviceByName(test.DevsimName) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout*2) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout*2) defer cancel() h := makeTestDevicesObservationHandler() diff --git a/client/observeResource.go b/client/observeResource.go index 7c278b4a..8b62d145 100644 --- a/client/observeResource.go +++ b/client/observeResource.go @@ -235,6 +235,10 @@ func (c *Client) ObserveResource( return "", err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + observationID, err = d.ObserveResourceWithCodec(ctx, link, observerCodec{contentFormat: cfg.codec.ContentFormat()}, h, cfg.opts...) if err != nil { return "", err diff --git a/client/observeResource_test.go b/client/observeResource_test.go index 453bd6bb..e6051d00 100644 --- a/client/observeResource_test.go +++ b/client/observeResource_test.go @@ -37,64 +37,22 @@ import ( "github.com/plgd-dev/device/v2/schema/resources" "github.com/plgd-dev/device/v2/schema/softwareupdate" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func makeObservationHandler() *observationHandler { - return &observationHandler{res: make(chan coap.DecodeFunc, 1), close: make(chan struct{})} -} - -type observationHandler struct { - res chan coap.DecodeFunc - close chan struct{} -} - -func (h *observationHandler) Handle(_ context.Context, body coap.DecodeFunc) { - h.res <- body -} - -func (h *observationHandler) Error(err error) { fmt.Println(err) } - -func (h *observationHandler) OnClose() { close(h.close) } - -func (h *observationHandler) waitForNotification(ctx context.Context) (coap.DecodeFunc, error) { - select { - case e := <-h.res: - return e, nil - case <-ctx.Done(): - return nil, ctx.Err() - case <-h.close: - return nil, fmt.Errorf("unexpected close") - } -} - -func (h *observationHandler) waitForClose(ctx context.Context) error { - select { - case e := <-h.res: - var d interface{} - if err := e(d); err != nil { - return fmt.Errorf("unexpected notification: cannot decode: %w", err) - } - return fmt.Errorf("unexpected notification %v", d) - case <-ctx.Done(): - return ctx.Err() - case <-h.close: - return nil - } -} - func testDevice(t *testing.T, name string, runTest func(ctx context.Context, t *testing.T, c *client.Client, deviceID string)) { deviceID := test.MustFindDeviceByName(name) - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) @@ -105,7 +63,7 @@ func testDevice(t *testing.T, name string, runTest func(ctx context.Context, t * } func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() id, err := c.ObserveResource(ctx, deviceID, test.TestResourceLightInstanceHref("1"), h) require.NoError(t, err) defer func(observationID string) { @@ -113,7 +71,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien require.NoError(t, errC) }(id) - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) d := coap.DetailedResponse[map[string]interface{}]{} err = res(&d) @@ -124,7 +82,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien assert.NotEmpty(t, d.ETag) } - h2 := makeObservationHandler() + h2 := testClient.MakeMockResourceObservationHandler() id, err = c.ObserveResource(ctx, deviceID, test.TestResourceLightInstanceHref("1"), h2) require.NoError(t, err) defer func(observationID string) { @@ -132,7 +90,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien require.NoError(t, errC) }(id) - res, err = h2.waitForNotification(ctx) + res, err = h2.WaitForNotification(ctx) require.NoError(t, err) d2 := coap.DetailedResponse[map[string]interface{}]{} err = res(&d2) @@ -149,7 +107,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien }, nil) require.NoError(t, err) - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) d = coap.DetailedResponse[map[string]interface{}]{} err = res(&d) @@ -161,7 +119,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien etag1 = d.ETag } - res, err = h2.waitForNotification(ctx) + res, err = h2.WaitForNotification(ctx) require.NoError(t, err) d2 = coap.DetailedResponse[map[string]interface{}]{} err = res(&d2) @@ -179,7 +137,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien }, nil) assert.NoError(t, err) - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) d = coap.DetailedResponse[map[string]interface{}]{} err = res(&d) @@ -190,7 +148,7 @@ func runObservingResourceTest(ctx context.Context, t *testing.T, c *client.Clien require.NotEqual(t, etag1, d.ETag) } - res, err = h2.waitForNotification(ctx) + res, err = h2.WaitForNotification(ctx) require.NoError(t, err) d2 = coap.DetailedResponse[map[string]interface{}]{} err = res(&d2) @@ -209,11 +167,11 @@ func TestObservingResource(t *testing.T) { func TestObservingDiscoveryResource(t *testing.T) { testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() id, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h) require.NoError(t, err) d := coap.DetailedResponse[schema.ResourceLinks]{} - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d) require.NoError(t, err) @@ -225,7 +183,7 @@ func TestObservingDiscoveryResource(t *testing.T) { d.Body.Sort() checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { // oic/res doesn't support observation return @@ -238,7 +196,7 @@ func TestObservingDiscoveryResource(t *testing.T) { const switchID = "1" createSwitches(ctx, t, c, deviceID, 1) d1 := coap.DetailedResponse[schema.ResourceLinks]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d1) require.NoError(t, err) @@ -260,7 +218,7 @@ func TestObservingDiscoveryResource(t *testing.T) { err = c.DeleteResource(ctx, deviceID, test.TestResourceSwitchesInstanceHref(switchID), nil) require.NoError(t, err) d2 := coap.DetailedResponse[schema.ResourceLinks]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -276,11 +234,11 @@ func TestObservingDiscoveryResource(t *testing.T) { func TestObservingDiscoveryResourceWithBaselineInterface(t *testing.T) { testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() id, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_BASELINE)) require.NoError(t, err) d := coap.DetailedResponse[resources.BaselineResourceDiscovery]{} - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d) require.NoError(t, err) @@ -292,7 +250,7 @@ func TestObservingDiscoveryResourceWithBaselineInterface(t *testing.T) { d.Body[0].Links.Sort() checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { // oic/res doesn't support observation return @@ -305,7 +263,7 @@ func TestObservingDiscoveryResourceWithBaselineInterface(t *testing.T) { const switchID = "1" createSwitches(ctx, t, c, deviceID, 1) d1 := coap.DetailedResponse[resources.BaselineResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d1) require.NoError(t, err) @@ -328,7 +286,7 @@ func TestObservingDiscoveryResourceWithBaselineInterface(t *testing.T) { err = c.DeleteResource(ctx, deviceID, test.TestResourceSwitchesInstanceHref(switchID), nil) require.NoError(t, err) d2 := coap.DetailedResponse[resources.BaselineResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -345,21 +303,21 @@ func TestObservingDiscoveryResourceWithBaselineInterface(t *testing.T) { func TestObservingNonObservableResource(t *testing.T) { testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() _, err := c.ObserveResource(ctx, deviceID, maintenance.ResourceURI, h) require.NoError(t, err) d := coap.DetailedResponse[maintenance.Maintenance]{} // resource is not observable so action (close/event) depends on goroutine scheduler which is not deterministic select { - case e := <-h.res: + case e := <-h.Res: err = e(&d) require.NoError(t, err) if ETagSupported { require.NotEmpty(t, d.ETag) } - err = h.waitForClose(ctx) + err = h.WaitForClose(ctx) require.NoError(t, err) - case <-h.close: + case <-h.Close: // if close comes first, then event is not received case <-ctx.Done(): require.NoError(t, ctx.Err()) @@ -428,11 +386,11 @@ var expectedFirstBatchResources = []batchResource{ func TestObservingDiscoveryResourceWithBatchInterface(t *testing.T) { testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() id, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B)) require.NoError(t, err) var d coap.DetailedResponse[resources.BatchResourceDiscovery] - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d) require.NoError(t, err) @@ -444,7 +402,7 @@ func TestObservingDiscoveryResourceWithBatchInterface(t *testing.T) { } checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { // oic/res doesn't support observation return @@ -457,7 +415,7 @@ func TestObservingDiscoveryResourceWithBatchInterface(t *testing.T) { const switchID = "1" createSwitches(ctx, t, c, deviceID, 1) var d1 coap.DetailedResponse[resources.BatchResourceDiscovery] - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d1) require.NoError(t, err) @@ -473,7 +431,7 @@ func TestObservingDiscoveryResourceWithBatchInterface(t *testing.T) { err = c.DeleteResource(ctx, deviceID, test.TestResourceSwitchesInstanceHref(switchID), nil) require.NoError(t, err) var d2 coap.DetailedResponse[resources.BatchResourceDiscovery] - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -508,11 +466,11 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnUpdate(t *testing.T) { } testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() obsID, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B)) require.NoError(t, err) d0 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d0) require.NoError(t, err) @@ -530,7 +488,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnUpdate(t *testing.T) { checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { t.Skip("oic/res doesn't support observation") return @@ -548,7 +506,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnUpdate(t *testing.T) { require.NoError(t, errU) }() - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) d1 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} err = res(&d1) @@ -568,7 +526,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnUpdate(t *testing.T) { obsID, err = c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B), client.WithQuery(queries[0])) require.NoError(t, err) d2 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -588,7 +546,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnUpdate(t *testing.T) { require.NoError(t, errC) }(obsID) d3 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d3) require.NoError(t, err) @@ -606,11 +564,11 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnCreate(t *testing.T) { } testDevice(t, test.DevsimName, func(ctx context.Context, t *testing.T, c *client.Client, deviceID string) { - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() obsID, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B)) require.NoError(t, err) d0 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d0) require.NoError(t, err) @@ -628,7 +586,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnCreate(t *testing.T) { checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { t.Skip("oic/res doesn't support observation") return @@ -641,7 +599,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnCreate(t *testing.T) { require.NoError(t, errD) }() - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) d1 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} err = res(&d1) @@ -661,7 +619,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnCreate(t *testing.T) { obsID, err = c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B), client.WithQuery(queries[0])) require.NoError(t, err) d2 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -681,7 +639,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnCreate(t *testing.T) { require.NoError(t, errC) }(obsID) d3 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d3) require.NoError(t, err) @@ -712,11 +670,11 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnDelete(t *testing.T) { require.NoError(t, errs.ErrorOrNil()) }() - h := makeObservationHandler() + h := testClient.MakeMockResourceObservationHandler() obsID, err := c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B)) require.NoError(t, err) d0 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err := h.waitForNotification(ctx) + res, err := h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d0) require.NoError(t, err) @@ -735,7 +693,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnDelete(t *testing.T) { checkForNonObservationCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = h.waitForClose(checkForNonObservationCtx) + err = h.WaitForClose(checkForNonObservationCtx) if err == nil { t.Skip("oic/res doesn't support observation") return @@ -745,7 +703,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnDelete(t *testing.T) { require.NoError(t, errD) toDelete = nil - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) d1 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} err = res(&d1) @@ -765,7 +723,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnDelete(t *testing.T) { obsID, err = c.ObserveResource(ctx, deviceID, resources.ResourceURI, h, client.WithInterface(interfaces.OC_IF_B), client.WithQuery(queries[0])) require.NoError(t, err) d2 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d2) require.NoError(t, err) @@ -785,7 +743,7 @@ func TestObserveDiscoveryResourceWithIncrementalChangesOnDelete(t *testing.T) { require.NoError(t, errC) }(obsID) d3 := coap.DetailedResponse[resources.BatchResourceDiscovery]{} - res, err = h.waitForNotification(ctx) + res, err = h.WaitForNotification(ctx) require.NoError(t, err) err = res(&d3) require.NoError(t, err) diff --git a/client/offboardDevice.go b/client/offboardDevice.go index 0fc613d9..67de91be 100644 --- a/client/offboardDevice.go +++ b/client/offboardDevice.go @@ -18,6 +18,8 @@ package client import ( "context" + + "github.com/plgd-dev/device/v2/pkg/net/coap" ) // OffboardDevice disconnects from the cloud and removes cloud configuration at the device. @@ -29,5 +31,9 @@ func (c *Client) OffboardDevice(ctx context.Context, deviceID string, opts ...Co return err } - return setCloudResource(ctx, links, d, "", "", "", "") + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + + return setCloudResource(ctx, links, d, "", "", "", "", cfg.opts...) } diff --git a/client/onboardDevice.go b/client/onboardDevice.go index 06b068d3..dcc04503 100644 --- a/client/onboardDevice.go +++ b/client/onboardDevice.go @@ -21,14 +21,16 @@ import ( "fmt" "github.com/plgd-dev/device/v2/client/core" + "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/acl" "github.com/plgd-dev/device/v2/schema/cloud" "github.com/plgd-dev/device/v2/schema/maintenance" "github.com/plgd-dev/device/v2/schema/softwareupdate" + "github.com/plgd-dev/go-coap/v3/message" ) -func setCloudResource(ctx context.Context, links schema.ResourceLinks, d *core.Device, authorizationProvider, authorizationCode, cloudURL, cloudID string) error { +func setCloudResource(ctx context.Context, links schema.ResourceLinks, d *core.Device, authorizationProvider, authorizationCode, cloudURL, cloudID string, options ...coap.OptionFunc) error { ob := cloud.ConfigurationUpdateRequest{ AuthorizationProvider: authorizationProvider, AuthorizationCode: authorizationCode, @@ -37,13 +39,13 @@ func setCloudResource(ctx context.Context, links schema.ResourceLinks, d *core.D } for _, l := range links.GetResourceLinks(cloud.ResourceType) { - return d.UpdateResource(ctx, l, ob, nil) + return d.UpdateResource(ctx, l, ob, nil, options...) } return fmt.Errorf("cloud resource not found") } -func setACLForCloud(ctx context.Context, p *core.ProvisioningClient, cloudID string, links schema.ResourceLinks) error { +func setACLForCloud(ctx context.Context, p *core.ProvisioningClient, cloudID string, links schema.ResourceLinks, opts []func(message.Options) message.Options) error { link, err := core.GetResourceLink(links, acl.ResourceURI) if err != nil { return err @@ -77,7 +79,7 @@ func setACLForCloud(ctx context.Context, p *core.ProvisioningClient, cloudID str }, } - return p.UpdateResource(ctx, link, cloudACL, nil) + return p.UpdateResource(ctx, link, cloudACL, nil, opts...) } // OnboardDevice connects device to the cloud. @@ -88,14 +90,18 @@ func (c *Client) OnboardDevice( opts ...CommonCommandOption, ) error { cfg := applyCommonOptions(opts...) - d, links, err := c.GetDevice(ctx, deviceID, WithDiscoveryConfiguration(cfg.discoveryConfiguration)) + d, links, err := c.GetDevice(ctx, deviceID, cfg) if err != nil { return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } + ok := d.IsSecured() if ok { - p, err := d.Provision(ctx, links) + p, err := d.Provision(ctx, links, cfg.opts...) if err != nil { return err } @@ -105,7 +111,7 @@ func (c *Client) OnboardDevice( } }() - err = setACLForCloud(ctx, p, cloudID, links) + err = setACLForCloud(ctx, p, cloudID, links, cfg.opts) if err != nil { return err } @@ -117,5 +123,5 @@ func (c *Client) OnboardDevice( CloudID: cloudID, }) } - return setCloudResource(ctx, links, d, authorizationProvider, authCode, cloudURL, cloudID) + return setCloudResource(ctx, links, d, authorizationProvider, authCode, cloudURL, cloudID, cfg.opts...) } diff --git a/client/onboardDevice_test.go b/client/onboardDevice_test.go index bf32ed87..0f9af3ac 100644 --- a/client/onboardDevice_test.go +++ b/client/onboardDevice_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -62,13 +63,13 @@ func TestClientOnboardDevice(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() deviceID, err = c.OwnDevice(ctx, deviceID) require.NoError(t, err) diff --git a/client/options.go b/client/options.go index 8647bd97..ed49f1ec 100644 --- a/client/options.go +++ b/client/options.go @@ -22,6 +22,7 @@ import ( "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/pkg/net/coap" "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/go-coap/v3/message" ) // WithQuery updates/gets resource with a query directly from a device. @@ -31,6 +32,14 @@ func WithQuery(resourceQuery string) ResourceQueryOption { } } +// WithDeviceID updates/gets resource with deviceID also in a query parameter for the bridged devices. +// Note: it is not needed when client has been created with UseDeviceIDInQuery. +func WithDeviceID(deviceID string) ResourceQueryOption { + return ResourceQueryOption{ + resourceQuery: "di=" + deviceID, + } +} + // WithInterface updates/gets resource with interface directly from a device. func WithInterface(resourceInterface string) ResourceInterfaceOption { return ResourceInterfaceOption{ @@ -38,12 +47,20 @@ func WithInterface(resourceInterface string) ResourceInterfaceOption { } } -func WithGetDetails(getDetails func(ctx context.Context, d *core.Device, links schema.ResourceLinks) (interface{}, error)) GetDetailsOption { +func WithGetDetails(getDetails func(ctx context.Context, d *core.Device, links schema.ResourceLinks, optsArgs ...func(message.Options) message.Options) (interface{}, error)) GetDetailsOption { return GetDetailsOption{ getDetails: getDetails, } } +// WithUseDeviceID appends deviceID to all get requests during discovery. +// Note: it is not needed when client has been created with UseDeviceIDInQuery. +func WithUseDeviceID(useDeviceID bool) UseDeviceIDInQueryOption { + return UseDeviceIDInQueryOption{ + UseDeviceID: useDeviceID, + } +} + func WithCodec(codec coap.Codec) CodecOption { return CodecOption{ codec: codec, @@ -99,6 +116,15 @@ func WithOTM(otmType OTMType) OwnOption { } } +type UseDeviceIDInQueryOption struct { + UseDeviceID bool +} + +func (r UseDeviceIDInQueryOption) applyOnGetDevices(opts getDevicesOptions) getDevicesOptions { + opts.useDeviceID = r.UseDeviceID + return opts +} + type ResourceETagOption struct { etag []byte } @@ -161,6 +187,48 @@ func (r ResourceQueryOption) applyOnObserve(opts observeOptions) observeOptions return opts } +func (r ResourceQueryOption) applyOnUpdate(opts updateOptions) updateOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + +func (r ResourceQueryOption) applyOnCreate(opts createOptions) createOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + +func (r ResourceQueryOption) applyOnDelete(opts deleteOptions) deleteOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + +func (r ResourceQueryOption) applyOnCommonCommand(opts commonCommandOptions) commonCommandOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + +func (r ResourceQueryOption) applyOnGetDevice(opts getDeviceOptions) getDeviceOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + +func (r ResourceQueryOption) applyOnGetDeviceByIP(opts getDeviceByIPOptions) getDeviceByIPOptions { + if r.resourceQuery != "" { + opts.opts = append(opts.opts, coap.WithQuery(r.resourceQuery)) + } + return opts +} + type DiscoveryConfigurationOption struct { cfg core.DiscoveryConfiguration } @@ -294,7 +362,7 @@ type ObserveDevicesOption = interface { applyOnObserveDevices(opts observeDevicesOptions) observeDevicesOptions } -type GetDetailsFunc = func(context.Context, *core.Device, schema.ResourceLinks) (interface{}, error) +type GetDetailsFunc = func(context.Context, *core.Device, schema.ResourceLinks, ...func(message.Options) message.Options) (interface{}, error) type GetDetailsOption struct { getDetails GetDetailsFunc @@ -319,15 +387,18 @@ type getDevicesOptions struct { resourceTypes []string getDetails GetDetailsFunc discoveryConfiguration core.DiscoveryConfiguration + useDeviceID bool } type getDeviceOptions struct { getDetails GetDetailsFunc discoveryConfiguration core.DiscoveryConfiguration + opts []coap.OptionFunc } type getDeviceByIPOptions struct { getDetails GetDetailsFunc + opts []coap.OptionFunc } type getDevicesWithHandlerOptions struct { @@ -451,6 +522,7 @@ func (r otmOption) applyOnOwn(opts ownOptions) ownOptions { type commonCommandOptions struct { discoveryConfiguration core.DiscoveryConfiguration + opts []coap.OptionFunc } // CommonCommandOption option definition. @@ -467,3 +539,8 @@ func applyCommonOptions(opts ...CommonCommandOption) commonCommandOptions { } return cfg } + +func (r commonCommandOptions) applyOnGetDevice(opts getDeviceOptions) getDeviceOptions { + opts.discoveryConfiguration = r.discoveryConfiguration + return opts +} diff --git a/client/ownDevice_test.go b/client/ownDevice_test.go index c7fde194..ec4fcc70 100644 --- a/client/ownDevice_test.go +++ b/client/ownDevice_test.go @@ -24,6 +24,7 @@ import ( "github.com/plgd-dev/device/v2/client" "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/stretchr/testify/require" ) @@ -45,7 +46,7 @@ func TestClientOwnDevice(t *testing.T) { }, } - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) diff --git a/client/updateResource.go b/client/updateResource.go index 0e22d049..7cad1755 100644 --- a/client/updateResource.go +++ b/client/updateResource.go @@ -21,6 +21,7 @@ import ( "github.com/plgd-dev/device/v2/client/core" codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/net/coap" ) // UpdateResource updates the device resource. @@ -51,6 +52,9 @@ func (c *Client) UpdateResource( if err != nil { return err } + if c.useDeviceIDInQuery { + cfg.opts = append(cfg.opts, coap.WithDeviceID(deviceID)) + } return d.UpdateResourceWithCodec(ctx, link, cfg.codec, request, response, cfg.opts...) } diff --git a/client/updateResource_test.go b/client/updateResource_test.go index 5d0e03f9..137fd196 100644 --- a/client/updateResource_test.go +++ b/client/updateResource_test.go @@ -27,6 +27,7 @@ import ( "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/test" + testClient "github.com/plgd-dev/device/v2/test/client" "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/stretchr/testify/require" ) @@ -115,10 +116,10 @@ func TestClientUpdateResource(t *testing.T) { wantErr: true, }, } - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) @@ -227,10 +228,10 @@ func TestClientUpdateResourceInRFOTM(t *testing.T) { wantErr: true, }, } - ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + ctx, cancel := context.WithTimeout(context.Background(), test.TestTimeout) defer cancel() - c, err := NewTestSecureClient() + c, err := testClient.NewTestSecureClient() require.NoError(t, err) defer func() { errC := c.Close(context.Background()) diff --git a/cmd/bridge-device/Dockerfile b/cmd/bridge-device/Dockerfile new file mode 100644 index 00000000..4828f1a7 --- /dev/null +++ b/cmd/bridge-device/Dockerfile @@ -0,0 +1,24 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.20.13-alpine AS build +RUN apk add --no-cache curl git build-base +WORKDIR $GOPATH/src/github.com/plgd-dev/device +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +WORKDIR /usr/local/go +RUN patch -p1 < "${GOPATH}/src/github.com/plgd-dev/device/tools/docker/patches/shrink_tls_conn.patch" +WORKDIR $GOPATH/src/github.com/plgd-dev/device +RUN CGO_ENABLED=0 go build -o /go/bin/bridge-device ./cmd/bridge-device + +FROM alpine:3.19 AS security-provider +RUN apk add -U --no-cache ca-certificates +RUN addgroup -S nonroot \ + && adduser -S nonroot -G nonroot + +FROM scratch AS service +COPY --from=security-provider /etc/passwd /etc/passwd +COPY --from=security-provider /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ +COPY --from=build /go/bin/bridge-device /usr/local/bin/bridge-device +USER nonroot +COPY ./cmd/bridge-device/config.yaml /config.yaml +ENTRYPOINT [ "/usr/local/bin/bridge-device" ] diff --git a/cmd/bridge-device/config.go b/cmd/bridge-device/config.go new file mode 100644 index 00000000..d6879a40 --- /dev/null +++ b/cmd/bridge-device/config.go @@ -0,0 +1,64 @@ +package main + +import ( + "fmt" + + "github.com/plgd-dev/device/v2/bridge/service" + "github.com/plgd-dev/device/v2/pkg/log" +) + +type TLSConfig struct { + CAPoolPath string `yaml:"caPoolPath" json:"caPool" description:"file path to the root certificates in PEM format"` + KeyPath string `yaml:"keyPath" json:"keyFile" description:"file path to the private key in PEM format"` + CertPath string `yaml:"certPath" json:"certFile" description:"file path to the certificate in PEM format"` + UseSystemCAPool bool `yaml:"useSystemCAPool" json:"useSystemCaPool" description:"use system certification pool"` +} + +func (c *TLSConfig) Validate() error { + if c.CAPoolPath == "" && !c.UseSystemCAPool { + return fmt.Errorf("caPool is required") + } + if (c.KeyPath == "" && c.CertPath != "") || (c.KeyPath != "" && c.CertPath == "") { + return fmt.Errorf("keyFile and certFile must be set together") + } + return nil +} + +type LogConfig struct { + Level log.Level `yaml:"level" json:"level" description:"log level"` +} + +type CloudConfig struct { + Enabled bool `yaml:"enabled" json:"enabled" description:"enable cloud connection"` + TLS TLSConfig `yaml:"tls" json:"tls"` +} + +type CredentialConfig struct { + Enabled bool `yaml:"enabled" json:"enabled" description:"enable credential manager"` +} + +func (c *CloudConfig) Validate() error { + if c.Enabled { + return c.TLS.Validate() + } + return nil +} + +type Config struct { + service.Config `yaml:",inline"` + Log LogConfig `yaml:"log" json:"log"` + Cloud CloudConfig `yaml:"cloud" json:"cloud"` + Credential CredentialConfig `yaml:"credential" json:"credential"` + NumGeneratedBridgedDevices int `yaml:"numGeneratedBridgedDevices"` + NumResourcesPerDevice int `yaml:"numResourcesPerDevice"` +} + +func (c *Config) Validate() error { + if err := c.Config.Validate(); err != nil { + return err + } + if c.NumGeneratedBridgedDevices <= 0 { + return fmt.Errorf("numGeneratedBridgedDevices - must be > 0") + } + return nil +} diff --git a/cmd/bridge-device/config.yaml b/cmd/bridge-device/config.yaml new file mode 100644 index 00000000..e0b08dad --- /dev/null +++ b/cmd/bridge-device/config.yaml @@ -0,0 +1,16 @@ +apis: + coap: + id: 8f596b43-29c0-4147-8b40-e99268ab30f7 + name: "my-bridge-example" + externalAddresses: + - "127.0.0.1:15683" + - "[::1]:15683" + maxMessageSize: 2097152 +log: + level: "info" +cloud: + enabled: true +credential: + enabled: true +numGeneratedBridgedDevices: 3 +numResourcesPerDevice: 16 diff --git a/cmd/bridge-device/main.go b/cmd/bridge-device/main.go new file mode 100644 index 00000000..bf7fab38 --- /dev/null +++ b/cmd/bridge-device/main.go @@ -0,0 +1,256 @@ +package main + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "flag" + "fmt" + "os" + "os/signal" + "path/filepath" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/resources" + "github.com/plgd-dev/device/v2/bridge/service" + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" + "github.com/plgd-dev/device/v2/pkg/log" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" + "github.com/plgd-dev/device/v2/schema/interfaces" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "gopkg.in/yaml.v3" +) + +func loadConfig(configFile string) (Config, error) { + // Sanitize the configFile variable to ensure it only contains a valid file path + configFile = filepath.Clean(configFile) + f, err := os.Open(configFile) + if err != nil { + return Config{}, err + } + defer func() { + _ = f.Close() + }() + var cfg Config + err = yaml.NewDecoder(f).Decode(&cfg) + if err != nil { + return Config{}, err + } + if err = cfg.Validate(); err != nil { + return Config{}, err + } + return cfg, nil +} + +type resourceData struct { + Name string `json:"name,omitempty"` +} + +type resourceDataSync struct { + resourceData + lock sync.Mutex +} + +func (r *resourceDataSync) setName(name string) { + r.lock.Lock() + defer r.lock.Unlock() + r.Name = name +} + +func (r *resourceDataSync) copy() resourceData { + r.lock.Lock() + defer r.lock.Unlock() + return resourceData{ + Name: r.Name, + } +} + +func addResources(d service.Device, numResources int) { + if numResources <= 0 { + return + } + obsWatcher := coapSync.NewMap[uint64, func()]() + for i := 0; i < numResources; i++ { + addResource(d, i, obsWatcher) + } + go func() { + for range time.After(time.Millisecond * 500) { + obsWatcher.Range(func(key uint64, h func()) bool { + h() + return true + }) + } + }() +} + +func addResource(d service.Device, idx int, obsWatcher *coapSync.Map[uint64, func()]) { + rds := resourceDataSync{ + resourceData: resourceData{ + Name: fmt.Sprintf("test-%v", idx), + }, + } + + resHandler := func(req *net.Request) (*pool.Message, error) { + resp := pool.NewMessage(req.Context()) + switch req.Code() { + case codes.GET: + resp.SetCode(codes.Content) + case codes.POST: + resp.SetCode(codes.Changed) + default: + return nil, fmt.Errorf("invalid method %v", req.Code()) + } + resp.SetContentFormat(message.AppOcfCbor) + data, err := cbor.Encode(rds.copy()) + if err != nil { + return nil, err + } + resp.SetBody(bytes.NewReader(data)) + return resp, nil + } + + var subID atomic.Uint64 + res := resources.NewResource(fmt.Sprintf("/test/%d", idx), resHandler, func(req *net.Request) (*pool.Message, error) { + codec := codecOcf.VNDOCFCBORCodec{} + var newData resourceData + err := codec.Decode(req.Message, &newData) + if err != nil { + return nil, err + } + rds.setName(newData.Name) + return resHandler(req) + }, []string{"x.plgd.test"}, []string{interfaces.OC_IF_BASELINE, interfaces.OC_IF_RW}) + res.SetObserveHandler(func(req *net.Request, handler func(msg *pool.Message, err error)) (cancel func(), err error) { + sub := subID.Add(1) + obsWatcher.Store(sub, func() { + resp, err := resHandler(req) + if err != nil { + handler(nil, err) + return + } + handler(resp, nil) + }) + return func() { + obsWatcher.Delete(sub) + }, nil + }) + d.AddResource(res) +} + +func getCloudTLS(cfg CloudConfig, credentialEnabled bool) (cloud.CAPool, *tls.Certificate, error) { + var ca []*x509.Certificate + var err error + if cfg.TLS.CAPoolPath == "" && !credentialEnabled { + return cloud.CAPool{}, nil, fmt.Errorf("cannot load ca: caPoolPath is empty") + } + if cfg.TLS.CAPoolPath != "" { + ca, err = pkgX509.ReadPemCertificates(cfg.TLS.CAPoolPath) + if err != nil { + return cloud.CAPool{}, nil, fmt.Errorf("cannot load ca('%v'): %w", cfg.TLS.CAPoolPath, err) + } + } + caPool := cloud.MakeCAPool(func() []*x509.Certificate { + return ca + }, cfg.TLS.UseSystemCAPool) + + if cfg.TLS.KeyPath == "" { + return caPool, nil, nil + } + + cert, err := tls.LoadX509KeyPair(cfg.TLS.CertPath, cfg.TLS.KeyPath) + if err != nil { + return cloud.CAPool{}, nil, fmt.Errorf("cannot load cert(%v, %v): %w", cfg.TLS.CertPath, cfg.TLS.KeyPath, err) + } + return caPool, &cert, nil +} + +func handleSignals(s *service.Service) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + for sig := range sigCh { + switch sig { + case syscall.SIGINT: + os.Exit(0) + return + case syscall.SIGTERM: + _ = s.Shutdown() + return + } + } +} + +func main() { + configFile := flag.String("config", "config.yaml", "path to config file") + flag.Parse() + cfg, err := loadConfig(*configFile) + if err != nil { + panic(err) + } + s, err := service.New(cfg.Config, service.WithLogger(log.NewStdLogger(cfg.Log.Level))) + if err != nil { + panic(err) + } + + opts := []device.Option{ + device.WithGetAdditionalPropertiesForResponse(func() map[string]interface{} { + return map[string]interface{}{ + "my-property": "my-value", + } + }), + } + if cfg.Cloud.Enabled { + caPool, cert, errC := getCloudTLS(cfg.Cloud, cfg.Credential.Enabled) + if errC != nil { + panic(errC) + } + opts = append(opts, device.WithCAPool(caPool)) + if cert != nil { + opts = append(opts, device.WithGetCertificates(func(string) []tls.Certificate { + return []tls.Certificate{*cert} + })) + } + } + + for i := 0; i < cfg.NumGeneratedBridgedDevices; i++ { + newDevice := func(id uuid.UUID, piid uuid.UUID) (service.Device, error) { + return device.New(device.Config{ + Name: fmt.Sprintf("bridged-device-%d", i), + ResourceTypes: []string{"oic.d.virtual"}, + ID: id, + ProtocolIndependentID: piid, + MaxMessageSize: cfg.Config.API.CoAP.MaxMessageSize, + Cloud: device.CloudConfig{ + Enabled: cfg.Cloud.Enabled, + }, + Credential: device.CredentialConfig{ + Enabled: cfg.Credential.Enabled, + }, + }, append(opts, device.WithLogger(device.NewLogger(id, cfg.Log.Level)))...) + } + d, errC := s.CreateDevice(uuid.New(), newDevice) + if errC == nil { + addResources(d, cfg.NumResourcesPerDevice) + d.Init() + } + } + + go func() { + handleSignals(s) + }() + + if err = s.Serve(); err != nil { + panic(err) + } +} diff --git a/cmd/ocfclient/main.go b/cmd/ocfclient/main.go index 22a109a6..3458c197 100644 --- a/cmd/ocfclient/main.go +++ b/cmd/ocfclient/main.go @@ -37,7 +37,7 @@ import ( "github.com/plgd-dev/device/v2/pkg/codec/json" "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/device/v2/pkg/security/signer" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) type Options struct { @@ -387,11 +387,11 @@ func generateIntermediateCertificate(certConfig generateCertificate.Configuratio if err != nil { return err } - signerCert, err := security.LoadX509(signCert) + signerCert, err := pkgX509.ReadPemCertificates(signCert) if err != nil { return err } - signerKey, err := security.LoadX509PrivateKey(signKey) + signerKey, err := pkgX509.ReadPemEcdsaPrivateKey(signKey) if err != nil { return err } @@ -415,11 +415,11 @@ func generateIdentityCertificate(certConfig generateCertificate.Configuration, i if err != nil { return err } - signerCert, err := security.LoadX509(signCert) + signerCert, err := pkgX509.ReadPemCertificates(signCert) if err != nil { return err } - signerKey, err := security.LoadX509PrivateKey(signKey) + signerKey, err := pkgX509.ReadPemEcdsaPrivateKey(signKey) if err != nil { return err } diff --git a/go.mod b/go.mod index 363bd6d8..a46dd897 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( go.uber.org/atomic v1.11.0 golang.org/x/sync v0.6.0 google.golang.org/grpc v1.61.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -26,18 +27,16 @@ require ( github.com/dsnet/golib/memfile v1.0.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pion/transport/v3 v3.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect - go.uber.org/multierr v1.11.0 // indirect - go.uber.org/zap v1.26.0 // indirect golang.org/x/crypto v0.18.0 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/sys v0.16.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240125205218-1f4bbc51befe // indirect google.golang.org/protobuf v1.32.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) // note: github.com/pion/dtls/v2/pkg/net package is not yet available in release branches diff --git a/go.sum b/go.sum index a7397a45..3a1984a9 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,7 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -81,6 +82,7 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc= github.com/lestrrat-go/jwx v1.0.2/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0= github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8= @@ -143,16 +145,11 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= -go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.15.0/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= -go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= -go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/pkg/codec/cbor/codec.go b/pkg/codec/cbor/codec.go index bfadacf7..937c3eb2 100644 --- a/pkg/codec/cbor/codec.go +++ b/pkg/codec/cbor/codec.go @@ -26,12 +26,24 @@ import ( // Encode encodes v and returns bytes. func Encode(v interface{}) ([]byte, error) { - return cbor.Marshal(v) + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + err := WriteTo(buf, v) + if err != nil { + return nil, err + } + return buf.Bytes(), nil } // WriteTo writes v to writer. func WriteTo(w io.Writer, v interface{}) error { - return cbor.NewEncoder(w).Encode(v) + encOpts := cbor.EncOptions{ + Sort: cbor.SortBytewiseLexical, + } + encMode, err := encOpts.EncMode() + if err != nil { + return err + } + return encMode.NewEncoder(w).Encode(v) } // Decode decodes bytes and stores the result in v. diff --git a/pkg/log/log.go b/pkg/log/log.go new file mode 100644 index 00000000..e8ab4396 --- /dev/null +++ b/pkg/log/log.go @@ -0,0 +1,70 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package log + +type Logger interface { + Debug(string) + Info(string) + Warn(string) + Error(string) + Debugf(template string, args ...interface{}) + Infof(template string, args ...interface{}) + Warnf(template string, args ...interface{}) + Errorf(template string, args ...interface{}) +} + +type NilLogger struct{} + +var nilLogger = &NilLogger{} + +func NewNilLogger() *NilLogger { + return nilLogger +} + +func (*NilLogger) Debug(string) { + // no-op +} + +func (*NilLogger) Info(string) { + // no-op +} + +func (*NilLogger) Warn(string) { + // no-op +} + +func (*NilLogger) Error(string) { + // no-op +} + +func (*NilLogger) Debugf(string, ...interface{}) { + // no-op +} + +func (*NilLogger) Infof(string, ...interface{}) { + // no-op +} + +func (*NilLogger) Warnf(string, ...interface{}) { + // no-op +} + +func (*NilLogger) Errorf(string, ...interface{}) { + // no-op +} diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go new file mode 100644 index 00000000..d53486fb --- /dev/null +++ b/pkg/log/log_test.go @@ -0,0 +1,37 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package log_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/pkg/log" +) + +func TestNilLogger(*testing.T) { + l := log.NewNilLogger() + l.Debug("debug") + l.Info("info") + l.Warn("warn") + l.Error("error") + l.Debugf("debugf") + l.Infof("infof") + l.Warnf("warnf") + l.Errorf("errorf") +} diff --git a/pkg/log/stdLogger.go b/pkg/log/stdLogger.go new file mode 100644 index 00000000..0e65c5da --- /dev/null +++ b/pkg/log/stdLogger.go @@ -0,0 +1,149 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package log + +import ( + "bytes" + "fmt" + "log" +) + +type Level int8 + +const ( + LevelNone Level = iota + LevelDebug + LevelInfo + LevelWarn + LevelError +) + +func ParseLevel(text string) (Level, error) { + var level Level + err := level.UnmarshalText([]byte(text)) + return level, err +} + +// String returns a lower-case ASCII representation of the log level. +func (l Level) String() string { + switch l { + case LevelNone: + return "none" + case LevelDebug: + return "debug" + case LevelInfo: + return "info" + case LevelWarn: + return "warn" + case LevelError: + return "error" + default: + return fmt.Sprintf("Level(%d)", l) + } +} + +func (l Level) MarshalText() ([]byte, error) { + return []byte(l.String()), nil +} + +func (l *Level) UnmarshalText(text []byte) error { + if !l.unmarshalText(text) && !l.unmarshalText(bytes.ToLower(text)) { + return fmt.Errorf("unrecognized level: %q", text) + } + return nil +} + +func (l *Level) unmarshalText(text []byte) bool { + switch string(text) { + case "", "none": + *l = LevelNone + case "debug": + *l = LevelDebug + case "info": + *l = LevelInfo + case "warn": + *l = LevelWarn + case "error": + *l = LevelError + default: + return false + } + return true +} + +type StdLogger struct { + *log.Logger + level Level +} + +func NewStdLogger(logLevel Level) *StdLogger { + return &StdLogger{ + Logger: log.Default(), + level: logLevel, + } +} + +func (l *StdLogger) checkLevel(level Level) bool { + return l.level != LevelNone && l.level <= level +} + +func (l *StdLogger) LogWithLevel(level Level, msg string) { + if l.checkLevel(level) { + l.Println(msg) + } +} + +func (l *StdLogger) Debug(msg string) { + l.LogWithLevel(LevelDebug, msg) +} + +func (l *StdLogger) Info(msg string) { + l.LogWithLevel(LevelInfo, msg) +} + +func (l *StdLogger) Warn(msg string) { + l.LogWithLevel(LevelWarn, msg) +} + +func (l *StdLogger) Error(msg string) { + l.LogWithLevel(LevelError, msg) +} + +// LogfWithLevel uses fmt.Errorf to construct and log.Printf to log a message. +func (l *StdLogger) LogfWithLevel(level Level, format string, args ...interface{}) { + if l.checkLevel(level) { + l.Printf("%s\n", fmt.Errorf(format, args...)) + } +} + +func (l *StdLogger) Debugf(format string, args ...interface{}) { + l.LogfWithLevel(LevelDebug, format, args...) +} + +func (l *StdLogger) Infof(format string, args ...interface{}) { + l.LogfWithLevel(LevelInfo, format, args...) +} + +func (l *StdLogger) Warnf(format string, args ...interface{}) { + l.LogfWithLevel(LevelWarn, format, args...) +} + +func (l *StdLogger) Errorf(format string, args ...interface{}) { + l.LogfWithLevel(LevelError, format, args...) +} diff --git a/pkg/log/stdLogger_test.go b/pkg/log/stdLogger_test.go new file mode 100644 index 00000000..ab037db8 --- /dev/null +++ b/pkg/log/stdLogger_test.go @@ -0,0 +1,151 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package log_test + +import ( + "bytes" + "testing" + + "github.com/plgd-dev/device/v2/pkg/codec/json" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/stretchr/testify/require" +) + +func TestStdLogger(t *testing.T) { + var b bytes.Buffer + l := log.NewStdLogger(log.LevelNone) + l.SetOutput(&b) + l.Debug("debug") + l.Info("info") + l.Warn("warn") + l.Error("error") + l.Debugf("debugf") + l.Infof("infof") + l.Warnf("warnf") + l.Errorf("errorf") + require.Empty(t, b.String()) + + b.Reset() + l = log.NewStdLogger(log.LevelDebug) + l.Debug("debug") + l.Info("info") + l.Warn("warn") + l.Error("error") + l.Debugf("debugf") + l.Infof("infof") + l.Warnf("warnf") + l.Errorf("errorf") + require.Contains(t, b.String(), "debug") + require.Contains(t, b.String(), "info") + require.Contains(t, b.String(), "warn") + require.Contains(t, b.String(), "error") + require.Contains(t, b.String(), "debugf") + require.Contains(t, b.String(), "infof") + require.Contains(t, b.String(), "warnf") + require.Contains(t, b.String(), "errorf") + + b.Reset() + l = log.NewStdLogger(log.LevelError) + l.Debug("debug") + l.Info("info") + l.Warn("warn") + l.Error("error") + l.Debugf("debugf") + l.Infof("infof") + l.Warnf("warnf") + l.Errorf("errorf") + require.NotContains(t, b.String(), "debug") + require.NotContains(t, b.String(), "info") + require.NotContains(t, b.String(), "warn") + require.Contains(t, b.String(), "error") + require.NotContains(t, b.String(), "debugf") + require.NotContains(t, b.String(), "infof") + require.NotContains(t, b.String(), "warnf") + require.Contains(t, b.String(), "errorf") +} + +func TestJsonEncodeLevel(t *testing.T) { + b, err := json.Encode(log.LevelNone) + require.NoError(t, err) + require.Equal(t, []byte(`"none"`), b) + b, err = json.Encode(log.LevelDebug) + require.NoError(t, err) + require.Equal(t, []byte(`"debug"`), b) + b, err = json.Encode(log.LevelInfo) + require.NoError(t, err) + require.Equal(t, []byte(`"info"`), b) + b, err = json.Encode(log.LevelWarn) + require.NoError(t, err) + require.Equal(t, []byte(`"warn"`), b) + b, err = json.Encode(log.LevelError) + require.NoError(t, err) + require.Equal(t, []byte(`"error"`), b) +} + +func TestJsonDecodeLevel(t *testing.T) { + var v log.Level + err := json.Decode([]byte(`""`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelNone, v) + err = json.Decode([]byte(`"none"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelNone, v) + err = json.Decode([]byte(`"NONE"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelNone, v) + err = json.Decode([]byte(`"debug"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelDebug, v) + err = json.Decode([]byte(`"DEBUG"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelDebug, v) + err = json.Decode([]byte(`"info"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelInfo, v) + err = json.Decode([]byte(`"INFO"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelInfo, v) + err = json.Decode([]byte(`"warn"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelWarn, v) + err = json.Decode([]byte(`"WARN"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelWarn, v) + err = json.Decode([]byte(`"error"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelError, v) + err = json.Decode([]byte(`"ERROR"`), &v) + require.NoError(t, err) + require.Equal(t, log.LevelError, v) + + err = json.Decode([]byte(`"unknown"`), &v) + require.Error(t, err) +} + +func TestParseLevel(t *testing.T) { + lvl, err := log.ParseLevel("none") + require.NoError(t, err) + require.Equal(t, log.LevelNone, lvl) + lvl, err = log.ParseLevel("error") + require.NoError(t, err) + require.Equal(t, log.LevelError, lvl) + + _, err = log.ParseLevel("unknown") + require.Error(t, err) +} diff --git a/pkg/net/coap/client.go b/pkg/net/coap/client.go index ff0326bc..9af82f48 100644 --- a/pkg/net/coap/client.go +++ b/pkg/net/coap/client.go @@ -21,15 +21,12 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/asn1" "fmt" "io" "net" - "strings" "sync" "time" - "github.com/google/uuid" piondtls "github.com/pion/dtls/v2" codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf" "github.com/plgd-dev/go-coap/v3/dtls" @@ -69,70 +66,6 @@ type Codec interface { Decode(m *pool.Message, v interface{}) error } -var ExtendedKeyUsage_IDENTITY_CERTIFICATE = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44924, 1, 6} - -func GetDeviceIDFromIdentityCertificate(cert *x509.Certificate) (string, error) { - // verify EKU manually - ekuHasClient := false - for _, eku := range cert.ExtKeyUsage { - if eku == x509.ExtKeyUsageClientAuth { - ekuHasClient = true - break - } - } - if !ekuHasClient { - return "", fmt.Errorf("not contains ExtKeyUsageClientAuth") - } - ekuHasOcfID := false - for _, eku := range cert.UnknownExtKeyUsage { - if eku.Equal(ExtendedKeyUsage_IDENTITY_CERTIFICATE) { - ekuHasOcfID = true - break - } - } - if !ekuHasOcfID { - return "", fmt.Errorf("not contains ExtKeyUsage with OCF ID(1.3.6.1.4.1.44924.1.6") - } - cn := strings.Split(cert.Subject.CommonName, ":") - if len(cn) != 2 { - return "", fmt.Errorf("invalid subject common name: %v", cert.Subject.CommonName) - } - if strings.ToLower(cn[0]) != "uuid" { - return "", fmt.Errorf("invalid subject common name %v: 'uuid' - not found", cert.Subject.CommonName) - } - deviceID, err := uuid.Parse(cn[1]) - if err != nil { - return "", fmt.Errorf("invalid subject common name %v: %w", cert.Subject.CommonName, err) - } - return deviceID.String(), nil -} - -func VerifyIdentityCertificate(cert *x509.Certificate) error { - // verify EKU manually - ekuHasClient := false - ekuHasServer := false - for _, eku := range cert.ExtKeyUsage { - if eku == x509.ExtKeyUsageClientAuth { - ekuHasClient = true - } - if eku == x509.ExtKeyUsageServerAuth { - ekuHasServer = true - } - } - if !ekuHasClient { - return fmt.Errorf("not contains ExtKeyUsageClientAuth") - } - if !ekuHasServer { - return fmt.Errorf("not contains ExtKeyUsageServerAuth") - } - _, err := GetDeviceIDFromIdentityCertificate(cert) - if err != nil { - return err - } - - return nil -} - func NewClient(conn ClientConn) *Client { return &Client{conn: conn} } diff --git a/pkg/net/coap/verify.go b/pkg/net/coap/verify.go new file mode 100644 index 00000000..274773f6 --- /dev/null +++ b/pkg/net/coap/verify.go @@ -0,0 +1,120 @@ +// ************************************************************************ +// Copyright (C) 2022 plgd.dev, s.r.o. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ************************************************************************ + +package coap + +import ( + "crypto/x509" + "encoding/asn1" + "fmt" + "strings" + + "github.com/google/uuid" +) + +var ExtendedKeyUsage_IDENTITY_CERTIFICATE = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44924, 1, 6} + +func verifyOcfEKU(cert *x509.Certificate) error { + hasOcfID := false + for _, eku := range cert.UnknownExtKeyUsage { + if eku.Equal(ExtendedKeyUsage_IDENTITY_CERTIFICATE) { + hasOcfID = true + break + } + } + if !hasOcfID { + return fmt.Errorf("certificate does not contain ExtKeyUsage with OCF ID(1.3.6.1.4.1.44924.1.6") + } + return nil +} + +func verifyEKU(cert *x509.Certificate, requireClient, requireServer, requireOcfId bool) error { + hasClient := false + hasServer := false + for _, eku := range cert.ExtKeyUsage { + if eku == x509.ExtKeyUsageClientAuth { + hasClient = true + continue + } + if eku == x509.ExtKeyUsageServerAuth { + hasServer = true + continue + } + } + if requireClient && !hasClient { + return fmt.Errorf("certificate does not contain ExtKeyUsageClientAuth") + } + if requireServer && !hasServer { + return fmt.Errorf("certificate does not contain ExtKeyUsageServerAuth") + } + + if !requireOcfId { + return nil + } + return verifyOcfEKU(cert) +} + +func getUUIDFromSubjectCommonName(cert *x509.Certificate) (uuid.UUID, error) { + cn := strings.Split(cert.Subject.CommonName, ":") + if len(cn) != 2 { + return uuid.UUID{}, fmt.Errorf("invalid subject common name: %v", cert.Subject.CommonName) + } + if strings.ToLower(cn[0]) != "uuid" { + return uuid.UUID{}, fmt.Errorf("invalid subject common name %v: 'uuid' - not found", cert.Subject.CommonName) + } + id, err := uuid.Parse(cn[1]) + if err != nil { + return uuid.UUID{}, fmt.Errorf("invalid subject common name %v: %w", cert.Subject.CommonName, err) + } + return id, nil +} + +func GetDeviceIDFromIdentityCertificate(cert *x509.Certificate) (string, error) { + // verify EKU manually + if err := verifyEKU(cert, true, false, true); err != nil { + return "", err + } + deviceID, err := getUUIDFromSubjectCommonName(cert) + if err != nil { + return "", err + } + return deviceID.String(), nil +} + +func VerifyIdentityCertificate(cert *x509.Certificate) error { + if err := verifyEKU(cert, true, true, false); err != nil { + return err + } + if _, err := GetDeviceIDFromIdentityCertificate(cert); err != nil { + return err + } + return nil +} + +func VerifyCloudCertificate(cert *x509.Certificate, cloudID uuid.UUID) error { + if err := verifyEKU(cert, true, true, false); err != nil { + return err + } + id, err := getUUIDFromSubjectCommonName(cert) + if err != nil { + return err + } + + if id != cloudID { + return fmt.Errorf("invalid cloud certificate: invalid cloudID(%v)", id.String()) + } + return nil +} diff --git a/pkg/net/coap/verify_test.go b/pkg/net/coap/verify_test.go new file mode 100644 index 00000000..6d89c870 --- /dev/null +++ b/pkg/net/coap/verify_test.go @@ -0,0 +1,76 @@ +// ************************************************************************ +// Copyright (C) 2022 plgd.dev, s.r.o. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ************************************************************************ + +package coap_test + +import ( + "crypto/ecdsa" + "crypto/x509" + "testing" + "time" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/pkg/net/coap" + "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" + "github.com/stretchr/testify/require" +) + +func generateRootCA(t *testing.T, cfg generateCertificate.Configuration) ([]*x509.Certificate, *ecdsa.PrivateKey) { + key, err := cfg.GenerateKey() + require.NoError(t, err) + pem, err := generateCertificate.GenerateRootCA(cfg, key) + require.NoError(t, err) + certs, err := pkgX509.ParsePemCertificates(pem) + require.NoError(t, err) + return certs, key +} + +func TestVerifyIdentityCertificate(t *testing.T) { + cfg := generateCertificate.Configuration{ + ValidFor: time.Minute, + } + rootCA, rootCAKey := generateRootCA(t, cfg) + + key, err := cfg.GenerateKey() + require.NoError(t, err) + id := uuid.New().String() + certPem, err := generateCertificate.GenerateIdentityCert(cfg, id, key, rootCA, rootCAKey) + require.NoError(t, err) + cert, err := pkgX509.ParsePemCertificates(certPem) + require.NoError(t, err) + + err = coap.VerifyIdentityCertificate(cert[0]) + require.NoError(t, err) +} + +func TestVerifyCloudCertificate(t *testing.T) { + cfg := generateCertificate.Configuration{ + ValidFor: time.Minute, + } + rootCA, rootCAKey := generateRootCA(t, cfg) + + key, err := cfg.GenerateKey() + require.NoError(t, err) + cloudID := uuid.New() + certPem, err := generateCertificate.GenerateIdentityCert(cfg, cloudID.String(), key, rootCA, rootCAKey) + require.NoError(t, err) + cert, err := pkgX509.ParsePemCertificates(certPem) + require.NoError(t, err) + + err = coap.VerifyCloudCertificate(cert[0], cloudID) + require.NoError(t, err) +} diff --git a/pkg/ocf/cloud/request.go b/pkg/ocf/cloud/request.go new file mode 100644 index 00000000..c72885b0 --- /dev/null +++ b/pkg/ocf/cloud/request.go @@ -0,0 +1,74 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +import ( + "time" + + "github.com/plgd-dev/device/v2/schema" +) + +type CoapSignUpRequest struct { + DeviceID string `json:"di"` + AuthorizationCode string `json:"accesstoken"` + AuthorizationProvider string `json:"authprovider"` +} + +type CoapSignUpResponse struct { + AccessToken string `yaml:"accessToken" json:"accesstoken"` + UserID string `yaml:"userID" json:"uid"` + RefreshToken string `yaml:"refreshToken" json:"refreshtoken"` + RedirectURI string `yaml:"-" json:"redirecturi"` + ExpiresIn int64 `yaml:"-" json:"expiresin"` + ValidUntil time.Time `yaml:"-" jsom:"-"` +} + +type CoapSignInRequest struct { + DeviceID string `json:"di"` + UserID string `json:"uid"` + AccessToken string `json:"accesstoken"` + Login bool `json:"login"` +} + +type CoapSignInResponse struct { + ExpiresIn int64 `json:"expiresin"` +} + +type CoapRefreshTokenRequest struct { + DeviceID string `json:"di"` + UserID string `json:"uid"` + RefreshToken string `json:"refreshtoken"` +} + +type CoapRefreshTokenResponse struct { + AccessToken string `json:"accesstoken"` + RefreshToken string `json:"refreshtoken"` + ExpiresIn int64 `json:"expiresin"` +} + +type PublishResourcesRequest struct { + DeviceID string `json:"di"` + Links schema.ResourceLinks `json:"links"` + TimeToLive int `json:"ttl"` +} + +type UnpublishResourcesRequest struct { + DeviceID string + InstanceIDs []int64 +} diff --git a/pkg/ocf/cloud/uri.go b/pkg/ocf/cloud/uri.go new file mode 100644 index 00000000..e042f5da --- /dev/null +++ b/pkg/ocf/cloud/uri.go @@ -0,0 +1,29 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package cloud + +const ( + Base = "/oic" + Secure = Base + "/sec" + SignUp = Secure + "/account" + RefreshToken = Secure + "/tokenrefresh" + SignIn = Secure + "/session" + ResourceDirectory = Base + "/rd" + ResourceDiscovery = Base + "/res" +) diff --git a/pkg/security/generateCertificate/generateIntermediateCA.go b/pkg/security/generateCertificate/generateIntermediateCA.go index e21a1bf3..1b06d3f2 100644 --- a/pkg/security/generateCertificate/generateIntermediateCA.go +++ b/pkg/security/generateCertificate/generateIntermediateCA.go @@ -7,7 +7,7 @@ import ( "fmt" "math/big" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) func newCert(cfg Configuration) (*x509.Certificate, error) { @@ -57,5 +57,5 @@ func GenerateIntermediateCA(cfg Configuration, privateKey *ecdsa.PrivateKey, sig if err != nil { return nil, err } - return security.CreatePemChain(signerCA, der) + return pkgX509.CreatePemChain(signerCA, der) } diff --git a/pkg/security/generateCertificate/generateIntermediateCA_test.go b/pkg/security/generateCertificate/generateIntermediateCA_test.go index 3914b733..36fb5e4b 100644 --- a/pkg/security/generateCertificate/generateIntermediateCA_test.go +++ b/pkg/security/generateCertificate/generateIntermediateCA_test.go @@ -5,7 +5,7 @@ import ( "crypto/x509" "testing" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/stretchr/testify/require" ) @@ -14,7 +14,7 @@ func generateRootCA(t *testing.T, cfg Configuration) ([]*x509.Certificate, *ecds require.NoError(t, err) cert, err := GenerateRootCA(cfg, privateKey) require.NoError(t, err) - crt, err := security.ParseX509FromPEM(cert) + crt, err := pkgX509.ParsePemCertificates(cert) require.NoError(t, err) return crt, privateKey } diff --git a/pkg/security/signer/signer.go b/pkg/security/signer/signer.go index d4f4f719..0207fe17 100644 --- a/pkg/security/signer/signer.go +++ b/pkg/security/signer/signer.go @@ -28,7 +28,7 @@ import ( "time" "github.com/plgd-dev/device/v2/pkg/net/coap" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) type OCFIdentityCertificate struct { @@ -105,5 +105,5 @@ func (s *OCFIdentityCertificate) Sign(_ context.Context, csr []byte) (signedCsr if err != nil { return } - return security.CreatePemChain(s.caCert, signedCsr) + return pkgX509.CreatePemChain(s.caCert, signedCsr) } diff --git a/pkg/security/signer/signer_test.go b/pkg/security/signer/signer_test.go index ae91a163..2d6065b9 100644 --- a/pkg/security/signer/signer_test.go +++ b/pkg/security/signer/signer_test.go @@ -14,7 +14,7 @@ import ( "github.com/google/uuid" "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" "github.com/plgd-dev/device/v2/pkg/security/signer" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/stretchr/testify/require" ) @@ -28,9 +28,9 @@ func TestOCFIdentityCertificateSign(t *testing.T) { type args struct { csr []byte } - caCert, err := security.LoadX509(os.Getenv("ROOT_CA_CRT")) + caCert, err := pkgX509.ReadPemCertificates(os.Getenv("ROOT_CA_CRT")) require.NoError(t, err) - caKey, err := security.LoadX509PrivateKey(os.Getenv("ROOT_CA_KEY")) + caKey, err := pkgX509.ReadPemEcdsaPrivateKey(os.Getenv("ROOT_CA_KEY")) require.NoError(t, err) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) diff --git a/pkg/security/x509/parse.go b/pkg/security/x509/parse.go new file mode 100644 index 00000000..93964319 --- /dev/null +++ b/pkg/security/x509/parse.go @@ -0,0 +1,108 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package x509 + +import ( + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" +) + +// ParsePemCertificates parses x509 certificates from PEM format +func ParsePemCertificates(pemBlock []byte) ([]*x509.Certificate, error) { + data := pemBlock + var cas []*x509.Certificate + for { + certDERBlock, tmp := pem.Decode(data) + if certDERBlock == nil { + return nil, fmt.Errorf("cannot decode pem block") + } + certs, err := x509.ParseCertificates(certDERBlock.Bytes) + if err != nil { + return nil, err + } + cas = append(cas, certs...) + if len(tmp) == 0 { + break + } + data = tmp + } + return cas, nil +} + +// ReadPemCertificates reads certificates from file in PEM format +func ReadPemCertificates(path string) ([]*x509.Certificate, error) { + certPEMBlock, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return nil, err + } + return ParsePemCertificates(certPEMBlock) +} + +// ParsePemEcdsaPrivateKey parses private key from PEM format +func ParsePemEcdsaPrivateKey(pemBlock []byte) (*ecdsa.PrivateKey, error) { + derBlock, _ := pem.Decode(pemBlock) + if derBlock == nil { + return nil, fmt.Errorf("cannot decode pem block") + } + + if key, err := x509.ParsePKCS8PrivateKey(derBlock.Bytes); err == nil { + switch key := key.(type) { + case *ecdsa.PrivateKey: + return key, nil + default: + return nil, fmt.Errorf("found unknown private key type in PKCS#8 wrapping") + } + } + + if key, err := x509.ParseECPrivateKey(derBlock.Bytes); err == nil { + return key, nil + } + + return nil, fmt.Errorf("failed to parse private key") +} + +// ReadPemEcdsaPrivateKey loads private key from file in PEM format +func ReadPemEcdsaPrivateKey(path string) (*ecdsa.PrivateKey, error) { + certPEMBlock, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return nil, err + } + return ParsePemEcdsaPrivateKey(certPEMBlock) +} + +// ParseCertificates parses the CA chain certificates from the DER data. +func ParseCertificates(cert *tls.Certificate) ([]*x509.Certificate, error) { + caChain := make([]*x509.Certificate, 0, 4) + for _, derBytes := range cert.Certificate { + ca, err := x509.ParseCertificates(derBytes) + if err != nil { + return nil, err + } + caChain = append(caChain, ca...) + } + if len(caChain) == 0 { + return nil, fmt.Errorf("no certificates") + } + return caChain, nil +} diff --git a/pkg/security/x509/parse_test.go b/pkg/security/x509/parse_test.go new file mode 100644 index 00000000..0442610b --- /dev/null +++ b/pkg/security/x509/parse_test.go @@ -0,0 +1,168 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package x509_test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "os" + "testing" + + "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" + "github.com/stretchr/testify/require" +) + +func generateRootCA(t *testing.T, cfg generateCertificate.Configuration) ([]*x509.Certificate, *ecdsa.PrivateKey) { + key, err := cfg.GenerateKey() + require.NoError(t, err) + pem, err := generateCertificate.GenerateRootCA(cfg, key) + require.NoError(t, err) + certs, err := pkgX509.ParsePemCertificates(pem) + require.NoError(t, err) + return certs, key +} + +func TestParsePemCertificates(t *testing.T) { + generateRootCA(t, generateCertificate.Configuration{}) +} + +func TestParsePemCertificatesChain(t *testing.T) { + cfg := generateCertificate.Configuration{} + ca, caKey := generateRootCA(t, generateCertificate.Configuration{}) + + key, err := cfg.GenerateKey() + require.NoError(t, err) + got, err := generateCertificate.GenerateIntermediateCA(cfg, key, ca, caKey) + require.NoError(t, err) + _, err = pkgX509.ParsePemCertificates(got) + require.NoError(t, err) +} + +func encodeKeyToPem(t *testing.T, key *ecdsa.PrivateKey) []byte { + b, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} + pem := pem.EncodeToMemory(block) + return pem +} + +func TestParsePemCertificates_Fail(t *testing.T) { + // invalid input + _, err := pkgX509.ParsePemCertificates(nil) + require.Error(t, err) + + // pem encoded key instead of certificate + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + pem := encodeKeyToPem(t, key) + _, err = pkgX509.ParsePemCertificates(pem) + require.Error(t, err) +} + +func TestReadPemCertificates(t *testing.T) { + cfg := generateCertificate.Configuration{} + key, err := cfg.GenerateKey() + require.NoError(t, err) + pem, err := generateCertificate.GenerateRootCA(cfg, key) + require.NoError(t, err) + + testFilePath := "./test.pem" + err = os.WriteFile(testFilePath, pem, 0o600) + require.NoError(t, err) + defer func() { + _ = os.Remove(testFilePath) + }() + + _, err = pkgX509.ReadPemCertificates(testFilePath) + require.NoError(t, err) +} + +func TestReadPemCertificates_Fail(t *testing.T) { + _, err := pkgX509.ReadPemCertificates("") + require.Error(t, err) +} + +func TestParsePemEcdsaPrivateKey(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + keyPem := encodeKeyToPem(t, key) + _, err = pkgX509.ParsePemEcdsaPrivateKey(keyPem) + require.NoError(t, err) + + pkcsDer, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + block := &pem.Block{Type: "EC PRIVATE KEY", Bytes: pkcsDer} + pkcsPem := pem.EncodeToMemory(block) + _, err = pkgX509.ParsePemEcdsaPrivateKey(pkcsPem) + require.NoError(t, err) +} + +func TestParsePemEcdsaPrivateKey_Fail(t *testing.T) { + _, err := pkgX509.ParsePemEcdsaPrivateKey(nil) + require.Error(t, err) + + cfg := generateCertificate.Configuration{} + ecKey, err := cfg.GenerateKey() + require.NoError(t, err) + ecPem, err := generateCertificate.GenerateRootCA(cfg, ecKey) + require.NoError(t, err) + _, err = pkgX509.ParsePemEcdsaPrivateKey(ecPem) + require.Error(t, err) + + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pkcsDer, err := x509.MarshalPKCS8PrivateKey(rsaKey) + require.NoError(t, err) + block := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkcsDer} + pkcsPem := pem.EncodeToMemory(block) + _, err = pkgX509.ParsePemEcdsaPrivateKey(pkcsPem) + require.Error(t, err) +} + +func TestReadPemEcdsaPrivateKey(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + keyPem := encodeKeyToPem(t, key) + + testFilePath := "./test.pem" + err = os.WriteFile(testFilePath, keyPem, 0o600) + require.NoError(t, err) + defer func() { + _ = os.Remove(testFilePath) + }() + + _, err = pkgX509.ReadPemEcdsaPrivateKey(testFilePath) + require.NoError(t, err) +} + +func TestReadPemEcdsaPrivateKey_Fail(t *testing.T) { + _, err := pkgX509.ReadPemEcdsaPrivateKey("") + require.Error(t, err) +} + +func TestParseCertificates_Fail(t *testing.T) { + _, err := pkgX509.ParseCertificates(&tls.Certificate{}) + require.Error(t, err) +} diff --git a/pkg/security/x509/write.go b/pkg/security/x509/write.go new file mode 100644 index 00000000..5f00b2f6 --- /dev/null +++ b/pkg/security/x509/write.go @@ -0,0 +1,50 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package x509 + +import ( + "bytes" + "crypto/x509" + "encoding/pem" +) + +// CreatePemChain creates chain of PEM certificates. +func CreatePemChain(intermedateCAs []*x509.Certificate, cert []byte) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, 2048)) + + // encode cert + err := pem.Encode(buf, &pem.Block{ + Type: "CERTIFICATE", Bytes: cert, + }) + if err != nil { + return nil, err + } + + // encode intermediates + for _, ca := range intermedateCAs { + err := pem.Encode(buf, &pem.Block{ + Type: "CERTIFICATE", Bytes: ca.Raw, + }) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} diff --git a/pkg/security/x509/write_test.go b/pkg/security/x509/write_test.go new file mode 100644 index 00000000..27e1ae0a --- /dev/null +++ b/pkg/security/x509/write_test.go @@ -0,0 +1,46 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package x509_test + +import ( + "testing" + + "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" + "github.com/stretchr/testify/require" +) + +func TestCreatePemChain(t *testing.T) { + cfg := generateCertificate.Configuration{} + caKey, err := cfg.GenerateKey() + require.NoError(t, err) + caPem, err := generateCertificate.GenerateRootCA(cfg, caKey) + require.NoError(t, err) + ca, err := pkgX509.ParsePemCertificates(caPem) + require.NoError(t, err) + key, err := cfg.GenerateKey() + require.NoError(t, err) + intermediatePem, err := generateCertificate.GenerateIntermediateCA(cfg, key, ca, caKey) + require.NoError(t, err) + intermediate, err := pkgX509.ParsePemCertificates(intermediatePem) + require.NoError(t, err) + + _, err = pkgX509.CreatePemChain(intermediate, caPem) + require.NoError(t, err) +} diff --git a/schema/credential/credential.go b/schema/credential/credential.go index 40b09985..c0f8ca75 100644 --- a/schema/credential/credential.go +++ b/schema/credential/credential.go @@ -31,17 +31,17 @@ const ( ) type Credential struct { - ID int `json:"credid,omitempty"` - Type CredentialType `json:"credtype"` - Subject string `json:"subjectuuid"` - Usage CredentialUsage `json:"credusage,omitempty"` - SupportedRefreshMethods []CredentialRefreshMethod `json:"crms,omitempty"` - OptionalData *CredentialOptionalData `json:"optionaldata,omitempty"` - Period string `json:"period,omitempty"` - PrivateData *CredentialPrivateData `json:"privatedata,omitempty"` - PublicData *CredentialPublicData `json:"publicdata,omitempty"` - RoleID *CredentialRoleID `json:"roleid,omitempty"` - Tag string `json:"tag,omitempty"` + ID int `json:"credid,omitempty" yaml:"id,omitempty"` + Type CredentialType `json:"credtype" yaml:"type"` + Subject string `json:"subjectuuid" yaml:"subject"` + Usage CredentialUsage `json:"credusage,omitempty" yaml:"usage,omitempty"` + SupportedRefreshMethods []CredentialRefreshMethod `json:"crms,omitempty" yaml:"supportedRefreshMethods,omitempty"` + OptionalData *CredentialOptionalData `json:"optionaldata,omitempty" yaml:"optionalData,omitempty"` + Period string `json:"period,omitempty" yaml:"period,omitempty"` + PrivateData *CredentialPrivateData `json:"privatedata,omitempty" yaml:"privateData,omitempty"` + PublicData *CredentialPublicData `json:"publicdata,omitempty" yaml:"publicData,omitempty"` + RoleID *CredentialRoleID `json:"roleid,omitempty" yaml:"roleID,omitempty"` + Tag string `json:"tag,omitempty" yaml:"tag,omitempty"` } type CredentialType uint8 @@ -117,9 +117,9 @@ const ( ) type CredentialOptionalData struct { - DataInternal interface{} `json:"data"` - Encoding CredentialOptionalDataEncoding `json:"encoding"` - IsRevoked bool `json:"revstat"` + DataInternal interface{} `json:"data" yaml:"data"` + Encoding CredentialOptionalDataEncoding `json:"encoding" yaml:"encoding"` + IsRevoked bool `json:"revstat" yaml:"isRevoked,omitempty"` } func toByte(v interface{}) []byte { @@ -181,8 +181,8 @@ const ( ) type CredentialPublicData struct { - DataInternal interface{} `json:"data"` - Encoding CredentialPublicDataEncoding `json:"encoding"` + DataInternal interface{} `json:"data" yaml:"data"` + Encoding CredentialPublicDataEncoding `json:"encoding" yaml:"encoding"` } func (c CredentialPublicData) Data() []byte { @@ -202,16 +202,16 @@ const ( ) type CredentialRoleID struct { - Authority string `json:"authority,omitempty"` - Role string `json:"role,omitempty"` + Authority string `json:"authority,omitempty" yaml:"authority,omitempty"` + Role string `json:"role,omitempty" yaml:"role,omitempty"` } type CredentialResponse struct { - ResourceOwner string `json:"rowneruuid"` - Interfaces []string `json:"if"` - ResourceTypes []string `json:"rt"` - Name string `json:"n"` - Credentials []Credential `json:"creds"` + ResourceOwner string `json:"rowneruuid" yaml:"resourceOwner,omitempty"` + Interfaces []string `json:"if,omitempty" yaml:"-"` + ResourceTypes []string `json:"rt,omitempty" yaml:"-"` + Name string `json:"n,omitempty" yaml:"name,omitempty"` + Credentials []Credential `json:"creds" yaml:"creds"` } type CredentialUpdateRequest struct { diff --git a/schema/link.go b/schema/link.go index c42d21c9..8100cef3 100644 --- a/schema/link.go +++ b/schema/link.go @@ -50,12 +50,12 @@ type ResourceLinks []ResourceLink // https://openconnectivity.org/specs/OCF_Core_Specification_v2.0.0.pdf type Policy struct { BitMask BitMask `json:"bm"` - UDPPort uint16 `json:"port"` - TCPPort uint16 `json:"x.org.iotivity.tcp"` - TCPTLSPort uint16 `json:"x.org.iotivity.tls"` + UDPPort uint16 `json:"port,omitempty"` + TCPPort uint16 `json:"x.org.iotivity.tcp,omitempty"` + TCPTLSPort uint16 `json:"x.org.iotivity.tls,omitempty"` // Secured is true if the resource is only available via an encrypted connection. - Secured bool `json:"sec"` + Secured *bool `json:"sec,omitempty"` } // Endpoint is defined on the line 2439 and 1892, Priority on 2434 of the Core specification: @@ -125,6 +125,17 @@ func (r ResourceLinks) GetResourceLinks(resourceTypes ...string) ResourceLinks { return links } +// FilterByDeviceID filter links by device id. +func (r ResourceLinks) FilterByDeviceID(deviceID string) ResourceLinks { + links := make([]ResourceLink, 0, len(r)) + for _, r := range r { + if r.GetDeviceID() == deviceID { + links = append(links, r) + } + } + return links +} + // PatchEndpoint adds Endpoint information where missing. func (r ResourceLinks) PatchEndpoint(addr kitNet.Addr, deviceEndpoints Endpoints) ResourceLinks { links := make(ResourceLinks, 0, len(r)) @@ -232,7 +243,7 @@ func (r ResourceLink) patchEndpoint(addr kitNet.Addr, deviceEndpoints Endpoints) } r.Endpoints = make([]Endpoint, 0, 4) if r.Policy.UDPPort != 0 { - if r.Policy.Secured { + if r.Policy.Secured != nil && *r.Policy.Secured { r.Endpoints = append(r.Endpoints, udpTlsEndpoint(addr.SetPort(r.Policy.UDPPort))) } else { r.Endpoints = append(r.Endpoints, udpEndpoint(addr.SetPort(r.Policy.UDPPort))) diff --git a/schema/link_test.go b/schema/link_test.go index 54c4f371..efa39c83 100644 --- a/schema/link_test.go +++ b/schema/link_test.go @@ -267,12 +267,13 @@ func TestResourceLinkPatchEndpoint(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + secured := tt.args.secured r := schema.ResourceLink{ Policy: &schema.Policy{ UDPPort: 5683, TCPPort: 5683, TCPTLSPort: 5684, - Secured: tt.args.secured, + Secured: &secured, }, } got := r.PatchEndpoint(tt.args.addr, nil) diff --git a/sonar-project.properties b/sonar-project.properties index 71edc9de..cb2e7bf6 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -17,5 +17,9 @@ sonar.tests=. sonar.test.inclusions=**/*_test.go sonar.test.exclusions= +# excludes duplications check +sonar.cpd.exclusions=cmd/**/*.go + +# exludes from code coverage report sonar.go.coverage.reportPaths=coverage/**/*coverage*.txt -sonar.coverage.exclusions=test/*.go,client/app/app.go,cmd/ocfclient/ocfclient.go,**/main.go,**/*.pb.go,**/*.pb.gw.go,**/*.js,**/*.py +sonar.coverage.exclusions=bridge/test/**/*.go,test/**/*.go,client/app/app.go,cmd/**/*.go,**/main.go,**/*.pb.go,**/*.pb.gw.go,**/*.js,**/*.py diff --git a/test/bridge-device/main.go b/test/bridge-device/main.go new file mode 100644 index 00000000..327f5e6f --- /dev/null +++ b/test/bridge-device/main.go @@ -0,0 +1,129 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "os/signal" + "syscall" + + "github.com/google/uuid" + "github.com/plgd-dev/device/v2/bridge/device" + "github.com/plgd-dev/device/v2/bridge/device/cloud" + "github.com/plgd-dev/device/v2/bridge/net" + "github.com/plgd-dev/device/v2/bridge/service" + "github.com/plgd-dev/device/v2/pkg/log" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" +) + +const ( + numGeneratedBridgedDevices = 3 +) + +func testConfig() service.Config { + return service.Config{ + API: service.APIConfig{ + CoAP: service.CoAPConfig{ + ID: uuid.New().String(), + Config: net.Config{ + ExternalAddresses: []string{"127.0.0.1:15683", "[::1]:15683"}, + MaxMessageSize: 2097152, + }, + }, + }, + } +} + +func getCloudTLS() (cloud.CAPool, *tls.Certificate, error) { + caPath := os.Getenv("CA_POOL") + fmt.Printf("Loading CA(%s)\n", caPath) + ca, err := pkgX509.ReadPemCertificates(caPath) + if err != nil { + return cloud.CAPool{}, nil, fmt.Errorf("cannot load ca: %w", err) + } + caPool := cloud.MakeCAPool(func() []*x509.Certificate { + return ca + }, false) + + certPath := os.Getenv("CERT_FILE") + keyPath := os.Getenv("KEY_FILE") + if keyPath != "" && certPath != "" { + fmt.Printf("Loading certificate(%s) and key(%s)\n", certPath, keyPath) + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return cloud.CAPool{}, nil, fmt.Errorf("cannot load cert: %w", err) + } + return caPool, &cert, nil + } + return caPool, nil, nil +} + +func handleSignals(s *service.Service) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + for sig := range sigCh { + switch sig { + case syscall.SIGINT: + os.Exit(0) + return + case syscall.SIGTERM: + _ = s.Shutdown() + return + } + } +} + +func main() { + cfg := testConfig() + if err := cfg.Validate(); err != nil { + panic(err) + } + s, err := service.New(cfg, service.WithLogger(log.NewStdLogger(log.LevelDebug))) + if err != nil { + panic(err) + } + + opts := []device.Option{} + caPool, cert, errC := getCloudTLS() + if errC != nil { + panic(errC) + } + opts = append(opts, device.WithCAPool(caPool)) + if cert != nil { + opts = append(opts, device.WithGetCertificates(func(string) []tls.Certificate { + return []tls.Certificate{*cert} + })) + } + + for i := 0; i < numGeneratedBridgedDevices; i++ { + newDevice := func(id uuid.UUID, piid uuid.UUID) (service.Device, error) { + return device.New(device.Config{ + Name: fmt.Sprintf("bridged-device-%d", i), + ResourceTypes: []string{"oic.d.virtual"}, + ID: id, + ProtocolIndependentID: piid, + MaxMessageSize: cfg.API.CoAP.MaxMessageSize, + Cloud: device.CloudConfig{ + Enabled: true, + Config: cloud.Config{ + CloudID: os.Getenv("CLOUD_SID"), + }, + }, + }, append(opts, device.WithLogger(device.NewLogger(id, log.LevelDebug)))...) + } + d, errC := s.CreateDevice(uuid.New(), newDevice) + if errC == nil { + d.Init() + } + } + + go func() { + handleSignals(s) + }() + + if err = s.Serve(); err != nil { + panic(err) + } +} diff --git a/test/client/client.go b/test/client/client.go new file mode 100644 index 00000000..3542fbe0 --- /dev/null +++ b/test/client/client.go @@ -0,0 +1,228 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "time" + + "github.com/plgd-dev/device/v2/client" + "github.com/plgd-dev/device/v2/pkg/log" + "github.com/plgd-dev/device/v2/pkg/net/coap" + "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + "github.com/plgd-dev/device/v2/schema" + "github.com/plgd-dev/device/v2/test" +) + +var CertIdentity = "00000000-0000-0000-0000-000000000001" + +type testSetupSecureClient struct { + ca []*x509.Certificate + mfgCA []*x509.Certificate + mfgCert tls.Certificate +} + +func (c *testSetupSecureClient) GetManufacturerCertificate() (tls.Certificate, error) { + if c.mfgCert.PrivateKey == nil { + return c.mfgCert, fmt.Errorf("private key not set") + } + return c.mfgCert, nil +} + +func (c *testSetupSecureClient) GetManufacturerCertificateAuthorities() ([]*x509.Certificate, error) { + if len(c.mfgCA) == 0 { + return nil, fmt.Errorf("certificate authority not set") + } + return c.mfgCA, nil +} + +func (c *testSetupSecureClient) GetRootCertificateAuthorities() ([]*x509.Certificate, error) { + if len(c.ca) == 0 { + return nil, fmt.Errorf("certificate authorities not set") + } + return c.ca, nil +} + +func NewTestSetupSecureClient(ca, mfgCA []*x509.Certificate, mfgCert tls.Certificate) client.ApplicationCallback { + return &testSetupSecureClient{ + ca: ca, + mfgCA: mfgCA, + mfgCert: mfgCert, + } +} + +func NewTestSecureClient() (*client.Client, error) { + return newTestSecureClient(test.IdentityIntermediateCA, test.IdentityIntermediateCAKey, false) +} + +func NewTestSecureClientWithBridgeSupport() (*client.Client, error) { + return newTestSecureClient(test.IdentityIntermediateCA, test.IdentityIntermediateCAKey, true) +} + +func NewTestSecureClientWithGeneratedCertificate() (*client.Client, error) { + var cfgCA generateCertificate.Configuration + cfgCA.Subject.CommonName = "anotherClient" + cfgCA.ValidFrom = "now" + cfgCA.ValidFor = time.Hour + + priv, err := cfgCA.GenerateKey() + if err != nil { + return nil, fmt.Errorf("cannot generate private key: %w", err) + } + cert, err := generateCertificate.GenerateRootCA(cfgCA, priv) + if err != nil { + return nil, fmt.Errorf("cannot generate root ca: %w", err) + } + derKey, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, fmt.Errorf("cannot marhsal private key: %w", err) + } + key := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: derKey}) + return newTestSecureClient(cert, key, false) +} + +func newTestSecureClient(signerCert, signerKey []byte, useDeviceIDInQuery bool) (*client.Client, error) { + cfg := client.Config{ + DeviceOwnershipSDK: &client.DeviceOwnershipSDKConfig{ + ID: CertIdentity, + Cert: string(signerCert), + CertKey: string(signerKey), + CreateSignerFunc: test.NewIdentityCertificateSigner, + }, + UseDeviceIDInQuery: useDeviceIDInQuery, + } + mfgTrustedCABlock, _ := pem.Decode(test.RootCACrt) + if mfgTrustedCABlock == nil { + return nil, fmt.Errorf("mfgTrustedCABlock is empty") + } + mfgCA, err := x509.ParseCertificates(mfgTrustedCABlock.Bytes) + if err != nil { + return nil, err + } + mfgCert, err := tls.X509KeyPair(test.MfgCert, test.MfgKey) + if err != nil { + return nil, fmt.Errorf("X509KeyPair failed: %w", err) + } + + client, err := client.NewClientFromConfig(&cfg, &testSetupSecureClient{ + mfgCA: mfgCA, + mfgCert: mfgCert, + }, log.NewStdLogger(log.LevelDebug)) + if err != nil { + return nil, err + } + err = client.Initialization(context.Background()) + if err != nil { + return nil, err + } + return client, nil +} + +func MakeMockDeviceResourcesObservationHandler() *MockDeviceResourcesObservationHandler { + return &MockDeviceResourcesObservationHandler{ + res: make(chan schema.ResourceLinks, 100), + close: make(chan struct{}), + } +} + +type MockDeviceResourcesObservationHandler struct { + res chan schema.ResourceLinks + close chan struct{} +} + +func (h *MockDeviceResourcesObservationHandler) Handle(_ context.Context, body schema.ResourceLinks) { + h.res <- body +} + +func (h *MockDeviceResourcesObservationHandler) Error(err error) { + fmt.Println(err) +} + +func (h *MockDeviceResourcesObservationHandler) OnClose() { + close(h.close) +} + +func (h *MockDeviceResourcesObservationHandler) WaitForNotification(ctx context.Context) (schema.ResourceLinks, error) { + select { + case e := <-h.res: + return e, nil + case <-ctx.Done(): + return nil, ctx.Err() + case <-h.close: + return nil, fmt.Errorf("unexpected close") + } +} + +func (h *MockDeviceResourcesObservationHandler) WaitForClose(ctx context.Context) error { + select { + case e := <-h.res: + return fmt.Errorf("unexpected notification %v", e) + case <-ctx.Done(): + return ctx.Err() + case <-h.close: + return nil + } +} + +type MockResourceObservationHandler struct { + Res chan coap.DecodeFunc + Close chan struct{} +} + +func MakeMockResourceObservationHandler() *MockResourceObservationHandler { + return &MockResourceObservationHandler{Res: make(chan coap.DecodeFunc, 10), Close: make(chan struct{})} +} + +func (h *MockResourceObservationHandler) Handle(_ context.Context, body coap.DecodeFunc) { + h.Res <- body +} + +func (h *MockResourceObservationHandler) Error(err error) { fmt.Println(err) } + +func (h *MockResourceObservationHandler) OnClose() { close(h.Close) } + +func (h *MockResourceObservationHandler) WaitForNotification(ctx context.Context) (coap.DecodeFunc, error) { + select { + case e := <-h.Res: + return e, nil + case <-ctx.Done(): + return nil, ctx.Err() + case <-h.Close: + return nil, fmt.Errorf("unexpected close") + } +} + +func (h *MockResourceObservationHandler) WaitForClose(ctx context.Context) error { + select { + case e := <-h.Res: + var d interface{} + if err := e(d); err != nil { + return fmt.Errorf("unexpected notification: cannot decode: %w", err) + } + return fmt.Errorf("unexpected notification %v", d) + case <-ctx.Done(): + return ctx.Err() + case <-h.Close: + return nil + } +} diff --git a/test/coap-gateway/defaultHandler.go b/test/coap-gateway/defaultHandler.go new file mode 100644 index 00000000..af61bcf8 --- /dev/null +++ b/test/coap-gateway/defaultHandler.go @@ -0,0 +1,240 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package test + +import ( + "fmt" + "sync" + "time" + + "github.com/plgd-dev/device/v2/pkg/ocf/cloud" +) + +// DefaultHandler is the default handler for tests +// +// It implements ServiceHandler interface by just logging the called method and +// returning default response and no error (if required). +type DefaultHandler struct { + deviceID string + accessToken string + refreshToken string + accessTokenLifetime int64 // lifetime in seconds or <0 value for a token without expiration +} + +func MakeDefaultHandler(accessTokenLifetime int64) DefaultHandler { + return DefaultHandler{ + accessToken: "access-token", + refreshToken: "refresh-token", + accessTokenLifetime: accessTokenLifetime, + } +} + +func (h *DefaultHandler) GetDeviceID() string { + return h.deviceID +} + +func (h *DefaultHandler) SetDeviceID(deviceID string) { + h.deviceID = deviceID +} + +func (h *DefaultHandler) SetAccessToken(accessToken string) { + h.accessToken = accessToken +} + +func (h *DefaultHandler) SetRefreshToken(refreshToken string) { + h.refreshToken = refreshToken +} + +func (h *DefaultHandler) SignUp(req cloud.CoapSignUpRequest) (cloud.CoapSignUpResponse, error) { + fmt.Printf("SignUp: %v\n", req) + h.SetDeviceID(req.DeviceID) + return cloud.CoapSignUpResponse{ + AccessToken: h.accessToken, + UserID: "1", + RefreshToken: h.refreshToken, + ExpiresIn: h.accessTokenLifetime, + RedirectURI: "", + }, nil +} + +func (h *DefaultHandler) CloseOnError() bool { + return true +} + +func (h *DefaultHandler) SignOff() error { + fmt.Printf("SignOff deviceID:%v\n", h.deviceID) + return nil +} + +func (h *DefaultHandler) SignIn(req cloud.CoapSignInRequest) (cloud.CoapSignInResponse, error) { + fmt.Printf("SignIn: %v\n", req) + return cloud.CoapSignInResponse{ + ExpiresIn: h.accessTokenLifetime, + }, nil +} + +func (h *DefaultHandler) SignOut(req cloud.CoapSignInRequest) error { + fmt.Printf("SignOut: %v\n", req) + return nil +} + +func (h *DefaultHandler) PublishResources(req cloud.PublishResourcesRequest) error { + fmt.Printf("PublishResources: %v\n", req) + return nil +} + +func (h *DefaultHandler) UnpublishResources(req cloud.UnpublishResourcesRequest) error { + fmt.Printf("UnpublishResources: %v\n", req) + return nil +} + +func (h *DefaultHandler) RefreshToken(req cloud.CoapRefreshTokenRequest) (cloud.CoapRefreshTokenResponse, error) { + fmt.Printf("RefreshToken: %v\n", req) + return cloud.CoapRefreshTokenResponse{ + RefreshToken: h.refreshToken, + AccessToken: h.accessToken, + ExpiresIn: h.accessTokenLifetime, + }, nil +} + +const ( + SignUpKey = "SignUp" // register + SignOffKey = "SignOff" // deregister + SignInKey = "SignIn" // log in + SignOutKey = "SignOut" // log out + PublishKey = "Publish" + UnpublishKey = "Unpublish" + RefreshTokenKey = "RefreshToken" +) + +type DefaultHandlerWithCounter struct { + *DefaultHandler + + CallCounter struct { + Data map[string]int + Lock sync.Mutex + } + + signedInChan chan int + signedOffChan chan int +} + +func NewCoapHandlerWithCounter(atLifetime int64) *DefaultHandlerWithCounter { + dh := MakeDefaultHandler(atLifetime) + return &DefaultHandlerWithCounter{ + DefaultHandler: &dh, + CallCounter: struct { + Data map[string]int + Lock sync.Mutex + }{ + Data: make(map[string]int), + }, + + signedInChan: make(chan int, 16), + signedOffChan: make(chan int, 16), + } +} + +func sendToChan(c chan int, v int) { + select { + case c <- v: + default: + } +} + +func waitForAction(c chan int, timeout time.Duration) int { + select { + case v := <-c: + return v + case <-time.After(timeout): + return -1 + } +} + +func (ch *DefaultHandlerWithCounter) SignUp(req cloud.CoapSignUpRequest) (cloud.CoapSignUpResponse, error) { + resp, err := ch.DefaultHandler.SignUp(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[SignUpKey]++ + ch.CallCounter.Lock.Unlock() + return resp, err +} + +func (ch *DefaultHandlerWithCounter) SignOff() error { + err := ch.DefaultHandler.SignOff() + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[SignOffKey]++ + signOffCount, ok := ch.CallCounter.Data[SignOffKey] + if ok { + sendToChan(ch.signedOffChan, signOffCount) + } + ch.CallCounter.Lock.Unlock() + return err +} + +func (ch *DefaultHandlerWithCounter) WaitForSignOff(timeout time.Duration) int { + return waitForAction(ch.signedOffChan, timeout) +} + +func (ch *DefaultHandlerWithCounter) SignIn(req cloud.CoapSignInRequest) (cloud.CoapSignInResponse, error) { + resp, err := ch.DefaultHandler.SignIn(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[SignInKey]++ + signInCount, ok := ch.CallCounter.Data[SignInKey] + if ok { + sendToChan(ch.signedInChan, signInCount) + } + ch.CallCounter.Lock.Unlock() + return resp, err +} + +func (ch *DefaultHandlerWithCounter) WaitForSignIn(timeout time.Duration) int { + return waitForAction(ch.signedInChan, timeout) +} + +func (ch *DefaultHandlerWithCounter) SignOut(req cloud.CoapSignInRequest) error { + err := ch.DefaultHandler.SignOut(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[SignOutKey]++ + ch.CallCounter.Lock.Unlock() + return err +} + +func (ch *DefaultHandlerWithCounter) PublishResources(req cloud.PublishResourcesRequest) error { + err := ch.DefaultHandler.PublishResources(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[PublishKey]++ + ch.CallCounter.Lock.Unlock() + return err +} + +func (ch *DefaultHandlerWithCounter) UnpublishResources(req cloud.UnpublishResourcesRequest) error { + err := ch.DefaultHandler.UnpublishResources(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[UnpublishKey]++ + ch.CallCounter.Lock.Unlock() + return err +} + +func (ch *DefaultHandlerWithCounter) RefreshToken(req cloud.CoapRefreshTokenRequest) (cloud.CoapRefreshTokenResponse, error) { + resp, err := ch.DefaultHandler.RefreshToken(req) + ch.CallCounter.Lock.Lock() + ch.CallCounter.Data[RefreshTokenKey]++ + ch.CallCounter.Lock.Unlock() + return resp, err +} diff --git a/test/coap-gateway/service/client.go b/test/coap-gateway/service/client.go new file mode 100644 index 00000000..5c83f73f --- /dev/null +++ b/test/coap-gateway/service/client.go @@ -0,0 +1,129 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapTcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" + grpcCodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Client a setup of connection +type Client struct { + server *Service + coapConn *coapTcpClient.Conn + handler ServiceHandler + + deviceID string +} + +// newClient creates and initializes client +func newClient(server *Service, client *coapTcpClient.Conn, handler ServiceHandler) *Client { + return &Client{ + server: server, + coapConn: client, + handler: handler, + } +} + +func (c *Client) GetCoapConnection() *coapTcpClient.Conn { + return c.coapConn +} + +func (c *Client) GetServiceHandler() ServiceHandler { + return c.handler +} + +func (c *Client) GetDeviceID() string { + return c.deviceID +} + +func (c *Client) SetDeviceID(deviceID string) { + c.deviceID = deviceID +} + +func (c *Client) RemoteAddrString() string { + return c.coapConn.RemoteAddr().String() +} + +func (c *Client) Context() context.Context { + return c.coapConn.Context() +} + +// Close closes coap connection +func (c *Client) Close() error { + if err := c.coapConn.Close(); err != nil { + return fmt.Errorf("cannot close client: %w", err) + } + return nil +} + +// OnClose is invoked when the coap connection was closed. +func (c *Client) OnClose() { + fmt.Printf("close client %v\n", c.coapConn.RemoteAddr()) +} + +type grpcErr interface { + GRPCStatus() *status.Status +} + +func isContextCanceled(err error) bool { + if errors.Is(err, context.Canceled) { + return true + } + var gErr grpcErr + if ok := errors.As(err, &gErr); ok { + return gErr.GRPCStatus().Code() == grpcCodes.Canceled + } + return false +} + +func (c *Client) sendResponse(code codes.Code, token message.Token, payload []byte) { + msg := pool.NewMessage(c.Context()) + msg.SetCode(code) + msg.SetToken(token) + if len(payload) > 0 { + msg.SetContentFormat(message.AppOcfCbor) + msg.SetBody(bytes.NewReader(payload)) + } + if err := c.coapConn.WriteMessage(msg); err != nil { + if !isContextCanceled(err) { + fmt.Printf("cannot send reply to %v: %v\n", c.GetDeviceID(), err) + } + } +} + +func (c *Client) sendErrorResponse(err error, code codes.Code, token message.Token) { + msg := pool.NewMessage(c.Context()) + msg.SetCode(code) + msg.SetToken(token) + // Don't set content format for diagnostic message: https://tools.ietf.org/html/rfc7252#section-5.5.2 + msg.SetBody(bytes.NewReader([]byte(err.Error()))) + if err = c.coapConn.WriteMessage(msg); err != nil { + fmt.Printf("cannot send error to %v: %v\n", c.GetDeviceID(), err) + } +} diff --git a/test/coap-gateway/service/config.go b/test/coap-gateway/service/config.go new file mode 100644 index 00000000..602b75c9 --- /dev/null +++ b/test/coap-gateway/service/config.go @@ -0,0 +1,31 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import "crypto/tls" + +type TLSConfig struct { + Enabled bool + *tls.Config +} + +type Config struct { + Addr string + TLS TLSConfig +} diff --git a/test/coap-gateway/service/refreshToken.go b/test/coap-gateway/service/refreshToken.go new file mode 100644 index 00000000..231846c0 --- /dev/null +++ b/test/coap-gateway/service/refreshToken.go @@ -0,0 +1,70 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + coapCodes "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" +) + +func refreshTokenPostHandler(req *mux.Message, client *Client) { + logErrorAndCloseClient := func(err error, code coapCodes.Code) { + client.sendErrorResponse(fmt.Errorf("cannot handle refresh token: %w", err), code, req.Token()) + if client.handler == nil || client.handler.CloseOnError() { + if err := client.Close(); err != nil { + fmt.Printf("refresh token error: %v\n", err) + } + } + } + + var r cloud.CoapRefreshTokenRequest + err := cbor.ReadFrom(req.Body(), &r) + if err != nil { + logErrorAndCloseClient(err, coapCodes.BadRequest) + return + } + + resp, err := client.handler.RefreshToken(r) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + out, err := cbor.Encode(resp) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + client.sendResponse(coapCodes.Changed, req.Token(), out) +} + +// RefreshToken +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.tokenrefresh.swagger.json +func refreshTokenHandler(req *mux.Message, client *Client) { + if req.Code() == coapCodes.POST { + refreshTokenPostHandler(req, client) + return + } + client.sendErrorResponse(fmt.Errorf("forbidden request from %v", client.RemoteAddrString()), coapCodes.Forbidden, req.Token()) +} diff --git a/test/coap-gateway/service/resourceDirectory.go b/test/coap-gateway/service/resourceDirectory.go new file mode 100644 index 00000000..8fd61ca3 --- /dev/null +++ b/test/coap-gateway/service/resourceDirectory.go @@ -0,0 +1,131 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + coapCodes "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" +) + +// fixHref ensures that href starts with "/" and does not end with "/". +func fixHref(href string) string { + backslash := regexp.MustCompile(`\/+`) + p := backslash.ReplaceAllString(href, "/") + p = strings.TrimRight(p, "/") + if len(p) > 0 && p[0] == '/' { + return p + } + return "/" + p +} + +func publishHandler(req *mux.Message, client *Client) { + p := cloud.PublishResourcesRequest{ + TimeToLive: -1, + } + err := cbor.ReadFrom(req.Body(), &p) + if err != nil { + client.sendErrorResponse(fmt.Errorf("cannot read publish request body received: %w", err), coapCodes.BadRequest, req.Token()) + return + } + + for i, link := range p.Links { + p.Links[i].DeviceID = p.DeviceID + p.Links[i].Href = fixHref(link.Href) + } + + if err = client.handler.PublishResources(p); err != nil { + client.sendErrorResponse(err, coapCodes.InternalServerError, req.Token()) + return + } + + out, err := cbor.Encode(p) + if err != nil { + client.sendErrorResponse(err, coapCodes.InternalServerError, req.Token()) + return + } + + client.sendResponse(coapCodes.Changed, req.Token(), out) +} + +func parseUnpublishRequestFromQuery(queries []string) (cloud.UnpublishResourcesRequest, error) { + req := cloud.UnpublishResourcesRequest{} + for _, q := range queries { + values, err := url.ParseQuery(q) + if err != nil { + return cloud.UnpublishResourcesRequest{}, fmt.Errorf("cannot parse unpublish query: %w", err) + } + if di := values.Get("di"); di != "" { + req.DeviceID = di + } + + if ins := values.Get("ins"); ins != "" { + i, err := strconv.Atoi(ins) + if err != nil { + return cloud.UnpublishResourcesRequest{}, fmt.Errorf("cannot convert %v to number", ins) + } + req.InstanceIDs = append(req.InstanceIDs, int64(i)) + } + } + + if req.DeviceID == "" { + return cloud.UnpublishResourcesRequest{}, fmt.Errorf("deviceID not found") + } + return req, nil +} + +func unpublishHandler(req *mux.Message, client *Client) { + queries, err := req.Options().Queries() + if err != nil { + client.sendErrorResponse(fmt.Errorf("cannot query string from unpublish request from device %v: %w", client.GetDeviceID(), err), coapCodes.BadRequest, req.Token()) + return + } + + r, err := parseUnpublishRequestFromQuery(queries) + if err != nil { + client.sendErrorResponse(fmt.Errorf("unable to parse unpublish request query string from device %v: %w", client.GetDeviceID(), err), coapCodes.BadRequest, req.Token()) + return + } + + err = client.handler.UnpublishResources(r) + if err != nil { + client.sendErrorResponse(err, coapCodes.InternalServerError, req.Token()) + return + } + + client.sendResponse(coapCodes.Deleted, req.Token(), nil) +} + +func resourceDirectoryHandler(req *mux.Message, client *Client) { + switch req.Code() { + case coapCodes.POST: + publishHandler(req, client) + case coapCodes.DELETE: + unpublishHandler(req, client) + default: + client.sendErrorResponse(fmt.Errorf("forbidden request from %v", client.RemoteAddrString()), coapCodes.Forbidden, req.Token()) + } +} diff --git a/test/coap-gateway/service/service.go b/test/coap-gateway/service/service.go new file mode 100644 index 00000000..8473fbc4 --- /dev/null +++ b/test/coap-gateway/service/service.go @@ -0,0 +1,238 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync" + "syscall" + + bridgeNet "github.com/plgd-dev/device/v2/bridge/net" + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/options" + "github.com/plgd-dev/go-coap/v3/tcp" + coapTcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" + coapTcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" +) + +// Service is a configuration of coap-gateway +type Service struct { + coapServer *coapTcpServer.Server + listener coapTcpServer.Listener + closeFn []func() + ctx context.Context + cancel context.CancelFunc + sigs chan os.Signal + getHandler GetServiceHandler + clients []*Client +} + +func newListener(cfg Config) (coapTcpServer.Listener, func(), error) { + if !cfg.TLS.Enabled { + listener, err := net.NewTCPListener("tcp", cfg.Addr) + if err != nil { + return nil, nil, fmt.Errorf("cannot create tcp listener: %w", err) + } + closeListener := func() { + if errC := listener.Close(); errC != nil { + fmt.Printf("failed to close tcp listener: %v\n", errC) + } + } + return listener, closeListener, nil + } + + listener, err := net.NewTLSListener("tcp", cfg.Addr, cfg.TLS.Config) + if err != nil { + return nil, nil, fmt.Errorf("cannot create tcp-tls listener: %w", err) + } + closeFn := (func() { + if errC := listener.Close(); errC != nil { + fmt.Printf("failed to close tcp-tls listener: %v\n", errC) + } + }) + return listener, closeFn, nil +} + +// New creates server. +func New(ctx context.Context, cfg Config, getHandler GetServiceHandler) (*Service, error) { + var closeFn []func() + listener, closeListener, err := newListener(cfg) + if err != nil { + return nil, fmt.Errorf("cannot create listener: %w", err) + } + closeFn = append(closeFn, closeListener) + + ctx, cancel := context.WithCancel(ctx) + + s := Service{ + listener: listener, + closeFn: closeFn, + ctx: ctx, + cancel: cancel, + sigs: make(chan os.Signal, 1), + getHandler: getHandler, + clients: nil, + } + + if err := s.setupCoapServer(); err != nil { + return nil, fmt.Errorf("cannot setup coap server: %w", err) + } + + return &s, nil +} + +const clientKey = "client" + +func (s *Service) coapConnOnNew(coapConn *coapTcpClient.Conn) { + client := newClient(s, coapConn, s.getHandler(s, WithCoapConnectionOpt(coapConn))) + coapConn.SetContextValue(clientKey, client) + coapConn.AddOnClose(func() { + client.OnClose() + }) + s.clients = append(s.clients, client) +} + +func validateCommand(writer mux.ResponseWriter, request *mux.Message, server *Service, fnc func(req *mux.Message, client *Client)) { + request.Hijack() + go func(w mux.ResponseWriter, req *mux.Message) { + client, ok := w.Conn().Context().Value(clientKey).(*Client) + if !ok || client == nil { + con, ok2 := w.Conn().(*coapTcpClient.Conn) + if !ok2 { + panic("invalid connection") + } + client = newClient(server, con, nil) + } + closeClient := func(c *Client) { + if err := c.Close(); err != nil { + fmt.Printf("cannot handle command: %v\n", err) + } + } + + switch req.Code() { + case codes.POST, codes.DELETE, codes.PUT, codes.GET: + fnc(req, client) + case codes.Empty: + if !ok { + client.sendErrorResponse(fmt.Errorf("cannot handle command: client not found"), codes.InternalServerError, req.Token()) + closeClient(client) + return + } + case codes.Content: + // Unregistered observer at a peer send us a notification + default: + fmt.Printf("received invalid code: CoapCode(%v)", req.Code()) + } + }(writer, request) +} + +func defaultHandler(req *mux.Message, client *Client) { + path, _ := req.Options().Path() + client.sendErrorResponse(fmt.Errorf("DeviceId: %v: unknown path %v", client.GetDeviceID(), path), codes.NotFound, req.Token()) +} + +func (s *Service) setupCoapServer() error { + setHandlerError := func(uri string, err error) error { + return fmt.Errorf("failed to set %v handler: %w", uri, err) + } + m := mux.NewRouter() + m.DefaultHandle(mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + validateCommand(w, r, s, defaultHandler) + })) + if err := m.Handle(ocfCloud.SignUp, mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + validateCommand(w, r, s, signUpHandler) + })); err != nil { + return setHandlerError(ocfCloud.SignUp, err) + } + if err := m.Handle(ocfCloud.SignIn, mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + validateCommand(w, r, s, signInHandler) + })); err != nil { + return setHandlerError(ocfCloud.SignIn, err) + } + if err := m.Handle(ocfCloud.ResourceDirectory, mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + validateCommand(w, r, s, resourceDirectoryHandler) + })); err != nil { + return setHandlerError(ocfCloud.ResourceDirectory, err) + } + if err := m.Handle(ocfCloud.RefreshToken, mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + validateCommand(w, r, s, refreshTokenHandler) + })); err != nil { + return setHandlerError(ocfCloud.RefreshToken, err) + } + + opts := make([]coapTcpServer.Option, 0, 5) + opts = append(opts, options.WithOnNewConn(s.coapConnOnNew)) + opts = append(opts, options.WithMux(m)) + opts = append(opts, options.WithContext(s.ctx)) + opts = append(opts, options.WithErrors(func(e error) { + fmt.Printf("test-coap: %v\n", e) + })) + opts = append(opts, options.WithMaxMessageSize(bridgeNet.DefaultMaxMessageSize)) + s.coapServer = tcp.NewServer(opts...) + return nil +} + +func (s *Service) Serve() error { + return s.serveWithHandlingSignal() +} + +func (s *Service) serveWithHandlingSignal() error { + var wg sync.WaitGroup + var err error + wg.Add(1) + go func(server *Service) { + defer wg.Done() + err = server.coapServer.Serve(server.listener) + server.cancel() + for i := range server.closeFn { + server.closeFn[len(server.closeFn)-1-i]() + } + }(s) + + signal.Notify(s.sigs, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM, + syscall.SIGQUIT) + <-s.sigs + + s.coapServer.Stop() + wg.Wait() + + return err +} + +func (s *Service) GetClients() []*Client { + return s.clients +} + +// Close turns off the server. +func (s *Service) Close() error { + select { + case s.sigs <- syscall.SIGTERM: + default: + } + return nil +} diff --git a/test/coap-gateway/service/serviceHandler.go b/test/coap-gateway/service/serviceHandler.go new file mode 100644 index 00000000..a53bf793 --- /dev/null +++ b/test/coap-gateway/service/serviceHandler.go @@ -0,0 +1,65 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +type ServiceHandlerConfig struct { + coapConn *client.Conn +} + +func (s *ServiceHandlerConfig) GetCoapConnection() *client.Conn { + return s.coapConn +} + +type Option interface { + Apply(o *ServiceHandlerConfig) +} + +type CoapConnectionOpt struct { + coapConn *client.Conn +} + +func (o CoapConnectionOpt) Apply(opts *ServiceHandlerConfig) { + opts.coapConn = o.coapConn +} + +func WithCoapConnectionOpt(c *client.Conn) CoapConnectionOpt { + return CoapConnectionOpt{ + coapConn: c, + } +} + +type GetServiceHandler = func(service *Service, opts ...Option) ServiceHandler + +type OnShutdown = func(ServiceHandler) + +type ServiceHandler interface { + CloseOnError() bool + SignUp(req ocfCloud.CoapSignUpRequest) (ocfCloud.CoapSignUpResponse, error) + SignOff() error + SignIn(req ocfCloud.CoapSignInRequest) (ocfCloud.CoapSignInResponse, error) + SignOut(req ocfCloud.CoapSignInRequest) error + PublishResources(req ocfCloud.PublishResourcesRequest) error + UnpublishResources(req ocfCloud.UnpublishResourcesRequest) error + RefreshToken(req ocfCloud.CoapRefreshTokenRequest) (ocfCloud.CoapRefreshTokenResponse, error) +} diff --git a/test/coap-gateway/service/signIn.go b/test/coap-gateway/service/signIn.go new file mode 100644 index 00000000..f1a4e532 --- /dev/null +++ b/test/coap-gateway/service/signIn.go @@ -0,0 +1,95 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + coapCodes "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" +) + +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.session.swagger.json +func signInPostHandler(req *mux.Message, client *Client, signIn cloud.CoapSignInRequest) { + logErrorAndCloseClient := func(err error, code coapCodes.Code) { + client.sendErrorResponse(fmt.Errorf("cannot handle sign in: %w", err), code, req.Token()) + if client.handler == nil || client.handler.CloseOnError() { + if err := client.Close(); err != nil { + fmt.Printf("sign in error: %v\n", err) + } + } + } + + resp, err := client.handler.SignIn(signIn) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + out, err := cbor.Encode(resp) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + client.sendResponse(coapCodes.Changed, req.Token(), out) +} + +// Sign-Out +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.session.swagger.json +func signOutPostHandler(req *mux.Message, client *Client, signOut cloud.CoapSignInRequest) { + logErrorAndCloseClient := func(err error, code coapCodes.Code) { + client.sendErrorResponse(fmt.Errorf("cannot handle sign out: %w", err), code, req.Token()) + if client.handler == nil || client.handler.CloseOnError() { + if err := client.Close(); err != nil { + fmt.Printf("sign out error: %v\n", err) + } + } + } + + err := client.handler.SignOut(signOut) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + client.sendResponse(coapCodes.Changed, req.Token(), []byte{0xA0}) // empty object +} + +// Sign-in +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.session.swagger.json +func signInHandler(req *mux.Message, client *Client) { + if req.Code() == coapCodes.POST { + var r cloud.CoapSignInRequest + err := cbor.ReadFrom(req.Body(), &r) + if err != nil { + client.sendErrorResponse(fmt.Errorf("cannot handle sign in: %w", err), coapCodes.BadRequest, req.Token()) + return + } + if r.Login { + signInPostHandler(req, client, r) + return + } + signOutPostHandler(req, client, r) + return + } + client.sendErrorResponse(fmt.Errorf("forbidden request from %v", client.RemoteAddrString()), coapCodes.Forbidden, req.Token()) +} diff --git a/test/coap-gateway/service/signUp.go b/test/coap-gateway/service/signUp.go new file mode 100644 index 00000000..ee7df5ff --- /dev/null +++ b/test/coap-gateway/service/signUp.go @@ -0,0 +1,96 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package service + +import ( + "fmt" + + "github.com/plgd-dev/device/v2/pkg/codec/cbor" + "github.com/plgd-dev/device/v2/pkg/ocf/cloud" + coapCodes "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/mux" +) + +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.account.swagger.json +func signUpPostHandler(r *mux.Message, client *Client) { + logErrorAndCloseClient := func(err error, code coapCodes.Code) { + client.sendErrorResponse(fmt.Errorf("cannot handle sign up: %w", err), code, r.Token()) + if client.handler == nil || client.handler.CloseOnError() { + if err := client.Close(); err != nil { + fmt.Printf("sign up error: %v\n", err) + } + } + } + + var signUp cloud.CoapSignUpRequest + if err := cbor.ReadFrom(r.Body(), &signUp); err != nil { + logErrorAndCloseClient(err, coapCodes.BadRequest) + return + } + + client.SetDeviceID(signUp.DeviceID) + + resp, err := client.handler.SignUp(signUp) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + out, err := cbor.Encode(resp) + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + client.sendResponse(coapCodes.Changed, r.Token(), out) +} + +// Sign-off +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.account.swagger.json +func signOffHandler(req *mux.Message, client *Client) { + logErrorAndCloseClient := func(err error, code coapCodes.Code) { + client.sendErrorResponse(fmt.Errorf("cannot handle sign off: %w", err), code, req.Token()) + if client.handler == nil || client.handler.CloseOnError() { + if err := client.Close(); err != nil { + fmt.Printf("sign off error: %v\n", err) + } + } + } + + err := client.handler.SignOff() + if err != nil { + logErrorAndCloseClient(err, coapCodes.InternalServerError) + return + } + + client.sendResponse(coapCodes.Deleted, req.Token(), nil) +} + +// Sign-up +// https://github.com/openconnectivityfoundation/security/blob/master/swagger2.0/oic.sec.account.swagger.json +func signUpHandler(r *mux.Message, client *Client) { + switch r.Code() { + case coapCodes.POST: + signUpPostHandler(r, client) + case coapCodes.DELETE: + signOffHandler(r, client) + default: + client.sendErrorResponse(fmt.Errorf("forbidden request from %v", client.RemoteAddrString()), coapCodes.Forbidden, r.Token()) + } +} diff --git a/test/coap-gateway/test.go b/test/coap-gateway/test.go new file mode 100644 index 00000000..2d57543c --- /dev/null +++ b/test/coap-gateway/test.go @@ -0,0 +1,69 @@ +/**************************************************************************** + * + * Copyright (c) 2024 plgd.dev s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"), + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + * + ****************************************************************************/ + +package test + +import ( + "context" + "crypto/tls" + "sync" + "testing" + + "github.com/plgd-dev/device/v2/test" + "github.com/plgd-dev/device/v2/test/coap-gateway/service" + "github.com/stretchr/testify/require" +) + +const ( + COAP_GW_HOST = "localhost:21002" +) + +func MakeConfig(t *testing.T) service.Config { + return service.Config{ + Addr: COAP_GW_HOST, + TLS: service.TLSConfig{ + Enabled: true, + Config: &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + Certificates: []tls.Certificate{test.GetCoapCertificate(t)}, + }, + }, + } +} + +func New(t *testing.T, getHandler service.GetServiceHandler, onShutdown service.OnShutdown) func() { + ctx := context.Background() + s, err := service.New(ctx, MakeConfig(t), getHandler) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = s.Serve() + }() + return func() { + _ = s.Close() + wg.Wait() + if onShutdown != nil { + for _, c := range s.GetClients() { + onShutdown(c.GetServiceHandler()) + } + } + } +} diff --git a/test/config.go b/test/config.go index 8cd83e0f..a518cdf3 100644 --- a/test/config.go +++ b/test/config.go @@ -28,6 +28,7 @@ import ( "time" "github.com/plgd-dev/device/v2/pkg/security/generateCertificate" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/acl" "github.com/plgd-dev/device/v2/schema/ael" @@ -50,7 +51,6 @@ import ( "github.com/plgd-dev/device/v2/schema/softwareupdate" "github.com/plgd-dev/device/v2/schema/sp" testTypes "github.com/plgd-dev/device/v2/test/resource/types" - "github.com/plgd-dev/kit/v2/security" ) var ( @@ -224,11 +224,11 @@ func GenerateIdentityCert(deviceID string) tls.Certificate { if err != nil { log.Fatal(err) } - signerCA, err := security.ParseX509FromPEM(RootCACrt) + signerCA, err := pkgX509.ParsePemCertificates(RootCACrt) if err != nil { log.Fatal(err) } - signerCAKey, err := security.LoadX509PrivateKey(os.Getenv("ROOT_CA_KEY")) + signerCAKey, err := pkgX509.ReadPemEcdsaPrivateKey(os.Getenv("ROOT_CA_KEY")) if err != nil { log.Fatal(err) } diff --git a/test/signer.go b/test/signer.go index 1fb68cee..c6cd2a07 100644 --- a/test/signer.go +++ b/test/signer.go @@ -7,11 +7,11 @@ import ( "time" "github.com/plgd-dev/device/v2/client/core" - "github.com/plgd-dev/kit/v2/security" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" ) func NewTestSigner() (core.CertificateSigner, error) { - identityIntermediateCA, err := security.ParseX509FromPEM(IdentityIntermediateCA) + identityIntermediateCA, err := pkgX509.ParsePemCertificates(IdentityIntermediateCA) if err != nil { return nil, err } diff --git a/test/test.go b/test/test.go index 6d3713ce..b8cc6033 100644 --- a/test/test.go +++ b/test/test.go @@ -19,9 +19,11 @@ package test import ( "context" "crypto" + "crypto/tls" "crypto/x509" "fmt" "os" + "path/filepath" "strings" "sync/atomic" "testing" @@ -29,14 +31,16 @@ import ( "github.com/plgd-dev/device/v2/client/core" "github.com/plgd-dev/device/v2/pkg/security/signer" + pkgX509 "github.com/plgd-dev/device/v2/pkg/security/x509" "github.com/plgd-dev/device/v2/schema" "github.com/plgd-dev/device/v2/schema/device" "github.com/plgd-dev/device/v2/schema/interfaces" "github.com/plgd-dev/device/v2/test/resource/types" - "github.com/plgd-dev/kit/v2/log" "github.com/stretchr/testify/require" ) +const TestTimeout = time.Second * 8 + func MustGetHostname() string { n, err := os.Hostname() if err != nil { @@ -93,7 +97,7 @@ func (h *findDeviceIDByNameHandler) Handle(ctx context.Context, dev *core.Device } func (h *findDeviceIDByNameHandler) Error(err error) { - log.Debug(err) + fmt.Printf("%v\n", err) } func FindDeviceByName(ctx context.Context, name string) (deviceID string, _ error) { @@ -193,7 +197,7 @@ func FindDeviceIP(ctx context.Context, deviceName string, ipType IPType) (string } defer func() { if errC := device.Close(ctx); errC != nil { - log.Errorf("FindDeviceIP: %w", errC) + fmt.Printf("cannot close device: %v\n", errC) } }() return getDeviceIP(device, ipType) @@ -216,7 +220,7 @@ func FindDeviceEndpoints(ctx context.Context, deviceName string, ipType IPType) } defer func() { if errC := device.Close(ctx); errC != nil { - log.Errorf("FindDeviceEndpoints: %w", errC) + fmt.Printf("cannot close device: %v\n", errC) } }() return device.GetEndpoints(), nil @@ -285,3 +289,38 @@ func CheckResourceLinks(t *testing.T, expected, actual schema.ResourceLinks) { } require.Empty(t, expLinks) } + +func CloudSID() string { + return os.Getenv("CLOUD_SID") +} + +func GetRootCApem(t *testing.T) []byte { + certPath := os.Getenv("ROOT_CA_CRT") + require.NotEmpty(t, certPath) + data, err := os.ReadFile(filepath.Clean(certPath)) + require.NoError(t, err) + return data +} + +func GetRootCA(t *testing.T) []*x509.Certificate { + certPem := GetRootCApem(t) + cas, err := pkgX509.ParsePemCertificates(certPem) + require.NoError(t, err) + return cas +} + +func GetMfgCertificate(t *testing.T) tls.Certificate { + ca, err := tls.X509KeyPair(MfgCert, MfgKey) + require.NoError(t, err) + return ca +} + +func GetCoapCertificate(t *testing.T) tls.Certificate { + certPath := os.Getenv("COAP_CRT") + require.NotEmpty(t, certPath) + keyPath := os.Getenv("COAP_KEY") + require.NotEmpty(t, keyPath) + ca, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + return ca +} diff --git a/tools/docker/patches/shrink_tls_conn.patch b/tools/docker/patches/shrink_tls_conn.patch new file mode 100644 index 00000000..3334a466 --- /dev/null +++ b/tools/docker/patches/shrink_tls_conn.patch @@ -0,0 +1,16 @@ +diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go +index 969f357834..63dff2b93e 100644 +--- a/src/crypto/tls/conn.go ++++ b/src/crypto/tls/conn.go +@@ -789,6 +789,11 @@ func (r *atLeastReader) Read(p []byte) (int, error) { + // at least n bytes or else returns an error. + func (c *Conn) readFromUntil(r io.Reader, n int) error { + if c.rawInput.Len() >= n { ++ if c.rawInput.Len() < bytes.MinRead && c.rawInput.Cap() > 4*bytes.MinRead { ++ p := c.rawInput.Bytes() ++ c.rawInput = *bytes.NewBuffer(make([]byte, len(p), bytes.MinRead)) ++ copy(c.rawInput.Bytes(), p) ++ } + return nil + } + needs := n - c.rawInput.Len()