from tempfile import TemporaryDirectory, NamedTemporaryFile
from pathlib import Path

from unittest import TestCase
from unittest.mock import patch

from obspy import UTCDateTime

from sohstationviewer.controller.processing import (
    read_mseed_channels,
    detect_data_type,
    get_data_type_from_file
)
from sohstationviewer.database.extract_data import get_signature_channels
from PySide2 import QtWidgets


TEST_DATA_DIR = Path(__file__).resolve().parent.parent.joinpath('test_data')
rt130_dir = TEST_DATA_DIR.joinpath('RT130-sample/2017149.92EB/2017150')
q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
centaur_dir = TEST_DATA_DIR.joinpath('Centaur-sample/SOH')
pegasus_dir = TEST_DATA_DIR.joinpath('Pegasus-sample/Pegasus_SVC4/soh')
multiplex_dir = TEST_DATA_DIR.joinpath('Q330_multiplex')


class TestReadMSeedChannel(TestCase):
    """Test suite for read_mseed_channels."""

    def setUp(self) -> None:
        """Set up test fixtures."""
        patcher = patch.object(QtWidgets, 'QTextBrowser')
        self.addCleanup(patcher.stop)
        self.widget_stub = patcher.start()

        # The actual value can be 'Q330', 'Centaur', or 'Pegasus'.
        # The code that uses this string only cares that it is not 'RT130',
        # though, so we are setting it to a stub value.
        self.mseed_dtype = 'MSeed'

    def test_read_channels_mseed_dir(self):
        """
        Test basic functionality of load_data - the given directory contains
        MSeed data.
        """
        with self.subTest("q330 - non multiplex"):
            q330_soh_channels = sorted(['LOG', 'VKI'])
            q330_mass_pos_channels = ['VM1']
            q330_wf_channels = ['HHE', 'LHE']
            q330_spr_gt_1 = []
            ret = read_mseed_channels(self.widget_stub, [q330_dir],
                                      False, True)
            self.assertListEqual(list(ret.keys()), ['AX08'])
            self.assertListEqual(ret['AX08']['soh'], q330_soh_channels)
            self.assertListEqual(ret['AX08']['mass_pos'],
                                 q330_mass_pos_channels)
            self.assertListEqual(ret['AX08']['waveform'], q330_wf_channels)
            self.assertListEqual(ret['AX08']['soh_spr_gr_1'], q330_spr_gt_1)
            self.assertEqual(ret['AX08']['start_time'], 1625443228.34)
            self.assertEqual(ret['AX08']['end_time'], UTCDateTime(0))

        with self.subTest(("centaur - multiplex - is_multiplex=True")):
            centaur_soh_channels = sorted(
                ['VDT', 'EX3', 'GEL', 'VEC', 'EX2', 'LCE', 'EX1', 'GLA', 'LCQ',
                 'GPL', 'GNS', 'GST', 'VCO', 'GAN', 'GLO', 'VPB', 'VEI'])
            centaur_mass_pos_channels = sorted(['VM1', 'VM2', 'VM3'])
            centaur_wf_channels = []
            centaur_spr_gt_1 = []
            ret = read_mseed_channels(self.widget_stub, [centaur_dir],
                                      True, True)
            self.assertListEqual(list(ret.keys()), ['3734'])
            self.assertListEqual(ret['3734']['soh'], centaur_soh_channels)
            self.assertListEqual(ret['3734']['mass_pos'],
                                 centaur_mass_pos_channels)
            self.assertListEqual(ret['3734']['waveform'], centaur_wf_channels)
            self.assertListEqual(ret['3734']['soh_spr_gr_1'], centaur_spr_gt_1)
            self.assertEqual(ret['3734']['start_time'], 1534464000.0)
            self.assertEqual(ret['3734']['end_time'], 1534550400.0)

        with self.subTest("centaur - multiplex - is_multiplex=False"):
            # not all channels detected if is_multiplex isn't set
            centaur_soh_channels = sorted(['GEL'])
            centaur_mass_pos_channels = sorted([])
            centaur_wf_channels = []
            centaur_spr_gt_1 = []
            ret = read_mseed_channels(self.widget_stub, [centaur_dir],
                                      False, True)
            self.assertListEqual(list(ret.keys()), ['3734'])
            self.assertListEqual(ret['3734']['soh'], centaur_soh_channels)
            self.assertListEqual(ret['3734']['mass_pos'],
                                 centaur_mass_pos_channels)
            self.assertListEqual(ret['3734']['waveform'], centaur_wf_channels)
            self.assertListEqual(ret['3734']['soh_spr_gr_1'], centaur_spr_gt_1)
            self.assertEqual(ret['3734']['start_time'], 1534464000.0)
            self.assertEqual(ret['3734']['end_time'], UTCDateTime(0))

        with self.subTest("pegasus - non multiplex"):
            pegasus_soh_channels = sorted(['VDT', 'VE1'])
            pegasus_mass_pos_channels = sorted(['VM1'])
            pegasus_wf_channels = []
            pegasus_spr_gt_1 = []
            ret = read_mseed_channels(self.widget_stub, [pegasus_dir],
                                      False, True)
            self.assertListEqual(list(ret.keys()), ['KC01'])
            self.assertListEqual(ret['KC01']['soh'], pegasus_soh_channels)
            self.assertListEqual(ret['KC01']['mass_pos'],
                                 pegasus_mass_pos_channels)
            self.assertListEqual(ret['KC01']['waveform'], pegasus_wf_channels)
            self.assertListEqual(ret['KC01']['soh_spr_gr_1'], pegasus_spr_gt_1)
            self.assertEqual(ret['KC01']['start_time'], 1588978560.0)
            self.assertEqual(ret['KC01']['end_time'], UTCDateTime(0))

        with self.subTest("q330 - multiplex - is_multiplex=True"):
            multiplex_soh_channels = ['LOG']
            multiplex_mass_pos_channels = []
            multiplex_wf_channels = sorted(
                ['BH1', 'BH2', 'BH3', 'BH4', 'BH5', 'BH6',
                 'EL1', 'EL2', 'EL4', 'EL5', 'EL6', 'ELZ'])
            multiplex_spr_gt_1 = sorted(
                ['BS1', 'BS2', 'BS3', 'BS4', 'BS5', 'BS6',
                 'ES1', 'ES2', 'ES3', 'ES4', 'ES5', 'ES6',
                 'LS1', 'LS2', 'LS3', 'LS4', 'LS5', 'LS6',
                 'SS1', 'SS2', 'SS3', 'SS4', 'SS5', 'SS6'])
            ret = read_mseed_channels(self.widget_stub, [multiplex_dir],
                                      True, True)
            self.assertListEqual(list(ret.keys()), ['3203'])
            self.assertListEqual(ret['3203']['soh'], multiplex_soh_channels)
            self.assertListEqual(ret['3203']['mass_pos'],
                                 multiplex_mass_pos_channels)
            self.assertListEqual(ret['3203']['waveform'],
                                 multiplex_wf_channels)
            self.assertListEqual(ret['3203']['soh_spr_gr_1'],
                                 multiplex_spr_gt_1)
            self.assertEqual(ret['3203']['start_time'],
                             1671729934.9000392)
            self.assertEqual(ret['3203']['end_time'], 1671733530.57)

        with self.subTest("q330 - multiplex - is_multiplex=False"):
            # not all channels detected if is_multiplex isn't set
            multiplex_soh_channels = []
            multiplex_mass_pos_channels = []
            multiplex_wf_channels = sorted(['EL1'])
            multiplex_spr_gt_1 = sorted([])
            ret = read_mseed_channels(self.widget_stub, [multiplex_dir],
                                      False, True)
            self.assertListEqual(list(ret.keys()), ['3203'])
            self.assertListEqual(ret['3203']['soh'], multiplex_soh_channels)
            self.assertListEqual(ret['3203']['mass_pos'],
                                 multiplex_mass_pos_channels)
            self.assertListEqual(ret['3203']['waveform'],
                                 multiplex_wf_channels)
            self.assertListEqual(ret['3203']['soh_spr_gr_1'],
                                 multiplex_spr_gt_1)
            self.assertEqual(ret['3203']['start_time'],
                             1671730004.145029)
            self.assertEqual(ret['3203']['end_time'], UTCDateTime(0))

    def test_read_channels_rt130_dir(self):
        """
        Test basic functionality of load_data - the given directory contains
        RT130 data.
        """
        ret = read_mseed_channels(self.widget_stub, [rt130_dir], True, True)
        self.assertEqual(ret, {})

    def test_read_mseed_channels_no_dir(self):
        """
        Test basic functionality of read_mseed_channels - no directory was
        given.
        """
        no_dir = []
        ret = read_mseed_channels(self.widget_stub, no_dir, True, True)
        self.assertEqual(ret, {})

    def test_read_mseed_channels_dir_does_not_exist(self):
        """
        Test basic functionality of read_mseed_channels - the given directory
        does not exist.
        """
        empty_name_dir = ['']
        ret = read_mseed_channels(self.widget_stub, empty_name_dir, True, True)
        self.assertEqual(ret, {})

        non_existent_dir = ['non_existent_dir']
        ret = read_mseed_channels(self.widget_stub, non_existent_dir,
                                  True, True)
        self.assertEqual(ret, {})

    def test_read_mseed_channels_empty_dir(self):
        """
        Test basic functionality of read_mseed_channels - the given directory
        is empty.
        """
        with TemporaryDirectory() as empty_dir:
            ret = read_mseed_channels(self.widget_stub, [empty_dir],
                                      True, True)
            self.assertEqual(ret, {})

    def test_read_mseed_channels_empty_data_dir(self):
        """
        Test basic functionality of read_mseed_channels - the given directory
        contains a data folder but no data file.
        """
        with TemporaryDirectory() as outer_dir:
            with TemporaryDirectory(dir=outer_dir):
                ret = read_mseed_channels(self.widget_stub, [outer_dir],
                                          True, True)
                self.assertEqual(ret, {})


