import os
from pathlib import Path
from tempfile import TemporaryDirectory

from unittest import TestCase
from unittest.mock import patch
from contextlib import redirect_stdout
import io

from sohstationviewer.model.handling_data import (
    read_mseed_or_text
)
from sohstationviewer.view.util.enums import LogType
from sohstationviewer.conf import constants


TEST_DATA_DIR = Path(__file__).resolve().parent.parent.joinpath('test_data')
unimplemented_ascii_blockette_file = TEST_DATA_DIR.joinpath(
    'Q330_unimplemented_ascii_block/XX-3203_4-20221222190255')

log_text_file = TEST_DATA_DIR.joinpath(
    'Pegasus-sample/Pegasus_SVC4/logs/2020/XX/KC01/XX.KC01...D.2020.129')


def mock_track_info(text: str, type: LogType) -> None:
    print(text)


class TestReadMseedText(TestCase):
    def populate_cur_key_for_all_data(self, cur_key):
        """
        Set up new data set's key for all data

        :param cur_key: current processing key: DAS SN, experiment number
        """
        if cur_key not in self.log_data:
            self.log_data[cur_key] = {}
            self.soh_data[cur_key] = {}
            self.mass_pos_data[cur_key] = {}
            self.waveform_data[cur_key] = {}
            self.gaps[cur_key] = []
            self.data_time[cur_key] = [constants.HIGHEST_INT, 0]
            self.stream_header_by_key_chan[cur_key] = {}

    def setUp(self) -> None:
        self.tmp_dir_obj: TemporaryDirectory = TemporaryDirectory()
        self.tmp_dir = self.tmp_dir_obj.name
        self.stream_header_by_key_chan = {}
        self.soh_data = {}
        self.mass_pos_data = {}
        self.waveform_data = {}
        self.data_time = {}
        self.gaps = {}
        self.log_data = {'TEXT': []}
        self.track_info = mock_track_info
        self.expected_soh_chan_ids = sorted([
            'VEP', 'VEC', 'VEA', 'VKI', 'VCO', 'VPB', 'LCQ', 'LCE', ])
        self.expected_mass_pos_chan_ids = sorted([
            'VMU', 'VMW', 'VMV'])
        self.expected_wf_chan_ids = sorted([
            'BH1', 'BH2', 'BH3', 'BH4', 'BH5', 'BH6',
            'EL1', 'EL2', 'EL4', 'EL5', 'EL6', 'ELZ',
            'LL1', 'LL2', 'LL4', 'LL5', 'LL6', 'LLZ',
            ])
        self.expected_log_chan_ids = sorted(['ACE', 'LOG', 'OCF'])

        self.expected_stream_header_chan_ids = sorted(
            self.expected_soh_chan_ids + self.expected_mass_pos_chan_ids +
            self.expected_wf_chan_ids + self.expected_log_chan_ids)

        self.expected_data_time = [1671729273.000004, 1671735657.9099982]

    @patch('sohstationviewer.model.mseed.blockettes_reader.'
           'NOT_READ_BLOCKETTES', {1000, 1001})
    def test_unimplemented_ascii_block_in_not_read_blockettes(self):
        """
        including testing ASCII text got from data
        """
        path2file = unimplemented_ascii_blockette_file
        f = io.StringIO()
        with redirect_stdout(f):
            read_mseed_or_text(
                path2file, self.tmp_dir, None, None,
                self.stream_header_by_key_chan,
                self.soh_data, self.mass_pos_data, self.waveform_data,
                self.log_data, self.data_time, [], ['*'], True, True,
                self.track_info, self.populate_cur_key_for_all_data)
        output = f.getvalue()
        self.assertEqual(output.strip(), '')
        self.assertEqual(list(self.soh_data.keys()), ['3203'])
        self.assertEqual(sorted(list(self.soh_data['3203'].keys())),
                         self.expected_soh_chan_ids)
        self.assertEqual(list(self.mass_pos_data.keys()), ['3203'])
        self.assertEqual(sorted(list(self.mass_pos_data['3203'].keys())),
                         self.expected_mass_pos_chan_ids)
        self.assertEqual(list(self.waveform_data.keys()), ['3203'])
        self.assertEqual(
            sorted(list(self.waveform_data['3203'].keys())),
            self.expected_wf_chan_ids)
        self.assertEqual(list(self.log_data.keys()), ['TEXT', '3203'])
        self.assertListEqual(sorted(list(self.log_data['3203'].keys())),
                             self.expected_log_chan_ids)
        self.assertEqual(
            self.log_data['3203']['LOG'][0].split('\n')[2],
            "STATE OF HEALTH: From:2022-12-22T18:54:33.200000Z  "
            "To:2022-12-22T18:54:33.200000Z")
        self.assertEqual(
            self.log_data['3203']['LOG'][0].split('\n')[3],
            "{800}[2022-12-22 18:54:33] PWRMGR: req=0 medtmr=9 ctr=1846 "
            "uvdc=12.977 ud=0.004 pd=0.005 pvdc=13.325 chg=0.027 low=0 "
            "trf=0 ms=0 ac=180 PM=GG\r")

        self.assertEqual(list(self.stream_header_by_key_chan), ['3203'])
        self.assertEqual(
            sorted(list(self.stream_header_by_key_chan['3203'].keys())),
            self.expected_stream_header_chan_ids)

        self.assertEqual(list(self.gaps), ['3203'])
        self.assertEqual(self.gaps['3203'], [])
        self.assertEqual(list(self.data_time.keys()), ['3203'])
        self.assertListEqual(self.data_time['3203'],
                             self.expected_data_time)

        self.assertEqual(len(os.listdir(self.tmp_dir)), 12)

    @patch('sohstationviewer.model.mseed.blockettes_reader.'
           'NOT_READ_BLOCKETTES', {})
    def test_unimplemented_ascii_block_not_in_not_read_blockettes(self):
        path2file = unimplemented_ascii_blockette_file
        f = io.StringIO()
        with redirect_stdout(f):
            read_mseed_or_text(
                path2file, self.tmp_dir, None, None,
                self.stream_header_by_key_chan,
                self.soh_data, self.mass_pos_data, self.waveform_data,
                self.log_data, self.data_time, [], ['*'], True, True,
                self.track_info, self.populate_cur_key_for_all_data)
        output = f.getvalue()
        self.assertEqual(
            output.strip(),
            f"{path2file.name}: 3203 - ACE: "
            f"Function to read blockette 1001 isn't implemented yet.\n"
            f"{path2file.name}: 3203 - OCF: "
            f"Function to read blockette 1001 isn't implemented yet.")

        self.assertEqual(list(self.soh_data.keys()), ['3203'])
        self.assertEqual(sorted(list(self.soh_data['3203'].keys())),
                         self.expected_soh_chan_ids)
        self.assertEqual(sorted(list(self.mass_pos_data['3203'].keys())),
                         self.expected_mass_pos_chan_ids)
        self.assertEqual(list(self.waveform_data.keys()), ['3203'])
        self.assertEqual(
            sorted(list(self.waveform_data['3203'].keys())),
            self.expected_wf_chan_ids)
        self.assertEqual(list(self.log_data.keys()), ['TEXT', '3203'])
        self.assertListEqual(sorted(list(self.log_data['3203'].keys())),
                             self.expected_log_chan_ids)

        self.assertEqual(list(self.stream_header_by_key_chan), ['3203'])
        self.assertEqual(
            sorted(list(self.stream_header_by_key_chan['3203'].keys())),
            self.expected_stream_header_chan_ids)

        self.assertEqual(list(self.gaps), ['3203'])
        self.assertEqual(self.gaps['3203'], [])

        self.assertEqual(list(self.data_time.keys()), ['3203'])
        self.assertListEqual(self.data_time['3203'],
                             self.expected_data_time)

        self.assertEqual(len(os.listdir(self.tmp_dir)), 12)

    def test_set_time_range(self):
        path2file = unimplemented_ascii_blockette_file
        read_mseed_or_text(
            path2file, self.tmp_dir, 1671730000, 1671730001,
            self.stream_header_by_key_chan,
            self.soh_data, self.mass_pos_data, self.waveform_data,
            self.log_data, self.data_time, [], ['*'], True, True,
            self.track_info, self.populate_cur_key_for_all_data)

        self.assertEqual(list(self.soh_data.keys()), ['3203'])
        self.assertEqual(sorted(list(self.soh_data['3203'].keys())),
                         self.expected_soh_chan_ids)
        self.assertEqual(sorted(list(self.mass_pos_data['3203'].keys())),
                         self.expected_mass_pos_chan_ids)
        self.assertEqual(list(self.waveform_data.keys()), ['3203'])
        self.assertEqual(
            sorted(list(self.waveform_data['3203'].keys())),
            self.expected_wf_chan_ids)
        self.assertEqual(list(self.log_data.keys()), ['TEXT', '3203'])
        self.assertEqual(self.log_data['3203'], {})

        self.assertEqual(list(self.stream_header_by_key_chan), ['3203'])
        self.assertEqual(
            sorted(list(self.stream_header_by_key_chan['3203'].keys())),
            self.expected_stream_header_chan_ids)

        self.assertEqual(list(self.gaps), ['3203'])
        self.assertEqual(self.gaps['3203'], [])

        self.assertEqual(list(self.data_time.keys()), ['3203'])
        self.assertListEqual(self.data_time['3203'],
                             [1671729273.000004, 1671735657.000004])

        self.assertEqual(len(os.listdir(self.tmp_dir)), 0)

    def test_set_req_soh(self):
        path2file = unimplemented_ascii_blockette_file
        expected_soh_chan_ids = ['BS1', 'VCO']
        expected_wf_chan_ids = ['BH1', 'BH2', 'BH3', 'BH4', 'BH5', 'BH6']
        expected_log_chan_ids = ['ACE']

        expected_stream_header_chan_ids = sorted(
            expected_soh_chan_ids + expected_wf_chan_ids +
            expected_log_chan_ids)
        # ACE: ASCII
        # VCO: soh channel with spr=1
        # BS1: soh channel with spr>1
        read_mseed_or_text(
            path2file, self.tmp_dir, None, None,
            self.stream_header_by_key_chan,
            self.soh_data, self.mass_pos_data, self.waveform_data,
            self.log_data, self.data_time, ['ACE', 'VCO', 'BS1'],
            ['BH*'], True, False,
            self.track_info, self.populate_cur_key_for_all_data)

        self.assertEqual(list(self.soh_data.keys()), ['3203'])
        self.assertEqual(sorted(list(self.soh_data['3203'].keys())),
                         expected_soh_chan_ids)
        self.assertEqual(sorted(list(self.mass_pos_data.keys())), ['3203'])
        self.assertEqual(list(self.waveform_data.keys()), ['3203'])
        self.assertEqual(
            sorted(list(self.waveform_data['3203'].keys())),
            expected_wf_chan_ids)
        self.assertEqual(list(self.log_data.keys()), ['TEXT', '3203'])
        self.assertListEqual(sorted(list(self.log_data['3203'].keys())),
                             expected_log_chan_ids)

        self.assertEqual(list(self.stream_header_by_key_chan), ['3203'])
        self.assertEqual(
            sorted(list(self.stream_header_by_key_chan['3203'].keys())),
            expected_stream_header_chan_ids)

        self.assertEqual(list(self.gaps), ['3203'])
        self.assertEqual(self.gaps['3203'], [])

        self.assertEqual(list(self.data_time.keys()), ['3203'])
        self.assertListEqual(self.data_time['3203'],
                             [1671729298.000004, 1671735657.179998])

        self.assertEqual(
            self.log_data['3203']['ACE'][0].split('\n')[2],
            "STATE OF HEALTH: From:2022-12-22T17:14:47.000000Z  "
            "To:2022-12-22T17:14:47.000000Z")
        self.assertEqual(len(os.listdir(self.tmp_dir)), 7)

    def test_text_file(self):
        path2file = log_text_file

        read_mseed_or_text(
            path2file, self.tmp_dir, None, None,
            self.stream_header_by_key_chan,
            self.soh_data, self.mass_pos_data, self.waveform_data,
            self.log_data, self.data_time, [], ['*'], True, True,
            self.track_info, self.populate_cur_key_for_all_data)

        self.assertEqual(self.soh_data, {})
        self.assertEqual(self.waveform_data, {})
        self.assertEqual(self.data_time, {})
        self.assertEqual(list(self.log_data.keys()), ['TEXT'])
        self.assertEqual(self.stream_header_by_key_chan, {})
        self.assertEqual(self.gaps, {})
        self.assertEqual(len(os.listdir(self.tmp_dir)), 0)