from pathlib import Path
from tempfile import TemporaryDirectory

from unittest import TestCase
from unittest.mock import patch

import numpy as np
from obspy.core import UTCDateTime

from sohstationviewer.conf import constants as const
from sohstationviewer.model.handling_data import (
    downsample,
    trim_downsample_WFChan,
    get_eachDay5MinList,
    get_trimTPSData,
)

ORIGINAL_CHAN_SIZE_LIMIT = const.CHAN_SIZE_LIMIT
ORIGINAL_RECAL_SIZE_LIMIT = const.RECAL_SIZE_LIMIT


class TestTrimDownsampleWfChan(TestCase):
    def no_file_memmap(self, file_path: Path, **kwargs):
        # Data will look the same as times. This has two benefits:
        # - It is a lot easier to inspect what data remains after trimming
        # and downsampling, seeing as the remaining data would be the same
        # as the remaining times.
        # - It is a lot easier to reproducibly create a test data set.
        array_size = 100
        file_idx = int(file_path.name.split('-')[-1])
        start = file_idx * array_size
        end = start + array_size
        return np.arange(start, end)

    def setUp(self) -> None:
        memmap_patcher = patch.object(np, 'memmap',
                                      side_effect=self.no_file_memmap)
        self.addCleanup(memmap_patcher.stop)
        memmap_patcher.start()

        self.channel_data = {}
        self.traces_info = []
        self.channel_data['tracesInfo'] = self.traces_info
        self.data_folder = TemporaryDirectory()
        for i in range(100):
            trace_size = 100
            start_time = i * trace_size
            trace = {}
            trace['startTmEpoch'] = start_time
            trace['endTmEpoch'] = start_time + trace_size - 1
            trace['size'] = trace_size

            times_file_name = Path(self.data_folder.name) / f'times-{i}'
            trace['times_f'] = times_file_name

            data_file_name = Path(self.data_folder.name) / f'data-{i}'
            trace['data_f'] = data_file_name

            self.traces_info.append(trace)
        self.start_time = 2500
        self.end_time = 7500

    def test_data_out_of_range(self):
        with self.subTest('test_start_time_later_than_data_end_time'):
            self.start_time = 12500
            self.end_time = 17500
            self.assertFalse(
                trim_downsample_WFChan(self.channel_data, self.start_time,
                                       self.end_time, True)
            )
        with self.subTest('test_end_time_earlier_than_data_start_time'):
            self.start_time = -7500
            self.end_time = -2500
            self.assertFalse(
                trim_downsample_WFChan(self.channel_data, self.start_time,
                                       self.end_time, True)
            )

    def test_result_is_stored(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue('times' in self.channel_data)
        self.assertGreater(len(self.channel_data['times']), 0)
        self.assertTrue('data' in self.channel_data)
        self.assertGreater(len(self.channel_data['data']), 0)

    def test_data_is_trimmed(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        is_times_trimmed = ((self.channel_data['times'] >= self.start_time) &
                            (self.channel_data['times'] <= self.end_time))
        self.assertTrue(is_times_trimmed.all())
        is_data_trimmed = ((self.channel_data['data'] >= self.start_time) &
                           (self.channel_data['data'] <= self.end_time))
        self.assertTrue(is_data_trimmed.all())

    @patch('sohstationviewer.model.handling_data.downsample', wraps=downsample)
    def test_data_is_downsampled(self, mock_downsample):
        const.CHAN_SIZE_LIMIT = 1000
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue(mock_downsample.called)
        const.CHAN_SIZE_LIMIT = ORIGINAL_CHAN_SIZE_LIMIT

    def test_data_small_enough_after_first_trim_flag_is_set(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue('fulldata' in self.channel_data)

    def test_no_additional_work_if_data_small_enough_after_first_trim(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        current_times = self.channel_data['times']
        current_data = self.channel_data['data']
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertIs(current_times, self.channel_data['times'])
        self.assertIs(current_data, self.channel_data['data'])

    def test_data_too_large_after_first_time(self):
        const.RECAL_SIZE_LIMIT = 1
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, False)
        self.assertTrue('times' not in self.channel_data)
        self.assertTrue('data' not in self.channel_data)
        const.RECAL_SIZE_LIMIT = ORIGINAL_RECAL_SIZE_LIMIT


class TestGetTrimTpsData(TestCase):
    def no_file_memmap(self, file_path: Path, **kwargs):
        # Data will look the same as times. This has two benefits:
        # - It is a lot easier to inspect what data remains after trimming
        # and downsampling, seeing as the remaining data would be the same
        # as the remaining times.
        # - It is a lot easier to reproducibly create a test data set.
        file_idx = int(file_path.name.split('-')[-1])
        if file_idx < const.SEC_DAY:
            array_size = 100
            start = file_idx * array_size
            end = start + array_size
        else:
            array_size = const.SEC_5M * 10
            start = file_idx
            end = start + array_size
        return np.arange(start, end)

    def setUp(self) -> None:
        memmap_patcher = patch.object(np, 'memmap',
                                      side_effect=self.no_file_memmap)
        self.addCleanup(memmap_patcher.stop)
        memmap_patcher.start()

        self.channel_data = {'samplerate': 1}
        self.traces_info = []
        self.channel_data['tracesInfo'] = self.traces_info
        self.data_folder = TemporaryDirectory()
        for i in range(100):
            trace_size = 100
            start_time = i * trace_size
            trace = {}
            trace['startTmEpoch'] = start_time
            trace['endTmEpoch'] = start_time + trace_size - 1
            trace['size'] = trace_size

            times_file_name = Path(self.data_folder.name) / f'times-{i}'
            trace['times_f'] = times_file_name

            data_file_name = Path(self.data_folder.name) / f'data-{i}'
            trace['data_f'] = data_file_name

            self.traces_info.append(trace)

        data_start_time = 0
        data_end_time = 10000
        self.each_day_5_min_list = get_eachDay5MinList(data_start_time,
                                                       data_end_time)
        self.start_time = 2500
        self.end_time = 7500

    def test_data_out_of_range(self):
        with self.subTest('test_start_time_later_than_data_end_time'):
            self.start_time = 12500
            self.end_time = 17500
            self.assertFalse(
                get_trimTPSData(self.channel_data, self.start_time,
                                self.end_time, self.each_day_5_min_list)
            )
        with self.subTest('test_end_time_earlier_than_data_start_time'):
            self.start_time = -7500
            self.end_time = -2500
            self.assertFalse(
                get_trimTPSData(self.channel_data, self.start_time,
                                self.end_time, self.each_day_5_min_list)
            )

    def test_result_is_stored_one_day_of_data(self):
        num_day = 1
        get_trimTPSData(self.channel_data, self.start_time,
                        self.end_time, self.each_day_5_min_list)
        self.assertTrue('tps_data' in self.channel_data)
        self.assertEqual(len(self.channel_data['tps_data']), num_day)

    def test_result_is_stored_multiple_days_of_data(self):
        second_day_data_start_time = int(
            UTCDateTime(1970, 1, 2, 18, 0, 0).timestamp
        )
        self.end_time = UTCDateTime(1970, 1, 2, 6, 0, 0).timestamp

        trace_size = const.SEC_5M * 10
        trace = {}
        trace['startTmEpoch'] = second_day_data_start_time
        trace['endTmEpoch'] = second_day_data_start_time + trace_size - 1
        trace['size'] = trace_size

        times_file_name = Path(self.data_folder.name).joinpath(
            f'times-{second_day_data_start_time}'
        )
        trace['times_f'] = times_file_name

        data_file_name = Path(self.data_folder.name).joinpath(
            f'data-{second_day_data_start_time}'
        )
        trace['data_f'] = data_file_name

        self.traces_info.append(trace)

        self.each_day_5_min_list = get_eachDay5MinList(0, trace['endTmEpoch'])

        num_day = 2
        get_trimTPSData(self.channel_data, self.start_time,
                        self.end_time, self.each_day_5_min_list)
        self.assertTrue('tps_data' in self.channel_data)
        self.assertEqual(len(self.channel_data['tps_data']), num_day)

    def test_data_has_gaps(self):
        third_day_data_start_time = int(
            UTCDateTime(1970, 1, 3, 18, 0, 0).timestamp
        )
        self.end_time = UTCDateTime(1970, 1, 2, 6, 0, 0).timestamp

        trace_size = const.SEC_5M * 10
        trace = {}
        trace['startTmEpoch'] = third_day_data_start_time
        trace['endTmEpoch'] = third_day_data_start_time + trace_size - 1
        trace['size'] = trace_size

        times_file_name = Path(self.data_folder.name).joinpath(
            f'times-{third_day_data_start_time}'
        )
        trace['times_f'] = times_file_name

        data_file_name = Path(self.data_folder.name).joinpath(
            f'data-{third_day_data_start_time}'
        )
        trace['data_f'] = data_file_name

        self.traces_info.append(trace)

        self.each_day_5_min_list = get_eachDay5MinList(0, trace['endTmEpoch'])

        num_day = 3
        get_trimTPSData(self.channel_data, self.start_time,
                        self.end_time, self.each_day_5_min_list)
        self.assertTrue('tps_data' in self.channel_data)
        self.assertEqual(len(self.channel_data['tps_data']), num_day)