From 16a930fa8f3482c7ad3a62c8f08392c1fc558685 Mon Sep 17 00:00:00 2001
From: kienle <kienle@passcal.nmt.edu>
Date: Mon, 19 Sep 2022 17:25:30 -0600
Subject: [PATCH] Add test for readWaveformTrace

---
 tests/test_model/test_handling_data.py | 34 ++++++++++++++++++++++----
 1 file changed, 29 insertions(+), 5 deletions(-)

diff --git a/tests/test_model/test_handling_data.py b/tests/test_model/test_handling_data.py
index 14ecbfaaf..d92a87533 100644
--- a/tests/test_model/test_handling_data.py
+++ b/tests/test_model/test_handling_data.py
@@ -1,4 +1,5 @@
 from pathlib import Path
+import tempfile
 from math import isclose
 
 from unittest import TestCase
@@ -11,10 +12,12 @@ from sohstationviewer.model.handling_data import (
     readSOHMSeed,
     readSOHTrace,
     readMPTrace,
+    readWaveformTrace,
 )
 from sohstationviewer.model.reftek.from_rt2ms.core import Reftek130
 
 TEST_DATA_DIR = Path(__file__).parent.parent.joinpath('test_data')
+# tempfile.tempdir = './tempdir'
 
 
 class TestHandlingData(TestCase):
@@ -49,7 +52,6 @@ class TestHandlingData(TestCase):
             ]
             self.assertTrue(
                 all(key in processed_trace for key in expected_key_list),
-                msg='Processed trace is missing some fields.'
             )
 
     def test_read_soh_trace_times_calculated_correctly(self):
@@ -76,10 +78,32 @@ class TestHandlingData(TestCase):
             np.array_equal(processed_trace['data'], expected)
         )
 
-    # @skip
-    # def test_read_waveform_trace(self):
-    #     self.fail()
-    #
+    @patch('sohstationviewer.model.handling_data.saveData2File')
+    def test_read_waveform_trace(self, mock_save_data_2_file):
+        station_id = self.rt130_trace.stats['station']
+        channel_id = self.rt130_trace.stats['channel']
+        # The function itself only cares about the length of this list so we
+        # stub it out.
+        traces_info = [dict() for _ in range(4)]
+        tmp_dir = tempfile.TemporaryDirectory()
+        processed_trace = readWaveformTrace(
+            self.rt130_trace, station_id, channel_id, traces_info, tmp_dir.name
+        )
+
+        expected_key_list = [
+            'samplerate',
+            'startTmEpoch',
+            'endTmEpoch',
+            'size',
+            'times_f',
+            'data_f',
+        ]
+        self.assertTrue(
+            all(key in processed_trace for key in expected_key_list),
+        )
+
+        self.assertTrue(mock_save_data_2_file.called)
+
     # @skip
     # def test_read_waveform_mseed(self):
     #     self.fail()
-- 
GitLab