diff --git a/docs/mqtt_implementation.md b/docs/mqtt_implementation.md index 56ba2f6..0f68b6f 100644 --- a/docs/mqtt_implementation.md +++ b/docs/mqtt_implementation.md @@ -26,7 +26,7 @@ The implementation provides the following features: ### MQTTClient -The `MQTTClient` class is the core implementation of the MQTT protocol. It handles the low-level details of the MQTT protocol, including packet formatting, socket communication, and protocol state management. +The `MQTTClient` class is the core implementation of the MQTT protocol. It handles the low-level details of the MQTT protocol, including packet formatting, socket communication, and protocol state management. This class is now in its own module `mqtt_client.py` for better testability and separation of concerns. #### Methods @@ -105,7 +105,7 @@ if client.connect(): For more control over the MQTT protocol, you can use the `MQTTClient` class directly: ```python -from esp_sensors.mqtt import MQTTClient +from esp_sensors.mqtt_client import MQTTClient import time # Create client diff --git a/src/esp_sensors/mqtt.py b/src/esp_sensors/mqtt.py index 4c8e90f..19a1e96 100644 --- a/src/esp_sensors/mqtt.py +++ b/src/esp_sensors/mqtt.py @@ -4,363 +4,19 @@ MQTT module for ESP sensors. This module provides functionality to connect to an MQTT broker and publish sensor data. It supports both real hardware and simulation mode. -This is a custom implementation of the MQTT protocol from scratch, without relying on umqtt. +This module uses the MQTTClient class from mqtt_client.py for the core MQTT implementation. """ import time import json -import socket -import struct - -# MQTT Protocol Constants -MQTT_PROTOCOL_LEVEL = 4 # MQTT 3.1.1 -MQTT_CLEAN_SESSION = 1 - -# MQTT Control Packet Types -CONNECT = 0x10 -CONNACK = 0x20 -PUBLISH = 0x30 -PUBACK = 0x40 -SUBSCRIBE = 0x80 -SUBACK = 0x90 -UNSUBSCRIBE = 0xA0 -UNSUBACK = 0xB0 -PINGREQ = 0xC0 -PINGRESP = 0xD0 -DISCONNECT = 0xE0 - -# MQTT Connection Return Codes -CONN_ACCEPTED = 0 -CONN_REFUSED_PROTOCOL = 1 -CONN_REFUSED_IDENTIFIER = 2 -CONN_REFUSED_SERVER = 3 -CONN_REFUSED_USER_PASS = 4 -CONN_REFUSED_AUTH = 5 - -class MQTTException(Exception): - """MQTT Exception class for handling MQTT-specific errors""" - pass - -class MQTTClient: - """ - A basic MQTT client implementation from scratch. - """ - def __init__( - self, - client_id, - server, - port=1883, - user=None, - password=None, - keepalive=60, - ssl=False, - ): - self.client_id = client_id - self.server = server - self.port = port - self.user = user - self.password = password - self.keepalive = keepalive - self.ssl = ssl - self.sock: socket.socket | None = None - self.connected = False - self.callback = None - self.pid = 0 # Packet ID for message tracking - self.subscriptions = {} # Track subscribed topics - self.last_ping = 0 - - def _generate_packet_id(self): - """Generate a unique packet ID for MQTT messages""" - self.pid = (self.pid + 1) % 65536 - return self.pid - - def _encode_length(self, length): - """Encode the remaining length field in the MQTT packet""" - result = bytearray() - while True: - byte = length % 128 - length = length // 128 - if length > 0: - byte |= 0x80 - result.append(byte) - if length == 0: - break - return result - - def _encode_string(self, string): - """Encode a string for MQTT packet""" - if isinstance(string, str): - string = string.encode('utf-8') - return bytearray(struct.pack("!H", len(string)) + string) - - def _send_packet(self, packet_type, payload=b''): - """Send an MQTT packet to the broker""" - if self.sock is None: - raise MQTTException("Not connected to broker (_send_packet)") - - # Construct the packet - packet = bytearray() - packet.append(packet_type) - - # Add remaining length - packet.extend(self._encode_length(len(payload))) - - # Add payload - if payload: - packet.extend(payload) - - # Send the packet - try: - self.sock.send(packet) - except Exception as e: - self.connected = False - raise MQTTException(f"Failed to send packet: {e}") - - def _recv_packet(self, timeout=1.0): - """Receive an MQTT packet from the broker""" - - if self.sock is None: - raise MQTTException("Not connected to broker (_recv_packet)") - - # Set socket timeout - self.sock.settimeout(timeout) - - try: - # Read packet type - packet_type = self.sock.recv(1) - if not packet_type: - return None, None - - # Read remaining length - remaining_length = 0 - multiplier = 1 - while True: - byte = self.sock.recv(1)[0] - remaining_length += (byte & 0x7F) * multiplier - multiplier *= 128 - if not (byte & 0x80): - break - - # Read the payload - payload = self.sock.recv(remaining_length) if remaining_length else b'' - - return packet_type[0], payload - except socket.timeout: - return None, None - except Exception as e: - self.connected = False - raise MQTTException(f"Failed to receive packet: {e}") - - def connect(self): - """Connect to the MQTT broker""" - - # Create socket - try: - self.sock = socket.socket() - print(f"[MQTT] Connecting to Socket {self.server}:{self.port} as {self.client_id}") - self.sock.connect((self.server, self.port)) - print(f"[MQTT] Connected to {self.server}:{self.port}") - except Exception as e: - print(f"Error connecting to MQTT broker: {e}") - raise MQTTException(f"Failed to connect to {self.server}:{self.port}: {e}") - - # Construct CONNECT packet - payload = bytearray() - - # Protocol name and level - payload.extend(self._encode_string("MQTT")) - payload.append(MQTT_PROTOCOL_LEVEL) - - # Connect flags - connect_flags = 0 - if self.user: - connect_flags |= 0x80 - if self.password: - connect_flags |= 0x40 - connect_flags |= MQTT_CLEAN_SESSION << 1 - payload.append(connect_flags) - - # Keepalive (in seconds) - payload.extend(struct.pack("!H", self.keepalive)) - - # Client ID - payload.extend(self._encode_string(self.client_id)) - - # Username and password if provided - if self.user: - payload.extend(self._encode_string(self.user)) - if self.password: - payload.extend(self._encode_string(self.password)) - - # Send CONNECT packet - self._send_packet(CONNECT, payload) - - # Wait for CONNACK - packet_type, payload = self._recv_packet() - if packet_type != CONNACK: - raise MQTTException(f"Unexpected response from broker: {packet_type}") - - # Check connection result - if len(payload) != 2: - raise MQTTException("Invalid CONNACK packet") - - if payload[1] != CONN_ACCEPTED: - raise MQTTException(f"Connection refused: {payload[1]}") - - self.connected = True - self.last_ping = time.time() - return 0 - - def disconnect(self): - """Disconnect from the MQTT broker""" - - if self.connected: - try: - self._send_packet(DISCONNECT) - self.sock.close() - except Exception as e: - print(f"Error during disconnect: {e}") - finally: - self.connected = False - self.sock = None - - def ping(self): - """Send PINGREQ to keep the connection alive""" - - if self.connected: - self._send_packet(PINGREQ) - packet_type, _ = self._recv_packet() - if packet_type != PINGRESP: - self.connected = False - raise MQTTException("No PINGRESP received") - self.last_ping = time.time() - - def publish(self, topic, msg, retain=False, qos=0): - """Publish a message to a topic""" - - if not self.connected: - raise MQTTException("Not connected to broker (publish)") - - # Check if we need to ping to keep connection alive - if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: - self.ping() - - # Convert topic and message to bytes if they're not already - if isinstance(topic, str): - topic = topic.encode('utf-8') - if isinstance(msg, str): - msg = msg.encode('utf-8') - - # Construct PUBLISH packet - packet_type = PUBLISH - if retain: - packet_type |= 0x01 - if qos: - packet_type |= (qos << 1) - - # Payload: topic + message - payload = self._encode_string(topic) - - # Add packet ID for QoS > 0 - if qos > 0: - pid = self._generate_packet_id() - payload.extend(struct.pack("!H", pid)) - - payload.extend(msg) - - # Send PUBLISH packet - self._send_packet(packet_type, payload) - - # For QoS 1, wait for PUBACK - if qos == 1: - packet_type, _ = self._recv_packet() - if packet_type != PUBACK: - raise MQTTException(f"No PUBACK received: {packet_type}") - - return - - def subscribe(self, topic, qos=0): - """Subscribe to a topic""" - - if not self.connected: - raise MQTTException("Not connected to broker (subscribe)") - - # Check if we need to ping to keep connection alive - if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: - self.ping() - - # Convert topic to bytes if it's not already - if isinstance(topic, str): - topic = topic.encode('utf-8') - - # Generate packet ID - pid = self._generate_packet_id() - - # Construct SUBSCRIBE packet - payload = bytearray(struct.pack("!H", pid)) - payload.extend(self._encode_string(topic)) - payload.append(qos) - - # Send SUBSCRIBE packet - self._send_packet(SUBSCRIBE | 0x02, payload) - - # Wait for SUBACK - packet_type, payload = self._recv_packet() - if packet_type != SUBACK: - raise MQTTException(f"No SUBACK received: {packet_type}") - - # Store subscription - topic_str = topic.decode('utf-8') if isinstance(topic, bytes) else topic - self.subscriptions[topic_str] = qos - - return - - def set_callback(self, callback): - """Set callback for received messages""" - self.callback = callback - - def check_msg(self): - """Check for pending messages from the broker""" - - if not self.connected: - return - - # Check if we need to ping to keep connection alive - if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: - self.ping() - - # Try to receive a packet with a short timeout - packet_type, payload = self._recv_packet(timeout=0.1) - - if packet_type is None: - return - - if packet_type & 0xF0 == PUBLISH: - # Extract flags - dup = (packet_type & 0x08) >> 3 - qos = (packet_type & 0x06) >> 1 - retain = packet_type & 0x01 - - # Extract topic - topic_len = struct.unpack("!H", payload[0:2])[0] - topic = payload[2:2+topic_len] - - # Extract packet ID for QoS > 0 - if qos > 0: - pid = struct.unpack("!H", payload[2+topic_len:2+topic_len+2])[0] - message = payload[2+topic_len+2:] - - # Send PUBACK for QoS 1 - if qos == 1: - self._send_packet(PUBACK, struct.pack("!H", pid)) - else: - message = payload[2+topic_len:] - - # Call the callback if set - if self.callback: - self.callback(topic, message) - - return +from .mqtt_client import ( + MQTTClient, MQTTException, + CONNECT, CONNACK, PUBLISH, PUBACK, SUBSCRIBE, SUBACK, + UNSUBSCRIBE, UNSUBACK, PINGREQ, PINGRESP, DISCONNECT, + CONN_ACCEPTED, CONN_REFUSED_PROTOCOL, CONN_REFUSED_IDENTIFIER, + CONN_REFUSED_SERVER, CONN_REFUSED_USER_PASS, CONN_REFUSED_AUTH, + MQTT_PROTOCOL_LEVEL, MQTT_CLEAN_SESSION +) class ESP32MQTTClient: diff --git a/src/esp_sensors/mqtt_client.py b/src/esp_sensors/mqtt_client.py new file mode 100644 index 0000000..2450e6a --- /dev/null +++ b/src/esp_sensors/mqtt_client.py @@ -0,0 +1,475 @@ +""" +MQTT Client module for ESP sensors. + +This module provides a basic MQTT client implementation from scratch, +without relying on umqtt. It handles the low-level details of the MQTT protocol +and provides a simple interface for connecting to an MQTT broker, publishing +messages, and subscribing to topics. +""" + +import time +import json +import socket +import struct + +# MQTT Protocol Constants +MQTT_PROTOCOL_LEVEL = 4 # MQTT 3.1.1 +MQTT_CLEAN_SESSION = 1 + +# MQTT Control Packet Types +CONNECT = 0x10 +CONNACK = 0x20 +PUBLISH = 0x30 +PUBACK = 0x40 +SUBSCRIBE = 0x80 +SUBACK = 0x90 +UNSUBSCRIBE = 0xA0 +UNSUBACK = 0xB0 +PINGREQ = 0xC0 +PINGRESP = 0xD0 +DISCONNECT = 0xE0 + +# MQTT Connection Return Codes +CONN_ACCEPTED = 0 +CONN_REFUSED_PROTOCOL = 1 +CONN_REFUSED_IDENTIFIER = 2 +CONN_REFUSED_SERVER = 3 +CONN_REFUSED_USER_PASS = 4 +CONN_REFUSED_AUTH = 5 + +class MQTTException(Exception): + """MQTT Exception class for handling MQTT-specific errors""" + pass + +class MQTTClient: + """ + A basic MQTT client implementation from scratch. + + This class implements the MQTT protocol directly using socket communication. + It provides functionality for connecting to an MQTT broker, publishing messages, + subscribing to topics, and receiving messages. + + Attributes: + client_id (str): Unique identifier for this client + server (str): MQTT broker address + port (int): MQTT broker port + user (str): Username for authentication + password (str): Password for authentication + keepalive (int): Keepalive interval in seconds + ssl (bool): Whether to use SSL/TLS + sock (socket.socket): Socket connection to the broker + connected (bool): Whether the client is connected to the broker + callback (callable): Callback function for received messages + pid (int): Packet ID for message tracking + subscriptions (dict): Dictionary of subscribed topics + last_ping (float): Timestamp of the last ping + """ + def __init__( + self, + client_id, + server, + port=1883, + user=None, + password=None, + keepalive=60, + ssl=False, + ): + """ + Initialize the MQTT client. + + Args: + client_id (str): Unique identifier for this client + server (str): MQTT broker address + port (int): MQTT broker port + user (str): Username for authentication + password (str): Password for authentication + keepalive (int): Keepalive interval in seconds + ssl (bool): Whether to use SSL/TLS + """ + self.client_id = client_id + self.server = server + self.port = port + self.user = user + self.password = password + self.keepalive = keepalive + self.ssl = ssl + self.sock: socket.socket | None = None + self.connected = False + self.callback = None + self.pid = 0 # Packet ID for message tracking + self.subscriptions = {} # Track subscribed topics + self.last_ping = 0 + + def _generate_packet_id(self): + """ + Generate a unique packet ID for MQTT messages. + + Returns: + int: A unique packet ID between 1 and 65535 + """ + self.pid = (self.pid + 1) % 65536 + return self.pid + + def _encode_length(self, length): + """ + Encode the remaining length field in the MQTT packet. + + Args: + length (int): The length to encode + + Returns: + bytearray: The encoded length + """ + result = bytearray() + while True: + byte = length % 128 + length = length // 128 + if length > 0: + byte |= 0x80 + result.append(byte) + if length == 0: + break + return result + + def _encode_string(self, string): + """ + Encode a string for MQTT packet. + + Args: + string (str or bytes): The string to encode + + Returns: + bytearray: The encoded string + """ + if isinstance(string, str): + string = string.encode('utf-8') + return bytearray(struct.pack("!H", len(string)) + string) + + def _send_packet(self, packet_type, payload=b''): + """ + Send an MQTT packet to the broker. + + Args: + packet_type (int): The MQTT packet type + payload (bytes): The packet payload + + Raises: + MQTTException: If the client is not connected or sending fails + """ + if self.sock is None: + raise MQTTException("Not connected to broker (_send_packet)") + + # Construct the packet + packet = bytearray() + packet.append(packet_type) + + # Add remaining length + packet.extend(self._encode_length(len(payload))) + + # Add payload + if payload: + packet.extend(payload) + + # Send the packet + try: + self.sock.send(packet) + except Exception as e: + self.connected = False + raise MQTTException(f"Failed to send packet: {e}") + + def _recv_packet(self, timeout=1.0): + """ + Receive an MQTT packet from the broker. + + Args: + timeout (float): Socket timeout in seconds + + Returns: + tuple: (packet_type, payload) or (None, None) if no packet received + + Raises: + MQTTException: If the client is not connected or receiving fails + """ + if self.sock is None: + raise MQTTException("Not connected to broker (_recv_packet)") + + # Set socket timeout + self.sock.settimeout(timeout) + + try: + # Read packet type + packet_type = self.sock.recv(1) + if not packet_type: + return None, None + + # Read remaining length + remaining_length = 0 + multiplier = 1 + while True: + byte = self.sock.recv(1)[0] + remaining_length += (byte & 0x7F) * multiplier + multiplier *= 128 + if not (byte & 0x80): + break + + # Read the payload + payload = self.sock.recv(remaining_length) if remaining_length else b'' + + return packet_type[0], payload + except socket.timeout: + return None, None + except Exception as e: + self.connected = False + raise MQTTException(f"Failed to receive packet: {e}") + + def connect(self): + """ + Connect to the MQTT broker. + + Returns: + int: 0 if successful, otherwise an error code + + Raises: + MQTTException: If connection fails + """ + # Create socket + try: + self.sock = socket.socket() + print(f"[MQTT] Connecting to Socket {self.server}:{self.port} as {self.client_id}") + self.sock.connect((self.server, self.port)) + print(f"[MQTT] Connected to {self.server}:{self.port}") + except Exception as e: + print(f"Error connecting to MQTT broker: {e}") + raise MQTTException(f"Failed to connect to {self.server}:{self.port}: {e}") + + # Construct CONNECT packet + payload = bytearray() + + # Protocol name and level + payload.extend(self._encode_string("MQTT")) + payload.append(MQTT_PROTOCOL_LEVEL) + + # Connect flags + connect_flags = 0 + if self.user: + connect_flags |= 0x80 + if self.password: + connect_flags |= 0x40 + connect_flags |= MQTT_CLEAN_SESSION << 1 + payload.append(connect_flags) + + # Keepalive (in seconds) + payload.extend(struct.pack("!H", self.keepalive)) + + # Client ID + payload.extend(self._encode_string(self.client_id)) + + # Username and password if provided + if self.user: + payload.extend(self._encode_string(self.user)) + if self.password: + payload.extend(self._encode_string(self.password)) + + # Send CONNECT packet + self._send_packet(CONNECT, payload) + + # Wait for CONNACK + packet_type, payload = self._recv_packet() + if packet_type != CONNACK: + raise MQTTException(f"Unexpected response from broker: {packet_type}") + + # Check connection result + if len(payload) != 2: + raise MQTTException("Invalid CONNACK packet") + + if payload[1] != CONN_ACCEPTED: + raise MQTTException(f"Connection refused: {payload[1]}") + + self.connected = True + self.last_ping = time.time() + return 0 + + def disconnect(self): + """ + Disconnect from the MQTT broker. + """ + if self.connected: + try: + self._send_packet(DISCONNECT) + self.sock.close() + except Exception as e: + print(f"Error during disconnect: {e}") + finally: + self.connected = False + self.sock = None + + def ping(self): + """ + Send PINGREQ to keep the connection alive. + + Raises: + MQTTException: If no PINGRESP is received + """ + if self.connected: + self._send_packet(PINGREQ) + packet_type, _ = self._recv_packet() + if packet_type != PINGRESP: + self.connected = False + raise MQTTException("No PINGRESP received") + self.last_ping = time.time() + + def publish(self, topic, msg, retain=False, qos=0): + """ + Publish a message to a topic. + + Args: + topic (str or bytes): The topic to publish to + msg (str or bytes): The message to publish + retain (bool): Whether the message should be retained by the broker + qos (int): Quality of Service level (0 or 1) + + Raises: + MQTTException: If the client is not connected or publishing fails + """ + if not self.connected: + raise MQTTException("Not connected to broker (publish)") + + # Check if we need to ping to keep connection alive + if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: + self.ping() + + # Convert topic and message to bytes if they're not already + if isinstance(topic, str): + topic = topic.encode('utf-8') + if isinstance(msg, str): + msg = msg.encode('utf-8') + + # Construct PUBLISH packet + packet_type = PUBLISH + if retain: + packet_type |= 0x01 + if qos: + packet_type |= (qos << 1) + + # Payload: topic + message + payload = self._encode_string(topic) + + # Add packet ID for QoS > 0 + if qos > 0: + pid = self._generate_packet_id() + payload.extend(struct.pack("!H", pid)) + + payload.extend(msg) + + # Send PUBLISH packet + self._send_packet(packet_type, payload) + + # For QoS 1, wait for PUBACK + if qos == 1: + packet_type, _ = self._recv_packet() + if packet_type != PUBACK: + raise MQTTException(f"No PUBACK received: {packet_type}") + + return + + def subscribe(self, topic, qos=0): + """ + Subscribe to a topic. + + Args: + topic (str or bytes): The topic to subscribe to + qos (int): Quality of Service level + + Raises: + MQTTException: If the client is not connected or subscription fails + """ + if not self.connected: + raise MQTTException("Not connected to broker (subscribe)") + + # Check if we need to ping to keep connection alive + if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: + self.ping() + + # Convert topic to bytes if it's not already + if isinstance(topic, str): + topic = topic.encode('utf-8') + + # Generate packet ID + pid = self._generate_packet_id() + + # Construct SUBSCRIBE packet + payload = bytearray(struct.pack("!H", pid)) + payload.extend(self._encode_string(topic)) + payload.append(qos) + + # Send SUBSCRIBE packet + self._send_packet(SUBSCRIBE | 0x02, payload) + + # Wait for SUBACK + packet_type, payload = self._recv_packet() + if packet_type != SUBACK: + raise MQTTException(f"No SUBACK received: {packet_type}") + + # Store subscription + topic_str = topic.decode('utf-8') if isinstance(topic, bytes) else topic + self.subscriptions[topic_str] = qos + + return + + def set_callback(self, callback): + """ + Set callback for received messages. + + Args: + callback (callable): Function to call when a message is received. + The callback should accept two parameters: + topic (bytes) and message (bytes). + """ + self.callback = callback + + def check_msg(self): + """ + Check for pending messages from the broker. + + This method should be called regularly to process incoming messages. + If a callback is set, it will be called with the topic and message. + """ + if not self.connected: + return + + # Check if we need to ping to keep connection alive + if self.keepalive > 0 and time.time() - self.last_ping >= self.keepalive: + self.ping() + + # Try to receive a packet with a short timeout + packet_type, payload = self._recv_packet(timeout=0.1) + + if packet_type is None: + return + + if packet_type & 0xF0 == PUBLISH: + # Extract flags + dup = (packet_type & 0x08) >> 3 + qos = (packet_type & 0x06) >> 1 + retain = packet_type & 0x01 + + # Extract topic + topic_len = struct.unpack("!H", payload[0:2])[0] + topic = payload[2:2+topic_len] + + # Extract packet ID for QoS > 0 + if qos > 0: + pid = struct.unpack("!H", payload[2+topic_len:2+topic_len+2])[0] + message = payload[2+topic_len+2:] + + # Send PUBACK for QoS 1 + if qos == 1: + self._send_packet(PUBACK, struct.pack("!H", pid)) + else: + message = payload[2+topic_len:] + + # Call the callback if set + if self.callback: + self.callback(topic, message) + + return \ No newline at end of file diff --git a/tests/test_mqtt.py b/tests/test_mqtt.py index 58588d1..38127a1 100644 --- a/tests/test_mqtt.py +++ b/tests/test_mqtt.py @@ -5,7 +5,8 @@ Tests for the MQTT module. import pytest import json from unittest.mock import patch, MagicMock -from src.esp_sensors.mqtt import setup_mqtt, publish_sensor_data, MQTTClient +from src.esp_sensors.mqtt_client import MQTTClient +from src.esp_sensors.mqtt import setup_mqtt, publish_sensor_data class TestSensor: diff --git a/tests/test_mqtt_client.py b/tests/test_mqtt_client.py new file mode 100644 index 0000000..f369ab7 --- /dev/null +++ b/tests/test_mqtt_client.py @@ -0,0 +1,266 @@ +""" +Tests for the MQTT Client module. + +This module contains tests for the MQTTClient class in the mqtt_client.py module. +""" + +import pytest +import socket +import struct +from unittest.mock import patch, MagicMock, call +from src.esp_sensors.mqtt_client import ( + MQTTClient, MQTTException, + CONNECT, CONNACK, PUBLISH, PUBACK, SUBSCRIBE, SUBACK, + PINGREQ, PINGRESP, DISCONNECT, + CONN_ACCEPTED +) + + +class TestMQTTClient: + """Tests for the MQTTClient class.""" + + @pytest.fixture + def mqtt_client(self): + """Fixture providing a basic MQTTClient instance.""" + return MQTTClient( + client_id="test_client", + server="test.mosquitto.org", + port=1883, + user="test_user", + password="test_pass", + keepalive=60, + ssl=False + ) + + def test_init(self, mqtt_client): + """Test that the MQTTClient initializes with the correct attributes.""" + assert mqtt_client.client_id == "test_client" + assert mqtt_client.server == "test.mosquitto.org" + assert mqtt_client.port == 1883 + assert mqtt_client.user == "test_user" + assert mqtt_client.password == "test_pass" + assert mqtt_client.keepalive == 60 + assert mqtt_client.ssl is False + assert mqtt_client.sock is None + assert mqtt_client.connected is False + assert mqtt_client.callback is None + assert mqtt_client.pid == 0 + assert mqtt_client.subscriptions == {} + assert mqtt_client.last_ping == 0 + + def test_generate_packet_id(self, mqtt_client): + """Test that _generate_packet_id returns sequential IDs and wraps around.""" + # First call should return 1 + assert mqtt_client._generate_packet_id() == 1 + assert mqtt_client.pid == 1 + + # Second call should return 2 + assert mqtt_client._generate_packet_id() == 2 + assert mqtt_client.pid == 2 + + # Set pid to 65535 (max value) + mqtt_client.pid = 65535 + + # Next call should wrap around to 0 + assert mqtt_client._generate_packet_id() == 0 + assert mqtt_client.pid == 0 + + def test_encode_length(self, mqtt_client): + """Test that _encode_length correctly encodes MQTT remaining length.""" + # Test small length (< 128) + assert list(mqtt_client._encode_length(64)) == [64] + + # Test medium length (128-16383) + assert list(mqtt_client._encode_length(128)) == [128 & 0x7F | 0x80, 1] + assert list(mqtt_client._encode_length(8192)) == [0x80, 0x40] + + # Test large length (16384-2097151) + assert list(mqtt_client._encode_length(2097151)) == [0xFF, 0xFF, 0x7F] + + def test_encode_string(self, mqtt_client): + """Test that _encode_string correctly encodes strings for MQTT packets.""" + # Test with string input + result = mqtt_client._encode_string("test") + assert len(result) == 6 # 2 bytes length + 4 bytes string + assert result[0:2] == b'\x00\x04' # Length (4) in network byte order + assert result[2:] == b'test' # String content + + # Test with bytes input + result = mqtt_client._encode_string(b"test") + assert len(result) == 6 + assert result[0:2] == b'\x00\x04' + assert result[2:] == b'test' + + @patch('socket.socket') + def test_connect_success(self, mock_socket, mqtt_client): + """Test successful connection to MQTT broker.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Configure the mock socket to return a successful CONNACK + mock_sock.recv.side_effect = [ + b'\x20', # CONNACK packet type + b'\x02', # Remaining length + b'\x00\x00' # Session present flag (0) + return code (0 = accepted) + ] + + # Call connect + result = mqtt_client.connect() + + # Verify socket was created and connected + mock_socket.assert_called_once() + mock_sock.connect.assert_called_once_with(("test.mosquitto.org", 1883)) + + # Verify CONNECT packet was sent + mock_sock.send.assert_called_once() + + # Verify result + assert result == 0 + assert mqtt_client.connected is True + assert mqtt_client.sock is mock_sock + + @patch('socket.socket') + def test_connect_failure(self, mock_socket, mqtt_client): + """Test connection failure to MQTT broker.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Configure the mock socket to return a failed CONNACK + mock_sock.recv.side_effect = [ + b'\x20', # CONNACK packet type + b'\x02', # Remaining length + b'\x00\x01' # Session present flag (0) + return code (1 = refused, protocol version) + ] + + # Call connect and verify it raises an exception + with pytest.raises(MQTTException, match="Connection refused: 1"): + mqtt_client.connect() + + @patch('socket.socket') + def test_disconnect(self, mock_socket, mqtt_client): + """Test disconnection from MQTT broker.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Set up the client as connected + mqtt_client.sock = mock_sock + mqtt_client.connected = True + + # Call disconnect + mqtt_client.disconnect() + + # Verify DISCONNECT packet was sent + mock_sock.send.assert_called_once() + + # Verify socket was closed + mock_sock.close.assert_called_once() + + # Verify client state + assert mqtt_client.connected is False + assert mqtt_client.sock is None + + @patch('socket.socket') + def test_publish(self, mock_socket, mqtt_client): + """Test publishing a message to a topic.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Set up the client as connected + mqtt_client.sock = mock_sock + mqtt_client.connected = True + mqtt_client.last_ping = 0 + + # Call publish + mqtt_client.publish("test/topic", "test message") + + # Verify PUBLISH packet was sent + mock_sock.send.assert_called_once() + + # Test with QoS 1 + mock_sock.reset_mock() + mock_sock.recv.side_effect = [ + b'\x40', # PUBACK packet type + b'\x02', # Remaining length + b'\x00\x01' # Packet ID + ] + + mqtt_client.publish("test/topic", "test message", qos=1) + + # Verify PUBLISH packet was sent + assert mock_sock.send.call_count == 1 + + @patch('socket.socket') + def test_subscribe(self, mock_socket, mqtt_client): + """Test subscribing to a topic.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Set up the client as connected + mqtt_client.sock = mock_sock + mqtt_client.connected = True + mqtt_client.last_ping = 0 + + # Configure the mock socket to return a successful SUBACK + mock_sock.recv.side_effect = [ + b'\x90', # SUBACK packet type + b'\x03', # Remaining length + b'\x00\x01\x00' # Packet ID + return code (0 = success) + ] + + # Call subscribe + mqtt_client.subscribe("test/topic") + + # Verify SUBSCRIBE packet was sent + mock_sock.send.assert_called_once() + + # Verify subscription was stored + assert "test/topic" in mqtt_client.subscriptions + assert mqtt_client.subscriptions["test/topic"] == 0 + + @patch('socket.socket') + def test_check_msg(self, mock_socket, mqtt_client): + """Test checking for messages.""" + # Configure the mock socket + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + + # Set up the client as connected + mqtt_client.sock = mock_sock + mqtt_client.connected = True + mqtt_client.last_ping = 0 + + # Set up a mock callback + mock_callback = MagicMock() + mqtt_client.set_callback(mock_callback) + + # Configure the mock socket to return a PUBLISH packet + topic = "test/topic" + message = "test message" + topic_encoded = struct.pack("!H", len(topic)) + topic.encode() + mock_sock.recv.side_effect = [ + bytes([PUBLISH]), # PUBLISH packet type + bytes([len(topic_encoded) + len(message)]), # Remaining length + topic_encoded + message.encode() # Topic + message + ] + + # Call check_msg + mqtt_client.check_msg() + + # Verify callback was called with correct parameters + mock_callback.assert_called_once_with(topic.encode(), message.encode()) + + def test_set_callback(self, mqtt_client): + """Test setting a callback function.""" + # Create a mock callback + mock_callback = MagicMock() + + # Set the callback + mqtt_client.set_callback(mock_callback) + + # Verify callback was set + assert mqtt_client.callback is mock_callback \ No newline at end of file