简体   繁体   中英

Python create if elif in __init__ for package and function

I am combining all defined function into a class and use if , elif to operate.
I will explain in the following.

First, I have a 3 types of plot, combo , line , and bar .
I know how to define function separately for these three plot.

Second, I want to combine these 3 plots together within a package using if .
The code I tried is:

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


class AP(object):

    def __init__(self, dt, date, group, value, value2, value3, value4, value5, value6, TYPE):
        self.dt = dt
        self.date = date
        self.group= carrier
        self.value = value
        self.col1 = col1
        self.col2 = col2
        self.col3 = col3
        self.col4 = col4
        self.TYPE = TYPE

        if self.TYPE == "combo":

            def ComboChart(self, dt, date, group, value, TYPE):
                dataset = pd.read_csv(dt)
                dataset['date'] = pd.to_datetime(dataset[date])
                dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
                dataset['qtr'] = dataset['date'].dt.quarter
                dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
                dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
                dataset = dataset[np.isfinite(dataset['total.YQGR'])]
                dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
                dataset.yq = dataset.yq.astype(str)

                fig, ax1 = plt.subplots(figsize=(12,7))
                ax2=ax1.twinx()
                sns.lineplot(x='yq',y='total.R', data=dataset, hue=group, ax=ax1, legend = None, palette = ('navy', 'r'), linewidth=5)
                ax1.set_xticklabels(ax1.get_xticks(), rotation=45, fontsize=15, weight = 'heavy')
                ax1.set_xlabel("", fontsize=15)
                ax1.set_ylabel("")
                ax1.set_ylim((0, max(dataset['total.R']) + 0.05))
                sns.barplot(x='yq', y='total.YQGR', data=dataset, hue=group, ax=ax2, palette = ('navy', 'r'))
                ax2.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax2.get_yticks()])
                ax2.set_ylabel("")
                ax2.set_ylim((min(dataset['total.YQGR']) - 0.01, max(dataset['total.YQGR']) + 0.2))
                ax2.get_legend().remove()
                ax2.legend(bbox_to_anchor=(-0.35, 0.5), loc=2, borderaxespad=0., fontsize = 'xx-large')
                for groups in ax2.containers:
                    for bar in groups:
                        if bar.get_height() >= 0:
                            ax2.text(
                                    bar.get_xy()[0] + bar.get_width()/1.5,
                                    bar.get_height() + 0.003, 
                                '{:.1f}%'.format(round(100*bar.get_height(),2)), 
                                    color='black',
                                    horizontalalignment='center',
                                    fontsize = 12, weight = 'heavy'
                                    )
                        else:
                            ax2.text(
                                    bar.get_xy()[0] + bar.get_width()/1.5,
                                    bar.get_height() - 0.008, 
                                '{:.1f}%'.format(round(100*bar.get_height(),2)), 
                                    color='black',
                                    horizontalalignment='center',
                                    fontsize = 12, weight = 'heavy'
                                    )
                ax1.yaxis.set_visible(False)
                ax2.yaxis.set_visible(False)
                ax2.xaxis.set_visible(False)
                ax1.spines["right"].set_visible(False)
                ax1.spines["left"].set_visible(False)
                ax1.spines["top"].set_visible(False)
                ax1.spines["bottom"].set_visible(False)
                ax2.spines["right"].set_visible(False)
                ax2.spines["left"].set_visible(False)
                ax2.spines["top"].set_visible(False)
                ax2.spines["bottom"].set_visible(False)
                ax1.set_title(TYPE, fontsize=20)
                plt.show()

                fig.savefig(TYPE, bbox_inches='tight', dpi=600)

        elif self.TYPE == "line":

            def line(self, dt, date, carrier, value, value2, TYPE):
                dataset = pd.read_csv(dt)
                dataset['date'] = pd.to_datetime(dataset[date])
                dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
                dataset = dataset.groupby([group, 'yq'])[value, value2].sum().reset_index()
                dataset['Arate'] = dataset[value2] / dataset[value]
                dataset.yq = dataset.yq.astype(str)

                fig, ax1 = plt.subplots(figsize=(12,7))
                sns.lineplot(x='yq', y='Arate', data=dataset, hue=group, ax=ax1, linewidth=5)
                ax1.set_xticklabels(dataset['yq'], rotation=45, fontsize = 15)
                ax1.set_xlabel("")
                ax1.set_ylabel("")
                ax1.set_ylim((min(dataset['Arate']) - 0.05, max(dataset['Arate']) + 0.05))
                ax1.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax1.get_yticks()], fontsize = 18, weight = 'heavy')
                ax1.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=2, borderaxespad=0., ncol = 6)
                ax1.yaxis.grid(True)
                ax1.spines["right"].set_visible(False)
                ax1.spines["left"].set_visible(False)
                ax1.spines["top"].set_visible(False)
                ax1.spines["bottom"].set_visible(False)
                ax1.set_title(TYPE, fontsize = 20)
                plt.show()

                fig.savefig(TYPE, bbox_inches='tight', dpi=600)

        elif self.TYPE == "bar":

            def Bar(self, dt, date, group, value3, value4, value5, value6, TYPE):
                dataset = pd.read_csv(dt, sep = '|')
                dataset['date'] = pd.to_datetime(dataset[date])
                dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
                dataset = dataset.groupby([group, 'yq'])[value3, value4, value5, value6].sum().reset_index()
                dataset = dataset.groupby([group]).tail(4)
                dataset.yq = dataset.yq.astype(str)
                dataset = pd.melt(dataset, id_vars = [group, 'yq'], value_vars = [value3, value4, value5, value6])
                dataset = dataset.groupby(['variable', group]).value.sum().reset_index()
                dataset['L4Qtr'] = dataset.value / dataset.groupby([group]).value.transform('sum')

                fig, ax1 = plt.subplots(figsize=(12,7))
                sns.barplot(x='variable', y='L4Qtr', data=dataset, hue=group, ax=ax1)
                ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=17.5, weight = 'heavy')
                ax1.set_xlabel("", fontsize=15)
                ax1.set_ylabel("")
                ax1.yaxis.set_ticks(np.arange(0, max(dataset['L4Qtr']) + 0.1, 0.05), False)
                ax1.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax1.get_yticks()], fontsize = 18, weight = 'heavy')
                ax1.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=2, borderaxespad=0., ncol = 6)
                for groups in ax1.containers:
                    for bar in groups:
                        ax1.text(
                                bar.get_xy()[0] + bar.get_width()/2,
                                bar.get_height() + 0.005, 
                            '{:.1f}%'.format(round(100*bar.get_height(),2)), 
                                color=bar.get_facecolor(),
                                horizontalalignment='center',
                                fontsize = 16, weight = 'heavy'
                                    )
                ax1.spines["right"].set_visible(False)
                ax1.spines["left"].set_visible(False)
                ax1.spines["top"].set_visible(False)
                ax1.spines["bottom"].set_visible(False)
                ax1.set_title(TYPE, fontsize=20)
                plt.show()

                fig.savefig(TYPE, bbox_inches='tight', dpi=600)

