""" Copyright start
  Copyright (C) 2008 - 2025 Fortinet Inc.
  All rights reserved.
  FORTINET CONFIDENTIAL & FORTINET PROPRIETARY SOURCE CODE
Copyright end """

import os
import re
import sys
import shlex
from helpers.logger import Logger
from helpers.cmd_utils import CmdUtils
from framework.base.tasks import Tasks
from constants import REPORT_FILE, STEP_RESULT_FILE, LOG_FILE
try:
    # This import will fail in case of external SME
    sys.path.append("/opt/cyops-auth")
    from utilities.ha.node import is_primary, is_secondary
    from utilities.ha.consts import BKP_CYOPS_CONN_FILE, CONF_FILE
except ImportError:
    pass

TASK_STATUS = {"DONE": "DONE", "FAILED": "FAILED", "SKIPPED": "SKIPPED"}
TASK_LOG_STATUS = {"STARTED": "STARTED", "COMPLETED": "COMPLETED"}
TEXT_COLOR = {'GREEN': '\033[92m', 'RED': '\033[91m', 'YELLOW': '\033[93m', 'RESET': '\033[0m'}
TEXT_DECORATION = {'BLINK': '\033[5m', 'BOLD': '\033[1m', 'RESET': '\033[0m'}


class RabbitMQUpgradeTask(Tasks):
    TASK_STATUS_MSG = "Update RabbitMQ environment settings."
    MSG_NODENAME_UNCHANGED = "No change in node name."
    MSG_ENABLE_LONGNAME = "Upgrade RabbitMQ use longname setting."
    MSG_RESTART_SERVICE = "Restarting service."
    MSG_UPDATE_NODENAME = "Update RabbitMQ node name."
    RABBITMQ_ENV_CONF = "/etc/rabbitmq/rabbitmq-env.conf"

    def __init__(self) -> None:
        super().__init__()
        self.logger = Logger.get_logger(__name__)
        self.cmd_line_utilities = CmdUtils()
        self.store_result_path = STEP_RESULT_FILE.format(self.current_version)
        self.report_path = REPORT_FILE.format(self.current_version)

    @property
    def tags(self) -> str:
        return 'post-upgrade'

    def get_description(self) -> str:
        return "Updates RabbitMQ configuration settings."

    def is_supported(self) -> bool:
        try:
            current_version = int(self.current_version.replace('.', ''))
            target_upgrade_version = int(self.target_upgrade_version.replace('.', ''))
            return target_upgrade_version >= current_version and self.is_enterprise()
        except ValueError as e:
            self.logger.error(f"Version parsing error: {e}")
            return False

    def execute(self):
        self.add_banner_in_log_file(self.MSG_UPDATE_NODENAME, TASK_LOG_STATUS["STARTED"])
        _is_secondary = self.is_secondary_node()
        if _is_secondary:
            # Take backup of db_config.yml before updating rabbitmq settings
            self.copy_file(CONF_FILE, BKP_CYOPS_CONN_FILE)
        self.disable_rabbitmq_nodename()
        self.setup_rabbitmq_with_new_settings()
        if _is_secondary:
            # Default file is updated with current mq_password and pointing to primary mq
            # Restore existing config file
            self.copy_file(BKP_CYOPS_CONN_FILE, CONF_FILE)
            self.set_file_permissions(CONF_FILE)

            # Remove backup file
            self.delete_file(BKP_CYOPS_CONN_FILE)
        self.restart_services()
        self.add_banner_in_log_file(self.TASK_STATUS_MSG, TASK_LOG_STATUS["COMPLETED"])

    def validate(self) -> bool:
        return True

    def _print_status_msg(self, msg, status):
        msg = "    " + msg
        reset = TEXT_COLOR["RESET"]
        if status == TASK_STATUS["DONE"]:
            color = TEXT_COLOR["GREEN"]
        elif status == TASK_STATUS["SKIPPED"]:
            color = TEXT_COLOR["YELLOW"]
        else:
            color = TEXT_COLOR["RED"]

        truncated_message = msg[:65] + "..." if len(msg) > 65 else msg
        width = 8
        status_text = f"{status:^{width}}"
        colored_status = f"{color}{status_text}{reset}"
        final_msg = "{:<70}{}[{}]".format(truncated_message, " ", colored_status)
        print(final_msg)

    def add_banner_in_log_file(self, msg: str, status: str) -> None:
        status_msg = " [{:^11}] {} {} ".format(status, ":", msg)
        border_length = len(status_msg)
        border = '=' * border_length
        new_line_char = "\n" if status == TASK_LOG_STATUS["STARTED"] else "\n\n"
        final_msg = f"{status_msg}{new_line_char}"
        try:
            with open(LOG_FILE, 'a') as log_file:
                log_file.write(final_msg)
        except Exception as e:
            self.logger.error(f"Failed to write log file: {e}")

    def disable_rabbitmq_nodename(self, config_path=RABBITMQ_ENV_CONF):
        try:
            if not os.path.exists(config_path):
                self.logger.error(f"Config file not found: {config_path}")
                return False

            nodename_changed = False
            with open(config_path, "r") as file:
                lines = file.readlines()
            with open(config_path, "w") as file:
                for line in lines:
                    if re.match(r"^\s*NODENAME\s*=", line):  # Match active NODENAME entry
                        if not line.lstrip().startswith("#"):
                            nodename_changed = True
                            file.write(f"# {line}")  # Comment the line
                        else:
                            file.write(line)
                    else:
                        file.write(line)
            if nodename_changed:
                status = TASK_STATUS["DONE"]
                msg = self.MSG_UPDATE_NODENAME
            else:
                status = TASK_STATUS["SKIPPED"]
                msg = self.MSG_NODENAME_UNCHANGED
            self._print_status_msg(msg, status)
            return nodename_changed
        except Exception as error:
            self._print_status_msg(self.MSG_UPDATE_NODENAME, TASK_STATUS['FAILED'])
            self.logger.error(f"Failed to comment NODENAME in rabbitmq-env.conf: {error}")
            return False

    def setup_rabbitmq_with_new_settings(self):
        try:
            cmd = "csadm mq --set-up-msg-broker"
            result = self.cmd_line_utilities.execute_cmd(cmd, True)
            if result.get('return_code') != 0:
                raise Exception(result.get('std_out'))
            self._print_status_msg(self.MSG_ENABLE_LONGNAME, TASK_STATUS['DONE'])
        except Exception as ex:
            self._print_status_msg(self.MSG_ENABLE_LONGNAME, TASK_STATUS["FAILED"])
            err_msg = f"Failed to update RabbitMQ settings: {ex}"
            self.logger.exception(err_msg)
            print(f"Exception occurred in post-upgrade task. Check logs at '{LOG_FILE}'")


    def restart_services(self):
        try:
            # Reload systemd daemon to apply any configuration changes
            self.cmd_line_utilities.execute_cmd("systemctl daemon-reload", True)

            # Restart RabbitMQ service
            cmd = "csadm services --restart"
            result = self.cmd_line_utilities.execute_cmd(cmd, True)

            if result.get('return_code') != 0:
                raise Exception(result.get('std_out'))

        except Exception as ex:
            self._print_status_msg(self.MSG_RESTART_SERVICE, TASK_STATUS["FAILED"])
            err_msg = f"Failed to restart services: {ex}"
            self.logger.exception(err_msg)
            print(f"Exception occurred in post-upgrade task. Check logs at '{LOG_FILE}'")

    def copy_file(self, source_file, destination_file):
        # Copy file
        cmd = f"cp {shlex.quote(source_file)} {shlex.quote(destination_file)}"
        self.cmd_line_utilities.execute_cmd(cmd)

    def set_file_permissions(self, file_path, permissions="644"):
        # Set file permissions
        cmd = f"chmod {shlex.quote(permissions)} {shlex.quote(file_path)}"
        self.cmd_line_utilities.execute_cmd(cmd)

    def delete_file(self, file_path):
        """Safely deletes a file if it exists, using execute_cmd."""
        file_path = os.path.abspath(file_path)  # Convert to absolute path for safety
        if not os.path.isfile(file_path):
            return  # File does not exist
        cmd = f"rm -f {shlex.quote(file_path)}"
        self.cmd_line_utilities.execute_cmd(cmd)

    def is_secondary_node(self):
        try:
            if self.is_enterprise():
                return is_secondary()
            return False
        except Exception as error:
            self.logger.exception(error)
            return False

    def is_enterprise(self):
        step_result = self.get_step_results("pre-upgrade", "initialize")
        return step_result["flag_is_enterprise"]
