# class with all plotting functions
import numpy as np

from sohstationviewer.controller.util import get_val
from sohstationviewer.controller.plotting_data import get_masspos_value_colors

from sohstationviewer.view.util.color import clr
from sohstationviewer.conf import constants


class Plotting:
    """
    Class that includes different methods to plot channels on a figure.
    """
    def __init__(self, parent, plotting_axes, main_window):
        """
        :param parent: PlottingWidget - widget to plot channels
        :param plotting_axes: PlottingAxes - widget that includes a figure
            and methods related to create axes
        :param main_window: QApplication - Main Window to access user's
            setting parameters
        """
        super().__init__()
        self.parent = parent
        self.main_window = main_window
        self.plotting_axes = plotting_axes

    def plot_none(self):
        """
        Plot with nothing needed to show rulers.
        :return ax: matplotlib.axes.Axes - axes of the empty plot
        """
        plot_h = 0.00001
        bw_plots_distance = 0.0001
        self.parent.plotting_bot -= plot_h + bw_plots_distance
        ax = self.plotting_axes.create_axes(
            self.parent.plotting_bot, plot_h, has_min_max_lines=False)
        ax.x = None
        ax.plot([0], [0], linestyle="")
        ax.chan_db_info = None
        return ax

    def plot_multi_color_dots(self, c_data, chan_db_info, chan_id,
                              ax, linked_ax):
        """
        plot scattered dots with colors defined by valueColors in database:
            Ex: *:W  means everything with white color
            Ex: -1:_|0:R|2.3:Y|+2.3:G  means
                value <= -1  => not plot
                value <= 0   => plot with R color
                value <= 2.3 => plot with Y color
                value > 2.3  => plot with G color
        Color codes are defined in colorSettings and limitted in 'valColRE'
            in dbSettings.py

        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            soh_data[key][chan_id]
        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        if linked_ax is not None:
            ax = linked_ax
        if ax is None:
            plot_h = self.plotting_axes.get_height(chan_db_info['height'])
            ax = self.plotting_axes.create_axes(
                self.parent.plotting_bot, plot_h,
                has_min_max_lines=False)

        x = []
        prev_val = -constants.HIGHEST_INT

        if chan_db_info['valueColors'] in [None, 'None', '']:
            chan_db_info['valueColors'] = '*:W'
        value_colors = chan_db_info['valueColors'].split('|')
        colors = []
        for vc in value_colors:
            v, c = vc.split(':')
            colors.append(c)
            val = get_val(v)
            if c == '_':
                prev_val = val
                continue
            points = []
            for times, data in zip(c_data['times'], c_data['data']):
                if v.startswith('+'):
                    points = [times[i]
                              for i in range(len(data))
                              if data[i] > val]
                elif v == '*':
                    points = times
                else:
                    points = [times[i]
                              for i in range(len(data))
                              if prev_val < data[i] <= val]
                x += points

            ax.plot(points, len(points) * [0], linestyle="",
                    marker='s', markersize=2,
                    zorder=constants.Z_ORDER['DOT'],
                    color=clr[c], picker=True, pickradius=3)
            prev_val = val

        total_samples = len(x)

        if len(colors) != 1:
            sample_no_colors = [clr['W']]
        else:
            sample_no_colors = [clr[colors[0]]]

        self.plotting_axes.set_axes_info(
            ax, [total_samples], sample_no_colors=sample_no_colors,
            chan_db_info=chan_db_info, linked_ax=linked_ax)

        ax.x_list = c_data['times']
        ax.chan_db_info = chan_db_info
        return ax

    def plot_up_down_dots(self, c_data, chan_db_info, chan_id, ax, linked_ax):
        """
        Plot channel with 2 different values, one above, one under center line.
        Each value has corresponding color defined in valueColors in database.
        Ex: 1:Y|0:R  means
            value == 1 => plot above center line with Y color
            value == 0 => plot under center line with R color
        Color codes are defined in colorSettings

        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            soh_data[key][chan_id]
        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        if linked_ax is not None:
            ax = linked_ax
        if ax is None:
            plot_h = self.plotting_axes.get_height(chan_db_info['height'])
            ax = self.plotting_axes.create_axes(
                self.parent.plotting_bot, plot_h,
                has_min_max_lines=False)

        val_cols = chan_db_info['valueColors'].split('|')
        points_list = []
        colors = []
        for vc in val_cols:
            v, c = vc.split(':')
            val = get_val(v)
            points = []
            for times, data in zip(c_data['times'], c_data['data']):
                points += [times[i]
                           for i in range(len(data))
                           if data[i] == val]
            points_list.append(points)
            colors.append(c)

        # down dots
        ax.plot(points_list[0], len(points_list[0]) * [-0.5], linestyle="",
                marker='s', markersize=2, zorder=constants.Z_ORDER['DOT'],
                color=clr[colors[0]], picker=True, pickradius=3)
        # up dots
        ax.plot(points_list[1], len(points_list[1]) * [0.5], linestyle="",
                marker='s', markersize=2, zorder=constants.Z_ORDER['DOT'],
                color=clr[colors[1]], picker=True, pickradius=3)

        ax.set_ylim(-2, 2)
        self.plotting_axes.set_axes_info(
            ax, [len(points_list[0]), len(points_list[1])],
            sample_no_colors=[clr[colors[0]], clr[colors[1]]],
            sample_no_pos=[0.25, 0.75],
            chan_db_info=chan_db_info, linked_ax=linked_ax)

        # x_bottom, x_top are the times of data points to be displayed at
        # bottom or top of the plot
        ax.x_bottom = np.array(points_list[0])
        ax.x_top = np.array(points_list[1])

        ax.chan_db_info = chan_db_info
        return ax

    def plot_time_dots(self, c_data, chan_db_info, chan_id, ax, linked_ax):
        """
        Plot times only
        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            soh_data[key][chan_id]

        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        if linked_ax is not None:
            ax = linked_ax
        if ax is None:
            plot_h = self.plotting_axes.get_height(chan_db_info['height'])
            ax = self.plotting_axes.create_axes(
                self.parent.plotting_bot, plot_h)

        color = 'W'
        if chan_db_info['valueColors'] not in [None, 'None', '']:
            color = chan_db_info['valueColors'].strip()
        x_list = c_data['times']
        total_x = sum([len(x) for x in x_list])
        self.plotting_axes.set_axes_info(
            ax, [total_x], sample_no_colors=[clr[color]],
            chan_db_info=chan_db_info, linked_ax=linked_ax)

        for x in x_list:
            ax.plot(x, [0] * len(x), marker='s', markersize=1.5,
                    linestyle='', zorder=constants.Z_ORDER['LINE'],
                    color=clr[color], picker=True,
                    pickradius=3)
        ax.x_list = x_list
        ax.chan_db_info = chan_db_info
        return ax

    def plot_lines_dots(self, c_data, chan_db_info, chan_id,
                        ax, linked_ax, info=''):
        """
        Plot lines with dots at the data points. Colors of dot and lines are
        defined in valueColors in database.
        Ex: L:G|D:W  means
            Lines are plotted with color G
            Dots are plotted with color W
        If D is not defined, dots won't be displayed.
        If L is not defined, lines will be plotted with color G
        Color codes are defined in colorSettings

        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            soh_data[key][chan_id] or DataTypeModel.__init__.
            waveform_data[key]['read_data'][chan_id] for waveform data
        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :param info: str - additional info to be displayed on sub-title under
            main-title
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        if linked_ax is not None:
            ax = linked_ax
        if ax is None:
            plot_h = self.plotting_axes.get_height(chan_db_info['height'])
            ax = self.plotting_axes.create_axes(
                self.parent.plotting_bot, plot_h)

        x_list, y_list = c_data['times'], c_data['data']

        colors = {}
        if chan_db_info['valueColors'] not in [None, 'None', '']:
            color_parts = chan_db_info['valueColors'].split('|')
            for cStr in color_parts:
                obj, c = cStr.split(':')
                colors[obj] = c
        l_color = 'G'
        has_dot = False
        if 'L' in colors:
            l_color = colors['L']
        if 'D' in colors:
            d_color = colors['D']
            has_dot = True
        else:
            d_color = l_color

        if chan_id == 'GPS Lk/Unlk':
            sample_no_list = []
            ax.x_bottom = x_list[0][np.where(y_list[0] == -1)[0]]
            sample_no_list.append(ax.x_bottom.size)
            ax.x_top = x_list[0][np.where(y_list[0] == 1)[0]]
            sample_no_list.append(ax.x_top.size)
            sample_no_colors = [clr[d_color], clr[d_color]]
        else:
            sample_no_list = [sum([len(x) for x in x_list])]
            sample_no_colors = [clr[d_color]]
        self.plotting_axes.set_axes_info(
            ax, sample_no_list, sample_no_colors=sample_no_colors,
            chan_db_info=chan_db_info,
            info=info, y_list=y_list, linked_ax=linked_ax)

        for x, y in zip(x_list, y_list):
            if not has_dot and sample_no_list[0] > 1:
                # set marker to be able to click point for info
                # but marker's size is small to not show dot.
                ax.myPlot = ax.plot(x, y, marker='o', markersize=0.01,
                                    linestyle='-', linewidth=0.7,
                                    zorder=constants.Z_ORDER['LINE'],
                                    color=clr[l_color],
                                    picker=True, pickradius=2)
            else:
                ax.myPlot = ax.plot(x, y, marker='s', markersize=1.5,
                                    linestyle='-', linewidth=0.7,
                                    zorder=constants.Z_ORDER['LINE'],
                                    color=clr[l_color],
                                    mfc=clr[d_color],
                                    mec=clr[d_color],
                                    picker=True, pickradius=3)

        if chan_id != 'GPS Lk/Unlk':
            ax.x_list = x_list
            ax.y_list = y_list

        ax.chan_db_info = chan_db_info
        return ax

    def plot_lines_s_rate(self, c_data, chan_db_info, chan_id, ax, linked_ax):
        """
        Plot line only for waveform data channel (seismic data). Sample rate
        unit will be displayed

        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            waveform_data[key]['read_data'][chan_id]
        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        if c_data['samplerate'] >= 1.0:
            info = "%dsps" % c_data['samplerate']
        else:
            info = "%gsps" % c_data['samplerate']
        return self.plot_lines_dots(c_data, chan_db_info, chan_id,
                                    ax, linked_ax, info=info)

    def plot_lines_mass_pos(self, c_data, chan_db_info, chan_id,
                            ax, linked_ax):
        """
        Plot multi-color dots with grey line for mass position channel.
        Use get_masspos_value_colors() to get value_colors map based on
            Menu - MP Coloring selected from Main Window.

        :param c_data: dict - data of the channel which includes down-sampled
            data in keys 'times' and 'data'. Refer to DataTypeModel.__init__.
            mass_pos_data[key][chan_id]
        :param chan_db_info: dict - info of channel from DB
        :param chan_id: str - name of channel
        :param ax: matplotlib.axes.Axes - axes to draw plot of channel
        :param linked_ax: matplotlib.axes.Axes/None - axes of another channel
            linked to this channel => both channels' will be plotted on the
            same axes
        :return ax: matplotlib.axes.Axes - axes of the channel
        """
        value_colors = get_masspos_value_colors(
            self.main_window.mass_pos_volt_range_opt, chan_id,
            self.parent.c_mode, self.parent.processing_log,
            ret_type='tupleList')

        if value_colors is None:
            return

        if ax is None:
            plot_h = self.plotting_axes.get_height(chan_db_info['height'])
            ax = self.plotting_axes.create_axes(
                self.parent.plotting_bot, plot_h)

        x_list, y_list = c_data['times'], c_data['data']
        total_x = sum([len(x) for x in x_list])
        self.plotting_axes.set_axes_info(
            ax, [total_x], chan_db_info=chan_db_info, y_list=y_list)
        for x, y in zip(x_list, y_list):
            ax.myPlot = ax.plot(x, y,
                                linestyle='-', linewidth=0.7,
                                color=self.parent.display_color['sub_basic'],
                                zorder=constants.Z_ORDER['LINE'])[0]
            colors = [None] * len(y)
            sizes = [1.5] * len(y)
            for i in range(len(y)):
                count = 0
                prev_v = 0
                for v, c in value_colors:
                    if count < (len(value_colors) - 1):
                        if prev_v < abs(y[i]) <= v:
                            colors[i] = clr[c]
                            break
                    else:
                        colors[i] = clr[c]
                        break
                    prev_v = v
                    count += 1
            ax.scatter(x, y, marker='s', c=colors, s=sizes,
                       picker=True, pickradius=15,
                       zorder=constants.Z_ORDER['DOT'],)
        ax.x_list = x_list
        ax.y_list = y_list
        ax.chan_db_info = chan_db_info
        return ax