from pathlib import Path
import tempfile
from math import isclose
from copy import copy as shallow_copy
import tracemalloc

from unittest import TestCase, expectedFailure
from unittest.mock import patch

from obspy.core import Trace, read as read_ms
import numpy as np

from sohstationviewer.model.handling_data import (
    readSOHMSeed,
    readSOHTrace,
    readMPTrace,
    readWaveformTrace,
    readWaveformMSeed,
    readWaveformReftek,
    readASCII,
    readText,
    saveData2File,
    checkChan,
    checkWFChan,
    checkSOHChan
)
from sohstationviewer.model.reftek.from_rt2ms.core import Reftek130

TEST_DATA_DIR = Path(__file__).parent.parent.joinpath('test_data')


# tempfile.tempdir = './tempdir'


class TestHandlingData(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        rt130_dir = TEST_DATA_DIR.joinpath(
            'RT130-sample/2017149.92EB/2017150/92EB')
        cls.rt130_soh_file = rt130_dir.joinpath('9/054910000_013EE8A0')
        rt130_soh = Reftek130.from_file(cls.rt130_soh_file)
        cls.rt130_soh_stream = Reftek130.to_stream(rt130_soh)
        cls.rt130_soh_trace = cls.rt130_soh_stream[0]
        cls.rt130_waveform_file = rt130_dir.joinpath('9/054910000_013EE8A0')
        rt130_waveform = Reftek130.from_file(cls.rt130_waveform_file)
        cls.rt130_waveform_stream = Reftek130.to_stream(rt130_waveform)
        cls.rt130_waveform_trace = cls.rt130_waveform_stream[0]

        q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
        cls.q330_soh_file = q330_dir.joinpath('AX08.XA..VKI.2021.186')
        cls.mseed_soh_stream = read_ms(cls.q330_soh_file)
        cls.mseed_soh_trace = cls.mseed_soh_stream[0]
        cls.q330_waveform_file = q330_dir.joinpath('AX08.XA..LHE.2021.186')
        cls.mseed_waveform_stream = read_ms(cls.q330_waveform_file)
        cls.mseed_waveform_trace = cls.mseed_waveform_stream[0]

    # @expectedFailure
    # def test_read_sohmseed(self):
    #     self.fail()

    def test_read_soh_trace_processed_trace_have_all_needed_info(self):
        processed_trace = readSOHTrace(self.mseed_soh_trace)
        with self.subTest('test_processed_trace_have_all_needed_info'):
            expected_key_list = [
                'chanID',
                'samplerate',
                'startTmEpoch',
                'endTmEpoch',
                'times',
                'data'
            ]
            self.assertTrue(
                all(key in processed_trace for key in expected_key_list),
            )

    def test_read_soh_trace_times_calculated_correctly(self):
        processed_trace = readSOHTrace(self.mseed_soh_trace)
        if isclose(processed_trace['startTmEpoch'], 0, abs_tol=0.0001):
            self.assertAlmostEqual(processed_trace['times'][0], 0)
        else:
            self.assertNotAlmostEqual(processed_trace['times'][0], 0)

    @patch('sohstationviewer.model.handling_data.readSOHTrace')
    def test_read_mp_trace(self, mock_read_soh_trace):
        mock_read_soh_trace.return_value = {
            # Contain five cases:
            # + Small positive
            # + Large positive
            # + Small negative
            # + Large negative
            # + Zero
            'data': np.array([1, 27272, -2, -23526, 0])
        }
        expected = np.array([0, 8.3, 0, -7.2, 0])
        processed_trace = readMPTrace(self.rt130_soh_trace)
        self.assertTrue(
            np.array_equal(processed_trace['data'], expected)
        )

    @patch('sohstationviewer.model.handling_data.saveData2File')
    def test_read_waveform_trace(self, mock_save_data_2_file):
        station_id = self.rt130_soh_trace.stats['station']
        channel_id = self.rt130_soh_trace.stats['channel']
        # The function itself only cares about the length of this list so we
        # stub it out.
        traces_info = [dict() for _ in range(4)]
        tmp_dir = tempfile.TemporaryDirectory()
        processed_trace = readWaveformTrace(
            self.rt130_soh_trace, station_id, channel_id, traces_info,
            tmp_dir.name
        )

        expected_key_list = [
            'samplerate',
            'startTmEpoch',
            'endTmEpoch',
            'size',
            'times_f',
            'data_f',
        ]
        self.assertTrue(
            all(key in processed_trace for key in expected_key_list),
        )

        self.assertTrue(mock_save_data_2_file.called)

    # @skip
    # def test_check_sohchan(self):
    #     self.fail()
    #
    # @skip
    # def test_check_wfchan(self):
    #     self.fail()
    #
    # @skip
    # def test_sort_data(self):
    #     self.fail()
    #
    # @skip
    # def test_squash_gaps(self):
    #     self.fail()
    #
    # @skip
    # def test_downsample(self):
    #     self.fail()
    #
    # @skip
    # def test_constant_rate(self):
    #     self.fail()
    #
    # @skip
    # def test_chunk_minmax(self):
    #     self.fail()
    #
    # @skip
    # def test_trim_downsample_sohchan(self):
    #     self.fail()
    #
    # @skip
    # def test_trim_downsample_wfchan(self):
    #     self.fail()
    #
    # @skip
    # def test_get_each_day5min_list(self):
    #     self.fail()
    #
    # @skip
    # def test_get_trim_tpsdata(self):
    #     self.fail()
    #
    # @skip
    # def test_find_tpstm(self):
    #     self.fail()


class TestReadWaveformMSeed(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
        cls.q330_waveform_file = q330_dir.joinpath('AX08.XA..LHE.2021.186')
        cls.mseed_waveform_stream = read_ms(cls.q330_waveform_file)
        cls.mseed_waveform_trace = cls.mseed_waveform_stream[0]

    def setUp(self) -> None:
        self.station_id = self.mseed_waveform_trace.stats['station']
        self.channel_id = self.mseed_waveform_trace.stats['channel']
        # This list is only ever written to, so we can keep it empty
        self.traces_info = []
        self.data_time = [0, 0]
        self.temp_dir = tempfile.TemporaryDirectory()

        patcher = patch(
            'sohstationviewer.model.handling_data.readWaveformTrace')
        self.addCleanup(patcher.stop)
        self.mock_read_waveform_trace = patcher.start()
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 0,
            'endTmEpoch': 0,
        }

    def tearDown(self) -> None:
        self.temp_dir.cleanup()

    def test_all_traces_are_processed(self):
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(len(self.traces_info),
                         len(self.mseed_waveform_stream))

    def test_readWaveformTrace_called(self):
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertTrue(self.mock_read_waveform_trace.called)

    def test_start_data_time_earlier_than_earliest_start_time(self):
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 0
        self.data_time = [start_time, 1625532949]

        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time[0], start_time)

    def test_end_data_time_later_than_latest_start_time(self):
        # End time set to be the last second of 9999 so as to be later than the
        # end time of any test data.
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        end_time = 253402326000
        self.data_time = [1625443222, end_time]

        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time[1], end_time)

    def test_data_time_is_between_earliest_start_and_latest_end_time(self):
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 512579
        end_time = 2623616
        self.data_time = [start_time, end_time]

        expected_updated_data_time = [51251, 2623623]
        readWaveformMSeed(
            self.q330_waveform_file, self.q330_waveform_file.name,
            self.station_id, self.channel_id, self.traces_info, self.data_time,
            self.temp_dir.name)
        self.assertEqual(self.data_time, expected_updated_data_time)