class TestDetectDataType(TestCase):
    """
    Test suite for detect_data_type and get_data_type_from_file functions.
    """

    def setUp(self) -> None:
        """Set up text fixtures."""
        func_patcher = patch('sohstationviewer.controller.processing.'
                             'get_data_type_from_file')
        self.addCleanup(func_patcher.stop)
        self.mock_get_data_type_from_file = func_patcher.start()

        self.dir1 = TemporaryDirectory()
        self.dir2 = TemporaryDirectory()
        self.file1 = NamedTemporaryFile(dir=self.dir1.name)
        self.file2 = NamedTemporaryFile(dir=self.dir2.name)

    def tearDown(self) -> None:
        """Teardown text fixtures."""
        del self.file1, self.file2
        self.dir1.cleanup()
        self.dir2.cleanup()

    def test_one_directory_not_unknown_data_type(self):
        """
        Test basic functionality of detect_data_type - only one directory was
        given and the data type it contains can be detected.
        """
        expected_data_type = ('RT130', False)
        self.mock_get_data_type_from_file.return_value = expected_data_type

        self.assertEqual(
            detect_data_type([self.dir1.name]),
            expected_data_type
        )

    def test_same_data_type_not_multiplex(self):
        """
        Test basic functionality of detect_data_type - the given directories
        contain the same data type and the data type was detected using the
        same channel.
        """
        expected_data_type = ('RT130', False)
        self.mock_get_data_type_from_file.return_value = expected_data_type

        self.assertEqual(
            detect_data_type([self.dir1.name, self.dir2.name]),
            expected_data_type
        )

    def test_same_data_type_multiplex(self):
        """
        Test basic functionality of detect_data_type - the given directories
        contain the same data type but the data type was detected using
        different channels.
        """
        returned_data_types = [('Q330', True), ('Q330', True)]
        self.mock_get_data_type_from_file.side_effect = returned_data_types

        self.assertEqual(
            detect_data_type([self.dir1.name, self.dir2.name]),
            returned_data_types[0]
        )

    def test_different_data_types(self):
        """
        Test basic functionality of detect_data_type - the given directories
        contain different data types.
        """
        returned_data_types = [('RT130', False), ('Q330', False)]
        self.mock_get_data_type_from_file.side_effect = returned_data_types

        with self.assertRaises(Exception) as context:
            detect_data_type([self.dir1.name, self.dir2.name])
        self.assertEqual(
            str(context.exception),
            f"There are more than one types of data detected:\n"
            f"{self.dir1.name}: RT130, "
            f"{self.dir2.name}: Q330\n\n"
            f"Please have only data that related to each other.")

    def test_unknown_data_type(self):
        """
        Test basic functionality of detect_data_type - can't detect any data
        type.
        """
        unknown_data_type = ('Unknown', False)
        self.mock_get_data_type_from_file.return_value = unknown_data_type
        with self.assertRaises(Exception) as context:
            detect_data_type([self.dir1.name])
        self.assertEqual(
            str(context.exception),
            "There are no known data detected.\n\n"
            "Do you want to cancel to select different folder(s)\n"
            "Or continue to read any available mseed file?")

    def test_multiplex_none(self):
        """
        Test basic functionality of detect_data_type - can't detect any data
        type.
        """
        unknown_data_type = ('Unknown', None)
        self.mock_get_data_type_from_file.return_value = unknown_data_type
        with self.assertRaises(Exception) as context:
            detect_data_type([self.dir1.name])
        self.assertEqual(
            str(context.exception),
            "No channel found for the data set")


