Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • software_public/passoft/sohstationviewer
1 result
Show changes
# Drawing waveform and mass position # Drawing waveform and mass position
# TODO: add more descriptive progress message
from __future__ import annotations
from PySide2 import QtWidgets, QtCore from typing import List
from PySide2 import QtCore, QtWidgets
from sohstationviewer.view.plotting.waveform_processor import \
WaveformChannelProcessor
from sohstationviewer.view.util.plot_func_names import plot_functions from sohstationviewer.view.util.plot_func_names import plot_functions
from sohstationviewer.view.plotting.plotting_widget import plotting_widget from sohstationviewer.view.plotting.plotting_widget import plotting_widget
from sohstationviewer.model.handling_data import trim_downsample_WFChan from sohstationviewer.controller.plotting_data import get_title
from sohstationviewer.controller.util import display_tracking_info
from sohstationviewer.controller.plottingData import getTitle
from sohstationviewer.controller.util import apply_convert_factor
from sohstationviewer.database import extract_data from sohstationviewer.database import extract_data
...@@ -19,11 +23,38 @@ class WaveformWidget(plotting_widget.PlottingWidget): ...@@ -19,11 +23,38 @@ class WaveformWidget(plotting_widget.PlottingWidget):
""" """
Widget to display waveform and mass position data. Widget to display waveform and mass position data.
""" """
def plot_channels(self, start_tm, end_tm, key,
data_time, time_ticks_total, finished = QtCore.Signal()
waveform_data, mass_pos_data):
def __init__(self, parent, tracking_box, *args):
super().__init__(parent, tracking_box)
self.data_processors: List[WaveformChannelProcessor] = []
# Only one data processor can run at a time, so it is not a big problem
#
self.thread_pool = QtCore.QThreadPool()
# Flag to indicate whether the data is being plotted for the first time
# or the plot is being zoomed in.
self.is_first_plotting = True
# Used to ensure that the user cannot read a new data set when we are
# zooming in on the waveform plot.
self.is_working = False
def reset_widget(self):
""" """
Recursively plot each channels for waveform_data and mass_pos_data. Reset the widget in order to plot a new data set.
"""
self.data_processors = []
self.is_first_plotting = True
self.is_working = False
def init_plot(self, start_tm, end_tm, key, data_time, time_ticks_total,
waveform_data, mass_pos_data):
"""
Initialize and configure the plot.
:param start_tm: float - requested start time to read :param start_tm: float - requested start time to read
:param end_tm: float - requested end time to read :param end_tm: float - requested end time to read
...@@ -31,79 +62,190 @@ class WaveformWidget(plotting_widget.PlottingWidget): ...@@ -31,79 +62,190 @@ class WaveformWidget(plotting_widget.PlottingWidget):
or (device's experiment number, experiment number ) for reftek or (device's experiment number, experiment number ) for reftek
:param data_time: [float, float] - time range of the selected data set :param data_time: [float, float] - time range of the selected data set
:param time_ticks_total: int - max number of tick to show on time bar :param time_ticks_total: int - max number of tick to show on time bar
:param waveform_data: dict - read waveform data of selected data set, :param waveform_data: dict - waveform data of selected data set,
refer to DataTypeModel.__init__.waveform_data[key]['read_data'] refer to DataTypeModel.__init__.waveform_data[key]['read_data']
:param mass_pos_data: dict - mass position data of the selected :param mass_pos_data: dict - mass position data of the selected
data set, refer to DataTypeModel.__init__.mass_pos_data[key] data set, refer to DataTypeModel.__init__.mass_pos_data[key]
""" """
# The name of the plotting data is used for compatibility with the API
# of the parent class PlottingWidget
self.plotting_data1 = waveform_data self.plotting_data1 = waveform_data
self.plotting_data2 = mass_pos_data self.plotting_data2 = mass_pos_data
self.processing_log = [] # [(message, type)]
self.processing_log = [] # [(message, type)]
self.gap_bar = None self.gap_bar = None
if self.axes: if self.axes:
self.plotting_axes.fig.clear() self.plotting_axes.fig.clear()
self.draw()
self.date_mode = self.parent.date_format.upper() self.date_mode = self.parent.date_format.upper()
self.time_ticks_total = time_ticks_total self.time_ticks_total = time_ticks_total
self.min_x = max(data_time[0], start_tm) self.min_x = max(data_time[0], start_tm)
self.max_x = min(data_time[1], end_tm) self.max_x = min(data_time[1], end_tm)
self.plot_total = len(self.plotting_data1) + len(self.plotting_data2) self.plot_total = len(self.plotting_data1) + len(self.plotting_data2)
title = getTitle(key, self.min_x, self.max_x, self.date_mode) title = get_title(key, self.min_x, self.max_x, self.date_mode)
self.plotting_bot = const.BOTTOM self.plotting_bot = const.BOTTOM
self.plotting_bot_pixel = const.BOTTOM_PX self.plotting_bot_pixel = const.BOTTOM_PX
self.axes = [] self.axes = []
self.timestamp_bar_top = self.plotting_axes.add_timestamp_bar(0.003) self.timestamp_bar_top = self.plotting_axes.add_timestamp_bar(0.003)
self.plotting_axes.set_title(title) self.plotting_axes.set_title(title)
def plot_channels(self, start_tm, end_tm, key,
data_time, time_ticks_total,
waveform_data, mass_pos_data):
"""
Prepare to plot waveform and mass-position data by creating a data
processor for each channel in the waveform data (mass-position data is
already processed when plotting SOH data). Then, run the data
processor for each channel.
:param start_tm: float - requested start time to read
:param end_tm: float - requested end time to read
:param key: str or (str, str) station name for mseed,
or (device's experiment number, experiment number ) for reftek
:param data_time: [float, float] - time range of the selected data set
:param time_ticks_total: int - max number of tick to show on time bar
:param waveform_data: dict - waveform data of selected data set,
refer to DataTypeModel.__init__.waveform_data[key]['read_data']
:param mass_pos_data: dict - mass position data of the selected
data set, refer to DataTypeModel.__init__.mass_pos_data[key]
"""
if not self.is_working:
self.reset_widget()
self.is_working = True
start_msg = 'Plotting waveform data...'
display_tracking_info(self.tracking_box, start_msg, 'info')
self.init_plot(start_tm, end_tm, key,
data_time, time_ticks_total,
waveform_data, mass_pos_data)
self.create_waveform_channel_processors()
self.process_channel()
def get_zoom_data(self, *args, **kwargs):
"""
Dummy method to comply with the implementation of self.set_lim in the
base class PlottingWidget.
"""
pass
def set_lim(self, first_time=False):
"""
The set_lim method of the base class PlottingWidget was not designed
with multi-threading in mind, so it made some assumption that is
difficult to satisfy in a multi-threaded design. While these
assumptions do not affect the initial plotting of the data, they make
designing a system for zooming more difficult.
Rather than trying to comply with the design of PlottingWidget.set_lim,
we decide to work around. This set_lim method still keeps the
functionality of processing the data based on the zoom range. However,
it delegates setting the new limit of the x and y axes to
PlottingWidget.set_lim. Because PlottingWidget.set_lim processes the
data by calling the method get_zoom_data, we implement an override that
does nothing, thus disabling that part of PlottingWidget.set_lim.
:param first_time: flag that indicate whether set_lim is called the
fist time for a data set.
"""
self.data_processors = []
if not self.is_working:
self.is_working = True
start_msg = 'Zooming in...'
display_tracking_info(self.tracking_box, start_msg, 'info')
self.create_waveform_channel_processors()
self.process_channel()
def create_waveform_channel_processors(self):
"""
Create a data processor for each channel in the waveform data.
"""
for chan_id in self.plotting_data1: for chan_id in self.plotting_data1:
chan_db_info = extract_data.get_wf_plot_info(chan_id) chan_db_info = extract_data.get_wf_plot_info(chan_id)
if chan_db_info['plotType'] == '': if chan_db_info['plotType'] == '':
continue continue
self.plotting_data1[chan_id]['chan_db_info'] = chan_db_info self.plotting_data1[chan_id]['chan_db_info'] = chan_db_info
self.get_zoom_data(self.plotting_data1[chan_id], chan_id, True) channel_processor = WaveformChannelProcessor(
self.plotting_data1[chan_id], chan_id, self.min_x, self.max_x,
True
)
self.data_processors.append(channel_processor)
channel_processor.finished.connect(self.process_channel)
channel_processor.stopped.connect(self.stopped)
def plot_mass_pos_channels(self):
"""
Plot the mass-position data. Because mass-position data has already
been processed in SOH widget, this method does no further processing
and instead go straight into plotting.
"""
for chan_id in self.plotting_data2: for chan_id in self.plotting_data2:
chan_db_info = extract_data.get_chan_plot_info( chan_db_info = extract_data.get_chan_plot_info(
chan_id, self.parent.data_type) chan_id, self.parent.data_type)
self.plotting_data2[chan_id]['chan_db_info'] = chan_db_info self.plotting_data2[chan_id]['chan_db_info'] = chan_db_info
self.get_zoom_data(self.plotting_data2[chan_id], chan_id, True) self.plot_single_channel(self.plotting_data2[chan_id], chan_id)
self.axes.append(self.plotting.plot_none()) @QtCore.Slot()
self.timestamp_bar_bottom = self.plotting_axes.add_timestamp_bar( def process_channel(self, channel_data=None, channel_id=None):
0.003, top=False) """
self.set_lim(first_time=True) Process a channel of waveform data. If channel_id and channel_data is
self.bottom = self.axes[-1].get_ybound()[0] not None, remove the first data processor from the list of processor
self.ruler = self.plotting_axes.add_ruler( and plot the channel channel_id. No matter what, the next channel
self.display_color['time_ruler']) processor is started. If there is no more channel processor, the widget
self.zoom_marker1 = self.plotting_axes.add_ruler( goes into the steps of finalizing the plot.
self.display_color['zoom_marker'])
self.zoom_marker2 = self.plotting_axes.add_ruler(
self.display_color['zoom_marker'])
# Set view size fit with the given data
if self.main_widget.geometry().height() < self.plotting_bot_pixel:
self.main_widget.setFixedHeight(self.plotting_bot_pixel)
self.draw() Semantically, if this method is called with no argument, the processing
is being started. Otherwise, if it receives two arguments, a data
processor just finished its work and is signalling the widget to plot
the processed channel and start processing the next one.
This method has been designed to only start one processor at once. This
is a deliberate choice in order to avoid running out of memory. Due to
how they are designed, channels processors use a lot of memory for
large data sets. If multiple processors are run at once, memory usage
of the program becomes untenable. It has been observed that a 7.5 GB
data set can cause an out of memory error.
:param channel_data: the data of channel_id
:param channel_id: the name of the channel to be plotted
"""
if channel_id is not None:
self.data_processors.pop(0)
self.plot_single_channel(channel_data, channel_id)
try:
channel_processor = self.data_processors[0]
self.thread_pool.start(channel_processor)
except IndexError:
self.plot_mass_pos_channels()
if self.is_first_plotting:
self.done()
finished_msg = 'Waveform plot finished'
else:
super().set_lim()
self.draw()
finished_msg = 'Zooming in finished'
def get_zoom_data(self, c_data, chan_id, first_time=False): display_tracking_info(self.tracking_box, finished_msg, 'info')
self.is_working = False
self.is_first_plotting = False
except Exception as e:
print(e)
def plot_single_channel(self, c_data, chan_id):
""" """
Trim data of a channel to the zooming time range and re-downsample Plot the channel chan_id.
the data to get more detail.
:param c_data: dict - data of the channel which includes down-sampled :param c_data: dict - data of the channel which includes down-sampled
data in keys 'times' and 'data'. Refer to data in keys 'times' and 'data'. Refer to
DataTypeModel.__init__.waveform_data[key]['read_data'][chan_id] or DataTypeModel.__init__.waveform_data[key]['read_data'][chan_id] or
DataTypeModel.__init__.mass_pos_data[key][chan_id] DataTypeModel.__init__.mass_pos_data[key][chan_id]
:param chan_id: str - name of channel :param chan_id: str - name of channel
:param first_time: bool - flag identify when the data set is first
plotted, not zoom in yet
""" """
if len(c_data['times']) == 0:
return
chan_db_info = c_data['chan_db_info'] chan_db_info = c_data['chan_db_info']
plot_type = chan_db_info['plotType'] plot_type = chan_db_info['plotType']
# data already processed for mass position in plotting_widget
if not (chan_id.startswith('VM') or chan_id.startswith('MP')):
trim_downsample_WFChan(
c_data, self.min_x, self.max_x, first_time)
apply_convert_factor(c_data, 1)
# use ax_wf because with mass position, ax has been used # use ax_wf because with mass position, ax has been used
# in plottingWidget # in plottingWidget
if 'ax_wf' not in c_data: if 'ax_wf' not in c_data:
...@@ -118,6 +260,52 @@ class WaveformWidget(plotting_widget.PlottingWidget): ...@@ -118,6 +260,52 @@ class WaveformWidget(plotting_widget.PlottingWidget):
getattr(self.plotting, plot_functions[plot_type][1])( getattr(self.plotting, plot_functions[plot_type][1])(
c_data, chan_db_info, chan_id, c_data['ax_wf'], None) c_data, chan_db_info, chan_id, c_data['ax_wf'], None)
def done(self):
"""
Finish up the plot after all channels in the data have been plotted.
Also signal to the main window that the waveform plot is finished so
that it can update its flags.
"""
self.axes.append(self.plotting.plot_none())
self.timestamp_bar_bottom = self.plotting_axes.add_timestamp_bar(
0.003, top=False)
super().set_lim(first_time=True)
self.bottom = self.axes[-1].get_ybound()[0]
self.ruler = self.plotting_axes.add_ruler(
self.display_color['time_ruler'])
self.zoom_marker1 = self.plotting_axes.add_ruler(
self.display_color['zoom_marker'])
self.zoom_marker2 = self.plotting_axes.add_ruler(
self.display_color['zoom_marker'])
# Set view size fit with the given data
if self.main_widget.geometry().height() < self.plotting_bot_pixel:
self.main_widget.setFixedHeight(self.plotting_bot_pixel)
self.draw()
self.finished.emit()
def request_stop(self):
"""
Request the widget to stop plotting.
"""
# The currently running data processor will always be the first one in
# the list.
self.data_processors[0].request_stop()
# Because the processors are started one at a time, we only need to
# delete all the processors from the queue. However, we need to keep
# the currently running data processor in memory to ensure that its
# downsampler terminate all running threads gracefully.
self.data_processors = [self.data_processors[0]]
@QtCore.Slot()
def stopped(self):
"""
The slot that is called when the last channel processor has terminated
all running background threads.
"""
display_tracking_info(self.tracking_box,
'Waveform plot stopped', 'info')
self.is_working = False
class WaveformDialog(QtWidgets.QWidget): class WaveformDialog(QtWidgets.QWidget):
def __init__(self, parent): def __init__(self, parent):
...@@ -166,7 +354,9 @@ class WaveformDialog(QtWidgets.QWidget): ...@@ -166,7 +354,9 @@ class WaveformDialog(QtWidgets.QWidget):
mass position channel mass position channel
""" """
self.plotting_widget = WaveformWidget( self.plotting_widget = WaveformWidget(
self, self.info_text_browser, "waveformWidget") self, self.info_text_browser)
self.plotting_widget.finished.connect(self.plot_finished)
main_layout.addWidget(self.plotting_widget, 2) main_layout.addWidget(self.plotting_widget, 2)
bottom_layout = QtWidgets.QHBoxLayout() bottom_layout = QtWidgets.QHBoxLayout()
...@@ -207,3 +397,6 @@ class WaveformDialog(QtWidgets.QWidget): ...@@ -207,3 +397,6 @@ class WaveformDialog(QtWidgets.QWidget):
Save the plotting to a file Save the plotting to a file
""" """
print("save") print("save")
def plot_finished(self):
self.parent.is_plotting_waveform = False
from typing import List, Dict
from PySide2 import QtCore
from sohstationviewer.conf import constants as const
import numpy as np
from sohstationviewer.model.downsampler import Downsampler
class WaveformChannelProcessorSignals(QtCore.QObject):
finished = QtCore.Signal(dict, str)
stopped = QtCore.Signal()
class WaveformChannelProcessor(QtCore.QRunnable):
"""
The class that handles trimming excess data and interfacing with a
downsampler for a waveform channel.
"""
def __init__(self, channel_data, channel_id, start_time, end_time,
first_time):
super().__init__()
self.signals = WaveformChannelProcessorSignals()
# Aliasing the signals for ease of use
self.finished = self.signals.finished
self.stopped = self.signals.stopped
self.stop_requested = False
self.downsampler = Downsampler()
self.channel_data: dict = channel_data
self.channel_id = channel_id
self.start_time = start_time
self.end_time = end_time
self.first_time = first_time
self.trimmed_trace_list = None
self.downsampled_times_list = []
self.downsampled_data_list = []
self.downsampled_list_lock = QtCore.QMutex()
def trim_waveform_data(self) -> List[Dict]:
"""
Trim off waveform traces whose times do not intersect the closed
interval [self.start_time, self.end_time]. Store the traces that are
not removed in self.trimmed_trace_list.
"""
data_start_time = self.channel_data['tracesInfo'][0]['startTmEpoch']
data_end_time = self.channel_data['tracesInfo'][-1]['endTmEpoch']
if (self.start_time > data_end_time
or self.end_time < data_start_time):
return []
good_start_indices = [index
for index, tr
in enumerate(self.channel_data['tracesInfo'])
if tr['startTmEpoch'] > self.start_time]
if good_start_indices:
start_idx = good_start_indices[0]
if start_idx > 0:
start_idx -= 1 # start_time in middle of trace
else:
start_idx = 0
good_end_indices = [idx
for idx, tr
in enumerate(self.channel_data['tracesInfo'])
if tr['endTmEpoch'] <= self.end_time]
if good_end_indices:
end_idx = good_end_indices[-1]
if end_idx < len(self.channel_data['tracesInfo']) - 1:
end_idx += 1 # end_time in middle of trace
else:
end_idx = 0
end_idx += 1 # a[x:y+1] = [a[x], ...a[y]]
good_indices = slice(start_idx, end_idx)
self.trimmed_trace_list = self.channel_data['tracesInfo'][good_indices]
def init_downsampler(self):
"""
Initialize the downsampler by loading the memmapped traces' data and
creating a downsampler worker for each loaded trace.
"""
# Calculate the number of requested_points
total_size = sum([tr['size'] for tr in self.trimmed_trace_list])
requested_points = 0
if total_size > const.CHAN_SIZE_LIMIT:
requested_points = int(
const.CHAN_SIZE_LIMIT / len(self.trimmed_trace_list)
)
# Downsample the data
for tr_idx, tr in enumerate(self.trimmed_trace_list):
if not self.stop_requested:
times = np.memmap(tr['times_f'],
dtype='int64', mode='r',
shape=tr['size'])
data = np.memmap(tr['data_f'],
dtype='int64', mode='r',
shape=tr['size'])
indexes = np.where((self.start_time <= times) &
(times <= self.end_time))
times = times[indexes]
data = data[indexes]
do_downsample = (requested_points != 0)
worker = self.downsampler.add_worker(
times, data, rq_points=requested_points,
do_downsample=do_downsample
)
# We need these connections to run in the background thread.
# However, their owner (the channel processor) is in the main
# thread, so the default connection type would make them
# run in the main thread. Instead, we have to use a direct
# connection to make these slots run in the background thread.
worker.signals.finished.connect(
self.trace_processed, type=QtCore.Qt.DirectConnection
)
worker.signals.stopped.connect(
self.stopped, type=QtCore.Qt.DirectConnection
)
@QtCore.Slot()
def trace_processed(self, times, data):
"""
The slot called when the downsampler worker of a waveform trace
finishes its job. Add the downsampled data to the appropriate list.
If the worker that emitted the signal is the last one, combine and
store the processed data in self.channel_data and emit the finished
signal of this class.
:param times: the downsampled array of time data.
:param data: the downsampled array of waveform data.
"""
self.downsampled_list_lock.lock()
self.downsampled_times_list.append(times)
self.downsampled_data_list.append(data)
self.downsampled_list_lock.unlock()
if len(self.downsampled_times_list) == len(self.trimmed_trace_list):
self.channel_data['times'] = np.hstack(self.downsampled_times_list)
self.channel_data['data'] = np.hstack(self.downsampled_data_list)
self.signals.finished.emit(self.channel_data, self.channel_id)
def run(self):
"""
The main method of this class. First check that the channel is not
already small enough after the first trim that there is no need for
further processing. Then, trim the waveform data based on
self.start_time and self.end_time. Afterwards, do some checks to
determine if there is a need to downsample the data. If yes, initialize
and start the downsampler.
"""
if 'fulldata' in self.channel_data:
self.finished.emit(self.channel_data, self.channel_id)
# data is small, already has full in the first trim
return
self.trim_waveform_data()
if not self.trimmed_trace_list:
self.channel_data['times'] = np.array([])
self.channel_data['data'] = np.array([])
self.finished.emit(self.channel_data, self.channel_id)
return False
total_size = sum([tr['size'] for tr in self.trimmed_trace_list])
if not self.first_time and total_size > const.RECAL_SIZE_LIMIT:
# The data is so big that processing it would not make it any
# easier to understand resulting plot.
return
if total_size <= const.CHAN_SIZE_LIMIT and self.first_time:
self.channel_data['fulldata'] = True
try:
del self.channel_data['times']
del self.channel_data['data']
except Exception:
pass
self.init_downsampler()
self.downsampler.start()
def request_stop(self):
"""
Stop processing the data by requesting the downsampler to stop
running.
"""
self.stop_requested = True
self.downsampler.request_stop()
...@@ -713,4 +713,5 @@ class UIMainWindow(object): ...@@ -713,4 +713,5 @@ class UIMainWindow(object):
self.prefer_soh_chan_button.clicked.connect( self.prefer_soh_chan_button.clicked.connect(
main_window.open_channel_preferences) main_window.open_channel_preferences)
self.read_button.clicked.connect(main_window.read_selected_files) self.read_button.clicked.connect(main_window.read_selected_files)
self.stop_button.clicked.connect(main_window.stop_load_data)
self.stop_button.clicked.connect(main_window.stop)
...@@ -4,73 +4,76 @@ from unittest import TestCase ...@@ -4,73 +4,76 @@ from unittest import TestCase
from obspy import UTCDateTime from obspy import UTCDateTime
from sohstationviewer.controller.plottingData import ( from sohstationviewer.controller.plotting_data import (
getMassposValueColors, get_masspos_value_colors,
formatTime, format_time,
getTitle, get_title,
getGaps, get_gaps,
getTimeTicks, get_time_ticks,
getDayTicks, get_day_ticks,
getUnitBitweight get_unit_bitweight,
) )
class TestGetGaps(TestCase): class TestGetGaps(TestCase):
"""Test suite for getGaps.""" """Test suite for get_gaps."""
def test_mixed_gap_sizes(self): def test_mixed_gap_sizes(self):
""" """
Test getGaps - the given list of gaps contain both gaps that are too Test get_gaps - the given list of gaps contain both gaps that are too
short and gaps that are long enough. short and gaps that are long enough.
""" """
gaps = [(0, 60), (60, 180), (180, 360)] gaps = [(0, 60), (60, 180), (180, 360)]
min_gap = 3 min_gap = 3
self.assertListEqual(getGaps(gaps, min_gap), [(180, 360)]) self.assertListEqual(get_gaps(gaps, min_gap), [(180, 360)])
def test_empty_gap_list(self): def test_empty_gap_list(self):
""" """
Test getGaps - the given list of gaps is empty. Test get_gaps - the given list of gaps is empty.
""" """
gaps = [] gaps = []
min_gap = 3 min_gap = 3
self.assertListEqual(getGaps(gaps, min_gap), []) self.assertListEqual(get_gaps(gaps, min_gap), [])
def test_all_gaps_are_too_short(self): def test_all_gaps_are_too_short(self):
""" """
Test getGaps - the given list of gaps only contain gaps that are too Test get_gaps - the given list of gaps only contain gaps that are too
short. short.
""" """
gaps = [(0, 60), (60, 180)] gaps = [(0, 60), (60, 180)]
min_gap = 3 min_gap = 3
self.assertListEqual(getGaps(gaps, min_gap), []) self.assertListEqual(get_gaps(gaps, min_gap), [])
def test_all_gaps_are_long_enough(self): def test_all_gaps_are_long_enough(self):
""" """
Test getGaps - the given list of gaps only contain gaps that are long Test get_gaps - the given list of gaps only contain gaps that are long
enough. enough.
""" """
gaps = [(0, 180), (180, 360)] gaps = [(0, 180), (180, 360)]
min_gap = 3 min_gap = 3
self.assertListEqual(getGaps(gaps, min_gap), [(0, 180), (180, 360)]) self.assertListEqual(get_gaps(gaps, min_gap), [(0, 180), (180, 360)])
class TestGetDayTicks(TestCase): class TestGetDayTicks(TestCase):
"""Test suite for getDayTicks.""" """Test suite for get_day_ticks."""
def test_get_day_ticks(self): def test_get_day_ticks(self):
"""Test getDayTicks.""" """Test get_day_ticks."""
expected = ( expected = (
[12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180, [12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180,
192, 204, 216, 228, 240, 252, 264, 276], 192, 204, 216, 228, 240, 252, 264, 276],
[48, 96, 144, 192, 240], [48, 96, 144, 192, 240],
['04', '08', '12', '16', '20'] ['04', '08', '12', '16', '20']
) )
self.assertTupleEqual(getDayTicks(), expected) self.assertTupleEqual(get_day_ticks(), expected)
class TestGetMassposValue(TestCase): class TestGetMassposValue(TestCase):
"""Test suite for getMasspossValue""" """Test suite for getMasspossValue"""
def test_string_output(self): def test_string_output(self):
""" """
Test basic functionality of getMassposValueColors - the range option Test basic functionality of get_masspos_value_colors - the range option
and color mode are correct, and the output is a string. and color mode are correct, and the output is a string.
""" """
expected_input_output_pairs = { expected_input_output_pairs = {
...@@ -89,14 +92,15 @@ class TestGetMassposValue(TestCase): ...@@ -89,14 +92,15 @@ class TestGetMassposValue(TestCase):
for input_val in expected_input_output_pairs: for input_val in expected_input_output_pairs:
with self.subTest(test_names[idx]): with self.subTest(test_names[idx]):
self.assertEqual( self.assertEqual(
getMassposValueColors(input_val[0], '', input_val[1], []), get_masspos_value_colors(input_val[0], '',
input_val[1], []),
expected_input_output_pairs[input_val] expected_input_output_pairs[input_val]
) )
idx += 1 idx += 1
def test_list_output(self): def test_list_output(self):
""" """
Test basic functionality of getMassposValueColors - the range option Test basic functionality of get_masspos_value_colors - the range option
and color mode are correct, and the output is a list. and color mode are correct, and the output is a list.
""" """
expected_input_output_pairs = { expected_input_output_pairs = {
...@@ -118,46 +122,47 @@ class TestGetMassposValue(TestCase): ...@@ -118,46 +122,47 @@ class TestGetMassposValue(TestCase):
for i, input_val in enumerate(expected_input_output_pairs): for i, input_val in enumerate(expected_input_output_pairs):
with self.subTest(test_names[i]): with self.subTest(test_names[i]):
self.assertListEqual( self.assertListEqual(
getMassposValueColors( get_masspos_value_colors(
input_val[0], '', input_val[1], [], retType=''), input_val[0], '', input_val[1], [], ret_type=''),
expected_input_output_pairs[input_val] expected_input_output_pairs[input_val]
) )
def test_range_option_not_supported(self): def test_range_option_not_supported(self):
""" """
Test basic functionality of getMassposValueColors - the range option Test basic functionality of get_masspos_value_colors - the range option
is not supported. is not supported.
""" """
errors = [] errors = []
empty_color_option = '' empty_color_option = ''
self.assertIsNone( self.assertIsNone(
getMassposValueColors(empty_color_option, '', 'B', errors)) get_masspos_value_colors(empty_color_option, '', 'B', errors))
self.assertGreater(len(errors), 0) self.assertGreater(len(errors), 0)
errors = [] errors = []
bad_color_option = 'unsupported' bad_color_option = 'unsupported'
self.assertIsNone( self.assertIsNone(
getMassposValueColors(bad_color_option, '', 'B', errors)) get_masspos_value_colors(bad_color_option, '', 'B', errors))
self.assertGreater(len(errors), 0) self.assertGreater(len(errors), 0)
def test_color_mode_not_supported(self): def test_color_mode_not_supported(self):
""" """
Test basic functionality of getMassposValueColors - the color mode is Test basic functionality of get_masspos_value_colors - color mode is
not supported. not supported.
""" """
errors = [] errors = []
empty_color_mode = '' empty_color_mode = ''
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
getMassposValueColors('regular', '', empty_color_mode, errors) get_masspos_value_colors('regular', '', empty_color_mode, errors)
errors = [] errors = []
bad_color_mode = 'unsupported' bad_color_mode = 'unsupported'
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
getMassposValueColors('regular', '', bad_color_mode, errors) get_masspos_value_colors('regular', '', bad_color_mode, errors)
class TestGetTimeTicks(TestCase): class TestGetTimeTicks(TestCase):
"""Test suite for getTimeTicks.""" """Test suite for get_time_ticks."""
def setUp(self) -> None: def setUp(self) -> None:
"""Set up text fixtures.""" """Set up text fixtures."""
self.label_cnt = 5 self.label_cnt = 5
...@@ -165,7 +170,7 @@ class TestGetTimeTicks(TestCase): ...@@ -165,7 +170,7 @@ class TestGetTimeTicks(TestCase):
def test_expected_time_range(self): def test_expected_time_range(self):
""" """
Test basic functionality of getTimeTicks - the given time range is Test basic functionality of get_time_ticks - the given time range is
expected in the data. expected in the data.
""" """
...@@ -179,7 +184,8 @@ class TestGetTimeTicks(TestCase): ...@@ -179,7 +184,8 @@ class TestGetTimeTicks(TestCase):
'19700101 00:00:03', '19700101 00:00:04'] '19700101 00:00:03', '19700101 00:00:04']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -193,7 +199,8 @@ class TestGetTimeTicks(TestCase): ...@@ -193,7 +199,8 @@ class TestGetTimeTicks(TestCase):
'19700101 00:04:00', '19700101 00:05:00'] '19700101 00:04:00', '19700101 00:05:00']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -207,7 +214,8 @@ class TestGetTimeTicks(TestCase): ...@@ -207,7 +214,8 @@ class TestGetTimeTicks(TestCase):
'19700101 07:00'] '19700101 07:00']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -222,7 +230,8 @@ class TestGetTimeTicks(TestCase): ...@@ -222,7 +230,8 @@ class TestGetTimeTicks(TestCase):
['19700102', '19700104', '19700106', '19700108', '19700110'] ['19700102', '19700104', '19700106', '19700108', '19700110']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -237,7 +246,8 @@ class TestGetTimeTicks(TestCase): ...@@ -237,7 +246,8 @@ class TestGetTimeTicks(TestCase):
['19700102', '19700105', '19700108', '19700111', '19700114'] ['19700102', '19700105', '19700108', '19700111', '19700114']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -249,7 +259,8 @@ class TestGetTimeTicks(TestCase): ...@@ -249,7 +259,8 @@ class TestGetTimeTicks(TestCase):
[864000.0, 1728000.0], [864000.0, 1728000.0],
['19700111', '19700121']) ['19700111', '19700121'])
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
...@@ -263,15 +274,16 @@ class TestGetTimeTicks(TestCase): ...@@ -263,15 +274,16 @@ class TestGetTimeTicks(TestCase):
['19700111', '19700131', '19700220', '19700312'] ['19700111', '19700131', '19700220', '19700312']
) )
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest, self.date_fmt,
self.label_cnt),
expected expected
) )
def test_boundary_time_range(self): def test_boundary_time_range(self):
""" """
Test basic functionality of getTimeTicks - the given time range is Test basic functionality of get_time_ticks - the given time range is
exactly 1 second, 1 minute, or 1 hour. Test the behavior where these exactly 1 second, 1 minute, or 1 hour. Test the behavior where these
time ranges make getTimeTicks returns a tuple that contains empty time ranges make get_time_ticks returns a tuple that contains empty
lists. lists.
""" """
expected = ([], [], []) expected = ([], [], [])
...@@ -279,55 +291,59 @@ class TestGetTimeTicks(TestCase): ...@@ -279,55 +291,59 @@ class TestGetTimeTicks(TestCase):
earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp
latest = UTCDateTime(1970, 1, 1, 0, 0, 1).timestamp latest = UTCDateTime(1970, 1, 1, 0, 0, 1).timestamp
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
with self.subTest('test_exactly_one_minute'): with self.subTest('test_exactly_one_minute'):
earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp
latest = UTCDateTime(1970, 1, 1, 0, 1, 0).timestamp latest = UTCDateTime(1970, 1, 1, 0, 1, 0).timestamp
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
with self.subTest('test_exactly_one_hour'): with self.subTest('test_exactly_one_hour'):
earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp earliest = UTCDateTime(1970, 1, 1, 0, 0, 0).timestamp
latest = UTCDateTime(1970, 1, 1, 1, 0, 0).timestamp latest = UTCDateTime(1970, 1, 1, 1, 0, 0).timestamp
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(earliest, latest, self.date_fmt, self.label_cnt), get_time_ticks(earliest, latest,
self.date_fmt, self.label_cnt),
expected expected
) )
def test_earliest_time_later_than_latest_time(self): def test_earliest_time_later_than_latest_time(self):
""" """
Test basic functionality of getTimeTicks - the given latest time is Test basic functionality of get_time_ticks - the given latest time is
earlier than the earliest time. earlier than the earliest time.
""" """
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(100, 0, self.date_fmt, self.label_cnt), get_time_ticks(100, 0, self.date_fmt, self.label_cnt),
([], [], []) ([], [], [])
) )
def test_time_range_is_zero(self): def test_time_range_is_zero(self):
""" """
Test basic functionality of getTimeTicks - the given time range is 0. Test basic functionality of get_time_ticks - the given time range is 0.
""" """
self.assertTupleEqual( self.assertTupleEqual(
getTimeTicks(0, 0, self.date_fmt, self.label_cnt), get_time_ticks(0, 0, self.date_fmt, self.label_cnt),
([], [], []) ([], [], [])
) )
def test_get_time_ticks_no_label_displayed(self): def test_get_time_ticks_no_label_displayed(self):
""" """
Test basic functionality of getTimeTicks - no time label is requested Test basic functionality of get_time_ticks - no time label is requested
to be displayed. to be displayed.
""" """
zero_label_cnt = 0 zero_label_cnt = 0
with self.assertRaises(ZeroDivisionError): with self.assertRaises(ZeroDivisionError):
getTimeTicks(0, 1000, self.date_fmt, zero_label_cnt) get_time_ticks(0, 1000, self.date_fmt, zero_label_cnt)
class TestFormatTimeAndGetTitle(TestCase): class TestFormatTimeAndGetTitle(TestCase):
"""Test suite for formatTime and getTitle""" """Test suite for format_time and get_title"""
def setUp(self) -> None: def setUp(self) -> None:
"""Set up text fixtures.""" """Set up text fixtures."""
self.positive_epoch_time = 67567567 self.positive_epoch_time = 67567567
...@@ -348,7 +364,7 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -348,7 +364,7 @@ class TestFormatTimeAndGetTitle(TestCase):
def test_format_time_epoch_time_date_only(self): def test_format_time_epoch_time_date_only(self):
""" """
Test basic functionality of formatTime - given time is epoch time and Test basic functionality of format_time - given time is epoch time and
uses only a date format. Tests three cases for each date format: epoch uses only a date format. Tests three cases for each date format: epoch
time is positive, negative, and zero. time is positive, negative, and zero.
""" """
...@@ -367,18 +383,18 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -367,18 +383,18 @@ class TestFormatTimeAndGetTitle(TestCase):
for test_name, date_mode in test_name_to_date_mode_map.items(): for test_name, date_mode in test_name_to_date_mode_map.items():
with self.subTest(test_name): with self.subTest(test_name):
self.assertEqual( self.assertEqual(
formatTime(self.positive_epoch_time, date_mode), format_time(self.positive_epoch_time, date_mode),
self.positive_formatted_dates[date_mode]) self.positive_formatted_dates[date_mode])
self.assertEqual( self.assertEqual(
formatTime(self.negative_epoch_time, date_mode), format_time(self.negative_epoch_time, date_mode),
self.negative_formatted_dates[date_mode]) self.negative_formatted_dates[date_mode])
self.assertEqual( self.assertEqual(
formatTime(0, date_mode), format_time(0, date_mode),
zero_epoch_formatted[test_name]) zero_epoch_formatted[test_name])
def test_format_time_epoch_time_date_and_time(self): def test_format_time_epoch_time_date_and_time(self):
""" """
Test basic functionality of formatTime - given time is epoch time and Test basic functionality of format_time - given time is epoch time and
both a time and a date format are used. Tests three cases for each date both a time and a date format are used. Tests three cases for each date
format: epoch time is positive, negative, and zero. format: epoch time is positive, negative, and zero.
""" """
...@@ -406,21 +422,21 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -406,21 +422,21 @@ class TestFormatTimeAndGetTitle(TestCase):
) )
self.assertEqual( self.assertEqual(
formatTime(self.positive_epoch_time, date_mode, format_time(self.positive_epoch_time, date_mode,
'HH:MM:SS'), 'HH:MM:SS'),
positive_time_expected positive_time_expected
) )
self.assertEqual( self.assertEqual(
formatTime(self.negative_epoch_time, date_mode, format_time(self.negative_epoch_time, date_mode,
'HH:MM:SS'), 'HH:MM:SS'),
negative_time_expected negative_time_expected
) )
self.assertEqual(formatTime(0, date_mode, 'HH:MM:SS'), self.assertEqual(format_time(0, date_mode, 'HH:MM:SS'),
zero_epoch_formatted[test_name]) zero_epoch_formatted[test_name])
def test_format_time_UTCDateTime_date_only(self): def test_format_time_UTCDateTime_date_only(self):
""" """
Test basic functionality of formatTime - given time is an UTCDateTime Test basic functionality of format_time - given time is an UTCDateTime
instance and uses only a date format. instance and uses only a date format.
""" """
test_name_to_date_mode_map = { test_name_to_date_mode_map = {
...@@ -432,12 +448,12 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -432,12 +448,12 @@ class TestFormatTimeAndGetTitle(TestCase):
expected_dates = self.positive_formatted_dates expected_dates = self.positive_formatted_dates
for test_name, date_mode in test_name_to_date_mode_map.items(): for test_name, date_mode in test_name_to_date_mode_map.items():
with self.subTest(test_name): with self.subTest(test_name):
self.assertEqual(formatTime(utc_date_time, date_mode), self.assertEqual(format_time(utc_date_time, date_mode),
expected_dates[date_mode]) expected_dates[date_mode])
def test_format_time_UTCDateTime_date_and_time(self): def test_format_time_UTCDateTime_date_and_time(self):
""" """
Test basic functionality of formatTime - given time is an UTCDateTime Test basic functionality of format_time - given time is an UTCDateTime
instance and both a time and a date format are used. instance and both a time and a date format are used.
""" """
test_name_to_date_mode_map = { test_name_to_date_mode_map = {
...@@ -451,13 +467,13 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -451,13 +467,13 @@ class TestFormatTimeAndGetTitle(TestCase):
for test_name, date_mode in test_name_to_date_mode_map.items(): for test_name, date_mode in test_name_to_date_mode_map.items():
with self.subTest('test_year_month_day_format'): with self.subTest('test_year_month_day_format'):
self.assertEqual( self.assertEqual(
formatTime(test_time, date_mode, 'HH:MM:SS'), format_time(test_time, date_mode, 'HH:MM:SS'),
f'{expected_dates[date_mode]} {expected_time}' f'{expected_dates[date_mode]} {expected_time}'
) )
def test_format_time_unsupported_date_format(self): def test_format_time_unsupported_date_format(self):
""" """
Test basic functionality of formatTime - given date format is not Test basic functionality of format_time - given date format is not
supported. supported.
""" """
test_time = self.positive_epoch_time test_time = self.positive_epoch_time
...@@ -466,21 +482,21 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -466,21 +482,21 @@ class TestFormatTimeAndGetTitle(TestCase):
with self.subTest('test_without_time_format'): with self.subTest('test_without_time_format'):
expected = '' expected = ''
self.assertEqual(formatTime(test_time, empty_format), self.assertEqual(format_time(test_time, empty_format),
expected) expected)
self.assertEqual(formatTime(test_time, bad_format), self.assertEqual(format_time(test_time, bad_format),
expected) expected)
with self.subTest('test_with_time_format'): with self.subTest('test_with_time_format'):
expected = f' {self.positive_formatted_time}' expected = f' {self.positive_formatted_time}'
self.assertEqual(formatTime(test_time, empty_format, 'HH:MM:SS'), self.assertEqual(format_time(test_time, empty_format, 'HH:MM:SS'),
expected) expected)
self.assertEqual(formatTime(test_time, bad_format, 'HH:MM:SS'), self.assertEqual(format_time(test_time, bad_format, 'HH:MM:SS'),
expected) expected)
def test_format_time_unsupported_time_format(self): def test_format_time_unsupported_time_format(self):
""" """
Test basic functionality of formatTime - given time format is not Test basic functionality of format_time - given time format is not
supported. supported.
""" """
test_time = self.positive_epoch_time test_time = self.positive_epoch_time
...@@ -489,14 +505,14 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -489,14 +505,14 @@ class TestFormatTimeAndGetTitle(TestCase):
bad_format = 'bad_format' bad_format = 'bad_format'
expected = self.positive_formatted_dates[date_format] expected = self.positive_formatted_dates[date_format]
self.assertEqual(formatTime(test_time, date_format, empty_format), self.assertEqual(format_time(test_time, date_format, empty_format),
expected) expected)
self.assertEqual(formatTime(test_time, date_format, bad_format), self.assertEqual(format_time(test_time, date_format, bad_format),
expected) expected)
def test_format_time_unsupported_date_and_time_format(self): def test_format_time_unsupported_date_and_time_format(self):
""" """
Test basic functionality of formatTime - both given date and time Test basic functionality of format_time - both given date and time
format are unsupported. format are unsupported.
""" """
test_time = self.positive_epoch_time test_time = self.positive_epoch_time
...@@ -506,12 +522,12 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -506,12 +522,12 @@ class TestFormatTimeAndGetTitle(TestCase):
for date_format in bad_date_formats: for date_format in bad_date_formats:
for time_format in bad_time_format: for time_format in bad_time_format:
self.assertEqual( self.assertEqual(
formatTime(test_time, date_format, time_format), format_time(test_time, date_format, time_format),
expected expected
) )
def test_get_title(self): def test_get_title(self):
"""Test basic functionality of getTitle.""" """Test basic functionality of get_title."""
date_mode = 'YYYYMMDD' date_mode = 'YYYYMMDD'
min_time = 0 min_time = 0
max_time = self.positive_epoch_time max_time = self.positive_epoch_time
...@@ -521,18 +537,18 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -521,18 +537,18 @@ class TestFormatTimeAndGetTitle(TestCase):
key = '3734' key = '3734'
expected = (f'3734 19700101 00:00:00 to ' expected = (f'3734 19700101 00:00:00 to '
f'{formatted_max_time} (18768.77)') f'{formatted_max_time} (18768.77)')
self.assertEqual(getTitle(key, min_time, max_time, date_mode), self.assertEqual(get_title(key, min_time, max_time, date_mode),
expected) expected)
with self.subTest('test_rt130'): with self.subTest('test_rt130'):
key = ('92EB', 25) key = ('92EB', 25)
expected = (f"('92EB', 25) 19700101 00:00:00 to " expected = (f"('92EB', 25) 19700101 00:00:00 to "
f"{formatted_max_time} (18768.77)") f"{formatted_max_time} (18768.77)")
self.assertEqual(getTitle(key, min_time, max_time, date_mode), self.assertEqual(get_title(key, min_time, max_time, date_mode),
expected) expected)
def test_get_title_max_time_earlier_than_min_time(self): def test_get_title_max_time_earlier_than_min_time(self):
""" """
Test basic functionality of getTitle - the given maximum time is Test basic functionality of get_title - the given maximum time is
chronologically earlier than the given minimum time. chronologically earlier than the given minimum time.
""" """
date_mode = 'YYYYMMDD' date_mode = 'YYYYMMDD'
...@@ -544,18 +560,18 @@ class TestFormatTimeAndGetTitle(TestCase): ...@@ -544,18 +560,18 @@ class TestFormatTimeAndGetTitle(TestCase):
key = '3734' key = '3734'
expected = (f'3734 {formatted_max_time} to ' expected = (f'3734 {formatted_max_time} to '
f'19700101 00:00:00 (-18768.77)') f'19700101 00:00:00 (-18768.77)')
self.assertEqual(getTitle(key, min_time, max_time, date_mode), self.assertEqual(get_title(key, min_time, max_time, date_mode),
expected) expected)
with self.subTest('test_rt130'): with self.subTest('test_rt130'):
key = ('92EB', 25) key = ('92EB', 25)
expected = (f"('92EB', 25) {formatted_max_time} to " expected = (f"('92EB', 25) {formatted_max_time} to "
f"19700101 00:00:00 (-18768.77)") f"19700101 00:00:00 (-18768.77)")
self.assertEqual(getTitle(key, min_time, max_time, date_mode), self.assertEqual(get_title(key, min_time, max_time, date_mode),
expected) expected)
class TestGetUnitBitweight(TestCase): class TestGetUnitBitweight(TestCase):
"""Test suite for getUnitBitweight.""" """Test suite for get_unit_bitweight."""
def setUp(self) -> None: def setUp(self) -> None:
"""Set up test fixtures.""" """Set up test fixtures."""
...@@ -565,68 +581,68 @@ class TestGetUnitBitweight(TestCase): ...@@ -565,68 +581,68 @@ class TestGetUnitBitweight(TestCase):
'plotType': 'test_plot_type', 'plotType': 'test_plot_type',
'unit': 'test_unit' 'unit': 'test_unit'
} }
# In most cases, we do not care about the value of bitweightOpt. So, we # In most cases, we do not care about the value of bitweight_opt. So,
# give it a default value unless needed. # we give it a default value unless needed.
self.default_bitweight_opt = 'low' self.default_bitweight_opt = 'low'
def test_soh_channel_linesDots_linesSRate_linesMasspos_plot_type(self): def test_soh_channel_linesDots_linesSRate_linesMasspos_plot_type(self):
""" """
Test basic functionality of getUnitBitweight - the given plot type is Test basic functionality of get_unit_bitweight - the given plot type is
linesDots, linesSRate, or linesMassposs. linesDots, linesSRate, or linesMassposs.
""" """
self.chan_info['plotType'] = 'linesDots' self.chan_info['plotType'] = 'linesDots'
with self.subTest('test_no_fix_point'): with self.subTest('test_no_fix_point'):
self.assertEqual( self.assertEqual(
getUnitBitweight(self.chan_info, self.default_bitweight_opt), get_unit_bitweight(self.chan_info, self.default_bitweight_opt),
'{}test_unit' '{}test_unit'
) )
with self.subTest('test_have_fix_point'): with self.subTest('test_have_fix_point'):
self.chan_info['fixPoint'] = 1 self.chan_info['fixPoint'] = 1
self.assertEqual( self.assertEqual(
getUnitBitweight(self.chan_info, self.default_bitweight_opt), get_unit_bitweight(self.chan_info, self.default_bitweight_opt),
'{:.1f}test_unit' '{:.1f}test_unit'
) )
def test_soh_channel_other_plot_type(self): def test_soh_channel_other_plot_type(self):
""" """
Test basic functionality of getUnitBitweight - the given plot type is Test basic functionality of get_unit_bitweight - the given plot type is
not linesDots, linesSRate, or linesMassposs and the channel is not a not linesDots, linesSRate, or linesMassposs and the channel is not a
seismic data channel. seismic data channel.
""" """
self.assertEqual( self.assertEqual(
getUnitBitweight(self.chan_info, self.default_bitweight_opt), get_unit_bitweight(self.chan_info, self.default_bitweight_opt),
'' ''
) )
def test_seismic_channel_have_fix_point(self): def test_seismic_channel_have_fix_point(self):
""" """
Test basic functionality of getUnitBitweight - the given plot type is Test basic functionality of get_unit_bitweight - the given plot type is
not linesDots, linesSRate, or linesMassposs, the channel is a not linesDots, linesSRate, or linesMassposs, the channel is a
seismic data channel, and there is a fix point. seismic data channel, and there is a fix point.
""" """
self.chan_info['channel'] = 'SEISMIC' self.chan_info['channel'] = 'SEISMIC'
self.chan_info['fixPoint'] = 1 self.chan_info['fixPoint'] = 1
self.assertEqual( self.assertEqual(
getUnitBitweight(self.chan_info, self.default_bitweight_opt), get_unit_bitweight(self.chan_info, self.default_bitweight_opt),
'{:.1f}test_unit' '{:.1f}test_unit'
) )
def test_seismic_channel_no_fix_point(self): def test_seismic_channel_no_fix_point(self):
""" """
Test basic functionality of getUnitBitweight - the given plot type is Test basic functionality of get_unit_bitweight - the given plot type is
not linesDots, linesSRate, or linesMassposs, the channel is a not linesDots, linesSRate, or linesMassposs, the channel is a
seismic data channel, and there is no fix point. seismic data channel, and there is no fix point.
""" """
self.chan_info['channel'] = 'SEISMIC' self.chan_info['channel'] = 'SEISMIC'
with self.subTest('test_no_bitweight_option'): with self.subTest('test_no_bitweight_option'):
self.assertEqual(getUnitBitweight(self.chan_info, 'high'), '{}V') self.assertEqual(get_unit_bitweight(self.chan_info, 'high'), '{}V')
with self.subTest('test_have_bitweight_option'): with self.subTest('test_have_bitweight_option'):
self.assertEqual(getUnitBitweight(self.chan_info, ''), self.assertEqual(get_unit_bitweight(self.chan_info, ''),
'{}test_unit') '{}test_unit')
def test_no_fix_point(self): def test_no_fix_point(self):
""" """
Test basic functionality of getUnitBitweight - the given channel info Test basic functionality of get_unit_bitweight - the given channel info
does not contain a value for fixPoint. does not contain a value for fixPoint.
""" """
del self.chan_info['fixPoint'] del self.chan_info['fixPoint']
...@@ -637,13 +653,13 @@ class TestGetUnitBitweight(TestCase): ...@@ -637,13 +653,13 @@ class TestGetUnitBitweight(TestCase):
# https://stackoverflow.com/questions/6181555/pass-a-python-unittest-if-an-exception-isnt-raised?noredirect=1&lq=1 # noqa # https://stackoverflow.com/questions/6181555/pass-a-python-unittest-if-an-exception-isnt-raised?noredirect=1&lq=1 # noqa
# https://stackoverflow.com/questions/4319825/python-unittest-opposite-of-assertraises/4319870#4319870 # noqa:E501 # https://stackoverflow.com/questions/4319825/python-unittest-opposite-of-assertraises/4319870#4319870 # noqa:E501
# Some context for this code: in getUnitBitWeight, if chanDB does not # Some context for this code: in getUnitBitWeight, if chan_db_info does
# have 'fix_point' as a key, then the value of fix_point defaults to 0. # not have 'fix_point' as a key, then the value of fix_point defaults
# This is implemented by getting the value of 'fix_point' in chanDB and # to 0. This is implemented by getting the value of 'fix_point' in
# catching the resulting KeyError. So, in order to test that # chan_db_info and catching the resulting KeyError. So, in order to
# getUnitBitweight handles this case correctly, we assert that no # test that get_unit_bitweight handles this case correctly, we assert
# exception was raised. # that no exception was raised.
try: try:
getUnitBitweight(self.chan_info, '') get_unit_bitweight(self.chan_info, '')
except KeyError: except KeyError:
self.fail() self.fail()
...@@ -7,10 +7,10 @@ from contextlib import redirect_stdout ...@@ -7,10 +7,10 @@ from contextlib import redirect_stdout
import io import io
from sohstationviewer.controller.processing import ( from sohstationviewer.controller.processing import (
loadData, load_data,
readChannels, read_channels,
detectDataType, detect_data_type,
getDataTypeFromFile get_data_type_from_file
) )
from sohstationviewer.database.extract_data import get_signature_channels from sohstationviewer.database.extract_data import get_signature_channels
from PySide2 import QtWidgets from PySide2 import QtWidgets
...@@ -25,7 +25,7 @@ pegasus_dir = TEST_DATA_DIR.joinpath('Pegasus-sample/Pegasus_SVC4/soh') ...@@ -25,7 +25,7 @@ pegasus_dir = TEST_DATA_DIR.joinpath('Pegasus-sample/Pegasus_SVC4/soh')
class TestLoadDataAndReadChannels(TestCase): class TestLoadDataAndReadChannels(TestCase):
"""Test suite for loadData and readChannels.""" """Test suite for load_data and read_channels."""
def setUp(self) -> None: def setUp(self) -> None:
"""Set up test fixtures.""" """Set up test fixtures."""
...@@ -40,182 +40,184 @@ class TestLoadDataAndReadChannels(TestCase): ...@@ -40,182 +40,184 @@ class TestLoadDataAndReadChannels(TestCase):
def test_load_data_rt130_good_dir(self): def test_load_data_rt130_good_dir(self):
""" """
Test basic functionality of loadData - the given directory can be Test basic functionality of load_data - the given directory can be
loaded without issues. Test RT130. loaded without issues. Test RT130.
""" """
self.assertIsInstance( self.assertIsInstance(
loadData('RT130', self.widget_stub, [rt130_dir]), load_data('RT130', self.widget_stub, [rt130_dir]),
RT130 RT130
) )
def test_load_data_mseed_q330_good_data_dir(self): def test_load_data_mseed_q330_good_data_dir(self):
""" """
Test basic functionality of loadData - the given directory can be Test basic functionality of load_data - the given directory can be
loaded without issues. Test MSeed. loaded without issues. Test MSeed.
""" """
self.assertIsInstance( self.assertIsInstance(
loadData(self.mseed_dtype, self.widget_stub, [q330_dir]), load_data(self.mseed_dtype, self.widget_stub, [q330_dir]),
MSeed MSeed
) )
self.assertIsInstance( self.assertIsInstance(
loadData(self.mseed_dtype, self.widget_stub, [centaur_dir]), load_data(self.mseed_dtype, self.widget_stub, [centaur_dir]),
MSeed MSeed
) )
self.assertIsInstance( self.assertIsInstance(
loadData(self.mseed_dtype, self.widget_stub, [pegasus_dir]), load_data(self.mseed_dtype, self.widget_stub, [pegasus_dir]),
MSeed MSeed
) )
def test_load_data_no_dir(self): def test_load_data_no_dir(self):
"""Test basic functionality of loadData - no directory was given.""" """Test basic functionality of load_data - no directory was given."""
no_dir_given = [] no_dir_given = []
self.assertIsNone(loadData('RT130', self.widget_stub, no_dir_given)) self.assertIsNone(load_data('RT130', self.widget_stub, no_dir_given))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, no_dir_given)) load_data(self.mseed_dtype, self.widget_stub, no_dir_given))
def test_load_data_dir_does_not_exist(self): def test_load_data_dir_does_not_exist(self):
""" """
Test basic functionality of loadData - the given directory does not Test basic functionality of load_data - the given directory does not
exist. exist.
""" """
empty_name_dir = [''] empty_name_dir = ['']
non_existent_dir = ['dir_that_does_not_exist'] non_existent_dir = ['dir_that_does_not_exist']
self.assertIsNone( self.assertIsNone(
loadData('RT130', self.widget_stub, empty_name_dir)) load_data('RT130', self.widget_stub, empty_name_dir))
self.assertIsNone( self.assertIsNone(
loadData('RT130', self.widget_stub, non_existent_dir)) load_data('RT130', self.widget_stub, non_existent_dir))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, empty_name_dir)) load_data(self.mseed_dtype, self.widget_stub, empty_name_dir))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, non_existent_dir)) load_data(self.mseed_dtype, self.widget_stub, non_existent_dir))
def test_load_data_empty_dir(self): def test_load_data_empty_dir(self):
""" """
Test basic functionality of loadData - the given directory is empty. Test basic functionality of load_data - the given directory is empty.
""" """
with TemporaryDirectory() as empty_dir: with TemporaryDirectory() as empty_dir:
self.assertIsNone( self.assertIsNone(
loadData('RT130', self.widget_stub, [empty_dir])) load_data('RT130', self.widget_stub, [empty_dir]))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, [empty_dir])) load_data(self.mseed_dtype, self.widget_stub, [empty_dir]))
def test_load_data_empty_data_dir(self): def test_load_data_empty_data_dir(self):
""" """
Test basic functionality of loadData - the given directory Test basic functionality of load_data - the given directory
contains a data folder but no data file. contains a data folder but no data file.
""" """
with TemporaryDirectory() as outer_dir: with TemporaryDirectory() as outer_dir:
with TemporaryDirectory(dir=outer_dir) as data_dir: with TemporaryDirectory(dir=outer_dir) as data_dir:
self.assertIsNone( self.assertIsNone(
loadData('RT130', self.widget_stub, [data_dir])) load_data('RT130', self.widget_stub, [data_dir]))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, [outer_dir])) load_data(self.mseed_dtype, self.widget_stub, [outer_dir]))
def test_load_data_data_type_mismatch(self): def test_load_data_data_type_mismatch(self):
""" """
Test basic functionality of loadData - the data type given does not Test basic functionality of load_data - the data type given does not
match the type of the data contained in the given directory. match the type of the data contained in the given directory.
""" """
self.assertIsNone( self.assertIsNone(
loadData('RT130', self.widget_stub, [q330_dir])) load_data('RT130', self.widget_stub, [q330_dir]))
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, self.widget_stub, [rt130_dir])) load_data(self.mseed_dtype, self.widget_stub, [rt130_dir]))
def test_load_data_data_traceback_error(self): def test_load_data_data_traceback_error(self):
""" """
Test basic functionality of loadData - when there is an error Test basic functionality of load_data - when there is an error
on loading data, the traceback info will be printed out on loading data, the traceback info will be printed out
""" """
f = io.StringIO() f = io.StringIO()
with redirect_stdout(f): with redirect_stdout(f):
self.assertIsNone(loadData('RT130', None, [q330_dir])) self.assertIsNone(load_data('RT130', None, [q330_dir]))
output = f.getvalue() output = f.getvalue()
self.assertIn( self.assertIn(
f"WARNING: Dir {q330_dir} " f"Dir {q330_dir} "
f"can't be read due to error: Traceback", f"can't be read due to error: Traceback",
output output
) )
with redirect_stdout(f): with redirect_stdout(f):
self.assertIsNone( self.assertIsNone(
loadData(self.mseed_dtype, None, [rt130_dir])) load_data(self.mseed_dtype, None, [rt130_dir]))
output = f.getvalue() output = f.getvalue()
self.assertIn( self.assertIn(
f"WARNING: Dir {rt130_dir} " f"Dir {rt130_dir} "
f"can't be read due to error: Traceback", f"can't be read due to error: Traceback",
output output
) )
def test_read_channels_mseed_dir(self): def test_read_channels_mseed_dir(self):
""" """
Test basic functionality of loadData - the given directory contains Test basic functionality of load_data - the given directory contains
MSeed data. MSeed data.
""" """
q330_channels = {'VKI', 'VM1'} q330_channels = {'VKI', 'VM1'}
self.assertSetEqual(readChannels(self.widget_stub, [q330_dir]), self.assertSetEqual(read_channels(self.widget_stub, [q330_dir]),
q330_channels) q330_channels)
centaur_channels = {'VDT', 'VM3', 'EX3', 'GEL', 'VEC', 'EX2', 'LCE', centaur_channels = {'VDT', 'VM3', 'EX3', 'GEL', 'VEC', 'EX2', 'LCE',
'EX1', 'GLA', 'LCQ', 'GPL', 'GNS', 'GST', 'VCO', 'EX1', 'GLA', 'LCQ', 'GPL', 'GNS', 'GST', 'VCO',
'GAN', 'GLO', 'VPB', 'VEI', 'VM2', 'VM1'} 'GAN', 'GLO', 'VPB', 'VEI', 'VM2', 'VM1'}
self.assertSetEqual(readChannels(self.widget_stub, [centaur_dir]), self.assertSetEqual(read_channels(self.widget_stub, [centaur_dir]),
centaur_channels) centaur_channels)
pegasus_channels = {'VDT', 'VM1', 'VE1'} pegasus_channels = {'VDT', 'VM1', 'VE1'}
self.assertSetEqual(readChannels(self.widget_stub, [pegasus_dir]), self.assertSetEqual(read_channels(self.widget_stub, [pegasus_dir]),
pegasus_channels) pegasus_channels)
def test_read_channels_rt130_dir(self): def test_read_channels_rt130_dir(self):
""" """
Test basic functionality of loadData - the given directory contains Test basic functionality of load_data - the given directory contains
RT130 data. RT130 data.
""" """
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, [rt130_dir]) read_channels(self.widget_stub, [rt130_dir])
def test_read_channels_no_dir(self): def test_read_channels_no_dir(self):
""" """
Test basic functionality of readChannels - no directory was given. Test basic functionality of read_channels - no directory was given.
""" """
no_dir = [] no_dir = []
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, no_dir) read_channels(self.widget_stub, no_dir)
def test_read_channels_dir_does_not_exist(self): def test_read_channels_dir_does_not_exist(self):
""" """
Test basic functionality of readChannels - the given directory does not Test basic functionality of read_channels - the given directory does
exist. not exist.
""" """
empty_name_dir = [''] empty_name_dir = ['']
non_existent_dir = ['non_existent_dir'] non_existent_dir = ['non_existent_dir']
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, empty_name_dir) read_channels(self.widget_stub, empty_name_dir)
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, non_existent_dir) read_channels(self.widget_stub, non_existent_dir)
def test_read_channels_empty_dir(self): def test_read_channels_empty_dir(self):
""" """
Test basic functionality of readChannels - the given directory is Test basic functionality of read_channels - the given directory is
empty. empty.
""" """
with TemporaryDirectory() as empty_dir: with TemporaryDirectory() as empty_dir:
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, [empty_dir]) read_channels(self.widget_stub, [empty_dir])
def test_read_channels_empty_data_dir(self): def test_read_channels_empty_data_dir(self):
""" """
Test basic functionality of readChannels - the given directory Test basic functionality of read_channels - the given directory
contains a data folder but no data file. contains a data folder but no data file.
""" """
with TemporaryDirectory() as outer_dir: with TemporaryDirectory() as outer_dir:
with TemporaryDirectory(dir=outer_dir): with TemporaryDirectory(dir=outer_dir):
with self.assertRaises(Exception): with self.assertRaises(Exception):
readChannels(self.widget_stub, [outer_dir]) read_channels(self.widget_stub, [outer_dir])
class TestDetectDataType(TestCase): class TestDetectDataType(TestCase):
"""Test suite for detectDataType and getDataTypeFromFile functions.""" """
Test suite for detect_data_type and get_data_type_from_file functions.
"""
def setUp(self) -> None: def setUp(self) -> None:
"""Set up text fixtures.""" """Set up text fixtures."""
...@@ -224,7 +226,7 @@ class TestDetectDataType(TestCase): ...@@ -224,7 +226,7 @@ class TestDetectDataType(TestCase):
self.widget_stub = widget_patcher.start() self.widget_stub = widget_patcher.start()
func_patcher = patch('sohstationviewer.controller.processing.' func_patcher = patch('sohstationviewer.controller.processing.'
'getDataTypeFromFile') 'get_data_type_from_file')
self.addCleanup(func_patcher.stop) self.addCleanup(func_patcher.stop)
self.mock_get_data_type_from_file = func_patcher.start() self.mock_get_data_type_from_file = func_patcher.start()
...@@ -241,20 +243,20 @@ class TestDetectDataType(TestCase): ...@@ -241,20 +243,20 @@ class TestDetectDataType(TestCase):
def test_one_directory_not_unknown_data_type(self): def test_one_directory_not_unknown_data_type(self):
""" """
Test basic functionality of detectDataType - only one directory was Test basic functionality of detect_data_type - only one directory was
given and the data type it contains can be detected. given and the data type it contains can be detected.
""" """
expected_data_type = ('RT130', '_') expected_data_type = ('RT130', '_')
self.mock_get_data_type_from_file.return_value = expected_data_type self.mock_get_data_type_from_file.return_value = expected_data_type
self.assertEqual( self.assertEqual(
detectDataType(self.widget_stub, [self.dir1.name]), detect_data_type(self.widget_stub, [self.dir1.name]),
expected_data_type[0] expected_data_type[0]
) )
def test_same_data_type_and_channel(self): def test_same_data_type_and_channel(self):
""" """
Test basic functionality of detectDataType - the given directories Test basic functionality of detect_data_type - the given directories
contain the same data type and the data type was detected using the contain the same data type and the data type was detected using the
same channel. same channel.
""" """
...@@ -262,13 +264,15 @@ class TestDetectDataType(TestCase): ...@@ -262,13 +264,15 @@ class TestDetectDataType(TestCase):
self.mock_get_data_type_from_file.return_value = expected_data_type self.mock_get_data_type_from_file.return_value = expected_data_type
self.assertEqual( self.assertEqual(
detectDataType(self.widget_stub, [self.dir1.name, self.dir2.name]), detect_data_type(
self.widget_stub, [self.dir1.name, self.dir2.name]
),
expected_data_type[0] expected_data_type[0]
) )
def test_same_data_type_different_channel(self): def test_same_data_type_different_channel(self):
""" """
Test basic functionality of detectDataType - the given directories Test basic functionality of detect_data_type - the given directories
contain the same data type but the data type was detected using contain the same data type but the data type was detected using
different channels. different channels.
""" """
...@@ -276,61 +280,65 @@ class TestDetectDataType(TestCase): ...@@ -276,61 +280,65 @@ class TestDetectDataType(TestCase):
self.mock_get_data_type_from_file.side_effect = returned_data_types self.mock_get_data_type_from_file.side_effect = returned_data_types
self.assertEqual( self.assertEqual(
detectDataType(self.widget_stub, [self.dir1.name, self.dir2.name]), detect_data_type(
self.widget_stub, [self.dir1.name, self.dir2.name]
),
returned_data_types[0][0] returned_data_types[0][0]
) )
def test_different_data_types(self): def test_different_data_types(self):
""" """
Test basic functionality of detectDataType - the given directories Test basic functionality of detect_data_type - the given directories
contain different data types. contain different data types.
""" """
returned_data_types = [('RT130', '_'), ('Q330', 'VEP')] returned_data_types = [('RT130', '_'), ('Q330', 'VEP')]
self.mock_get_data_type_from_file.side_effect = returned_data_types self.mock_get_data_type_from_file.side_effect = returned_data_types
self.assertIsNone( self.assertIsNone(
detectDataType(self.widget_stub, [self.dir1.name, self.dir2.name]) detect_data_type(
self.widget_stub, [self.dir1.name, self.dir2.name]
)
) )
def test_unknown_data_type(self): def test_unknown_data_type(self):
""" """
Test basic functionality of detectDataType - can't detect any data Test basic functionality of detect_data_type - can't detect any data
type. type.
""" """
unknown_data_type = ('Unknown', '_') unknown_data_type = ('Unknown', '_')
self.mock_get_data_type_from_file.return_value = unknown_data_type self.mock_get_data_type_from_file.return_value = unknown_data_type
self.assertIsNone(detectDataType(self.widget_stub, [self.dir1.name])) self.assertIsNone(detect_data_type(self.widget_stub, [self.dir1.name]))
class TestGetDataTypeFromFile(TestCase): class TestGetDataTypeFromFile(TestCase):
"""Test suite for getDataTypeFromFile""" """Test suite for get_data_type_from_file"""
def test_rt130_data(self): def test_rt130_data(self):
""" """
Test basic functionality of getDataTypeFromFile - given file contains Test basic functionality of get_data_type_from_file - given file
RT130 data. contains RT130 data.
""" """
rt130_file = Path(rt130_dir).joinpath( rt130_file = Path(rt130_dir).joinpath(
'92EB/0/000000000_00000000') '92EB/0/000000000_00000000')
expected_data_type = ('RT130', '_') expected_data_type = ('RT130', '_')
self.assertTupleEqual( self.assertTupleEqual(
getDataTypeFromFile(rt130_file, get_signature_channels()), get_data_type_from_file(rt130_file, get_signature_channels()),
expected_data_type expected_data_type
) )
def test_cannot_detect_data_type(self): def test_cannot_detect_data_type(self):
""" """
Test basic functionality of getDataTypeFromFile - cannot detect data Test basic functionality of get_data_type_from_file - cannot detect
type contained in given file. data type contained in given file.
""" """
test_file = NamedTemporaryFile() test_file = NamedTemporaryFile()
self.assertIsNone( self.assertIsNone(
getDataTypeFromFile(test_file.name, get_signature_channels())) get_data_type_from_file(test_file.name, get_signature_channels()))
def test_mseed_data(self): def test_mseed_data(self):
""" """
Test basic functionality of getDataTypeFromFile - given file contains Test basic functionality of get_data_type_from_file - given file
MSeed data. contains MSeed data.
""" """
q330_file = q330_dir.joinpath('AX08.XA..VKI.2021.186') q330_file = q330_dir.joinpath('AX08.XA..VKI.2021.186')
centaur_file = centaur_dir.joinpath( centaur_file = centaur_dir.joinpath(
...@@ -343,21 +351,22 @@ class TestGetDataTypeFromFile(TestCase): ...@@ -343,21 +351,22 @@ class TestGetDataTypeFromFile(TestCase):
sig_chan = get_signature_channels() sig_chan = get_signature_channels()
self.assertTupleEqual(getDataTypeFromFile(q330_file, sig_chan), self.assertTupleEqual(get_data_type_from_file(q330_file, sig_chan),
q330_data_type) q330_data_type)
self.assertTupleEqual(getDataTypeFromFile(centaur_file, sig_chan), self.assertTupleEqual(get_data_type_from_file(centaur_file, sig_chan),
centaur_data_type) centaur_data_type)
self.assertTupleEqual(getDataTypeFromFile(pegasus_file, sig_chan), self.assertTupleEqual(get_data_type_from_file(pegasus_file, sig_chan),
pegasus_data_type) pegasus_data_type)
def test_file_does_not_exist(self): def test_file_does_not_exist(self):
""" """
Test basic functionality of getDataTypeFromFile - given file does not Test basic functionality of get_data_type_from_file - given file does
exist. not exist.
""" """
empty_name_file = '' empty_name_file = ''
non_existent_file = 'non_existent_dir' non_existent_file = 'non_existent_dir'
with self.assertRaises(FileNotFoundError): with self.assertRaises(FileNotFoundError):
getDataTypeFromFile(empty_name_file, get_signature_channels()) get_data_type_from_file(empty_name_file, get_signature_channels())
with self.assertRaises(FileNotFoundError): with self.assertRaises(FileNotFoundError):
getDataTypeFromFile(non_existent_file, get_signature_channels()) get_data_type_from_file(non_existent_file,
get_signature_channels())
...@@ -9,14 +9,14 @@ import string ...@@ -9,14 +9,14 @@ import string
from sohstationviewer.controller.util import ( from sohstationviewer.controller.util import (
validateFile, validate_file,
getDirSize, get_dir_size,
getTime6, get_time_6,
getTime6_2y, get_time_6_2y,
getTime6_4y, get_time_6_4y,
getTime4, get_time_4,
getVal, get_val,
rtnPattern, rtn_pattern,
add_thousand_separator add_thousand_separator
) )
...@@ -28,7 +28,9 @@ TEST_DATA_DIR = os.path.realpath(os.path.join( ...@@ -28,7 +28,9 @@ TEST_DATA_DIR = os.path.realpath(os.path.join(
class TestGetTime(TestCase): class TestGetTime(TestCase):
"""Test suite for getTime6, getTime6_2y, getTime6_4y, and getTime4.""" """
Test suite for get_time_6, get_time_6_2y, get_time_6_4y, and get_time_4.
"""
def setUp(self): def setUp(self):
"""Set up text fixtures.""" """Set up text fixtures."""
self.time6_2y = '01:251:09:41:35:656' self.time6_2y = '01:251:09:41:35:656'
...@@ -36,55 +38,55 @@ class TestGetTime(TestCase): ...@@ -36,55 +38,55 @@ class TestGetTime(TestCase):
self.time4_day_1 = '1:09:41:35' self.time4_day_1 = '1:09:41:35'
self.time4 = '251:09:41:35' self.time4 = '251:09:41:35'
@patch('sohstationviewer.controller.util.getTime6_4y') @patch('sohstationviewer.controller.util.get_time_6_4y')
@patch('sohstationviewer.controller.util.getTime6_2y') @patch('sohstationviewer.controller.util.get_time_6_2y')
def test_get_time6(self, mock_2y, mock_4y): def test_get_time6(self, mock_2y, mock_4y):
""" """
Test getTime6 - check that getTime6 delegates work to the appropriate Test get_time_6 - check that get_time_6 delegates work to the
helper function depending on the input. appropriate helper function depending on the input.
""" """
with self.subTest('test_2_digit_year'): with self.subTest('test_2_digit_year'):
getTime6(self.time6_2y) get_time_6(self.time6_2y)
self.assertTrue(mock_2y.called) self.assertTrue(mock_2y.called)
self.assertFalse(mock_4y.called) self.assertFalse(mock_4y.called)
mock_2y.reset_mock() mock_2y.reset_mock()
mock_4y.reset_mock() mock_4y.reset_mock()
with self.subTest('test_4_digit_year'): with self.subTest('test_4_digit_year'):
getTime6(self.time6_4y) get_time_6(self.time6_4y)
self.assertTrue(mock_4y.called) self.assertTrue(mock_4y.called)
self.assertFalse(mock_2y.called) self.assertFalse(mock_2y.called)
def test_get_time6_invalid_input(self): def test_get_time6_invalid_input(self):
"""Test getTime6 - the input is not one of the expected formats.""" """Test get_time_6 - the input is not one of the expected formats."""
with self.subTest('test_input_contains_colon'): with self.subTest('test_input_contains_colon'):
bad_inputs = [':523:531:', 'fs:523:531:', 'towe:523:531:'] bad_inputs = [':523:531:', 'fs:523:531:', 'towe:523:531:']
for input_str in bad_inputs: for input_str in bad_inputs:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
getTime6(input_str) get_time_6(input_str)
with self.subTest('test_input_does_not_contain_colon'): with self.subTest('test_input_does_not_contain_colon'):
input_str = 'fq31dqrt63' input_str = 'fq31dqrt63'
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
getTime6(input_str) get_time_6(input_str)
def test_get_time6_2y(self): def test_get_time6_2y(self):
"""Test getTime6_2y.""" """Test get_time_6_2y."""
epoch_time, year = getTime6_2y(self.time6_2y) epoch_time, year = get_time_6_2y(self.time6_2y)
self.assertAlmostEqual(epoch_time, 999942095.656) self.assertAlmostEqual(epoch_time, 999942095.656)
self.assertEqual(year, 2001) self.assertEqual(year, 2001)
def test_get_time6_4y(self): def test_get_time6_4y(self):
"""Test getTime6_4y.""" """Test get_time_6_4y."""
epoch_time, year = getTime6_4y(self.time6_4y) epoch_time, year = get_time_6_4y(self.time6_4y)
self.assertAlmostEqual(epoch_time, 999942095.656) self.assertAlmostEqual(epoch_time, 999942095.656)
self.assertEqual(year, 2001) self.assertEqual(year, 2001)
def test_get_time4_year_added(self): def test_get_time4_year_added(self):
"""Test getTime4 - a year has been added.""" """Test get_time_4 - a year has been added."""
year = 2001 year = 2001
year_added = True year_added = True
with self.subTest('test_first_day_of_year'): with self.subTest('test_first_day_of_year'):
epoch_time, ret_year, ret_year_added = ( epoch_time, ret_year, ret_year_added = (
getTime4(self.time4_day_1, year, year_added) get_time_4(self.time4_day_1, year, year_added)
) )
self.assertEqual(epoch_time, 978342095) self.assertEqual(epoch_time, 978342095)
self.assertEqual(ret_year, year) self.assertEqual(ret_year, year)
...@@ -92,19 +94,19 @@ class TestGetTime(TestCase): ...@@ -92,19 +94,19 @@ class TestGetTime(TestCase):
with self.subTest('test_other_days'): with self.subTest('test_other_days'):
epoch_time, ret_year, ret_year_added = ( epoch_time, ret_year, ret_year_added = (
getTime4(self.time4, year, year_added) get_time_4(self.time4, year, year_added)
) )
self.assertEqual(epoch_time, 999942095) self.assertEqual(epoch_time, 999942095)
self.assertEqual(ret_year, year) self.assertEqual(ret_year, year)
self.assertTrue(ret_year_added) self.assertTrue(ret_year_added)
def test_get_time4_year_not_added(self): def test_get_time4_year_not_added(self):
"""Test getTime4 - a year has not been added.""" """Test get_time_4 - a year has not been added."""
year = 2001 year = 2001
year_added = False year_added = False
with self.subTest('test_first_day_of_year'): with self.subTest('test_first_day_of_year'):
epoch_time, ret_year, ret_year_added = ( epoch_time, ret_year, ret_year_added = (
getTime4(self.time4_day_1, year, year_added) get_time_4(self.time4_day_1, year, year_added)
) )
self.assertEqual(epoch_time, 1009878095) self.assertEqual(epoch_time, 1009878095)
self.assertEqual(ret_year, year + 1) self.assertEqual(ret_year, year + 1)
...@@ -112,7 +114,7 @@ class TestGetTime(TestCase): ...@@ -112,7 +114,7 @@ class TestGetTime(TestCase):
with self.subTest('test_other_days'): with self.subTest('test_other_days'):
epoch_time, ret_year, ret_year_added = ( epoch_time, ret_year, ret_year_added = (
getTime4(self.time4, year, year_added) get_time_4(self.time4, year, year_added)
) )
self.assertEqual(epoch_time, 999942095) self.assertEqual(epoch_time, 999942095)
self.assertEqual(ret_year, year) self.assertEqual(ret_year, year)
...@@ -120,21 +122,21 @@ class TestGetTime(TestCase): ...@@ -120,21 +122,21 @@ class TestGetTime(TestCase):
class TestValidateFile(TestCase): class TestValidateFile(TestCase):
"""Test suite for validateFile.""" """Test suite for validate_file."""
def test_valid_file(self): def test_valid_file(self):
""" """
Test basic functionality of validateFile - given file exists and is not Test basic functionality of validate_file - given file exists and is
an info file. not an info file.
""" """
with NamedTemporaryFile() as valid_file: with NamedTemporaryFile() as valid_file:
self.assertTrue( self.assertTrue(
validateFile(valid_file.name, validate_file(valid_file.name,
os.path.basename(valid_file.name)) os.path.basename(valid_file.name))
) )
def test_info_file(self): def test_info_file(self):
""" """
Test basic functionality of validateFile - given file exists and is an Test basic functionality of validate_file - given file exists and is an
info file. info file.
""" """
with self.subTest('test_dot_DS_Store'): with self.subTest('test_dot_DS_Store'):
...@@ -158,38 +160,40 @@ class TestValidateFile(TestCase): ...@@ -158,38 +160,40 @@ class TestValidateFile(TestCase):
ds_store_path = os.path.join(temp_dir, '.DS_Store') ds_store_path = os.path.join(temp_dir, '.DS_Store')
with open(ds_store_path, 'w+') as ds_store_file: with open(ds_store_path, 'w+') as ds_store_file:
self.assertFalse( self.assertFalse(
validateFile(ds_store_path, validate_file(ds_store_path,
os.path.basename(ds_store_file.name)) os.path.basename(ds_store_file.name))
) )
with self.subTest('test_dot_underscore'): with self.subTest('test_dot_underscore'):
with NamedTemporaryFile(prefix='._') as info_file: with NamedTemporaryFile(prefix='._') as info_file:
self.assertFalse( self.assertFalse(
validateFile(info_file.name, validate_file(info_file.name,
os.path.basename(info_file.name)) os.path.basename(info_file.name))
) )
def test_file_does_not_exist(self): def test_file_does_not_exist(self):
""" """
Test basic functionality of validateFile - given file does not exist. Test basic functionality of validate_file - given file does not exist.
""" """
empty_name_file = '' empty_name_file = ''
self.assertFalse(validateFile(empty_name_file, empty_name_file)) self.assertFalse(validate_file(empty_name_file, empty_name_file))
not_exist_file = 'file_does_not_exist' not_exist_file = 'file_does_not_exist'
self.assertFalse(validateFile(not_exist_file, not_exist_file)) self.assertFalse(validate_file(not_exist_file, not_exist_file))
class TestGetDirSize(TestCase): class TestGetDirSize(TestCase):
"""Test suite for getDirSize.""" """Test suite for get_dir_size."""
def test_files_have_size_zero(self): def test_files_have_size_zero(self):
"""Test getDirSize - all files in the given directory has size zero.""" """
Test get_dir_size - all files in the given directory has size zero.
"""
expected_file_count = 10 expected_file_count = 10
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
files = [] files = []
for i in range(expected_file_count): for i in range(expected_file_count):
files.append(NamedTemporaryFile(dir=directory)) files.append(NamedTemporaryFile(dir=directory))
dir_size, dir_file_count = getDirSize(directory) dir_size, dir_file_count = get_dir_size(directory)
self.assertEqual(dir_size, 0) self.assertEqual(dir_size, 0)
self.assertEqual(dir_file_count, expected_file_count) self.assertEqual(dir_file_count, expected_file_count)
# Explicitly clean up the temporary files. If we don't do this, # Explicitly clean up the temporary files. If we don't do this,
...@@ -202,8 +206,8 @@ class TestGetDirSize(TestCase): ...@@ -202,8 +206,8 @@ class TestGetDirSize(TestCase):
def test_files_have_size_greater_than_zero(self): def test_files_have_size_greater_than_zero(self):
""" """
Test getDirSize - all files in the given directory have a size greater Test get_dir_size - all files in the given directory have a size
than zero. greater than zero.
""" """
expected_file_count = 10 expected_file_count = 10
size_one_file = 10 size_one_file = 10
...@@ -221,7 +225,7 @@ class TestGetDirSize(TestCase): ...@@ -221,7 +225,7 @@ class TestGetDirSize(TestCase):
# size of the file on disk will stay 0. # size of the file on disk will stay 0.
temp_file.flush() temp_file.flush()
files.append(temp_file) files.append(temp_file)
dir_size, dir_file_count = getDirSize(temp_dir.name) dir_size, dir_file_count = get_dir_size(temp_dir.name)
self.assertEqual(dir_size, size_one_file * expected_file_count) self.assertEqual(dir_size, size_one_file * expected_file_count)
self.assertEqual(dir_file_count, expected_file_count) self.assertEqual(dir_file_count, expected_file_count)
# Explicitly clean up the temporary files. If we don't do this, # Explicitly clean up the temporary files. If we don't do this,
...@@ -234,60 +238,60 @@ class TestGetDirSize(TestCase): ...@@ -234,60 +238,60 @@ class TestGetDirSize(TestCase):
def test_nested_folder_structure(self): def test_nested_folder_structure(self):
""" """
Test getDirSize - the given directory contains nested directories. Test get_dir_size - the given directory contains nested directories.
""" """
test_folder = os.path.join(TEST_DATA_DIR, 'Pegasus-sample') test_folder = os.path.join(TEST_DATA_DIR, 'Pegasus-sample')
dir_size, dir_file_count = getDirSize(test_folder) dir_size, dir_file_count = get_dir_size(test_folder)
self.assertEqual(dir_size, 7974651) self.assertEqual(dir_size, 7974651)
self.assertEqual(dir_file_count, 10) self.assertEqual(dir_file_count, 10)
def test_empty_directory(self): def test_empty_directory(self):
"""Test getDirSize - the given directory contains no file.""" """Test get_dir_size - the given directory contains no file."""
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
dir_size, dir_file_count = getDirSize(temp_dir) dir_size, dir_file_count = get_dir_size(temp_dir)
self.assertEqual(dir_size, 0) self.assertEqual(dir_size, 0)
self.assertEqual(dir_file_count, 0) self.assertEqual(dir_file_count, 0)
def test_directory_does_not_exist(self): def test_directory_does_not_exist(self):
"""Test getDirSize - the given directory does not exist.""" """Test get_dir_size - the given directory does not exist."""
empty_name_dir = '' empty_name_dir = ''
dir_size, dir_file_count = getDirSize(empty_name_dir) dir_size, dir_file_count = get_dir_size(empty_name_dir)
self.assertEqual(dir_size, 0) self.assertEqual(dir_size, 0)
self.assertEqual(dir_file_count, 0) self.assertEqual(dir_file_count, 0)
non_existent_dir = 'directory does not exist' non_existent_dir = 'directory does not exist'
dir_size, dir_file_count = getDirSize(non_existent_dir) dir_size, dir_file_count = get_dir_size(non_existent_dir)
self.assertEqual(dir_size, 0) self.assertEqual(dir_size, 0)
self.assertEqual(dir_file_count, 0) self.assertEqual(dir_file_count, 0)
class TestRtnPattern(TestCase): class TestRtnPattern(TestCase):
"""Test suite for rtnPattern.""" """Test suite for rtn_pattern."""
def test_no_upper(self): def test_no_upper(self):
"""Test rtnPattern - characters are not converted to uppercase.""" """Test rtn_pattern - characters are not converted to uppercase."""
with self.subTest('test_digit'): with self.subTest('test_digit'):
digits = '123456789' digits = '123456789'
self.assertEqual(rtnPattern(digits), '0' * len(digits)) self.assertEqual(rtn_pattern(digits), '0' * len(digits))
with self.subTest('test_lowercase'): with self.subTest('test_lowercase'):
lowercase_chars = string.ascii_lowercase lowercase_chars = string.ascii_lowercase
self.assertEqual(rtnPattern(lowercase_chars), self.assertEqual(rtn_pattern(lowercase_chars),
'a' * len(lowercase_chars)) 'a' * len(lowercase_chars))
with self.subTest('test_uppercase'): with self.subTest('test_uppercase'):
uppercase_chars = string.ascii_uppercase uppercase_chars = string.ascii_uppercase
self.assertEqual(rtnPattern(uppercase_chars), self.assertEqual(rtn_pattern(uppercase_chars),
'A' * len(uppercase_chars)) 'A' * len(uppercase_chars))
def test_with_upper(self): def test_with_upper(self):
"""Test rtnPattern - all characters are converted to uppercase.""" """Test rtn_pattern - all characters are converted to uppercase."""
lowercase_chars = string.ascii_lowercase lowercase_chars = string.ascii_lowercase
self.assertEqual(rtnPattern(lowercase_chars, upper=True), self.assertEqual(rtn_pattern(lowercase_chars, upper=True),
'A' * len(lowercase_chars)) 'A' * len(lowercase_chars))
class TestGetVal(TestCase): class TestGetVal(TestCase):
"""Test suite for getVal.""" """Test suite for get_val."""
def test_normal_case(self): def test_normal_case(self):
"""Test getVal - the input is of an expected value.""" """Test get_val - the input is of an expected value."""
# formatter:off # formatter:off
test_name_to_test_map = { test_name_to_test_map = {
'test_with_decimal_point': ('60.3V', 60.3), 'test_with_decimal_point': ('60.3V', 60.3),
...@@ -300,20 +304,20 @@ class TestGetVal(TestCase): ...@@ -300,20 +304,20 @@ class TestGetVal(TestCase):
# formatter:on # formatter:on
for test_name, inout_pair in test_name_to_test_map.items(): for test_name, inout_pair in test_name_to_test_map.items():
with self.subTest(test_name): with self.subTest(test_name):
self.assertEqual(getVal(inout_pair[0]), inout_pair[1]) self.assertEqual(get_val(inout_pair[0]), inout_pair[1])
def test_positive_negative_sign_in_front(self): def test_positive_negative_sign_in_front(self):
""" """
Test rtnPattern - the input has both a positive sign and a negative Test rtn_pattern - the input has both a positive sign and a negative
sign in the front. sign in the front.
""" """
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
getVal('+-1.0V') get_val('+-1.0V')
def test_bad_input(self): def test_bad_input(self):
"""Test rtnPattern - the input has a value that is not expected.""" """Test rtn_pattern - the input has a value that is not expected."""
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
getVal('') get_val('')
class TestFmti(TestCase): class TestFmti(TestCase):
......
...@@ -79,7 +79,7 @@ class TestExtractData(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestExtractData(unittest.TestCase):
'valueColors': None} 'valueColors': None}
# Data type has None value. None value comes from # Data type has None value. None value comes from
# controller.processing.detectDataType. # controller.processing.detect_data_type.
expected_result['label'] = 'DEFAULT-SOH/Data Def' expected_result['label'] = 'DEFAULT-SOH/Data Def'
self.assertDictEqual(get_chan_plot_info('SOH/Data Def', None), self.assertDictEqual(get_chan_plot_info('SOH/Data Def', None),
expected_result) expected_result)
......
...@@ -9,13 +9,12 @@ import numpy as np ...@@ -9,13 +9,12 @@ import numpy as np
from sohstationviewer.conf import constants as const from sohstationviewer.conf import constants as const
from sohstationviewer.model.handling_data import ( from sohstationviewer.model.handling_data import (
downsample, trim_downsample_soh_chan,
chunk_minmax, trim_downsample_wf_chan,
trim_downsample_SOHChan,
trim_downsample_WFChan,
trim_waveform_data, trim_waveform_data,
downsample_waveform_data downsample_waveform_data
) )
from sohstationviewer.model.downsampler import downsample, chunk_minmax
ORIGINAL_CHAN_SIZE_LIMIT = const.CHAN_SIZE_LIMIT ORIGINAL_CHAN_SIZE_LIMIT = const.CHAN_SIZE_LIMIT
ORIGINAL_RECAL_SIZE_LIMIT = const.RECAL_SIZE_LIMIT ORIGINAL_RECAL_SIZE_LIMIT = const.RECAL_SIZE_LIMIT
...@@ -63,15 +62,15 @@ class TestTrimWfData(TestCase): ...@@ -63,15 +62,15 @@ class TestTrimWfData(TestCase):
self.start_time = 12500 self.start_time = 12500
self.end_time = 17500 self.end_time = 17500
self.assertFalse( self.assertFalse(
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
) )
with self.subTest('test_end_time_earlier_than_data_start_time'): with self.subTest('test_end_time_earlier_than_data_start_time'):
self.start_time = -7500 self.start_time = -7500
self.end_time = -2500 self.end_time = -2500
self.assertFalse( self.assertFalse(
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
) )
def test_no_data(self): def test_no_data(self):
...@@ -222,7 +221,7 @@ class TestDownsampleWaveformData(TestCase): ...@@ -222,7 +221,7 @@ class TestDownsampleWaveformData(TestCase):
class TestDownsample(TestCase): class TestDownsample(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
patcher = patch('sohstationviewer.model.handling_data.chunk_minmax') patcher = patch('sohstationviewer.model.downsampler.chunk_minmax')
self.addCleanup(patcher.stop) self.addCleanup(patcher.stop)
self.mock_chunk_minmax = patcher.start() self.mock_chunk_minmax = patcher.start()
self.times = np.arange(1000) self.times = np.arange(1000)
...@@ -339,7 +338,7 @@ class TestChunkMinmax(TestCase): ...@@ -339,7 +338,7 @@ class TestChunkMinmax(TestCase):
self.assertEqual(data.size, req_points) self.assertEqual(data.size, req_points)
self.assertEqual(log_idx.size, req_points) self.assertEqual(log_idx.size, req_points)
@patch('sohstationviewer.model.handling_data.downsample', wraps=downsample) @patch('sohstationviewer.model.downsampler.downsample', wraps=downsample)
def test_data_size_is_not_multiple_of_requested_points( def test_data_size_is_not_multiple_of_requested_points(
self, mock_downsample): self, mock_downsample):
req_points = 102 req_points = 102
...@@ -386,8 +385,8 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -386,8 +385,8 @@ class TestTrimDownsampleSohChan(TestCase):
def test_start_time_later_than_times_data(self): def test_start_time_later_than_times_data(self):
self.start_time = 250 self.start_time = 250
self.end_time = 1250 self.end_time = 1250
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertGreaterEqual(self.channel_info['times'].min(), self.assertGreaterEqual(self.channel_info['times'].min(),
self.start_time) self.start_time)
self.assertEqual( self.assertEqual(
...@@ -398,8 +397,8 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -398,8 +397,8 @@ class TestTrimDownsampleSohChan(TestCase):
def test_end_time_earlier_than_times_data(self): def test_end_time_earlier_than_times_data(self):
self.start_time = -250 self.start_time = -250
self.end_time = 750 self.end_time = 750
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertLessEqual(self.channel_info['times'].max(), self.assertLessEqual(self.channel_info['times'].max(),
self.end_time) self.end_time)
self.assertEqual( self.assertEqual(
...@@ -410,8 +409,8 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -410,8 +409,8 @@ class TestTrimDownsampleSohChan(TestCase):
def test_start_time_earlier_than_times_data(self): def test_start_time_earlier_than_times_data(self):
self.start_time = -250 self.start_time = -250
self.end_time = 750 self.end_time = 750
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertEqual(self.channel_info['times'].min(), self.assertEqual(self.channel_info['times'].min(),
self.channel_info['orgTrace']['times'].min()) self.channel_info['orgTrace']['times'].min())
self.assertEqual( self.assertEqual(
...@@ -422,8 +421,8 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -422,8 +421,8 @@ class TestTrimDownsampleSohChan(TestCase):
def test_end_time_later_than_times_data(self): def test_end_time_later_than_times_data(self):
self.start_time = 250 self.start_time = 250
self.end_time = 1250 self.end_time = 1250
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertEqual(self.channel_info['times'].max(), self.assertEqual(self.channel_info['times'].max(),
self.channel_info['orgTrace']['times'].max()) self.channel_info['orgTrace']['times'].max())
self.assertEqual( self.assertEqual(
...@@ -434,35 +433,35 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -434,35 +433,35 @@ class TestTrimDownsampleSohChan(TestCase):
def test_times_data_contained_in_time_range(self): def test_times_data_contained_in_time_range(self):
self.start_time = -250 self.start_time = -250
self.end_time = 1250 self.end_time = 1250
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
np.testing.assert_array_equal(self.channel_info['times'], np.testing.assert_array_equal(self.channel_info['times'],
self.org_trace['times']) self.org_trace['times'])
def test_time_range_is_the_same_as_times_data(self): def test_time_range_is_the_same_as_times_data(self):
self.start_time = ZERO_EPOCH_TIME self.start_time = ZERO_EPOCH_TIME
self.end_time = 999 self.end_time = 999
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
np.testing.assert_array_equal(self.channel_info['times'], np.testing.assert_array_equal(self.channel_info['times'],
self.org_trace['times']) self.org_trace['times'])
def test_time_range_does_not_overlap_times_data(self): def test_time_range_does_not_overlap_times_data(self):
self.start_time = 2000 self.start_time = 2000
self.end_time = 3000 self.end_time = 3000
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertEqual(self.channel_info['times'].size, 0) self.assertEqual(self.channel_info['times'].size, 0)
self.assertEqual(self.channel_info['data'].size, 0) self.assertEqual(self.channel_info['data'].size, 0)
def test_data_is_downsampled(self): def test_data_is_downsampled(self):
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertTrue(self.mock_downsample.called) self.assertTrue(self.mock_downsample.called)
def test_processed_data_is_stored_in_appropriate_location(self): def test_processed_data_is_stored_in_appropriate_location(self):
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
expected_keys = ('orgTrace', 'times', 'data') expected_keys = ('orgTrace', 'times', 'data')
self.assertTupleEqual(tuple(self.channel_info.keys()), self.assertTupleEqual(tuple(self.channel_info.keys()),
expected_keys) expected_keys)
...@@ -470,8 +469,8 @@ class TestTrimDownsampleSohChan(TestCase): ...@@ -470,8 +469,8 @@ class TestTrimDownsampleSohChan(TestCase):
@patch('sohstationviewer.model.handling_data.downsample') @patch('sohstationviewer.model.handling_data.downsample')
def test_arguments_sent_to_downsample(self, mock_downsample): def test_arguments_sent_to_downsample(self, mock_downsample):
mock_downsample.return_value = (1, 2, 3) mock_downsample.return_value = (1, 2, 3)
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
positional_args, named_args = mock_downsample.call_args positional_args, named_args = mock_downsample.call_args
self.assertEqual(len(positional_args), 2) self.assertEqual(len(positional_args), 2)
...@@ -493,15 +492,15 @@ class TestTrimDownsampleSohChanWithLogidx(TestCase): ...@@ -493,15 +492,15 @@ class TestTrimDownsampleSohChanWithLogidx(TestCase):
def test_time_range_does_not_overlap_times_data(self): def test_time_range_does_not_overlap_times_data(self):
self.start_time = 2000 self.start_time = 2000
self.end_time = 3000 self.end_time = 3000
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
self.assertEqual(self.channel_info['times'].size, 0) self.assertEqual(self.channel_info['times'].size, 0)
self.assertEqual(self.channel_info['data'].size, 0) self.assertEqual(self.channel_info['data'].size, 0)
self.assertEqual(self.channel_info['logIdx'].size, 0) self.assertEqual(self.channel_info['logIdx'].size, 0)
def test_processed_data_is_stored_in_appropriate_location(self): def test_processed_data_is_stored_in_appropriate_location(self):
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
expected_keys = ('orgTrace', 'times', 'data', 'logIdx') expected_keys = ('orgTrace', 'times', 'data', 'logIdx')
self.assertTupleEqual(tuple(self.channel_info.keys()), self.assertTupleEqual(tuple(self.channel_info.keys()),
expected_keys) expected_keys)
...@@ -509,8 +508,8 @@ class TestTrimDownsampleSohChanWithLogidx(TestCase): ...@@ -509,8 +508,8 @@ class TestTrimDownsampleSohChanWithLogidx(TestCase):
@patch('sohstationviewer.model.handling_data.downsample') @patch('sohstationviewer.model.handling_data.downsample')
def test_arguments_sent_to_downsample(self, mock_downsample): def test_arguments_sent_to_downsample(self, mock_downsample):
mock_downsample.return_value = (1, 2, 3) mock_downsample.return_value = (1, 2, 3)
trim_downsample_SOHChan(self.channel_info, self.start_time, trim_downsample_soh_chan(self.channel_info, self.start_time,
self.end_time, self.first_time) self.end_time)
positional_args, named_args = mock_downsample.call_args positional_args, named_args = mock_downsample.call_args
self.assertEqual(len(positional_args), 3) self.assertEqual(len(positional_args), 3)
...@@ -558,32 +557,32 @@ class TestTrimDownsampleWfChan(TestCase): ...@@ -558,32 +557,32 @@ class TestTrimDownsampleWfChan(TestCase):
self.end_time = 7500 self.end_time = 7500
def test_result_is_stored(self): def test_result_is_stored(self):
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
self.assertTrue('times' in self.channel_data) self.assertTrue('times' in self.channel_data)
self.assertGreater(len(self.channel_data['times']), 0) self.assertGreater(len(self.channel_data['times']), 0)
self.assertTrue('data' in self.channel_data) self.assertTrue('data' in self.channel_data)
self.assertGreater(len(self.channel_data['data']), 0) self.assertGreater(len(self.channel_data['data']), 0)
def test_data_small_enough_after_first_trim_flag_is_set(self): def test_data_small_enough_after_first_trim_flag_is_set(self):
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
self.assertTrue('fulldata' in self.channel_data) self.assertTrue('fulldata' in self.channel_data)
def test_no_additional_work_if_data_small_enough_after_first_trim(self): def test_no_additional_work_if_data_small_enough_after_first_trim(self):
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
current_times = self.channel_data['times'] current_times = self.channel_data['times']
current_data = self.channel_data['data'] current_data = self.channel_data['data']
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, True) self.end_time, True)
self.assertIs(current_times, self.channel_data['times']) self.assertIs(current_times, self.channel_data['times'])
self.assertIs(current_data, self.channel_data['data']) self.assertIs(current_data, self.channel_data['data'])
def test_data_too_large_after_trimming(self): def test_data_too_large_after_trimming(self):
const.RECAL_SIZE_LIMIT = 1 const.RECAL_SIZE_LIMIT = 1
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, False) self.end_time, False)
self.assertTrue('times' not in self.channel_data) self.assertTrue('times' not in self.channel_data)
self.assertTrue('data' not in self.channel_data) self.assertTrue('data' not in self.channel_data)
const.RECAL_SIZE_LIMIT = ORIGINAL_RECAL_SIZE_LIMIT const.RECAL_SIZE_LIMIT = ORIGINAL_RECAL_SIZE_LIMIT
...@@ -593,7 +592,7 @@ class TestTrimDownsampleWfChan(TestCase): ...@@ -593,7 +592,7 @@ class TestTrimDownsampleWfChan(TestCase):
@patch('sohstationviewer.model.handling_data.downsample_waveform_data', @patch('sohstationviewer.model.handling_data.downsample_waveform_data',
wraps=downsample_waveform_data) wraps=downsample_waveform_data)
def test_data_trim_and_downsampled(self, mock_downsample, mock_trim): def test_data_trim_and_downsampled(self, mock_downsample, mock_trim):
trim_downsample_WFChan(self.channel_data, self.start_time, trim_downsample_wf_chan(self.channel_data, self.start_time,
self.end_time, False) self.end_time, False)
self.assertTrue(mock_trim.called) self.assertTrue(mock_trim.called)
self.assertTrue(mock_downsample.called) self.assertTrue(mock_downsample.called)