I am having an interesting issue with displaying data with Seaborn lineplot.
I have item sales for 5 items over a period of time. I would like to see the sales after each product's introduction.
Here is my code:
items = ['Item 1', 'Item 2', 'Item 3', 'Item 4', 'Item 5']
fig, ax = plt.subplots(squeeze=False)
ax[0] = sns.lineplot(x=item_sales.index, y='Item 1', data=item_sales, alpha=0.2)
ax[1] = sns.lineplot(x=item_sales.index, y='Item 2', data=item_sales, alpha=0.2)
ax[2] = sns.lineplot(x=item_sales.index, y='Item 3', data=item_sales, alpha=0.2)
ax[3] = sns.lineplot(x=item_sales.index, y='Item 4', data=item_sales, alpha=0.4)
ax[4] = sns.lineplot(x=item_sales.index, y='Item 5', data=item_sales, alpha=0.2)
ax.set_ylabel('')
ax.set_yticks([])
plt.title('Timeline of item sales')
plt.show()
This code errors out with the following line but draws 2 lines:
ax[1] = sns.lineplot(x=item_sales.index, y='Item 2', data=item_sales, alpha=0.2)
IndexError: index 1 is out of bounds for axis 0 with size 1
However, following line displays the plot perfectly without any errors:
item_sales.plot()
What could be the reason for the above error - the data is pretty clean and there are no missing values:
item_sales.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36 entries, 0 to 35
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Date Created 36 non-null object
1 Item 1 36 non-null int64
2 Item 2 36 non-null int64
3 Item 3 36 non-null int64
4 Item 4 36 non-null int64
5 Item 5 36 non-null int64
dtypes: int64(5), object(1)
memory usage: 1.8+ KB
Thank you.
The reason you're getting the IndexError
is because your ax
object is a 2-dimensional array and you're indexing on the first (length=1) dimension:
squeeze bool, default: True
If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.
If you want to plot multiple lines on the same plot, you can just have them share the same ax
by passing it to seaborn
like so:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# prepare sample data
items = ['Item 1', 'Item 2', 'Item 3', 'Item 4', 'Item 5']
sales_data = dict(zip(items, np.random.randint(0, 25, (5, 30))))
item_sales = pd.DataFrame(sales_data)
fig, ax = plt.subplots(figsize=(8,4))
sns.set_palette("tab10", n_colors=5)
sns.lineplot(x=item_sales.index, y='Item 1', data=item_sales, alpha=0.3, ax=ax)
sns.lineplot(x=item_sales.index, y='Item 2', data=item_sales, alpha=0.3, ax=ax)
sns.lineplot(x=item_sales.index, y='Item 3', data=item_sales, alpha=0.3, ax=ax)
sns.lineplot(x=item_sales.index, y='Item 4', data=item_sales, alpha=1, ax=ax)
sns.lineplot(x=item_sales.index, y='Item 5', data=item_sales, alpha=0.3, ax=ax)
ax.set_ylabel('')
ax.set_yticks([])
plt.title('Timeline of item sales')
plt.show()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.