class TestGetDataTypeFromFile(TestCase):
    """Test suite for get_data_type_from_file"""
    def test_rt130_data(self):
        """
        Test basic functionality of get_data_type_from_file - given file
        contains RT130 data.
        """
        rt130_file = Path(rt130_dir).joinpath(
            '92EB/0/000000000_00000000')
        expected_data_type = ('RT130', False)
        self.assertTupleEqual(
            get_data_type_from_file(rt130_file, get_signature_channels()),
            expected_data_type
        )

    def test_cannot_detect_data_type(self):
        """
        Test basic functionality of get_data_type_from_file - cannot detect
        data type contained in given file.
        """
        test_file = NamedTemporaryFile()
        ret = get_data_type_from_file(
            Path(test_file.name), get_signature_channels())
        self.assertEqual(ret, (None, False))

    def test_mseed_data(self):
        """
        Test basic functionality of get_data_type_from_file - given file
        contains MSeed data.
        """
        q330_file = q330_dir.joinpath('AX08.XA..VKI.2021.186')
        centaur_file = centaur_dir.joinpath(
            'XX.3734.SOH.centaur-3_3734..20180817_000000.miniseed.miniseed')
        pegasus_file = pegasus_dir.joinpath(
            '2020/XX/KC01/VE1.D/XX.KC01..VE1.D.2020.129')
        q330_data_type = ('Q330', False)
        centaur_data_type = ('Centaur', True)
        pegasus_data_type = ('Pegasus', False)

        sig_chan = get_signature_channels()

        self.assertTupleEqual(get_data_type_from_file(q330_file, sig_chan),
                              q330_data_type)
        self.assertTupleEqual(get_data_type_from_file(centaur_file, sig_chan),
                              centaur_data_type)
        self.assertTupleEqual(get_data_type_from_file(pegasus_file, sig_chan),
                              pegasus_data_type)

    def test_file_does_not_exist(self):
        """
        Test basic functionality of get_data_type_from_file - given file does
        not exist.
        """
        empty_name_file = Path('')
        non_existent_file = Path('non_existent_dir')
        with self.assertRaises(IsADirectoryError):
            get_data_type_from_file(empty_name_file, get_signature_channels())
        with self.assertRaises(FileNotFoundError):
            get_data_type_from_file(non_existent_file,
                                    get_signature_channels())

    def test_non_data_binary_file(self):
        binary_file = Path(__file__).resolve().parent.parent.parent.joinpath(
            'images', 'home.png')
        ret = get_data_type_from_file(binary_file, get_signature_channels())
        self.assertIsNone(ret)