Third, I hope others can simply use this module as below:

import sys
sys.path.append(r'\\users\desktop\module')
from AP import AP as ap

Finally, when someone assign TYPE , it will automatically plot and save it.

# This will plot combo chart
ap(r'\\users\desktop\dataset.csv', date = 'DATEVALUE', group = 'GRPS', value = 'total', TYPE = 'combo')

Above is the ideal thought. I do not need to pass value2 ~ value6 in it since combo does not use them.
When I want bar :

# This will plot bar chart
ap(r'\\users\desktop\dataset.csv', date = 'DATEVALUE', group = 'GRPS', value3 = 'col1', value4 = 'col2', value5 = 'col3', value6 = 'col4', TYPE = 'combo')

My code is incorrect since error happened. It seems that I need to pass all parameters in it.

However, even I passed all parameters in it. No error but no output.

Any suggestion?

could you explain, why you don't just create subclasses for the types? Wouldn't that be more straight-forward?

1.) One way would be to make the subclasses visible to the user and if you don't like this,

2.) you could just create a kind of interface class (eg AP that hides the class that is used behind the scenes and for example instanciates as soon as the type is set.

3.) you can work as you began, but then I guess you would have to make the methods visible to the user, because I guess the way you implemented it, the functions are only visible in the init method (maybe your indentaion is not quite correct). For example if your if statements are executed by the init method, then you could assign the methods to instance variables like self.ComboChart= ComboChart to be able to call the method from outside. But imho that would not be very pythonic and a bit more hacky/less object oriented.

So I'd suggest 1.) and if that is not possible for some reason, then I'd go for solution 2. Both solutions also allow you to form a clean class structure and reuse code that way, while you are still able to build your simplified interface class if you like.

An example (pseudo code) for method 1 would look like below. Please note, that I haven't tested it, it is only meant to give you an idea, about splitting logic in an object oriented way. I didn't check your whole solution and so I don't know for example, if you always group your data in the same way. I'd proabably also separate the presentation logic from the data logic. That would especially be a good idea if you plan to display the same data in more ways, because with the current logic, you would reread the csv file and reporcess the data each time you want another represenatiation. So not to make it more complicated while I just want to explain the basic principle I ignored this and gave an example for a base class "Chart" and a subclass "ComboChart". The "ComboChart" class knows how to read/group the data, because it inherits the methods from "Chart", so you only have to implement it once and thus if you find a bug or want to enhance it later, you only need to do it in one place. The draw_chart method then only needs to do what's different according to the chosen representation. A user would have to create the instance of the subclass according the chart type they want to display and call display_chart().

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


