Source code for pandas_alive._base_chart

""" Implementation of BaseChart constructor that all chart types should inherit from.

Methods & Attributes defined will be shared on all other chart types, is also interfaced with use of `super().method()`.

"""

import datetime
import typing

import attr
import matplotlib
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.units as munits
import numpy as np
import pandas as pd
from matplotlib import ticker
from matplotlib.animation import FuncAnimation, PillowWriter
from matplotlib.colors import Colormap, to_rgba

# For conciseDateFormatter for all plots https://matplotlib.org/3.1.0/gallery/ticks_and_spines/date_concise_formatter.html
converter = mdates.ConciseDateConverter()
munits.registry[np.datetime64] = converter
munits.registry[datetime.date] = converter
munits.registry[datetime.datetime] = converter


DARK24 = [
    "#2E91E5",
    "#E15F99",
    "#1CA71C",
    "#FB0D0D",
    "#DA16FF",
    "#222A2A",
    "#B68100",
    "#750D86",
    "#EB663B",
    "#511CFB",
    "#00A08B",
    "#FB00D1",
    "#FC0080",
    "#B2828D",
    "#6C7C32",
    "#778AAE",
    "#862A16",
    "#A777F1",
    "#620042",
    "#1616A7",
    "#DA60CA",
    "#6C4516",
    "#0D2A63",
    "#AF0038",
]


@attr.s()
class _BaseChart:
    """
    BaseChart constructor for attributes and methods for all chart types to share

    See :func: pandas_alive.plotting.plot for more details on input requirements

    """

    df: pd.DataFrame = attr.ib()
    interpolate_period: bool = attr.ib()
    steps_per_period: int = attr.ib()
    period_length: int = attr.ib()
    period_fmt: str = attr.ib()
    figsize: typing.Tuple[float, float] = attr.ib()
    title: str = attr.ib()
    fig: plt.Figure = attr.ib()
    cmap: typing.Union[str, Colormap, typing.List[str]] = attr.ib()
    tick_label_size: typing.Union[int, float, str] = attr.ib()
    period_label: typing.Union[
        bool, typing.Dict[str, typing.Union[int, float, str]]
    ] = attr.ib()
    period_summary_func: typing.Callable = attr.ib()
    fixed_max: bool = attr.ib()
    dpi: int = attr.ib()
    writer: str = attr.ib()
    enable_progress_bar: bool = attr.ib()
    kwargs = attr.ib()

    def __attrs_post_init__(self):
        """
        Post initialisation steps to run

        Functionality from attrs to calculate new attributes based on input args and kwargs
        """
        # Raises:
        #     ValueError: If `interpolate_period=True` and DataFrame index is not DateTimeIndex

        if isinstance(self.df, pd.Series):
            self.df = pd.DataFrame(self.df)

        self.df = self.df.copy()
        # from matplotlib import rcParams

        # rcParams.update({"figure.autolayout": True})

        # if self.interpolate_period == True and not isinstance(
        #     self.df.index, pd.DatetimeIndex
        # ):
        #     raise ValueError(
        #         f"If using interpolate_period, ensure the index is a DatetimeIndex (eg, use df.index = pd.to_datetime(df.index))"
        #     )

        self.orig_df = self.df.copy()
        self.colors = self.get_colors(self.cmap)  # Get colors for plotting
        if not isinstance(self.df.columns, pd.MultiIndex):
            self.data_cols = self.get_data_cols(
                self.df
            )  # Get column names with valid data
            self.df = self.rename_data_columns(
                self.df
            )  # Force data column names to be string
        else:
            self.data_cols = self.df.columns.get_level_values(level=0).unique().tolist()

        print(f"Generating {self.__class__.__name__}, plotting {self.data_cols}")

        # Careful to use self.df in later calculations (eg, df_rank), use orig_df if needed
        self.df = self.get_interpolated_df(
            self.df, self.steps_per_period, self.interpolate_period
        )
        if self.fig is None:
            self.fig, self.ax = self.create_figure()
            self.figsize = self.fig.get_size_inches()
        else:
            # This will use `fig=` input by user and gets its first axis
            self.ax = self.fig.get_axes()[0]
            self.ax.tick_params(labelsize=self.tick_label_size)

        if self.figsize is not None:
            self.fig.set_size_inches(self.figsize)
        self.fig.set_tight_layout(False)
        if self.title:
            self.ax.set_title(self.title)
        
        self.validate_params()

    def validate_params(self):
        """ Validate figure is a matplotlib Figure instance

        Args:
            attribute ([type]): Unused as required by attrs decorator
            value (plt.figure): Figure instance for chart

        Raises:
            TypeError: figure provided is not a matplotlib Figure instance
        """
        if self.fig is not None and not isinstance(self.fig, plt.Figure):
            raise TypeError("`fig` must be a matplotlib Figure instance")
        if self.writer:
            import matplotlib.animation as manimation

            if self.writer == "pillow":
                raise RuntimeError(
                    f"Pandas_Alive interfaces with Pillow for you, simply use the extension `.gif` and don't specify a writer"
                )
            if self.writer not in manimation.writers.list():
                raise RuntimeError(
                    f"Ensure that a matplotlib writer library is installed, list of available writer librarys {manimation.writers.list()}, see https://github.com/JackMcKew/pandas_alive/blob/main/README.md#requirements for more details"
                )

    def get_period_label(
        self,
        period_label: typing.Union[
            bool, typing.Dict[str, typing.Union[int, float, str]]
        ],
    ) -> typing.Union[bool, typing.Dict[str, typing.Union[int, float, str]]]:
        """ Parameters for period annotation on charts, dict will be passed to kwargs in matplotlib.ax.text()

        Args:
            period_label (typing.Union[bool,typing.Dict[str,typing.Union[int,float,str]]]): If `True` or dict, use the index as the text label

        Raises:
            ValueError: `x` and `y` must be supplied as a minimum

        Returns:
            typing.Union[bool,typing.Dict[str,typing.Union[int,float,str]]]: Returns `True` or dict will be passed to kwargs in matplotlib.ax.text()
        """
        if not period_label:
            return False
        elif period_label is True:
            # Default to bottom right corner
            period_label = {"size": 12, "x": 0.9, "y": 0.1, "ha": "right"}
        else:
            if "x" not in period_label or "y" not in period_label:
                raise ValueError(
                    '`period_label` dictionary must have keys for "x" and "y"'
                )
        return period_label

    def get_colors(
        self, cmap: typing.Union[Colormap, str, typing.List[str]]
    ) -> typing.List[str]:
        """ Get colours for plotting data

        Args:
            cmap (typing.Union[colors.Colormap,str]): Provide string of colormap name, colormap instance, single color instance or list of colors as supported by https://matplotlib.org/2.0.2/api/colors_api.html

        Raises:
            ValueError: If no supported colors are found
            TypeError: Type of colors is not supported

        Returns:
            typing.List[str]: Returns list of RGB values for colors as strings
        """
        if isinstance(cmap, str):
            try:
                cmap = DARK24 if cmap == "dark24" else plt.cm.get_cmap(cmap)
            except ValueError:
                # Try setting a list of repeating colours if no cmap found (for single colours)
                cmap = [to_rgba(cmap)] * len(self.get_data_cols(self.df))
            except:
                raise ValueError(
                    "Provide a suitable color name or color map as per matplotlib"
                )
        if isinstance(cmap, Colormap):
            chart_colors = cmap(range(cmap.N)).tolist()
        elif isinstance(cmap, list):
            chart_colors = cmap
        elif hasattr(cmap, "tolist"):
            chart_colors = cmap.tolist()
        else:
            raise TypeError(
                "`cmap` must be a string name of a color, a matplotlib colormap instance or a list of colors"
            )

        return chart_colors

    def get_single_color(self, color_string: str) -> typing.Tuple[int, int, int, int]:
        """
        Get single RBGA value from string

        From provided string return the RGB value from `to_rgba`, see more details at https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.colors.to_rgba.html

        Args:
            color_string (str): Must be a part of named colors in matplotlib https://matplotlib.org/3.1.1/gallery/color/named_colors.html#sphx-glr-gallery-color-named-colors-py

        Returns:
            typing.Tuple[int,int,int,int]: Tuple of (r, g, b, a) scalars.
        """
        return to_rgba(color_string)

    def set_x_y_limits(self, df: pd.DataFrame, i: int, ax: matplotlib.pyplot.Axes):
        """
        Set axis limits for both x and y of passed axes object

        For use with fixed_max to set the figure as the highest/lowest value in entire dataframe otherwise takes min/max from each timestep
        Fixed_max applies to both x & y

        Args:
            df (pd.DataFrame): DataFrame to take min/max from
            i (int): Frame number to slice DataFrame on if used without fixed_max
            ax (matplotlib.pyplot.Axes): Axes to apply limits to
        """
        # TODO fix max for x and y?
        if self.fixed_max:
            xlim_start = df.index.min()
            # For avoiding UserWarning on first frame with identical start and end limits
            if isinstance(xlim_start, pd.Timestamp):
                xlim_end = self.df.index.max() + pd.Timedelta(seconds=1)
            else:
                xlim_end = self.df.index.max() + 1e-6
        else:
            xlim_start = df.index[: i + 1].min()

            if isinstance(xlim_start, pd.Timestamp):
                # For avoiding UserWarning on first frame with identical start and end limits
                xlim_end = self.df.index[: i + 1].max() + pd.Timedelta(seconds=1)
            else:
                xlim_end = self.df.index[: i + 1].max() + 1e-6

        # ufunc error occurs in anaconda environments if not converted to datetime instead of Timestamp
        if isinstance(xlim_start, pd.Timestamp):
            ax.set_xlim(xlim_start.to_pydatetime(), xlim_end.to_pydatetime())
        else:
            ax.set_xlim(xlim_start, xlim_end)
        # self.ax.set_xlim(self.df.index[: i + 1].min(), self.df.index[: i + 1].max())
        
        # Avoid lines/scatter crossing vertical ylim and looking cut off
        ylim_scale = (self.df.values.max() - self.df.values.min())*0.05
        ylim_bot_scale = ylim_scale
        ylim_top_scale = ylim_scale
        # remove tolerance on ylim_bot/_top when data doesn't cross zero values
        if self.df.values.min() >= 0:
            ylim_bot_scale = 0
        if self.df.values.max() <= 0:
            ylim_top_scale = 0
        if self.fixed_max:
            # ax.set_ylim(self.df.min().min(skipna=True)*ylim_scale, self.df.max().max(skipna=True)*ylim_scale)
            ax.set_ylim(self.df.values.min() - ylim_bot_scale, self.df.values.max() + ylim_top_scale)
        else:
            ax.set_ylim(self.df.iloc[: i + 1].values.min() - ylim_bot_scale, self.df.iloc[: i + 1].values.max() + ylim_top_scale)
            # ax.set_ylim(
            #     self.df.iloc[: i + 1]
            #     .select_dtypes(include=[np.number])
            #     .min()
            #     .min(skipna=True)*ylim_scale,
            #     self.df.iloc[: i + 1]
            #     .select_dtypes(include=[np.number])
            #     .max()
            #     .max(skipna=True)*ylim_scale,
            # )

    def rename_data_columns(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Converts all column names to string

        Args:
            df (pd.DataFrame): DataFrame to rename columns on

        Returns:
            pd.DataFrame: DataFrame with converted columns
        """
        df.columns = df.columns.astype(str)
        return df

    def get_data_cols(self, df: pd.DataFrame) -> typing.List[str]:
        """ Get list of columns containing plottable numeric data to plot

        Raises:
            Exception: If column name is missing or changed during calculation
            Exception: If no numeric data was found to be plotted

        Returns:
            typing.List[str]: List of column names containing numeric data
        """
        data_cols = []
        for i, col in enumerate(df.columns):
            if col not in df.columns:
                raise Exception(
                    "Could not find '%s' in the columns of the provided DataFrame/Series. Please provide for the <y> parameter either a column name of the DataFrame/Series or an array of the same length."
                    % col
                )
            if np.issubdtype(df[col].dtype, np.number):
                data_cols.append(col)
        if not data_cols:
            raise Exception("No numeric data columns found for plotting.")

        data_cols = [str(col) for col in data_cols]

        return data_cols

    def get_interpolated_df(
        self, df: pd.DataFrame, steps_per_period: int, interpolate_period: bool
    ) -> pd.DataFrame:
        """ Get interpolated dataframe to span total animation

        Args:
            df (pd.DataFrame): Input dataframe
            steps_per_period (int): The number of steps to go from one period to the next. Data will show linearly between each period
            interpolate_period (bool): Whether to interpolate the period, must be datetime index

        Returns:
            pd.DataFrame: Interpolated dataframe
        """

        # Period interpolated to match other charts for multiple plotting
        # https://stackoverflow.com/questions/52701330/pandas-reindex-and-interpolate-time-series-efficiently-reindex-drops-data

        interpolated_df = df.reset_index()
        interpolated_df.index = interpolated_df.index * steps_per_period
        new_index = range(interpolated_df.index[-1] + 1)
        interpolated_df = interpolated_df.reindex(new_index)
        if interpolate_period:
            if interpolated_df.iloc[:, 0].dtype.kind == "M":
                first, last = interpolated_df.iloc[[0, -1], 0]
                dr = pd.date_range(first, last, periods=len(interpolated_df.index))
                interpolated_df.iloc[:, 0] = dr
            else:
                interpolated_df.iloc[:, 0] = interpolated_df.iloc[:, 0].interpolate()
        else:
            interpolated_df.iloc[:, 0] = interpolated_df.iloc[:, 0].fillna(
                method="ffill"
            )

        interpolated_df = interpolated_df.set_index(interpolated_df.columns[0])
        # if self.interpolate_period == True and not isinstance(
        #     self.df.index, pd.DatetimeIndex
        # ):
        if interpolate_period and isinstance(self.df.index, pd.DatetimeIndex):
            interpolated_df = interpolated_df.interpolate(method="time")
        else:
            interpolated_df = interpolated_df.interpolate()

        return interpolated_df

    def init_func(self) -> None:
        """ Initializing method for animation, to be overridden by extended classes

        Raises:
            NotImplementedError: Method to be overridden has not been implemented
        """
        raise NotImplementedError("Initializing method not yet implemented")

    def anim_func(self, frame: int) -> None:
        """ Animation method, to be overridden by extended chart class

        Args:
            frame (int): Frame to be animated

        Raises:
            NotImplementedError: Animation method not yet implemented in extended chart class
        """
        raise NotImplementedError("Animation method not yet implemented")

    def get_frames(self) -> typing.Iterable:
        """ Method for determining range of frames to animate.

        Returns:
            typing.Iterable: Range with length of index in DataFrame
        """
        return range(len(self.df.index))

    def make_animation(
        self, frames: typing.Union[typing.Iterable, int], init_func: typing.Callable
    ) -> FuncAnimation:
        """ Method for creating animation

        Args:
            frames (int): Number of frames to animate
            init_func (function): Initialization function for chart

        Returns:
            FuncAnimation: FuncAnimation instance for extending with save, etc
        """

        interval = self.period_length / self.steps_per_period
        return FuncAnimation(
            self.fig, self.anim_func, frames, init_func, interval=interval
        )

    def calculate_new_figsize(self, real_fig: plt.figure) -> typing.List[float]:
        """ Calculate figure size to allow for labels, etc

        Args:
            real_fig (plt.figure): Figure before calculation

        Returns:
            typing.List[float]: The dimensions [left, bottom, width, height] of the new axes. All quantities are in fractions of figure width and height.
        """
        import io

        fig = plt.Figure(figsize=self.figsize)

        ax = fig.add_subplot()

        max_val = self.df.values.max().max()
        # ax.tick_params(labelrotation=0, axis="y", labelsize=self.tick_label_size)
        ax.tick_params(labelrotation=0, labelsize=self.tick_label_size)

        fig.canvas.print_figure(io.BytesIO())
        orig_pos = ax.get_position()
        ax.set_yticklabels(self.df.columns)
        ax.set_xticklabels([max_val] * len(ax.get_xticks()))

        fig.canvas.print_figure(io.BytesIO(), format="png")
        new_pos = ax.get_position()

        coordx, prev_coordx = new_pos.x0, orig_pos.x0
        coordy, prev_coordy = new_pos.y0, orig_pos.y0
        old_w, old_h = self.figsize

        # if coordx > prev_coordx or coordy > prev_coordy:
        prev_w_inches = prev_coordx * old_w
        total_w_inches = coordx * old_w
        extra_w_inches = total_w_inches - prev_w_inches
        new_w_inches = extra_w_inches + old_w

        prev_h_inches = prev_coordy * old_h
        total_h_inches = coordy * old_h
        extra_h_inches = total_h_inches - prev_h_inches
        new_h_inches = extra_h_inches + old_h

        real_fig.set_size_inches(new_w_inches, new_h_inches)
        left = total_w_inches / new_w_inches
        bottom = total_h_inches / new_h_inches
        width = orig_pos.x1 - left
        height = orig_pos.y1 - bottom
        return [left, bottom, width, height]

    def apply_style(self, ax: matplotlib.pyplot.Axes) -> matplotlib.pyplot.Axes:
        """
        Apply styling to axes with spines and grid, can be overridden

        Args:
            ax (matplotlib.pyplot.Axes): Axes to apply styling to

        Returns:
            matplotlib.pyplot.Axes: Styled Axes object
        """

        ax.grid(True, axis="x", color="white")
        ax.set_axisbelow(True)
        ax.tick_params(length=0, labelsize=self.tick_label_size, pad=2)
        ax.set_facecolor(".9")
        for spine in ax.spines.values():
            spine.set_visible(False)
        return ax

    def create_figure(self) -> typing.Tuple[plt.figure, plt.axes]:
        """ Create base figure with styling, can be overridden if styling unwanted

        Returns:
            typing.Tuple[plt.figure,plt.figure.axes]: Returns Figure instance and the axes initialized within
        """

        fig = plt.Figure(figsize=self.figsize, dpi=self.dpi)
        # limit = (0.2, self.n_bars + 0.8)
        rect = self.calculate_new_figsize(fig)
        ax = fig.add_axes(rect)

        ax = self.apply_style(ax)

        return fig, ax

    def show_period(self, i: int) -> None:
        """
        Show period label on plot

        Args:
            i (int): Frame number of animation to take slice of DataFrame and retrieve current index for show as period

        Raises:
            ValueError: If custom period label location is used must contain `x`, `y` and `s` in dictionary.
        """
        if self.period_label:
            if self.period_fmt:
                idx_val = self.df.index[i]
                if self.df.index.dtype.kind == "M":  # Date time
                    s = idx_val.strftime(self.period_fmt)
                else:
                    s = self.period_fmt.format(x=idx_val)
            else:
                s = self.df.index.astype(str)[i]
            num_texts = len(self.ax.texts)
            if num_texts == 0:
                # first frame
                self.ax.text(
                    s=s,
                    transform=self.ax.transAxes,
                    **self.get_period_label(self.period_label),
                )
            else:
                self.ax.texts[0].set_text(s)

        if self.period_summary_func:
            values = self.df.iloc[i]
            text_dict = self.period_summary_func(values)
            if "x" not in text_dict or "y" not in text_dict or "s" not in text_dict:
                name = self.period_summary_func.__name__
                raise ValueError(
                    f"The dictionary returned from `{name}` must contain "
                    '"x", "y", and "s"'
                )
            if len(self.ax.texts) != 2:
                self.ax.text(transform=self.ax.transAxes, **text_dict)
            else:
                self.ax.texts[1].set_text(text_dict["s"])

    def clearing(self):
        """
        Function to remove all existing lines, collections, etc at the start 
        of the animation AND after a save. This reduces the danger with old 
        contents to be passed onto new multiple animations.
        """
        for ax in self.fig.axes:
            for item in ax.lines + ax.collections + ax.containers + ax.texts:
                item.remove()
    
    def save(self, filename: str) -> None:
        """ Save method for FuncAnimation.

        Args:
            filename (str): File name with extension to save animation to, supported formats at https://matplotlib.org/3.1.1/api/animation_api.html
        """

        # Inspiration for design pattern https://github.com/altair-viz/altair/blob/c55707730935159e4e2d2c789a6dd2bc3f1ec0f2/altair/utils/save.py
        # https://altair-viz.github.io/user_guide/saving_charts.html
        
        if self.enable_progress_bar:
            self.setup_progress_bar()

        anim = self.make_animation(self.get_frames(), self.init_func)
        self.fps = 1000 / self.period_length * self.steps_per_period
        interval = self.period_length / self.steps_per_period
        num_frames = len(self.get_frames())

        extension = filename.split(".")[-1]
        try:
            if self.writer:
                anim.save(filename, fps=self.fps, dpi=self.dpi, writer=self.writer)
            else:
                if extension == "gif":
                    import io
                    import matplotlib

                    matplotlib.use("Agg")
                    from PIL import Image

                    frames = []
                    for i in range(0, num_frames):
                        self.anim_func(i)
                        buffer = io.BytesIO()
                        self.fig.savefig(buffer, format="png")
                        buffer.seek(0)
                        image = Image.open(buffer)
                        plt.close()
                        frames.append(image)
                    frames[0].save(
                        filename,
                        save_all=True,
                        append_images=frames[:],
                        optimize=False,
                        duration=interval,
                        loop=0,
                    )
                else:
                    anim.save(filename, fps=self.fps, dpi=self.dpi)
            if self.enable_progress_bar:
                self.progress_bar.close()
            # Clearing axes contents after save, so that fig's axes can be re-used in a 
            # consequent multiple plot after this one in a notebook.
            self.clearing()

        except TypeError as e:
            raise RuntimeError(
                "Ensure that a matplotlib writer library is installed, see https://github.com/JackMcKew/pandas_alive/blob/main/README.md#requirements for more details"
            )

    # def encode_html5_video(self,anim):
    #     VIDEO_TAG = """<video controls>
    #         <source src="data:video/x-m4v;base64,{0}" type="video/mp4">
    #         Your browser does not support the video tag.
    #         </video>"""
    #     from tempfile import NamedTemporaryFile
    #     if not hasattr(anim, '_encoded_video'):
    #         with NamedTemporaryFile(suffix='.mp4') as f:
    #             anim.save(f.name, fps=20, extra_args=['-vcodec', 'libx264', '-pix_fmt', 'yuv420p'])
    #             video = open(f.name, "rb").read()
    #         anim._encoded_video = video.encode("base64")
    #     return VIDEO_TAG.format(anim._encoded_video)

    def get_html5_video(self):
        """ Convert the animation to an HTML5 <video> tag.

        This saves the animation as an h264 video, encoded in base64 directly into the HTML5 video tag. This respects the rc parameters for the writer as well as the bitrate. This also makes use of the interval to control the speed, and uses the repeat parameter to decide whether to loop.

        Returns:
            HTML5 <video> tag: Encoded h264 video
        """

        anim = self.make_animation(self.get_frames(), self.init_func)

        # html_tag = self.encode_html5_video(anim)
        html_tag = anim.to_html5_video()
        if "too large to embed" in html_tag:
            import warnings

            warnings.warn(
                "HTML5 Tag is too large to embed, try another format such as mp4, GIF or otherwise."
            )
        return html_tag

    def update_progress_bar(self) -> None:
        """
        Update TQDM instance by 1
        """
        self.progress_bar.update(1)

    def setup_progress_bar(self):
        """
        Create an instance of alive-progress bar in manual mode

        [extended_summary]

        Raises:
            ModuleNotFoundError: [description]
        """
        try:
            from tqdm.auto import tqdm

            self.progress_bar = tqdm(total=len(self.get_frames()))
        except ImportError:
            raise ImportError(
                "Install tqdm bar with `pip install tqdm` or `conda install -c conda-forge tqdm`, see more details at https://github.com/tqdm/tqdm. Additionally, for `tqdm.auto` to work in Jupyter install `ipywidgets`, more details at https://ipywidgets.readthedocs.io/en/stable/user_install.html."
            )

    # Possibly include image
    # background method?
    # def show_image(
    #     self,
    #     ax,
    #     path_to_image: str,
    #     extent: typing.Tuple[float],
    #     zorder: int = 0,
    #     aspect: str = "equal",
    # ):
    #     image = plt.imread(path_to_image)

    #     ax.imshow(image, zorder=zorder, extent=extent, aspect=aspect)

    #     return ax