diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 0000000..0128dbd
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1 @@
+* @synapsec-ai/subnet-owners
\ No newline at end of file
diff --git a/.github/workflows/subnet-containers-latest.yml b/.github/workflows/subnet-containers-latest.yml
new file mode 100644
index 0000000..7d88252
--- /dev/null
+++ b/.github/workflows/subnet-containers-latest.yml
@@ -0,0 +1,38 @@
+name: Push Docker image to latest
+
+on:
+ push:
+ branches:
+ - 'main'
+
+jobs:
+ build-miner-container-latest:
+ runs-on:
+ group: synapsec-larger-runners
+ permissions:
+ contents: read # Default permission to read repository contents
+ packages: write # Permission to write to GitHub Packages
+
+ env:
+ IMAGE_NAME: soundsright-miner
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Podman
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y podman
+
+ - name: Build Podman image
+ run: podman build -t ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:latest -f miner.Dockerfile
+
+ - name: Log in to GitHub Container Registry
+ run: echo "${{ secrets.GITHUB_TOKEN }}" | podman login ghcr.io -u ${{ github.repository_owner }} --password-stdin
+
+ - name: Push Podman image
+ run: podman push ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:latest
+
+ - name: Store image details
+ run: echo "{name}={image::ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:latest}" >> $GITHUB_OUTPUTS
\ No newline at end of file
diff --git a/.github/workflows/subnet-containers-tagged.yml b/.github/workflows/subnet-containers-tagged.yml
new file mode 100644
index 0000000..e15553c
--- /dev/null
+++ b/.github/workflows/subnet-containers-tagged.yml
@@ -0,0 +1,44 @@
+name: Push Docker image with version tag
+
+on:
+ push:
+ tags:
+ - 'v*.*.*'
+
+jobs:
+ build-miner-container-tagged:
+ runs-on:
+ group: synapsec-larger-runners
+ permissions:
+ contents: read # Default permission to read repository contents
+ packages: write # Permission to write to GitHub Packages
+
+ env:
+ IMAGE_NAME: soundsright-miner
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Podman
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y podman
+
+ - name: Determine Image Tag
+ id: tag
+ run: |
+ TAG=$(echo $GITHUB_REF | sed 's/refs\/tags\///')
+ echo "tag=${TAG}" >> $GITHUB_ENV
+
+ - name: Build Podman image
+ run: podman build -t ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:${{ env.tag }} -f miner.Dockerfile
+
+ - name: Log in to GitHub Container Registry
+ run: echo "${{ secrets.GITHUB_TOKEN }}" | podman login ghcr.io -u ${{ github.repository_owner }} --password-stdin
+
+ - name: Push Podman image
+ run: podman push ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:${{ env.tag }}
+
+ - name: Store image details
+ run: echo "{name}={image::ghcr.io/${{ github.repository_owner }}/${{ env.IMAGE_NAME }}:${{ env.tag }}}" >> $GITHUB_OUTPUT
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b0a363f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,164 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+test_data
\ No newline at end of file
diff --git a/.miner-env.sample b/.miner-env.sample
new file mode 100644
index 0000000..5201f82
--- /dev/null
+++ b/.miner-env.sample
@@ -0,0 +1,28 @@
+NETUID=DO THIS ONCE WE REGISTER
+SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
+WALLET=coldkey_name
+HOTKEY=hotkey_name
+
+# Available: INFO, INFOX, DEBUG, DEBUGX, TRACE, TRACEX
+LOG_LEVEL=INFO
+
+# Necessary for dataset generation
+OPENAI_API_KEY=
+
+# Miner model specification by task and sample rate.
+# If you have not fine-tuned a model for a specific task and sample rate, just leave it blank.
+# NOTE: EACH MINER CAN ONLY RESPOND FOR ONE TASK AND ONE SAMPLE RATE.
+# PLEASE REGISTER ANOTHER MINER IF YOU HAVE ANOTHER MODEL FOR ANOTHER TASK OR SAMPLE RATE.
+# 16kHz Sample Rate, Denoising Task
+DENOISING_16000HZ_HF_MODEL_NAMESPACE=
+DENOISING_16000HZ_HF_MODEL_NAME=
+DENOISING_16000HZ_HF_MODEL_REVISION=
+
+# 16kHz Sample Rate, Dereverberation Task
+DEREVERBERATION_16000HZ_HF_MODEL_NAMESPACE=
+DEREVERBERATION_16000HZ_HF_MODEL_NAME=
+DEREVERBERATION_16000HZ_HF_MODEL_REVISION=
+
+# HealthCheck API
+HEALTHCHECK_API_HOST=0.0.0.0
+HEALTHCHECK_API_PORT=6000
\ No newline at end of file
diff --git a/.validator-env.sample b/.validator-env.sample
new file mode 100644
index 0000000..d0522c7
--- /dev/null
+++ b/.validator-env.sample
@@ -0,0 +1,12 @@
+NETUID=DO THIS ONCE WE REGISTER
+SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
+WALLET=coldkey_name
+HOTKEY=hotkey_name
+OPENAI_API_KEY=
+
+# Available: INFO, INFOX, DEBUG, DEBUGX, TRACE, TRACEX
+LOG_LEVEL=INFO
+
+# HealthCheck API
+HEALTHCHECK_API_HOST=0.0.0.0
+HEALTHCHECK_API_PORT=6000
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..e79da55
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,41 @@
+MIT License
+
+Copyright (c) 2024 synapsec.ai
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+---
+
+### Third-Party Code:
+
+Portions of this software are derived from code in the following project(s):
+
+- taoverse by macrocosm-os
+ - Repository: https://github.com/macrocosm-os/taoverse/
+ - Copyright (c) macrocosm-os
+ - Licensed under the MIT License (included in the `THIRD_PARTY_LICENSES` file)
+
+- python-acoustics by python-acoustics
+ - Repository: https://github.com/python-acoustics/python-acoustics
+ - Copyright (c) Python Acoustics
+ - Licensed under the BSD 3-Clause Licence (included in the `THIRD_PARTY_LICENSES` file)
+
+- ears_benchmark by sp-uhh
+ - Repository: https://github.com/sp-uhh/ears_benchmark/tree/main
+ - Licensed under the CC-NC 4.0 International License (included in the `THIRD_PARTY_LICENSES` file)
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..209b637
--- /dev/null
+++ b/README.md
@@ -0,0 +1,47 @@
+
SoundsRight (SN)
+
+
+Bittensor's Speech Enhancement Subnet
+
+If you are unfamiliar with how Bittensor works, please check out [this primer](https://docs.bittensor.com/learn/bittensor-building-blocks) first!
+
+SoundsRight is dedicated to incentivizing the research and development of open-source models for speech enhancement through daily fine-tuning competitions, powered by the decentralized Bittensor ecosystem.
+
+Miners in the subnet will upload their fine-tuned models to HuggingFace, and the subnet's validators are in charge of downloading the models, benchmarking their performance and rewarding miners accordingly.
+
+**Each competition is winner-takes-all.**
+
+Fine-Tuning Competitions
+
+The table below outlines competitions currently being held by the subnet. Competitions are distinguished by the sample rate of the testing data, the task and the metric used for benchmarking.
+
+| Sample Rate | Task | Benchmarking Metric | % of Total Miner Incentives |
+| ----------- | ---- | ------ | --------------------------- |
+| 16 kHz | Denoising | PESQ | 15 |
+| 16 kHz | Denoising | ESTOI | 12.5 |
+| 16 kHz | Denoising | SI-SDR | 7.5 |
+| 16 kHz | Denoising | SI-SAR | 7.5 |
+| 16 kHz | Denoising | SI-SIR | 7.5 |
+| 16 kHz | Dereverberation | PESQ | 15 |
+| 16 kHz | Dereverberation | ESTOI | 12.5 |
+| 16 kHz | Dereverberation | SI-SDR | 7.5 |
+| 16 kHz | Dereverberation | SI-SAR | 7.5 |
+| 16 kHz | Dereverberation | SI-SIR | 7.5 |
+
+For more details about sample rates, tasks and metrics, please reference the [competition docs](docs/subnet/competitions.md).
+
+Getting Started
+
+To get started with mining or validating in the subnet, please reference the following documentation:
+
+Subnet Documentation
+
+
+
+Miner Documentation
+
+
+
+Validator Documentation
+
+
\ No newline at end of file
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 0000000..c43e8f1
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,9 @@
+# Security Policy
+
+Security issues related to the subnet should be responsibly disclosed to the subnet development team via e-mail: bounty@synapsec.ai.
+
+In your submission, please include the following:
+1) Brief summary of the vulnerability including the potential impact
+2) Steps to reproduce
+3) Any code samples or other supplementary material
+4) (Optional) Your contact details (if other than e-mail) for further discussions
\ No newline at end of file
diff --git a/THIRD_PARTY_LICENSES b/THIRD_PARTY_LICENSES
new file mode 100644
index 0000000..022754d
--- /dev/null
+++ b/THIRD_PARTY_LICENSES
@@ -0,0 +1,453 @@
+MIT License
+
+Copyright (c) 2024 Taoverse
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+---
+
+Copyright (c) 2013, Python Acoustics
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the documentation and/or
+ other materials provided with the distribution.
+
+* Neither the name of the {organization} nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+---
+
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..f86d7e2
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,31 @@
+services:
+ common-miner: &common-miner
+ image: ghcr.io/synapsec-ai/soundsright-miner:v1.0.0
+ restart: unless-stopped
+ pull_policy: always
+ ports:
+ - "6000:6001"
+ volumes:
+ - SoundsRightSubnet:/root/.SoundsRightSubnet
+ - ${HOME}/.bittensor:/root/.bittensor
+
+ soundsright-miner:
+ <<: *common-miner
+ command: /bin/bash -c "source /SoundsRightSubnet/.venv/bin/activate && python3 /SoundsRightSubnet/soundsright/neurons/miner.py --netuid ${NETUID} --subtensor.chain_endpoint ${SUBTENSOR_CHAIN_ENDPOINT} --wallet.name ${WALLET} --wallet.hotkey ${HOTKEY} --log_level ${LOG_LEVEL} --healthcheck_host ${HEALTHCHECK_API_HOST} --healthcheck_port ${HEALTHCHECK_API_PORT}"
+
+ soundsright-miner-dev:
+ restart: unless-stopped
+ pull_policy: always
+ ports:
+ - "6000:6000"
+ - "6001:6001"
+ volumes:
+ - SoundsRightSubnet:/root/.SoundsRightSubnet
+ - ${HOME}/.bittensor:/root/.bittensor
+ build:
+ context: .
+ dockerfile: miner.Dockerfile
+ command: /bin/bash -c "source /SoundsRightSubnet/.venv/bin/activate && python3 /SoundsRightSubnet/soundsright/neurons/miner.py --netuid ${NETUID} --subtensor.chain_endpoint ${SUBTENSOR_CHAIN_ENDPOINT} --wallet.name ${WALLET} --wallet.hotkey ${HOTKEY} --log_level ${LOG_LEVEL} --healthcheck_host ${HEALTHCHECK_API_HOST} --healthcheck_port ${HEALTHCHECK_API_PORT} --validator_min_stake 0"
+
+volumes:
+ SoundsRightSubnet:
\ No newline at end of file
diff --git a/docs/mining/generate_data.md b/docs/mining/generate_data.md
new file mode 100644
index 0000000..41759f4
--- /dev/null
+++ b/docs/mining/generate_data.md
@@ -0,0 +1,33 @@
+# Fine-Tuning Dataset Generation
+
+Miners are able to generate datasets of their own to fine-tune their models, though you will need an OpenAI API key to do so.
+
+You will also need a bit of storage to download the noise/reverb datasets. To download the WHAM (noise) dataset requires 76GB of storage, and to download the ARNI (reverb) dataset requires 51 GB of storage.
+
+First, create a .env file with:
+```
+cp .env.sample .env
+```
+Next, add your OpenAI API key to the OPENAI_API_KEY variable in the .env file you have created.
+
+Then, navigate to the scripts directory in the SoundsRight repository with:
+```
+cd scripts
+```
+From there, use the `generate_dataset.py` script to generate your dataset with the following command line arguments:
+
+| Argument | Description | Type | Always Required |
+| :------: | :---------: | :--: | :------: |
+| --clean_dir | Path of directory where you want your clean data to go. | str | Yes |
+| --sample_rate | Sample rate of the dataset, defaults to 16000. | int | Yes |
+| --n | Dataset size. | int | Yes |
+| --task | What task you want the dataset for. One of: denoising, dereverberation, both | str | Yes |
+| --noise_dir | The directory where the noisy dataset will be stored. You only need to input this if you want to generate a dataset for the denoising task. | str | No |
+| --noise_data_dir | The directory where the data to generate noisy datasets will be stored. You only need to input this if you want to generate a dataset for the denoising task. | str | No |
+| ---reverb_dir | The directory where the reverberation dataset will be stored. You only need to input this if you want to generate a dataset for the dereverberation task. | str | No |
+| --reverb_data_dir | The directory where data to generate reverberation datasets will be stored. You only need to input this if you want to generate a dataset for the dereverberation task. | str | No |
+
+Here is an example of how to call the script:
+```
+python3 generate_dataset.py --clean_dir my_clean_dir --sample_rate 16000 --n 5000 --task denoising --noise_dir my_noise_dir --noise_data_dir my_noise_data_dir
+```
\ No newline at end of file
diff --git a/docs/mining/model_formatting.md b/docs/mining/model_formatting.md
new file mode 100644
index 0000000..d365a0c
--- /dev/null
+++ b/docs/mining/model_formatting.md
@@ -0,0 +1,41 @@
+# Model Formatting
+
+Models submitted to validators must follow a few formatting guidelines, and we have provided a [template](https://huggingface.co/synapsecai/soundsright-template) for miners to use. Your model will not be scored by validators unless it follows the guidelines exactly.
+
+The `main` branch of this template is what should be modified by miners to create their own models. The branches `DENOISING_16000Hz` and `DEREVERBERATION_16000HZ` serve as tutorials, being fitted with different pretrained checkpoints of [SGMSE+](https://huggingface.co/sp-uhh/speech-enhancement-sgmse).
+
+For detailed instructions on how to format your model, please reference the `README.md` in the `main` branch.
+
+# Model Testing
+
+A script has been provided to test that your model is compatible with the validator architecture.
+
+To run the script, first make sure you have completed the following installations:
+
+1. Podman
+
+```
+$ apt-get update
+$ apt-get -y install podman
+```
+
+2. Python venv
+```
+$ cd SoundsRightSubnet
+$ python3 -m venv .venv
+$ source .venv/bin/activate
+```
+
+3. Python dependencies
+```
+(.venv) $ pip install --use-pep517 pesq==0.0.4 && pip install -e .[validator] && pip install httpx==0.27.2
+```
+
+Once the installation is complete, run your script with the following command:
+```
+(.venv) $ python3 scripts/verify_miner_model.py --model_namespace --model_name --model_revision
+```
+
+If `MODEL VERIFICATION SUCCESSFUL.` appears in the logs, then your model is ready to be submitted to validators!
+
+Note that this may take a while depending on the machine you run the script with (especially if you do not have a GPU). Please reference the documentation on running a validator if you wish to mirror the hardware exactly.
\ No newline at end of file
diff --git a/docs/mining/running_miner.md b/docs/mining/running_miner.md
new file mode 100644
index 0000000..7e995c1
--- /dev/null
+++ b/docs/mining/running_miner.md
@@ -0,0 +1,127 @@
+# Mining in the SoundsRight Subnet
+
+## Overview
+
+Generally, mining on the subnet looks like this:
+
+1. Miner fine-tunes a model. We recommend visiting the [website](https://www.soundsright.ai) and basing your model off of the best model from the previous competition, but ultimately it is up to you.
+2. Miner uploads the model to HuggingFace and makes it publicly available.
+3. Miner ensures that their model is compatible with the validator script used to benchmark their model. See the [model tutorial doc](model_tutorial.md) for more details.
+4. Miner updates their .env file with the model's data and restarts their miner neuron. The miner will automatically trigger the process of communicating the model data with validators upon restarting.
+
+Note that there is **no fine-tuning script contained within the miner neuron itself**--all miners are responsible for fine-tuning their models externally. Miner neurons are only used to communicate model data to validators.
+
+However this repository does contain scripts which can be used to generate fine-tuning datasets. Note that miners will need to have an OpenAI API key in order for this to work. Please reference the [dataset generation docs](generate_data.md) for more information.
+
+Also, **each miner can only submit models for one specific task and sample rate**. If you wish to provide models for multiple tasks and/or sample rates, you will need to register multiple miners.
+
+## Running a Miner
+
+### 1. Machine Specifications
+
+As the miner's only function is to upload model metadata to the Bittensor chain and send model information to validators, it is quite lightweight and should work on most configurations.
+
+We have been testing miners on machines running on both **Ubuntu 24.04** and **Python 3.12** with the following hardware configurations:
+
+- 16 GB RAM
+- 4 vCPU
+- 50 GB SSD
+
+We also highly recommend that you use a dedicated server to run the SoundsRight miner.
+
+### 2. Installation of Mandatory Packages
+
+The following sections will assume you are running as root.
+
+#### 2.1 Install Docker Engine for Ubuntu
+For installing the Docker Engine for Ubuntu, follow the official instructions: [Install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/).
+
+#### 2.2 Validate installation
+After installation is done, validate the docker engine has been installed correctly:
+```
+$ docker run hello-world
+```
+
+#### 2.3 Install the mandatory packages
+
+Run the following command:
+```
+$ apt-get install python3.12-venv
+```
+
+### 3. Preparation
+
+#### 3.1 Setup the GitHub repository and python virtualenv
+To clone the repository and setup the Python virtualenv, execute the following commands:
+```
+$ git clone https://github.com/synapsec-ai/SoundsRightSubnet.git
+$ cd SoundsRightSubnet
+$ python3 -m venv .venv
+$ source .venv/bin/activate
+(.venv) $ pip install bittensor
+```
+
+#### 3.2 Regenerate the miner wallet
+
+The private portion of the coldkey is not needed to run the subnet miner. **Never have your private miner coldkey or hotkeys not used to run the miner stored on the server**.
+
+To regenerate the keys on the host, execute the following commands:
+```
+(.venv) $ btcli wallet regen_coldkeypub
+(.venv) $ btcli wallet regen_hotkey
+```
+
+#### 3.3 Setup .env
+
+Create the .env from the .env.sample file provided with the following:
+
+```
+cp .miner-env.sample .env
+```
+
+The contents of the .env file must then be adjusted. The following variables apply for miners:
+
+| Variable | Meaning |
+| :------: | :-----: |
+| NETUID | The subnet's netuid. For mainnet this value is , and for testnet this value is 271. |
+| SUBTENSOR_CHAIN_ENDPOINT | The Bittensor chain endpoint. Please make sure to always use your own endpoint. For mainnnet, the default endpoint is: wss://finney.opentensor.ai:443 and for testnet the default endpoint is: wss://test.finney.opentensor.ai:443 |
+| WALLET | The name of your coldkey. |
+| HOTKEY | The name of your hotkey. |
+| LOG_LEVEL | Specifies the level of logging you will see on the validator. Choose between INFO, INFOX, DEBUG. DEBUGX, TRACE, and TRACEX. |
+| OPENAI_API_KEY | Your OpenAI API key. This is not needed to run the miner, only to generate training datasets. |
+| HEALTHCHECK_API_HOST | Host for HealthCheck API, default is 0.0.0.0. There is no need to adjust this value unless you want to. |
+| HEALTHCHECK_API_PORT | Port for HealthCheck API, default is 6000. There is no need to adjust this value unless you want to, and you will have to modify the ports in the docker-compose.yml file if you choose to do so. |
+
+In addition to this, the model being submitted to the competition must be specified in the .env file. Specifically, the model namespace, name, and revision must be specified in the .env for the particular competition being entered in by the miner.
+
+For example, if we want to submit the `main` branch of the HuggingFace model `synapsecai/my_speech_enhancement_model` to be evalauted, we designate the following:
+
+| Variable | Designation |
+| :------: | :-----: |
+| HF_MODEL_NAMESPACE | synapsecai |
+| HF_MODEL_NAME | my_speech_enhancement_model |
+| HF_MODEL_REVISION | main |
+
+### 4. Running the Miner
+
+Run the miner with this command:
+
+```
+$ docker compose up soundsright-miner -d
+```
+To see the logs, execute the following command:
+
+```
+$ docker compose logs soundsright-miner -f
+```
+
+### 5. Updating the Miner
+
+Updating the miner is done by re-launching the docker compose with the `--force-recreate` flag enabled after the git repository has been updated. This will re-create the containers and download the latest images from the container registry.
+
+```
+$ cd SoundsRightSubnet
+$ git pull
+$ docker compose up soundsright-miner -d --force-recreate
+```
+
diff --git a/docs/subnet/citations.md b/docs/subnet/citations.md
new file mode 100644
index 0000000..1a67f0d
--- /dev/null
+++ b/docs/subnet/citations.md
@@ -0,0 +1,34 @@
+# Citations
+
+Portions of the subnet code are derived from the following sources:
+
+1. **Prawda, Karolina, Schlecht, Sebastian J., & Välimäki, Vesa.**
+ **Dataset of impulse responses from variable acoustics room Arni at Aalto Acoustic Labs [Data set]**.
+ *Zenodo*. [https://doi.org/10.5281/zenodo.6985104](https://doi.org/10.5281/zenodo.6985104), 2022.
+
+2. **Richter, Julius, de Oliveira, Danilo, & Gerkmann, Timo.**
+ **Investigating Training Objectives for Generative Speech Enhancement.**
+ *arXiv preprint*, [arXiv:2409.10753](https://arxiv.org/abs/2409.10753), 2024.
+
+3. **Richter, Julius, Welker, Simon, Lemercier, Jean-Marie, Lay, Bunlong, & Gerkmann, Timo.**
+ **Speech Enhancement and Dereverberation with Diffusion-based Generative Models.**
+ *IEEE/ACM Transactions on Audio, Speech, and Language Processing*, 31, 2351–2364.
+ [https://doi.org/10.1109/TASLP.2023.3285241](https://doi.org/10.1109/TASLP.2023.3285241), 2023.
+
+4. **Richter, Julius, Wu, Yi-Chiao, Krenn, Steven, Welker, Simon, Lay, Bunlong, Watanabe, Shinji, Richard, Alexander, & Gerkmann, Timo.**
+ **EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation**.
+ In *ISCA Interspeech*, 2024.
+
+5. **Schroeder, M. R.**
+ **New Method of Measuring Reverberation Time**.
+ *Journal of the Acoustical Society of America* March 1968, 37(3), 409–412.
+ [https://doi.org/10.1121/1.1938260](https://doi.org/10.1121/1.1938260), March 1968.
+
+6. **Welker, Simon, Richter, Julius, & Gerkmann, Timo.**
+ **Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain.**
+ In *Proceedings of Interspeech 2022*, 2928–2932.
+ [https://doi.org/10.21437/Interspeech.2022-10653](https://doi.org/10.21437/Interspeech.2022-10653), 2022.
+
+7. **Wichern, Gordon, Antognini, Joe, Flynn, Michael, Zhu, Licheng Richard, McQuinn, Emmett, Crow, Dwight, Manilow, Ethan, & Le Roux, Jonathan.**
+ **WHAM!: Extending Speech Separation to Noisy Environments**.
+ In *Proceedings of Interspeech*, September 2019.
diff --git a/docs/subnet/competitions.md b/docs/subnet/competitions.md
new file mode 100644
index 0000000..8191e32
--- /dev/null
+++ b/docs/subnet/competitions.md
@@ -0,0 +1,65 @@
+# SoundsRight Competitions
+
+Each individual competition in the subnet is denoted by a unique **sample rate**, **task** and **benchmarking metric**. This doc serves to explain what each of these components are.
+
+## Sample Rate
+
+What we percieve as sound is fundamentally a wave propogating through the air. To represent this digitally we take samples of the signal and mesh them together, much like how a video is comprised of individual frames. Continuing this analogy, the sample rate of digital audio is akin to the frame rate of a video.
+
+Where sample rates differ from frame rates are in how many samples are taken per second--so much so that sample rates are often represented in kHz. Another unique property of the sample rate is that the higher the sample rate, the higher the sound frequencies that can be digitally represented.
+
+The table below denotes a few commonly used sample rates and their applications.
+
+| Sample Rate | Details | Common Applications |
+| ----------- | ------- | ------------------- |
+| 8 kHz | The minimum sample rate for intelligible human speech, often used in applications with limited bandwith. Also known as narrowband audio. | Telephone calls, intercom systems, VoIP |
+| 16 kHz | A good sample rate for capturing human speech while maintaining smaller file sizes. Also known as wideband audio. | Speech recognition and transcription, VoIP |
+| 44.1 kHz | A sample rate that covers the entire range of human hearing. | CD's, Spotify, audiobooks |
+| 48 kHz | A sample rate that covers the entire range of human hearing and also divides evenly with video frame rates to make syncing easier. | Films, television, high-quality digital media (live-streaming, Youtube, etc.) |
+
+Currently, the subnet only hosts competitions for 16kHz sample rate, but this will change with future releases. Please reference the [roadmap](roadmap.md) for more information.
+
+## Tasks
+
+The subnet currently hosts competitions for two tasks--**denoising** and **dereverberation**. Examples have been provided to better illustrate the types of speech enhancement the subnet promotes. First, here is the clean version of the audio, generated with text-to-speech (as is done in the evaluation datasets generated by validators):
+
+https://github.com/user-attachments/assets/e44773fe-c494-426b-ad30-dae4443efa15
+
+### Denoising
+
+The task of denoising involves isolating speech from any background noise present in the recording. The subnet uses the [WHAM! noise dataset](http://wham.whisper.ai/) to add noise to clean text-to-speech outputs in evaluation datasets, as in the example below:
+
+https://github.com/user-attachments/assets/8f6ce652-4bf7-41b2-80b5-e10eb6d78c1a
+
+### Dereverberation
+
+The task of dereverberation involves removing any reverberation from speech (an echo from a large room, etc.). The subnet convolves text-to-speech outputs with room impulse responses from the [Arni dataset](https://zenodo.org/records/6985104) to generate reverberant speech, as in the example below:
+
+https://github.com/user-attachments/assets/e739cda6-ac50-40f6-b8ef-62447a3484bf
+
+
+## Evaluation Metrics
+
+There are a multitude of metrics to assess the quality of audio. Below are the metrics used in the subnet's competitions:
+
+### PESQ (Perceptual Evaluation of Speech Quality)
+
+This metric's aim is to quanitify a person's percieved quality of speech, and is useful as a holistic determination of the quality of speech enhancement performed.
+
+It is important to note that PESQ only works for 8kHz and 16kHz audio.
+
+### ESTOI (Extended Short-Time Objective Intelligibility)
+
+This metric's aim is to quantify the intelligibility of speech--how easy it is the understand the speech itself.
+
+### SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
+
+This metric determines how much distortion is present in the audio. Distortion can be thought of as unwanted changes to the speech signal as a result of the enhancement operation.
+
+### SI-SAR (Scale-Invariant Signal-to-Artifacts Ratio)
+
+This metric determines the level of artifacts present in the audio. Artifacts can be thought of as new, unwanted components introduced as a result of the speech enhancement operation.
+
+### SI-SIR (Scale-Invariant Signal-to-Interference Ratio)
+
+This metric determines the level of interference present in the audio. Interference can be thought of as unwanted audio from outside sources still present in the recording, such as the noise from a crowded room.
\ No newline at end of file
diff --git a/docs/subnet/roadmap.md b/docs/subnet/roadmap.md
new file mode 100644
index 0000000..e26f69a
--- /dev/null
+++ b/docs/subnet/roadmap.md
@@ -0,0 +1,48 @@
+# Subnet Roadmap
+
+The current goal for the subnet is to facilitate the open-source research and development of state-of-the-art speech enhancement models. We recognize that there is potential to create far more open-source work in this field.
+
+The ultimate goal of the subnet is to create a monetized product in the form of an API. However, in order to make the product as competetive as possible, the subnet's first goal is to create a large body of work for miners to draw their inspiration from.
+
+The following roadmap outlines our plans to bring a SoTA speech enhancement API into fruition:
+
+## Versioning and release management
+In order to ensure the subnet users can prepare in advance we have defined a formal patching policy for the subnet components.
+
+The subnet uses **semantic versioning** in which the version number consists of three parts (Major.Minor.Patch) and an optional pre-release tag (-beta, -alpha). Depending on the type of release, there are a few things that the subnet users should be aware of.
+
+- Major Releases (**X**.0.0)
+ - There can be breaking changes and updates are mandatory for all subnet users.
+ - After the update is released, the `weights_version` hyperparameter is adjusted immediately after release such that in order to set the weights in the subnet, the neurons must be running the latest version.
+ - Major releases are communicated in the Subnet's Discord channel at least 1 week in advance.
+ - Registration may be disabled for up to 24 hours.
+
+- Minor releases (0.**X**.0)
+ - There can be breaking changes.
+ - In case there are breaking changes, the update will be announced in the Subnet's Discord channel at least 48 hours in advance. Otherwise a minimum of 24 hour notice is given.
+ - If there are breaking changes, the `weights_version` hyperparameter is adjusted immediately after release such that in order to set the weights in the subnet, the neurons must be running the latest version.
+ - If there are no breaking changes, the `weights_version` hyperparameter will be adjusted 24 hours after the launch.
+ - Minor releases are mandatory for all subnet users.
+ - Registration may be disabled for up to 24 hours.
+
+- Patch releases (0.0.**X**)
+ - Patch releases do not contain breaking changes and updates will not be mandatory unless there is a need to hotfix either scoring or penalty algorithms.
+ - Patch releases without changes to scoring or penalty algorithms are pushed to production without prior notice.
+
+## SoundsRight v1.0.0
+- Register on testnet
+- 16 kHz competitions for denoising and dereverberation tasks
+
+## SoundsRight v1.1.0
+- Register on mainnet
+
+## SoundsRight v2.0.0
+- TTS generation upgrade
+- 48 kHz competitions for denoising and dereverberation tasks
+
+## SoundsRight v3.0.0
+- More utilities provided to miners and validators
+- Validator performance dashboards
+
+## SoundsRight v4.0.0
+- Complete subnet overhaul to focus on monetization via API
\ No newline at end of file
diff --git a/docs/subnet/subnet_architecture.md b/docs/subnet/subnet_architecture.md
new file mode 100644
index 0000000..0d4c876
--- /dev/null
+++ b/docs/subnet/subnet_architecture.md
@@ -0,0 +1,31 @@
+# Subnet Architecture
+
+There are two main entities in the subnet:
+
+1. **Miners** upload fine-tuned speech enhancement models to HuggingFace.
+2. **Validators** benchmark models and reward miners whose models perform the best.
+
+Here is a diagram of the overarching process:
+
+```mermaid
+sequenceDiagram
+ participant Miner
+ participant HuggingFace
+ participant Bittensor Chain
+ participant Validator
+ participant Subnet Website
+
+ Miner->>Miner: Fine-tunes a speech enhancement model
+ Miner->>HuggingFace: Uploads model
+ Miner->>Bittensor Chain: Writes model metadata
+ Validator->>Miner: Sends Synapse requesting model information
+ Miner->>Validator: Returns Synapse containing model information
+ Validator->>Bittensor Chain: References model metadata
+ Bittensor Chain-->>Validator: Confirms model ownership
+ HuggingFace->>Validator: Downloads model
+ Validator->>Validator: Benchmarks model on locally generated dataset
+ Validator->>Subnet Website: Reports benchmarking results
+ Subnet Website->>Subnet Website: Constructs competition leaderboards
+ Validator->>Bittensor Chain: Sets weights for miners
+ Bittensor Chain->>Miner: Assigns incentives
+```
diff --git a/docs/validating/running_validator.md b/docs/validating/running_validator.md
new file mode 100644
index 0000000..634c4b9
--- /dev/null
+++ b/docs/validating/running_validator.md
@@ -0,0 +1,200 @@
+# Validating in the SoundsRight Subnet
+
+## Summary
+Running a validator the in Subnet requires **3,000 staked TAO**.
+
+**We also implore validators to run:**
+1. **In a separate environment dedicated to validating for only the SoundsRight subnet.**
+2. **Using a child hotkey.**
+
+## Validator deployment
+
+### 1. Virtual machine deployment
+The subnet requires **Ubuntu 24.04** and **Python 3.12** with at least the following hardware configuration:
+
+- 16 GB VRAM
+- 23 GB RAM
+- 512 GB storage (1000 IOPS)
+- 5 gbit/s network bandwidth
+- 6 vCPU
+
+When running the subnet validator, we are highly recommending that you run the subnet validator with DataCrunch.io using the **1x Tesla V100** instance type with **Ubuntu 24.04** and **CUDA 12.6**.
+
+This is the setup we are performing our testing and development with; as a result, they are being used as the performance baseline for the subnet validators.
+
+Running the validator with DataCrunch.io is not mandatory and the subnet validator should work on other environments as well, though the exact steps for setup may vary depending on the service used. This guide assumes you're running Ubuntu 24.04 provided by DataCrunch.io, and thus skips steps that might be mandatory in other environments (for example, installing the NVIDIA and CUDA drivers).
+
+### 2. Installation of mandatory packages
+
+Note that for the following steps, it will be assumed that you will be running the validator fully as root and as such, any action that needs to be performed as root will not be denoted with sudo.
+
+#### 2.1 Install Podman for Ubuntu
+
+For installing Podman for Ubuntu, run the following command:
+```
+$ apt-get update
+$ apt-get -y install podman
+```
+
+#### 2.2 Install the mandatory packages
+
+Run the following command:
+```
+$ apt update && apt-get install python3.12-venv && apt install jq && apt install npm && npm install pm2 -g && pm2 update && apt install -y python3.12-dev build-essential gcc g++
+```
+
+#### 2.3 Configure NVIDIA Container Toolkit and CDI
+
+Follow the instructions to download the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) with Apt.
+
+Modify `/etc/nvidia-container-runtime/config.toml` and set the following parameters if you're running docker as non-root user:
+```
+[nvidia-container-cli]
+no-cgroups = true
+
+[nvidia-container-runtime]
+debug = "/tmp/nvidia-container-runtime.log"
+```
+You can also run the following command to achieve the same result:
+```
+$ sudo nvidia-ctk config --set nvidia-container-cli.no-cgroups --in-place
+```
+
+Next, follow the instructions for [generating a CDI specification](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/cdi-support.html).
+
+Verify that the CDI specification was done correctly with:
+```
+$ nvidia-ctk cdi list
+```
+You should see this in your output:
+```
+nvidia.com/gpu=all
+nvidia.com/gpu=0
+```
+
+### 3. Preparation
+
+This section covers setting up the repository, virtual environment, regenerating wallets, and setting up environmental variables.
+
+#### 3.1 Setup the GitHub repository and python virtualenv
+To clone the repository and setup the Python virtualenv, execute the following commands:
+```
+$ git clone https://github.com/synapsec-ai/SoundsRightSubnet.git
+$ cd SoundsRightSubnet
+$ python3 -m venv .venv
+$ source .venv/bin/activate
+(.venv) $ pip install bittensor
+```
+
+#### 3.2 Regenerate the validator wallet
+
+The private portion of the coldkey is not needed to run the subnet validator. **Never have your private validator coldkey or hotkeys not used to run the validator stored on the server**. Please use a dedicated server for each subnet to minimize impact of potential security issues.
+
+To regenerate the keys on the host, execute the following commands:
+```
+(.venv) $ btcli wallet regen_coldkeypub
+(.venv) $ btcli wallet regen_hotkey
+```
+
+#### 3.3 Setup the environmental variables
+The subnet repository contains a sample validator env (`.env.sample`) file that is used to pass the correct parameters to the docker compose file.
+
+Create a new file in the root of the repository called `.env` based on the given sample.
+```
+(.venv) $ cp .validator-env.sample .env
+```
+The contents of the `.env` file must be adjusted according to the validator configuration. Below is a table explaining what each variable in the .env file represents (note that the .env variables that do not apply for validators are not listed here):
+
+| Variable | Meaning |
+| :------: | :-----: |
+| NETUID | The subnet's netuid. For mainnet this value is , and for testnet this value is 271. |
+| SUBTENSOR_CHAIN_ENDPOINT | The Bittensor chain endpoint. Please make sure to always use your own endpoint. For mainnnet, the default endpoint is: wss://finney.opentensor.ai:443, and for testnet the default endpoint is: wss://test.finney.opentensor.ai:443. |
+| WALLET | The name of your coldkey. |
+| HOTKEY | The name of your hotkey. |
+| LOG_LEVEL | Specifies the level of logging you will see on the validator. Choose between INFO, INFOX, DEBUG. DEBUGX, TRACE, and TRACEX. |
+| OPENAI_API_KEY | Your OpenAI API key. |
+| HEALTHCHECK_API_HOST | Host for HealthCheck API, default is 0.0.0.0. There is no need to adjust this value unless you want to. |
+| HEALTHCHECK_API_PORT | Port for HealthCheck API, default is 6000. There is no need to adjust this value unless you want to. |
+
+.env example:
+```
+NETUID=
+SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
+WALLET=my_coldkey
+HOTKEY=my_hotkey
+OPENAI_API_KEY=THIS-IS-AN-OPENAI-API-KEY-wfhwe78r78frfg7e8ghrveh78ehrg
+
+# Available: INFO, INFOX, DEBUG, DEBUGX, TRACE, TRACEX
+LOG_LEVEL=TRACE
+
+# HealthCheck API
+HEALTHCHECK_API_HOST=0.0.0.0
+HEALTHCHECK_API_PORT=6000
+```
+
+#### 3.4 Installing Python Dependencies
+
+Run the following commands:
+
+```
+(.venv) $ pip install --use-pep517 pesq==0.0.4 && pip install -e .[validator] && pip install httpx==0.27.2
+```
+
+### 4. Running the validator
+
+Run the validator with this command:
+```
+$ bash scripts/run_validator.sh --name soundsright-validator --max_memory_restart 50G --branch main
+```
+To see the logs, execute the following command:
+```
+$ pm2 logs
+```
+
+### 5. Updating validator
+
+To update the validator, pull the newest changes to main and restart the pm2 process:
+
+```
+$ cd SoundsRightSubnet
+$ git pull && pm2 restart
+```
+
+### 6. Assessing validator health
+
+A HealthCheck API is built into the validator, which can be queried for an assessment of the validator's performance. Note that the commands in this section assume default values for the `healthcheck_host` and `healthcheck_port` arguments of `0.0.0.0` and `6000` respectively. The following endpoints are available:
+
+#### 6.1 Healthcheck
+
+This endpoint offers an overview of validator performance. It can be queried with:
+
+```
+$ curl http://127.0.0.1:6000/healthcheck | jq
+```
+
+#### 6.2 Metrics
+
+This endpoint offers a view of all of the metrics tabulated by the Healthcheck API. It can be queried with:
+```
+$ curl http://127.0.0.1:6000/healthcheck/metrics | jq
+```
+
+#### 6.3 Events
+
+This endpoint offers insight into WARNING, SUCCESS and ERROR logs in the validator. It can be queried with:
+```
+$ curl http://127.0.0.1:6000/healthcheck/events | jq
+```
+
+#### 6.4 Best Models by Competition
+
+This endpoint offers insight into the best models known by the validator for the previous day's competition. It can be queried with:
+```
+$ curl http://127.0.0.1:6000/healthcheck/best_models | jq
+```
+
+#### 6.5 Models for Current Competitions
+This endpoint offers insight into the best models known by the validator for the previous day's competition. It can be queried with:
+```
+$ curl http://127.0.0.1:6000/healthcheck/current_models | jq
+```
\ No newline at end of file
diff --git a/docs/validating/validator_architecture.md b/docs/validating/validator_architecture.md
new file mode 100644
index 0000000..b165177
--- /dev/null
+++ b/docs/validating/validator_architecture.md
@@ -0,0 +1,29 @@
+# Validator Architecture Overview
+
+Validators on the subnet are in charge of benchmarking miner models and assigning weights for miners who submit the top performing models for each competition. Competitions span one day, and below is a diagram illustrating what the validator does during each:
+
+```mermaid
+sequenceDiagram
+ participant Miner
+ participant HuggingFace
+ participant Bittensor Chain
+ participant Validator
+ participant Subnet Website
+
+ Validator->>Validator: Generate new benchmarking dataset
+ Validator->>Validator: Benchmark SGMSE+ on new dataset
+ Validator->>Subnet Website: Report SGMSE+ benchmark results
+ Validator->>Miner: Send Synapse requesting model information
+ Miner->>Validator: Return Synapse containing model information
+ Bittensor Chain->>Validator: Obtain model metadata
submitted by miner
+ HuggingFace->>Validator: Download model
+ Validator->>Validator: Obtain hash of model directory
+ Validator->>Validator: Confirm model ownership by miner
using chain metadata and model hash
+ Validator->>Validator: Confirms model container is safe to run
+ Validator->>Validator: Runs model container
and benchmarks model
+ Validator->>Validator: Iterates through all miners
and assigns scores per competition
+ Validator->>Subnet Website: Submits miner model
benchmarking results
+ Validator->>Bittensor Chain: Sets weights for miners
+ Bittensor Chain->>Miner: Assigns incentives
+ Subnet Website->>Subnet Website: Generates leaderboard and
results of miner benchmarks
against standard (SGMSE+)
+```
\ No newline at end of file
diff --git a/miner.Dockerfile b/miner.Dockerfile
new file mode 100644
index 0000000..28a599a
--- /dev/null
+++ b/miner.Dockerfile
@@ -0,0 +1,9 @@
+FROM python:3.10.14-bookworm
+
+# Copy required files
+RUN mkdir -p /SoundsRightSubnet && mkdir -p /home/$USERNAME/.bittensor && mkdir -p /home/$USERNAME/.SoundsRightSubnet
+COPY soundsright /SoundsRightSubnet/soundsright
+COPY pyproject.toml /SoundsRightSubnet
+COPY .env /SoundsRightSubnet
+
+RUN /bin/bash -c "python3 -m venv /SoundsRightSubnet/.venv && source /SoundsRightSubnet/.venv/bin/activate && pip3 install pesq==0.0.4 && pip3 install -e /SoundsRightSubnet/.[validator]"
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..de49e24
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,61 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "SoundsRight"
+version = "1.0.0"
+description = "This project implements the SoundsRight Bittensor subnet."
+authors = [
+ { name = "synapsec.ai", email = "contact@synapsec.ai" }
+]
+readme = { file = "README.md", content-type = "text/markdown" }
+license = { file = "LICENSE" }
+classifiers = [
+ "Development Status :: 3 - Beta",
+ "Intended Audience :: Developers",
+ "Topic :: Software Development :: Build Tools",
+ "License :: OSI Approved :: MIT License",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering",
+ "Topic :: Scientific/Engineering :: Mathematics",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries",
+ "Topic :: Software Development :: Libraries :: Python Modules"
+]
+requires-python = ">=3.10,<3.13"
+
+dependencies = [
+ "bittensor==8.5.1",
+ "python-dotenv==1.0.1",
+ "fastapi==0.110.1",
+ "pydantic==2.9.2",
+ "uvicorn==0.30.0",
+ "numpy==2.0.1",
+]
+
+[project.urls]
+homepage = "https://github.com/synapsec-ai/SoundsRightSubnet"
+
+[project.optional-dependencies]
+validator = [
+ "PyYAML==6.0.2",
+ "requests==2.32.3",
+ "librosa==0.10.2.post1",
+ "scipy==1.14.1",
+ "GitPython==3.1.43",
+ "soundfile==0.12.1",
+ "pyloudnorm==0.1.1",
+ "openai==1.54.5",
+ "torch==2.5.1",
+ "torchaudio==2.5.1",
+ "pystoi==0.4.1",
+]
+testing = [
+ "pytest==8.3.3",
+]
+
+[tool.setuptools.packages.find]
+include = ["soundsright"]
\ No newline at end of file
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/assets/clean/1.wav b/scripts/assets/clean/1.wav
new file mode 100755
index 0000000..5925baa
Binary files /dev/null and b/scripts/assets/clean/1.wav differ
diff --git a/scripts/assets/clean/2.wav b/scripts/assets/clean/2.wav
new file mode 100755
index 0000000..5925baa
Binary files /dev/null and b/scripts/assets/clean/2.wav differ
diff --git a/scripts/assets/reverb/1.wav b/scripts/assets/reverb/1.wav
new file mode 100755
index 0000000..78694c0
Binary files /dev/null and b/scripts/assets/reverb/1.wav differ
diff --git a/scripts/assets/reverb/2.wav b/scripts/assets/reverb/2.wav
new file mode 100755
index 0000000..78694c0
Binary files /dev/null and b/scripts/assets/reverb/2.wav differ
diff --git a/scripts/generate_dataset.py b/scripts/generate_dataset.py
new file mode 100644
index 0000000..6554a41
--- /dev/null
+++ b/scripts/generate_dataset.py
@@ -0,0 +1,68 @@
+import argparse
+from soundsright.base.data import generate_dataset_for_miner
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--clean_dir",
+ type=str,
+ help="Path of directory where you want your clean data to go.",
+ required=True
+)
+parser.add_argument(
+ "--sample_rate",
+ type=int,
+ help="Sample rate must be an int, default is 16000.",
+ required=False,
+ choices=[16000],
+ default=16000
+)
+parser.add_argument(
+ "--n",
+ type=int,
+ help="The number of data files you want to generate.",
+ required=True,
+)
+parser.add_argument(
+ "--task",
+ type=str,
+ help="The task you want to generate a dataset for. One of: 'denoising', 'dereverberation' or 'both'.",
+ required=True,
+ choices=['denoising', 'dereverberation', 'both'],
+)
+parser.add_argument(
+ "--noise_dir",
+ type=str,
+ help="The directory where the noisy dataset will be stored. You only need to input this if you want to generate a dataset for the denoising task.",
+ default=None,
+)
+parser.add_argument(
+ "--noise_data_dir",
+ type=str,
+ help="The directory where the data to generate noisy datasets will be stored. You only need to input this if you want to generate a dataset for the denoising task.",
+ default=None,
+)
+parser.add_argument(
+ "--reverb_dir",
+ type=str,
+ help="The directory where the reverberation dataset will be stored. You only need to input this if you want to generate a dataset for the dereverberation task.",
+ default=None,
+)
+parser.add_argument(
+ "--reverb_data_dir",
+ type=str,
+ help="The directory where data to generate reverberation datasets will be stored. You only need to input this if you want to generate a dataset for the dereverberation task.",
+ default=None,
+)
+
+args = parser.parse_args()
+
+generate_dataset_for_miner(
+ clean_dir=args.clean_dir,
+ sample_rate=args.sample_rate,
+ n=args.n,
+ task=args.task,
+ reverb_data_dir=args.reverb_data_dir,
+ noise_data_dir=args.noise_data_dir,
+ reverb_dir=args.reverb_dir,
+ noise_dir=args.noise_dir,
+)
\ No newline at end of file
diff --git a/scripts/run_validator.sh b/scripts/run_validator.sh
new file mode 100644
index 0000000..0d72f9c
--- /dev/null
+++ b/scripts/run_validator.sh
@@ -0,0 +1,218 @@
+#!/bin/bash
+declare -A args
+
+check_runtime_environment() {
+ if ! python --version "$1" &>/dev/null; then
+ echo "ERROR: Python is not available. Make sure Python is installed and venv has been activated."
+ exit 1
+ fi
+
+ # Get Python version
+ python_version=$(python -c 'import sys; print(sys.version_info[:])')
+ IFS=', ' read -r -a values <<< "$(sed 's/[()]//g; s/,//g' <<< "$python_version")"
+
+ # Validate that we are on a version greater than 3
+ if ! [[ ${values[0]} -ge 3 ]]; then
+ echo "ERROR: The current major version of python "${values[0]}" is less than required: 3"
+ exit 1
+ fi
+
+ # Validate that the minor version is at least 10
+ if ! [[ ${values[1]} -ge 12 ]]; then
+ echo "ERROR: The current minor version of python "${values[1]}" is less than required: 12"
+ exit 1
+ fi
+
+ echo "The installed python version "${values[0]}"."${values[1]}" meets the minimum requirement (3.12)."
+
+ # Check that the required packages are installed. These should be bundled with the OS and/or Python version.
+ # If they do not exists, they should be installed manually. We do not want to install these in the run script,
+ # as it could mess up the local system
+
+ package_list=("libssl-dev" "python"${values[0]}"."${values[1]}"-dev")
+
+ error=0
+ for package_name in "${package_list[@]}"; do
+ if ! dpkg -l | grep -q -w "^ii $package_name"; then
+ echo "ERROR: $package_name is not installed. Please install it manually."
+ error=1
+ fi
+ done
+
+ if [[ $error -eq 1 ]]; then
+ exit 1
+ fi
+
+ if [ -n "$VIRTUAL_ENV" ]; then
+ echo "Virtual environment is activated: $VIRTUAL_ENV"
+ else
+ echo "WARNING: Virtual environment is not activated. It is recommended to run this script in a python venv."
+ fi
+}
+
+parse_arguments() {
+
+ while [[ $# -gt 0 ]]; do
+ if [[ $1 == "--"* ]]; then
+ arg_name=${1:2} # Remove leading "--" from the argument name
+
+ # Special handling for logging argument
+ if [[ "$arg_name" == "logging"* ]]; then
+ shift
+ if [[ $1 != "--"* ]]; then
+ IFS='.' read -ra parts <<< "$arg_name"
+ args[${parts[0]}]=${parts[1]}
+ fi
+ else
+ shift
+ args[$arg_name]="$1" # Assign the argument value to the argument name
+ fi
+ fi
+ shift
+ done
+
+ for key in "${!args[@]}"; do
+ echo "Argument: $key, Value: ${args[$key]}"
+ done
+}
+
+pull_repo_and_checkout_branch() {
+ local branch="${args['branch']}"
+
+ # Pull the latest repository
+ git pull --all
+
+ # Change to the specified branch if provided
+ if [[ -n "$branch" ]]; then
+ echo "Switching to branch: $branch"
+ git checkout "$branch" || { echo "Branch '$branch' does not exist."; exit 1; }
+ fi
+
+ local current_branch=$(git symbolic-ref --short HEAD)
+ git fetch &>/dev/null # Silence output from fetch command
+ if ! git rev-parse --quiet --verify "origin/$current_branch" >/dev/null; then
+ echo "You are using a branch that does not exists in remote. Make sure your local branch is up-to-date with the latest version in the main branch."
+ fi
+}
+
+install_packages() {
+ # local cfg_version=$(grep -oP 'version\s*=\s*\K[^ ]+' setup.cfg)
+ local installed_version=$(pip show SoundsRight | grep -oP 'Version:\s*\K[^ ]+')
+
+ # Load dotenv configuration
+ DOTENV_FILE=".env"
+ if [ -f "$DOTENV_FILE" ]; then
+ # Load environment variables from .env file
+ export $(grep -v '^#' $DOTENV_FILE | xargs)
+ echo "Environment variables loaded from $DOTENV_FILE"
+ fi
+
+ # if [[ "$cfg_version" == "$installed_version" ]]; then
+ # echo "Subnet versions "$cfg_version" and "$installed_version" are matching: No installation is required."
+ # else
+ echo "Installing python package with pip with validator extras"
+ pip install -e .[validator]
+
+ # fi
+
+ # Uvloop re-implements asyncio module which breaks bittensor. It is
+ # not needed by the default implementation of the
+ # soundsright-subnet, so we can uninstall it.
+ if pip show uvloop &>/dev/null; then
+ echo "Uninstalling conflicting module uvloop"
+ pip uninstall -y uvloop
+ fi
+
+ echo "All python packages are installed"
+}
+
+generate_pm2_launch_file() {
+ echo "Generating PM2 launch file"
+ local cwd=$(pwd)
+ local neuron_script="${cwd}/soundsright/neurons/validator.py"
+ local interpreter="${VIRTUAL_ENV}/bin/python"
+ local branch="${args['branch']}"
+ local name="${args['name']}"
+ local dataset_size="${args['dataset_size']}"
+ local max_memory_restart="${args['max_memory_restart']}"
+ # Script arguments
+ local netuid="${NETUID}"
+ local subtensor_chain_endpoint="${SUBTENSOR_CHAIN_ENDPOINT}"
+ local wallet_name="${WALLET}"
+ local wallet_hotkey="${HOTKEY}"
+ local logging_value="${LOG_LEVEL}"
+ local healthcheck_api_host="$HEALTHCHECK_API_HOST"
+ local healthcheck_api_port="$HEALTHCHECK_API_PORT"
+
+ # Construct argument list for the neuron
+ if [[ -z "$netuid" || -z "$wallet_name" || -z "$wallet_hotkey" || -z "$name" || -z "$max_memory_restart" ]]; then
+ echo "name, max_memory_restart, netuid, wallet.name, and wallet.hotkey are mandatory arguments."
+ exit 1
+ fi
+
+ local launch_args="--netuid $netuid --wallet.name $wallet_name --wallet.hotkey $wallet_hotkey"
+
+ if [[ -n "$subtensor_chain_endpoint" ]]; then
+ launch_args+=" --subtensor.chain_endpoint $subtensor_chain_endpoint"
+ fi
+
+ if [[ -n "$logging_value" ]]; then
+ launch_args+=" --log_level $logging_value"
+ fi
+
+ if [[ -n "$healthcheck_host" ]]; then
+ launch_args+=" --healthcheck_host $healthcheck_host"
+ fi
+
+ if [[ -n "$healthcheck_port" ]]; then
+ launch_args+=" --healthcheck_port $healthcheck_port"
+ fi
+
+ if [[ -n "$dataset_size" ]]; then
+ dataset_size_value="${arg#*=}"
+ launch_args+=" --dataset_size $dataset_size"
+ fi
+
+ if [[ -v args['debug_mode'] ]]; then
+ launch_args+=" --debug_mode"
+ fi
+
+ echo "Launch arguments: $launch_args"
+
+ cat < ${name}.config.js
+module.exports = {
+ apps: [
+ {
+ "name" : "${name}",
+ "script" : "${neuron_script}",
+ "interpreter" : "${interpreter}",
+ "args" : "${launch_args}",
+ "max_memory_restart" : "${max_memory_restart}"
+ }
+ ]
+}
+EOF
+}
+
+launch_pm2_instance() {
+ local name="${args['name']}"
+ eval "pm2 start ${name}.config.js"
+}
+
+echo "### START OF EXECUTION ###"
+# Parse arguments and assign to associative array
+parse_arguments "$@"
+
+check_runtime_environment
+echo "Python venv checks completed. Sleeping 2 seconds."
+sleep 2
+pull_repo_and_checkout_branch
+echo "Repo pulled and branch checkout done. Sleeping 2 seconds."
+sleep 2
+install_packages
+echo "Installation done. Sleeping 2 seconds."
+sleep 2
+echo "Generating PM2 ecosystem file"
+generate_pm2_launch_file
+echo "Launching PM instance"
+launch_pm2_instance
\ No newline at end of file
diff --git a/scripts/verify_miner_model.py b/scripts/verify_miner_model.py
new file mode 100644
index 0000000..8354ac3
--- /dev/null
+++ b/scripts/verify_miner_model.py
@@ -0,0 +1,199 @@
+import argparse
+import logging
+import os
+import shutil
+import time
+import glob
+
+import soundsright.base.utils as Utils
+import soundsright.base.models as Models
+import soundsright.base.benchmarking as Benchmarking
+
+logging.basicConfig(
+ level=logging.DEBUG,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.StreamHandler()
+ ]
+)
+
+def validate_all_reverb_files_are_enhanced(reverb_dir, enhanced_dir):
+ reverb_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(reverb_dir, '*.wav'))])
+ enhanced_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(enhanced_dir, '*.wav'))])
+ return reverb_files == enhanced_files
+
+def initialize_run_and_benchmark_model(model_namespace, model_name, model_revision):
+
+ clean_dir = os.path.join(os.getcwd(), "assets", "clean")
+ reverb_dir = os.path.join(os.getcwd(), "assets", "reverb")
+
+ output_base_path = os.getcwd()
+ model_dir = os.path.join(output_base_path, "model")
+ model_output_dir = os.path.join(output_base_path, "model_output")
+ for d in [model_dir, model_output_dir]:
+ if not os.path.exists(d):
+ os.makedirs(d)
+
+ logging.info(f"{model_dir} exists: {os.path.exists(model_dir)}\n{model_output_dir} exists: {os.path.exists(model_output_dir)}")
+
+ logging.info("Downloading model:")
+ try:
+ model_hash = Models.get_model_content_hash(
+ model_id = f"{model_namespace}/{model_name}",
+ revision=model_revision,
+ local_dir=model_dir,
+ log_level="TRACE"
+ )
+ logging.info(f"Model downloaded. Model hash:")
+ print(model_hash)
+ except Exception as e:
+ logging.error(f"Model download failed because: {e}")
+
+ logging.info("Validating container configuration:")
+ if not Utils.validate_container_config(model_dir):
+ logging.error("Container config validation failed.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Container validation succeeded.")
+
+ Utils.delete_container(log_level="TRACE")
+
+ logging.info("Starting container:")
+ if not Utils.start_container(directory=model_dir, log_level="TRACE"):
+ logging.error("Container could not be started.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Container started.")
+
+ time.sleep(10)
+
+ logging.info("Checking container status:")
+ if not Utils.check_container_status(log_level="TRACE"):
+ logging.error("Container status check failed. Please check your /status/ endpoint.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Container status check successful.")
+
+ time.sleep(1)
+
+ logging.info("Preparing model:")
+ if not Utils.prepare(log_level="TRACE"):
+ logging.error("Model preparation failed. Please check your /prepare/ endpoint.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Container preparation successful.")
+
+ time.sleep(10)
+
+ logging.info("Uploading audio:")
+ if not Utils.upload_audio(noisy_dir=reverb_dir, log_level="TRACE"):
+ logging.error("Reverb audio upload failed. Please check your /upload-audio/ endpoint.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Audio upload successful.")
+
+ time.sleep(5)
+
+ logging.info("Enhancing audio:")
+ if not Utils.enhance_audio(log_level="TRACE"):
+ logging.error("Audio enhancement failed. Please check your /enhance/ endpoint.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Audio enhancement successful.")
+
+ time.sleep(5)
+
+ logging.info("Downloading enhanced files:")
+ if not Utils.download_enhanced(enhanced_dir=model_output_dir,log_level="TRACE"):
+ logging.error("Could not download enhanced files. Please check your /download-enhanced/ endpoint.")
+ Utils.delete_container(log_level="TRACE")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("Enhanced audio download successful.")
+
+ Utils.delete_container(log_level="TRACE")
+
+ logging.info("Checking to make sure that all files were enhanced:")
+ if not validate_all_reverb_files_are_enhanced(reverb_dir=reverb_dir, enhanced_dir=model_output_dir):
+ logging.error("Mismatch between reverb files and enhanced files. Your model did not return all of the audio files it was expected to.")
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+ return False
+ logging.info("File validation successful.")
+
+ clean_files = sorted([f for f in os.listdir(clean_dir) if f.lower().endswith('.wav')])
+ enhanced_files = sorted([f for f in os.listdir(model_output_dir) if f.lower().endswith('.wav')])
+ noisy_files = sorted([f for f in os.listdir(reverb_dir) if f.lower().endswith('.wav')])
+
+ logging.info(f"Clean files: {clean_files}\nNoisy files: {noisy_files}\nEnhanced files: {enhanced_files}")
+
+ logging.info("Calculating metrics:")
+ try:
+ metrics_dict = Benchmarking.calculate_metrics_dict(
+ sample_rate=16000,
+ clean_directory=clean_dir,
+ enhanced_directory=model_output_dir,
+ noisy_directory=reverb_dir,
+ log_level="TRACE",
+ )
+ logging.info(f"Calculated model performance benchmarks: {metrics_dict}")
+ except Exception as e:
+ logging.error(f"Benchmarking metrics could not be calculated because: {e}")
+
+ shutil.rmtree(model_dir)
+ shutil.rmtree(model_output_dir)
+
+ return True
+
+def verify_miner_model(model_namespace, model_name, model_revision):
+
+ logging.info(f"Starting verificaiton for model: huggingface.co/{model_namespace}/{model_name}/tree/{model_revision}")
+
+ if not initialize_run_and_benchmark_model(model_namespace=model_namespace, model_name=model_name, model_revision=model_revision):
+ logging.critical(f"MODEL VERIFICATION FAILED.")
+
+ logging.info("\n\nMODEL VERIFICATION SUCCESSFUL.")
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--model_namespace",
+ type=str,
+ help="HuggingFace namespace (user/org name).",
+ required=True,
+ )
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ help="HuggingFace model name.",
+ required=True,
+ )
+ parser.add_argument(
+ "--model_revision",
+ type=str,
+ help="HuggingFace model revision (branch name).",
+ required=True,
+ )
+
+ args = parser.parse_args()
+
+ verify_miner_model(
+ model_namespace=args.model_namespace,
+ model_name=args.model_name,
+ model_revision=args.model_revision,
+ )
\ No newline at end of file
diff --git a/soundsright/__init__.py b/soundsright/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/soundsright/base/__init__.py b/soundsright/base/__init__.py
new file mode 100644
index 0000000..94c37a2
--- /dev/null
+++ b/soundsright/base/__init__.py
@@ -0,0 +1,6 @@
+from .neuron import BaseNeuron
+
+from .protocol import (
+ Denoising_16kHz_Protocol,
+ Dereverberation_16kHz_Protocol,
+)
\ No newline at end of file
diff --git a/soundsright/base/benchmarking/__init__.py b/soundsright/base/benchmarking/__init__.py
new file mode 100644
index 0000000..3829fb2
--- /dev/null
+++ b/soundsright/base/benchmarking/__init__.py
@@ -0,0 +1,24 @@
+from .metrics import (
+ calculate_si_sir_for_directories,
+ calculate_si_sar_for_directories,
+ calculate_si_sdr_for_directories,
+ calculate_pesq_for_directories,
+ calculate_estoi_for_directories,
+ calculate_metrics_dict,
+)
+
+from .scoring import (
+ calculate_improvement_factor,
+ new_model_surpasses_historical_model,
+ get_best_model_from_list,
+ determine_competition_scores,
+ calculate_overall_scores,
+ filter_models_with_same_hash,
+ filter_models_with_same_metadata,
+ filter_models_for_deregistered_miners,
+)
+
+from .remote_logging import (
+ miner_models_remote_logging,
+ sgmse_remote_logging,
+)
\ No newline at end of file
diff --git a/soundsright/base/benchmarking/metrics.py b/soundsright/base/benchmarking/metrics.py
new file mode 100644
index 0000000..5b714b0
--- /dev/null
+++ b/soundsright/base/benchmarking/metrics.py
@@ -0,0 +1,616 @@
+import librosa
+import soundfile as sf
+from pesq import pesq
+from pystoi import stoi
+import numpy as np
+from scipy import stats
+import os
+from typing import List, Tuple
+
+import soundsright.base.utils as Utils
+
+def si_sdr_components(s_hat, s, n):
+ """
+ Adapted from SGMSE+ [1,2,3,4]
+
+ [1] Richter, Julius, de Oliveira, Danilo, & Gerkmann, Timo.
+ Investigating Training Objectives for Generative Speech Enhancement.
+ arXiv preprint, https://arxiv.org/abs/2409.10753, 2024.
+
+ [2] Richter, Julius, Welker, Simon, Lemercier, Jean-Marie, Lay, Bunlong, & Gerkmann, Timo.
+ Speech Enhancement and Dereverberation with Diffusion-based Generative Models.
+ IEEE/ACM Transactions on Audio, Speech, and Language Processing, 31, 2351–2364.
+ https://doi.org/10.1109/TASLP.2023.3285241, 2023.
+
+ [3] Richter, Julius, Wu, Yi-Chiao, Krenn, Steven, Welker, Simon, Lay, Bunlong, Watanabe, Shinji, Richard, Alexander, & Gerkmann, Timo.
+ EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation.
+ In ISCA Interspeech, 2024.
+
+ [4] Welker, Simon, Richter, Julius, & Gerkmann, Timo.
+ Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain.
+ In Proceedings of Interspeech 2022, 2928–2932.
+ https://doi.org/10.21437/Interspeech.2022-10653, 2022.
+ """
+ # s_target
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
+ s_target = alpha_s * s
+
+ # e_noise
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
+ e_noise = alpha_n * n
+
+ # e_art
+ e_art = s_hat - s_target - e_noise
+
+ return s_target, e_noise, e_art
+
+def energy_ratios(s_hat, s, n):
+ """
+ Adapted from SGMSE+ [1,2,3,4]
+
+ [1] Richter, Julius, de Oliveira, Danilo, & Gerkmann, Timo.
+ Investigating Training Objectives for Generative Speech Enhancement.
+ arXiv preprint, https://arxiv.org/abs/2409.10753, 2024.
+
+ [2] Richter, Julius, Welker, Simon, Lemercier, Jean-Marie, Lay, Bunlong, & Gerkmann, Timo.
+ Speech Enhancement and Dereverberation with Diffusion-based Generative Models.
+ IEEE/ACM Transactions on Audio, Speech, and Language Processing, 31, 2351–2364.
+ https://doi.org/10.1109/TASLP.2023.3285241, 2023.
+
+ [3] Richter, Julius, Wu, Yi-Chiao, Krenn, Steven, Welker, Simon, Lay, Bunlong, Watanabe, Shinji, Richard, Alexander, & Gerkmann, Timo.
+ EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation.
+ In ISCA Interspeech, 2024.
+
+ [4] Welker, Simon, Richter, Julius, & Gerkmann, Timo.
+ Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain.
+ In Proceedings of Interspeech 2022, 2928–2932.
+ https://doi.org/10.21437/Interspeech.2022-10653, 2022.
+ """
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
+
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
+
+ return si_sdr, si_sir, si_sar
+
+def calculate_si_sdr_for_directories(clean_directory: str, enhanced_directory: str, noisy_directory: str, sample_rate: int, log_level: str, confidence_level: float = 0.95) -> dict:
+ """
+ Calculate SI_SDR scores for all matching audio files in the given directories and compute the average SI_SDR score with a confidence interval.
+
+ Parameters:
+ :param clean_directory: (str): Path to the directory containing clean reference audio files.
+ :param enhanced_directory: (str): Path to the directory containing enhanced degraded audio files.
+ :param noisy_directory: (str): Path to the directory containing the noisy audio files.
+ :param sample_rate: {int): Sampling rate to use for SI_SDR calculation (8000 or 16000 Hz).
+ :param confidence_level: (float): Confidence level for the confidence interval (default is 0.95 for 95%).
+
+ Returns:
+ dict
+ """
+
+ # Get list of audio files in both directories
+ clean_files = sorted([f for f in os.listdir(clean_directory) if f.lower().endswith('.wav')])
+ enhanced_files = sorted([f for f in os.listdir(enhanced_directory) if f.lower().endswith('.wav')])
+ noisy_files = sorted([f for f in os.listdir(noisy_directory) if f.lower().endswith('.wav')])
+ # Match files by name
+ matched_files = set(clean_files).intersection(enhanced_files, noisy_files)
+ if not matched_files or len(matched_files) <= (0.95*len(clean_files)):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error calculating SI_SDR: No matching audio files found in the provided directories.",
+ log_level=log_level
+ )
+ raise ValueError("No matching audio files found in the provided directories.")
+
+ si_sdr_scores = []
+
+ for file_name in matched_files:
+ try:
+ clean_audio_path = os.path.join(clean_directory, file_name)
+ enhanced_audio_path = os.path.join(enhanced_directory, file_name)
+ noisy_audio_path = os.path.join(noisy_directory, file_name)
+
+ # Load the clean audio file
+ clean_audio, clean_sr = sf.read(clean_audio_path)
+ # Load the enhanced audio file
+ enhanced_audio, enhanced_sr = sf.read(enhanced_audio_path)
+ # Load the noisy audio file
+ noisy_audio, noisy_sr = sf.read(noisy_audio_path)
+
+ if clean_sr != enhanced_sr or clean_sr != noisy_sr:
+ continue
+
+ # Ensure the signals have the same length
+ if len(clean_audio) != len(enhanced_audio) or len(noisy_audio) != len(enhanced_audio):
+ continue
+
+ # Convert to float32 type
+ clean_audio = clean_audio.astype(np.float32)
+ enhanced_audio = enhanced_audio.astype(np.float32)
+ noisy_audio = noisy_audio.astype(np.float32)
+ noise = noisy_audio - clean_audio
+
+ # Calculate the SI_SDR score
+ si_sdr_score = float(energy_ratios(enhanced_audio, clean_audio, noise)[0])
+ si_sdr_scores.append(si_sdr_score)
+
+ except:
+ continue
+
+ if not si_sdr_scores:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="No SI_SDR scores were calculated. Check your audio files and directories.",
+ log_level=log_level,
+ )
+ raise ValueError("No SI_SDR scores were calculated. Check your audio files and directories.")
+
+ # Calculate average SI_SDR score
+ average_si_sdr = np.mean(si_sdr_scores)
+
+ # Calculate confidence interval
+ n = len(si_sdr_scores)
+ stderr = stats.sem(si_sdr_scores)
+ t_score = stats.t.ppf((1 + confidence_level) / 2.0, df=n - 1)
+ margin_of_error = t_score * stderr
+ confidence_interval = (average_si_sdr - margin_of_error, average_si_sdr + margin_of_error)
+
+ output = {
+ "scores":si_sdr_scores,
+ "average":average_si_sdr,
+ "confidence_interval":confidence_interval,
+ }
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"SI_SDR metrics output: {output}",
+ log_level=log_level,
+ )
+
+ return output
+
+def calculate_si_sir_for_directories(clean_directory: str, enhanced_directory: str, noisy_directory: str, sample_rate: int, log_level: str, confidence_level: float = 0.95) -> dict:
+ """
+ Calculate SI_SIR scores for all matching audio files in the given directories and compute the average SI_SIR score with a confidence interval.
+
+ Parameters:
+ :param clean_directory: (str): Path to the directory containing clean reference audio files.
+ :param enhanced_directory: (str): Path to the directory containing enhanced degraded audio files.
+ :param noisy_directory: (str): Path to the directory containing the noisy audio files.
+ :param sample_rate: {int): Sampling rate to use for SI_SIR calculation (8000 or 16000 Hz).
+ :param confidence_level: (float): Confidence level for the confidence interval (default is 0.95 for 95%).
+
+ Returns:
+ dict
+ """
+
+ # Get list of audio files in both directories
+ clean_files = sorted([f for f in os.listdir(clean_directory) if f.lower().endswith('.wav')])
+ enhanced_files = sorted([f for f in os.listdir(enhanced_directory) if f.lower().endswith('.wav')])
+ noisy_files = sorted([f for f in os.listdir(noisy_directory) if f.lower().endswith('.wav')])
+ # Match files by name
+ matched_files = set(clean_files).intersection(enhanced_files, noisy_files)
+ if not matched_files:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error calculating SI_SIR: No matching audio files found in the provided directories.",
+ log_level=log_level
+ )
+ raise ValueError("No matching audio files found in the provided directories.")
+
+ si_sir_scores = []
+
+ for file_name in matched_files or len(matched_files) <= (0.95*len(clean_files)):
+ try:
+ clean_audio_path = os.path.join(clean_directory, file_name)
+ enhanced_audio_path = os.path.join(enhanced_directory, file_name)
+ noisy_audio_path = os.path.join(noisy_directory, file_name)
+
+ # Load the clean audio file
+ clean_audio, clean_sr = sf.read(clean_audio_path)
+ # Load the enhanced audio file
+ enhanced_audio, enhanced_sr = sf.read(enhanced_audio_path)
+ # Load the noisy audio file
+ noisy_audio, noisy_sr = sf.read(noisy_audio_path)
+
+ if clean_sr != enhanced_sr or clean_sr != noisy_sr:
+ continue
+
+ # Ensure the signals have the same length
+ if len(clean_audio) != len(enhanced_audio) or len(noisy_audio) != len(enhanced_audio):
+ continue
+
+ # Convert to float32 type
+ clean_audio = clean_audio.astype(np.float32)
+ enhanced_audio = enhanced_audio.astype(np.float32)
+ noisy_audio = noisy_audio.astype(np.float32)
+ noise = noisy_audio - clean_audio
+
+ # Calculate the SI_SIR score
+ si_sir_score = float(energy_ratios(enhanced_audio, clean_audio, noise)[1])
+ si_sir_scores.append(si_sir_score)
+
+ except:
+ continue
+
+ if not si_sir_scores:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="No SI_SIR scores were calculated. Check your audio files and directories.",
+ log_level=log_level,
+ )
+ raise ValueError("No SI_SIR scores were calculated. Check your audio files and directories.")
+
+ # Calculate average SI_SIR score
+ average_si_sir = np.mean(si_sir_scores)
+
+ # Calculate confidence interval
+ n = len(si_sir_scores)
+ stderr = stats.sem(si_sir_scores)
+ t_score = stats.t.ppf((1 + confidence_level) / 2.0, df=n - 1)
+ margin_of_error = t_score * stderr
+ confidence_interval = (average_si_sir - margin_of_error, average_si_sir + margin_of_error)
+
+ output = {
+ "scores":si_sir_scores,
+ "average":average_si_sir,
+ "confidence_interval":confidence_interval,
+ }
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"SI_SIR metrics output: {output}",
+ log_level=log_level,
+ )
+
+ return output
+
+def calculate_si_sar_for_directories(clean_directory: str, enhanced_directory: str, noisy_directory: str, sample_rate: int, log_level: str, confidence_level: float = 0.95) -> dict:
+ """
+ Calculate SI_SAR scores for all matching audio files in the given directories and compute the average SI_SAR score with a confidence interval.
+
+ Parameters:
+ :param clean_directory: (str): Path to the directory containing clean reference audio files.
+ :param enhanced_directory: (str): Path to the directory containing enhanced degraded audio files.
+ :param noisy_directory: (str): Path to the directory containing the noisy audio files.
+ :param sample_rate: {int): Sampling rate to use for SI_SAR calculation (8000 or 16000 Hz).
+ :param confidence_level: (float): Confidence level for the confidence interval (default is 0.95 for 95%).
+
+ Returns:
+ dict
+ """
+
+ # Get list of audio files in both directories
+ clean_files = sorted([f for f in os.listdir(clean_directory) if f.lower().endswith('.wav')])
+ enhanced_files = sorted([f for f in os.listdir(enhanced_directory) if f.lower().endswith('.wav')])
+ noisy_files = sorted([f for f in os.listdir(noisy_directory) if f.lower().endswith('.wav')])
+ # Match files by name
+ matched_files = set(clean_files).intersection(enhanced_files, noisy_files)
+ if not matched_files or len(matched_files) <= (0.95*len(clean_files)):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error calculating SI_SAR: No matching audio files found in the provided directories.",
+ log_level=log_level
+ )
+ raise ValueError("No matching audio files found in the provided directories.")
+
+ si_sar_scores = []
+
+ for file_name in matched_files:
+ try:
+ clean_audio_path = os.path.join(clean_directory, file_name)
+ enhanced_audio_path = os.path.join(enhanced_directory, file_name)
+ noisy_audio_path = os.path.join(noisy_directory, file_name)
+
+ # Load the clean audio file
+ clean_audio, clean_sr = sf.read(clean_audio_path)
+ # Load the enhanced audio file
+ enhanced_audio, enhanced_sr = sf.read(enhanced_audio_path)
+ # Load the noisy audio file
+ noisy_audio, noisy_sr = sf.read(noisy_audio_path)
+
+ if clean_sr != enhanced_sr or clean_sr != noisy_sr:
+ continue
+
+ # Ensure the signals have the same length
+ if len(clean_audio) != len(enhanced_audio) or len(noisy_audio) != len(enhanced_audio):
+ continue
+
+ # Convert to float32 type
+ clean_audio = clean_audio.astype(np.float32)
+ enhanced_audio = enhanced_audio.astype(np.float32)
+ noisy_audio = noisy_audio.astype(np.float32)
+ noise = noisy_audio - clean_audio
+
+ # Calculate the SI_SAR score
+ si_sar_score = float(energy_ratios(enhanced_audio, clean_audio, noise)[2])
+ si_sar_scores.append(si_sar_score)
+ except:
+ continue
+
+ if not si_sar_scores:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="No SI_SAR scores were calculated. Check your audio files and directories.",
+ log_level=log_level,
+ )
+ raise ValueError("No SI_SAR scores were calculated. Check your audio files and directories.")
+
+ # Calculate average SI_SAR score
+ average_si_sar = np.mean(si_sar_scores)
+
+ # Calculate confidence interval
+ n = len(si_sar_scores)
+ stderr = stats.sem(si_sar_scores)
+ t_score = stats.t.ppf((1 + confidence_level) / 2.0, df=n - 1)
+ margin_of_error = t_score * stderr
+ confidence_interval = (average_si_sar - margin_of_error, average_si_sar + margin_of_error)
+
+ output = {
+ "scores":si_sar_scores,
+ "average":average_si_sar,
+ "confidence_interval":confidence_interval,
+ }
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"SI_SAR metrics output: {output}",
+ log_level=log_level,
+ )
+
+ return output
+
+def calculate_pesq_for_directories(clean_directory: str, enhanced_directory: str, sample_rate: int, log_level: str, confidence_level: float = 0.95) -> dict:
+ """
+ Calculate PESQ scores for all matching audio files in the given directories and compute the average PESQ score with a confidence interval.
+
+ Parameters:
+ :param clean_directory: (str): Path to the directory containing clean reference audio files.
+ :param enhanced_directory: (str): Path to the directory containing enhanced degraded audio files.
+ :param sample_rate: {int): Sampling rate to use for PESQ calculation (8000 or 16000 Hz).
+ :param confidence_level: (float): Confidence level for the confidence interval (default is 0.95 for 95%).
+
+ Returns:
+ dict
+ """
+
+ # Get list of audio files in both directories
+ clean_files = sorted([f for f in os.listdir(clean_directory) if f.lower().endswith('.wav')])
+
+ enhanced_files = sorted([f for f in os.listdir(enhanced_directory) if f.lower().endswith('.wav')])
+
+ # Match files by name
+ matched_files = set(clean_files).intersection(enhanced_files)
+ if not matched_files or len(matched_files) <= (0.95*len(clean_files)):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error calculating PESQ: No matching audio files found in the provided directories.",
+ log_level=log_level
+ )
+ raise ValueError("No matching audio files found in the provided directories.")
+
+ pesq_scores = []
+
+ for file_name in matched_files:
+ try:
+ clean_audio_path = os.path.join(clean_directory, file_name)
+ enhanced_audio_path = os.path.join(enhanced_directory, file_name)
+
+ # Load the clean audio file
+ clean_audio, clean_sr = sf.read(clean_audio_path)
+
+ # Load the enhanced audio file
+ enhanced_audio, enhanced_sr = sf.read(enhanced_audio_path)
+
+ if clean_sr != enhanced_sr:
+ continue
+
+ # Ensure the signals have the same length
+ if len(clean_audio) != len(enhanced_audio):
+ continue
+
+ # Convert to float32 type
+ clean_audio = clean_audio.astype(np.float32)
+ enhanced_audio = enhanced_audio.astype(np.float32)
+
+ # Set the mode based on the sample rate
+ if sample_rate == 8000:
+ mode = 'nb' # Narrow-band
+ elif sample_rate == 16000:
+ mode = 'wb' # Wide-band
+ else:
+ raise ValueError("Unsupported sample rate. Use 8000 or 16000 Hz.")
+
+ # Calculate the PESQ score
+ pesq_score = float(pesq(sample_rate, clean_audio, enhanced_audio, mode))
+
+ pesq_scores.append(pesq_score)
+ except Exception as e:
+ continue
+
+ if not pesq_scores:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="No PESQ scores were calculated. Check your audio files and directories.",
+ log_level=log_level,
+ )
+ raise ValueError("No PESQ scores were calculated. Check your audio files and directories.")
+
+ # Calculate average PESQ score
+ average_pesq = np.mean(pesq_scores)
+
+ # Calculate confidence interval
+ n = len(pesq_scores)
+ stderr = stats.sem(pesq_scores)
+ t_score = stats.t.ppf((1 + confidence_level) / 2.0, df=n - 1)
+ margin_of_error = t_score * stderr
+ confidence_interval = (average_pesq - margin_of_error, average_pesq + margin_of_error)
+
+ output = {
+ "scores":pesq_scores,
+ "average":average_pesq,
+ "confidence_interval":confidence_interval
+ }
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"PESQ metrics output: {output}",
+ log_level=log_level,
+ )
+
+ return output
+
+def calculate_estoi_for_directories(clean_directory: str, enhanced_directory: str, sample_rate: int, log_level: str, confidence_level: float = 0.95) -> dict:
+ """
+ Calculate ESTOI scores for all matching audio files in the given directories and compute the average ESTOI score with a confidence interval.
+
+ Parameters:
+ :param clean_directory: (str): Path to the directory containing clean reference audio files.
+ :param enhanced_directory: (str): Path to the directory containing enhanced degraded audio files.
+ :param sample_rate: {int): Sampling rate to use for ESTOI calculation (8000 or 16000 Hz).
+ :param confidence_level: (float): Confidence level for the confidence interval (default is 0.95 for 95%).
+
+ Returns:
+ dict
+ """
+
+ # Get list of audio files in both directories
+ clean_files = sorted([f for f in os.listdir(clean_directory) if f.lower().endswith('.wav')])
+ enhanced_files = sorted([f for f in os.listdir(enhanced_directory) if f.lower().endswith('.wav')])
+
+ # Match files by name
+ matched_files = set(clean_files).intersection(enhanced_files)
+ if not matched_files or len(matched_files) <= (0.95*len(clean_files)):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error calculating ESTOI: No matching audio files found in the provided directories.",
+ log_level=log_level
+ )
+ raise ValueError("No matching audio files found in the provided directories.")
+
+ estoi_scores = []
+
+ for file_name in matched_files:
+ try:
+ clean_audio_path = os.path.join(clean_directory, file_name)
+ enhanced_audio_path = os.path.join(enhanced_directory, file_name)
+
+ # Load the clean audio file
+ clean_audio, clean_sr = sf.read(clean_audio_path)
+ # Load the enhanced audio file
+ enhanced_audio, enhanced_sr = sf.read(enhanced_audio_path)
+
+ if clean_sr != enhanced_sr:
+ continue
+
+ # Ensure the signals have the same length
+ if len(clean_audio) != len(enhanced_audio):
+ continue
+
+ # Convert to float32 type
+ clean_audio = clean_audio.astype(np.float32)
+ enhanced_audio = enhanced_audio.astype(np.float32)
+
+ # Calculate the ESTOI score
+ estoi_score = float(stoi(x=clean_audio, y=enhanced_audio, fs_sig=sample_rate))
+ estoi_scores.append(estoi_score)
+ except:
+ continue
+
+ if not estoi_scores:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="No ESTOI scores were calculated. Check your audio files and directories.",
+ log_level=log_level,
+ )
+ raise ValueError("No ESTOI scores were calculated. Check your audio files and directories.")
+
+ # Calculate average ESTOI score
+ average_estoi = np.mean(estoi_scores)
+
+ # Calculate confidence interval
+ n = len(estoi_scores)
+ stderr = stats.sem(estoi_scores)
+ t_score = stats.t.ppf((1 + confidence_level) / 2.0, df=n - 1)
+ margin_of_error = t_score * stderr
+ confidence_interval = (average_estoi - margin_of_error, average_estoi + margin_of_error)
+
+ output = {
+ "scores":estoi_scores,
+ "average":average_estoi,
+ "confidence_interval":confidence_interval,
+ }
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"ESTOI metrics output: {output}",
+ log_level=log_level,
+ )
+
+ return output
+
+def calculate_metrics_dict(sample_rate: int, clean_directory: str, enhanced_directory: str, noisy_directory: str, log_level: str) -> dict:
+
+ metrics_dict = {}
+
+ if sample_rate == 16000:
+ try:
+ # SI_SDR
+ metrics_dict['SI_SDR'] = calculate_si_sdr_for_directories(
+ clean_directory=clean_directory,
+ enhanced_directory=enhanced_directory,
+ noisy_directory=noisy_directory,
+ sample_rate=sample_rate,
+ log_level=log_level,
+ )
+ except:
+ metrics_dict['SI_SDR'] = {}
+
+ try:
+ # SI_SIR
+ metrics_dict['SI_SIR'] = calculate_si_sir_for_directories(
+ clean_directory=clean_directory,
+ enhanced_directory=enhanced_directory,
+ noisy_directory=noisy_directory,
+ sample_rate=sample_rate,
+ log_level=log_level,
+ )
+ except:
+ metrics_dict['SI_SIR'] = {}
+
+ try:
+ # SI_SAR
+ metrics_dict['SI_SAR'] = calculate_si_sar_for_directories(
+ clean_directory=clean_directory,
+ enhanced_directory=enhanced_directory,
+ noisy_directory=noisy_directory,
+ sample_rate=sample_rate,
+ log_level=log_level,
+ )
+ except:
+ metrics_dict['SI_SAR'] = {}
+
+ try:
+ # PESQ
+ metrics_dict['PESQ'] = calculate_pesq_for_directories(
+ clean_directory=clean_directory,
+ enhanced_directory=enhanced_directory,
+ sample_rate=sample_rate,
+ log_level=log_level,
+ )
+ except:
+ metrics_dict['PESQ'] = {}
+
+ try:
+ # ESTOI
+ metrics_dict['ESTOI'] = calculate_estoi_for_directories(
+ clean_directory=clean_directory,
+ enhanced_directory=enhanced_directory,
+ sample_rate=sample_rate,
+ log_level=log_level,
+ )
+ except:
+ metrics_dict['ESTOI'] = {}
+
+ return metrics_dict
\ No newline at end of file
diff --git a/soundsright/base/benchmarking/remote_logging.py b/soundsright/base/benchmarking/remote_logging.py
new file mode 100644
index 0000000..2f43e00
--- /dev/null
+++ b/soundsright/base/benchmarking/remote_logging.py
@@ -0,0 +1,154 @@
+import secrets
+import requests
+import time
+import json
+import bittensor as bt
+
+import soundsright.base.utils as Utils
+
+def requests_post(url, headers: dict, data: dict, log_level: str, timeout: int = 12) -> dict:
+ """Handles sending remote logs to SYNAPSEC remote logging API"""
+ try:
+ # Get prompt
+ res = requests.post(url=url, headers=headers, data=json.dumps(data), timeout=timeout)
+ # Check for correct status code
+ if res.status_code == 201:
+ return res
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Unable to connect to remote host: {url}: HTTP/{res.status_code} - {res.json()}",
+ log_level=log_level,
+ )
+
+ return res
+
+ except requests.exceptions.ReadTimeout as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Remote API request timed out: {e}",
+ log_level=log_level,
+ )
+ except requests.exceptions.JSONDecodeError as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Unable to read the response from the remote API: {e}",
+ log_level=log_level,
+ )
+ except requests.exceptions.ConnectionError as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Unable to connect to the remote API: {e}",
+ log_level=log_level,
+ )
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f'Generic error during request: {e}',
+ log_level=log_level,
+ )
+
+ return {}
+
+def miner_models_remote_logging(hotkey: bt.Keypair, current_miner_models: dict, log_level: str) -> bool:
+ """
+ Attempts to log the best models from current competition.
+
+ Returns:
+ bool: True if logging was successful, False otherwise
+ """
+ nonce = str(secrets.token_hex(24))
+ timestamp = str(int(time.time()))
+
+ signature = Utils.sign_data(hotkey=hotkey, data=f'{nonce}-{timestamp}')
+
+ headers = {
+ "X-Hotkey": hotkey.ss58_address,
+ "X-Signature": signature,
+ "X-Nonce": nonce,
+ "X-Timestamp": timestamp,
+ "X-API-Key":hotkey.ss58_address
+ }
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message=f"Sending current models to remote logger. Model data: {current_miner_models}. Headers: {headers}",
+ log_level=log_level,
+ )
+
+ body = {
+ "models":current_miner_models,
+ "category":"current"
+ }
+
+ res = requests_post(url="https://logs.soundsright.ai/", headers=headers, data=body, log_level=log_level)
+
+ if res and res.status_code == 201:
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message="Current model remote logging successful.",
+ log_level=log_level,
+ )
+
+ return True
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="Current model remote logging unsuccessful. Please contact subnet owners if issue persists.",
+ log_level=log_level,
+ )
+
+ return False
+
+def sgmse_remote_logging(hotkey: bt.Keypair, sgmse_benchmarks: dict, log_level: str) -> bool:
+
+ """
+ Attempts to log the best models from current competition.
+
+ Returns:
+ bool: True if logging was successful, False otherwise
+ """
+ nonce = str(secrets.token_hex(24))
+ timestamp = str(int(time.time()))
+
+ signature = Utils.sign_data(hotkey=hotkey, data=f'{nonce}-{timestamp}')
+
+ headers = {
+ "X-Hotkey": hotkey.ss58_address,
+ "X-Signature": signature,
+ "X-Nonce": nonce,
+ "X-Timestamp": timestamp,
+ "X-API-Key":hotkey.ss58_address
+ }
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message=f"Sending SGMSE+ benchmarks for all competitions on new dataset to remote logger. Model data: {sgmse_benchmarks}. Headers: {headers}",
+ log_level=log_level,
+ )
+
+ body = {
+ "models":sgmse_benchmarks,
+ "category":"sgmse"
+ }
+
+ res = requests_post(url="https://logs.soundsright.ai/", headers=headers, data=body, log_level=log_level)
+
+ if res and res.status_code == 201:
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message="SGMSE+ benchmark remote logging successful.",
+ log_level=log_level,
+ )
+
+ return True
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="SGMSE+ benchmark remote logging unsuccessful. Please contact subnet owners if issue persists.",
+ log_level=log_level,
+ )
+
+ return False
\ No newline at end of file
diff --git a/soundsright/base/benchmarking/scoring.py b/soundsright/base/benchmarking/scoring.py
new file mode 100644
index 0000000..30b63f4
--- /dev/null
+++ b/soundsright/base/benchmarking/scoring.py
@@ -0,0 +1,286 @@
+import bittensor as bt
+import numpy as np
+from typing import List
+
+import soundsright.base.utils as Utils
+
+# This function is adapted from the LinearDecay.compute_epsilon function in the taoverse repository
+# by macrocosm-os, available at https://github.com/macrocosm-os/taoverse/
+def calculate_improvement_factor(new_model_block, old_model_block, start_improvement = 0.0035, end_improvement = 0.0015, decay_block = 50400) -> float:
+ block_difference = max(new_model_block - old_model_block, 0)
+ block_adjustment = min(block_difference/decay_block, 1)
+ improvement_adjustment = block_adjustment * (start_improvement-end_improvement)
+ return start_improvement - improvement_adjustment
+
+def new_model_surpasses_historical_model(new_model_metric, new_model_block, old_model_metric, old_model_block) -> bool:
+ """
+ It is assumed that the higher the metric value, the better the performance.
+
+ A new model must have a performance metric that is higher than the current
+ best performing model by an improvement factor.
+ """
+ # Return False is new model underperforms old model
+ if new_model_metric <= old_model_metric:
+ return False
+ # Otherwise, we want to calculate the improvement factor based on block differential
+ improvement_factor = calculate_improvement_factor(new_model_block, old_model_block)
+ # If the new model has performance better or equal to the improvement factor return True
+ if (new_model_metric / old_model_metric) >= (improvement_factor + 1):
+ return True
+ # Othewrwise, return False
+ return False
+
+def get_best_model_from_list(models_data: List[dict], metric_name: str) -> dict:
+ """Gets the best model submitted during today's competition for a specific metric
+
+ Args:
+ current_models_data (List[dict]): List of model performance logs
+ metric_name (str): The metric we want to find the best model for
+
+ Returns:
+ dict: The dictionary representing the model with the highest average value for the specified metric.
+ """
+ best_model = None
+ highest_average = float('-inf')
+
+ for model in models_data:
+ metrics = model.get('metrics', {})
+ metric_data = metrics.get(metric_name, {})
+
+ # Ensure the metric_data contains 'average' and it is a number
+ if 'average' in metric_data and isinstance(metric_data['average'], (int, float)):
+ if metric_data['average'] > highest_average:
+ highest_average = metric_data['average']
+ best_model = model
+
+ return best_model
+
+def determine_competition_scores(
+ competition_scores: dict,
+ competition_max_scores: dict,
+ metric_proportions: dict,
+ best_miner_models: dict,
+ miner_models: dict,
+ metagraph: bt.metagraph,
+ log_level: str,
+):
+
+ # Construct new log of best performing models to update as we iterate
+ new_best_miner_models = {}
+ for competition in competition_scores.keys():
+ new_best_miner_models[competition] = []
+
+ # Iterate through competitions
+ for competition in competition_scores.keys():
+
+ # Iterate through metrics in each competition
+ for metric_name in metric_proportions[competition].keys():
+
+ # Determine the score to assign to the best miner
+ competition_metric_score = competition_max_scores[competition] * metric_proportions[competition][metric_name]
+
+ # Find best current model
+ current_models = miner_models[competition]
+ best_current_model = get_best_model_from_list(models_data=current_models, metric_name=metric_name)
+
+ # Continue to next iteration in loop in the case that no miner models have been submitted
+ if not best_current_model:
+ continue
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Best model for metric: {metric_name} in current competition: {competition} is: {best_current_model}",
+ log_level=log_level,
+ )
+
+ # Obtain best historical model
+ best_models = best_miner_models[competition]
+ best_historical_model = get_best_model_from_list(models_data=best_models, metric_name=metric_name)
+
+ # Assign score to the best current model if best historical model does not exist
+ if not best_historical_model:
+
+ uid = metagraph.hotkeys.index(best_current_model['hotkey'])
+ competition_scores[competition][uid] += competition_metric_score
+ new_best_miner_models[competition].append(best_current_model)
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Competition winner for metric: {metric_name} in current competition: {competition} is: {best_current_model}. Assigning score: {competition_metric_score}",
+ log_level=log_level,
+ )
+
+ continue
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Best historical model for metric: {metric_name} in current competition: {competition} is: {best_historical_model}",
+ log_level=log_level,
+ )
+
+ # Determine actual metric average values
+ best_current_model_metric_value = best_current_model['metrics'][metric_name]['average']
+ best_historical_model_metric_value = best_historical_model['metrics'][metric_name]['average']
+
+ # Determine metadata upload block
+ best_current_model_block = best_current_model['block']
+ best_historical_model_block = best_historical_model['block']
+
+ # Determine if new model beats historical model performance by signficiant margin
+ if new_model_surpasses_historical_model(
+ new_model_metric = best_current_model_metric_value,
+ new_model_block = best_current_model_block,
+ old_model_metric = best_historical_model_metric_value,
+ old_model_block = best_historical_model_block,
+ ):
+
+ # If so, assign score to new model
+ uid = metagraph.hotkeys.index(best_current_model['hotkey'])
+ competition_scores[competition][uid] += competition_metric_score
+
+ # Append to new best performing model knowledge
+ new_best_miner_models[competition].append(best_current_model)
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Competition winner for metric: {metric_name} in current competition: {competition} is: {best_current_model}. Assigning score: {competition_metric_score}",
+ log_level=log_level,
+ )
+
+ # Otherwise, assign score to old model
+ else:
+
+ uid = metagraph.hotkeys.index(best_historical_model['hotkey'])
+ competition_scores[competition][uid] += competition_metric_score
+
+ # Append to new best performing model knowledge
+ new_best_miner_models[competition].append(best_historical_model)
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Competition winner for metric: {metric_name} in current competition: {competition} is: {best_historical_model}. Assigning score: {competition_metric_score}",
+ log_level=log_level,
+ )
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"New best performing models: {new_best_miner_models}.",
+ log_level=log_level,
+ )
+
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"New competition scores: {competition_scores}",
+ log_level=log_level,
+ )
+
+ return new_best_miner_models, competition_scores
+
+def calculate_overall_scores(
+ competition_scores: dict,
+ scores: np.ndarray,
+ log_level: str
+):
+ for competition in competition_scores:
+ for i, _ in enumerate(competition_scores[competition]):
+ scores[i] += competition_scores[competition][i]
+
+ return scores
+
+def filter_models_with_same_hash(new_competition_miner_models: list) -> list:
+ """
+ Filter out model results if there are two models with the same directory hash.
+
+ We keep the model whose metadata was uploaded to the chain first.
+
+ Args:
+ :param new_competition_miner_models: (List[dict]): List of benchmarking results for models in current competition
+
+ Returns:
+ List[dict]: Filtered list of benchmarking results for models in current competition
+ """
+ # Dictionary to store the minimum 'block' for each unique 'model_hash'
+ unique_models = {}
+ blacklisted_models = []
+
+ for item in new_competition_miner_models:
+ if item and isinstance(item, dict) and 'model_hash' in item.keys() and 'block' in item.keys() and 'hf_model_namespace' in item.keys() and 'hf_model_name' in item.keys() and 'hf_model_revision' in item.keys():
+ model_hash = item['model_hash']
+ block = item['block']
+
+ # Check if this model_hash is already in the unique_models dictionary
+ if model_hash in unique_models:
+ # Keep the entry with the lowest 'block' value
+ if block < unique_models[model_hash]['block']:
+ blacklist_model = unique_models[model_hash]
+ filtered_blacklist_model = {
+ 'hf_model_namespace':blacklist_model['hf_model_namespace'],
+ 'hf_model_name':blacklist_model['hf_model_name'],
+ 'hf_model_revision':blacklist_model['hf_model_revision'],
+ }
+ blacklisted_models.append(filtered_blacklist_model)
+ unique_models[model_hash] = item
+ else:
+ # If model_hash not seen before, add it to unique_models
+ unique_models[model_hash] = item
+
+ # Return a list of unique items with the lowest 'block' value for each 'model_hash'
+ return list(unique_models.values()), blacklisted_models
+
+def filter_models_with_same_metadata(new_competition_miner_models: list) -> list:
+ """Filter out model results if there are two models with the same model (namspace, name, revision and class).
+
+ We keep the model whose metadata was uploaded to the chain first.
+
+ Args:
+ :param new_competition_miner_models: (List[dict]): List of benchmarking results for models in current competition
+
+ Returns:
+ List[dict]: Filtered list of benchmarking results for models in current competition
+ """
+ unique_models = {}
+ blacklisted_models = []
+
+ for item in new_competition_miner_models:
+ if item and isinstance(item, dict) and 'block' in item.keys() and 'hf_model_namespace' in item.keys() and 'hf_model_name' in item.keys() and 'hf_model_revision' in item.keys():
+
+ model_id = f"{item['hf_model_namespace']}{item['hf_model_name']}{item['hf_model_revision']}"
+ block = item['block']
+
+ # Check if this model_hash is already in the unique_models dictionary
+ if model_id in unique_models:
+ # Keep the entry with the lowest 'block' value
+ if block < unique_models[model_id]['block']:
+ blacklist_model = unique_models[model_id]
+ filtered_blacklist_model = {
+ 'hf_model_namespace':blacklist_model['hf_model_namespace'],
+ 'hf_model_name':blacklist_model['hf_model_name'],
+ 'hf_model_revision':blacklist_model['hf_model_revision'],
+ }
+ blacklisted_models.append(filtered_blacklist_model)
+ unique_models[model_id] = item
+ else:
+ # If model_hash not seen before, add it to unique_models
+ unique_models[model_id] = item
+
+ # Return a list of unique items with the lowest 'block' value for each 'model_hash'
+ return list(unique_models.values()), blacklisted_models
+
+def filter_models_for_deregistered_miners(miner_models, hotkeys):
+ """Removes models from list if the miner who submitted it has deregistered.
+
+ Args:
+ :param new_competition_miner_models: (List[dict]): List of new models
+ hotkeys (List[str]): List of currently registered miner hotkeys
+
+ Returns:
+ List[dict]: List of models submitted by miners with registered hotkeys
+ """
+ registered_models = []
+
+ for model in miner_models:
+ if model and isinstance(model, dict) and 'hotkey' in model.keys():
+ if 'hotkey' in model.keys() and model['hotkey'] in hotkeys:
+ registered_models.append(model)
+
+ return registered_models
\ No newline at end of file
diff --git a/soundsright/base/data/__init__.py b/soundsright/base/data/__init__.py
new file mode 100644
index 0000000..4d0562c
--- /dev/null
+++ b/soundsright/base/data/__init__.py
@@ -0,0 +1,12 @@
+from .download import (
+ dataset_download,
+ download_arni,
+ download_wham,
+)
+
+from .tts import TTSHandler
+
+from .generate import (
+ reset_all_data_directories,
+ create_noise_and_reverb_data_for_all_sampling_rates
+)
\ No newline at end of file
diff --git a/soundsright/base/data/download.py b/soundsright/base/data/download.py
new file mode 100644
index 0000000..9e953e8
--- /dev/null
+++ b/soundsright/base/data/download.py
@@ -0,0 +1,219 @@
+import os
+import requests
+import zipfile
+import shutil
+from soundsright.base.utils import subnet_logger
+
+def download_arni(arni_path: str, log_level: str = "INFO", partial: bool = False) -> None:
+ """Downloads ARNI RIR dataset. [1]
+
+ [1] Prawda, Karolina, Schlecht, Sebastian J., & Välimäki, Vesa.
+ Dataset of impulse responses from variable acoustics room Arni at Aalto Acoustic Labs [Data set].
+ Zenodo, 2022.
+ https://doi.org/10.5281/zenodo.6985104
+
+ Args:
+ :param arni_path: (str): Output path to save ARNI files.
+ :param log_level: (str, optional): Log level for operations. Defaults to "INFO".
+ :param partial: (bool, optional): Set to True if you only want to partially download the dataset for testing purposes. Defaults to False.
+
+ Raises:
+ Exception: Raised if download fails in any way.
+ """
+ # URLs and filenames
+ files = [
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_0-5.zip?download=1",
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_6-15.zip?download=1",
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_16-25.zip?download=1",
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_26-35.zip?download=1",
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_36-45.zip?download=1",
+ "https://zenodo.org/records/6985104/files/IR_Arni_upload_numClosed_46-55.zip?download=1",
+ ]
+
+ if partial:
+ files = files[0:2]
+
+ if not os.path.exists(arni_path):
+ os.makedirs(arni_path)
+
+ # Download each file
+ for url in files:
+ try:
+ # Get the file name from the URL
+ zip_filename = url.split('/')[-1].split('?')[0]
+ zip_filepath = os.path.join(arni_path, zip_filename)
+
+ # Download the file
+ response = requests.get(url, stream=True)
+ response.raise_for_status()
+
+ # Save the file to disk
+ with open(zip_filepath, 'wb') as file:
+ for chunk in response.iter_content(chunk_size=8192):
+ file.write(chunk)
+
+ # Extract the zip file
+ with zipfile.ZipFile(zip_filepath, 'r') as zip_ref:
+ extract_dir = os.path.join(arni_path, zip_filename.replace('.zip', ''))
+ zip_ref.extractall(extract_dir)
+
+ # Move all .wav files to arni_path (including subdirectories)
+ for root, _, files_in_dir in os.walk(extract_dir):
+ for file in files_in_dir:
+ if file.endswith('.wav'):
+ source_file = os.path.join(root, file)
+ destination_file = os.path.join(arni_path, file)
+ shutil.move(source_file, destination_file)
+
+ # Clean up: Delete the extracted directory and the .zip file
+ shutil.rmtree(extract_dir)
+ os.remove(zip_filepath)
+ subnet_logger(
+ severity="TRACE",
+ message=f"Downloaded portion of Arni dataset from url: {url}",
+ log_level=log_level
+ )
+
+ except Exception as e:
+ subnet_logger(
+ severity="ERROR",
+ message=f"Error downloading or processing {url}: {e}",
+ log_level=log_level
+ )
+ raise e
+
+def download_wham(wham_path: str, log_level:str = "INFO") -> None:
+ """Downloads WHAM! 48kHz noise dataset. [2]
+
+ [2] Wichern, Gordon, Antognini, Joe, Flynn, Michael, Zhu, Licheng Richard,
+ McQuinn, Emmett, Crow, Dwight, Manilow, Ethan, & Le Roux, Jonathan.
+ WHAM!: Extending Speech Separation to Noisy Environments.
+ In Proceedings of Interspeech, September 2019.
+ http://wham.whisper.ai/
+
+ Args:
+ :param wham_path: (str): Path to save WHAM! 48kHz dataset
+ :param log_level: (str, optional): Log level for operations. Defaults to "INFO".
+
+ Raises:
+ Exception: Raised if download fails in any way.
+ """
+ try:
+ url = 'https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/high_res_wham.zip'
+ file_name = 'high_res_wham.zip'
+ file_path = os.path.join(wham_path, file_name)
+
+ # Send a GET request to the URL and stream the content
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status() # Check if the request was successful
+ # Open a local file in write-binary mode
+ with open(file_path, 'wb') as f:
+ # Write the content in chunks
+ for chunk in r.iter_content(chunk_size=8192):
+ if chunk: # Filter out keep-alive chunks
+ f.write(chunk)
+
+ # Unzip the file
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
+ zip_ref.extractall(wham_path)
+
+ # Define the directory containing the extracted audio files
+ extracted_audio_dir = os.path.join(wham_path, 'high_res_wham', 'audio')
+
+ # Move all .wav files from the extracted directory to wham_path
+ for file_name in os.listdir(extracted_audio_dir):
+ if file_name.endswith('.wav'):
+ src_path = os.path.join(extracted_audio_dir, file_name)
+ dest_path = os.path.join(wham_path, file_name)
+ shutil.move(src_path, dest_path)
+
+ # Remove the high_res_wham directory after moving the .wav files
+ shutil.rmtree(os.path.join(wham_path, 'high_res_wham'))
+
+ # Delete the zip file after extraction
+ os.remove(file_path)
+
+ except Exception as e:
+ subnet_logger(
+ severity="ERROR",
+ message=f"WHAM download failed. Exception: {e}",
+ log_level=log_level
+ )
+ raise e
+
+# Check dataset directories, create if they do not exist. Then download WHAM and Arni datasets
+def dataset_download(wham_path: str, arni_path: str, log_level: str = "INFO", partial: bool = False) -> bool:
+ """Downloads ARNI and WHAM! 48kHz datasets.
+
+ Args:
+ :param wham_path: (str): Path to save WHAM! 48kHz dataset
+ :param arni_path: (str): Output path to save ARNI files.
+ :param log_level: (str, optional): Log level for operations. Defaults to "INFO".
+ :param partial: (bool, optional): Set to True if you only want to partially download the dataset for testing purposes. Defaults to False.
+
+ Returns:
+ bool: _description_
+ """
+ # Check if dataset directories exist, create if not
+ for directory in [wham_path, arni_path]:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ # Check if Arni dataset already downloaded, download if not
+ if not any(file.endswith(".wav") for file in os.listdir(arni_path)):
+ try:
+ subnet_logger(
+ severity="INFO",
+ message="Starting download of ARNI dataset.",
+ log_level=log_level
+ )
+
+ download_arni(arni_path=arni_path, log_level=log_level, partial=partial)
+ subnet_logger(
+ severity="INFO",
+ message="Arni dataset download complete.",
+ log_level=log_level
+ )
+ except:
+ subnet_logger(
+ severity="ERROR",
+ message="Arni datset download failed. Please contact subnet owners if error persists. Exiting neuron.",
+ log_level=log_level
+ )
+ return False
+ else:
+ subnet_logger(
+ severity="INFO",
+ message="Arni dataset has already been downloaded.",
+ log_level=log_level
+ )
+
+ # Check if WHAM dataset already downloaded, download if not
+ if not any(file.endswith(".wav") for file in os.listdir(wham_path)):
+ try:
+ subnet_logger(
+ severity="INFO",
+ message="Starting download of WHAM dataset.",
+ log_level=log_level
+ )
+ download_wham(wham_path=wham_path)
+ subnet_logger(
+ severity="INFO",
+ message="WHAM dataset download complete.",
+ log_level=log_level
+ )
+ except:
+ subnet_logger(
+ severity="ERROR",
+ message="WHAM datset download failed. Please contact subnet owners if error persists. Exiting neuron.",
+ log_level=log_level
+ )
+ return False
+ else:
+ subnet_logger(
+ severity="INFO",
+ message="WHAM dataset has already been downloaded.",
+ log_level=log_level
+ )
+
+ return True
\ No newline at end of file
diff --git a/soundsright/base/data/generate.py b/soundsright/base/data/generate.py
new file mode 100644
index 0000000..fffc808
--- /dev/null
+++ b/soundsright/base/data/generate.py
@@ -0,0 +1,489 @@
+import os
+import random
+import numpy as np
+import librosa
+import soundfile as sf
+from scipy.signal import convolve
+import pyloudnorm as pyln
+from typing import List
+from scipy import stats
+
+import soundsright.base.utils as Utils
+import soundsright.base.data as Data
+
+def _obtain_random_rir_from_arni(arni_dir_path: str) -> str:
+ """Returns random RIR from Arni dataset as a list.
+
+ Args:
+ :param arni_dir_path: (str): Path to ARNI dataset.
+
+ Returns:
+ str: Path to .wav file in ARNI dataset.
+ """
+ # Get all .wav files in the ARNI directory (including subdirectories if needed)
+ wav_files = [os.path.join(root, f) for root, dirs, files in os.walk(arni_dir_path) for f in files if f.endswith('.wav')]
+
+ # Raise an error if no .wav files are found
+ if not wav_files:
+ raise ValueError(f"No .wav files found in the directory {arni_dir_path}.")
+
+ # Select and return a random .wav file
+ return random.choice(wav_files)
+
+def calc_rt60(h, sr, rt='t30') -> float:
+ """
+ RT60 measurement routine acording to Schroeder's method [1].
+
+ [1] M. R. Schroeder, "New Method of Measuring Reverberation Time," J. Acoust. Soc. Am., vol. 37, no. 3, pp. 409-412, Mar. 1968.
+
+ Adapted from https://github.com/python-acoustics/python-acoustics/blob/99d79206159b822ea2f4e9d27c8b2fbfeb704d38/acoustics/room.py#L156
+
+ Args:
+ :param h: (np.ndarray): The RIR signal.
+ :param sr: (int): The sample rate.
+ :param rt: (str): The RT60 calculation to make. Default is 't30'
+ """
+ rt = rt.lower()
+ if rt == 't30':
+ init = -5.0
+ end = -35.0
+ factor = 2.0
+ elif rt == 't20':
+ init = -5.0
+ end = -25.0
+ factor = 3.0
+ elif rt == 't10':
+ init = -5.0
+ end = -15.0
+ factor = 6.0
+ elif rt == 'edt':
+ init = 0.0
+ end = -10.0
+ factor = 6.0
+
+ h_abs = np.abs(h) / np.max(np.abs(h))
+
+ # Schroeder integration
+ sch = np.cumsum(h_abs[::-1]**2)[::-1]
+ sch_db = 10.0 * np.log10(sch / np.max(sch)+1e-20)
+
+ # Linear regression
+ sch_init = sch_db[np.abs(sch_db - init).argmin()]
+ sch_end = sch_db[np.abs(sch_db - end).argmin()]
+ init_sample = np.where(sch_db == sch_init)[0][0]
+ end_sample = np.where(sch_db == sch_end)[0][0]
+ x = np.arange(init_sample, end_sample + 1) / sr
+ y = sch_db[init_sample:end_sample + 1]
+ slope, intercept = stats.linregress(x, y)[0:2]
+
+ # Reverberation time (T30, T20, T10 or EDT)
+ db_regress_init = (init - intercept) / slope
+ db_regress_end = (end - intercept) / slope
+ t60 = factor * (db_regress_end - db_regress_init)
+ return t60
+
+def _convolve_tts_with_random_rir(
+ tts_path: str,
+ arni_dir_path: str,
+ reverb_dir_path: str,
+ max_rt60: float = 2.0,
+) -> None:
+ """
+ Convolves a mono audio file with a random RIR from the Arni dataset and saves the output.
+ The RIR is resampled to match the audio's sample rate before convolution.
+
+ This method was adapted from the generation of the EARS-Reverb dataset [2]
+
+ [2] J. Richter, Y.-C. Wu, S. Krenn, S. Welker, B. Lay, S. Watanabe, A. Richard, and T. Gerkmann,
+ "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation,"
+ in Proc. ISCA Intertts, pp. 4873-4877, 2024.
+
+ Here is a link to the code: https://github.com/sp-uhh/ears_benchmark/blob/main/generate_ears_reverb.py
+
+ Args:
+ :param tts_path: (str): Path to the input TTS audio file.
+ :param arni_dir_path: (str); Path to the directory containing the Arni RIR dataset.
+ :param reverb_dir_path: (str): Path to save the convolved output file.
+
+ Returns:
+ None
+ """
+ # Obtain output path from tts_path
+ output_path = os.path.join(reverb_dir_path, os.path.basename(tts_path))
+
+ tts, tts_sr = sf.read(tts_path)
+ meter = pyln.Meter(tts_sr)
+
+ # Sample RIRs until RT60 is below max_rt60 and pre_samples are below max_pre_samples
+ rt60 = np.inf
+ while rt60 > max_rt60:
+ rir_file = _obtain_random_rir_from_arni(arni_dir_path=arni_dir_path)
+
+ rir, sr = sf.read(rir_file, always_2d=True)
+
+ # Take random channel if file is multi-channel
+ channel = np.random.randint(0, rir.shape[1])
+ rir = rir[:,channel]
+ assert sr == 44100
+ rir = librosa.resample(rir, orig_sr=sr, target_sr=tts_sr)
+
+ # Cut RIR to get direct path at the beginning
+ max_index = np.argmax(np.abs(rir))
+ rir = rir[max_index:]
+
+ # Normalize RIRs in range [0.1, 0.7]
+ if np.max(np.abs(rir)) < 0.1:
+ rir = 0.1 * rir / np.max(np.abs(rir))
+ elif np.max(np.abs(rir)) > 0.7:
+ rir = 0.7 * rir / np.max(np.abs(rir))
+
+ rt60 = calc_rt60(rir, sr=sr)
+
+ mixture = convolve(tts, rir)[:len(tts)]
+
+ # normalize mixture
+ loudness_tts = meter.integrated_loudness(tts)
+ loudness_mixture = meter.integrated_loudness(mixture)
+ delta_loudness = loudness_tts - loudness_mixture
+ gain = np.power(10.0, delta_loudness/20.0)
+ # if gain is inf sample again
+ if np.isinf(gain):
+ rt60 = np.inf
+ mixture = gain * mixture
+
+ if np.max(np.abs(mixture)) > 1.0:
+ mixture = mixture / np.max(np.abs(mixture))
+
+ sf.write(output_path, mixture, tts_sr)
+
+def convolve_all_tts_with_random_rir(tts_dir_path: str, arni_dir_path: str, reverb_dir_path: str) -> None:
+ """Generates the entire reverberant database.
+
+ Args:
+ :param tts_dir_path: (str): Path to clean TTS dataset.
+ :param arni_dir_path: (str): Path to ARNI datasrt.
+ :param reverb_dir_path: (str): Path to save reverberant dataset.
+
+ Returns:
+ None
+ """
+ tts_paths = [os.path.join(tts_dir_path, f) for f in os.listdir(tts_dir_path) if f.endswith('.wav')]
+ for tts_path in tts_paths:
+ _convolve_tts_with_random_rir(tts_path=tts_path, arni_dir_path=arni_dir_path, reverb_dir_path=reverb_dir_path)
+
+def _obtain_random_noise_from_wham(wham_dir_path: str) -> str:
+ """
+ Returns the full path of a randomly selected .wav file from the specified directory.
+
+ Args:
+ :param directory: (str): The path to the directory containing .wav files.
+
+ Returns:
+ str: The full path to the randomly selected .wav file.
+
+ Raises:
+ ValueError: If no .wav files are found in the directory.
+ """
+ # List all files in the directory
+ files_in_directory = os.listdir(wham_dir_path)
+
+ # Filter out only .wav files
+ wav_files = [file for file in files_in_directory if file.lower().endswith('.wav')]
+
+ if not wav_files:
+ raise ValueError(f"No .wav files found in the directory: {wham_dir_path}")
+
+ # Choose a random .wav file
+ random_wav_file = random.choice(wav_files)
+
+ # Get the full path
+ full_path = os.path.join(wham_dir_path, random_wav_file)
+
+ return full_path
+
+def _add_random_wham_noise_to_tts(
+ tts_path: str,
+ wham_dir_path: str,
+ noise_dir_path: str,
+ min_snr: float = -2.5,
+ max_snr: float = 17.5,
+ ramp_time_in_ms: int = 10
+ ) -> None:
+ """
+ Adds random WHAM noise to a TTS audio file with a specified SNR.
+
+ This method was adapted from the generation of the EARS-WHAM dataset [2]
+
+ [2] J. Richter, Y.-C. Wu, S. Krenn, S. Welker, B. Lay, S. Watanabe, A. Richard, and T. Gerkmann,
+ "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation,"
+ in Proc. ISCA Intertts, pp. 4873-4877, 2024.
+
+ Here is a link to the code: https://github.com/sp-uhh/ears_benchmark/blob/main/generate_ears_wham.py
+
+ Args:
+ :param tts_path: (str): Path to the TTS .wav file.
+ :param wham_dir_path: (str): Path to the directory containing WHAM .wav files.
+ :param noise_dir_path: (str): Path where the output noisy TTS should be saved.
+ :param min_snr: (float): Minimum SNR (in dB).
+ :param max_snr: (float): Maximum SNR (in dB).
+ :param ramp_time_in_ms: (float): Duration of the ramp at the start and end in milliseconds.
+
+ Returns:
+ None
+ """
+ # Get a random noise .wav file path
+ noise_wav_path = _obtain_random_noise_from_wham(wham_dir_path)
+
+ # Load the TTS and noise audio files
+ tts_audio, tts_sr = librosa.load(tts_path, sr=None)
+ noise_audio, noise_sr = librosa.load(noise_wav_path, sr=None)
+
+ # Resample noise if needed to match the TTS sampling rate
+ if noise_sr != tts_sr:
+ noise_audio = librosa.resample(noise_audio, orig_sr=noise_sr, target_sr=tts_sr)
+
+ # If noise is longer than the TTS audio, select a random segment
+ if len(noise_audio) > len(tts_audio):
+ max_start = len(noise_audio) - len(tts_audio)
+ start_idx = np.random.randint(0, max_start)
+ noise_audio = noise_audio[start_idx:start_idx + len(tts_audio)]
+ else:
+ # Ensure noise is the same length as the TTS audio (loop if necessary)
+ repeats = int(np.ceil(len(tts_audio) / len(noise_audio)))
+ noise_audio = np.tile(noise_audio, repeats)[:len(tts_audio)]
+
+ # Choose a random SNR value between min_snr and max_snr
+ snr_dB = random.uniform(min_snr, max_snr)
+
+ # Perform loudness normalization to match the target SNR
+ meter = pyln.Meter(tts_sr)
+ loudness_tts = meter.integrated_loudness(tts_audio)
+ loudness_noise = meter.integrated_loudness(noise_audio)
+
+ # Calculate the required gain for the noise
+ target_loudness = loudness_tts - snr_dB
+ delta_loudness = target_loudness - loudness_noise
+ gain = np.power(10.0, delta_loudness / 20.0)
+ noise_scaled = gain * noise_audio
+
+ # Mix the TTS audio with the scaled noise
+ mixture = tts_audio + noise_scaled
+
+ # Adjust for clipping by increasing SNR if needed
+ while np.max(np.abs(mixture)) >= 1.0:
+ snr_dB += 1 # Increase SNR to reduce noise level
+ target_loudness = loudness_tts - snr_dB
+ delta_loudness = target_loudness - loudness_noise
+ gain = np.power(10.0, delta_loudness / 20.0)
+ noise_scaled = gain * noise_audio
+ mixture = tts_audio + noise_scaled
+
+ # Apply ramps at beginning and end
+ ramp_duration = ramp_time_in_ms / 1000.0 # Convert ramp time to seconds
+ ramp_samples = int(ramp_duration * tts_sr)
+ ramp = np.linspace(0, 1, ramp_samples)
+
+ # Apply ramps to the mixture
+ mixture[:ramp_samples] *= ramp
+ mixture[-ramp_samples:] *= ramp[::-1]
+
+ # Apply ramps to the original tts for consistency
+ tts_audio[:ramp_samples] *= ramp
+ tts_audio[-ramp_samples:] *= ramp[::-1]
+
+ filename = os.path.basename(tts_path)
+ noise_path = os.path.join(noise_dir_path, filename)
+
+ # Save the resulting mixture to the specified noise_dir_path
+ sf.write(noise_path, mixture, tts_sr)
+
+def add_random_wham_noise_to_all_tts(tts_dir_path: str, wham_dir_path: str, noise_dir_path: str) -> None:
+ """Generates full noisy dataset.
+
+ Args:
+ :param tts_dir_path: (str): Path to clean TTS dataset.
+ :param wham_dir_path: (str): Path to WHAM! dataset.
+ :param noise_dir_path: (str): Path to noisy dataset.
+
+ Returns:
+ None
+ """
+ tts_paths = [os.path.join(tts_dir_path, f) for f in os.listdir(tts_dir_path) if f.endswith('.wav')]
+ for tts_path in tts_paths:
+ _add_random_wham_noise_to_tts(tts_path=tts_path, wham_dir_path=wham_dir_path, noise_dir_path=noise_dir_path)
+
+def create_noise_and_reverb_data_for_all_sampling_rates(
+ tts_base_path: str,
+ arni_dir_path: str,
+ reverb_base_path: str,
+ wham_dir_path: str,
+ noise_base_path: str,
+ tasks: List[str],
+ log_level: str) -> None:
+ """Takes TTS dataset and applies noise and/or reverb
+ to generate noise and/or reverb datasets.
+
+ Args:
+ :param tts_base_path: (str): Path to clean TTs dataset.
+ :param arni_dir_path: (str): _description_
+ :param reverb_base_path: (str): _description_
+ :param wham_dir_path: (str): _description_
+ :param noise_base_path: (str): _description_
+ :param tasks: (List[str]): _description_
+
+ Returns:
+ None
+ """
+
+ # Iterate through each sub-directory in the TTS base path
+ for dir_name in os.listdir(tts_base_path):
+ tts_dir_path = os.path.join(tts_base_path, dir_name)
+
+ # Ensure it is a directory
+ if os.path.isdir(tts_dir_path):
+ # Define the corresponding reverb and noise sub-directory paths
+ reverb_dir_path = os.path.join(reverb_base_path, dir_name)
+ noise_dir_path = os.path.join(noise_base_path, dir_name)
+
+ if 'DENOISING' in tasks:
+ # Create the noise sub-directory if it does not exist
+ if not os.path.exists(noise_dir_path):
+ os.makedirs(noise_dir_path)
+ add_random_wham_noise_to_all_tts(tts_dir_path=tts_dir_path, wham_dir_path=wham_dir_path, noise_dir_path=noise_dir_path)
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message=f"Denoising dataset created in directory: {noise_dir_path}",
+ log_level=log_level,
+ )
+
+ if 'DEREVERBERATION' in tasks:
+ # Create the reverb sub-directory if it does not exist
+ if not os.path.exists(reverb_dir_path):
+ os.makedirs(reverb_dir_path)
+ convolve_all_tts_with_random_rir(tts_dir_path=tts_dir_path, arni_dir_path=arni_dir_path, reverb_dir_path=reverb_dir_path)
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message=f"Dereverberation dataset created in directory: {reverb_dir_path}",
+ log_level=log_level,
+ )
+
+def reset_all_data_directories(tts_base_path: str, reverb_base_path: str, noise_base_path: str) -> bool:
+ """
+ Removes all .wav files from the subdirectories of the specified base paths.
+
+ Args:
+ :param tts_base_path: Base path containing subdirectories with .wav files.
+ :param reverb_base_path: Base path containing subdirectories with .wav files.
+ :param noise_base_path: Base path containing subdirectories with .wav files.
+
+ Returns:
+ bool: True if operation was successful, False if not.
+ """
+ # List of all base paths to process
+ base_paths = [tts_base_path, reverb_base_path, noise_base_path]
+
+ for base_path in base_paths:
+ # Iterate through each subdirectory within the base path
+ for subdir_name in os.listdir(base_path):
+ subdir_path = os.path.join(base_path, subdir_name)
+
+ # Check if it is indeed a directory
+ if os.path.isdir(subdir_path):
+ # Iterate through all files in the subdirectory
+ for file_name in os.listdir(subdir_path):
+ # Check if the file is a .wav file
+ if file_name.endswith('.wav'):
+ file_path = os.path.join(subdir_path, file_name)
+
+ # Remove the .wav file
+ try:
+ os.remove(file_path)
+ except OSError as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Error removing file located at: {file_path}"
+ )
+ return False
+
+ return True
+
+def generate_dataset_for_miner(
+ clean_dir: str,
+ sample_rate: int,
+ n: int,
+ task: str,
+ reverb_data_dir: str | None = None,
+ noise_data_dir: str | None = None,
+ noise_dir: str | None = None,
+ reverb_dir: str | None = None
+) -> None:
+ """Function to generate fine-tuning datasets for miners.
+
+ Args:
+ clean_dir (str): Path to clean TTS dataset.
+ sample_rate (int): Sample rate.
+ n (int): Number of elements in each dataset.
+ task (str): "denoising" or "dereverberation".
+ reverb_data_dir (str | None, optional): ARNI dataset path. Defaults to None (for if you are only looking to generate a noisy dataset).
+ noise_data_dir (str | None, optional): WHAM! dataset path. Defaults to None (for if you are only looking to generate a reverberant dataset).
+ noise_dir (str | None, optional): Noisy dataset path. Defaults to None (for if you are only looking to generate a reverberant dataset).
+ reverb_dir (str | None, optional): Reverb dataset path. Defaults to None (for if you are only looking to generate a noisy dataset).
+
+ Raises:
+ Exception: Raised if there is an issue during either download or generation.
+
+ Returns:
+ None
+ """
+ assert task in ['denoising', 'dereverberation', 'both'], "Input argument: task must be one of: 'denoising', 'dereverberation', 'both'"
+ assert isinstance(sample_rate, int), "Input argument: sample_rate must be of type int"
+ assert sample_rate in [16000], "Input argument: sample_rate must be 16000"
+ assert reverb_data_dir or noise_data_dir, "At least one of input arguments: reverb_data_dir or noise_data_dir must be specified. If you want to generate both reverb and noise datasets (inputting 'both' into task), then both must be specified."
+
+ dirs_to_make = []
+ for d in [clean_dir, noise_dir, reverb_dir, noise_data_dir, reverb_data_dir]:
+ if d:
+ dirs_to_make.append(d)
+
+ for directory in dirs_to_make:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ tts_handler = Data.TTSHandler(
+ tts_base_path=clean_dir,
+ sample_rates = [sample_rate]
+ )
+
+ tts_handler.create_openai_tts_dataset(
+ sample_rate = sample_rate,
+ n=n,
+ for_miner=True
+ )
+
+ if task.lower() == "denoising":
+ if not any(file.endswith(".wav") for file in os.listdir(noise_data_dir)):
+ try:
+ Data.download_wham(wham_path=noise_data_dir)
+ except Exception as e:
+ raise e("Noise dataset download failed.")
+
+ add_random_wham_noise_to_all_tts(tts_dir_path=clean_dir, wham_dir_path=noise_data_dir, noise_dir_path=noise_dir)
+
+ elif task.lower == 'dereverberation':
+ if not any(file.endswith(".wav") for file in os.listdir(reverb_data_dir)):
+ try:
+ Data.download_arni(arni_path=reverb_data_dir)
+ except Exception as e:
+ raise e("Reverb dataset download failed")
+
+ convolve_all_tts_with_random_rir(tts_dir_path=clean_dir, arni_dir_path=reverb_data_dir, reverb_dir_path=reverb_dir)
+
+ else:
+ Data.dataset_download(
+ wham_path=noise_data_dir,
+ arni_path=reverb_data_dir,
+ )
+ add_random_wham_noise_to_all_tts(tts_dir_path=clean_dir, wham_dir_path=noise_data_dir, noise_dir_path=noise_dir)
+ convolve_all_tts_with_random_rir(tts_dir_path=clean_dir, arni_dir_path=reverb_data_dir, reverb_dir_path=reverb_dir)
\ No newline at end of file
diff --git a/soundsright/base/data/tts.py b/soundsright/base/data/tts.py
new file mode 100644
index 0000000..adb00f7
--- /dev/null
+++ b/soundsright/base/data/tts.py
@@ -0,0 +1,136 @@
+import random
+import os
+from openai import OpenAI
+import librosa
+import soundfile as sf
+from typing import List
+
+from soundsright.base.utils import subnet_logger
+from soundsright.base.templates import (
+ TOPICS,
+ EMOTIONS
+)
+
+from dotenv import load_dotenv
+load_dotenv()
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+# Handles all TTS-related operations
+class TTSHandler:
+
+ def __init__(self, tts_base_path: str, sample_rates: List[int], log_level: str = "INFO"):
+ self.tts_base_path = tts_base_path
+ self.sample_rates = sample_rates
+ api_key = os.getenv("OPENAI_API_KEY")
+ self.openai_client = OpenAI(api_key=api_key)
+ self.openai_voices = ['alloy','echo','fable','onyx','nova','shimmer']
+ self.log_level = log_level
+
+ def _generate_prompt(self) -> str:
+ topic = random.choice(tuple(TOPICS)).lower()
+ emotion = random.choice(tuple(EMOTIONS)).lower()
+ prompt_selection = random.randint(0,3)
+
+ if prompt_selection == 0:
+ return f"You are a somebody who knows a lot of facts. Please provide an output related to the following topic: '{topic}' with a tone of '{emotion}'. Do not start your response with the topic. Respond in English. Do not include anything at the start, or the end, but just the sentence, as your reply will be formatted into a larger block of text and it needs to flow smoothly."
+
+ elif prompt_selection == 1:
+ return f"You are a somebody who asks a lot of thought-provoking questions. Please provide an output related to the following topic: '{topic}' with a tone of '{emotion}'. Do not start your response with the topic. Respond in English. Do not include anything at the start, or the end, but just the sentence, as your reply will be formatted into a larger block of text and it needs to flow smoothly."
+
+ return f"You are a somebody who provides commentary on a wide variety of topics. Please provide an output related to the following topic: '{topic}' with a tone of: '{emotion}'. Do not start your response with the topic. Respond in English. Do not include anything at the start, or the end, but just the sentence, as your reply will be formatted into a larger block of text and it needs to flow smoothly."
+
+ # Generates unique sentences for TTS
+ def _generate_random_sentence(self, n: int=8) -> str:
+ prompt = self._generate_prompt()
+ messages = [
+ {
+ "role":"system",
+ "content":"You will be asked to partake in different forms of conversation. Do not start your response with the topic. Respond in English. Do not include anything at the start, or the end, but just the sentence, as your reply will be formatted into a larger block of text and it needs to flow smoothly."
+ },
+ {
+ "role":"user",
+ "content":prompt
+ }
+ ]
+
+ completion = self.openai_client.chat.completions.create(model='gpt-4o', messages=messages)
+ return completion.choices[0].message.content
+
+ # Generates one output TTS file at correct sample rate
+ def _do_single_openai_tts_query(self, tts_file_path: str, sample_rate: int, voice: str = 'random'):
+ # voice control
+ if voice == 'random' or voice not in self.openai_voices:
+ voice = random.choice(self.openai_voices)
+ # define openai call params
+ params = {
+ 'model':'tts-1-hd',
+ 'voice':voice,
+ 'input':self._generate_random_sentence()
+ }
+ # call openai with client
+ try:
+ response=self.openai_client.audio.speech.create(**params)
+ response.stream_to_file(tts_file_path)
+ subnet_logger(
+ severity="TRACE",
+ message=f"Obtained TTS audio file: {tts_file_path} from OpenAI.",
+ log_level=self.log_level
+ )
+ # raise error if it fails
+ except Exception as e:
+ subnet_logger(
+ severity="ERROR",
+ message="Could not get TTS audio file from OpenAI, please check configuration.",
+ log_level=self.log_level
+ )
+ # resample in place if necessary
+ try:
+ # Load the generated TTS audio file
+ audio_data, sr = librosa.load(tts_file_path, sr=None)
+
+ # Check if sample rate matches
+ if sr != sample_rate:
+ # Resample the audio
+ audio_data_resampled = librosa.resample(audio_data, orig_sr=sr, target_sr=sample_rate)
+
+ # Write the resampled audio back to the same file
+ sf.write(tts_file_path, audio_data_resampled, sample_rate)
+
+ # Log the resampling action
+ subnet_logger(
+ severity="TRACE",
+ message=f"Resampled audio file '{tts_file_path}' from {sr} Hz to {sample_rate} Hz.",
+ log_level=self.log_level
+ )
+ except Exception as e:
+ subnet_logger(
+ severity="ERROR",
+ message=f"Error during resampling of '{tts_file_path}': {e}",
+ log_level=self.log_level
+ )
+
+ # Creates TTS dataset of length n at specified sample rate
+ def create_openai_tts_dataset(self, sample_rate: int, n:int, for_miner: bool = False):
+ # define output file location and make directory if it doesn't exist
+ if for_miner:
+ output_dir = self.tts_base_path
+ else:
+ output_dir = os.path.join(self.tts_base_path, str(sample_rate))
+ os.makedirs(output_dir, exist_ok=True)
+ # count to n and make files
+ for i in range(n):
+ self._do_single_openai_tts_query(
+ tts_file_path = os.path.join(output_dir, (str(i) + ".wav")),
+ sample_rate=sample_rate
+ )
+
+ # Create TTS dataset of length n for all sample rates
+ def create_openai_tts_dataset_for_all_sample_rates(self, n:int):
+
+ for sample_rate in self.sample_rates:
+ self.create_openai_tts_dataset(
+ sample_rate=sample_rate,
+ n=n
+ )
\ No newline at end of file
diff --git a/soundsright/base/models/__init__.py b/soundsright/base/models/__init__.py
new file mode 100644
index 0000000..1475f74
--- /dev/null
+++ b/soundsright/base/models/__init__.py
@@ -0,0 +1,10 @@
+from .evaluate import ModelEvaluationHandler
+
+from .metadata import ModelMetadataHandler
+
+from .validation import (
+ get_directory_content_hash,
+ get_model_content_hash,
+)
+
+from .sgmse import SGMSEHandler
\ No newline at end of file
diff --git a/soundsright/base/models/evaluate.py b/soundsright/base/models/evaluate.py
new file mode 100644
index 0000000..27eb204
--- /dev/null
+++ b/soundsright/base/models/evaluate.py
@@ -0,0 +1,480 @@
+import os
+import time
+import glob
+import asyncio
+import shutil
+import hashlib
+import bittensor as bt
+from typing import List
+
+# Import custom modules
+import soundsright.base.benchmarking as Benchmarking
+import soundsright.base.models as Models
+import soundsright.base.utils as Utils
+
+class ModelEvaluationHandler:
+
+ def __init__(
+ self,
+ tts_base_path: str,
+ noise_base_path: str,
+ reverb_base_path: str,
+ model_output_path: str,
+ model_path: str,
+ sample_rate: int,
+ task: str,
+ hf_model_namespace: str,
+ hf_model_name: str,
+ hf_model_revision: str,
+ log_level: str,
+ subtensor: bt.subtensor,
+ subnet_netuid: int,
+ miner_hotkey: str,
+ miner_models: List[dict],
+ ):
+ """Initializes ModelEvaluationHandler
+
+ Args:
+ :param tts_base_path: (str): Base directory for TTS dataset
+ :param noise_base_path: (str): Base directory for denoising dataset
+ :param reverb_base_path: (str): Base directory for dereverberation dataset
+ :param model_output_path: (str): Directory for model outputs to be temporarily stored for benchmarking
+ :param model_path: (str): Directory for model to be temporarily stored for benchmarking
+ :param sample_rate: (int): Sample rate
+ :param task: (str): DENOISING or DEREVERBERATION
+ :param hf_model_namespace: (str): Namespace from synapse
+ :param hf_model_name: (str): Name from synapse
+ :param hf_model_revision: (str): Revision from synapse
+ :param log_level: (str): Log level from .env
+ :param subtensor: (bt.subtensor): Subtensor from validator
+ :param subnet_netuid: (int): Netuid from .env
+ :param miner_hotkey: (str): ss58 address
+ :param miner_models: (List[dict]): Most recent benchmarked model/empty response for each miner for the competition
+ """
+ # Paths
+ self.tts_path = os.path.join(tts_base_path, str(sample_rate))
+ self.noise_path = os.path.join(noise_base_path, str(sample_rate))
+ self.reverb_path = os.path.join(reverb_base_path, str(sample_rate))
+ self.model_output_path = model_output_path
+ self.model_path = model_path
+ # Competition
+ self.sample_rate = sample_rate
+ self.task = task
+ self.task_path = self.noise_path if self.task == "DENOISING" else self.reverb_path
+ self.competition = f"{task}_{sample_rate}HZ"
+ # Model
+ self.hf_model_namespace = hf_model_namespace
+ self.hf_model_name = hf_model_name
+ self.hf_model_id = f"{hf_model_namespace}/{hf_model_name}"
+ self.hf_model_revision = hf_model_revision
+ self.hf_model_block = None
+ self.model_hash = ''
+ self.forbidden_model_hashes = [
+ "ENZIdw0H8Vbb79lXDQKBqqReXIj2ycgOX1Ob0QoexAU=",
+ "Mbx0++bk5q6n+rdVlUblElnturj/zRobTk61WFVHmgg=",
+ ]
+ # Misc
+ self.log_level = log_level
+ self.miner_hotkey = miner_hotkey
+ self.miner_models = miner_models
+ self.metadata_handler = Models.ModelMetadataHandler(
+ subtensor=subtensor,
+ subnet_netuid=subnet_netuid,
+ log_level=self.log_level
+ )
+
+ def obtain_model_metadata(self):
+ """
+ Validates that the model provided by the miner matches the metadata uploaded to the chain.
+
+ Updates ModelEvaluationHandler.model_metadata and ModelEvaluationHandler.hf_model_block with
+ on-chain data.
+
+ Returns:
+ bool: True if model metadata could be obtained, False otherwise
+ """
+ try:
+ outcome = asyncio.run(self.metadata_handler.obtain_model_metadata_from_chain(
+ hotkey=self.miner_hotkey,
+ ))
+
+ if not outcome or not self.metadata_handler.metadata or not self.metadata_handler.metadata_block:
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not obtain model metadata from chain for hotkey: {self.miner_hotkey}",
+ log_level=self.log_level
+ )
+
+ return False
+
+ else:
+
+ self.model_metadata = self.metadata_handler.metadata
+ self.hf_model_block = self.metadata_handler.metadata_block
+
+ Utils.subnet_logger(
+ severity="DEBUG",
+ message=f"Recieved model metadata from chain: {self.model_metadata} on block: {self.hf_model_block} for hotkey: {self.miner_hotkey}",
+ log_level=self.log_level
+ )
+
+ return True
+
+ except:
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not obtain model metadata from chain for hotkey: {self.miner_hotkey}",
+ log_level=self.log_level
+ )
+
+ return False
+
+ def validate_model_metadata(self):
+ """Validates that the model metadata is for a model belonging to the mienr with the following steps:
+
+ 1. Re-create model metadata string and confirm that its hash matches metadata uploaded to chain
+ 2. Make sure that model name is unique among models submitted. If it is not, it checks the block that
+ metadata was uploaded to the chain. If the metadata was uploaded first, we assume that this is the
+ miner that originally uploaded the model to Huggingface
+ 3. Download model and calculate hash of model directory
+ 4. Make sure that model hash is unique among models submitted. If it is not, it checks the block that
+ metadata was uploaded to the chain. If the metadata was uploaded first, we assume that this is the
+ miner that originally uploaded the model to Huggingface
+
+ Returns:
+ bool: True if model metadata checks out, False if otherwise
+ """
+ # Obtain competition id from model and miner data
+ competition_id = self.metadata_handler.get_competition_id_from_competition_name(self.competition)
+
+ # Determine miner metadata
+ metadata_str = f"{self.hf_model_namespace}:{self.hf_model_name}:{self.hf_model_revision}:{self.miner_hotkey}:{competition_id}"
+
+ # Hash it and compare to hash uploaded to chain
+ if hashlib.sha256(metadata_str.encode()).hexdigest() != self.model_metadata:
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"Model: {self.hf_model_id} metadata could not be validated with on-chain metadata. Exiting model evaluation.",
+ log_level=self.log_level
+ )
+ return False
+
+ # Check to make sure that namespace, name and revision are unique among submitted models and if not, that it was submitted first
+ for model_dict in self.miner_models:
+ if (
+ model_dict['hf_model_namespace'] == self.hf_model_namespace
+ ) and (
+ model_dict['hf_model_name'] == self.hf_model_name
+ ) and (
+ model_dict['hf_model_revision'] == self.hf_model_revision
+ ) and (
+ model_dict['block'] < self.hf_model_block
+ ):
+ return False
+
+ # Download model to path and obtain model hash
+ self.model_hash, _ = Models.get_model_content_hash(
+ model_id=self.hf_model_id,
+ revision=self.hf_model_revision,
+ local_dir=self.model_path,
+ log_level=self.log_level
+ )
+
+ # Make sure model hash is unique
+ if self.model_hash in [model_data['model_hash'] for model_data in self.miner_models] and self.model_hash not in self.forbidden_model_hashes:
+
+ # Find block that metadata was uploaded to chain for all models with identical directory hash
+ model_blocks_with_same_hash = []
+ for model_data in self.miner_models:
+ if model_data['model_hash'] == self.model_hash:
+ model_blocks_with_same_hash.append(model_data['block'])
+
+ # Append current model block for comparison
+ model_blocks_with_same_hash.append(self.hf_model_block)
+
+ # If it's not unique, don't return False only if this model is the earliest one uploaded to chain
+ if min(model_blocks_with_same_hash) != self.hf_model_block:
+ Utils.subnet_logger(
+ severity="INFO",
+ message=f"Current model: {self.hf_model_id} has identical hash with another model and was not uploaded first. Exiting model evaluation.",
+ log_level=self.log_level
+ )
+ return False
+
+ return True
+
+ def validate_all_noisy_files_are_enhanced(self):
+ noisy_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(self.task_path, '*.wav'))])
+ enhanced_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(self.model_output_path, '*.wav'))])
+ return noisy_files == enhanced_files
+
+ def initialize_and_run_model(self):
+ """_summary_
+
+ Returns:
+ bool:
+ """
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Validating container configuration for model: {self.hf_model_namespace}/{self.hf_model_name}.",
+ log_level=self.log_level
+ )
+
+ # Validate container
+ if not Utils.validate_container_config(self.model_path):
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Validating container configuration for model failed: {self.hf_model_namespace}/{self.hf_model_name}.",
+ log_level=self.log_level,
+ )
+
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Validating container configuration for model succeeded: {self.hf_model_namespace}/{self.hf_model_name}. Now starting container",
+ log_level=self.log_level
+ )
+
+ # Delete any existing containers before starting new one
+ Utils.delete_container(log_level=self.log_level)
+
+ # Start container
+ if not Utils.start_container(directory=self.model_path, log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Container could not be started",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Container started.",
+ log_level=self.log_level
+ )
+
+ time.sleep(10)
+
+ if not Utils.check_container_status(log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Could not establish connection with API.",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Connection established with API and status was verified. Commencing model preparation.",
+ log_level=self.log_level
+ )
+
+ time.sleep(1)
+
+ if not Utils.prepare(log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not prepare the model.",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Model preparation successful. Commencing transfer of noisy files.",
+ log_level=self.log_level
+ )
+
+ time.sleep(10)
+
+ if not Utils.upload_audio(noisy_dir=self.task_path, log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Noisy files could not be uploaded to model container. Ending benchmarking of model.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Noisy files were transferred to model container. Commencing enhancement.",
+ log_level=self.log_level
+ )
+
+ time.sleep(5)
+
+ if not Utils.enhance_audio(log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Noisy files could not be enhanced by model. Ending benchmarking of model.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Enhancement complete. Downloading enhanced files from model container.",
+ log_level=self.log_level
+ )
+
+ time.sleep(5)
+
+ if not Utils.download_enhanced(enhanced_dir=self.model_output_path, log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Enhanced files could not be downloaded. Ending benchmarking of model.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Downloaded and unzipped files. Validating that all noisy files have been enhanced.",
+ log_level=self.log_level
+ )
+
+ if not self.validate_all_noisy_files_are_enhanced():
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Mismatch detected between noisy and enhanced files. Ending benchmarking of model.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Validation successful. Deleting containers.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+
+ return True
+
+ def _reset_dir(self, directory: str) -> None:
+ """Removes all files and sub-directories in an inputted directory
+
+ Args:
+ directory (str): Directory to reset.
+ """
+ # Check if the directory exists
+ if not os.path.exists(directory):
+ return
+
+ # Loop through all the files and subdirectories in the directory
+ for filename in os.listdir(directory):
+ file_path = os.path.join(directory, filename)
+
+ # Check if it's a file or directory and remove accordingly
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path) # Remove the file or link
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path) # Remove the directory and its contents
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Failed to delete {file_path}. Reason: {e}",
+ log_level=self.log_level
+ )
+
+ def reset_model_dirs(self):
+ """
+ Removes files and sub-directories in ModelEvaluationHandler.model_path and
+ ModelEvaluationHandler.model_output_path to make sure all is clear for
+ next model to be benchmarked.
+ """
+ self._reset_dir(directory=self.model_path)
+ self._reset_dir(directory=self.model_output_path)
+
+ def download_run_and_evaluate(self):
+ """
+ Overarching function to verify, download, execute and evaluate model performance.
+
+ Returns:
+ :param metric_average: (float): Average value for the evaluation metric for the model
+ :param confidence_interval: (List[float]) 95% CI for metric score
+ :param metric_values: (List[float]) List of all metric scores for the model evaluation
+ :param metric: (str): Name of the metric being used, determined by competition.
+ :param model_hash:
+ :param hf_model_block:
+ """
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Checking if model: {self.hf_model_id} from miner: {self.miner_hotkey} has metadata that can be obtained from chain.",
+ log_level=self.log_level
+ )
+
+ # Attempt to obtain the model metadata stored on-chain
+ if not self.obtain_model_metadata():
+ # Remove all files from model-based directories (model files and model outcome files)
+ self.reset_model_dirs()
+ # Return zero for the output metric if the model could not be obtained
+ return {}, self.model_hash, self.hf_model_block
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Checking if model: {self.hf_model_id} from miner: {self.miner_hotkey} has metadata that can be validated and a unique hash for the model.",
+ log_level=self.log_level
+ )
+
+ # Attempt to validate the model with the data stored on-chain. This step also downloads the model to self.model_path
+ if not self.validate_model_metadata():
+ # Remove all files from model-based directories (model files and model outcome files)
+ self.reset_model_dirs()
+ # Return zero for the output metric if the model metadata could not be validated
+ return {}, self.model_hash, self.hf_model_block
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Running model: {self.hf_model_id} from miner: {self.miner_hotkey} on validator benchmarking dataset.",
+ log_level=self.log_level
+ )
+
+ # Initialize and run the model on the dataset
+ if not self.initialize_and_run_model():
+ # Remove all files from model-based directories (model files and model outcome files)
+ self.reset_model_dirs()
+ # Return zero for the output metric if the model could not be initialized or run
+ return {}, self.model_hash, self.hf_model_block
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Calculating metrics for benchmarking dataset for model: {self.hf_model_id} from miner: {self.miner_hotkey}.",
+ log_level=self.log_level
+ )
+
+ # Calculate metrics (metrics vary depending on sample rate)
+ metrics_dict = Benchmarking.calculate_metrics_dict(
+ clean_directory=self.tts_path,
+ enhanced_directory=self.model_output_path,
+ noisy_directory=self.task_path,
+ sample_rate=self.sample_rate,
+ log_level=self.log_level,
+ )
+
+ # Remove all files from model-based directories (model files and model outcome files)
+ self.reset_model_dirs()
+
+ return metrics_dict, self.model_hash, self.hf_model_block
\ No newline at end of file
diff --git a/soundsright/base/models/metadata.py b/soundsright/base/models/metadata.py
new file mode 100644
index 0000000..02c4c09
--- /dev/null
+++ b/soundsright/base/models/metadata.py
@@ -0,0 +1,109 @@
+import bittensor as bt
+import soundsright.base.utils as Utils
+
+class ModelMetadataHandler:
+
+ def __init__(
+ self,
+ subtensor: bt.subtensor,
+ subnet_netuid: int,
+ log_level: str,
+ wallet: bt.wallet | None = None,
+ ):
+
+ self.subtensor = subtensor
+ self.subnet_netuid = subnet_netuid
+ self.wallet = wallet
+ self.log_level = log_level
+ self.metadata = ''
+ self.metadata_block = 0
+
+ Utils.timeout_decorator(timeout=60)
+ async def upload_model_metadata_to_chain(self, metadata: str):
+ """_summary_
+
+ Args:
+ metadata (str): Hash of metadata string
+
+ Returns:
+ bool: True if metadata could be uploaded to chain, False otherwise
+ """
+ outcome = bt.core.extrinsics.serving.publish_metadata(
+ self.subtensor,
+ wallet=self.wallet,
+ netuid=self.subnet_netuid,
+ data_type=f"Raw{len(metadata)}",
+ data=metadata.encode(),
+ wait_for_inclusion=True,
+ wait_for_finalization=True,
+ )
+
+ return outcome
+
+ Utils.timeout_decorator(timeout=60)
+ async def obtain_model_metadata_from_chain(self, hotkey: str):
+ """_summary_
+
+ Args:
+ hotkey (str): ss58_address of miner hotkey
+
+ Returns:
+ bool: True if model metadata could be obtained from chain, False otherwise
+ """
+ try:
+
+ metadata = bt.core.extrinsics.serving.get_metadata(
+ self=self.subtensor,
+ netuid=self.subnet_netuid,
+ hotkey=hotkey
+ )
+
+ commitment = metadata["info"]["fields"][0]
+ hex_data = commitment[list(commitment.keys())[0]][2:]
+ self.metadata = bytes.fromhex(hex_data).decode()
+
+ self.metadata_block = metadata['block']
+
+ return True
+
+ except Exception as e:
+ raise e
+ return False
+
+ def get_competition_id_from_competition_name(self, competition_name):
+ """Obtains competition id from competition name for metadata purposes
+
+ Args:
+ :param competition_name: (str): Name of competition
+
+ Returns:
+ int | None: int if competition_name is valid, None otherwise
+ """
+ conversion_dict={
+ "DENOISING_16000HZ":1,
+ "DEREVERBERATION_16000HZ":2,
+ }
+
+ if competition_name in conversion_dict.keys():
+ return conversion_dict[competition_name]
+
+ return None
+
+ def get_competition_name_from_competition_id(self, competition_id):
+ """Get competition string from numerical id
+
+ Args:
+ :param competition_id: (int | str): id of competition as used in metadata string
+
+ Returns:
+ :param competition_name" (str): name of competition as used in dict keys in the rest of the repo
+ """
+ conversion_dict = {
+ "1":"DENOISING_16000HZ",
+ "2":"DEREVERBERATION_16000HZ",
+ }
+
+ if str(competition_id) in conversion_dict.keys():
+ return conversion_dict[str(competition_id)]
+
+ return None
\ No newline at end of file
diff --git a/soundsright/base/models/sgmse.py b/soundsright/base/models/sgmse.py
new file mode 100644
index 0000000..1d6e854
--- /dev/null
+++ b/soundsright/base/models/sgmse.py
@@ -0,0 +1,223 @@
+import os
+import yaml
+import glob
+import shutil
+from typing import List
+from git import Repo
+
+import soundsright.base.utils as Utils
+
+
+class SGMSEHandler:
+
+ def __init__(self, task: str, sample_rate: int, task_path: str, sgmse_path: str, sgmse_output_path: str, log_level: str) -> None:
+
+ self.hf_model_url = "https://huggingface.co/synapsecai/SoundsRightModelTemplate"
+ self.task = task
+ self.sample_rate = sample_rate
+ self.competition = f"{task}_{sample_rate}HZ"
+ self.task_path = task_path
+ self.sgmse_path = sgmse_path
+ self.sgmse_output_path = sgmse_output_path
+ self.log_level = log_level
+
+ def download_model_container(self) -> bool:
+ try:
+ Repo.clone_from(self.hf_model_url, self.sgmse_path, branch=self.competition)
+ return True
+ except:
+ return False
+
+ def _reset_dir(self, directory):
+ """Removes all files and sub-directories in an inputted directory
+
+ Args:
+ directory (_type_): _description_
+ """
+ # Check if the directory exists
+ if not os.path.exists(directory):
+ return
+
+ # Loop through all the files and subdirectories in the directory
+ for filename in os.listdir(directory):
+ file_path = os.path.join(directory, filename)
+
+ # Check if it's a file or directory and remove accordingly
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path) # Remove the file or link
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path) # Remove the directory and its contents
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Failed to delete {file_path}. Reason: {e}",
+ log_level=self.log_level
+ )
+
+ def reset_model_dirs(self):
+ """
+ Removes files and sub-directories in ModelEvaluationHandler.sgmse_path and
+ ModelEvaluationHandler.sgmse_output_path to make sure all is clear for
+ next model to be benchmarked.
+ """
+ self._reset_dir(directory=self.sgmse_path)
+ self._reset_dir(directory=self.sgmse_output_path)
+
+ def initialize_and_run_model(self):
+ """Initializes model and runs the container to enhance audio
+
+ Returns:
+ bool: True if operations were successful, False otherwise
+ """
+ # Delete everything before starting container
+ Utils.delete_container(log_level=self.log_level)
+
+ # Start container
+ if not Utils.start_container(directory=self.sgmse_path, log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="SGMSE+ container could not be started. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="SGMSE+ Container started.",
+ log_level=self.log_level,
+ )
+
+ if not Utils.check_container_status(log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not establish connection with SGMSE+ API. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Connection established with SGMSE+ API and status was verified. Commencing model preparation.",
+ log_level=self.log_level
+ )
+
+ if not Utils.prepare(log_level=self.log_level):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not prepare the SGMSE+ model. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="SGMSE+ model preparation successful. Commencing transfer of noisy files.",
+ log_level=self.log_level
+ )
+
+ if not Utils.upload_audio(noisy_dir=self.task_path, log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="Noisy files could not be uploaded to SGMSE+ model container. Ending benchmarking of SGMSE+ model. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Noisy files were transferred to SGMSE+ model container. Commencing enhancement.",
+ log_level=self.log_level,
+ )
+
+ if not Utils.enhance_audio(log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="Noisy files could not be enhanced by SGMSE+ model. Ending benchmarking of model. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Enhancement complete. Downloading enhanced files from SGMSE+ model container.",
+ log_level=self.log_level,
+ )
+
+ if not Utils.download_enhanced(enhanced_dir=self.sgmse_output_path, log_level=self.log_level):
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="Enhanced files could not be downloaded. Ending benchmarking of SGMSE+ model. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message="Downloaded and unzipped files from SGMSE+ model API. Validating that all noisy files have been enhanced.",
+ log_level=self.log_level,
+ )
+
+ if not self.validate_all_noisy_files_are_enhanced():
+
+ Utils.subnet_logger(
+ severity="ERROR",
+ message="Mismatch detected between noisy and enhanced files. Ending benchmarking of SGMSE+ model. Please contact subnet owners if issue persists.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+ return False
+
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Validation successful. Deleting SGMSE+ container.",
+ log_level=self.log_level
+ )
+
+ Utils.delete_container(log_level=self.log_level)
+
+ return True
+
+ def validate_all_noisy_files_are_enhanced(self):
+ noisy_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(self.task_path, '*.wav'))])
+ enhanced_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(self.sgmse_output_path, '*.wav'))])
+ return noisy_files == enhanced_files
+
+ def download_start_and_enhance(self) -> bool:
+ # Initial cleaning of model dirs
+ self.reset_model_dirs()
+
+ # Download model
+ if not self.download_model_container():
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not download SGMSE container (branch: {self.competition}. Please contact subnet owners if issue persists.)",
+ log_level=self.log_level
+ )
+ self.reset_model_dirs()
+ return False
+
+ # Initialize and run model
+ if not self.initialize_and_run_model():
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Could not enhance files using SGMSE container (branch: {self.competition}. Please contact subnet owners if issue persists.)",
+ log_level=self.log_level
+ )
+ self.reset_model_dirs()
+ return False
+
+ return True
\ No newline at end of file
diff --git a/soundsright/base/models/validation.py b/soundsright/base/models/validation.py
new file mode 100644
index 0000000..1634606
--- /dev/null
+++ b/soundsright/base/models/validation.py
@@ -0,0 +1,88 @@
+import hashlib
+import os
+import base64
+import shutil
+from git import Repo
+
+import soundsright.base.utils as Utils
+
+def get_directory_content_hash(directory: str):
+ """
+ Computes a single hash of the combined contents of all files in a directory,
+ excluding certain unnecessary files and directories.
+
+ Args:
+ :param directory: (str): Path to the directory.
+
+ Returns:
+ str: A base64-encoded hash representing the combined contents of all files.
+ """
+ hash_obj = hashlib.sha256()
+ excluded_dirs = {'.git'}
+ excluded_files = {'.lock', '.metadata'}
+
+ # Traverse the directory in a consistent order
+ for root, dirs, files in os.walk(directory):
+ # Exclude specified directories
+ dirs[:] = [d for d in dirs if d not in excluded_dirs]
+ sorted_files = sorted(files)
+
+ for file_name in sorted_files: # Sort files to ensure consistent order
+ if file_name in excluded_files:
+ continue # Skip excluded files
+
+ file_path = os.path.join(root, file_name)
+
+ # Update the hash with the relative file path to capture structure
+ rel_path = os.path.relpath(file_path, directory).replace(os.sep, '/')
+ hash_obj.update(rel_path.encode())
+
+ try:
+ # Read the entire content to confirm there’s no issue
+ with open(file_path, "rb") as f:
+ file_contents = f.read()
+
+ # If reading was successful, update the hash with file contents
+ hash_obj.update(file_contents)
+
+ except Exception as e:
+ continue
+
+ # Encode the final hash in base64 and return it
+ return base64.b64encode(hash_obj.digest()).decode(), sorted_files
+
+def get_model_content_hash(
+ model_id: str,
+ revision: str,
+ local_dir: str,
+ log_level: str,
+):
+ """
+ Downloads the model and computes the hash of its entire contents.
+
+ Args:
+ :param model_id: (str): The repository ID of the Hugging Face model (e.g., 'bert-base-uncased').
+ :param revision: (str): The specific branch, tag, or commit hash (default is 'main').
+ :param local_dir: (str): Local directory to download the model to.
+ :param log_level: (str): One of: INFO, INFOX, DEBUG, DEBUGX, TRACE, TRACEX.
+
+ Returns:
+ str: The combined hash of the model's contents.
+ """
+ # Remove all directory contents if it doesn't exist
+ try:
+ shutil.rmtree(local_dir)
+ except:
+ Utils.subnet_logger(
+ severity="TRACE",
+ message=f"Model directory already deleted: {local_dir}",
+ log_level=log_level
+ )
+
+ repo_url = f"https://huggingface.co/{model_id}"
+
+ # Download the model files for the specified revision
+ Repo.clone_from(repo_url, local_dir, branch=revision)
+
+ # Compute the hash of the model's contents
+ return get_directory_content_hash(directory=local_dir)
\ No newline at end of file
diff --git a/soundsright/base/neuron.py b/soundsright/base/neuron.py
new file mode 100644
index 0000000..563317b
--- /dev/null
+++ b/soundsright/base/neuron.py
@@ -0,0 +1,163 @@
+"""
+Module for SoundsRight subnet neurons.
+
+Neurons are the backbone of the subnet and are providing the subnet
+users tools to interact with the subnet and participate in the
+value-creation chain. There are two primary neuron classes: validator and miner.
+"""
+
+from argparse import ArgumentParser
+import os
+from datetime import datetime
+import bittensor as bt
+import numpy as np
+import pickle
+
+# Import custom modules
+import soundsright.base.utils as Utils
+import soundsright.base.data as Data
+
+def convert_data(data):
+ if isinstance(data, dict):
+ return {key: convert_data(value) for key, value in data.items()}
+ elif isinstance(data, list):
+ return [convert_data(item) for item in data]
+ elif isinstance(data, np.ndarray):
+ return data.item() if data.size == 1 else data.tolist()
+ elif isinstance(data, np.float32):
+ return float(data.item()) if data.size == 1 else data.tolist()
+ else:
+ return data
+
+class BaseNeuron:
+ """Base neuron class for the SoundsRight Subnet.
+
+ This class handles base operations for both the miner and validator.
+
+ Attributes:
+ parser:
+ Instance of ArgumentParser with the arguments given as
+ command-line arguments in the execution script
+ profile:
+ Instance of str depicting the profile for the neuron
+ """
+
+ def __init__(self, parser: ArgumentParser, profile: str) -> None:
+ self.parser = parser
+ self.path_hotkey = None
+ self.profile = profile
+ self.step = 0
+ self.last_updated_block = 0
+ self.subnet_version = Utils.config["module_version"]
+ self.score_version = Utils.config["score_version"]
+ self.base_path = os.path.join(os.path.expanduser('~'), ".SoundsRight")
+ self.cache_path = None
+ self.log_path = None
+ self.tts_path = None # Where clean TTS datasets are stored
+ self.noise_data_path = None # Where the noise dataset is stored
+ self.rir_data_path = None # Where the RIR dataset is stored
+ self.reverb_path = None # Where the TTS with reverb added is stored
+ self.noise_path = None # Where the TTS with noise added is stored
+ self.model_output_path = None # Where the model's outputs are stored
+ self.model_path = None # Where the model is stored
+ self.sgmse_path = None # Where the SGMSE+ model and its checkpoints will be stored
+ self.sgmse_output_path = None # Where the SGMSE+ model outputs will be stored
+ self.healthcheck_api = None
+ self.log_level = "INFO"
+
+ def config(self, bt_classes: list) -> bt.config:
+ """Applies neuron configuration.
+
+ This function attaches the configuration parameters to the
+ necessary bittensor classes and initializes the logging for the
+ neuron.
+
+ Args:
+ bt_classes:
+ A list of Bittensor classes the apply the configuration
+ to
+
+ Returns:
+ config:
+ An instance of Bittensor config class containing the
+ neuron configuration
+
+ Raises:
+ AttributeError:
+ An error occurred during the configuration process
+ OSError:
+ Unable to create a log path.
+
+ """
+ try:
+ for bt_class in bt_classes:
+ bt_class.add_args(self.parser)
+ except AttributeError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to attach ArgumentParsers to Bittensor classes: {e}"
+ )
+ raise AttributeError from e
+
+ config = bt.config(self.parser)
+
+ # Construct log path
+ self.path_hotkey = config.wallet.hotkey
+ self.log_path = os.path.join(self.base_path, "logs", config.wallet.name, config.wallet.hotkey, str(config.netuid), self.profile)
+
+ # Construct cache path
+ self.cache_path = os.path.join(self.base_path, "cache", config.wallet.name, config.wallet.hotkey, str(config.netuid), self.profile, self.score_version)
+
+ # Construct data paths
+ self.noise_path = os.path.join(self.base_path, "data", "noise")
+ self.reverb_path = os.path.join(self.base_path, "data", "reverb")
+ self.rir_data_path = os.path.join(self.base_path, "data", "rir_data")
+ self.noise_data_path = os.path.join(self.base_path, "data", "noise_data")
+ self.tts_path = os.path.join(self.base_path, "data", "tts")
+ self.model_output_path = os.path.join(self.base_path, "models", "model_output")
+ self.model_path = os.path.join(self.base_path, "models", "model")
+ self.sgmse_path = os.path.join(self.base_path, "models", "sgmse")
+ self.sgmse_output_path = os.path.join(self.base_path, "models", "sgmse_output")
+ self.sgmse_ckpt_files = {
+ "DENOISING_16000HZ":"train_wsj0_2cta4cov_epoch=159.ckpt",
+ "DEREVERBERATION_16000HZ":"epoch=326-step=408750.ckpt",
+ }
+
+ # Create the OS paths if they do not exists
+ try:
+ for os_path in [self.log_path, self.cache_path, self.noise_path, self.reverb_path, self.rir_data_path, self.noise_data_path, self.tts_path, self.model_output_path, self.model_path, self.sgmse_path, self.sgmse_output_path]:
+ full_path = os.path.expanduser(os_path)
+ if not os.path.exists(full_path):
+ os.makedirs(full_path, exist_ok=True)
+
+ if os_path == self.log_path:
+ config.full_path = os.path.expanduser(os_path)
+ except OSError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to create log path: {e}"
+ )
+ raise OSError from e
+
+ return config
+
+ def neuron_logger(self, severity: str, message: str):
+ """This method is a wrapper for the bt.logging function to add extra
+ functionality around the native logging capabilities"""
+
+ Utils.subnet_logger(severity=severity, message=message, log_level=self.log_level)
+
+ # Append extra information to to the logs if healthcheck API is enabled
+ if self.healthcheck_api and severity.upper() in ("SUCCESS", "ERROR", "WARNING"):
+
+ event_severity = severity.lower()
+
+ # Metric
+ self.healthcheck_api.append_metric(
+ metric_name=f"log_entries.{event_severity}", value=1
+ )
+
+ # Store event
+ self.healthcheck_api.add_event(
+ event_name=f"{event_severity}", event_data=message
+ )
\ No newline at end of file
diff --git a/soundsright/base/protocol.py b/soundsright/base/protocol.py
new file mode 100644
index 0000000..08f90ca
--- /dev/null
+++ b/soundsright/base/protocol.py
@@ -0,0 +1,42 @@
+import bittensor as bt
+import pydantic
+
+class Denoising_16kHz_Protocol(bt.Synapse):
+ """
+ This class is used for miners to report to validators
+ their model for the 16kHz denoising task competition.
+ """
+ data: dict | None = pydantic.Field(
+ default=None,
+ description = "HuggingFace model identfication",
+ )
+
+ subnet_version: int = pydantic.Field(
+ ...,
+ description="Subnet version provides information about the subnet version the Synapse creator is running at",
+ allow_mutation=False,
+ )
+
+ def deserialize(self) -> bt.Synapse:
+ """Deserialize the instance of the protocol"""
+ return self
+
+class Dereverberation_16kHz_Protocol(bt.Synapse):
+ """
+ This class is used for miners to report to validators
+ their model for the 16kHz denoising task competition.
+ """
+ data: dict | None = pydantic.Field(
+ default=None,
+ description = "HuggingFace model identfication",
+ )
+
+ subnet_version: int = pydantic.Field(
+ ...,
+ description="Subnet version provides information about the subnet version the Synapse creator is running at",
+ allow_mutation=False,
+ )
+
+ def deserialize(self) -> bt.Synapse:
+ """Deserialize the instance of the protocol"""
+ return self
\ No newline at end of file
diff --git a/soundsright/base/templates/__init__.py b/soundsright/base/templates/__init__.py
new file mode 100644
index 0000000..29e9ea2
--- /dev/null
+++ b/soundsright/base/templates/__init__.py
@@ -0,0 +1,3 @@
+from .topics import TOPICS
+
+from .emotions import EMOTIONS
\ No newline at end of file
diff --git a/soundsright/base/templates/emotions.py b/soundsright/base/templates/emotions.py
new file mode 100644
index 0000000..a2bbdc8
--- /dev/null
+++ b/soundsright/base/templates/emotions.py
@@ -0,0 +1,72 @@
+EMOTIONS = {
+ "happy",
+ "affectionate",
+ "fascinated",
+ "joyous",
+ "excited",
+ "exhilirated",
+ "overjoyed",
+ "amused",
+ "inquisitive",
+ "surprised",
+ "in awe",
+ "calm",
+ "empathetic",
+ "grateful",
+ "blissful",
+ "fulfilled",
+ "patient",
+ "peaceful",
+ "serene",
+ "trusting",
+ "delighted",
+ "inspired",
+ "passionate",
+ "playful",
+ "refreshed",
+ "ecstatic",
+ "engaged",
+ "enthusiastic",
+ "confident",
+ "brave",
+ "optimistic",
+ "agitated",
+ "aggravated",
+ "bitter",
+ "cynical",
+ "disdainful",
+ "edgy",
+ "disturbed",
+ "disgruntled",
+ "furious",
+ "grouchy",
+ "frustrated",
+ "hostile",
+ "irritated",
+ "irate",
+ "moody",
+ "on edge",
+ "outraged",
+ "upset",
+ "hopeless",
+ "gloomy",
+ "dissapointed",
+ "sorrowful",
+ "upset",
+ "bored",
+ "confused",
+ "distant",
+ "indifferent",
+ "ashamed",
+ "humiliated",
+ "mortified",
+ "afraid",
+ "nervous",
+ "anxious",
+ "apprehensive",
+ "scared",
+ "sensitive",
+ "regretful",
+ "remorseful",
+ "concerned",
+}
\ No newline at end of file
diff --git a/soundsright/base/templates/topics.py b/soundsright/base/templates/topics.py
new file mode 100644
index 0000000..fb8497b
--- /dev/null
+++ b/soundsright/base/templates/topics.py
@@ -0,0 +1,311 @@
+TOPICS = {
+ "art history",
+ "marine conservation",
+ "aviation history",
+ "mythical creatures",
+ "dairy farming",
+ "humanitarian aid",
+ "environmental science",
+ "gardening tools",
+ "world music",
+ "waste reduction",
+ "reptiles",
+ "sustainable agriculture",
+ "endangered species",
+ "asian culture",
+ "biodiversity",
+ "organic farming",
+ "canine behavior",
+ "latin american culture",
+ "traditional crafts",
+ "primatology",
+ "human rights",
+ "brewing history",
+ "capital punishment",
+ "natural disasters",
+ "pedagogy",
+ "pottery",
+ "glass art",
+ "monarchies",
+ "wildlife",
+ "burial practices",
+ "card games",
+ "mindfulness",
+ "green living",
+ "historic inventions",
+ "ocean life",
+ "botanical science",
+ "performing arts",
+ "gourmet cooking",
+ "wild canines",
+ "material science",
+ "playground equipment",
+ "societal issues",
+ "orchestral music",
+ "ornithology",
+ "ancient myths",
+ "fiber arts",
+ "soccer",
+ "logic puzzles",
+ "comparative religion",
+ "macroeconomics",
+ "ancient cultures",
+ "hammock making",
+ "cotton production",
+ "cartography",
+ "urban development",
+ "gemology",
+ "cultural studies",
+ "european history",
+ "bohemian lifestyle",
+ "civil engineering",
+ "spa culture",
+ "digital communications",
+ "mineral extraction",
+ "woodworking",
+ "broadcast media",
+ "human anatomy",
+ "equine science",
+ "mathematical theory",
+ "international sports",
+ "board games",
+ "world history",
+ "computational theory",
+ "lighthouse engineering",
+ "carpentry",
+ "match manufacturing",
+ "athletic training",
+ "airport logistics",
+ "clown performance",
+ "beach volleyball",
+ "spelunking",
+ "ceramic arts",
+ "swine farming",
+ "quantum mechanics",
+ "immunology",
+ "automotive technology",
+ "professional wrestling",
+ "romani culture",
+ "recycling systems",
+ "fossil fuel impact",
+ "joinery",
+ "dietary science",
+ "large mammals",
+ "automation",
+ "salt mining",
+ "language studies",
+ "railway systems",
+ "fashion design",
+ "earth sciences",
+ "poetic forms",
+ "sheep herding",
+ "children's playgrounds",
+ "barbering",
+ "internet security",
+ "mule breeding",
+ "baking",
+ "currency systems",
+ "stone masonry",
+ "gaming culture",
+ "entomology",
+ "penology",
+ "comic book art",
+ "magic tricks",
+ "nautical vessels",
+ "wind energy",
+ "organ music",
+ "bicycle sports",
+ "surf culture",
+ "circus arts",
+ "environmental protection",
+ "plant sciences",
+ "sugar production",
+ "rail transport",
+ "mass media",
+ "shoe repair",
+ "wildlife conservation",
+ "gorilla behavior",
+ "local food markets",
+ "competitive sports",
+ "theatrical performance",
+ "rhinoceros conservation",
+ "seismology",
+ "tropical fruits",
+ "shark biology",
+ "firefighting",
+ "craft brewing",
+ "cinematography",
+ "pharmaceuticals",
+ "amphibians",
+ "maternal health",
+ "apiology",
+ "wildlife preservation",
+ "animation techniques",
+ "coffee culture",
+ "fine arts",
+ "textile arts",
+ "skydiving",
+ "collectible cards",
+ "digital photography",
+ "documentary production",
+ "astronautics",
+ "canoeing",
+ "fencing",
+ "puppet theater",
+ "reptilian species",
+ "ice hockey",
+ "ballet dancing",
+ "legal studies",
+ "modern technology",
+ "root vegetables",
+ "volcanology",
+ "tobacco industry",
+ "weather forecasting",
+ "banana cultivation",
+ "psychological studies",
+ "mechanical engineering",
+ "creative inventions",
+ "rugby",
+ "wildlife biology",
+ "musical instruments",
+ "airship history",
+ "parachuting",
+ "tobacco products",
+ "acrobatics",
+ "medical imaging",
+ "textile manufacturing",
+ "electronic sports",
+ "cricket",
+ "ignition devices",
+ "leather crafting",
+ "journalism",
+ "vanilla farming",
+ "cue sports",
+ "maritime history",
+ "golf",
+ "leisure activities",
+ "marine ecosystems",
+ "social sciences",
+ "medical practices",
+ "cheese making",
+ "fortune telling",
+ "buddhist studies",
+ "martial arts",
+ "amphibian species",
+ "news delivery",
+ "disease research",
+ "pack animals",
+ "bread making",
+ "printing technology",
+ "weaving",
+ "dental care",
+ "boxing",
+ "indigenous cultures",
+ "aviation",
+ "sailing",
+ "fashion history",
+ "caprine animals",
+ "bison",
+ "primates",
+ "hunting practices",
+ "hospitality",
+ "chiropterology",
+ "bakery operations",
+ "nuclear power",
+ "association football",
+ "mountaineering",
+ "transport logistics",
+ "news reporting",
+ "philosophy",
+ "motorcycling",
+ "communication devices",
+ "military history",
+ "criminal behavior",
+ "legal systems",
+ "basketball",
+ "archery",
+ "criminology",
+ "ferry transportation",
+ "literature",
+ "funerary customs",
+ "pharmacy operations",
+ "political science",
+ "desert environments",
+ "public transit",
+ "crime prevention",
+ "brewing science",
+ "leatherworking",
+ "pharmacology",
+ "mountain climbing",
+ "political ideologies",
+ "mummification",
+ "stadium design",
+ "financial systems",
+ "armed forces",
+ "big cats",
+ "water management",
+ "mortality",
+ "hot air balloons",
+ "mountain ranges",
+ "world war i",
+ "leather tanning",
+ "chess strategy",
+ "culinary techniques",
+ "pork production",
+ "occupational studies",
+ "architectural design",
+ "literacy",
+ "dance styles",
+ "cinema history",
+ "insect study",
+ "yoga practice",
+ "carousel design",
+ "comedy performance",
+ "scouting",
+ "angling",
+ "medical research",
+ "forest ecology",
+ "elephant behavior",
+ "plant biology",
+ "historical figures",
+ "reptile studies",
+ "documentary filmmaking",
+ "coal mining",
+ "natural events",
+ "scientific research",
+ "baseball",
+ "metalwork",
+ "aerospace engineering",
+ "cycling",
+ "distillation",
+ "jazz music",
+ "african cultures",
+ "smoking cessation",
+ "seashells",
+ "horticulture",
+ "community service",
+ "numismatics",
+ "smoking pipes",
+ "veterinary science",
+ "astronomy",
+ "radio broadcasting",
+ "river ecosystems",
+ "rice cultivation",
+ "funerary practices",
+ "camel breeding",
+ "accordion music",
+ "climate science",
+ "gold mining",
+ "marine archaeology",
+ "zoology",
+ "building construction",
+ "renewable energy",
+ "blockchain",
+ "cryptocurrency",
+ "bitcoin",
+ "professional e-sports",
+ "professional cricket",
+ "youe favorite bars in the area",
+ "your favorite workout routine",
+ "your favorite restaurant in the area",
+}
\ No newline at end of file
diff --git a/soundsright/base/utils/__init__.py b/soundsright/base/utils/__init__.py
new file mode 100644
index 0000000..f56787b
--- /dev/null
+++ b/soundsright/base/utils/__init__.py
@@ -0,0 +1,26 @@
+from .logging import subnet_logger
+
+from .config import ModuleConfig
+
+config = ModuleConfig().get_full_config()
+
+from .healthcheck import HealthCheckAPI
+
+from .utils import (
+ timeout_decorator,
+ validate_uid,
+ validate_miner_response,
+ validate_model_benchmark,
+ sign_data
+)
+
+from .container import (
+ validate_container_config,
+ start_container,
+ check_container_status,
+ prepare,
+ upload_audio,
+ enhance_audio,
+ download_enhanced,
+ delete_container,
+)
\ No newline at end of file
diff --git a/soundsright/base/utils/config.py b/soundsright/base/utils/config.py
new file mode 100644
index 0000000..1277797
--- /dev/null
+++ b/soundsright/base/utils/config.py
@@ -0,0 +1,59 @@
+"""This module is responsble for managing the configuration parameters
+used by the soundsright module"""
+
+from os import environ
+from dotenv import load_dotenv
+
+load_dotenv()
+
+class ModuleConfig:
+ """This class is used to standardize the presentation of
+ configuration parameters used throughout the soundsright module"""
+
+ def __init__(self):
+
+ # Determine module code version
+ self.__version__ = "1.0.0"
+
+ # Determine the score version
+ self.__score_version__ = "1"
+
+ # Convert the version into a single integer
+ self.__version_split__ = self.__version__.split(".")
+ self.__spec_version__ = (
+ (1000 * int(self.__version_split__[0]))
+ + (10 * int(self.__version_split__[1]))
+ + (1 * int(self.__version_split__[2]))
+ )
+
+ # Initialize with default values
+ self.__config__ = {
+ "module_version": self.__spec_version__,
+ "score_version": self.__score_version__,
+ }
+
+ def get_full_config(self) -> dict:
+ """Returns the full configuration data"""
+ return self.__config__
+
+ def set_config(self, key, value) -> dict:
+ """Updates the configuration value of a particular key and
+ returns updated configuration"""
+
+ if key and value:
+ self.__config__[key] = value
+ elif key and isinstance(value, bool):
+ self.__config__[key] = value
+ else:
+ raise ValueError(f"Unable to set the value: {value} for key: {key}")
+ return self.get_full_config()
+
+ def get_config(self, key):
+ """Returns the configuration for a particular key"""
+
+ value = (self.get_full_config())[key]
+
+ if not value and not isinstance(value, bool):
+ raise ValueError(f"Unable to get the value: {value} for key: {key}")
+
+ return value
\ No newline at end of file
diff --git a/soundsright/base/utils/container.py b/soundsright/base/utils/container.py
new file mode 100644
index 0000000..9848a23
--- /dev/null
+++ b/soundsright/base/utils/container.py
@@ -0,0 +1,425 @@
+import os
+import yaml
+import subprocess
+import requests
+import zipfile
+import time
+import glob
+
+import soundsright.base.utils as Utils
+
+def check_dockerfile_for_root_user(dockerfile_path):
+ """
+ Checks if a Dockerfile configures the container to run as a root user,
+ considering ARG definitions for the user ID.
+
+ Args:
+ directory (str): The directory to search for a Dockerfile.
+ dockerfile_path (str): The specific path to the Dockerfile.
+
+ Returns:
+ bool: True if the Dockerfile configures the container to run as root, False otherwise.
+
+ Raises:
+ FileNotFoundError: If no Dockerfile is found at the specified path.
+ """
+ try:
+ user_line_exists = False
+ arg_definitions = {}
+ env_definitions = {}
+
+ with open(dockerfile_path, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip()
+
+ # Parse ARG directives
+ if line.startswith("ARG"):
+ parts = line.split("=")
+ if len(parts) == 2:
+ arg_name = parts[0].split()[1].strip()
+ arg_value = parts[1].strip()
+ arg_definitions[arg_name] = arg_value
+
+ if line.startswith("ENV"):
+ parts = line.split("=")
+ if len(parts) == 2:
+ env_name = parts[0].split()[1].strip()
+ env_value = parts[1].strip()
+ env_definitions[env_name] = env_value
+
+ # Parse USER directive
+ if line.startswith("USER"):
+ user_line_exists = True
+ user = line.split()[1]
+
+ # Resolve ARG references in the USER directive
+ if user.startswith("$"):
+ user = arg_definitions.get(user[1:], "0")
+ if user == "root" or str(user) == "0" or user.startswith("$"):
+ return True
+ user = env_definitions.get(user[1:], "0")
+ if user == "root" or str(user) == "0" or user.startswith("$"):
+ return True
+
+ # Check if the resolved user is root
+ if user == "root" or str(user) == "0":
+ return True
+
+ # Check for conflicts with specific UID (e.g., validator UID)
+ elif "10001" in line:
+ return True
+
+ # If no USER directive is found, the default is root
+ if not user_line_exists:
+ return True
+
+ except Exception as e:
+ return True # Default to True if an error occurs to err on the side of caution
+
+ return False # Returns False if no root configuration is detected
+
+def check_dockerfile_for_sensitive_config(dockerfile_path, sensitive_directories):
+ """
+ Finds a Dockerfile in the specified directory or its subdirectories and checks
+ if the `.bittensor` directory is mounted as a volume.
+
+ Args:
+ directory (str): The directory to search for a Dockerfile. Defaults to the current directory.
+
+ Returns:
+ bool: True if the `.bittensor` directory is mounted, False otherwise.
+
+ Raises:
+ FileNotFoundError: If no Dockerfile is found in the specified directory or its subdirectories.
+ """
+ try:
+ with open(dockerfile_path, "r") as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip()
+ if any(sensitive in line for sensitive in sensitive_directories):
+ return True
+
+ except Exception as e:
+ return True
+
+ # If no VOLUME directive mentions .bittensor, return False
+ return False
+
+def validate_container_config(directory) -> bool:
+ """
+ Makes sure that both the Dockerfile and docker-compose.yml files
+ do not run the container as root,
+ Args:
+ directory (str): Repository of docker container
+
+ Returns:
+ """
+ # Find dockerfile path
+ dockerfile_path = None
+
+ # Search for the Dockerfile in the directory and subdirectories
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file == "Dockerfile":
+ dockerfile_path = os.path.join(root, file)
+ break
+ if dockerfile_path:
+ break
+
+ if not dockerfile_path:
+ return False
+
+ # Find docker-compose.yml path
+ dockerfile_path = None
+
+ # Search for docker-compose.yml in the directory and its subdirectories
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file in ["docker-compose.yml", "docker-compose.yaml"]:
+ dockerfile_path = os.path.join(root, file)
+ break
+ if dockerfile_path:
+ break
+
+ if not dockerfile_path:
+ return True
+
+ # Define sensitive host directories to look for
+ sensitive_directories = [
+ "docker.sock",
+ "var",
+ "etc",
+ "proc",
+ "sys",
+ "dev",
+ "root",
+ "home",
+ "boot",
+ "lib",
+ "lib64",
+ "opt",
+ "mnt",
+ "media",
+ "proc",
+ ".bittensor",
+ "bittensor"
+ ]
+
+ if check_dockerfile_for_root_user(dockerfile_path):
+ return False
+
+ if check_dockerfile_for_sensitive_config(dockerfile_path, sensitive_directories):
+ return False
+
+ return True
+
+def start_container(directory, log_level) -> bool:
+ """Runs the container with docker compose
+
+ Args:
+ directory (str): Directory containing the container
+ """
+ dockerfile_path = None
+
+ # Search for docker-compose.yml in the directory and its subdirectories
+ for root, _, files in os.walk(directory):
+ for file in files:
+ if file == "Dockerfile":
+ dockerfile_path = os.path.join(root, file)
+ break
+ if dockerfile_path:
+ break
+
+ if not dockerfile_path:
+ return False
+
+ if not os.path.isfile(dockerfile_path):
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"No `Dockerfile` file found in the specified directory: {directory}",
+ log_level=log_level,
+ )
+ return False
+
+ try:
+ result0 = subprocess.run(["podman", "build", "-t", "modelapi", "--file", dockerfile_path], check=True)
+ if result0.returncode != 0:
+ return False
+ result1 = subprocess.run(["podman", "run", "-d", "--device", "nvidia.com/gpu=all", "--volume", "/usr/local/cuda-12.6:/usr/local/cuda-12.6", "--user", "10002:10002", "--name", "modelapi", "-p", "6500:6500", "modelapi"], check=True)
+ if result1.returncode != 0:
+ return False
+ return True
+
+ except subprocess.CalledProcessError as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Container could not be started due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Container could not be started due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+
+def check_container_status(log_level, timeout=5) -> bool:
+
+ url = f"http://127.0.0.1:6500/status/"
+ try:
+ start_time = int(time.time())
+ current_time = start_time
+ while start_time + 100 >= current_time:
+ try:
+ res = requests.get(url, timeout=timeout)
+ if res.status_code == 200:
+ data=res.json()
+ if "container_running" in data.keys() and data['container_running']:
+ return True
+ current_time = int(time.time())
+ except:
+ current_time = int(time.time())
+
+ return False
+
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Container status could not be determiend due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+
+def upload_audio(noisy_dir, log_level, timeout=500,) -> bool:
+ """
+ Upload audio files to the API.
+
+ Returns:
+ bool: True if operation was successful, False otherwise
+ """
+ url = f"http://127.0.0.1:6500/upload-audio/"
+
+ files = sorted(glob.glob(os.path.join(noisy_dir, "*.wav")))
+
+ try:
+ with requests.Session() as session:
+ file_payload = [
+ ("files", (os.path.basename(file), open(file, "rb"), "audio/wav"))
+ for file in files
+ ]
+
+ response = session.post(url, files=file_payload, timeout=timeout)
+
+ for _, file in file_payload:
+ file[1].close() # Ensure all files are closed after the request
+
+ response.raise_for_status()
+ data = response.json()
+
+ sorted_files = sorted([file[1][0] for file in file_payload])
+ sorted_response = sorted(data["uploaded_files"])
+ return sorted_files == sorted_response and data["status"]
+
+ except requests.RequestException as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Uploading audio to model failed because: {e}",
+ log_level=log_level
+ )
+ return False
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Uploading audio to model failed because: {e}",
+ log_level=log_level
+ )
+ return False
+
+def prepare(log_level, timeout=300) -> bool:
+
+ url = f"http://127.0.0.1:6500/prepare/"
+ try:
+ res = requests.post(url, timeout=timeout)
+ if res.status_code==200:
+ data = res.json()
+ return data['preparations']
+ return False
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Container model could not be prepared due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+
+def enhance_audio(log_level, timeout=3600) -> bool:
+ """
+ Trigger audio enhancement on the API.
+
+ Returns:
+ bool: True if enhancement was successful, False otherwise
+ """
+ url = f"http://127.0.0.1:6500/enhance/"
+
+ try:
+ response = requests.post(url, timeout=timeout)
+ response.raise_for_status()
+ data = response.json()
+ return data['status']
+ except requests.RequestException as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Audio could not be enhanced due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Audio could not be enhanced due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+
+def download_enhanced(enhanced_dir, log_level, timeout=500) -> bool:
+ """
+ Download the zip file containing enhanced audio files, extract its contents,
+ and remove the zip file.
+
+ Args:
+ enhanced_dir (str): Directory to save and extract the downloaded zip file.
+
+ Returns:
+ bool: True if successful, False otherwise.
+ """
+ url = "http://127.0.0.1:6500/download-enhanced/"
+ zip_file_path = os.path.join(enhanced_dir, "enhanced_audio_files.zip")
+
+ try:
+ # Download the ZIP file
+ response = requests.get(url, stream=True, timeout=timeout)
+ response.raise_for_status()
+
+ # Save the ZIP file to enhanced_dir
+ with open(zip_file_path, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ # Extract the ZIP file contents to enhanced_dir
+ with zipfile.ZipFile(zip_file_path, "r") as zip_file:
+ zip_file.extractall(enhanced_dir)
+
+ # Delete the ZIP file after extraction
+ os.remove(zip_file_path)
+
+ return True
+ except requests.RequestException as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Enhanced audio could not be downloaded due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Enhanced audio could not be downloaded due to error: {e}",
+ log_level=log_level,
+ )
+ return False
+
+def delete_container(log_level) -> bool:
+ """Deletes a specified Docker container by name or ID.
+
+ Returns:
+ bool: True if the container was successfully deleted, False otherwise.
+ """
+ try:
+ # Delete container
+ subprocess.run(
+ ["podman", "rm", "-f", "modelapi"],
+ check=True
+ )
+ # Remove all images
+ subprocess.run(
+ ["podman", "rmi", "-a", "-f"],
+ check=True,
+ )
+ # System prune
+ subprocess.run(
+ ["podman", "system", "prune", "-a", "-f"],
+ check=True,
+ )
+ return True
+
+ except Exception as e:
+ Utils.subnet_logger(
+ severity="ERROR",
+ message=f"Container deletion failed due to error: {e}",
+ log_level=log_level,
+ )
+ return False
\ No newline at end of file
diff --git a/soundsright/base/utils/healthcheck.py b/soundsright/base/utils/healthcheck.py
new file mode 100644
index 0000000..0de30dc
--- /dev/null
+++ b/soundsright/base/utils/healthcheck.py
@@ -0,0 +1,331 @@
+"""
+This module implements a health check API for the LLM Defender Subnet
+neurons. The purpose of the health check API is to provide key
+information about the health of the neuron to enable easier
+troubleshooting.
+
+It is highly recommended to connect the health check API into the
+monitoring tools used to monitor the server. The health metrics are not
+persistent and will be lost if neuron is restarted.
+
+Endpoints:
+ /healthcheck
+ Returns boolean depicting the health of the neuron based on the
+ health metrics
+ /healthcheck/metrics
+ Returns a dictionary of the metrics the health is derived from
+ /healthcheck/events
+ Returns list of relevant events related to the health metrics
+ (error and warning)
+
+Validator Endpoints:
+
+ /healthcheck/current_models
+ Returns information on models in current competition
+ /healthcheck/best_models
+ Returns information on best models from last competition
+ /healthcheck/competitions
+ Returns information on which competitions are currently being
+ hosted by the subnet
+ /healthcheck/competition_scores
+ Returns information about current competition-specific miner scores
+ /healthcheck/scores
+ Returns information about current overall miner scores
+
+Port and host can be controlled with --healthcheck_port and
+--healthcheck_host parameters.
+"""
+
+from fastapi import FastAPI
+from pydantic import BaseModel, ConfigDict
+from typing import Dict
+import datetime
+import uvicorn
+import threading
+import numpy as np
+
+class HealthCheckResponse(BaseModel):
+ status: bool
+ checks: Dict
+ timestamp: str
+
+class HealthCheckDataResponse(BaseModel):
+ data: Dict | None
+ timestamp: str
+
+class HealthCheckScoreResponse(BaseModel):
+ data: np.ndarray | None
+ timestamp: str
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+class HealthCheckAPI:
+ def __init__(self, host: str, port: int, is_validator: bool, current_models: dict = None, best_models: dict = None):
+
+ # Variables
+ self.host = host
+ self.port = port
+ self.is_validator = is_validator
+
+ self.current_models = current_models
+ self.best_models = best_models
+ self.competition_scores = None
+ self.scores = None
+
+ # Status variables
+ self.health_metrics = {
+ "start_time": datetime.datetime.now().timestamp(),
+ "neuron_running": False,
+ "iterations": 0,
+ "datasets_generatred":0,
+ "competitions_judged":0,
+ "log_entries.success": 0,
+ "log_entries.warning": 0,
+ "log_entries.error": 0,
+ "axons.total_filtered_axons": 0,
+ "axons.total_queried_axons": 0,
+ "axons.queries_per_second":0.0,
+ "responses.total_valid_responses": 0,
+ "responses.total_invalid_responses": 0,
+ "responses.valid_responses_per_second":0.0,
+ "responses.invalid_responses_per_second":0.0,
+ "weights.targets": 0,
+ "weights.last_set_timestamp": None,
+ "weights.last_committed_timestamp":None,
+ "weights.last_revealed_timestamp":None,
+ "weights.total_count_set":0,
+ "weights.total_count_committed":0,
+ "weights.total_count_revealed":0,
+ "weights.set_per_second":0.0,
+ "weights.committed_per_second":0.0,
+ "weights.revealed_per_second":0.0
+ }
+
+ self.health_events = {
+ "warning": [],
+ "error": [],
+ "success": []
+ }
+
+ # App
+ self.app = FastAPI()
+ self._setup_routes()
+
+ def _setup_routes(self):
+ self.app.add_api_route(
+ "/healthcheck",
+ self._healthcheck,
+ response_model=HealthCheckResponse
+ )
+ self.app.add_api_route(
+ "/healthcheck/metrics",
+ self._healthcheck_metrics,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "/healthcheck/events",
+ self._healthcheck_events,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "healthcheck/current_models",
+ self._healthcheck_current_models,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "healthcheck/best_models",
+ self._healthcheck_best_models,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "healthcheck/competitions",
+ self._healthcheck_competitions,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "healthcheck/competition_scores",
+ self._healthcheck_competition_scores,
+ response_model=HealthCheckDataResponse,
+ )
+ self.app.add_api_route(
+ "healthcheck/scores",
+ self._healthcheck_scores,
+ response_model=HealthCheckDataResponse,
+ )
+
+ def _healthcheck(self):
+ try:
+ # Update health status when the /healthcheck API is invoked
+ self.healthy, checks = self.get_health()
+
+ # Return status
+ return {"status": self.healthy, "checks": checks, "timestamp": str(datetime.datetime.now())}
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_metrics(self):
+ try:
+ # Return the metrics collected by the HealthCheckAPI
+ return {
+ "data": self.health_metrics,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_events(self):
+ try:
+ # Return the events collected by the HealthCheckAPI
+ return {
+ "data": self.health_events,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_current_models(self):
+ try:
+ return {
+ "data": self.current_models if self.current_models else None,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_best_models(self):
+ try:
+ return {
+ "data": self.best_models if self.best_models else None,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_competitions(self):
+ try:
+ competitions = [k for k in self.best_models]
+ return {
+ "data": competitions if competitions else None,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_competition_scores(self):
+ try:
+ return {
+ "data":self.competition_scores,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def _healthcheck_scores(self):
+ try:
+ return {
+ "data":self.scores,
+ "timestamp": str(datetime.datetime.now()),
+ }
+ except Exception:
+ return {"status": False, "timestamp": str(datetime.datetime.now())}
+
+ def run(self):
+ """This method runs the HealthCheckAPI"""
+ threading.Thread(
+ target=uvicorn.run,
+ args=(self.app,),
+ kwargs={"host": self.host, "port": self.port},
+ daemon=True,
+ ).start()
+
+ def add_event(self, event_name: str, event_data: str) -> bool:
+ """This method adds an event to self.health_events dictionary"""
+ if isinstance(event_name, str) and event_name.upper() in (
+ "SUCCESS",
+ "ERROR",
+ "WARNING",
+ ):
+
+ # Append the received event under the correct key if it is str
+ if isinstance(event_data, str) and not isinstance(event_data, bool):
+ event_severity = event_name.lower()
+ self.health_events[event_severity].append(
+ {"timestamp": str(datetime.datetime.now()), "message": event_data}
+ )
+
+ # Reduce the number of events if more than 250
+ if len(self.health_events[event_severity]) > 250:
+ self.health_events[event_severity] = self.health_events[
+ event_severity
+ ][-250:]
+
+ return True
+
+ return True
+
+ def append_metric(self, metric_name: str, value: int | bool) -> bool:
+ """This method increases the metric counter by the value defined
+ in the counter. If the counter is bool, sets the metric value to
+ the provided value. This function must be executed whenever the
+ counters for the given metrics wants to be updated"""
+
+ if metric_name in self.health_metrics.keys() and value > 0:
+ if isinstance(value, bool):
+ self.health_metrics[metric_name] = value
+ else:
+ self.health_metrics[metric_name] += value
+ else:
+ return False
+
+ return True
+
+ def update_metric(self, metric_name: str, value: str | int | float):
+ """This method updates a value for a metric that renews every iteration."""
+ if metric_name in self.health_metrics.keys():
+ self.health_metrics[metric_name] = value
+ return True
+ else:
+ return False
+
+ def update_rates(self):
+ """This method updates the rate-based parameters within the
+ healthcheck API--prompts generated per second, axons queried per
+ second, valid responses per second and invalid responses per second."""
+
+ time_passed = datetime.datetime.now().timestamp() - self.health_metrics['start_time']
+
+ if time_passed > 0:
+
+ # Calculate queries per second
+ self.health_metrics['axons.queries_per_second'] = self.health_metrics['axons.total_queried_axons'] / time_passed
+
+ # Calculate valid responses per second
+ self.health_metrics['responses.valid_responses_per_second'] = self.health_metrics['responses.total_valid_responses'] / time_passed
+
+ # Calculate invalid responses per second
+ self.health_metrics['responses.invalid_responses_per_second'] = self.health_metrics['responses.total_invalid_responses'] / time_passed
+
+ # Calculate weight set events per second
+ self.health_metrics['weights.set_per_second'] = self.health_metrics['weights.total_count_set'] / time_passed
+
+ # Calculate weight commit events per second
+ self.health_metrics['weights.committed_per_second'] = self.health_metrics["weights.total_count_committed"] / time_passed
+
+ # Calculate weight reveal events per second
+ self.health_metrics['weights.revealed_per_second'] = self.health_metrics["weights.total_count_revealed"] / time_passed
+
+ return True
+
+ else:
+ return False
+
+ def update_current_models(self, current_models):
+ self.current_models = current_models
+
+ def update_best_models(self, best_models):
+ self.best_models = best_models
+
+ def update_competition_scores(self, competition_scores):
+ self.competition_scores = competition_scores
+
+ def update_scores(self, scores):
+ self.scores = scores
\ No newline at end of file
diff --git a/soundsright/base/utils/logging.py b/soundsright/base/utils/logging.py
new file mode 100644
index 0000000..372ecc7
--- /dev/null
+++ b/soundsright/base/utils/logging.py
@@ -0,0 +1,63 @@
+import bittensor as bt
+
+def subnet_logger(severity: str, message: str, log_level: str):
+ """This method is a wrapper for the bt.logging function to add extra
+ functionality around the native logging capabilities. This method is
+ used together with the neuron_logger() method."""
+
+ if (isinstance(severity, str) and not isinstance(severity, bool)) and (
+ isinstance(message, str) and not isinstance(message, bool) and (isinstance(log_level, str) and not isinstance(log_level, bool))
+ ):
+ # Do mapping of custom log levels
+ log_levels = {
+ "INFO": 0,
+ "INFOX": 1,
+ "DEBUG": 2,
+ "DEBUGX": 3,
+ "TRACE": 4,
+ "TRACEX": 5
+ }
+
+ bittensor_severities = {
+ "SUCCESS": "SUCCESS",
+ "WARNING": "WARNING",
+ "ERROR": "ERROR",
+ "INFO": "INFO",
+ "INFOX": "INFO",
+ "DEBUG": "DEBUG",
+ "DEBUGX": "DEBUG",
+ "TRACE": "TRACE",
+ "TRACEX": "TRACE"
+ }
+
+ severity_emoji = {
+ "SUCCESS": chr(0x2705),
+ "ERROR": chr(0x274C),
+ "WARNING": chr(0x1F6A8),
+ "INFO": chr(0x1F4A1),
+ "DEBUG": chr(0x1F527),
+ "TRACE": chr(0x1F50D),
+ }
+
+ # Use utils.subnet_logger() to write the logs
+ if severity.upper() in ("SUCCESS", "ERROR", "WARNING") or log_levels[log_level] >= log_levels[severity.upper()]:
+
+ general_severity=bittensor_severities[severity.upper()]
+
+ if general_severity.upper() == "SUCCESS":
+ bt.logging.success(msg=message, prefix=severity_emoji["SUCCESS"])
+
+ elif general_severity.upper() == "ERROR":
+ bt.logging.error(msg=message, prefix=severity_emoji["ERROR"])
+
+ elif general_severity.upper() == "WARNING":
+ bt.logging.warning(msg=message, prefix=severity_emoji["WARNING"])
+
+ elif general_severity.upper() == "INFO":
+ bt.logging.info(msg=message, prefix=severity_emoji["INFO"])
+
+ elif general_severity.upper() == "DEBUG":
+ bt.logging.debug(msg=message, prefix=severity_emoji["DEBUG"])
+
+ if general_severity.upper() == "TRACE":
+ bt.logging.trace(msg=message, prefix=severity_emoji["TRACE"])
\ No newline at end of file
diff --git a/soundsright/base/utils/utils.py b/soundsright/base/utils/utils.py
new file mode 100644
index 0000000..b9dbcac
--- /dev/null
+++ b/soundsright/base/utils/utils.py
@@ -0,0 +1,118 @@
+import asyncio
+import bittensor as bt
+
+def timeout_decorator(timeout):
+ """
+ Uses asyncio to create an arbitrary timeout for an asynchronous
+ function call. This function is used for ensuring a stuck function
+ call does not block the execution indefinitely.
+
+ Inputs:
+ timeout:
+ The amount of seconds to allow the function call to run
+ before timing out the execution.
+
+ Returns:
+ decorator:
+ A function instance which itself contains an asynchronous
+ wrapper().
+
+ Raises:
+ TimeoutError:
+ Function call has timed out.
+ """
+
+ def decorator(func):
+ async def wrapper(*args, **kwargs):
+ try:
+ # Schedule execution of the coroutine with a timeout
+ return await asyncio.wait_for(func(*args, **kwargs), timeout)
+ except asyncio.TimeoutError:
+ # Raise a TimeoutError with a message indicating which function timed out
+ raise TimeoutError(
+ f"Function '{func.__name__}' execution timed out after {timeout} seconds."
+ )
+
+ return wrapper
+
+ return decorator
+
+def validate_uid(uid):
+ """
+ This method makes sure that a uid is an int instance between 0 and
+ 255. It also makes sure that boolean inputs are filtered out as
+ non-valid uid's.
+
+ Arguments:
+ uid:
+ A unique user id that we are checking to make sure is valid.
+ (integer between 0 and 255).
+
+ Returns:
+ True:
+ uid is valid--it is an integer between 0 and 255, True and
+ False excluded.
+ False:
+ uid is NOT valid.
+ """
+ # uid must be an integer instance between 0 and 255
+ if not isinstance(uid, int) or isinstance(uid, bool):
+ return False
+ if uid < 0 or uid > 255:
+ return False
+ return True
+
+def validate_miner_response(response):
+
+ validation_dict = {
+ 'hf_model_namespace':str,
+ 'hf_model_name':str,
+ 'hf_model_revision':str,
+ }
+
+ for k in response.keys():
+ if not isinstance(response[k], validation_dict[k]) or k not in validation_dict.keys():
+ return False
+
+ return True
+
+def validate_model_benchmark(model_benchmark):
+
+ validation_dict = {
+ 'hf_model_namespace':str,
+ 'hf_model_name':str,
+ 'hf_model_revision':str,
+ 'hotkey':str,
+ 'model_hash':str,
+ 'block':int,
+ 'metrics':dict,
+ }
+
+ for k in model_benchmark.keys():
+ if k not in validation_dict.keys() or not isinstance(model_benchmark[k], validation_dict[k]):
+ return False
+
+ return True
+
+def sign_data(hotkey: bt.Keypair, data: str) -> str:
+ """Signs the given data with the wallet hotkey
+
+ Arguments:
+ wallet:
+ The wallet used to sign the Data
+ data:
+ Data to be signed
+
+ Returns:
+ signature:
+ Signature of the key signing for the data
+ """
+ try:
+ signature = hotkey.sign(data.encode()).hex()
+ return signature
+ except TypeError as e:
+ bt.logging.error(f'Unable to sign data: {data} with wallet hotkey: {hotkey.ss58_address} due to error: {e}')
+ raise TypeError from e
+ except AttributeError as e:
+ bt.logging.error(f'Unable to sign data: {data} with wallet hotkey: {hotkey.ss58_address} due to error: {e}')
+ raise AttributeError from e
\ No newline at end of file
diff --git a/soundsright/core/__init__.py b/soundsright/core/__init__.py
new file mode 100644
index 0000000..05ecc4c
--- /dev/null
+++ b/soundsright/core/__init__.py
@@ -0,0 +1,3 @@
+from .validator.validator import SubnetValidator
+
+from .miner.miner import SubnetMiner
\ No newline at end of file
diff --git a/soundsright/core/miner/__init__.py b/soundsright/core/miner/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/soundsright/core/miner/miner.py b/soundsright/core/miner/miner.py
new file mode 100644
index 0000000..d6582fb
--- /dev/null
+++ b/soundsright/core/miner/miner.py
@@ -0,0 +1,585 @@
+from argparse import ArgumentParser
+from typing import Tuple
+import sys
+import bittensor as bt
+import hashlib
+import json
+import asyncio
+import os
+import traceback
+import time
+from dotenv import load_dotenv
+load_dotenv()
+
+# Import custom modules
+import soundsright.base as Base
+import soundsright.base.utils as Utils
+import soundsright.base.models as Models
+
+class SubnetMiner(Base.BaseNeuron):
+ """SubnetMiner class for SoundsRight Subnet"""
+
+ def __init__(self, parser: ArgumentParser):
+ """
+ Initializes the SubnetMiner class with attributes
+ neuron_config, model, tokenizer, wallet, subtensor, metagraph,
+ miner_uid
+
+ Arguments:
+ parser:
+ An ArgumentParser instance.
+
+ Returns:
+ None
+ """
+ super().__init__(parser=parser, profile="miner")
+
+ self.neuron_config = self.config(
+ bt_classes=[bt.subtensor, bt.logging, bt.wallet, bt.axon]
+ )
+
+ # Read command line arguments and perform actions based on them
+ args = parser.parse_args()
+ self.log_level = args.log_level
+
+ # Setup logging
+ bt.logging(config=self.neuron_config, logging_dir=self.neuron_config.full_path)
+ if args.log_level in ("DEBUG", "DEBUGX"):
+ bt.logging.enable_debug()
+ elif args.log_level in ("TRACE", "TRACEX"):
+ bt.logging.enable_trace()
+ else:
+ bt.logging.enable_default()
+
+ # Healthcheck API
+ self.healthcheck_api = Utils.HealthCheckAPI(
+ host=args.healthcheck_host, port=args.healthcheck_port, is_validator = False
+ )
+
+ # Run healthcheck API
+ self.healthcheck_api.run()
+
+ self.validator_min_stake = args.validator_min_stake
+
+ self.wallet, self.subtensor, self.metagraph, self.miner_uid = self.setup()
+
+ self.hotkey = self.wallet.hotkey.ss58_address
+
+ self.metadata_handler = Models.ModelMetadataHandler(
+ subtensor=self.subtensor,
+ subnet_netuid=self.neuron_config.netuid,
+ log_level=self.log_level,
+ wallet=self.wallet,
+ )
+
+ self.validator_stats = {}
+
+ self.miner_model_data = None
+
+ def save_state(self):
+ """Save miner state to models.json file
+ """
+ self.neuron_logger(
+ severity="INFO",
+ message="Saving miner state."
+ )
+
+ filename = os.path.join(self.cache_path, "models.json")
+
+ with open(filename,"w") as json_file:
+ json.dump(self.miner_model_data, json_file)
+
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Saved the following state to file: {filename} models: {self.miner_model_data}"
+ )
+
+ def load_state(self):
+ """Load miner state from models.json file if it exists
+ """
+ filename = os.path.join(self.cache_path, "models.json")
+
+ # If save file exists:
+ if os.path.exists(filename):
+ # Load save file data for miner models
+ with open(filename, "r") as json_file:
+ self.miner_model_data = json.load(json_file)
+
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Loaded the following state from file: {filename} models: {self.miner_model_data}"
+ )
+
+ # Otherwise start with a blank canvas and load from .env
+ else:
+ self.miner_model_data = {
+ "DENOISING_16000HZ":None,
+ "DEREVERBERATION_16000HZ":None,
+ }
+
+ def update_miner_model_data(self):
+ """Updates miner's models with new model data"""
+ # Model counter (this cannot be more than 1 or it will cause an error)
+ model_counter = 0
+
+ # New miner model data, used as reference with existing model data to see if chain needs to be updated
+ new_miner_model_data = {}
+
+ # Iterate through competitions
+ for sample_rate in ["16000HZ"]:
+ for task in ["DENOISING","DEREVERBERATION"]:
+ # Get .env params
+ namespace = os.getenv(f"{task}_{sample_rate}_HF_MODEL_NAMESPACE")
+ name = os.getenv(f"{task}_{sample_rate}_HF_MODEL_NAME")
+ revision = os.getenv(f"{task}_{sample_rate}_HF_MODEL_REVISION")
+
+ # If model is specified for this competition
+ if namespace and name and revision:
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Found specified model for competition: huggingface.co/{namespace}/{name}/{revision}"
+ )
+
+ # Update new model data dict with information
+ new_miner_model_data[f'{task}_{sample_rate}'] = {
+ 'hf_model_namespace':namespace,
+ 'hf_model_name':name,
+ 'hf_model_revision':revision,
+ }
+ # Add 1 to model counter
+ model_counter+=1
+
+ # Set to None if no model data providd
+ else:
+ new_miner_model_data[f'{task}_{sample_rate}'] = None
+
+ # Exit miner if model data for more than one competition detected
+ if model_counter > 1:
+ self.neuron_logger(
+ severity="ERROR",
+ message="Model data for multiple tasks and/or sample rates detected. Please register a new miner for each new task or sample rate you want to partake in. Exiting miner."
+ )
+ sys.exit()
+
+ upload_outcome = False
+
+ # Iterate through competitions to see if metadata has to be updated
+ for competition in new_miner_model_data.keys():
+
+ # Check that there is new model data and that it differs from old model data loaded from state
+ if new_miner_model_data[competition] and competition in self.miner_model_data.keys() and new_miner_model_data[competition] != self.miner_model_data[competition]:
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Uploading model metadata to chain: {new_miner_model_data[competition]}"
+ )
+
+ # Obtain competition id
+ competition_id = self.metadata_handler.get_competition_id_from_competition_name(competition)
+
+ # Get string of un-hashed metadata
+ unhashed_metadata = f"{new_miner_model_data[competition]['hf_model_namespace']}:{new_miner_model_data[competition]['hf_model_name']}:{new_miner_model_data[competition]['hf_model_revision']}:{self.hotkey}:{competition_id}"
+
+ # Hash it
+ metadata = hashlib.sha256(unhashed_metadata.encode()).hexdigest()
+
+ # Upload to chain
+ upload_outcome = asyncio.run(self.metadata_handler.upload_model_metadata_to_chain(metadata=metadata))
+
+ # Case that we had to upload metadata to chain for new model data and it was successful
+ if upload_outcome:
+ # Update miner model data so it can be saved
+ self.miner_model_data = new_miner_model_data
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"New model data has been uploaded to chain: {self.miner_model_data}. Sleeping for 60 seconds before starting miner operations."
+ )
+
+ # Sleep for a minute to guarantee that model metadata is uploaded to chain before the miner responds to validators
+ time.sleep(60)
+
+ # Case that we did not have to upload metadata to chain, or upload was not successful. In either case we default to model information saved to state
+ else:
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Loaded miner model data from state: {self.miner_model_data} with no upload to chain necessary."
+ )
+
+ def _update_validator_stats(self, hotkey, stat_type):
+ """Helper function to update the validator stats"""
+ if hotkey in self.validator_stats:
+ if stat_type in self.validator_stats[hotkey]:
+ self.validator_stats[hotkey][stat_type] += 1
+ else:
+ self.validator_stats[hotkey][stat_type] = 1
+ else:
+ self.validator_stats[hotkey] = {}
+ self.validator_stats[hotkey][stat_type] = 1
+
+ def setup(self) -> Tuple[bt.wallet, bt.subtensor, bt.metagraph, str]:
+ """This function setups the neuron.
+
+ The setup function initializes the neuron by registering the
+ configuration.
+
+ Arguments:
+ None
+
+ Returns:
+ wallet:
+ An instance of bittensor.wallet containing information about
+ the wallet
+ subtensor:
+ An instance of bittensor.subtensor
+ metagraph:
+ An instance of bittensor.metagraph
+ miner_uid:
+ An instance of int consisting of the miner UID
+
+ Raises:
+ AttributeError:
+ The AttributeError is raised if wallet, subtensor & metagraph cannot be logged.
+ """
+ bt.logging(config=self.neuron_config, logging_dir=self.neuron_config.full_path)
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Initializing miner for subnet: {self.neuron_config.netuid} on network: {self.neuron_config.subtensor.chain_endpoint} with config:\n {self.neuron_config}"
+ )
+
+ # Setup the bittensor objects
+ try:
+ wallet = bt.wallet(config=self.neuron_config)
+ subtensor = bt.subtensor(config=self.neuron_config)
+ metagraph = subtensor.metagraph(self.neuron_config.netuid)
+ except AttributeError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to setup bittensor objects: {e}"
+ )
+ sys.exit()
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Bittensor objects initialized:\nMetagraph: {metagraph}\
+ \nSubtensor: {subtensor}\nWallet: {wallet}"
+ )
+
+ # Validate that our hotkey can be found from metagraph
+ if wallet.hotkey.ss58_address not in metagraph.hotkeys:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Your miner: {wallet} is not registered to chain connection: {subtensor}. Run btcli register and try again"
+ )
+ sys.exit()
+
+ # Get the unique identity (UID) from the network
+ miner_uid = metagraph.hotkeys.index(wallet.hotkey.ss58_address)
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Miner is running with UID: {miner_uid}"
+ )
+
+ return wallet, subtensor, metagraph, miner_uid
+
+ def check_whitelist(self, hotkey):
+ """
+ Checks if a given validator hotkey has been whitelisted.
+
+ Arguments:
+ hotkey:
+ A str instance depicting a hotkey.
+
+ Returns:
+ True:
+ True is returned if the hotkey is whitelisted.
+ False:
+ False is returned if the hotkey is not whitelisted.
+ """
+
+ if isinstance(hotkey, bool) or not isinstance(hotkey, str):
+ return False
+
+ whitelisted_hotkeys = [
+ "5G4gJgvAJCRS6ReaH9QxTCvXAuc4ho5fuobR7CMcHs4PRbbX", # sn14 dev team test validator
+ ]
+
+ if hotkey in whitelisted_hotkeys:
+ return True
+
+ return False
+
+ def blacklist_fn(self, synapse: Base.Denoising_16kHz_Protocol | Base.Dereverberation_16kHz_Protocol) -> Tuple[bool, str]:
+ """
+ This function is executed before the synapse data has been
+ deserialized.
+
+ On a practical level this means that whatever blacklisting
+ operations we want to perform, it must be done based on the
+ request headers or other data that can be retrieved outside of
+ the request data.
+
+ As it currently stands, we want to blacklist requests that are
+ not originating from valid validators. This includes:
+ - unregistered hotkeys
+ - entities which are not validators
+ - entities with insufficient stake
+
+ Returns:
+ [True, ""] for blacklisted requests where the reason for
+ blacklisting is contained in the quotes.
+ [False, ""] for non-blacklisted requests, where the quotes
+ contain a formatted string (f"Hotkey {synapse.dendrite.hotkey}
+ has insufficient stake: {stake}",)
+ """
+
+ # Check whitelisted hotkeys (queries should always be allowed)
+ if self.check_whitelist(hotkey=synapse.dendrite.hotkey):
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Accepted whitelisted hotkey: {synapse.dendrite.hotkey})"
+ )
+ return (False, f"Accepted whitelisted hotkey: {synapse.dendrite.hotkey}")
+
+ # Blacklist entities that have not registered their hotkey
+ if synapse.dendrite.hotkey not in self.metagraph.hotkeys:
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Blacklisted unknown hotkey: {synapse.dendrite.hotkey}"
+ )
+ return (
+ True,
+ f"Hotkey {synapse.dendrite.hotkey} was not found from metagraph.hotkeys",
+ )
+
+ # Blacklist entities that are not validators
+ uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey)
+ if not self.metagraph.validator_permit[uid]:
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Blacklisted non-validator: {synapse.dendrite.hotkey}"
+ )
+ return (True, f"Hotkey {synapse.dendrite.hotkey} is not a validator")
+
+ # Blacklist entities that have insufficient stake
+ stake = float(self.metagraph.S[uid])
+ if stake < self.validator_min_stake:
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Blacklisted validator {synapse.dendrite.hotkey} with insufficient stake: {stake}"
+ )
+ return (
+ True,
+ f"Hotkey {synapse.dendrite.hotkey} has insufficient stake: {stake}",
+ )
+
+ # Allow all other entities
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Accepted hotkey: {synapse.dendrite.hotkey} (UID: {uid} - Stake: {stake})"
+ )
+ return (False, f"Accepted hotkey: {synapse.dendrite.hotkey}")
+
+ def blacklist_16kHz_denoising(self, synapse: Base.Denoising_16kHz_Protocol) -> Tuple[bool, str]:
+ """Wrapper for the blacklist function to avoid repetition in code"""
+ return self.blacklist_fn(synapse=synapse)
+
+ def blacklist_16kHz_dereverberation(self, synapse: Base.Dereverberation_16kHz_Protocol) -> Tuple[bool, str]:
+ """Wrapper for the blacklist function to avoid repetition in code"""
+ return self.blacklist_fn(synapse=synapse)
+
+ def priority_fn(self, synapse: Base.Denoising_16kHz_Protocol | Base.Dereverberation_16kHz_Protocol) -> float:
+ """
+ This function defines the priority based on which the validators
+ are selected. Higher priority value means the input from the
+ validator is processed faster.
+ """
+
+ # Prioritize whitelisted validators
+ if self.check_whitelist(hotkey=synapse.dendrite.hotkey):
+ return 10000000.0
+
+ # Otherwise prioritize validators based on their stake
+ uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey)
+ stake = float(self.metagraph.S[uid])
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Prioritized: {synapse.dendrite.hotkey} (UID: {uid} - Stake: {stake})"
+ )
+
+ return stake
+
+ def priority_16kHz_denoising(self, synapse: Base.Denoising_16kHz_Protocol) -> float:
+ """Wrapper for the priority function to avoid repetition in code"""
+ return self.priority_fn(synapse=synapse)
+
+ def priority_16kHz_dereverberation(self, synapse: Base.Dereverberation_16kHz_Protocol) -> float:
+ """Wrapper for the priority function to avoid repetition in code"""
+ return self.priority_fn(synapse=synapse)
+
+ def forward(self, synapse: Base.Denoising_16kHz_Protocol | Base.Dereverberation_16kHz_Protocol, competition: str) -> Base.Denoising_16kHz_Protocol | Base.Dereverberation_16kHz_Protocol:
+ """This function responds to validators with the miner's model data"""
+
+ hotkey = synapse.dendrite.hotkey
+
+ self._update_validator_stats(hotkey, f"received_{competition}_competition_synapse_count")
+
+ # Print version information and perform version checks
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Synapse version: {synapse.subnet_version}, our version: {self.subnet_version}"
+ )
+ if synapse.subnet_version > self.subnet_version:
+ self.neuron_logger(
+ severity="WARNING",
+ message=f"Received a synapse from a validator with higher subnet version ({synapse.subnet_version}) than yours ({self.subnet_version}). Please update the miner."
+ )
+
+ # Set data output (None is returned if no model data is provided since it is a default in the init)
+ synapse.data = self.miner_model_data[competition]
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Processed synapse from validator: {hotkey} for competition: {competition}"
+ )
+
+ self._update_validator_stats(hotkey, f"processed_{competition}_competition_synapse_count")
+
+ return synapse
+
+ def forward_16kHz_denoising(self, synapse: Base.Denoising_16kHz_Protocol) -> Base.Denoising_16kHz_Protocol:
+ """Wrapper for the forward function to avoid repetition in code"""
+ return self.forward(synapse=synapse, competition='DENOISING_16000HZ')
+
+ def forward_16kHz_dereverberation(self, synapse: Base.Dereverberation_16kHz_Protocol) -> Base.Dereverberation_16kHz_Protocol:
+ """Wrapper for the forward function to avoid repetition in code"""
+ return self.forward(synapse=synapse, competition='DEREVERBERATION_16000HZ')
+
+ def run(self):
+
+ # Load existing model data or start with a blank slate
+ self.load_state()
+
+ # Update known miner model data with .env params
+ self.update_miner_model_data()
+
+ # Save this updated state
+ self.save_state()
+
+ # Link the miner to the Axon
+ axon = bt.axon(wallet=self.wallet, config=self.neuron_config)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Linked miner to Axon: {axon}"
+ )
+
+ # Attach the miner functions to the Axon
+ axon.attach(
+ forward_fn=self.forward_16kHz_denoising,
+ blacklist_fn=self.blacklist_16kHz_denoising,
+ priority_fn=self.priority_16kHz_denoising,
+ ).attach(
+ forward_fn=self.forward_16kHz_dereverberation,
+ blacklist_fn=self.blacklist_16kHz_dereverberation,
+ priority_fn=self.priority_16kHz_dereverberation
+ )
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Attached functions to Axon: {axon}"
+ )
+
+ # Pass the Axon information to the network
+ axon.serve(netuid=self.neuron_config.netuid, subtensor=self.subtensor)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Axon served on network: {self.neuron_config.subtensor.chain_endpoint} with netuid: {self.neuron_config.netuid}"
+ )
+ # Activate the Miner on the network
+ axon.start()
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Axon started on port: {self.neuron_config.axon.port}"
+ )
+
+ # This loop maintains the miner's operations until intentionally stopped.
+ self.neuron_logger(
+ severity="INFO",
+ message="Miner has been initialized and we are connected to the network. Start main loop."
+ )
+
+ # Get module version
+ version = Utils.config["module_version"]
+
+ # When we init, set last_updated_block to current_block
+ self.last_updated_block = self.subtensor.get_current_block()
+
+ self.healthcheck_api.append_metric(metric_name="neuron_running", value=True)
+
+ while True:
+ try:
+ # Below: Periodically update our knowledge of the network graph.
+ if self.step % 600 == 0:
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Syncing metagraph: {self.metagraph} with subtensor: {self.subtensor}"
+ )
+
+ self.metagraph.sync(subtensor=self.subtensor)
+
+ # Check registration status
+ if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys:
+ self.neuron_logger(
+ severity="SUCCESS",
+ message=f"Hotkey is not registered on metagraph: {self.wallet.hotkey.ss58_address}."
+ )
+
+ if self.step % 60 == 0:
+ self.metagraph = self.subtensor.metagraph(self.neuron_config.netuid)
+ log = (
+ f"Version:{version} | "
+ f"Step:{self.step} | "
+ f"Block:{self.metagraph.block.item()} | "
+ f"Stake:{self.metagraph.S[self.miner_uid]} | "
+ f"Rank:{self.metagraph.R[self.miner_uid]} | "
+ f"Trust:{self.metagraph.T[self.miner_uid]} | "
+ f"Consensus:{self.metagraph.C[self.miner_uid] } | "
+ f"Incentive:{self.metagraph.I[self.miner_uid]} | "
+ f"Emission:{self.metagraph.E[self.miner_uid]}"
+ )
+
+ self.neuron_logger(
+ severity="INFO",
+ message=log
+ )
+
+ # Print validator stats
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Validator stats: {self.validator_stats}"
+ )
+
+ self.step += 1
+ time.sleep(1)
+
+ # If someone intentionally stops the miner, it'll safely terminate operations.
+ except KeyboardInterrupt:
+ axon.stop()
+ self.neuron_logger(
+ severity="SUCCESS",
+ message="Miner killed by keyboard interrupt."
+ )
+ self.healthcheck_api.append_metric(metric_name="neuron_running", value=False)
+ break
+ # In case of unforeseen errors, the miner will log the error and continue operations.
+ except Exception:
+ self.neuron_logger(
+ severity="SUCCESS",
+ message=traceback.format_exc()
+ )
+ continue
\ No newline at end of file
diff --git a/soundsright/core/validator/__init__.py b/soundsright/core/validator/__init__.py
new file mode 100644
index 0000000..cd98610
--- /dev/null
+++ b/soundsright/core/validator/__init__.py
@@ -0,0 +1 @@
+from .validator import SubnetValidator
\ No newline at end of file
diff --git a/soundsright/core/validator/validator.py b/soundsright/core/validator/validator.py
new file mode 100644
index 0000000..b339409
--- /dev/null
+++ b/soundsright/core/validator/validator.py
@@ -0,0 +1,1442 @@
+import copy
+import argparse
+from datetime import datetime, timedelta, timezone
+from typing import List
+import os
+import traceback
+import secrets
+import time
+import bittensor as bt
+import numpy as np
+import asyncio
+import sys
+import logging
+import pickle
+
+# Import custom modules
+import soundsright.base.benchmarking as Benchmarking
+import soundsright.base.data as Data
+import soundsright.base.utils as Utils
+import soundsright.base.models as Models
+import soundsright.base as Base
+
+class SuppressPydanticFrozenFieldFilterDereverberation_16kHz_Protocol(logging.Filter):
+ def filter(self, record):
+ return 'Ignoring error when setting attribute: 1 validation error for Dereverberation_16kHz_Protocol' not in record.getMessage()
+
+class SuppressPydanticFrozenFieldFilterDenoising_16kHz_Protocol(logging.Filter):
+ def filter(self, record):
+ return 'Ignoring error when setting attribute: 1 validation error for Denoising_16kHz_Protocol' not in record.getMessage()
+
+class SubnetValidator(Base.BaseNeuron):
+ """
+ Main class for the SoundsRight subnet validator.
+ """
+
+ def __init__(self, parser: argparse.ArgumentParser):
+
+ super().__init__(parser=parser, profile="validator")
+
+ self.version = Utils.config["module_version"]
+ self.neuron_config = None
+ self.wallet = None
+ self.subtensor = None
+ self.dendrite = None
+ self.metagraph: bt.metagraph | None = None
+ self.scores = None
+ self.hotkeys = None
+ self.load_validator_state = None
+ self.query = None
+ self.debug_mode = True
+ self.dataset_size = 1
+ self.weights_objects = []
+ self.sample_rates = [16000]
+ self.tasks = ['DENOISING','DEREVERBERATION']
+ self.miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+ self.best_miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+ self.blacklisted_miner_models = {
+ "DENOISING_16000HZ":[],
+ "DEREVERBERATION_16000HZ":[],
+ }
+ self.competition_max_scores = {
+ 'DENOISING_16000HZ':50,
+ 'DEREVERBERATION_16000HZ':50,
+ }
+ self.metric_proportions = {
+ "DENOISING_16000HZ":{
+ "PESQ":0.3,
+ "ESTOI":0.25,
+ "SI_SDR":0.15,
+ "SI_SIR":0.15,
+ "SI_SAR":0.15,
+ },
+ "DEREVERBERATION_16000HZ":{
+ "PESQ":0.3,
+ "ESTOI":0.25,
+ "SI_SDR":0.15,
+ "SI_SIR":0.15,
+ "SI_SAR":0.15,
+ },
+ }
+ self.competition_scores = {
+ 'DENOISING_16000HZ':None,
+ 'DEREVERBERATION_16000HZ':None,
+ }
+ self.sgmse_benchmarks = {
+ "DENOISING_16000HZ":None,
+ "DEREVERBERATION_16000HZ":None,
+ }
+
+ self.remote_logging_interval = 3600
+ self.last_remote_logging_timestamp = 0
+
+ self.apply_config(bt_classes=[bt.subtensor, bt.logging, bt.wallet])
+ self.initialize_neuron()
+ self.TTSHandler = Data.TTSHandler(
+ tts_base_path=self.tts_path,
+ sample_rates=self.sample_rates
+ )
+ dataset_download_outcome = Data.dataset_download(
+ wham_path = self.noise_data_path,
+ arni_path = self.rir_data_path,
+ log_level = self.log_level
+ )
+ if not dataset_download_outcome:
+ sys.exit()
+
+ self.generate_new_dataset(override=False)
+
+ if not self.check_wav_files():
+ self.benchmark_sgmse_for_all_competitions()
+
+ def check_wav_files(self):
+ directories = [self.tts_path, self.reverb_path, self.noise_path]
+
+ for dir_path in directories:
+ if not os.path.isdir(dir_path):
+ return False
+
+ wav_files = [f for f in os.listdir(dir_path) if f.endswith('.wav')]
+ if not wav_files:
+ return False
+
+ return True
+
+ def generate_new_dataset(self, override=True) -> None:
+
+ # Check to see if we need to generate a new dataset
+ if override or not self.check_wav_files():
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Generating new dataset."
+ )
+
+ # Clear existing datasets
+ Data.reset_all_data_directories(
+ tts_base_path=self.tts_path,
+ reverb_base_path=self.reverb_path,
+ noise_base_path=self.noise_path
+ )
+
+ # Generate new TTS data
+ self.TTSHandler.create_openai_tts_dataset_for_all_sample_rates(n=(3 if self.debug_mode else self.dataset_size))
+
+ tts_16000 = os.path.join(self.tts_path, "16000")
+ tts_files_16000 = [f for f in os.listdir(tts_16000)]
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"TTS files generated in directory: {tts_16000} are: {tts_files_16000}"
+ )
+
+ # Generate new noise/reverb data
+ Data.create_noise_and_reverb_data_for_all_sampling_rates(
+ tts_base_path=self.tts_path,
+ arni_dir_path=self.rir_data_path,
+ reverb_base_path=self.reverb_path,
+ wham_dir_path=self.noise_data_path,
+ noise_base_path=self.noise_path,
+ tasks=self.tasks,
+ log_level=self.log_level,
+ )
+
+ noise_16000 = os.path.join(self.noise_path, "16000")
+ noise_files_16000 = [f for f in os.listdir(noise_16000)]
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Noise files generated in directory: {noise_16000}: {noise_files_16000}"
+ )
+
+ reverb_16000 = os.path.join(self.reverb_path, "16000")
+ reverb_files_16000 = [f for f in os.listdir(reverb_16000)]
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Reverb files generated in directory: {reverb_16000}: {reverb_files_16000}"
+ )
+
+ self.healthcheck_api.append_metric(metric_name="datasets_generated", value=1)
+
+ def get_next_competition_timestamp(self) -> int:
+ """
+ Finds the Unix timestamp for the next day at 9:00 AM GMT.
+ """
+ # Current time in GMT
+ now = datetime.now(timezone.utc)
+
+ # Find the next day at 9:00 AM
+ next_day = now + timedelta(days=1)
+ next_day_at_nine = next_day.replace(hour=9, minute=0, second=0, microsecond=0)
+
+ # Return Unix timestamp
+ return int(next_day_at_nine.timestamp())
+
+ def update_next_competition_timestamp(self) -> None:
+ """
+ Updates the next competition timestamp to the 9:00 AM GMT of the following day.
+ """
+ # Add 1 day to the current competition time
+ next_competition_time = datetime.fromtimestamp(self.next_competition_timestamp, tz=timezone.utc)
+ next_competition_time += timedelta(days=1)
+
+ # Set the new timestamp
+ self.next_competition_timestamp = int(next_competition_time.timestamp())
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Next competition will be at {datetime.fromtimestamp(self.next_competition_timestamp, tz=timezone.utc)}"
+ )
+
+ self.healthcheck_api.append_metric(metric_name="competitions_judged", value=1)
+
+ def apply_config(self, bt_classes) -> bool:
+ """This method applies the configuration to specified bittensor classes"""
+ try:
+ self.neuron_config = self.config(bt_classes=bt_classes)
+ except AttributeError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to apply validator configuration: {e}"
+ )
+ raise AttributeError from e
+ except OSError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to create logging directory: {e}"
+ )
+ raise OSError from e
+
+ return True
+
+ def validator_validation(self, metagraph, wallet, subtensor) -> bool:
+ """This method validates the validator has registered correctly"""
+ if wallet.hotkey.ss58_address not in metagraph.hotkeys:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Your validator: {wallet} is not registered to chain connection: {subtensor}. Run btcli register and try again"
+ )
+ return False
+
+ return True
+
+ def setup_bittensor_objects(self, neuron_config) -> tuple[bt.wallet, bt.subtensor, bt.dendrite, bt.metagraph]:
+ """Setups the bittensor objects"""
+ try:
+ wallet = bt.wallet(config=neuron_config)
+ subtensor = bt.subtensor(config=neuron_config)
+ dendrite = bt.dendrite(wallet=wallet)
+ metagraph = subtensor.metagraph(neuron_config.netuid)
+ except AttributeError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Unable to setup bittensor objects: {e}"
+ )
+ raise AttributeError from e
+
+ self.hotkeys = copy.deepcopy(metagraph.hotkeys)
+
+ self.wallet = wallet
+ self.subtensor = subtensor
+ self.dendrite = dendrite
+ self.metagraph = metagraph
+
+ return self.wallet, self.subtensor, self.dendrite, self.metagraph
+
+ def initialize_neuron(self) -> bool:
+ """This function initializes the neuron.
+
+ The setup function initializes the neuron by registering the
+ configuration.
+
+ Args:
+ None
+
+ Returns:
+ Bool:
+ A boolean value indicating success/failure of the initialization.
+ Raises:
+ AttributeError:
+ AttributeError is raised if the neuron initialization failed
+ IndexError:
+ IndexError is raised if the hotkey cannot be found from the metagraph
+ """
+ # Read command line arguments and perform actions based on them
+ args = self._parse_args(parser=self.parser)
+ self.log_level = args.log_level
+
+ # Setup logging
+ bt.logging(config=self.neuron_config, logging_dir=self.neuron_config.full_path)
+ if args.log_level in ("DEBUG", "DEBUGX"):
+ bt.logging.enable_debug()
+ elif args.log_level in ("TRACE", "TRACEX"):
+ bt.logging.enable_trace()
+ else:
+ bt.logging.enable_default()
+
+ # Suppress specific validation errors from pydantic
+ bt.logging._logger.addFilter(SuppressPydanticFrozenFieldFilterDereverberation_16kHz_Protocol())
+ bt.logging._logger.addFilter(SuppressPydanticFrozenFieldFilterDenoising_16kHz_Protocol())
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Initializing validator for subnet: {self.neuron_config.netuid} on network: {self.neuron_config.subtensor.chain_endpoint} with config: {self.neuron_config}"
+ )
+
+ # Setup the bittensor objects
+ self.setup_bittensor_objects(self.neuron_config)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Bittensor objects initialized:\nMetagraph: {self.metagraph}\nSubtensor: {self.subtensor}\nWallet: {self.wallet}"
+ )
+
+ if not args.debug_mode:
+ # Validate that the validator has registered to the metagraph correctly
+ if not self.validator_validation(self.metagraph, self.wallet, self.subtensor):
+ raise IndexError("Unable to find validator key from metagraph")
+
+ # Get the unique identity (UID) from the network
+ validator_uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address)
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Validator is running with UID: {validator_uid}"
+ )
+
+ # Disable debug mode
+ self.debug_mode = False
+
+ self.dataset_size = args.dataset_size
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Debug mode: {self.debug_mode}"
+ )
+
+ self.next_competition_timestamp = self.get_next_competition_timestamp()
+
+ if args.load_state == "False":
+ self.load_validator_state = False
+ else:
+ self.load_validator_state = True
+
+ if self.load_validator_state:
+ self.load_state()
+ else:
+ self.init_default_scores()
+
+ # Healthcheck API
+ self.healthcheck_api = Utils.HealthCheckAPI(
+ host=args.healthcheck_host, port=args.healthcheck_port, is_validator = True, current_models=self.miner_models, best_models=self.best_miner_models
+ )
+
+ # Run healthcheck API
+ self.healthcheck_api.run()
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"HealthCheck API running at: http://{args.healthcheck_host}:{args.healthcheck_port}"
+ )
+
+ return True
+
+ def _parse_args(self, parser) -> argparse.Namespace:
+ return parser.parse_args()
+
+ def check_hotkeys(self) -> None:
+ """Checks if some hotkeys have been replaced in the metagraph"""
+ if self.hotkeys is not None and np.size(self.hotkeys) > 0:
+ # Check if known state len matches with current metagraph hotkey length
+ if len(self.hotkeys) == len(self.metagraph.hotkeys):
+ current_hotkeys = self.metagraph.hotkeys
+ for i, hotkey in enumerate(current_hotkeys):
+ if self.hotkeys[i] != hotkey:
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Index '{i}' has mismatching hotkey. Old hotkey: '{self.hotkeys[i]}', new hotkey: '{hotkey}. Resetting score to 0.0"
+ )
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Score before reset: {self.scores[i]}, competition scores: {self.competition_scores}"
+ )
+ self.reset_hotkey_scores(i)
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Score after reset: {self.scores[i]}, competition scores: {self.competition_scores}"
+ )
+ # Case that there are more/less hotkeys in metagraph
+ else:
+ # Add new zero-score values
+ self.neuron_logger(
+ severity="INFO",
+ message=f"State and metagraph hotkey length mismatch. Metagraph: {len(self.metagraph.hotkeys)} State: {len(self.hotkeys)}. Adjusting scores accordingly."
+ )
+ self.adjust_scores_length(
+ metagraph_len=len(self.metagraph.hotkeys),
+ state_len=len(self.hotkeys)
+ )
+
+ self.hotkeys = copy.deepcopy(self.metagraph.hotkeys)
+
+ for competition in self.miner_models:
+ self.miner_models[competition] = Benchmarking.filter_models_for_deregistered_miners(
+ miner_models=self.miner_models[competition],
+ hotkeys=self.hotkeys
+ )
+
+ def reset_hotkey_scores(self, hotkey_index) -> None:
+ self.scores[hotkey_index] = 0.0
+ for competition in self.competition_scores:
+ self.competition_scores[competition][hotkey_index] = 0.0
+
+ def adjust_scores_length(self, metagraph_len, state_len) -> None:
+ if metagraph_len > state_len:
+ additional_zeros = np.zeros(
+ (metagraph_len-state_len),
+ dtype=np.float32,
+ )
+
+ self.scores = np.concatenate((self.scores, additional_zeros))
+ for competition in self.competition_scores:
+ self.competition_scores[competition] = np.concatenate((self.competition_scores[competition], additional_zeros))
+
+ async def send_competition_synapse(self, uid_to_query: int, sample_rate: int, task: str, timeout: int = 5) -> List[bt.synapse]:
+ # Broadcast query to valid Axons
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Sent competition synapse for {task} at {sample_rate/1000} kHz to UID: {uid_to_query}."
+ )
+
+ axon_to_query = self.metagraph.axons[uid_to_query]
+
+ if sample_rate == 16000 and task == 'DENOISING':
+ return await self.dendrite.forward(
+ axon_to_query,
+ Base.Denoising_16kHz_Protocol(subnet_version=self.subnet_version),
+ timeout=timeout,
+ deserialize=True,
+ )
+
+ elif sample_rate == 16000 and task == 'DEREVERBERATION':
+ return await self.dendrite.forward(
+ axon_to_query,
+ Base.Dereverberation_16kHz_Protocol(subnet_version=self.subnet_version),
+ timeout=timeout,
+ deserialize=True,
+ )
+
+ def save_state(self) -> None:
+ """Saves the state of the validator to a file."""
+
+ state_filename = os.path.join(self.cache_path, "state.npz")
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Saving validator state to file: {state_filename}."
+ )
+
+ # Save the state of the validator to file.
+ np.savez_compressed(
+ state_filename,
+ step=self.step,
+ scores=self.scores,
+ competition_scores_DENOISING_16000HZ=self.competition_scores['DENOISING_16000HZ'],
+ competition_scores_DEREVERBERATION_16000HZ=self.competition_scores['DEREVERBERATION_16000HZ'],
+ hotkeys=self.hotkeys,
+ last_updated_block=self.last_updated_block,
+ next_competition_timestamp=self.next_competition_timestamp,
+ )
+
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Saved the following state to file: {state_filename} step: {self.step}, scores: {self.scores}, competition_scores: {self.competition_scores}, hotkeys: {self.hotkeys}, last_updated_block: {self.last_updated_block}"
+ )
+
+ miner_models_pickle_filename = os.path.join(self.cache_path, "miner_models.pickle")
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Saving miner models to pickle file: {miner_models_pickle_filename}"
+ )
+
+ with open(miner_models_pickle_filename, "wb") as pickle_file:
+ pickle.dump(self.miner_models, pickle_file)
+
+ best_miner_models_pickle_filename = os.path.join(self.cache_path, "best_miner_models.pickle")
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Saving best miner models to pickle file: {best_miner_models_pickle_filename}"
+ )
+
+ with open(best_miner_models_pickle_filename, "wb") as pickle_file:
+ pickle.dump(self.best_miner_models, pickle_file)
+
+ blacklisted_miner_models_pickle_filename = os.path.join(self.cache_path, "blacklisted_miner_models.pickle")
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Saving blacklisted miner models to pickle file: {blacklisted_miner_models_pickle_filename}"
+ )
+
+ with open(blacklisted_miner_models_pickle_filename, "wb") as pickle_file:
+ pickle.dump(self.blacklisted_miner_models, pickle_file)
+
+ def init_default_scores(self) -> None:
+ """Validators without previous validation knowledge should start
+ with default score of 0.0 for each UID. The method can also be
+ used to reset the scores in case of an internal error"""
+
+ self.neuron_logger(
+ severity="INFO",
+ message="Initiating validator with default overall scores for each UID"
+ )
+ self.scores = np.zeros_like(self.metagraph.S, dtype=np.float32)
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Overall weights for validation have been initialized: {self.scores}"
+ )
+ self.competition_scores = {
+ "DENOISING_16000HZ":None,
+ "DEREVERBERATION_16000HZ":None,
+ }
+ for competition in self.competition_scores.keys():
+ self.competition_scores[competition] = np.zeros_like(self.metagraph.S, dtype=np.float32)
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Scores for competition: {competition} have been initialized: {self.competition_scores[competition]}"
+ )
+
+ def reset_validator_state(self, state_path) -> None:
+ """Inits the default validator state. Should be invoked only
+ when an exception occurs and the state needs to reset."""
+
+ # Rename current state file in case manual recovery is needed
+ os.rename(
+ state_path,
+ f"{state_path}-{int(datetime.now().timestamp())}.autorecovery",
+ )
+
+ self.init_default_scores()
+ self.step = 0
+ self.last_updated_block = 0
+ self.hotkeys = None
+ self.next_competition_timestamp = self.get_next_competition_timestamp()
+
+ def load_state(self) -> None:
+ """Loads the state of the validator from a file."""
+
+ # Load the state of the validator from file.
+ state_path = os.path.join(self.cache_path, "state.npz")
+
+ if os.path.exists(state_path):
+ try:
+ self.neuron_logger(
+ severity="INFO",
+ message="Loading validator state."
+ )
+ state = np.load(state_path, allow_pickle=True)
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Loaded the following state from file: {state}"
+ )
+
+ self.step = state["step"]
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Step loaded from file: {self.step}"
+ )
+
+ self.scores = state["scores"]
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Scores loaded from saved file: {self.scores}"
+ )
+
+ self.competition_scores = {
+ "DENOISING_16000HZ": state['competition_scores_DENOISING_16000HZ'],
+ "DEREVERBERATION_16000HZ": state['competition_scores_DEREVERBERATION_16000HZ']
+ }
+
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Competition scores loaded from file: {self.competition_scores}"
+ )
+
+ self.hotkeys = state["hotkeys"]
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Hotkeys loaded from file: {self.hotkeys}"
+ )
+
+ self.last_updated_block = state["last_updated_block"]
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Last updated block loaded from file: {self.last_updated_block}"
+ )
+
+ self.next_competition_timestamp = state['next_competition_timestamp']
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Next competition timestamp loaded from file: {self.next_competition_timestamp}"
+ )
+
+ except Exception as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Validator state reset because an exception occurred: {e}"
+ )
+ self.reset_validator_state(state_path=state_path)
+ else:
+ self.init_default_scores()
+ self.step = 0
+ self.last_updated_block = 0
+ self.hotkeys = None
+ self.next_competition_timestamp = self.get_next_competition_timestamp()
+
+ miner_models_filepath = os.path.join(self.cache_path, "miner_models.pickle")
+ if os.path.exists(miner_models_filepath):
+ try:
+ with open(miner_models_filepath, "rb") as pickle_file:
+ self.miner_models = pickle.load(pickle_file)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Loaded miner models from {miner_models_filepath}: {self.miner_models}"
+ )
+ except Exception as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Could not load miner models from {miner_models_filepath} because: {e}"
+ )
+ self.miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+ else:
+ self.miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+
+ best_miner_models_filepath = os.path.join(self.cache_path, "best_miner_models.pickle")
+ if os.path.exists(best_miner_models_filepath):
+ try:
+ with open(best_miner_models_filepath, "rb") as pickle_file:
+ self.best_miner_models = pickle.load(pickle_file)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Loaded best miner models from {best_miner_models_filepath}: {self.best_miner_models}"
+ )
+ except Exception as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Could not load best miner models from {best_miner_models_filepath} because: {e}"
+ )
+ self.best_miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+ else:
+ self.best_miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+
+ blacklisted_miner_models_filepath = os.path.join(self.cache_path, "blacklisted_miner_models.pickle")
+ if os.path.exists(blacklisted_miner_models_filepath):
+ try:
+ with open(blacklisted_miner_models_filepath, "rb") as pickle_file:
+ self.blacklisted_miner_models = pickle.load(pickle_file)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Loaded blacklisted miner models from {blacklisted_miner_models_filepath}: {self.blacklisted_miner_models}"
+ )
+ except Exception as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Could not load blacklisted miner models from {blacklisted_miner_models_filepath} because: {e}"
+ )
+ self.blacklisted_miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+ else:
+ self.blacklisted_miner_models = {
+ 'DENOISING_16000HZ':[],
+ 'DEREVERBERATION_16000HZ':[],
+ }
+
+ @Utils.timeout_decorator(timeout=30)
+ async def sync_metagraph(self) -> None:
+ """Syncs the metagraph"""
+
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Attempting sync of metagraph: {self.metagraph} with subtensor: {self.subtensor}"
+ )
+
+ # Sync the metagraph
+ self.metagraph.sync(subtensor=self.subtensor)
+
+ def handle_metagraph_sync(self) -> None:
+ try:
+ asyncio.run(self.sync_metagraph())
+ self.neuron_logger(
+ severity="INFOX",
+ message=f"Metagraph synced: {self.metagraph}"
+ )
+ except TimeoutError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Metagraph sync timed out: {e}"
+ )
+
+ def handle_weight_setting(self) -> None:
+ """
+ Checks if setting/committing/revealing weights is appropriate, triggers the process if so.
+ """
+ # Check if it's time to set/commit new weights
+ if self.subtensor.get_current_block() >= self.last_updated_block + 100 and not self.debug_mode:
+
+ # Try set/commit weights
+ try:
+ asyncio.run(self.commit_weights())
+ self.last_updated_block = self.subtensor.get_current_block()
+
+ except TimeoutError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Committing weights timed out: {e}"
+ )
+
+ # If commit reveal is enabled, reveal weights in queue
+ if self.subtensor.get_subnet_hyperparameters(netuid=self.neuron_config.netuid).commit_reveal_weights_enabled:
+
+ # Reveal weights stored in queue
+ self.reveal_weights_in_queue()
+
+ @Utils.timeout_decorator(timeout=30)
+ async def commit_weights(self) -> None:
+ """Sets the weights for the subnet"""
+
+ def normalize_weights_list(weights):
+ max_value = self.subtensor.get_subnet_hyperparameters(netuid=self.neuron_config.netuid).max_weight_limit
+ if all(x==1 for x in weights):
+ return [(x/max_value) for x in weights]
+ elif all(x==0 for x in weights):
+ return weights
+ else:
+ return [(x/max(weights)) for x in weights]
+
+ self.healthcheck_api.update_metric(metric_name='weights.targets', value=np.count_nonzero(self.scores))
+
+ weights = self.scores
+ salt=secrets.randbelow(2**16)
+ block = self.subtensor.get_current_block()
+ uids = [int(uid) for uid in self.metagraph.uids]
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Committing weights: {weights}"
+ )
+ if not self.debug_mode:
+ # Commit reveal if it is enabled
+ if self.subtensor.get_subnet_hyperparameters(netuid=self.neuron_config.netuid).commit_reveal_weights_enabled:
+
+ self.neuron_logger(
+ severity="DEBUGX",
+ message=f"Committing weights with the following parameters: netuid={self.neuron_config.netuid}, wallet={self.wallet}, uids={uids}, weights={weights}, version_key={self.subnet_version}"
+ )
+ # This is a crucial step that updates the incentive mechanism on the Bittensor blockchain.
+ # Miners with higher scores (or weights) receive a larger share of TAO rewards on this subnet.
+ result, msg = self.subtensor.commit_weights(
+ netuid=self.neuron_config.netuid, # Subnet to set weights on.
+ wallet=self.wallet, # Wallet to sign set weights using hotkey.
+ uids=uids, # Uids of the miners to set weights for.
+ weights=weights, # Weights to set for the miners.
+ salt=[salt],
+ max_retries=5,
+ )
+ # For successful commits
+ if result:
+
+ self.neuron_logger(
+ severity="SUCCESS",
+ message=f"Successfully committed weights: {weights}. Message: {msg}"
+ )
+
+ self.healthcheck_api.update_metric(metric_name='weights.last_committed_timestamp', value=time.strftime("%H:%M:%S", time.localtime()))
+ self.healthcheck_api.append_metric(metric_name="weights.total_count_committed", value=1)
+
+ self._store_weight_metadata(
+ salt=salt,
+ uids=uids,
+ weights=weights,
+ block=block
+ )
+
+ # For unsuccessful commits
+ else:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Failed to commit weights: {weights}. Message: {msg}"
+ )
+ else:
+ self.neuron_logger(
+ severity="DEBUGX",
+ message=f"Setting weights with the following parameters: netuid={self.neuron_config.netuid}, wallet={self.wallet}, uids={self.metagraph.uids}, weights={weights}, version_key={self.subnet_version}"
+ )
+
+ weights = normalize_weights_list(weights)
+
+ # This is a crucial step that updates the incentive mechanism on the Bittensor blockchain.
+ # Miners with higher scores (or weights) receive a larger share of TAO rewards on this subnet.
+ result = self.subtensor.set_weights(
+ netuid=self.neuron_config.netuid, # Subnet to set weights on.
+ wallet=self.wallet, # Wallet to sign set weights using hotkey.
+ uids=self.metagraph.uids, # Uids of the miners to set weights for.
+ weights=weights, # Weights to set for the miners.
+ wait_for_inclusion=False,
+ version_key=self.subnet_version,
+ )
+ if result:
+ self.neuron_logger(
+ severity="SUCCESS",
+ message=f"Successfully set weights: {weights}"
+ )
+
+ self.healthcheck_api.update_metric(metric_name='weights.last_set_timestamp', value=time.strftime("%H:%M:%S", time.localtime()))
+ self.healthcheck_api.append_metric(metric_name="weights.total_count_set", value=1)
+
+ else:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Failed to set weights: {weights}"
+ )
+ else:
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Skipped setting weights due to debug mode"
+ )
+
+ def _store_weight_metadata(self, salt, uids, weights, block) -> None:
+ """Stores weight metadata as part of the SubnetValidator.weights_objects attribute
+
+ Args:
+ salt (int): Unique salt for weights.
+ uids (list): Uids to set weights for
+ weights (np.ndarray)): Weights array
+ block (int): What block weights were initially committed to chain
+ """
+ # Construct weight object
+ data = {
+ "salt": salt,
+ "uids": uids,
+ "weights": weights,
+ "block": block
+ }
+
+ # Store weight object
+ self.weights_objects.append(data)
+
+ self.neuron_logger(
+ severity='TRACE',
+ message=f'Weight data appended to weights_objects for future reveal: {data}'
+ )
+
+ @Utils.timeout_decorator(timeout=30)
+ async def reveal_weights(self, weight_object) -> bool:
+ """
+ Reveals weights (in the case that commit reveal is enabled for the subnet)
+
+ Args:
+ :param weight_object: (dict): Validator's local log of weights to be revealed
+
+ Returns:
+ bool: True if weights were revealed successfully, False otherwise
+ """
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Revealing weights: {weight_object}"
+ )
+
+ status, msg = self.subtensor.reveal_weights(
+ wallet=self.wallet,
+ netuid=self.neuron_config.netuid,
+ uids=weight_object["uids"],
+ weights=weight_object["weights"],
+ salt=np.array([weight_object["salt"]]),
+ max_retries=5
+ )
+
+ if status:
+ self.neuron_logger(
+ severity="SUCCESS",
+ message=f'Weight reveal succeeded for weights: {weight_object} Status message: {msg}'
+ )
+ self.healthcheck_api.update_metric(metric_name='weights.last_revealed_timestamp', value=time.strftime("%H:%M:%S", time.localtime()))
+ self.healthcheck_api.append_metric(metric_name="weights.total_count_revealed", value=1)
+
+ else:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f'Weight reveal failed. Status message: {msg}'
+ )
+
+ return status
+
+ def reveal_weights_in_queue(self) -> None:
+ """
+ Looks through queue, sees if any weight objects are at/past the time to reveal them. Reveals them if this is the case
+ """
+ current_block = self.subtensor.get_current_block()
+ commit_reveal_weights_interval = self.subtensor.get_subnet_hyperparameters(netuid=self.neuron_config.netuid).commit_reveal_weights_interval
+ new_weights_objects = []
+
+ for weight_object in self.weights_objects:
+ if (current_block - weight_object['block']) >= commit_reveal_weights_interval:
+ try:
+ status = asyncio.run(self.reveal_weights(weight_object=weight_object))
+ if not status:
+ new_weights_objects.append(weight_object)
+ except TimeoutError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Revealing weights timed out: {e}"
+ )
+ new_weights_objects.append(weight_object)
+
+ else:
+ new_weights_objects.append(weight_object)
+
+ self.weights_objects = new_weights_objects
+
+ self.neuron_logger(
+ severity="TRACE",
+ message=f"Weights objects in queue to be revealed: {self.weights_objects}"
+ )
+
+ def handle_remote_logging(self) -> None:
+ """
+ References last updated timestamp and specified interval
+ to see if remote logging needs to be done for best models
+ from previous competition and current best models. If
+ logging is successful it updates the timestamp
+ """
+ current_timestamp = int(time.time())
+
+ if (self.last_remote_logging_timestamp + self.remote_logging_interval) <= current_timestamp and not self.debug_mode:
+
+ # Log models for current competition
+ current_models_outcome = Benchmarking.miner_models_remote_logging(
+ hotkey=self.wallet.hotkey,
+ current_miner_models=self.miner_models,
+ log_level=self.log_level,
+ )
+
+ sgmse_outcome = Benchmarking.sgmse_remote_logging(
+ hotkey=self.wallet.hotkey,
+ sgmse_benchmarks=self.sgmse_benchmarks,
+ log_level=self.log_level,
+ )
+
+ if current_models_outcome and sgmse_outcome:
+ self.last_remote_logging_timestamp = int(time.time())
+
+ def get_uids_to_query(self) -> List[int]:
+ """This function determines valid axon to send the query to--
+ they must have valid ips """
+ axons = self.metagraph.axons
+ # Clear axons that do not have an IP
+ axons_with_valid_ip = [axon for axon in axons if axon.ip != "0.0.0.0"]
+
+ # Clear axons with duplicate IP/Port
+ axon_ips = set()
+ filtered_axons = [
+ axon
+ for axon in axons_with_valid_ip
+ if axon.ip_str() not in axon_ips and not axon_ips.add(axon.ip_str())
+ ]
+
+ self.neuron_logger(
+ severity="TRACEX",
+ message=f"Filtered out axons. Original list: {len(axons)}, filtered list: {len(filtered_axons)}"
+ )
+
+ self.healthcheck_api.append_metric(metric_name="axons.total_filtered_axons", value=len(filtered_axons))
+
+ return [self.hotkeys.index(axon.hotkey) for axon in filtered_axons]
+
+ def find_dict_by_hotkey(self, dict_list, hotkey) -> dict | None:
+ """_summary_
+
+ Args:
+ :param dict_list: (List[dict]): List of dictionaries
+ :param hotkey: (str): ss58_adr
+
+ Returns:
+ dict: if hotkey in dict_list. None otherwise
+ """
+ for d in dict_list:
+ if d.get('hotkey') == hotkey:
+ return d
+ return {}
+
+ def benchmark_sgmse(self, sample_rate: int, task: str) -> None:
+ """Runs benchmarking for SGMSE for competition based on current dataset
+
+ Args:
+ sample_rate (int): Sample rate
+ task (str): DENOISING/DEREVERBERATION
+ """
+
+ competition = f"{task}_{sample_rate}HZ"
+ task_path = os.path.join(self.noise_path, str(sample_rate)) if task == "DENOISING" else os.path.join(self.reverb_path, str(sample_rate))
+
+ sgmse_handler = Models.SGMSEHandler(
+ task = task,
+ sample_rate = sample_rate,
+ task_path = task_path,
+ sgmse_path = self.sgmse_path,
+ sgmse_output_path = self.sgmse_output_path,
+ log_level=self.log_level,
+ )
+
+ sgmse_benchmarking_outcome = sgmse_handler.download_start_and_enhance()
+
+ # Calculate metrics
+ if sgmse_benchmarking_outcome:
+ metrics_dict = Benchmarking.calculate_metrics_dict(
+ clean_directory=os.path.join(self.tts_path, str(sample_rate)),
+ enhanced_directory=self.sgmse_output_path,
+ noisy_directory=task_path,
+ sample_rate=sample_rate,
+ log_level=self.log_level,
+ )
+
+ # Append metrics to dict
+ self.sgmse_benchmarks[competition] = metrics_dict
+
+ self.neuron_logger(
+ severity="TRACE",
+ message=f"Determined SGMSE+ benchmarks for {competition} competition: {self.sgmse_benchmarks[competition]}"
+ )
+
+ def benchmark_sgmse_for_all_competitions(self) -> None:
+ """Runs benchmarking for SGMSE+ for all competitions, sends results to remote logger
+ """
+ self.neuron_logger(
+ severity="INFO",
+ message="Benchmarking SGMSE+ on today's dataset."
+ )
+
+ # Reset benchmarking dic
+ for competition_key in self.sgmse_benchmarks.keys():
+ self.sgmse_benchmarks[competition_key] = None
+
+ for sample_rate in self.sample_rates:
+ for task in self.tasks:
+ # Benchmark SGMSE+ on dataset
+ self.benchmark_sgmse(sample_rate=sample_rate, task=task)
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"SGMSE+ benchmarks: {self.sgmse_benchmarks}"
+ )
+
+ def benchmark_model(self, model_metadata: dict, sample_rate: int, task: str, hotkey: str) -> dict:
+ """Runs benchmarking for miner-submitted model using Models.ModelEvaluationHandler
+
+ Args:
+ :param model_metadata: (dict): Model metadata submitted by miner via synapse
+ :param sample_rate: (int): Sample rate
+ :param task: (str): DENOISING/DEREVERBERATIOn
+ :param hotkey: (str): ss58_address
+
+ Returns:
+ dict: model benchmarking results. If model benchmarking could not be performed, returns an empty (no-response) dict
+ """
+ # Validate that miner data is formatted correctly
+ if not Utils.validate_miner_response(model_metadata):
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Miner with hotkey: {hotkey} has response that was not properly formatted, cannot benchmark: {model_metadata}"
+ )
+
+ return {
+ 'hotkey':hotkey,
+ 'hf_model_name':'',
+ 'hf_model_namespace':'',
+ 'hf_model_revision':'',
+ 'model_hash':'',
+ 'block':10000000000000000,
+ 'metrics':{}
+ }
+
+ # Initialize model evaluation handler
+ eval_handler = Models.ModelEvaluationHandler(
+ tts_base_path=self.tts_path,
+ noise_base_path=self.noise_path,
+ reverb_base_path=self.reverb_path,
+ model_output_path=self.model_output_path,
+ model_path=self.model_path,
+ sample_rate=sample_rate,
+ task=task,
+ hf_model_namespace=model_metadata['hf_model_namespace'],
+ hf_model_name=model_metadata['hf_model_name'],
+ hf_model_revision=model_metadata['hf_model_revision'],
+ log_level=self.log_level,
+ subtensor=self.subtensor,
+ subnet_netuid=self.neuron_config.netuid,
+ miner_hotkey=hotkey,
+ miner_models=self.miner_models[f'{task}_{sample_rate}HZ']
+ )
+
+ metrics_dict, model_hash, model_block = eval_handler.download_run_and_evaluate()
+
+ model_benchmark = {
+ 'hotkey':hotkey,
+ 'hf_model_name':model_metadata['hf_model_name'],
+ 'hf_model_namespace':model_metadata['hf_model_namespace'],
+ 'hf_model_revision':model_metadata['hf_model_revision'],
+ 'model_hash':model_hash,
+ 'block':model_block,
+ 'metrics':metrics_dict,
+ }
+
+ if not Utils.validate_model_benchmark(model_benchmark):
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Model benchmark: {model_benchmark} for task: {task} and sample rate: {sample_rate} is invalidly formatted."
+ )
+
+ return {
+ 'hotkey':hotkey,
+ 'hf_model_name':'',
+ 'hf_model_namespace':'',
+ 'hf_model_revision':'',
+ 'model_hash':'',
+ 'metrics':{},
+ 'block':10000000000000000,
+ }
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Model benchmark for task: {task} and sample rate: {sample_rate}: {model_benchmark}"
+ )
+
+ return model_benchmark
+
+ def run_competition(self, sample_rate, task) -> None:
+ """
+ Runs a competition (a competition is a unique combination of sample rate and task).
+
+ 1. Queries all miners for their models.
+ 2. If miner submits a new model, benchmarks it with SubnetValidator.benchmark_model
+ 3. Updates knowledge of miner model benchmarking results.
+ """
+ # Obtain existing list of miner model data for this competition
+ competition_miner_models = self.miner_models[f"{task}_{sample_rate}HZ"]
+ blacklisted_miner_models = self.blacklisted_miner_models[f"{task}_{sample_rate}HZ"]
+
+ # Create new list which we will gradually append to and eventually replace self.miner_models with
+ new_competition_miner_models = []
+
+ # Iterate through UIDs to query
+ for uid_to_query in self.get_uids_to_query():
+
+ if Utils.validate_uid(uid_to_query):
+
+ # Send synapse
+ response = asyncio.run(self.send_competition_synapse(
+ uid_to_query=uid_to_query,
+ sample_rate=sample_rate,
+ task=task
+ ))
+
+ # Add this data to the HealthCheck API
+ self.healthcheck_api.append_metric(metric_name="axons.total_queried_axons", value=1)
+
+ # Check that the miner has responded with a model for this competition. If not, skip it
+ if response.data:
+
+ self.neuron_logger(
+ severity="TRACE",
+ message=f"Recieved response from miner with UID: {uid_to_query}: {response.data}"
+ )
+
+ # Add this data to HealthCheck API
+ self.healthcheck_api.append_metric(metric_name="responses.total_valid_responses", value=1)
+
+ # Find existing information on miner model
+ miner_model_all_data = copy.deepcopy(self.find_dict_by_hotkey(competition_miner_models, self.hotkeys[uid_to_query]))
+
+ # Construct a dict that has only the keys needed from previously known model data
+ miner_model_data = {}
+ if miner_model_all_data and 'hf_model_namespace' in miner_model_all_data.keys() and 'hf_model_name' in miner_model_all_data.keys() and 'hf_model_revision' in miner_model_all_data.keys():
+ for k in ['hf_model_namespace','hf_model_name','hf_model_revision']:
+ miner_model_data[k] = miner_model_all_data[k]
+
+ # Check that the synapse response is validly formatted
+ valid_model=False
+ if isinstance(response.data, dict) and 'hf_model_namespace' in response.data and 'hf_model_name' in response.data and 'hf_model_revision' in response.data and response.data['hf_model_namespace'] != "synapsecai":
+ valid_model=True
+
+ # In case that synapse response is not formatted correctly:
+ if not valid_model:
+ # Continue to next miner if there is no known historical data
+ if not miner_model_all_data:
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Miner response is invalid: {response.data}"
+ )
+ continue
+ # Keep using historical data if it exists before continuing to next miner
+ else:
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Miner response is invalid: {response.data}. Keeping old miner model data: {miner_model_all_data}"
+ )
+ new_competition_miner_models.append(miner_model_all_data)
+ continue
+
+ # If the model in the synapse has never been evaluated by the validator:
+ if (not miner_model_data or response.data != miner_model_data) and (miner_model_data not in blacklisted_miner_models) and valid_model:
+
+ # Create a dictionary logging miner model metadata & benchmark values
+ model_data = self.benchmark_model(
+ model_metadata = response.data,
+ sample_rate = sample_rate,
+ task = task,
+ hotkey = self.hotkeys[uid_to_query],
+ )
+
+ # Append to the list
+ new_competition_miner_models.append(model_data)
+
+ # If the model in the synapse has already been evaluated by the validator:
+ else:
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Model has already been evaluated: {miner_model_all_data}"
+ )
+
+ # Append existing data the list
+ new_competition_miner_models.append(miner_model_all_data)
+
+ # In the case of empty rersponse:
+ else:
+
+ # Add this data to the HealthCheck API
+ self.healthcheck_api.append_metric(metric_name="responses.total_invalid_responses", value=1)
+
+ # Find miner model data
+ miner_model_data = self.find_dict_by_hotkey(competition_miner_models, self.hotkeys[uid_to_query])
+
+ # If any existing model exists for the hotkey:
+ if miner_model_data:
+
+ # Append existing model data to list
+ new_competition_miner_models.append(miner_model_data)
+
+ # In the case that multiple models have the same hash, we only want to include the model with the earliest block when the metadata was uploaded to the chain
+ hash_filtered_new_competition_miner_models, same_hash_blacklist = Benchmarking.filter_models_with_same_hash(
+ new_competition_miner_models=new_competition_miner_models
+ )
+
+ # In the case that multiple models have the same metadata, we only want to include the model with the earliest block when the metadata was uploaded to the chain
+ hash_metadata_filtered_new_competition_miner_models, same_metadata_blacklist = Benchmarking.filter_models_with_same_metadata(
+ new_competition_miner_models=hash_filtered_new_competition_miner_models
+ )
+
+ self.blacklisted_miner_models[f"{task}_{sample_rate}HZ"].extend(same_hash_blacklist)
+ self.blacklisted_miner_models[f"{task}_{sample_rate}HZ"].extend(same_metadata_blacklist)
+ self.miner_models[f"{task}_{sample_rate}HZ"] = hash_metadata_filtered_new_competition_miner_models
+
+ competition = f"{task}_{sample_rate}HZ"
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Models for competition: {competition}: {self.miner_models[competition]}"
+ )
+
+ def run(self) -> None:
+ """
+ Main validator loop.
+ """
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Starting validator loop with version: {self.version}"
+ )
+ self.healthcheck_api.append_metric(metric_name="neuron_running", value=True)
+
+ while True:
+ try:
+ # Update knowledge of metagraph and save state before going onto a new competition
+ # First, sync metagraph
+ self.handle_metagraph_sync()
+
+ # Then, check that hotkey knowledge matches
+ self.check_hotkeys()
+
+ # Check to see if validator is still registered on metagraph
+ if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys:
+ self.neuron_logger(
+ severity="ERROR",
+ message=f"Hotkey is not registered on metagraph: {self.wallet.hotkey.ss58_address}."
+ )
+
+ # Save validator state
+ self.save_state()
+
+ # Iterate through sample rates
+ for sample_rate in self.sample_rates:
+
+ # Iterate through tasks
+ for task in self.tasks:
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Judging for competition: {task}_{sample_rate}HZ"
+ )
+
+ # Query and evaluate models for this competition, generate the data necessary to run the scoring algorithm
+ self.run_competition(sample_rate=sample_rate, task=task)
+
+ # Check if it's time for a new competition
+ if int(time.time()) >= self.next_competition_timestamp or self.debug_mode:
+
+ self.neuron_logger(
+ severity="INFO",
+ message="Starting new competition."
+ )
+
+ # First, sync metagraph
+ self.handle_metagraph_sync()
+
+ # Then, check that hotkey knowledge matches
+ self.check_hotkeys()
+
+ # First reset competition scores and overall scores so that we can re-calculate them from validator model data
+ self.init_default_scores()
+
+ # Calculate scores for each competition
+ self.best_miner_models, self.competition_scores = Benchmarking.determine_competition_scores(
+ competition_scores = self.competition_scores,
+ competition_max_scores = self.competition_max_scores,
+ metric_proportions = self.metric_proportions,
+ best_miner_models = self.best_miner_models,
+ miner_models = self.miner_models,
+ metagraph = self.metagraph,
+ log_level = self.log_level,
+ )
+
+ # Update validator.scores
+ self.scores = Benchmarking.calculate_overall_scores(
+ competition_scores = self.competition_scores,
+ scores = self.scores,
+ log_level = self.log_level,
+ )
+
+ self.neuron_logger(
+ severity="INFO",
+ message=f"Overall miner scores: {self.scores}"
+ )
+
+ # Update HealthCheck API
+ self.healthcheck_api.update_competition_scores(self.competition_scores)
+ self.healthcheck_api.update_scores(self.scores)
+
+ # Update timestamp to next day's 9AM (GMT)
+ self.update_next_competition_timestamp()
+
+ # Update dataset for next day's competition
+ self.generate_new_dataset()
+
+ # Benchmark SGMSE+ for new dataset as a comparison for miner models
+ self.benchmark_sgmse_for_all_competitions()
+
+ # Handle setting of weights
+ self.handle_weight_setting()
+
+ # Handle remote logging
+ self.handle_remote_logging()
+
+ self.neuron_logger(
+ severity="TRACE",
+ message=f"Updating HealthCheck API."
+ )
+
+ # Update metrics in healthcheck API at end of each iteration
+ self.healthcheck_api.update_current_models(self.miner_models)
+ self.healthcheck_api.update_best_models(self.best_miner_models)
+ self.healthcheck_api.append_metric(metric_name='iterations', value=1)
+ self.healthcheck_api.update_rates()
+
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Competition scores: {self.competition_scores}. Scores: {self.scores}"
+ )
+
+ # Sleep for a duration equivalent to 1/3 of the block time (i.e., time between successive blocks).
+ self.neuron_logger(
+ severity="DEBUG",
+ message=f"Sleeping for: {bt.BLOCKTIME/3} seconds"
+ )
+ time.sleep(bt.BLOCKTIME / 3)
+
+ # If we encounter an unexpected error, log it for debugging.
+ except RuntimeError as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=e
+ )
+ traceback.print_exc()
+
+ # If the user interrupts the program, gracefully exit.
+ except KeyboardInterrupt:
+ self.neuron_logger(
+ severity="SUCCESS",
+ message="Keyboard interrupt detected. Exiting validator.")
+ sys.exit()
+
+ # If we encounter a general unexpected error, log it for debugging.
+ except Exception as e:
+ self.neuron_logger(
+ severity="ERROR",
+ message=e
+ )
+ traceback.print_exc()
\ No newline at end of file
diff --git a/soundsright/neurons/__init__.py b/soundsright/neurons/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/soundsright/neurons/miner.py b/soundsright/neurons/miner.py
new file mode 100644
index 0000000..2c3afb5
--- /dev/null
+++ b/soundsright/neurons/miner.py
@@ -0,0 +1,69 @@
+"""
+This miner script executes the main loop for the miner and keeps the
+miner active in the Bittensor network.
+"""
+
+from argparse import ArgumentParser
+import soundsright.core as SoundsRightCore
+
+# This is the main function, which runs the miner.
+if __name__ == "__main__":
+ # Parse command line arguments
+ parser = ArgumentParser()
+
+ ###########################################
+ # THIS MUST BE ADJUSTED UPON REGISTRATION #
+ ###########################################
+ parser.add_argument(
+ "--netuid",
+ type=int,
+ default=0,
+ help="The chain subnet uid"
+ )
+
+ parser.add_argument(
+ "--logging.logging_dir",
+ type=str,
+ default="/var/log/bittensor",
+ help="Provide the log directory",
+ )
+
+ parser.add_argument(
+ "--validator_min_stake",
+ type=float,
+ default=10000.0,
+ help="Determine the minimum stake the validator should have to accept requests",
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=str,
+ default="INFO",
+ choices=["INFO", "INFOX", "DEBUG", "DEBUGX", "TRACE", "TRACEX"],
+ help="Determine the logging level used by the subnet modules",
+ )
+
+ parser.add_argument(
+ "--healthcheck_host",
+ type=str,
+ default="0.0.0.0",
+ help="Set the healthcheck API host. Defaults to 0.0.0.0 to expose it outside of the container.",
+ )
+
+ parser.add_argument(
+ "--healthcheck_port",
+ type=int,
+ default=6000,
+ help="Determine the port used by the healthcheck API.",
+ )
+
+ parser.add_argument(
+ "--axon.port",
+ type=int,
+ default=6001,
+ help="Axon port, default is 6001. If you want to alter this value you will also need to adjust the exposed ports in the docker-compose.yml file."
+ )
+
+ # Create a miner based on the Class definitions
+ subnet_miner = SoundsRightCore.SubnetMiner(parser=parser)
+ subnet_miner.run()
\ No newline at end of file
diff --git a/soundsright/neurons/validator.py b/soundsright/neurons/validator.py
new file mode 100644
index 0000000..380e830
--- /dev/null
+++ b/soundsright/neurons/validator.py
@@ -0,0 +1,72 @@
+"""
+Main script for running SoundsRight validator
+"""
+# Import standard modules
+from argparse import ArgumentParser
+import os
+from dotenv import load_dotenv
+load_dotenv()
+
+# Import subnet modules
+import soundsright.core as SoundsRightCore
+
+# The main function parses the configuration and runs the validator.
+if __name__ == "__main__":
+ # Parse command line arguments
+ parser = ArgumentParser()
+
+ ###########################################
+ # THIS MUST BE ADJUSTED UPON REGISTRATION #
+ ###########################################
+ parser.add_argument(
+ "--netuid",
+ type=int,
+ default=0,
+ help="The chain subnet uid."
+ )
+
+ parser.add_argument(
+ "--load_state",
+ type=str,
+ default="True",
+ help="WARNING: Setting this value to False clears the old state.",
+ )
+
+ parser.add_argument(
+ "--debug_mode",
+ action="store_true",
+ help="Running the validator in debug mode ignores selected validity checks. Not to be used in production.",
+ )
+
+ parser.add_argument(
+ "--dataset_size",
+ default=2000,
+ type=int,
+ help="Size of evaluation dataset."
+ )
+
+ parser.add_argument(
+ "--log_level",
+ type=str,
+ default="INFO",
+ choices=["INFO", "INFOX", "DEBUG", "DEBUGX", "TRACE", "TRACEX"],
+ help="Determine the logging level used by the subnet modules",
+ )
+
+ parser.add_argument(
+ "--healthcheck_host",
+ type=str,
+ default="0.0.0.0",
+ help="Set the healthcheck API host. Defaults to 0.0.0.0 to expose it outside of the container.",
+ )
+
+ parser.add_argument(
+ "--healthcheck_port",
+ type=int,
+ default=6000,
+ help="Determine the port used by the healthcheck API.",
+ )
+
+ # Create a validator based on the Class definitions and initialize it
+ subnet_validator = SoundsRightCore.SubnetValidator(parser=parser)
+ subnet_validator.run()
\ No newline at end of file
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/benchmarking/__init__.py b/test/benchmarking/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/benchmarking/scoring_test.py b/test/benchmarking/scoring_test.py
new file mode 100644
index 0000000..73182b8
--- /dev/null
+++ b/test/benchmarking/scoring_test.py
@@ -0,0 +1,232 @@
+import pytest
+from unittest.mock import Mock, patch
+import numpy as np
+import soundsright.base.benchmarking as Benchmarking
+
+@pytest.mark.parametrize("new_model_block, old_model_block, start_improvement, end_improvement, decay_block, expected", [
+ (0, 0, 0.0035, 0.0015, 50400, 0.0035), # No difference
+ (50400, 0, 0.0035, 0.0015, 50400, 0.0015), # Full decay, should reach end_improvement
+ (25200, 0, 0.0035, 0.0015, 50400, 0.0025), # Halfway through decay
+ (50400, 50400, 0.0035, 0.0015, 50400, 0.0035), # Same blocks, no improvement adjustment
+ (100000, 0, 0.0035, 0.0015, 50400, 0.0015), # Beyond decay, should cap at end_improvement
+ (30000, 10000, 0.0035, 0.0015, 50400, 0.002706), # Partial decay with a difference of 20,000 blocks
+])
+
+def test_calculate_improvement_factor(new_model_block, old_model_block, start_improvement, end_improvement, decay_block, expected):
+ result = Benchmarking.calculate_improvement_factor(
+ new_model_block,
+ old_model_block,
+ start_improvement=start_improvement,
+ end_improvement=end_improvement,
+ decay_block=decay_block
+ )
+ assert result == pytest.approx(expected, rel=1e-2), f"Expected {expected}, got {result}"
+
+
+@pytest.mark.parametrize("new_model_metric, new_model_block, old_model_metric, old_model_block, expected", [
+ (1.01, 0, 1.0, 0, True), # New model surpasses old with improvement factor at start value
+ (1.007, 25200, 1.0, 0, True), # Halfway decay improvement factor
+ (1.0015, 50400, 1.0, 0, True), # Full decay improvement factor met
+ (1.0014, 50400, 1.0, 0, False), # Full decay but not enough improvement
+ (1.0, 0, 1.0, 0, False), # No improvement
+ (1.02, 0, 1.01, 10000, True), # New model surpasses, partial decay
+ (1.008, 100000, 1.0, 0, True), # New model exceeds best model with block difference beyond decay
+ (1.005, 100000, 1.0, 0, True), # Exactly at end improvement after max decay
+ (1.0002, 50400, 1.0, 0, False), # New model slightly better but not enough
+ (0.99, 0, 1.0, 0, False), # New model worse than best model
+])
+
+def test_new_model_surpasses_historical_model(new_model_metric, new_model_block, old_model_metric, old_model_block, expected):
+ result = Benchmarking.new_model_surpasses_historical_model(new_model_metric, new_model_block, old_model_metric, old_model_block)
+ assert result == expected, f"Expected {expected} but got {result}"
+
+
+def test_get_best_model_from_list():
+ models_data = [
+ {
+ "name": "Model A",
+ "metrics": {
+ "PESQ": {"average": 1.1},
+ "ESTOI": {"average": 1.2},
+ },
+ },
+ {
+ "name": "Model B",
+ "metrics": {
+ "PESQ": {"average": 1.3},
+ "ESTOI": {"average": 1.0},
+ },
+ },
+ {
+ "name": "Model C",
+ "metrics": {
+ "PESQ": {"average": 0.9},
+ "ESTOI": {"average": 1.4},
+ },
+ },
+ ]
+
+ # Test for PESQ metric
+ best_model_pesq = Benchmarking.get_best_model_from_list(models_data, "PESQ")
+ assert best_model_pesq["name"] == "Model B"
+
+ # Test for ESTOI metric
+ best_model_estoi = Benchmarking.get_best_model_from_list(models_data, "ESTOI")
+ assert best_model_estoi["name"] == "Model C"
+
+ # Test for non-existent metric
+ best_model_nonexistent = Benchmarking.get_best_model_from_list(models_data, "SI_SDR")
+ assert best_model_nonexistent is None
+
+ # Test with empty models_data
+ best_model_empty = Benchmarking.get_best_model_from_list([], "PESQ")
+ assert best_model_empty is None
+
+ # Test with invalid metric data
+ models_data_invalid = [
+ {
+ "name": "Model D",
+ "metrics": {
+ "PESQ": {"average": "invalid"},
+ },
+ },
+ {
+ "name": "Model E",
+ "metrics": {
+ "PESQ": {"average": None},
+ },
+ },
+ ]
+ best_model_invalid = Benchmarking.get_best_model_from_list(models_data_invalid, "PESQ")
+ assert best_model_invalid is None
+
+@pytest.mark.parametrize(
+ "competition_scores, competition_max_scores, metric_proportions, best_miner_models, miner_models, expected_scores, expected_best_models",
+ [
+ # Edge case: No models submitted for a competition
+ (
+ {"competition1": np.array([0, 0])},
+ {"competition1": 100},
+ {"competition1": {"PESQ": 1.0}},
+ {"competition1": []},
+ {"competition1": []},
+ {"competition1": np.array([0, 0])},
+ {"competition1": []},
+ ),
+ # Edge case: Single model submitted
+ (
+ {"competition1": np.array([0, 0])},
+ {"competition1": 100},
+ {"competition1": {"PESQ": 1.0}},
+ {"competition1": []},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}}, "block": 10},
+ ]},
+ {"competition1": np.array([100, 0])},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}}, "block": 10},
+ ]},
+ ),
+ # Edge case: Multiple models submitted, same metric values
+ (
+ {"competition1": np.array([0, 0])},
+ {"competition1": 100},
+ {"competition1": {"PESQ": 1.0}},
+ {"competition1": []},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.1}}, "block": 10},
+ {"hotkey": "miner_hotkey_ss58adr2", "metrics": {"PESQ": {"average": 1.1}}, "block": 15},
+ ]},
+ {"competition1": np.array([100, 0])},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.1}}, "block": 10},
+ ]},
+ ),
+ # Edge case: Historical model outperforms current model
+ (
+ {"competition1": np.array([0, 0])},
+ {"competition1": 100},
+ {"competition1": {"PESQ": 1.0}},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}}, "block": 5},
+ ]},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr2", "metrics": {"PESQ": {"average": 1.2}}, "block": 10},
+ ]},
+ {"competition1": np.array([100, 0])},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}}, "block": 5},
+ ]},
+ ),
+ # Edge case: Historical model outperforms current model in one metric and current model outperforms historical model in the other metric
+ (
+ {"competition1": np.array([0, 0])},
+ {"competition1": 100},
+ {"competition1": {"PESQ": 0.3, "ESTOI": 0.7}},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}, "ESTOI": {"average": 0.5}}, "block": 5},
+ ]},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr2", "metrics": {"PESQ": {"average": 1.2}, "ESTOI": {"average": 0.9}}, "block": 10},
+ ]},
+ {"competition1": np.array([30, 70])},
+ {"competition1": [
+ {"hotkey": "miner_hotkey_ss58adr", "metrics": {"PESQ": {"average": 1.5}, "ESTOI": {"average": 0.5}}, "block": 5},
+ {"hotkey": "miner_hotkey_ss58adr2", "metrics": {"PESQ": {"average": 1.2}, "ESTOI": {"average": 0.9}}, "block": 10},
+ ]},
+ ),
+ ],
+)
+def test_determine_competition_scores(
+ competition_scores,
+ competition_max_scores,
+ metric_proportions,
+ best_miner_models,
+ miner_models,
+ expected_scores,
+ expected_best_models,
+):
+ metagraph = Mock()
+ metagraph.hotkeys = ["miner_hotkey_ss58adr", "miner_hotkey_ss58adr2"]
+
+ new_best_miner_models, updated_competition_scores = Benchmarking.determine_competition_scores(
+ competition_scores,
+ competition_max_scores,
+ metric_proportions,
+ best_miner_models,
+ miner_models,
+ metagraph,
+ log_level="INFO",
+ )
+
+ for competition, scores in updated_competition_scores.items():
+ assert np.array_equal(scores, expected_scores[competition]), (
+ f"updated_competition_scores[{competition}]: {scores} "
+ f"is not equal to expected_scores[{competition}]: {expected_scores[competition]}"
+ )
+
+ # Comparison for best-performing models
+ assert new_best_miner_models == expected_best_models, (
+ f"new_best_miner_models: {new_best_miner_models} "
+ f"is not equal to expected_best_models: {expected_best_models}"
+ )
+
+
+
+@pytest.mark.parametrize("competition_scores, initial_scores, expected", [
+ # Basic case
+ ({"comp1": np.array([0,0,15,0]), "comp2": np.array([0,0,0,10]), "comp3": np.array([0,5,0,0])}, np.zeros(4), np.array([0, 5, 15, 10])),
+
+ # No updates (empty competition_scores)
+ ({"comp1":np.zeros(3), "comp2":np.zeros(3)}, np.zeros(3), np.array([0, 0, 0])),
+])
+
+def test_calculate_overall_scores_varied(competition_scores, initial_scores, expected):
+ log_level = "INFO"
+
+ # Call the function
+ updated_scores = Benchmarking.calculate_overall_scores(competition_scores, initial_scores, log_level)
+
+ # Assertions
+ np.testing.assert_array_equal(updated_scores, expected, "The overall scores were not calculated as expected.")
+
\ No newline at end of file
diff --git a/test/benchmarking/test_remote_logging.py b/test/benchmarking/test_remote_logging.py
new file mode 100644
index 0000000..5147d87
--- /dev/null
+++ b/test/benchmarking/test_remote_logging.py
@@ -0,0 +1,183 @@
+import pytest
+import os
+import time
+import bittensor as bt
+
+import soundsright.base.benchmarking as Benchmarking
+
+from dotenv import load_dotenv
+load_dotenv()
+
+def get_hk():
+ ck_name = os.getenv("WALLET")
+ hk_name = os.getenv("HOTKEY")
+ wallet = bt.wallet(name=ck_name, hotkey=hk_name)
+ return wallet.hotkey.ss58_address
+
+def test_miner_models_remote_logging():
+
+ hk = get_hk()
+
+ miner_models = {
+ "category": "current",
+ "validator":hk,
+ "timestamp":int(time.time()*1000),
+ "models": {
+ "DENOISING_16000HZ":[
+ {
+ "hotkey":"miner_hotkey_ss58adr",
+ "hf_model_name":"SoundsRightModelTemplate",
+ "hf_model_namespace":"synapsecai",
+ "hf_model_revision":"main",
+ "model_hash":"aaaaaaaaaaaaaaaaa11111aaaaa",
+ "block":17873471294,
+ "metrics":{
+ "PESQ":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "ESTOI":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SDR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SAR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SIR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ }
+ }
+ ],
+ "DEREVERBERATION_16000HZ":[
+ {
+ "hotkey":"miner_hotkey_ss58adr",
+ "hf_model_name":"SoundsRightModelTemplate",
+ "hf_model_namespace":"synapsecai",
+ "hf_model_revision":"main",
+ "model_hash":"aaaaaaaaaaaaaaaaa11111aaaaa",
+ "block":17873471294,
+ "metrics":{
+ "PESQ":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "ESTOI":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SDR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SAR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SIR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ }
+ }
+ ],
+ }
+ }
+
+ logging_outcome = Benchmarking.miner_models_remote_logging(
+ hotkey=hk,
+ current_miner_models=miner_models,
+ log_level="TRACE"
+ )
+
+ assert logging_outcome, "Miner model logging failed."
+
+def test_sgmse_remote_logging():
+
+ hk = get_hk()
+
+ sgmse_benchmark = {
+ "category": "sgmse",
+ "validator":hk,
+ "timestamp":int(time.time()*1000),
+ "models": {
+ "DENOISING_16000HZ":{
+ "PESQ":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "ESTOI":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SDR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SAR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SIR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ },
+ "DEREVERBERATION_16000HZ":{
+ "PESQ":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "ESTOI":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SDR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SAR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ "SI_SIR":{
+ "scores":[1.0,1.1,1.2],
+ "average":1.1,
+ "confidence_interval":[1.05,1.15],
+ },
+ },
+ }
+ }
+
+ logging_outcome = Benchmarking.sgmse_remote_logging(
+ hotkey=hk,
+ sgmse_benchmarks=sgmse_benchmark,
+ log_level="TRACE"
+ )
+
+ assert logging_outcome, "SGMSE+ benchmark failed"
\ No newline at end of file
diff --git a/test/bittensor/test_commit_reveal.py b/test/bittensor/test_commit_reveal.py
new file mode 100644
index 0000000..5d45083
--- /dev/null
+++ b/test/bittensor/test_commit_reveal.py
@@ -0,0 +1,115 @@
+import bittensor as bt
+import secrets
+import numpy as np
+import os
+import time
+from dotenv import load_dotenv
+load_dotenv()
+
+class CommitMachine:
+
+ def __init__(self):
+
+ # Setup Bittensor objects
+ self.subtensor = bt.subtensor(network="test")
+ self.metagraph = self.subtensor.metagraph(netuid=os.environ.get("NETUID"))
+ self.wallet = bt.wallet(name="testnet_validator", hotkey="validator")
+
+ self.commit_reveal_weights_interval = self.subtensor.get_subnet_hyperparameters(netuid=os.environ.get("NETUID")).commit_reveal_weights_interval
+
+ # Store salts to be revealed
+ self.weight_objects = []
+
+ def _get_random_weights(self, metagraph) -> np.ndarray:
+
+ # Generate random weights
+ weights = np.array([secrets.randbelow(1025) for _ in range(metagraph.size)])
+
+ return weights.tolist()
+
+ def _store_weight_metadata(self, salt, uids, weights, block):
+
+ # Construct weight object
+ data = {
+ "salt": np.array([salt]),
+ "uids": uids,
+ "weights": weights,
+ "block": block
+ }
+
+ # Store weight object
+ self.weight_objects.append(data)
+
+ def reveal_weights(self, weight_object):
+
+ status, msg = self.subtensor.reveal_weights(
+ wallet=self.wallet,
+ netuid=38,
+ uids=weight_object["uids"],
+ weights=weight_object["weights"],
+ salt=weight_object["salt"],
+ max_retries=5
+ )
+
+ print(f'Weight reveal status: {status} - Status message: {msg}')
+
+ return status
+
+ def commit_weights(self):
+
+ # Generate metadata
+ salt = secrets.randbelow(2**16)
+ uids = self.metagraph.uids.tolist()
+ weights = self._get_random_weights(self.metagraph.uids)
+
+ # Store metadata
+ self._store_weight_metadata(salt, uids, weights, self.subtensor.block)
+
+ # Commit weights
+ status, msg = self.subtensor.commit_weights(
+ wallet=self.wallet,
+ netuid=38,
+ salt=[salt],
+ uids=uids,
+ weights=weights,
+ max_retries=5
+ )
+
+ print(f'Weight commit status: {status} - Status message: {msg}')
+
+ return status
+
+def test_commit_reveal():
+ commit_machine = CommitMachine()
+
+ # Add few commits to the list, sleep to separate the commits into different blocks
+ assert commit_machine.commit_weights(), "Commit weights failed"
+ time.sleep(bt.BLOCKTIME * 2)
+ assert commit_machine.commit_weights(), "Commit weights failed"
+ time.sleep(bt.BLOCKTIME * 2)
+ assert commit_machine.commit_weights(), "Commit weights failed"
+
+ # Reveal all weight commits
+ while(len(commit_machine.weight_objects) > 0):
+
+ # Get current block
+ current_block = commit_machine.subtensor.block
+ print(f'Current block: {current_block}')
+
+ # Iterate all weight objects
+ for i,obj in enumerate(commit_machine.weight_objects):
+ diff = (current_block - obj["block"])
+ print(f'Difference is {diff} blocks')
+ if diff > commit_machine.commit_reveal_weights_interval:
+ print(f'Revealing: {obj}')
+ assert commit_machine.reveal_weights(obj), "Reveal weights failed"
+
+ # Remove from objects list
+ commit_machine.weight_objects.pop(i)
+
+ # Sleep for blocktime before next iteration
+ time.sleep(bt.BLOCKTIME)
+
+
+if __name__ == "__main__":
+ test_commit_reveal()
\ No newline at end of file
diff --git a/test/data/data_gen_loop.py b/test/data/data_gen_loop.py
new file mode 100644
index 0000000..c2130bd
--- /dev/null
+++ b/test/data/data_gen_loop.py
@@ -0,0 +1,44 @@
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+from soundsright.base.data import create_noise_and_reverb_data_for_all_sampling_rates, TTSHandler, reset_all_data_directories, dataset_download
+
+base_path = os.path.join(os.path.expanduser("~"), ".SoundsRight")
+tts_base_path = os.path.join(base_path,'data/tts')
+noise_base_path = os.path.join(base_path,'data/noise')
+reverb_base_path = os.path.join(base_path,'data/reverb')
+arni_path = os.path.join(base_path,'data/rir_data')
+wham_path = os.path.join(base_path,'data/noise_data')
+sample_rates = [16000]
+
+for directory in [tts_base_path, noise_base_path, reverb_base_path, arni_path, wham_path]:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ print(f"Created directory: {directory}")
+
+reset_all_data_directories(tts_base_path=tts_base_path, reverb_base_path=reverb_base_path, noise_base_path=noise_base_path)
+
+if False:
+ print("Downloading datasets")
+ dataset_download(wham_path=wham_path, arni_path=arni_path, partial=True)
+
+tts_handler = TTSHandler(tts_base_path=tts_base_path, sample_rates=sample_rates)
+print("TTSHandler initialized")
+for sr in sample_rates:
+ print("Creating TTS dataset")
+ tts_handler.create_openai_tts_dataset_for_all_sample_rates(
+ n=3
+ )
+
+create_noise_and_reverb_data_for_all_sampling_rates(
+ tts_base_path = tts_base_path,
+ arni_dir_path = arni_path,
+ reverb_base_path=reverb_base_path,
+ wham_dir_path=wham_path,
+ noise_base_path=noise_base_path,
+ tasks=['denoising', 'dereverberation']
+)
+
+remove=input("remove all files? y/n")
+if remove=='y':
+ reset_all_data_directories(tts_base_path=tts_base_path, reverb_base_path=reverb_base_path, noise_base_path=noise_base_path)
\ No newline at end of file
diff --git a/test/models/__init__.py b/test/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/models/metadata_handler_test.py b/test/models/metadata_handler_test.py
new file mode 100644
index 0000000..bd40fb8
--- /dev/null
+++ b/test/models/metadata_handler_test.py
@@ -0,0 +1,81 @@
+import bittensor as bt
+import pytest
+import argparse
+import os
+import hashlib
+import time
+import random
+import string
+import asyncio
+from dotenv import load_dotenv
+load_dotenv()
+
+import soundsright.base.models as Models
+
+def init_metadata_handler(subnet_netuid = 38):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--netuid", type=int, default=subnet_netuid)
+ parser.add_argument("--subtensor.network", type=str, default="test")
+ config = bt.config(parser=parser)
+ subtensor = bt.subtensor(config=config, network = "test")
+ ck=os.environ.get("WALLET")
+ hk=os.environ.get("HOTKEY")
+ wallet=bt.wallet(name=ck,hotkey=hk)
+
+ return Models.ModelMetadataHandler(
+ subtensor=subtensor,
+ subnet_netuid=subnet_netuid,
+ log_level="INFO",
+ wallet=wallet
+ ), wallet.hotkey.ss58_address
+
+@pytest.mark.parametrize("competition_id, expected", [
+ (1, "DENOISING_16000HZ"),
+ (2, "DEREVERBERATION_16000HZ"),
+ ("1", "DENOISING_16000HZ"),
+ ("2", "DEREVERBERATION_16000HZ"),
+ (5, None),
+ ("5", None),
+ (None, None),
+ ({}, None),
+ (True, None),
+ (False, None)
+])
+
+def test_get_competition_name_from_competition_id(competition_id, expected):
+ metadata_handler, _ = init_metadata_handler()
+ assert metadata_handler.get_competition_name_from_competition_id(competition_id) == expected
+
+@pytest.mark.parametrize("competition_name, expected", [
+ ("DENOISING_16000HZ", 1),
+ ("DEREVERBERATION_16000HZ", 2),
+])
+
+def test_get_competition_id_from_competition_name(competition_name, expected):
+ metadata_handler, _ = init_metadata_handler()
+ assert metadata_handler.get_competition_id_from_competition_name(competition_name) == expected
+
+def generate_random_string(N):
+ """Generate a random alphanumeric string of length N."""
+ if N < 1:
+ raise ValueError("Length N must be at least 1.")
+
+ characters = string.ascii_letters + string.digits
+ return ''.join(random.choice(characters) for _ in range(N))
+
+def test_chain_metadata_upload_round_trip():
+ metadata_handler, hk = init_metadata_handler()
+
+ metadata_str = f"{generate_random_string(10)}:{generate_random_string(25)}:{generate_random_string(5)}:{generate_random_string(2)}:{generate_random_string(48)}:1"
+ hashed_metadata_str = hashlib.sha256(metadata_str.encode()).hexdigest()
+
+ upload_outcome = asyncio.run(metadata_handler.upload_model_metadata_to_chain(metadata=hashed_metadata_str))
+ assert upload_outcome == True
+ block_uploaded = metadata_handler.subtensor.get_current_block()
+
+ assert block_uploaded
+
+ download_outcome = asyncio.run(metadata_handler.obtain_model_metadata_from_chain(hotkey=hk))
+ assert download_outcome == True
+ assert block_uploaded >= metadata_handler.metadata_block
+ assert hashlib.sha256(metadata_str.encode()).hexdigest() == metadata_handler.metadata
\ No newline at end of file
diff --git a/test/models/validation_test.py b/test/models/validation_test.py
new file mode 100644
index 0000000..f7b8d28
--- /dev/null
+++ b/test/models/validation_test.py
@@ -0,0 +1,60 @@
+import soundsright.base.models as Models
+import os
+import shutil
+import pytest
+
+def remove_all_in_path(path):
+ """
+ Removes all files and directories located at the specified path.
+
+ Args:
+ path (str): The path to the directory to clear.
+ """
+ if not os.path.isdir(path):
+ raise ValueError(f"The specified path '{path}' is not a directory.")
+
+ for filename in os.listdir(path):
+ file_path = os.path.join(path, filename)
+
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.remove(file_path) # Remove file or symlink
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path) # Remove directory and its contents
+ except Exception as e:
+ print(f"Failed to delete {file_path}. Reason: {e}")
+
+@pytest.mark.parametrize("model_id", [
+ ("huseinzol05/speech-enhancement-mask-unet"),
+ ("sp-uhh/speech-enhancement-sgmse"),
+ ("rollingkevin/speech-enhancement-unet"),
+])
+def test_get_model_content_hash(model_id):
+
+ model_path=f"{os.path.expanduser('~')}/.soundsright/model_test"
+ model_dir = os.path.join(os.getcwd(), model_path)
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+
+ model_hash_1, sorted_files_1 = Models.get_model_content_hash(
+ model_id=model_id,
+ revision="main",
+ local_dir=model_dir,
+ log_level="INFO"
+ )
+
+ remove_all_in_path(model_dir)
+
+ model_hash_2, sorted_files_2 = Models.get_model_content_hash(
+ model_id=model_id,
+ revision="main",
+ local_dir=model_dir,
+ log_level="INFO"
+ )
+
+ remove_all_in_path(model_dir)
+ shutil.rmtree(model_dir)
+
+ assert len(sorted_files_1) == len(sorted_files_2), "File lengths different"
+ assert sorted_files_1 == sorted_files_2, "File order different"
+ assert model_hash_1 == model_hash_2, "Hashes different"
diff --git a/test/utils/__init__.py b/test/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/test/utils/test_container.py b/test/utils/test_container.py
new file mode 100644
index 0000000..3890749
--- /dev/null
+++ b/test/utils/test_container.py
@@ -0,0 +1,84 @@
+from git import Repo
+import time
+import os
+import shutil
+import yaml
+import glob
+import pytest
+
+import soundsright.base.utils as Utils
+
+def remove_all_in_path(path):
+ """
+ Removes all files and directories located at the specified path.
+
+ Args:
+ path (str): The path to the directory to clear.
+ """
+ if not os.path.isdir(path):
+ raise ValueError(f"The specified path '{path}' is not a directory.")
+
+ for filename in os.listdir(path):
+ file_path = os.path.join(path, filename)
+
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.remove(file_path) # Remove file or symlink
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path) # Remove directory and its contents
+ except Exception as e:
+ print(f"Failed to delete {file_path}. Reason: {e}")
+
+
+def validate_all_noisy_files_are_enhanced(noisy_dir, enhanced_dir):
+ noisy_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(noisy_dir, '*.wav'))])
+ enhanced_files = sorted([os.path.basename(f) for f in glob.glob(os.path.join(enhanced_dir, '*.wav'))])
+ return noisy_files == enhanced_files
+
+@pytest.mark.parametrize("model_id, revision",[
+ ("synapsecai/SoundsRightModelTemplate", "DENOISING_16000HZ"),
+ ("synapsecai/SoundsRightModelTemplate", "DEREVERBERATION_16000HZ")
+])
+def test_container(model_id, revision):
+
+ model_dir=f"{os.path.expanduser('~')}/.SoundsRight/data/model"
+ noisy_dir = f"{os.path.expanduser('~')}/.SoundsRight/data/noise/16000"
+ reverb_dir = f"{os.path.expanduser('~')}/.SoundsRight/data/reverb/16000"
+ enhanced_noise_dir = f"{os.path.expanduser('~')}/.SoundsRight/test_data/enhanced_noise/16000"
+ enhanced_reverb_dir = f"{os.path.expanduser('~')}/.SoundsRight/test_data/enhanced_reverb/16000"
+
+ for directory in [model_dir, noisy_dir, reverb_dir, enhanced_noise_dir, enhanced_reverb_dir]:
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ repo_url = f"https://huggingface.co/{model_id}"
+
+ shutil.rmtree(model_dir)
+
+ # Download the model files for the specified revision
+ Repo.clone_from(repo_url, model_dir, branch=revision)
+
+ assert Utils.validate_container_config(model_dir) == True, "Container config invalid"
+
+ assert Utils.start_container(directory=model_dir, log_level="INFO"), "Container could not be started"
+ time.sleep(5)
+ assert Utils.check_container_status(log_level="INFO", timeout=None), "Container status invalid"
+ assert Utils.prepare(log_level="INFO",timeout=None), "Model prepration failed"
+ time.sleep(5)
+ if "DENOISING" in revision:
+ assert Utils.upload_audio(noisy_dir, log_level="INFO", timeout=None), "Audio could not be uploaded"
+ else:
+ assert Utils.upload_audio(reverb_dir, log_level="INFO", timeout=None), "Audio could not be uploaded"
+ time.sleep(5)
+ assert Utils.enhance_audio(log_level="INFO", timeout=None), "Files could not be enhanced"
+ if "DENOISING" in revision:
+ assert Utils.download_enhanced(enhanced_noise_dir, log_level="INFO", timeout=None), "Enhanced files could not be downloaded"
+ assert validate_all_noisy_files_are_enhanced(noisy_dir, enhanced_noise_dir), "Mismatch between noisy files and enhanced files"
+ else:
+ assert Utils.download_enhanced(enhanced_reverb_dir, log_level="INFO", timeout=None), "Enhanced files could not be downloaded"
+ assert validate_all_noisy_files_are_enhanced(reverb_dir, enhanced_reverb_dir), "Mismatch between noisy files and enhanced files"
+ time.sleep(5)
+ assert Utils.delete_container(log_level="INFO"), "Container could not be deleted"
+
+ remove_all_in_path(model_dir)
+ shutil.rmtree(model_dir)
\ No newline at end of file
diff --git a/test/utils/upload_audio.py b/test/utils/upload_audio.py
new file mode 100644
index 0000000..fc59c00
--- /dev/null
+++ b/test/utils/upload_audio.py
@@ -0,0 +1,58 @@
+import glob
+import requests
+import os
+
+def upload_audio(noisy_dir, timeout=500,) -> bool:
+ """
+ Upload audio files to the API.
+
+ Returns:
+ bool: True if operation was successful, False otherwise
+ """
+ url = f"http://127.0.0.1:6500/upload-audio/"
+
+ files = sorted(glob.glob(os.path.join(noisy_dir, "*.wav")))
+
+ print(f"files: {files}")
+
+ try:
+ with requests.Session() as session:
+ file_payload = [
+ ("files", (os.path.basename(file), open(file, "rb"), "audio/wav"))
+ for file in files
+ ]
+
+ print(f"files_payload: {file_payload}")
+
+ response = session.post(url, files=file_payload, timeout=timeout)
+
+ for _, file in file_payload:
+ file[1].close() # Ensure all files are closed after the request
+
+ response.raise_for_status()
+ data = response.json()
+
+ print(f"response data: {data}")
+
+ sorted_files = sorted([file[1][0] for file in file_payload])
+ print(f"sorted_files: {sorted_files}")
+ sorted_response = sorted(data["uploaded_files"])
+ print(f"sorted_response: {sorted_response}")
+ outcome = sorted_files == sorted_response and data["status"]
+ print(f"sorted_files == sorted_response: {outcome}")
+ return sorted_files == sorted_response and data["status"]
+
+ except requests.RequestException as e:
+ print(f"Uploading audio to model failed because: {e}")
+ return False
+ except Exception as e:
+ print(f"Uploading audio to model failed because: {e}")
+ return False
+
+def main():
+ noisy_dir = f"{os.path.expanduser('~')}/.SoundsRight/data/noise/16000"
+
+ upload_audio(noisy_dir=noisy_dir)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file