1# Copyright: (c)  2025, Intel Corporation
2# Author: Arkadiusz Cholewinski <arkadiuszx.cholewinski@intel.com>
3
4import csv
5import logging
6import queue
7import re
8import threading
9import time
10
11import utils.UtilityFunctions as UtilityFunctions
12from abstract.PowerMonitor import PowerMonitor
13from stm32l562e_dk.PowerShieldConfig import PowerShieldConf
14from stm32l562e_dk.PowerShieldData import PowerShieldData
15from stm32l562e_dk.SerialHandler import SerialHandler
16
17
18class PowerShield(PowerMonitor):
19    def __init__(self):
20        """
21        Initializes the PowerShield.
22        """
23        self.handler = None
24        self.dataQueue = queue.Queue()
25        self.acqComplete = False
26        self.acqStart = False
27        self.target_voltage = None
28        self.target_temperature = None
29        self.acqTimeoutThread = None
30        self.power_shield_conf = PowerShieldConf()
31        self.power_shield_data = PowerShieldData()
32
33    def init(self):
34        """
35        Initializes the power monitor.
36        """
37        self.__take_control()
38        self.__set_voltage(self.power_shield_conf.target_voltage)
39        self.__set_format(self.power_shield_conf.data_format)
40        self.__set_func_mode(self.power_shield_conf.function_mode)
41
42    def connect(self, power_device_path: str):
43        """Opens the connection using the SerialHandler."""
44        self.handler = SerialHandler(power_device_path, 3686400)
45        self.handler.open()
46
47    def disconnect(self):
48        """Closes the connection using the SerialHandler."""
49        self.handler.close()
50
51    def __send_command(self, command: str, expected_ack: str = None, ack: bool = False) -> str:
52        """
53        Sends a command to the device, retrieves the response,
54        and optionally verifies the acknowledgment.
55
56        :param command: The command to send.
57        :param expected_ack: The expected acknowledgment response (e.g., "ack htc").
58        :return: The response received from the device.
59        """
60        if not self.handler.is_open():
61            logging.info(f"Error: Connection is not open. Cannot send command: {command}")
62            return ""
63
64        logging.debug(f"Sending command: {command}")
65        self.handler.send_cmd(command)
66        if ack:
67            response = self.handler.receive_cmd()
68            logging.debug(f"Response: {response}")
69
70            # Check if the response contains the expected acknowledgment
71            if expected_ack and expected_ack not in response:
72                logging.error(f"Error: Expected acknowledgment '{expected_ack}' not found.")
73                return ""
74
75            return response
76        return 0
77
78    def __test_communication(self):
79        """
80        Sends a version command to the device.
81        """
82        if not self.handler.is_open():
83            logging.error("Error: Connection is not open. Cannot send version command.")
84            return ""
85        command = 'version'
86        logging.info(f"Sending command: {command}")
87        self.handler.send_cmd(command)
88        response = self.handler.receive_cmd()
89        logging.info(f"Response: {response}")
90        return response
91
92    def __reset(self):
93        """
94        Sends the reset command ('PSRST') to the power monitor device,
95        closes the connection, waits for the reset process to complete,
96        and repeatedly attempts to reconnect until successful.
97        """
98        command = "psrst"
99
100        if not self.handler.is_open():
101            logging.error("Error: Connection is not open. Cannot reset the device.")
102            return
103
104        logging.info(f"Sending reset command: {command}")
105        self.handler.send_cmd(command)
106
107        # Close the connection
108        self.handler.close()
109        self.handler.serial_connection = None
110
111        time.sleep(5)
112        # Attempt to reopen the connection
113        try:
114            self.handler.open()
115            logging.info("Connection reopened after reset.")
116        except Exception as e:
117            logging.error(f"Failed to reopen connection after reset: {e}")
118
119    def __get_voltage_level(self) -> float:
120        """
121        Sends the 'volt get' command and returns the voltage value as a float.
122
123        :return: The voltage level as a float, in volts (V).
124        """
125        command = 'volt get'
126        response = self.__send_command(command, expected_ack="ack volt get", ack=True)
127
128        # If response contains the expected acknowledgment, extract and return the voltage
129        if response:
130            parts = response.split()
131            try:
132                if len(parts) >= 5:
133                    # Use regex to find a string that matches the pattern, e.g., "3292-03"
134                    match = re.search(r'(\d+)-(\d+)', parts[5])
135                    if match:
136                        # Extract the base (3292) and exponent (03)
137                        base = match.group(1)
138                        exponent = match.group(2)
139
140                        # Construct the scientific notation string (e.g., 3292e-03)
141                        voltage_str = f"{base}e-{exponent}"
142
143                        # Convert the string into a float
144                        voltage = float(voltage_str)
145
146                        # Return the voltage as a float
147                        self.target_voltage = round(voltage, 3)
148                        return self.target_voltage
149            except ValueError:
150                logging.error("Error: Could not convert temperature value.")
151                return float('nan')
152        else:
153            logging.error("Error: No response for voltage command.")
154        return float('nan')
155
156    def __get_temperature(self, unit: str = PowerShieldConf.TemperatureUnit.CELSIUS) -> float:
157        """
158        Sends the temperature command and returns the temperature as a float.
159
160        :param unit: The unit to request the temperature in, either 'degc' or 'degf'.
161        :return: The temperature value as a float, in the specified unit (°C or °F).
162        """
163        # Send the temp command with the unit
164        response = self.__send_command(f"temp {unit}", expected_ack=f"ack temp {unit}", ack=True)
165
166        # If response contains the expected acknowledgment, extract the temperature
167        if response:
168            try:
169                # Example response format: "PowerShield > ack temp degc 28.0"
170                parts = response.split()
171                if len(parts) >= 5 and parts[5].replace('.', '', 1).isdigit():
172                    # Extract temperature and convert to float
173                    self.target_temetarute = float(parts[5])
174                    logging.info(f"Temperature: {self.target_temetarute} {unit}")
175                    return self.target_temetarute
176                else:
177                    print("Error: Temperature value not found in response.")
178                    return None
179            except ValueError:
180                logging.error("Error: Could not convert temperature value.")
181                return None
182        else:
183            logging.error("Error: No response for temp command.")
184        return None
185
186    def __take_control(self) -> str:
187        """
188        Sends the 'htc' command and verifies the acknowledgment.
189
190        :return: The acknowledgment response or error message.
191        """
192        return self.__send_command("htc", expected_ack="ack htc", ack=True)
193
194    def __set_format(self, data_format: str = PowerShieldConf.DataFormat.ASCII_DEC):
195        """
196        Sets the measurement data format.
197        The format can be either ASCII (decimal) or Binary (hexadecimal).
198
199        :param data_format: The data format to set.
200                            Options are 'ascii_dec' or 'bin_hexa'.
201        :return: None
202        """
203        # Validate the input format
204        if data_format not in vars(PowerShieldConf.DataFormat).values():
205            logging.error(
206                f"Error: Invalid format '{data_format}'. "
207                "Valid options are 'ascii_dec' or 'bin_hexa'."
208            )
209            return
210
211        command = f"format {data_format}"
212        response = self.__send_command(command, expected_ack=f"ack format {data_format}", ack=True)
213
214        # If response contains the expected acknowledgment, the format was set successfully
215        if response:
216            logging.info(f"Data format set to {data_format}.")
217        else:
218            logging.error(f"Error: Failed to set data format to {data_format}.")
219
220    def __set_frequency(self, frequency: enumerate):
221        """
222        Sets the sampling frequency for the measurement.
223        The frequency can be any valid value from the list.
224
225        :param frequency: The sampling frequency to set.
226        Valid options include:
227        {100k, 50k, 20k, 10k, 5k, 2k, 1k, 500, 200, 100, 50, 20, 10, 5, 2, 1}.
228
229        :return: None
230        """
231        # Validate the input frequency
232        if frequency not in vars(PowerShieldConf.SamplingFrequency).values():
233            logging.error(
234                f"Error: Invalid frequency '{frequency}'."
235                "Valid options are:"
236                "100k, 50k, 20k, 10k, 5k, 2k, 1k, 500, 200, 100, 50, 20, 10, 5, 2, 1."
237            )
238            return
239
240        command = f"freq {frequency}"
241        response = self.__send_command(command, expected_ack=f"ack freq {frequency}", ack=True)
242
243        if response:
244            logging.info(f"Sampling frequency set to {frequency}.")
245        else:
246            logging.error(f"Error: Failed to set sampling frequency to {frequency}.")
247
248    def __set_acquisition_time(self, acquisition_time: str = '0'):
249        command = f"acqtime {acquisition_time}"
250        response = self.__send_command(
251            command, expected_ack=f"ack acqtime {acquisition_time}", ack=True
252        )
253
254        if response:
255            logging.info(f"Acquisition time set to {acquisition_time}.")
256        else:
257            logging.error(f"Error: Failed to set acquisition time to {acquisition_time}.")
258
259    def __set_voltage(self, voltage: enumerate):
260        command = f"volt {voltage}"
261        response = self.__send_command(command, expected_ack=f"ack volt {voltage}", ack=True)
262
263        if response:
264            logging.info(f"Voltage set to {voltage}.")
265        else:
266            logging.error(f"Error: Failed to set voltage to {voltage}.")
267
268    def __set_func_mode(self, function_mode: str = PowerShieldConf.FunctionMode.HIGH):
269        """
270        Sets the acquisition mode for current measurement.
271        The function_mode can be either 'optim' or 'high'.
272
273        - 'optim': Priority on current resolution (100 nA - 10 mA) with max freq at 100 kHz.
274        - 'high': High current (30 µA - 10 mA), high frequency (50-100 kHz), high resolution.
275
276        :param mode: The acquisition mode. Must be either 'optim' or 'high'.
277        :return: None
278        """
279        # Validate the input format
280        if function_mode not in vars(PowerShieldConf.FunctionMode).values():
281            logging.error(
282                f"Error: Invalid format '{function_mode}'."
283                "Valid options are 'ascii_dec' or 'bin_hexa'."
284            )
285            return
286
287        command = f"funcmode {function_mode}"
288        response = self.__send_command(
289            command, expected_ack=f"ack funcmode {function_mode}", ack=True
290        )
291
292        if response:
293            logging.info(f"Data format set to {function_mode}.")
294        else:
295            logging.error(f"Error: Failed to set data format to {function_mode}.")
296
297    def __acq_data(self):
298        """
299        Continuously reads data from the serial port and puts it
300        into a queue until acquisition is complete.
301        """
302        logging.info("Started data acquisition...")
303        while True:
304            # Read the first byte
305            first_byte = self.handler.read_bytes(1)
306            if len(first_byte) < 1 or self.acqComplete:  # Exit conditions
307                logging.info("Stopping data acquisition...")
308                return
309
310            # Check if it's metadata
311            if first_byte == b'\xf0':  # Metadata marker
312                second_byte = self.handler.read_bytes(1)
313                # Handle metadata types
314                metadata_type = second_byte[0]
315                self.__handle_metadata(metadata_type)
316            else:
317                # Not metadata, treat as data
318                if self.acqStart:
319                    second_byte = self.handler.read_bytes(1)
320                    data = []
321                    data.append(first_byte)
322                    if len(second_byte) < 1 or self.acqComplete:
323                        logging.info("Stopping data acquisition...")
324                        return
325                    data.append(second_byte)
326                    amps = UtilityFunctions.convert_to_amps(
327                        UtilityFunctions.bytes_to_twobyte_values(data)
328                    )
329                    self.dataQueue.put([amps])
330
331    def __handle_metadata(self, metadata_type):
332        if metadata_type == 0xF1:
333            logging.info("Received Metadata: ASCII error message.")
334            # self.handle_metadata_error()
335        elif metadata_type == 0xF2:
336            logging.info("Received Metadata: ASCII information message.")
337            # self.handle_metadata_info()
338        elif metadata_type == 0xF3:
339            logging.info("Received Metadata: Timestamp message.")
340            self.__handle_metadata_timestamp()
341            self.acqStart = True
342        elif metadata_type == 0xF4:
343            logging.info("Received Metadata: End of acquisition tag.")
344            self.__handle_metadata_end()
345            self.__handle_summary()
346        elif metadata_type == 0xF5:
347            logging.info("Received Metadata: Overcurrent detected.")
348            # self.handle_metadata_overcurrent()
349        else:
350            logging.error(f"Error: Unknown Metadata Type: {metadata_type:#04x}")
351
352    def __handle_summary(self):
353        s = ""
354        while True:
355            # Read the first byte
356            x = self.handler.read_bytes(1)
357            if len(x) < 1 or x == 0xF0:
358                self.acqComplete = True
359                return s.replace("\0", "").strip().replace("\r", "").replace("\n\n\n", "\n")
360            s += str(x, encoding='ascii', errors='ignore')
361
362    def __handle_metadata_end(self):
363        """
364        Handle metadata end of acquisition message.
365        """
366        # Read the next 2 bytes
367        metadata_bytes = self.handler.read_bytes(2)
368        if len(metadata_bytes) < 2:
369            logging.error("Error: Incomplete end of acquisition metadata reveived.")
370            return
371        # Check for end tags (last 2 bytes)
372        end_tag_1 = metadata_bytes[0]
373        end_tag_2 = metadata_bytes[1]
374        if end_tag_1 != 0xFF or end_tag_2 != 0xFF:
375            logging.error("Error: Invalid metadata end tags received.")
376            return
377
378    def __handle_metadata_timestamp(self):
379        """
380        Handle metadata timestamp message. Parses and displays the timestamp and buffer load.
381        """
382        # Read the next 7 bytes (timestamp + buffer load + end tags)
383        metadata_bytes = self.handler.read_bytes(7)
384        if len(metadata_bytes) < 7:
385            logging.error("Error: Incomplete timestamp metadata received.")
386            return
387
388        # Parse the timestamp (4 bytes, big-endian)
389        timestamp_ms = int.from_bytes(metadata_bytes[0:4], byteorder='big', signed=False)
390        # Parse the buffer Tx load value (1 byte)
391        buffer_load = metadata_bytes[4]
392        # Check for end tags (last 2 bytes)
393        end_tag_1 = metadata_bytes[5]
394        end_tag_2 = metadata_bytes[6]
395        if end_tag_1 != 0xFF or end_tag_2 != 0xFF:
396            logging.error("Error: Invalid metadata end tags received.")
397            return
398
399        # Display parsed values
400        logging.info(f"Metadata Timestamp: {timestamp_ms} ms")
401        logging.info(f"Buffer Tx Load: {buffer_load}%")
402
403    def __start_measurement(self):
404        """
405        Starts the measurement by sending the 'start' command. Once the measurement starts,
406        data can be received continuously until the 'stop' command is sent.
407
408        :return: None
409        """
410        command = "start"
411        self.acqComplete = False
412        self.__send_command(command)
413
414        raw_to_file_Thread = threading.Thread(
415            target=self.__raw_to_file, args=(self.power_shield_conf.output_file,)
416        )
417        raw_to_file_Thread.start()
418        logging.info("Measurement started. Receiving data...")
419        self.__acq_data()
420        raw_to_file_Thread.join()
421
422    def __raw_to_file(self, outputFilePath: str):
423        # Open a CSV file for writing
424        with open(outputFilePath, 'w', newline='') as outputFile:
425            writer = csv.writer(outputFile)
426            while True:
427                if self.dataQueue.empty() and bool(self.acqComplete):
428                    outputFile.close()
429                    break
430                if not self.dataQueue.empty():
431                    data = self.dataQueue.get()
432                    writer.writerow(data)
433                    outputFile.flush()
434                else:
435                    time.sleep(0.1)
436
437    def measure(self, time: int, freq: str = None, reset: bool = False):
438        self.power_shield_conf.acquisition_time = time
439        _time, self.power_shield_conf.acquisition_time_unit = UtilityFunctions.convert_acq_time(
440            time
441        )
442
443        if reset:
444            self.__reset()
445        self.__take_control()
446        self.__set_format(self.power_shield_conf.data_format)
447        if freq is not None:
448            self.__set_frequency(freq)
449        else:
450            self.__set_frequency(self.power_shield_conf.sampling_frequency)
451        self.__set_acquisition_time(
452            UtilityFunctions.convert_to_scientific_notation(
453                time=_time, unit=self.power_shield_conf.acquisition_time_unit
454            )
455        )
456        self.__start_measurement()
457
458    def get_data(self, unit: str = PowerShieldConf.MeasureUnit.RAW_DATA):
459        if self.acqComplete:
460            # Open the CSV file
461            with open(self.power_shield_conf.output_file) as file:
462                csv_reader = csv.reader(file)
463                for row in csv_reader:
464                    self.power_shield_data.data.append(row[0])
465                if unit == PowerShieldConf.MeasureUnit.CURRENT_RMS:
466                    self.power_shield_data.current_RMS = UtilityFunctions.calculate_rms(
467                        self.power_shield_data.data
468                    )
469                    return self.power_shield_data.current_RMS
470                elif unit == PowerShieldConf.MeasureUnit.POWER:
471                    _delta_time = self.power_shield_conf.acquisition_time
472                    self.power_shield_data.power = 0
473                    for data in self.power_shield_data.data:
474                        self.power_shield_data.power += float(
475                            float(data) * float(_delta_time) * float(self.target_voltage)
476                        )
477                    return self.power_shield_data.power
478                elif unit == PowerShieldConf.MeasureUnit.RAW_DATA:
479                    return self.power_shield_data.data
480                else:
481                    logging.error("Error: Unknown unit of requested data")
482        else:
483            logging.info("Acquisition not complete.")
484        return None
485