Skip to content

Commit

Permalink
Merge pull request #121 from ti-mo/tb/mqtt-port
Browse files Browse the repository at this point in the history
Support connecting to MQTT brokers on non-standard ports
  • Loading branch information
unixorn authored Sep 23, 2023
2 parents e72d672 + f3b1be4 commit 9901a43
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 26 deletions.
54 changes: 37 additions & 17 deletions ha_mqtt_discoverable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ class DeviceInfo(BaseModel):
"""A list of connections of the device to the outside world as a list of tuples\
[connection_type, connection_identifier]"""
configuration_url: Optional[str] = None
"""A link to the webpage that can manage the configuration of this device. Can be either an HTTP or HTTPS link."""
"""A link to the webpage that can manage the configuration of this device.
Can be either an HTTP or HTTPS link."""

@root_validator
def must_have_identifiers_or_connection(cls, values):
Expand All @@ -501,14 +502,16 @@ class EntityInfo(BaseModel):
device: Optional[DeviceInfo] = None
"""Information about the device this sensor belongs to"""
device_class: Optional[str] = None
"""Sets the class of the device, changing the device state and icon that is displayed on the frontend."""
"""Sets the class of the device, changing the device state and icon that is
displayed on the frontend."""
enabled_by_default: Optional[bool] = None
"""Flag which defines if the entity should be enabled when first added."""
entity_category: Optional[str] = None
"""Classification of a non-primary entity."""
expire_after: Optional[int] = None
"""If set, it defines the number of seconds after the sensor’s state expires, if it’s not updated.\
After expiry, the sensor’s state becomes unavailable. Default the sensors state never expires."""
"""If set, it defines the number of seconds after the sensor’s state expires,
if it’s not updated. After expiry, the sensor’s state becomes unavailable.
Default the sensors state never expires."""
force_update: Optional[bool] = None
"""Sends update events even if the value hasn’t changed.\
Useful if you want to have meaningful value graphs in history."""
Expand All @@ -520,7 +523,8 @@ class EntityInfo(BaseModel):
qos: Optional[int] = None
"""The maximum QoS level to be used when receiving messages."""
unique_id: Optional[str] = None
"""Set this to enable editing sensor from the HA ui and to integrate with a device"""
"""Set this to enable editing sensor from the HA ui and to integrate with a
device"""

@root_validator
def device_need_unique_id(cls, values):
Expand All @@ -540,6 +544,7 @@ class MQTT(BaseModel):
"""Connection settings for the MQTT broker"""

host: str
port: Optional[int] = 1883
username: Optional[str] = None
password: Optional[str] = None
client_name: Optional[str] = None
Expand Down Expand Up @@ -588,8 +593,10 @@ def __init__(
Args:
settings: Settings for the entity we want to create in Home Assistant.
See the `Settings` class for the available options.
on_connect: Optional callback function invoked when the MQTT client successfully connects to the broker.
If defined, you need to call `_connect_client() to establish the connection manually.`
on_connect: Optional callback function invoked when the MQTT client \
successfully connects to the broker.
If defined, you need to call `_connect_client()` to establish the \
connection manually.
"""
# Import here to avoid circular dependency on imports
# TODO how to better handle this?
Expand Down Expand Up @@ -639,7 +646,8 @@ def __init__(

# Create the MQTT client, registering the user `on_connect` callback
self._setup_client(on_connect)
# If there is a callback function defined, the user must manually connect to the MQTT client
# If there is a callback function defined, the user must manually connect
# to the MQTT client
if not on_connect:
self._connect_client()

Expand All @@ -660,11 +668,14 @@ def _setup_client(self, on_connect: Optional[Callable] = None) -> None:
"""Create an MQTT client and setup some basic properties on it"""
mqtt_settings = self._settings.mqtt
logger.debug(
f"Creating mqtt client({mqtt_settings.client_name}) for {mqtt_settings.host}"
f"Creating mqtt client({mqtt_settings.client_name}) for "
"{mqtt_settings.host}:{mqtt_settings.port}"
)
self.mqtt_client = mqtt.Client(mqtt_settings.client_name)
if mqtt_settings.tls_key:
logger.info(f"Connecting to {mqtt_settings.host} with SSL")
logger.info(
f"Connecting to {mqtt_settings.host}:{mqtt_settings.port} with SSL"
)
logger.debug(f"ca_certs={mqtt_settings.tls_ca_cert}")
logger.debug(f"certfile={mqtt_settings.tls_certfile}")
logger.debug(f"keyfile={mqtt_settings.tls_key}")
Expand All @@ -676,7 +687,9 @@ def _setup_client(self, on_connect: Optional[Callable] = None) -> None:
tls_version=ssl.PROTOCOL_TLS,
)
else:
logger.debug(f"Connecting to {mqtt_settings.host} without SSL")
logger.debug(
f"Connecting to {mqtt_settings.host}:{mqtt_settings.port} without SSL"
)
if mqtt_settings.username:
self.mqtt_client.username_pw_set(
mqtt_settings.username, password=mqtt_settings.password
Expand All @@ -689,13 +702,17 @@ def _setup_client(self, on_connect: Optional[Callable] = None) -> None:
self.mqtt_client.will_set(self.availability_topic, "offline", retain=True)

def _connect_client(self) -> None:
"""Connect the client to the MQTT broker, start its onw internal loop in a separate thread"""
result = self.mqtt_client.connect(self._settings.mqtt.host)
"""Connect the client to the MQTT broker, start its onw internal loop in
a separate thread"""
result = self.mqtt_client.connect(
self._settings.mqtt.host, self._settings.mqtt.port
)
# Check if we have established a connection
if result != mqtt.MQTT_ERR_SUCCESS:
raise RuntimeError("Error while connecting to MQTT broker")

# Start the internal network loop of the MQTT library to handle incoming messages in a separate thread
# Start the internal network loop of the MQTT library to handle incoming
# messages in a separate thread
self.mqtt_client.loop_start()

def _state_helper(
Expand Down Expand Up @@ -738,7 +755,8 @@ def delete(self) -> None:

config_message = ""
logger.info(
f"Writing '{config_message}' to topic {self.config_topic} on {self._settings.mqtt.host}"
f"Writing '{config_message}' to topic {self.config_topic} on "
"{self._settings.mqtt.host}:{self._settings.mqtt.port}"
)
self.mqtt_client.publish(self.config_topic, config_message, retain=True)

Expand Down Expand Up @@ -771,7 +789,8 @@ def write_config(self):
config_message = json.dumps(self.generate_config())

logger.debug(
f"Writing '{config_message}' to topic {self.config_topic} on {self._settings.mqtt.host}"
f"Writing '{config_message}' to topic {self.config_topic} on "
"{self._settings.mqtt.host}:{self._settings.mqtt.port}"
)
self.wrote_configuration = True
self.config_message = config_message
Expand All @@ -786,7 +805,8 @@ def set_attributes(self, attributes: dict[str, Any]):
"""Update the attributes of the entity
Args:
attributes: dictionary containing all the attributes that will be set for this entity
attributes: dictionary containing all the attributes that will be \
set for this entity
"""
# HA expects a JSON object in the attribute topic
json_attributes = json.dumps(attributes)
Expand Down
1 change: 1 addition & 0 deletions ha_mqtt_discoverable/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def create_base_parser(description: str = "Base parser"):
parser.add_argument("--mqtt-user", type=str, help="MQTT user.")
parser.add_argument("--mqtt-password", type=str, help="MQTT password.")
parser.add_argument("--mqtt-server", type=str, help="MQTT server.")
parser.add_argument("--mqtt-port", type=str, help="MQTT port.", default=1883)
parser.add_argument("--settings-file", type=str, help="Settings file.")

parser.add_argument("--use-tls", "--use-ssl", action="store_true", help="Use TLS.")
Expand Down
8 changes: 7 additions & 1 deletion ha_mqtt_discoverable/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from ha_mqtt_discoverable.utils import read_yaml_file


logger = logging.getLogger(__name__)


Expand All @@ -43,6 +42,7 @@ def load_mqtt_settings(path: str = None, cli=None) -> dict:
settings["mqtt_password"] = cli.mqtt_password
settings["mqtt_prefix"] = cli.mqtt_prefix
settings["mqtt_server"] = cli.mqtt_server
settings["mqtt_port"] = cli.mqtt_port
settings["mqtt_user"] = cli.mqtt_user

# Optional settings - make sure we don't raise an exception if they're unset
Expand Down Expand Up @@ -73,6 +73,8 @@ def load_mqtt_settings(path: str = None, cli=None) -> dict:
raise RuntimeError("No device_name was specified")
if "mqtt_prefix" not in settings:
raise RuntimeError("You need to specify an mqtt prefix")
if "mqtt_port" not in settings:
raise RuntimeError("You need to specify an mqtt port")
if "mqtt_user" not in settings:
raise RuntimeError("No mqtt_user was specified")
if "mqtt_password" not in settings:
Expand Down Expand Up @@ -105,6 +107,8 @@ def sensor_delete_settings(path: str = None, cli=None) -> dict:
settings["mqtt_prefix"] = cli.mqtt_prefix
if cli.mqtt_server:
settings["mqtt_server"] = cli.mqtt_server
if cli.mqtt_port:
settings["mqtt_port"] = cli.mqtt_port
if cli.mqtt_user:
settings["mqtt_user"] = cli.mqtt_user

Expand All @@ -117,6 +121,8 @@ def sensor_delete_settings(path: str = None, cli=None) -> dict:
raise RuntimeError("No device_name was specified")
if "mqtt_prefix" not in settings:
raise RuntimeError("You need to specify an mqtt prefix")
if "mqtt_port" not in settings:
raise RuntimeError("You need to specify an mqtt port")
if "mqtt_user" not in settings:
raise RuntimeError("No mqtt_user was specified")
if "mqtt_password" not in settings:
Expand Down
18 changes: 11 additions & 7 deletions tests/test_discoverable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
# limitations under the License.
#
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
from concurrent.futures import ThreadPoolExecutor
from threading import Event
from unittest.mock import MagicMock

import paho.mqtt.subscribe as subscribe
import pytest
from paho.mqtt.client import (
MQTT_ERR_SUCCESS,
Client,
MQTTMessage,
MQTTv5,
SubscribeOptions,
MQTT_ERR_SUCCESS,
)
import paho.mqtt.subscribe as subscribe
import pytest
from pytest_mock import MockerFixture
from ha_mqtt_discoverable import DeviceInfo, Discoverable, Settings, EntityInfo

from ha_mqtt_discoverable import DeviceInfo, Discoverable, EntityInfo, Settings


@pytest.fixture
Expand Down Expand Up @@ -232,7 +234,8 @@ def test_str(discoverable: Discoverable[EntityInfo]):
# Define a callback function to be invoked when we receive a message on the topic
def message_callback(client: Client, userdata, message: MQTTMessage, tmp=None):
logging.info("Received %s", message)
# If the broker is `dirty` and contains messages send by other test functions, skip these retained messages
# If the broker is `dirty` and contains messages send by other test functions,
# skip these retained messages
if message.retain:
logging.warn("Skipping retained message")
return
Expand Down Expand Up @@ -339,7 +342,8 @@ def test_set_availability(discoverable_availability: Discoverable):


def test_set_availability_wrong_config(discoverable: Discoverable):
"""A discoverable that has not set availability to manual cannot invoke the methods"""
"""A discoverable that has not set availability to manual cannot invoke the \
methods"""
with pytest.raises(RuntimeError):
discoverable.set_availability(True)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import pytest

from ha_mqtt_discoverable import Settings
from ha_mqtt_discoverable.sensors import Sensor, SensorInfo

Expand All @@ -38,7 +39,8 @@ def test_generate_config(sensor: Sensor):
config = sensor.generate_config()

assert config is not None
# If we have defined a custom unit of measurement, check that is part of the output config
# If we have defined a custom unit of measurement, check that is part of the
# output config
if sensor._entity.unit_of_measurement:
assert config["unit_of_measurement"] == sensor._entity.unit_of_measurement

Expand Down

0 comments on commit 9901a43

Please sign in to comment.