Skip to content
Snippets Groups Projects
Commit 16a930fa authored by Kien Le's avatar Kien Le
Browse files

Add test for readWaveformTrace

parent e3ae443b
No related branches found
No related tags found
1 merge request!27Draft: Add tests for functions in handling_data.py
from pathlib import Path from pathlib import Path
import tempfile
from math import isclose from math import isclose
from unittest import TestCase from unittest import TestCase
...@@ -11,10 +12,12 @@ from sohstationviewer.model.handling_data import ( ...@@ -11,10 +12,12 @@ from sohstationviewer.model.handling_data import (
readSOHMSeed, readSOHMSeed,
readSOHTrace, readSOHTrace,
readMPTrace, readMPTrace,
readWaveformTrace,
) )
from sohstationviewer.model.reftek.from_rt2ms.core import Reftek130 from sohstationviewer.model.reftek.from_rt2ms.core import Reftek130
TEST_DATA_DIR = Path(__file__).parent.parent.joinpath('test_data') TEST_DATA_DIR = Path(__file__).parent.parent.joinpath('test_data')
# tempfile.tempdir = './tempdir'
class TestHandlingData(TestCase): class TestHandlingData(TestCase):
...@@ -49,7 +52,6 @@ class TestHandlingData(TestCase): ...@@ -49,7 +52,6 @@ class TestHandlingData(TestCase):
] ]
self.assertTrue( self.assertTrue(
all(key in processed_trace for key in expected_key_list), all(key in processed_trace for key in expected_key_list),
msg='Processed trace is missing some fields.'
) )
def test_read_soh_trace_times_calculated_correctly(self): def test_read_soh_trace_times_calculated_correctly(self):
...@@ -76,10 +78,32 @@ class TestHandlingData(TestCase): ...@@ -76,10 +78,32 @@ class TestHandlingData(TestCase):
np.array_equal(processed_trace['data'], expected) np.array_equal(processed_trace['data'], expected)
) )
# @skip @patch('sohstationviewer.model.handling_data.saveData2File')
# def test_read_waveform_trace(self): def test_read_waveform_trace(self, mock_save_data_2_file):
# self.fail() station_id = self.rt130_trace.stats['station']
# channel_id = self.rt130_trace.stats['channel']
# The function itself only cares about the length of this list so we
# stub it out.
traces_info = [dict() for _ in range(4)]
tmp_dir = tempfile.TemporaryDirectory()
processed_trace = readWaveformTrace(
self.rt130_trace, station_id, channel_id, traces_info, tmp_dir.name
)
expected_key_list = [
'samplerate',
'startTmEpoch',
'endTmEpoch',
'size',
'times_f',
'data_f',
]
self.assertTrue(
all(key in processed_trace for key in expected_key_list),
)
self.assertTrue(mock_save_data_2_file.called)
# @skip # @skip
# def test_read_waveform_mseed(self): # def test_read_waveform_mseed(self):
# self.fail() # self.fail()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment