from pathlib import Path
import tempfile
from math import isclose

from unittest import TestCase
from unittest.mock import patch

from obspy.core import Stream, read as read_ms
import numpy as np

from sohstationviewer.model.handling_data import (
    readSOHMSeed,
    readSOHTrace,
    readMPTrace,
    readWaveformTrace,
    readWaveformMSeed,
    readWaveformReftek
)
from sohstationviewer.model.reftek.from_rt2ms.core import Reftek130

TEST_DATA_DIR = Path(__file__).parent.parent.joinpath('test_data')


# tempfile.tempdir = './tempdir'


class TestHandlingData(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        rt130_dir = TEST_DATA_DIR.joinpath(
            'RT130-sample/2017149.92EB/2017150/92EB')
        cls.rt130_soh_file = rt130_dir.joinpath('9/054910000_013EE8A0')
        rt130_soh = Reftek130.from_file(cls.rt130_soh_file)
        cls.rt130_soh_stream = Reftek130.to_stream(rt130_soh)
        cls.rt130_soh_trace = cls.rt130_soh_stream[0]
        cls.rt130_waveform_file = rt130_dir.joinpath('9/054910000_013EE8A0')
        rt130_waveform = Reftek130.from_file(cls.rt130_waveform_file)
        cls.rt130_waveform_stream = Reftek130.to_stream(rt130_waveform)
        cls.rt130_waveform_trace = cls.rt130_waveform_stream[0]

        q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
        cls.q330_soh_file = q330_dir.joinpath('AX08.XA..VKI.2021.186')
        cls.mseed_soh_stream = read_ms(cls.q330_soh_file)
        cls.mseed_soh_trace = cls.mseed_soh_stream[0]
        cls.q330_waveform_file = q330_dir.joinpath('AX08.XA..LHE.2021.186')
        cls.mseed_waveform_stream = read_ms(cls.q330_waveform_file)
        cls.mseed_waveform_trace = cls.mseed_waveform_stream[0]

    # @expectedFailure
    # def test_read_sohmseed(self):
    #     self.fail()

    def test_read_soh_trace_processed_trace_have_all_needed_info(self):
        processed_trace = readSOHTrace(self.mseed_soh_trace)
        with self.subTest('test_processed_trace_have_all_needed_info'):
            expected_key_list = [
                'chanID',
                'samplerate',
                'startTmEpoch',
                'endTmEpoch',
                'times',
                'data'
            ]
            self.assertTrue(
                all(key in processed_trace for key in expected_key_list),
            )

    def test_read_soh_trace_times_calculated_correctly(self):
        processed_trace = readSOHTrace(self.mseed_soh_trace)
        if isclose(processed_trace['startTmEpoch'], 0, abs_tol=0.0001):
            self.assertAlmostEqual(processed_trace['times'][0], 0)
        else:
            self.assertNotAlmostEqual(processed_trace['times'][0], 0)

    @patch('sohstationviewer.model.handling_data.readSOHTrace')
    def test_read_mp_trace(self, mock_read_soh_trace):
        mock_read_soh_trace.return_value = {
            # Contain five cases:
            # + Small positive
            # + Large positive
            # + Small negative
            # + Large negative
            # + Zero
            'data': np.array([1, 27272, -2, -23526, 0])
        }
        expected = np.array([0, 8.3, 0, -7.2, 0])
        processed_trace = readMPTrace(self.rt130_soh_trace)
        self.assertTrue(
            np.array_equal(processed_trace['data'], expected)
        )

    @patch('sohstationviewer.model.handling_data.saveData2File')
    def test_read_waveform_trace(self, mock_save_data_2_file):
        station_id = self.rt130_soh_trace.stats['station']
        channel_id = self.rt130_soh_trace.stats['channel']
        # The function itself only cares about the length of this list so we
        # stub it out.
        traces_info = [dict() for _ in range(4)]
        tmp_dir = tempfile.TemporaryDirectory()
        processed_trace = readWaveformTrace(
            self.rt130_soh_trace, station_id, channel_id, traces_info, tmp_dir.name
        )

        expected_key_list = [
            'samplerate',
            'startTmEpoch',
            'endTmEpoch',
            'size',
            'times_f',
            'data_f',
        ]
        self.assertTrue(
            all(key in processed_trace for key in expected_key_list),
        )

        self.assertTrue(mock_save_data_2_file.called)



    # @skip
    # def test_read_ascii(self):
    #     self.fail()
    #
    # @skip
    # def test_read_text(self):
    #     self.fail()
    #
    # @skip
    # def test_save_data2file(self):
    #     self.fail()
    #
    # @skip
    # def test_check_chan(self):
    #     self.fail()
    #
    # @skip
    # def test_check_sohchan(self):
    #     self.fail()
    #
    # @skip
    # def test_check_wfchan(self):
    #     self.fail()
    #
    # @skip
    # def test_sort_data(self):
    #     self.fail()
    #
    # @skip
    # def test_squash_gaps(self):
    #     self.fail()
    #
    # @skip
    # def test_downsample(self):
    #     self.fail()
    #
    # @skip
    # def test_constant_rate(self):
    #     self.fail()
    #
    # @skip
    # def test_chunk_minmax(self):
    #     self.fail()
    #
    # @skip
    # def test_trim_downsample_sohchan(self):
    #     self.fail()
    #
    # @skip
    # def test_trim_downsample_wfchan(self):
    #     self.fail()
    #
    # @skip
    # def test_get_each_day5min_list(self):
    #     self.fail()
    #
    # @skip
    # def test_get_trim_tpsdata(self):
    #     self.fail()
    #
    # @skip
    # def test_find_tpstm(self):
    #     self.fail()


class TestReadWaveformMSeed(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
        cls.q330_waveform_file = q330_dir.joinpath('AX08.XA..LHE.2021.186')
        cls.mseed_waveform_stream = read_ms(cls.q330_waveform_file)
        cls.mseed_waveform_trace = cls.mseed_waveform_stream[0]

    def setUp(self) -> None:
        self.station_id = self.mseed_waveform_trace.stats['station']
        self.channel_id = self.mseed_waveform_trace.stats['channel']
        # This list is only ever written to, so we can keep it empty
        self.traces_info = []
        self.data_time = [0, 0]
        self.temp_dir = tempfile.TemporaryDirectory()

        patcher = patch(
            'sohstationviewer.model.handling_data.readWaveformTrace')
        self.addCleanup(patcher.stop)
        self.mock_readWaveformTrace = patcher.start()
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 0,
            'endTmEpoch': 0,
        }

    def tearDown(self) -> None:
        self.temp_dir.cleanup()

    def test_all_traces_are_processed(self):
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(len(self.traces_info), len(self.mseed_waveform_stream))

    def test_readWaveformTrace_called(self):
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertTrue(self.mock_readWaveformTrace.called)

    def test_start_data_time_earlier_than_earliest_start_time(self):
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 0
        self.data_time = [start_time, 1625532949]

        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time[0], start_time)

    def test_end_data_time_later_than_latest_start_time(self):
        # End time set to be the last second of 9999 so as to be later than the
        # end time of any test data.
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        end_time = 253402326000
        self.data_time = [1625443222, end_time]

        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time[1], end_time)

    def test_data_time_is_between_earliest_start_and_latest_end_time(self):
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 512579
        end_time = 2623616
        self.data_time = [start_time, end_time]

        expected_updated_data_time = [51251, 2623623]
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time, expected_updated_data_time)


