简体   繁体   中英

Matplotlib scatter plot doesn't take strings on x-axis?

I have data in this format. I am able to get a bar plot out of it but not a scatter plot because the x-axis has strings. I looked into a couple of other posts from where I couldn't gather much info.

  import pandas as pd
    import numpy as np

    df = pd.DataFrame({"id":["ssa", "ssb", "ssc", "xxa", "xxb", "xxc"], "mean":[1.3,1.5,5.2,3.1,2.1,3.2], "sd":[0.9,0.5,0.3,0.1,0.2,0.3]})
    df

I get a bar plot with error bars using the following command:

import matplotlib.pyplot as plt
ax = plt.figure()
ax = df.plot(kind='bar',x='id', y='mean',figsize=[15,6], yerr='sd')
ax.set_xlabel("id")
ax.set_ylabel("mean")
ax = plt.tight_layout()
ax = plt.show()

在此处输入图片说明

But I get an error when I try to do a scatter plot of the same df.

ax = plt.figure()
ax = df.plot(kind='scatter',x='id', y='mean',figsize=[15,6], yerr='sd')
ax.set_xlabel("id")
ax.set_ylabel("mean")
ax = plt.tight_layout()
ax = plt.show()

Error traceback:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-10-b3ab7237d4f1> in <module>()
      1 ax = plt.figure()
----> 2 ax = df.plot(kind='scatter',x='id', y='mean',figsize=[15,6], yerr='sd', style='.')
      3 ax.set_xlabel("id")
      4 ax.set_ylabel("mean")
      5 ax = plt.tight_layout()

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\plotting\_core.pyc in __call__(self, x, y, kind, ax, subplots, sharex, sharey, layout, figsize, use_index, title, grid, legend, style, logx, logy, loglog, xticks, yticks, xlim, ylim, rot, fontsize, colormap, table, yerr, xerr, secondary_y, sort_columns, **kwds)
   2618                           fontsize=fontsize, colormap=colormap, table=table,
   2619                           yerr=yerr, xerr=xerr, secondary_y=secondary_y,
-> 2620                           sort_columns=sort_columns, **kwds)
   2621     __call__.__doc__ = plot_frame.__doc__
   2622 

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\plotting\_core.pyc in plot_frame(data, x, y, kind, ax, subplots, sharex, sharey, layout, figsize, use_index, title, grid, legend, style, logx, logy, loglog, xticks, yticks, xlim, ylim, rot, fontsize, colormap, table, yerr, xerr, secondary_y, sort_columns, **kwds)
   1855                  yerr=yerr, xerr=xerr,
   1856                  secondary_y=secondary_y, sort_columns=sort_columns,
-> 1857                  **kwds)
   1858 
   1859 

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\plotting\_core.pyc in _plot(data, x, y, subplots, ax, kind, **kwds)
   1680         plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
   1681 
-> 1682     plot_obj.generate()
   1683     plot_obj.draw()
   1684     return plot_obj.result

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\plotting\_core.pyc in generate(self)
    236         self._compute_plot_data()
    237         self._setup_subplots()
--> 238         self._make_plot()
    239         self._add_table()
    240         self._make_legend()

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\plotting\_core.pyc in _make_plot(self)
    829         else:
    830             label = None
--> 831         scatter = ax.scatter(data[x].values, data[y].values, c=c_values,
    832                              label=label, cmap=cmap, **self.kwds)
    833         if cb:

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\core\frame.pyc in __getitem__(self, key)
   2060             return self._getitem_multilevel(key)
   2061         else:
-> 2062             return self._getitem_column(key)
   2063 
   2064     def _getitem_column(self, key):

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\core\frame.pyc in _getitem_column(self, key)
   2067         # get column
   2068         if self.columns.is_unique:
-> 2069             return self._get_item_cache(key)
   2070 
   2071         # duplicate columns & possible reduce dimensionality

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\core\generic.pyc in _get_item_cache(self, item)
   1532         res = cache.get(item)
   1533         if res is None:
-> 1534             values = self._data.get(item)
   1535             res = self._box_item_values(item, values)
   1536             cache[item] = res

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\core\internals.pyc in get(self, item, fastpath)
   3588 
   3589             if not isnull(item):
-> 3590                 loc = self.items.get_loc(item)
   3591             else:
   3592                 indexer = np.arange(len(self.items))[isnull(self.items)]

C:\Users\AppData\Local\Continuum\Anaconda2\lib\site-packages\pandas\core\indexes\base.pyc in get_loc(self, key, method, tolerance)
   2393                 return self._engine.get_loc(key)
   2394             except KeyError:
-> 2395                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2396 
   2397         indexer = self.get_indexer([key], method=method, tolerance=tolerance)

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas\_libs\index.c:5239)()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc (pandas\_libs\index.c:5085)()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas\_libs\hashtable.c:20405)()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item (pandas\_libs\hashtable.c:20359)()

KeyError: 'id'

So, instead I used seaborn to plot and it plots perfectly. However, I am not sure how to pass "sd" column to plot error bars in it.

fig, ax = plt.subplots(figsize=(5,3))
ax = sns.pointplot(x="id", y="mean",  data=df, join=False)
ax = plt.xticks(rotation=90)
ax = plt.tight_layout()
ax = plt.show()

在此处输入图片说明

fig, ax = plt.subplots(figsize=(25,5))
ax = sns.pointplot(x="id", y="mean", data=df, join=False)
ax.map(plt.errorbar, "id", "mean", "sd", marker="o")
ax = plt.xticks(rotation=90)
ax = plt.tight_layout()
ax = plt.show()

The above code throws the following error:

AttributeError                            Traceback (most recent call last)
<ipython-input-21-18652e3e8b12> in <module>()
      1 fig, ax = plt.subplots(figsize=(25,5))
      2 ax = sns.pointplot(x="id", y="mean", data=df, join=False)
----> 3 ax.map(plt.errorbar, "id", "mean", "sd", marker="o")
      4 ax = plt.xticks(rotation=90)
      5 ax = plt.tight_layout()

AttributeError: 'AxesSubplot' object has no attribute 'map'

What I would ideally like to have is a plot similar to the pointplot but with each point having a different size (as specified by its corresponding sd) or with each point having an error bar (given by sd). Can someone tell me how to do this?

Just use matplotlib.axes.Axes.errorbar() as shown in this example by adding this line:

ax.errorbar(np.arange(len(df['id'])), df['mean'], yerr=df['sd'], ls='None')

But I don't think you have to use seaborn:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.DataFrame({"id":["ssa", "ssb", "ssc", "xxa", "xxb", "xxc"], 
                   "mean":[1.3,1.5,5.2,3.1,2.1,3.2], 
                   "sd":[0.9,0.5,0.3,0.1,0.2,0.3]})
plt.errorbar(np.arange(len(df['id'])), df['mean'], yerr=df['sd'], ls='None', marker='o')
ax = plt.gca()
ax.xaxis.set_ticks(np.arange(len(df['id'])))
ax.xaxis.set_ticklabels(df['id'], rotation=90)
plt.xlabel("id")
plt.ylabel("mean")

plt.show()

在此处输入图片说明

Edit: OP now wants "to remove error bars and draw your point with size proportional to sd".

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.DataFrame({"id":["ssa", "ssb", "ssc", "xxa", "xxb", "xxc"], 
                   "mean":[1.3,1.5,5.2,3.1,2.1,3.2], 
                   "sd":[0.9,0.5,0.3,0.1,0.2,0.3]})
fig, ax = plt.subplots()
size_scaler = 300 # Your points will be too small if you just use sd
ax.scatter(np.arange(len(df['id'])), df['mean'], s=df['sd']*size_scaler, marker='o')
ax.xaxis.set_ticks(np.arange(len(df['id'])))
ax.xaxis.set_ticklabels(df['id'], rotation=90)
plt.xlabel("id")
plt.ylabel("mean")

plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM