from pathlib import Path
from tempfile import TemporaryDirectory

from unittest import TestCase
from unittest.mock import patch

import numpy as np

from sohstationviewer.conf import constants as const
from sohstationviewer.model.handling_data import (
    downsample,
    trim_downsample_WFChan,
)

ORIGINAL_CHAN_SIZE_LIMIT = const.CHAN_SIZE_LIMIT
ORIGINAL_RECAL_SIZE_LIMIT = const.RECAL_SIZE_LIMIT


class TestTrimDownsampleWfChan(TestCase):
    def no_file_memmap(self, file_path: Path, **kwargs):
        # Data will look the same as times. This has two benefits:
        # - It is a lot easier to inspect what data remains after trimming
        # and downsampling, seeing as the remaining data would be the same
        # as the remaining times.
        # - It is a lot easier to reproducibly create a test data set.
        array_size = 100
        file_idx = int(file_path.name.split('-')[-1])
        start = file_idx * array_size
        end = start + array_size
        return np.arange(start, end)

    def setUp(self) -> None:
        memmap_patcher = patch.object(np, 'memmap',
                                      side_effect=self.no_file_memmap)
        self.addCleanup(memmap_patcher.stop)
        memmap_patcher.start()

        self.channel_data = {}
        self.traces_info = []
        self.channel_data['tracesInfo'] = self.traces_info
        self.data_folder = TemporaryDirectory()
        for i in range(100):
            trace_size = 100
            start_time = i * trace_size
            trace = {}
            trace['startTmEpoch'] = start_time
            trace['endTmEpoch'] = start_time + trace_size - 1
            trace['size'] = trace_size

            times_file_name = Path(self.data_folder.name) / f'times-{i}'
            trace['times_f'] = times_file_name

            data_file_name = Path(self.data_folder.name) / f'data-{i}'
            trace['data_f'] = data_file_name

            self.traces_info.append(trace)
        self.start_time = 2500
        self.end_time = 7500

    def test_data_out_of_range(self):
        with self.subTest('test_start_time_later_than_data_end_time'):
            self.start_time = 12500
            self.end_time = 17500
            self.assertFalse(
                trim_downsample_WFChan(self.channel_data, self.start_time,
                                       self.end_time, True)
            )
        with self.subTest('test_end_time_earlier_than_data_start_time'):
            self.start_time = -7500
            self.end_time = -2500
            self.assertFalse(
                trim_downsample_WFChan(self.channel_data, self.start_time,
                                       self.end_time, True)
            )

    def test_result_is_stored(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue('times' in self.channel_data)
        self.assertGreater(len(self.channel_data['times']), 0)
        self.assertTrue('data' in self.channel_data)
        self.assertGreater(len(self.channel_data['data']), 0)

    def test_data_is_trimmed(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        is_times_trimmed = ((self.channel_data['times'] >= self.start_time) &
                            (self.channel_data['times'] <= self.end_time))
        self.assertTrue(is_times_trimmed.all())
        is_data_trimmed = ((self.channel_data['data'] >= self.start_time) &
                           (self.channel_data['data'] <= self.end_time))
        self.assertTrue(is_data_trimmed.all())

    @patch('sohstationviewer.model.handling_data.downsample', wraps=downsample)
    def test_data_is_downsampled(self, mock_downsample):
        const.CHAN_SIZE_LIMIT = 1000
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue(mock_downsample.called)
        const.CHAN_SIZE_LIMIT = ORIGINAL_CHAN_SIZE_LIMIT

    def test_data_small_enough_after_first_trim_flag_is_set(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertTrue('fulldata' in self.channel_data)

    def test_no_additional_work_if_data_small_enough_after_first_trim(self):
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        current_times = self.channel_data['times']
        current_data = self.channel_data['data']
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, True)
        self.assertIs(current_times, self.channel_data['times'])
        self.assertIs(current_data, self.channel_data['data'])

    def test_data_too_large_after_first_time(self):
        const.RECAL_SIZE_LIMIT = 1
        trim_downsample_WFChan(self.channel_data, self.start_time,
                               self.end_time, False)
        self.assertTrue('times' not in self.channel_data)
        self.assertTrue('data' not in self.channel_data)
        const.RECAL_SIZE_LIMIT = ORIGINAL_RECAL_SIZE_LIMIT