class TestReadWaveformReftek(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        rt130_dir = TEST_DATA_DIR.joinpath(
            'RT130-sample/2017149.92EB/2017150/92EB')
        cls.rt130_waveform_file = rt130_dir.joinpath('1/000000015_0036EE80')
        cls.rt130_waveform = Reftek130.from_file(cls.rt130_waveform_file)
        cls.rt130_waveform_stream = Reftek130.to_stream(cls.rt130_waveform)
        cls.rt130_waveform_trace = cls.rt130_waveform_stream[0]

    def setUp(self) -> None:
        # The first element of the key, the unit ID, should be a character
        # string. However, the unit ID is stored as a byte string in the data
        # stream. While having the unit ID as a byte string should not cause
        # any problem with the test, we convert it to a character string anyway
        # to be consistent with how readWaveformReftek is called.
        self.key = (
            self.rt130_waveform_trace.stats['unit_id'].decode('utf-8'),
            self.rt130_waveform_trace.stats['experiment_number']
        )
        self.read_data = {}
        self.data_time = [0, 0]
        self.temp_dir = tempfile.TemporaryDirectory()

        patcher = patch(
            'sohstationviewer.model.handling_data.readWaveformTrace')
        self.addCleanup(patcher.stop)
        self.mock_readWaveformTrace = patcher.start()
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 0,
            'endTmEpoch': 0,
        }

    def tearDown(self) -> None:
        self.temp_dir.cleanup()

    def test_all_traces_are_processed(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        num_traces_read = 0
        for channel_data in self.read_data.values():
            num_traces_read += len(channel_data['tracesInfo'])
        self.assertTrue(num_traces_read, len(self.rt130_waveform_stream))

    def test_read_data_updated_with_all_channels(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertTupleEqual(tuple(self.read_data.keys()),
                              ('DS1-1', 'DS1-2', 'DS1-3'))

    def test_read_data_existing_channel_appended_to(self):
        self.read_data = {
            'DS1-1':
                {'tracesInfo':
                     [{'startTmEpoch': 0, 'endTmEpoch': 0}],
                 'samplerate': 40.0}
        }
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(len(self.read_data['DS1-1']['tracesInfo']), 2)

    def test_readWaveformTrace_called(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertTrue(self.mock_readWaveformTrace.called)

    def test_start_data_time_earlier_than_earliest_start_time(self):
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 0
        self.data_time = [start_time, 1625532949]

        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time[0], start_time)

    def test_end_data_time_later_than_latest_start_time(self):
        # End time set to be the last second of 9999 so as to be later than the
        # end time of any test data.
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        end_time = 253402326000
        self.data_time = [1625443222, end_time]

        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time[1], end_time)

    def test_data_time_is_between_earliest_start_and_latest_end_time(self):
        self.mock_readWaveformTrace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 512579
        end_time = 2623616
        self.data_time = [start_time, end_time]

        expected_updated_data_time = [51251, 2623623]
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time, expected_updated_data_time)