class Chart(object):
    def __init__(self, dt, date, group, value, value2, value3, value4, value5, value6):
        self.dt = dt
        self.date = date
        self.group= carrier
        self.value = value
        self.col1 = col1
        self.col2 = col2
        self.col3 = col3
        self.col4 = col4
        self.TYPE = TYPE
        self.dataset= None

    def _read_data_(self)        
        dataset = pd.read_csv(dt)
        dataset['date'] = pd.to_datetime(dataset[self.date])
        dataset['yq'] = pd.PeriodIndex(dataset['date'], freq='Q')
        dataset['qtr'] = dataset['date'].dt.quarter
        dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
        dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
        dataset = dataset[np.isfinite(dataset['total.YQGR'])]
        dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
        dataset.yq = dataset.yq.astype(str)
        self.dataset= dataset
        return dataset

    def get_data(self):
        if self.dataset is None:
            self._read_data_()
        return self.dataset

    def group_data(self):
        dataset= self.get_data()
        dataset = dataset.groupby([carrier, 'yq', 'qtr'])[value].sum().reset_index()
        dataset['total.YQGR'] = dataset[value] / dataset.groupby(['qtr', carrier])[value].transform('shift') - 1
        dataset = dataset[np.isfinite(dataset['total.YQGR'])]
        dataset['total.R'] = dataset[value] / dataset.groupby(group)[value].transform('first')
        dataset.yq = dataset.yq.astype(str)
        return dataset

    def draw_chart(self):
        pass


class ComboChart(Chart):
    def draw_chart(self):
        dataset = self.group_data()
        fig, ax1 = plt.subplots(figsize=(12,7))
        ax2=ax1.twinx()
        sns.lineplot(x='yq',y='total.R', data=dataset, hue=group, ax=ax1, legend = None, palette = ('navy', 'r'), linewidth=5)
        ax1.set_xticklabels(ax1.get_xticks(), rotation=45, fontsize=15, weight = 'heavy')
        ax1.set_xlabel("", fontsize=15)
        ax1.set_ylabel("")
        ax1.set_ylim((0, max(dataset['total.R']) + 0.05))
        sns.barplot(x='yq', y='total.YQGR', data=dataset, hue=group, ax=ax2, palette = ('navy', 'r'))
        ax2.set_yticklabels(['{:.1f}%'.format(a*100) for a in ax2.get_yticks()])
        ax2.set_ylabel("")
        ax2.set_ylim((min(dataset['total.YQGR']) - 0.01, max(dataset['total.YQGR']) + 0.2))
        ax2.get_legend().remove()
        ax2.legend(bbox_to_anchor=(-0.35, 0.5), loc=2, borderaxespad=0., fontsize = 'xx-large')
        for groups in ax2.containers:
            for bar in groups:
                if bar.get_height() >= 0:
                    ax2.text(
                            bar.get_xy()[0] + bar.get_width()/1.5,
                            bar.get_height() + 0.003, 
                        '{:.1f}%'.format(round(100*bar.get_height(),2)), 
                            color='black',
                            horizontalalignment='center',
                            fontsize = 12, weight = 'heavy'
                            )
                else:
                    ax2.text(
                            bar.get_xy()[0] + bar.get_width()/1.5,
                            bar.get_height() - 0.008, 
                        '{:.1f}%'.format(round(100*bar.get_height(),2)), 
                            color='black',
                            horizontalalignment='center',
                            fontsize = 12, weight = 'heavy'
                            )
        ax1.yaxis.set_visible(False)
        ax2.yaxis.set_visible(False)
        ax2.xaxis.set_visible(False)
        ax1.spines["right"].set_visible(False)
        ax1.spines["left"].set_visible(False)
        ax1.spines["top"].set_visible(False)
        ax1.spines["bottom"].set_visible(False)
        ax2.spines["right"].set_visible(False)
        ax2.spines["left"].set_visible(False)
        ax2.spines["top"].set_visible(False)
        ax2.spines["bottom"].set_visible(False)
        ax1.set_title(TYPE, fontsize=20)
        plt.show()

        fig.savefig(TYPE, bbox_inches='tight', dpi=600)

The second method (with the interface class) would just look the same, only that you have a forth class that is known to the user and knows how to call the real implementation. Like this:

class YourInterface:
    def __init__(self, your_arguments, TYPE):
        if TYPE == __ 'ComboChart':
            self.client= ComboChart(your_arguments)
        elif TYPE == ....

    def display_chart(self):
        self.client.display_chart()

But it's a pretty boring class, isnt't it? I'd only do this if your class hierarchy is very technical and could change over time if you want to avoid that the users of your library build up dependencies on the real class hierarchy that would probably be broken as soon as you change your hierarchy. For most cases I guess, class hierarchies stay relatively stable, so you don't need such an extra level of abstraction created by an interface class.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM