mirror of
https://github.com/OMGeeky/homecontrol.esp-sensors.git
synced 2026-01-21 01:51:47 +01:00
refactor: move MQTT client implementation to a separate module for better organization
This commit is contained in:
@@ -26,7 +26,7 @@ The implementation provides the following features:
|
||||
|
||||
### MQTTClient
|
||||
|
||||
The `MQTTClient` class is the core implementation of the MQTT protocol. It handles the low-level details of the MQTT protocol, including packet formatting, socket communication, and protocol state management.
|
||||
The `MQTTClient` class is the core implementation of the MQTT protocol. It handles the low-level details of the MQTT protocol, including packet formatting, socket communication, and protocol state management. This class is now in its own module `mqtt_client.py` for better testability and separation of concerns.
|
||||
|
||||
#### Methods
|
||||
|
||||
@@ -105,7 +105,7 @@ if client.connect():
|
||||
For more control over the MQTT protocol, you can use the `MQTTClient` class directly:
|
||||
|
||||
```python
|
||||
from esp_sensors.mqtt import MQTTClient
|
||||
from esp_sensors.mqtt_client import MQTTClient
|
||||
import time
|
||||
|
||||
# Create client
|
||||
|
||||
@@ -4,363 +4,19 @@ MQTT module for ESP sensors.
|
||||
This module provides functionality to connect to an MQTT broker and publish sensor data.
|
||||
It supports both real hardware and simulation mode.
|
||||
|
||||
This is a custom implementation of the MQTT protocol from scratch, without relying on umqtt.
|
||||
This module uses the MQTTClient class from mqtt_client.py for the core MQTT implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import socket
|
||||
import struct
|
||||
|
||||
# MQTT Protocol Constants
|
||||
MQTT_PROTOCOL_LEVEL = 4 # MQTT 3.1.1
|
||||
MQTT_CLEAN_SESSION = 1
|
||||
|
||||
# MQTT Control Packet Types
|
||||
CONNECT = 0x10
|
||||
CONNACK = 0x20
|
||||
PUBLISH = 0x30
|
||||
PUBACK = 0x40
|
||||
SUBSCRIBE = 0x80
|
||||
SUBACK = 0x90
|
||||
UNSUBSCRIBE = 0xA0
|
||||
UNSUBACK = 0xB0
|
||||
PINGREQ = 0xC0
|
||||
PINGRESP = 0xD0
|
||||
DISCONNECT = 0xE0
|
||||
|
||||
# MQTT Connection Return Codes
|
||||
CONN_ACCEPTED = 0
|
||||
CONN_REFUSED_PROTOCOL = 1
|
||||
CONN_REFUSED_IDENTIFIER = 2
|
||||
CONN_REFUSED_SERVER = 3
|
||||
CONN_REFUSED_USER_PASS = 4
|
||||
CONN_REFUSED_AUTH = 5
|
||||
|
||||
class MQTTException(Exception):
|
||||
"""MQTT Exception class for handling MQTT-specific errors"""
|
||||
pass
|
||||
|
||||
class MQTTClient:
|
||||
"""
|
||||
A basic MQTT client implementation from scratch.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
client_id,
|
||||
server,
|
||||
port=1883,
|
||||
user=None,
|
||||
password=None,
|
||||
keepalive=60,
|
||||
ssl=False,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.keepalive = keepalive
|
||||
self.ssl = ssl
|
||||
self.sock: socket.socket | None = None
|
||||
self.connected = False
|
||||
self.callback = None
|
||||
self.pid = 0 # Packet ID for message tracking
|
||||
self.subscriptions = {} # Track subscribed topics
|
||||
self.last_ping = 0
|
||||
|
||||
def _generate_packet_id(self):
|
||||
"""Generate a unique packet ID for MQTT messages"""
|
||||
self.pid = (self.pid + 1) % 65536
|
||||
return self.pid
|
||||
|
||||
def _encode_length(self, length):
|
||||
"""Encode the remaining length field in the MQTT packet"""
|
||||
result = bytearray()
|
||||
while True:
|
||||
byte = length % 128
|
||||
length = length // 128
|
||||
if length > 0:
|
||||
byte |= 0x80
|
||||
result.append(byte)
|
||||
if length == 0:
|
||||
break
|
||||
return result
|
||||
|
||||
def _encode_string(self, string):
|
||||
"""Encode a string for MQTT packet"""
|
||||
if isinstance(string, str):
|
||||
string = string.encode('utf-8')
|
||||
return bytearray(struct.pack("!H", len(string)) + string)
|
||||
|
||||
def _send_packet(self, packet_type, payload=b''):
|
||||
"""Send an MQTT packet to the broker"""
|
||||
if self.sock is None:
|
||||
raise MQTTException("Not connected to broker (_send_packet)")
|
||||
|
||||
# Construct the packet
|
||||
packet = bytearray()
|
||||
packet.append(packet_type)
|
||||
|
||||
# Add remaining length
|
||||
packet.extend(self._encode_length(len(payload)))
|
||||
|
||||
# Add payload
|
||||
if payload:
|
||||
packet.extend(payload)
|
||||
|
||||
# Send the packet
|
||||
try:
|
||||
self.sock.send(packet)
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
raise MQTTException(f"Failed to send packet: {e}")
|
||||
|
||||
def _recv_packet(self, timeout=1.0):
|
||||
"""Receive an MQTT packet from the broker"""
|
||||
|
||||
if self.sock is None:
|
||||
raise MQTTException("Not connected to broker (_recv_packet)")
|
||||
|
||||
# Set socket timeout
|
||||
self.sock.settimeout(timeout)
|
||||
|
||||
try:
|
||||
# Read packet type
|
||||
packet_type = self.sock.recv(1)
|
||||
if not packet_type:
|
||||
return None, None
|
||||
|
||||
# Read remaining length
|
||||
remaining_length = 0
|
||||
multiplier = 1
|
||||
while True:
|
||||
byte = self.sock.recv(1)[0]
|
||||
remaining_length += (byte & 0x7F) * multiplier
|
||||
multiplier *= 128
|
||||
if not (byte & 0x80):
|
||||
break
|
||||
|
||||
# Read the payload
|
||||
payload = self.sock.recv(remaining_length) if remaining_length else b''
|
||||
|
||||
return packet_type[0], payload
|
||||
except socket.timeout:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
raise MQTTException(f"Failed to receive packet: {e}")
|
||||
|
||||
def connect(self):
|
||||
"""Connect to the MQTT broker"""
|
||||
|
||||
# Create socket
|
||||
try:
|
||||
self.sock = socket.socket()
|
||||
print(f"[MQTT] Connecting to Socket {self.server}:{self.port} as {self.client_id}")
|
||||
self.sock.connect((self.server, self.port))
|
||||
print(f"[MQTT] Connected to {self.server}:{self.port}")
|
||||
except Exception as e:
|
||||
print(f"Error connecting to MQTT broker: {e}")
|
||||
raise MQTTException(f"Failed to connect to {self.server}:{self.port}: {e}")
|
||||
|
||||
# Construct CONNECT packet
|
||||
payload = bytearray()
|
||||
|
||||
# Protocol name and level
|
||||
payload.extend(self._encode_string("MQTT"))
|
||||
payload.append(MQTT_PROTOCOL_LEVEL)
|
||||
|
||||
# Connect flags
|
||||
connect_flags = 0
|
||||
if self.user:
|
||||
connect_flags |= 0x80
|
||||
if self.password:
|
||||
connect_flags |= 0x40
|
||||
connect_flags |= MQTT_CLEAN_SESSION << 1
|
||||
payload.append(connect_flags)
|
||||
|
||||
# Keepalive (in seconds)
|
||||
payload.extend(struct.pack("!H", self.keepalive))
|
||||
|
||||
# Client ID
|
||||
payload.extend(self._encode_string(self.client_id))
|
||||
|
||||
# Username and password if provided
|
||||
if self.user:
|
||||
payload.extend(self._encode_string(self.user))
|
||||
if self.password:
|
||||
payload.extend(self._encode_string(self.password))
|
||||
|
||||
# Send CONNECT packet
|
||||
self._send_packet(CONNECT, payload)
|
||||
|
||||
# Wait for CONNACK
|
||||
packet_type, payload = self._recv_packet()
|
||||
if packet_type != CONNACK:
|
||||
raise MQTTException(f"Unexpected response from broker: {packet_type}")
|
||||
|
||||
# Check connection result
|
||||
if len(payload) != 2:
|
||||
raise MQTTException("Invalid CONNACK packet")
|
||||
|
||||
if payload[1] != CONN_ACCEPTED:
|
||||
raise MQTTException(f"Connection refused: {payload[1]}")
|
||||
|
||||
self.connected = True
|
||||
self.last_ping = time.time()
|
||||
return 0
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from the MQTT broker"""
|
||||
|
||||
if self.connected:
|
||||
try:
|
||||
self._send_packet(DISCONNECT)
|
||||
self.sock.close()
|
||||
except Exception as e:
|
||||
print(f"Error during disconnect: {e}")
|
||||
finally:
|
||||
self.connected = False
|
||||
self.sock = None
|
||||
|
||||
def ping(self):
|
||||
"""Send PINGREQ to keep the connection alive"""
|
||||
|
||||
if self.connected:
|
||||
self._send_packet(PINGREQ)
|
||||
packet_type, _ = self._recv_packet()
|
||||
if packet_type != PINGRESP:
|
||||
self.connected = False
|
||||
raise MQTTException("No PINGRESP received")
|
||||
self.last_ping = time.time()
|
||||
|
||||
def publish(self, topic, msg, retain=False, qos=0):
|
||||
"""Publish a message to a topic"""
|
||||
|
||||
if not self.connected:
|
||||
raise MQTTException("Not connected to broker (publish)")
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Convert topic and message to bytes if they're not already
|
||||
if isinstance(topic, str):
|
||||
topic = topic.encode('utf-8')
|
||||
if isinstance(msg, str):
|
||||
msg = msg.encode('utf-8')
|
||||
|
||||
# Construct PUBLISH packet
|
||||
packet_type = PUBLISH
|
||||
if retain:
|
||||
packet_type |= 0x01
|
||||
if qos:
|
||||
packet_type |= (qos << 1)
|
||||
|
||||
# Payload: topic + message
|
||||
payload = self._encode_string(topic)
|
||||
|
||||
# Add packet ID for QoS > 0
|
||||
if qos > 0:
|
||||
pid = self._generate_packet_id()
|
||||
payload.extend(struct.pack("!H", pid))
|
||||
|
||||
payload.extend(msg)
|
||||
|
||||
# Send PUBLISH packet
|
||||
self._send_packet(packet_type, payload)
|
||||
|
||||
# For QoS 1, wait for PUBACK
|
||||
if qos == 1:
|
||||
packet_type, _ = self._recv_packet()
|
||||
if packet_type != PUBACK:
|
||||
raise MQTTException(f"No PUBACK received: {packet_type}")
|
||||
|
||||
return
|
||||
|
||||
def subscribe(self, topic, qos=0):
|
||||
"""Subscribe to a topic"""
|
||||
|
||||
if not self.connected:
|
||||
raise MQTTException("Not connected to broker (subscribe)")
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Convert topic to bytes if it's not already
|
||||
if isinstance(topic, str):
|
||||
topic = topic.encode('utf-8')
|
||||
|
||||
# Generate packet ID
|
||||
pid = self._generate_packet_id()
|
||||
|
||||
# Construct SUBSCRIBE packet
|
||||
payload = bytearray(struct.pack("!H", pid))
|
||||
payload.extend(self._encode_string(topic))
|
||||
payload.append(qos)
|
||||
|
||||
# Send SUBSCRIBE packet
|
||||
self._send_packet(SUBSCRIBE | 0x02, payload)
|
||||
|
||||
# Wait for SUBACK
|
||||
packet_type, payload = self._recv_packet()
|
||||
if packet_type != SUBACK:
|
||||
raise MQTTException(f"No SUBACK received: {packet_type}")
|
||||
|
||||
# Store subscription
|
||||
topic_str = topic.decode('utf-8') if isinstance(topic, bytes) else topic
|
||||
self.subscriptions[topic_str] = qos
|
||||
|
||||
return
|
||||
|
||||
def set_callback(self, callback):
|
||||
"""Set callback for received messages"""
|
||||
self.callback = callback
|
||||
|
||||
def check_msg(self):
|
||||
"""Check for pending messages from the broker"""
|
||||
|
||||
if not self.connected:
|
||||
return
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Try to receive a packet with a short timeout
|
||||
packet_type, payload = self._recv_packet(timeout=0.1)
|
||||
|
||||
if packet_type is None:
|
||||
return
|
||||
|
||||
if packet_type & 0xF0 == PUBLISH:
|
||||
# Extract flags
|
||||
dup = (packet_type & 0x08) >> 3
|
||||
qos = (packet_type & 0x06) >> 1
|
||||
retain = packet_type & 0x01
|
||||
|
||||
# Extract topic
|
||||
topic_len = struct.unpack("!H", payload[0:2])[0]
|
||||
topic = payload[2:2+topic_len]
|
||||
|
||||
# Extract packet ID for QoS > 0
|
||||
if qos > 0:
|
||||
pid = struct.unpack("!H", payload[2+topic_len:2+topic_len+2])[0]
|
||||
message = payload[2+topic_len+2:]
|
||||
|
||||
# Send PUBACK for QoS 1
|
||||
if qos == 1:
|
||||
self._send_packet(PUBACK, struct.pack("!H", pid))
|
||||
else:
|
||||
message = payload[2+topic_len:]
|
||||
|
||||
# Call the callback if set
|
||||
if self.callback:
|
||||
self.callback(topic, message)
|
||||
|
||||
return
|
||||
from .mqtt_client import (
|
||||
MQTTClient, MQTTException,
|
||||
CONNECT, CONNACK, PUBLISH, PUBACK, SUBSCRIBE, SUBACK,
|
||||
UNSUBSCRIBE, UNSUBACK, PINGREQ, PINGRESP, DISCONNECT,
|
||||
CONN_ACCEPTED, CONN_REFUSED_PROTOCOL, CONN_REFUSED_IDENTIFIER,
|
||||
CONN_REFUSED_SERVER, CONN_REFUSED_USER_PASS, CONN_REFUSED_AUTH,
|
||||
MQTT_PROTOCOL_LEVEL, MQTT_CLEAN_SESSION
|
||||
)
|
||||
|
||||
|
||||
class ESP32MQTTClient:
|
||||
|
||||
475
src/esp_sensors/mqtt_client.py
Normal file
475
src/esp_sensors/mqtt_client.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""
|
||||
MQTT Client module for ESP sensors.
|
||||
|
||||
This module provides a basic MQTT client implementation from scratch,
|
||||
without relying on umqtt. It handles the low-level details of the MQTT protocol
|
||||
and provides a simple interface for connecting to an MQTT broker, publishing
|
||||
messages, and subscribing to topics.
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import socket
|
||||
import struct
|
||||
|
||||
# MQTT Protocol Constants
|
||||
MQTT_PROTOCOL_LEVEL = 4 # MQTT 3.1.1
|
||||
MQTT_CLEAN_SESSION = 1
|
||||
|
||||
# MQTT Control Packet Types
|
||||
CONNECT = 0x10
|
||||
CONNACK = 0x20
|
||||
PUBLISH = 0x30
|
||||
PUBACK = 0x40
|
||||
SUBSCRIBE = 0x80
|
||||
SUBACK = 0x90
|
||||
UNSUBSCRIBE = 0xA0
|
||||
UNSUBACK = 0xB0
|
||||
PINGREQ = 0xC0
|
||||
PINGRESP = 0xD0
|
||||
DISCONNECT = 0xE0
|
||||
|
||||
# MQTT Connection Return Codes
|
||||
CONN_ACCEPTED = 0
|
||||
CONN_REFUSED_PROTOCOL = 1
|
||||
CONN_REFUSED_IDENTIFIER = 2
|
||||
CONN_REFUSED_SERVER = 3
|
||||
CONN_REFUSED_USER_PASS = 4
|
||||
CONN_REFUSED_AUTH = 5
|
||||
|
||||
class MQTTException(Exception):
|
||||
"""MQTT Exception class for handling MQTT-specific errors"""
|
||||
pass
|
||||
|
||||
class MQTTClient:
|
||||
"""
|
||||
A basic MQTT client implementation from scratch.
|
||||
|
||||
This class implements the MQTT protocol directly using socket communication.
|
||||
It provides functionality for connecting to an MQTT broker, publishing messages,
|
||||
subscribing to topics, and receiving messages.
|
||||
|
||||
Attributes:
|
||||
client_id (str): Unique identifier for this client
|
||||
server (str): MQTT broker address
|
||||
port (int): MQTT broker port
|
||||
user (str): Username for authentication
|
||||
password (str): Password for authentication
|
||||
keepalive (int): Keepalive interval in seconds
|
||||
ssl (bool): Whether to use SSL/TLS
|
||||
sock (socket.socket): Socket connection to the broker
|
||||
connected (bool): Whether the client is connected to the broker
|
||||
callback (callable): Callback function for received messages
|
||||
pid (int): Packet ID for message tracking
|
||||
subscriptions (dict): Dictionary of subscribed topics
|
||||
last_ping (float): Timestamp of the last ping
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
client_id,
|
||||
server,
|
||||
port=1883,
|
||||
user=None,
|
||||
password=None,
|
||||
keepalive=60,
|
||||
ssl=False,
|
||||
):
|
||||
"""
|
||||
Initialize the MQTT client.
|
||||
|
||||
Args:
|
||||
client_id (str): Unique identifier for this client
|
||||
server (str): MQTT broker address
|
||||
port (int): MQTT broker port
|
||||
user (str): Username for authentication
|
||||
password (str): Password for authentication
|
||||
keepalive (int): Keepalive interval in seconds
|
||||
ssl (bool): Whether to use SSL/TLS
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.keepalive = keepalive
|
||||
self.ssl = ssl
|
||||
self.sock: socket.socket | None = None
|
||||
self.connected = False
|
||||
self.callback = None
|
||||
self.pid = 0 # Packet ID for message tracking
|
||||
self.subscriptions = {} # Track subscribed topics
|
||||
self.last_ping = 0
|
||||
|
||||
def _generate_packet_id(self):
|
||||
"""
|
||||
Generate a unique packet ID for MQTT messages.
|
||||
|
||||
Returns:
|
||||
int: A unique packet ID between 1 and 65535
|
||||
"""
|
||||
self.pid = (self.pid + 1) % 65536
|
||||
return self.pid
|
||||
|
||||
def _encode_length(self, length):
|
||||
"""
|
||||
Encode the remaining length field in the MQTT packet.
|
||||
|
||||
Args:
|
||||
length (int): The length to encode
|
||||
|
||||
Returns:
|
||||
bytearray: The encoded length
|
||||
"""
|
||||
result = bytearray()
|
||||
while True:
|
||||
byte = length % 128
|
||||
length = length // 128
|
||||
if length > 0:
|
||||
byte |= 0x80
|
||||
result.append(byte)
|
||||
if length == 0:
|
||||
break
|
||||
return result
|
||||
|
||||
def _encode_string(self, string):
|
||||
"""
|
||||
Encode a string for MQTT packet.
|
||||
|
||||
Args:
|
||||
string (str or bytes): The string to encode
|
||||
|
||||
Returns:
|
||||
bytearray: The encoded string
|
||||
"""
|
||||
if isinstance(string, str):
|
||||
string = string.encode('utf-8')
|
||||
return bytearray(struct.pack("!H", len(string)) + string)
|
||||
|
||||
def _send_packet(self, packet_type, payload=b''):
|
||||
"""
|
||||
Send an MQTT packet to the broker.
|
||||
|
||||
Args:
|
||||
packet_type (int): The MQTT packet type
|
||||
payload (bytes): The packet payload
|
||||
|
||||
Raises:
|
||||
MQTTException: If the client is not connected or sending fails
|
||||
"""
|
||||
if self.sock is None:
|
||||
raise MQTTException("Not connected to broker (_send_packet)")
|
||||
|
||||
# Construct the packet
|
||||
packet = bytearray()
|
||||
packet.append(packet_type)
|
||||
|
||||
# Add remaining length
|
||||
packet.extend(self._encode_length(len(payload)))
|
||||
|
||||
# Add payload
|
||||
if payload:
|
||||
packet.extend(payload)
|
||||
|
||||
# Send the packet
|
||||
try:
|
||||
self.sock.send(packet)
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
raise MQTTException(f"Failed to send packet: {e}")
|
||||
|
||||
def _recv_packet(self, timeout=1.0):
|
||||
"""
|
||||
Receive an MQTT packet from the broker.
|
||||
|
||||
Args:
|
||||
timeout (float): Socket timeout in seconds
|
||||
|
||||
Returns:
|
||||
tuple: (packet_type, payload) or (None, None) if no packet received
|
||||
|
||||
Raises:
|
||||
MQTTException: If the client is not connected or receiving fails
|
||||
"""
|
||||
if self.sock is None:
|
||||
raise MQTTException("Not connected to broker (_recv_packet)")
|
||||
|
||||
# Set socket timeout
|
||||
self.sock.settimeout(timeout)
|
||||
|
||||
try:
|
||||
# Read packet type
|
||||
packet_type = self.sock.recv(1)
|
||||
if not packet_type:
|
||||
return None, None
|
||||
|
||||
# Read remaining length
|
||||
remaining_length = 0
|
||||
multiplier = 1
|
||||
while True:
|
||||
byte = self.sock.recv(1)[0]
|
||||
remaining_length += (byte & 0x7F) * multiplier
|
||||
multiplier *= 128
|
||||
if not (byte & 0x80):
|
||||
break
|
||||
|
||||
# Read the payload
|
||||
payload = self.sock.recv(remaining_length) if remaining_length else b''
|
||||
|
||||
return packet_type[0], payload
|
||||
except socket.timeout:
|
||||
return None, None
|
||||
except Exception as e:
|
||||
self.connected = False
|
||||
raise MQTTException(f"Failed to receive packet: {e}")
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the MQTT broker.
|
||||
|
||||
Returns:
|
||||
int: 0 if successful, otherwise an error code
|
||||
|
||||
Raises:
|
||||
MQTTException: If connection fails
|
||||
"""
|
||||
# Create socket
|
||||
try:
|
||||
self.sock = socket.socket()
|
||||
print(f"[MQTT] Connecting to Socket {self.server}:{self.port} as {self.client_id}")
|
||||
self.sock.connect((self.server, self.port))
|
||||
print(f"[MQTT] Connected to {self.server}:{self.port}")
|
||||
except Exception as e:
|
||||
print(f"Error connecting to MQTT broker: {e}")
|
||||
raise MQTTException(f"Failed to connect to {self.server}:{self.port}: {e}")
|
||||
|
||||
# Construct CONNECT packet
|
||||
payload = bytearray()
|
||||
|
||||
# Protocol name and level
|
||||
payload.extend(self._encode_string("MQTT"))
|
||||
payload.append(MQTT_PROTOCOL_LEVEL)
|
||||
|
||||
# Connect flags
|
||||
connect_flags = 0
|
||||
if self.user:
|
||||
connect_flags |= 0x80
|
||||
if self.password:
|
||||
connect_flags |= 0x40
|
||||
connect_flags |= MQTT_CLEAN_SESSION << 1
|
||||
payload.append(connect_flags)
|
||||
|
||||
# Keepalive (in seconds)
|
||||
payload.extend(struct.pack("!H", self.keepalive))
|
||||
|
||||
# Client ID
|
||||
payload.extend(self._encode_string(self.client_id))
|
||||
|
||||
# Username and password if provided
|
||||
if self.user:
|
||||
payload.extend(self._encode_string(self.user))
|
||||
if self.password:
|
||||
payload.extend(self._encode_string(self.password))
|
||||
|
||||
# Send CONNECT packet
|
||||
self._send_packet(CONNECT, payload)
|
||||
|
||||
# Wait for CONNACK
|
||||
packet_type, payload = self._recv_packet()
|
||||
if packet_type != CONNACK:
|
||||
raise MQTTException(f"Unexpected response from broker: {packet_type}")
|
||||
|
||||
# Check connection result
|
||||
if len(payload) != 2:
|
||||
raise MQTTException("Invalid CONNACK packet")
|
||||
|
||||
if payload[1] != CONN_ACCEPTED:
|
||||
raise MQTTException(f"Connection refused: {payload[1]}")
|
||||
|
||||
self.connected = True
|
||||
self.last_ping = time.time()
|
||||
return 0
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnect from the MQTT broker.
|
||||
"""
|
||||
if self.connected:
|
||||
try:
|
||||
self._send_packet(DISCONNECT)
|
||||
self.sock.close()
|
||||
except Exception as e:
|
||||
print(f"Error during disconnect: {e}")
|
||||
finally:
|
||||
self.connected = False
|
||||
self.sock = None
|
||||
|
||||
def ping(self):
|
||||
"""
|
||||
Send PINGREQ to keep the connection alive.
|
||||
|
||||
Raises:
|
||||
MQTTException: If no PINGRESP is received
|
||||
"""
|
||||
if self.connected:
|
||||
self._send_packet(PINGREQ)
|
||||
packet_type, _ = self._recv_packet()
|
||||
if packet_type != PINGRESP:
|
||||
self.connected = False
|
||||
raise MQTTException("No PINGRESP received")
|
||||
self.last_ping = time.time()
|
||||
|
||||
def publish(self, topic, msg, retain=False, qos=0):
|
||||
"""
|
||||
Publish a message to a topic.
|
||||
|
||||
Args:
|
||||
topic (str or bytes): The topic to publish to
|
||||
msg (str or bytes): The message to publish
|
||||
retain (bool): Whether the message should be retained by the broker
|
||||
qos (int): Quality of Service level (0 or 1)
|
||||
|
||||
Raises:
|
||||
MQTTException: If the client is not connected or publishing fails
|
||||
"""
|
||||
if not self.connected:
|
||||
raise MQTTException("Not connected to broker (publish)")
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Convert topic and message to bytes if they're not already
|
||||
if isinstance(topic, str):
|
||||
topic = topic.encode('utf-8')
|
||||
if isinstance(msg, str):
|
||||
msg = msg.encode('utf-8')
|
||||
|
||||
# Construct PUBLISH packet
|
||||
packet_type = PUBLISH
|
||||
if retain:
|
||||
packet_type |= 0x01
|
||||
if qos:
|
||||
packet_type |= (qos << 1)
|
||||
|
||||
# Payload: topic + message
|
||||
payload = self._encode_string(topic)
|
||||
|
||||
# Add packet ID for QoS > 0
|
||||
if qos > 0:
|
||||
pid = self._generate_packet_id()
|
||||
payload.extend(struct.pack("!H", pid))
|
||||
|
||||
payload.extend(msg)
|
||||
|
||||
# Send PUBLISH packet
|
||||
self._send_packet(packet_type, payload)
|
||||
|
||||
# For QoS 1, wait for PUBACK
|
||||
if qos == 1:
|
||||
packet_type, _ = self._recv_packet()
|
||||
if packet_type != PUBACK:
|
||||
raise MQTTException(f"No PUBACK received: {packet_type}")
|
||||
|
||||
return
|
||||
|
||||
def subscribe(self, topic, qos=0):
|
||||
"""
|
||||
Subscribe to a topic.
|
||||
|
||||
Args:
|
||||
topic (str or bytes): The topic to subscribe to
|
||||
qos (int): Quality of Service level
|
||||
|
||||
Raises:
|
||||
MQTTException: If the client is not connected or subscription fails
|
||||
"""
|
||||
if not self.connected:
|
||||
raise MQTTException("Not connected to broker (subscribe)")
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Convert topic to bytes if it's not already
|
||||
if isinstance(topic, str):
|
||||
topic = topic.encode('utf-8')
|
||||
|
||||
# Generate packet ID
|
||||
pid = self._generate_packet_id()
|
||||
|
||||
# Construct SUBSCRIBE packet
|
||||
payload = bytearray(struct.pack("!H", pid))
|
||||
payload.extend(self._encode_string(topic))
|
||||
payload.append(qos)
|
||||
|
||||
# Send SUBSCRIBE packet
|
||||
self._send_packet(SUBSCRIBE | 0x02, payload)
|
||||
|
||||
# Wait for SUBACK
|
||||
packet_type, payload = self._recv_packet()
|
||||
if packet_type != SUBACK:
|
||||
raise MQTTException(f"No SUBACK received: {packet_type}")
|
||||
|
||||
# Store subscription
|
||||
topic_str = topic.decode('utf-8') if isinstance(topic, bytes) else topic
|
||||
self.subscriptions[topic_str] = qos
|
||||
|
||||
return
|
||||
|
||||
def set_callback(self, callback):
|
||||
"""
|
||||
Set callback for received messages.
|
||||
|
||||
Args:
|
||||
callback (callable): Function to call when a message is received.
|
||||
The callback should accept two parameters:
|
||||
topic (bytes) and message (bytes).
|
||||
"""
|
||||
self.callback = callback
|
||||
|
||||
def check_msg(self):
|
||||
"""
|
||||
Check for pending messages from the broker.
|
||||
|
||||
This method should be called regularly to process incoming messages.
|
||||
If a callback is set, it will be called with the topic and message.
|
||||
"""
|
||||
if not self.connected:
|
||||
return
|
||||
|
||||
# Check if we need to ping to keep connection alive
|
||||
if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive:
|
||||
self.ping()
|
||||
|
||||
# Try to receive a packet with a short timeout
|
||||
packet_type, payload = self._recv_packet(timeout=0.1)
|
||||
|
||||
if packet_type is None:
|
||||
return
|
||||
|
||||
if packet_type & 0xF0 == PUBLISH:
|
||||
# Extract flags
|
||||
dup = (packet_type & 0x08) >> 3
|
||||
qos = (packet_type & 0x06) >> 1
|
||||
retain = packet_type & 0x01
|
||||
|
||||
# Extract topic
|
||||
topic_len = struct.unpack("!H", payload[0:2])[0]
|
||||
topic = payload[2:2+topic_len]
|
||||
|
||||
# Extract packet ID for QoS > 0
|
||||
if qos > 0:
|
||||
pid = struct.unpack("!H", payload[2+topic_len:2+topic_len+2])[0]
|
||||
message = payload[2+topic_len+2:]
|
||||
|
||||
# Send PUBACK for QoS 1
|
||||
if qos == 1:
|
||||
self._send_packet(PUBACK, struct.pack("!H", pid))
|
||||
else:
|
||||
message = payload[2+topic_len:]
|
||||
|
||||
# Call the callback if set
|
||||
if self.callback:
|
||||
self.callback(topic, message)
|
||||
|
||||
return
|
||||
@@ -5,7 +5,8 @@ Tests for the MQTT module.
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.esp_sensors.mqtt import setup_mqtt, publish_sensor_data, MQTTClient
|
||||
from src.esp_sensors.mqtt_client import MQTTClient
|
||||
from src.esp_sensors.mqtt import setup_mqtt, publish_sensor_data
|
||||
|
||||
|
||||
class TestSensor:
|
||||
|
||||
266
tests/test_mqtt_client.py
Normal file
266
tests/test_mqtt_client.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Tests for the MQTT Client module.
|
||||
|
||||
This module contains tests for the MQTTClient class in the mqtt_client.py module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import socket
|
||||
import struct
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
from src.esp_sensors.mqtt_client import (
|
||||
MQTTClient, MQTTException,
|
||||
CONNECT, CONNACK, PUBLISH, PUBACK, SUBSCRIBE, SUBACK,
|
||||
PINGREQ, PINGRESP, DISCONNECT,
|
||||
CONN_ACCEPTED
|
||||
)
|
||||
|
||||
|
||||
class TestMQTTClient:
|
||||
"""Tests for the MQTTClient class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mqtt_client(self):
|
||||
"""Fixture providing a basic MQTTClient instance."""
|
||||
return MQTTClient(
|
||||
client_id="test_client",
|
||||
server="test.mosquitto.org",
|
||||
port=1883,
|
||||
user="test_user",
|
||||
password="test_pass",
|
||||
keepalive=60,
|
||||
ssl=False
|
||||
)
|
||||
|
||||
def test_init(self, mqtt_client):
|
||||
"""Test that the MQTTClient initializes with the correct attributes."""
|
||||
assert mqtt_client.client_id == "test_client"
|
||||
assert mqtt_client.server == "test.mosquitto.org"
|
||||
assert mqtt_client.port == 1883
|
||||
assert mqtt_client.user == "test_user"
|
||||
assert mqtt_client.password == "test_pass"
|
||||
assert mqtt_client.keepalive == 60
|
||||
assert mqtt_client.ssl is False
|
||||
assert mqtt_client.sock is None
|
||||
assert mqtt_client.connected is False
|
||||
assert mqtt_client.callback is None
|
||||
assert mqtt_client.pid == 0
|
||||
assert mqtt_client.subscriptions == {}
|
||||
assert mqtt_client.last_ping == 0
|
||||
|
||||
def test_generate_packet_id(self, mqtt_client):
|
||||
"""Test that _generate_packet_id returns sequential IDs and wraps around."""
|
||||
# First call should return 1
|
||||
assert mqtt_client._generate_packet_id() == 1
|
||||
assert mqtt_client.pid == 1
|
||||
|
||||
# Second call should return 2
|
||||
assert mqtt_client._generate_packet_id() == 2
|
||||
assert mqtt_client.pid == 2
|
||||
|
||||
# Set pid to 65535 (max value)
|
||||
mqtt_client.pid = 65535
|
||||
|
||||
# Next call should wrap around to 0
|
||||
assert mqtt_client._generate_packet_id() == 0
|
||||
assert mqtt_client.pid == 0
|
||||
|
||||
def test_encode_length(self, mqtt_client):
|
||||
"""Test that _encode_length correctly encodes MQTT remaining length."""
|
||||
# Test small length (< 128)
|
||||
assert list(mqtt_client._encode_length(64)) == [64]
|
||||
|
||||
# Test medium length (128-16383)
|
||||
assert list(mqtt_client._encode_length(128)) == [128 & 0x7F | 0x80, 1]
|
||||
assert list(mqtt_client._encode_length(8192)) == [0x80, 0x40]
|
||||
|
||||
# Test large length (16384-2097151)
|
||||
assert list(mqtt_client._encode_length(2097151)) == [0xFF, 0xFF, 0x7F]
|
||||
|
||||
def test_encode_string(self, mqtt_client):
|
||||
"""Test that _encode_string correctly encodes strings for MQTT packets."""
|
||||
# Test with string input
|
||||
result = mqtt_client._encode_string("test")
|
||||
assert len(result) == 6 # 2 bytes length + 4 bytes string
|
||||
assert result[0:2] == b'\x00\x04' # Length (4) in network byte order
|
||||
assert result[2:] == b'test' # String content
|
||||
|
||||
# Test with bytes input
|
||||
result = mqtt_client._encode_string(b"test")
|
||||
assert len(result) == 6
|
||||
assert result[0:2] == b'\x00\x04'
|
||||
assert result[2:] == b'test'
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_connect_success(self, mock_socket, mqtt_client):
|
||||
"""Test successful connection to MQTT broker."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Configure the mock socket to return a successful CONNACK
|
||||
mock_sock.recv.side_effect = [
|
||||
b'\x20', # CONNACK packet type
|
||||
b'\x02', # Remaining length
|
||||
b'\x00\x00' # Session present flag (0) + return code (0 = accepted)
|
||||
]
|
||||
|
||||
# Call connect
|
||||
result = mqtt_client.connect()
|
||||
|
||||
# Verify socket was created and connected
|
||||
mock_socket.assert_called_once()
|
||||
mock_sock.connect.assert_called_once_with(("test.mosquitto.org", 1883))
|
||||
|
||||
# Verify CONNECT packet was sent
|
||||
mock_sock.send.assert_called_once()
|
||||
|
||||
# Verify result
|
||||
assert result == 0
|
||||
assert mqtt_client.connected is True
|
||||
assert mqtt_client.sock is mock_sock
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_connect_failure(self, mock_socket, mqtt_client):
|
||||
"""Test connection failure to MQTT broker."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Configure the mock socket to return a failed CONNACK
|
||||
mock_sock.recv.side_effect = [
|
||||
b'\x20', # CONNACK packet type
|
||||
b'\x02', # Remaining length
|
||||
b'\x00\x01' # Session present flag (0) + return code (1 = refused, protocol version)
|
||||
]
|
||||
|
||||
# Call connect and verify it raises an exception
|
||||
with pytest.raises(MQTTException, match="Connection refused: 1"):
|
||||
mqtt_client.connect()
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_disconnect(self, mock_socket, mqtt_client):
|
||||
"""Test disconnection from MQTT broker."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Set up the client as connected
|
||||
mqtt_client.sock = mock_sock
|
||||
mqtt_client.connected = True
|
||||
|
||||
# Call disconnect
|
||||
mqtt_client.disconnect()
|
||||
|
||||
# Verify DISCONNECT packet was sent
|
||||
mock_sock.send.assert_called_once()
|
||||
|
||||
# Verify socket was closed
|
||||
mock_sock.close.assert_called_once()
|
||||
|
||||
# Verify client state
|
||||
assert mqtt_client.connected is False
|
||||
assert mqtt_client.sock is None
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_publish(self, mock_socket, mqtt_client):
|
||||
"""Test publishing a message to a topic."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Set up the client as connected
|
||||
mqtt_client.sock = mock_sock
|
||||
mqtt_client.connected = True
|
||||
mqtt_client.last_ping = 0
|
||||
|
||||
# Call publish
|
||||
mqtt_client.publish("test/topic", "test message")
|
||||
|
||||
# Verify PUBLISH packet was sent
|
||||
mock_sock.send.assert_called_once()
|
||||
|
||||
# Test with QoS 1
|
||||
mock_sock.reset_mock()
|
||||
mock_sock.recv.side_effect = [
|
||||
b'\x40', # PUBACK packet type
|
||||
b'\x02', # Remaining length
|
||||
b'\x00\x01' # Packet ID
|
||||
]
|
||||
|
||||
mqtt_client.publish("test/topic", "test message", qos=1)
|
||||
|
||||
# Verify PUBLISH packet was sent
|
||||
assert mock_sock.send.call_count == 1
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_subscribe(self, mock_socket, mqtt_client):
|
||||
"""Test subscribing to a topic."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Set up the client as connected
|
||||
mqtt_client.sock = mock_sock
|
||||
mqtt_client.connected = True
|
||||
mqtt_client.last_ping = 0
|
||||
|
||||
# Configure the mock socket to return a successful SUBACK
|
||||
mock_sock.recv.side_effect = [
|
||||
b'\x90', # SUBACK packet type
|
||||
b'\x03', # Remaining length
|
||||
b'\x00\x01\x00' # Packet ID + return code (0 = success)
|
||||
]
|
||||
|
||||
# Call subscribe
|
||||
mqtt_client.subscribe("test/topic")
|
||||
|
||||
# Verify SUBSCRIBE packet was sent
|
||||
mock_sock.send.assert_called_once()
|
||||
|
||||
# Verify subscription was stored
|
||||
assert "test/topic" in mqtt_client.subscriptions
|
||||
assert mqtt_client.subscriptions["test/topic"] == 0
|
||||
|
||||
@patch('socket.socket')
|
||||
def test_check_msg(self, mock_socket, mqtt_client):
|
||||
"""Test checking for messages."""
|
||||
# Configure the mock socket
|
||||
mock_sock = MagicMock()
|
||||
mock_socket.return_value = mock_sock
|
||||
|
||||
# Set up the client as connected
|
||||
mqtt_client.sock = mock_sock
|
||||
mqtt_client.connected = True
|
||||
mqtt_client.last_ping = 0
|
||||
|
||||
# Set up a mock callback
|
||||
mock_callback = MagicMock()
|
||||
mqtt_client.set_callback(mock_callback)
|
||||
|
||||
# Configure the mock socket to return a PUBLISH packet
|
||||
topic = "test/topic"
|
||||
message = "test message"
|
||||
topic_encoded = struct.pack("!H", len(topic)) + topic.encode()
|
||||
mock_sock.recv.side_effect = [
|
||||
bytes([PUBLISH]), # PUBLISH packet type
|
||||
bytes([len(topic_encoded) + len(message)]), # Remaining length
|
||||
topic_encoded + message.encode() # Topic + message
|
||||
]
|
||||
|
||||
# Call check_msg
|
||||
mqtt_client.check_msg()
|
||||
|
||||
# Verify callback was called with correct parameters
|
||||
mock_callback.assert_called_once_with(topic.encode(), message.encode())
|
||||
|
||||
def test_set_callback(self, mqtt_client):
|
||||
"""Test setting a callback function."""
|
||||
# Create a mock callback
|
||||
mock_callback = MagicMock()
|
||||
|
||||
# Set the callback
|
||||
mqtt_client.set_callback(mock_callback)
|
||||
|
||||
# Verify callback was set
|
||||
assert mqtt_client.callback is mock_callback
|
||||
Reference in New Issue
Block a user