class TestReadWaveformReftek(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        rt130_dir = TEST_DATA_DIR.joinpath(
            'RT130-sample/2017149.92EB/2017150/92EB')
        cls.rt130_waveform_file = rt130_dir.joinpath('1/000000015_0036EE80')
        cls.rt130_waveform = Reftek130.from_file(cls.rt130_waveform_file)
        cls.rt130_waveform_stream = Reftek130.to_stream(cls.rt130_waveform)
        cls.rt130_waveform_trace = cls.rt130_waveform_stream[0]

    def setUp(self) -> None:
        # The first element of the key, the unit ID, should be a character
        # string. However, the unit ID is stored as a byte string in the data
        # stream. While having the unit ID as a byte string should not cause
        # any problem with the test, we convert it to a character string anyway
        # to be consistent with how readWaveformReftek is called.
        self.key = (
            self.rt130_waveform_trace.stats['unit_id'].decode('utf-8'),
            self.rt130_waveform_trace.stats['experiment_number']
        )
        self.read_data = {}
        self.data_time = [0, 0]
        self.temp_dir = tempfile.TemporaryDirectory()

        patcher = patch(
            'sohstationviewer.model.handling_data.readWaveformTrace')
        self.addCleanup(patcher.stop)
        self.mock_read_waveform_trace = patcher.start()
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 0,
            'endTmEpoch': 0,
        }

    def tearDown(self) -> None:
        self.temp_dir.cleanup()

    def test_all_traces_are_processed(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        num_traces_read = 0
        for channel_data in self.read_data.values():
            num_traces_read += len(channel_data['tracesInfo'])
        self.assertTrue(num_traces_read, len(self.rt130_waveform_stream))

    def test_read_data_updated_with_all_channels(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertTupleEqual(tuple(self.read_data.keys()),
                              ('DS1-1', 'DS1-2', 'DS1-3'))

    def test_read_data_existing_channel_appended_to(self):
        self.read_data = {
            'DS1-1':
                {'tracesInfo': [{'startTmEpoch': 0, 'endTmEpoch': 0}],
                 'samplerate': 40.0}
        }
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(len(self.read_data['DS1-1']['tracesInfo']), 2)

    def test_readWaveformTrace_called(self):
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertTrue(self.mock_read_waveform_trace.called)

    def test_start_data_time_earlier_than_earliest_start_time(self):
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 0
        self.data_time = [start_time, 1625532949]

        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time[0], start_time)

    def test_end_data_time_later_than_latest_start_time(self):
        # End time set to be the last second of 9999 so as to be later than the
        # end time of any test data.
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        end_time = 253402326000
        self.data_time = [1625443222, end_time]

        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time[1], end_time)

    def test_data_time_is_between_earliest_start_and_latest_end_time(self):
        self.mock_read_waveform_trace.return_value = {
            'startTmEpoch': 51251,
            'endTmEpoch': 2623623,
        }
        start_time = 512579
        end_time = 2623616
        self.data_time = [start_time, end_time]

        expected_updated_data_time = [51251, 2623623]
        readWaveformReftek(self.rt130_waveform, self.key, self.read_data,
                           self.data_time, self.temp_dir.name)
        self.assertEqual(self.data_time, expected_updated_data_time)


class TestReadASCII(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        q330_dir = TEST_DATA_DIR.joinpath('Q330-sample/day_vols_AX08')
        cls.q330_log_file = q330_dir.joinpath('AX08.XA..LOG.2021.186')
        cls.mseed_ascii_stream = read_ms(cls.q330_log_file)
        cls.mseed_ascii_trace = cls.mseed_ascii_stream[0]

    def track_info(self, msg, msg_type):
        self.info_tracked.append((msg, msg_type))

    def setUp(self) -> None:
        self.station_id = self.mseed_ascii_trace.stats.station
        self.channel_id = self.mseed_ascii_trace.stats.channel
        self.log_data = {}
        self.info_tracked = []

    def test_station_and_channel_inserted_into_log_data(self):
        readASCII(self.q330_log_file, None, self.station_id, self.channel_id,
                  self.mseed_ascii_trace, self.log_data, self.track_info)
        self.assertTrue(self.station_id in self.log_data)
        self.assertTrue(self.channel_id in self.log_data[self.station_id])

    def test_trace_contains_log_data(self):
        returned_file = readASCII(self.q330_log_file, None, self.station_id,
                                  self.channel_id, self.mseed_ascii_trace,
                                  self.log_data, self.track_info)
        self.assertIsNone(returned_file)

        log_string = self.log_data[self.station_id][self.channel_id][0]
        log_lines = log_string.split('\n')
        # Q330 uses \r\n as the new line character, so we need to remove the \r
        # after splitting the log string based on \n
        log_lines = [line.strip('\r') for line in log_lines]
        # The first four elements of log_lines will be there whether anything
        # is read from the file, so we know that something has been read from
        # the file when log_lines have more than four elements.
        self.assertGreater(len(log_lines), 4)
        # Check that we are not reading in gibberish
        self.assertEqual(
            log_lines[4],
            'Quanterra Packet Baler Model 14 Restart. Version 2.26'
        )

    @expectedFailure
    def test_trace_does_not_contain_log_data(self):
        raise NotImplementedError(
            self.test_trace_does_not_contain_log_data.__qualname__)
        # We are only reassigning the data reference of the trace and are not
        # modifying the stored data. As a result, we only need a shallow copy
        # of the trace.
        trace = shallow_copy(self.mseed_ascii_trace)
        trace.data = np.array([])
        returned_file = readASCII(self.q330_log_file, None, self.station_id,
                                  self.channel_id, trace, self.log_data,
                                  self.track_info)
        self.assertIsNotNone(returned_file)
        log_header = (
            '\n\n**** STATE OF HEALTH: From:2021-07-05T03:37:40.120000Z  '
            'To:2021-07-05T03:37:40.120000Z\n'
        )
        log_info = self.mseed_ascii_trace.data.tobytes().decode()
        log_string = log_header + log_info
        print(self.log_data[self.station_id][self.channel_id][0])
        self.assertEqual(log_string,
                         self.log_data[self.station_id][self.channel_id][0])


class TestReadText(TestCase):
    def setUp(self):
        self.text_file = tempfile.NamedTemporaryFile(mode='w+t')
        self.text_file.write('Test text')
        self.text_file.flush()
        self.text_file_name = Path(self.text_file.name).name
        self.non_text_file = TEST_DATA_DIR.joinpath('Q330-sample/'
                                                    'day_vols_AX08/'
                                                    'AX08.XA..HHE.2021.186')
        self.text_logs = []

    def tearDown(self) -> None:
        self.text_file.close()

    def test_log_appended_to(self):
        readText(self.text_file.name, self.text_file.name, self.text_logs)
        self.assertGreater(len(self.text_logs), 0)

    def test_text_file(self):
        readText(self.text_file.name, Path(self.text_file.name).name,
                 self.text_logs)
        self.assertEqual(
            self.text_logs[0],
            f'\n\n** STATE OF HEALTH: {Path(self.text_file.name).name}\nTest text')  # noqa: E501

    def test_non_text_file(self):
        with self.assertRaises(Exception):
            readText(self.non_text_file, self.non_text_file.name,
                     self.text_logs)

    def test_non_existent_file(self):
        non_existent_file = TEST_DATA_DIR.joinpath('non_existent_file')
        with self.assertRaises(FileNotFoundError):
            readText(str(non_existent_file),
                     non_existent_file.name,
                     self.text_logs)


class TestSaveData2File(TestCase):
    def setUp(self) -> None:
        self.temp_dir = tempfile.TemporaryDirectory()
        self.time_data = 'timedata'
        self.station_id = 'station'
        self.channel_id = 'channel'
        self.trace_idx = 0
        self.trace_size = 100
        tracemalloc.start()
        # We do not need the metadata of a data set (stored in Trace.stats), so
        # we can create a trace without reading any data set. This approach
        # allows us to create a trace that contains a large amount of data
        # without having to store a large data set in the GitLab server.
        self.trace = Trace(np.ones(self.trace_size))

        # The current implementation stores the data in a memmap that contains
        # 64-bit integers, and so each data point has a size of 8 bytes in the
        # memmap
        self.data_point_size = 8

    def tearDown(self) -> None:
        self.temp_dir.cleanup()
        tracemalloc.stop()

    def test_file_created(self):
        mem_file_name = saveData2File(self.temp_dir.name, self.time_data,
                                      self.station_id, self.channel_id,
                                      self.trace, self.trace_idx,
                                      self.trace_size)

        self.assertTrue(mem_file_name.is_file())

        expected_mem_file_name = "station-channel-timedata-0"
        self.assertEqual(mem_file_name.name, expected_mem_file_name)

        memmap_size = self.data_point_size * self.trace_size
        self.assertEqual(mem_file_name.stat().st_size, memmap_size)

    @patch.object(np, 'memmap')
    def test_memmap_called(self, mock_memmap):
        saveData2File(self.temp_dir.name, self.time_data, self.station_id,
                      self.channel_id, self.trace, self.trace_idx,
                      self.trace_size)
        self.assertTrue(mock_memmap.called)

    @expectedFailure
    def test_memory_freed(self):
        # This test works by calculating the amount of memory freed after
        # saveData2File is called. Because there are operations happening
        # between the memory being freed and the calculation, the calculated
        # amount of memory won't be exact. Fixing this would involve removing
        # all the aforementioned intermediate operations, which is (very
        # likely) impossible and does not provide much value due to the actual
        # deviation being small (consistently at most 1000 bytes).

        # The difference between the calculated amount of freed memory and the
        # size of the memmap is around 1000 bytes, we want data that is at
        # least an order of magnitude bigger to ensure that this difference is
        # obvious in our calculation.
        self.trace_size = 10000
        self.trace = Trace(np.ones(self.trace_size))
        memmap_size = self.trace_size * self.data_point_size

        current_mem_used = tracemalloc.get_traced_memory()[0]
        saveData2File(self.temp_dir.name, self.time_data, self.station_id,
                      self.channel_id, self.trace, self.trace_idx,
                      self.trace_size)
        freed_mem = current_mem_used - tracemalloc.get_traced_memory()[0]

        freed_mem_max_error = 1000
        self.assertTrue(isclose(memmap_size, freed_mem,
                                abs_tol=freed_mem_max_error))

    def test_empty_trace(self):
        self.trace_size = 0
        self.trace = Trace(np.ones(self.trace_size))
        with self.assertRaises(ValueError):
            saveData2File(self.temp_dir.name, self.time_data, self.station_id,
                          self.channel_id, self.trace, self.trace_idx,
                          self.trace_size)


class TestCheckChan(TestCase):
    def setUp(self) -> None:
        self.req_soh_chans = ['LCE', 'LCQ']
        self.req_wf_chans = ['LHE', 'HHE']

        check_soh_chan_patcher = patch(
            'sohstationviewer.model.handling_data.checkSOHChan',
            wraps=checkSOHChan
        )
        self.addCleanup(check_soh_chan_patcher.stop)
        self.mock_check_soh_chan = check_soh_chan_patcher.start()

        check_wf_chan_patcher = patch(
            'sohstationviewer.model.handling_data.checkWFChan',
            wraps=checkWFChan
        )
        self.addCleanup(check_wf_chan_patcher.stop)
        self.mock_check_wf_chan = check_wf_chan_patcher.start()

    def test_channel_is_waveform_and_is_requested(self):
        waveform_channel = 'LHE'
        ret = checkChan(waveform_channel, self.req_soh_chans,
                        self.req_wf_chans)
        self.assertEqual(ret, 'WF')
        self.assertTrue(self.mock_check_wf_chan.called)

    def test_channel_is_waveform_but_not_requested(self):
        waveform_channel = 'HH1'
        ret = checkChan(waveform_channel, self.req_soh_chans,
                        self.req_wf_chans)
        self.assertFalse(ret)
        self.assertTrue(self.mock_check_wf_chan.called)

    def test_channel_is_soh_and_is_requested(self):
        with self.subTest('test_normal_channel'):
            soh_channel = 'LCE'
            ret = checkChan(soh_channel, self.req_soh_chans, self.req_wf_chans)
            self.assertEqual(ret, 'SOH')
            self.assertTrue(self.mock_check_soh_chan.called)
        self.mock_check_soh_chan.reset_mock()
        with self.subTest('test_mass_position_channel'):
            soh_channel = 'VM1'
            ret = checkChan(soh_channel, self.req_soh_chans, self.req_wf_chans)
            self.assertEqual(ret, 'SOH')
            self.assertTrue(self.mock_check_soh_chan.called)

    def test_channel_is_soh_but_not_requested(self):
        soh_channel = 'VKI'
        ret = checkChan(soh_channel, self.req_soh_chans, self.req_wf_chans)
        self.assertFalse(ret)
        self.assertTrue(self.mock_check_soh_chan.called)


class TestCheckSohChan(TestCase):
    def setUp(self) -> None:
        self.req_soh_chans = ['LCE', 'LCQ', 'EX?']
        # Generated using the builtin random library with a seed of
        # 22589824271860044. This seed is obtained by converting the string
        # PASSCAL to bytes and converting the resulting bytes to an integer
        # with big endian ordering.
        self.sample_channel_ids = ['ODV', 'QGA', 'NF4', 'OLY', 'UZM']

    def test_all_channels_requested(self):
        self.req_soh_chans = []

        for channel_id in self.sample_channel_ids:
            self.assertTrue(checkSOHChan(channel_id, self.req_soh_chans))

    def test_channel_is_requested(self):
        with self.subTest('test_normal_channels'):
            channel_id = 'LCE'
            self.assertTrue(checkSOHChan(channel_id, self.req_soh_chans))
        with self.subTest('test_mass_position_channels'):
            base_channel_id = 'VM'
            channel_suffixes = ['0', '1', '2', '3', '4', '5', '6']
            for suffix in channel_suffixes:
                channel_id = base_channel_id + suffix
                self.assertTrue(checkSOHChan(channel_id, self.req_soh_chans))
        with self.subTest('test_external_soh_channels'):
            base_channel_id = 'EX'
            channel_suffixes = ['1', '2', '3']
            for suffix in channel_suffixes:
                channel_id = base_channel_id + suffix
                self.assertTrue(checkSOHChan(channel_id, self.req_soh_chans))

    def test_channel_not_requested(self):
        for channel_id in self.sample_channel_ids:
            self.assertFalse(checkSOHChan(channel_id, self.req_soh_chans))


class TestCheckWfChan(TestCase):
    def setUp(self) -> None:
        self.req_wf_chans = ['LHE', 'HHE']
        self.sample_channel_ids = ['LHE', 'HHE', 'AL2', 'MNZ', 'VNN']

    def test_all_channels_requested(self):
        self.req_wf_chans = ['*']
        with self.subTest('test_waveform_channel'):
            for channel_id in self.sample_channel_ids:
                self.assertTupleEqual(
                    checkWFChan(channel_id, self.req_wf_chans),
                    ('WF', True)
                )
        with self.subTest('test_non_waveform_channel'):
            channel_id = 'Not a waveform channel'
            self.assertTupleEqual(
                checkWFChan(channel_id, self.req_wf_chans),
                ('', True)
            )

    def test_channel_is_requested(self):
        channel_id = 'LHE'
        self.assertTupleEqual(
            checkWFChan(channel_id, self.req_wf_chans),
            ('WF', True)
        )

    def test_channel_not_requested(self):
        with self.subTest('test_waveform_channel'):
            channel_id = 'AHE'
            self.assertTupleEqual(
                checkWFChan(channel_id, self.req_wf_chans),
                ('WF', False)
            )
        with self.subTest('test_non_waveform_channel'):
            channel_id = 'Not a waveform channel'
            self.assertTupleEqual(
                checkWFChan(channel_id, self.req_wf_chans),
                ('', False)
            )