简体   繁体   English

Seaborn swarmplot 面包进线

[英]Seaborn swarmplot bread into lines

I'm trying to make this swarmplot with seaborn我正在尝试用seaborn制作这个swarmplot

My problem is that the swarms are too wide.我的问题是群太宽了。 I want to be able to break them up into rows of maximum 3 dots per row我希望能够将它们分成每行最多 3 个点的行

This is my code:这是我的代码:

# Import modules
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
###

# Import and clead dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()]   # Remove NA

Books['Year'] = pd.to_datetime(             # Convert to dates
    Books['Date Read'],
    format='%YYYY%mm%dd', 
    errors='coerce')

Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year # Take only years
Books[['Year', 'Date Read']]                 # merge the years in
###

# Calculate mean rate by year
RateMeans = (Books["My Rating"].groupby(Books["Year"]).mean())
Years = list(RateMeans.index.values)
Rates = list(RateMeans)
RateMeans = pd.DataFrame(
    {'Years': Years,
     'Rates': Rates
    })
###

# Plot
fig,ax = plt.subplots(figsize=(20,10))

## Violin Plot:
plot = sns.violinplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    color = "white", 
    inner=None,
    #palette=colors_from_values(ArrayRates[:,1], "Blues")
    )

## Swarm Plot
plot = sns.swarmplot(
    data=Books, 
    x = "Year", 
    y = 'My Rating', 
    ax=ax,
    hue = "My Rating",
    size = 10
    )
    
## Style
    
### Title
ax.text(x=0.5, y=1.1, s='Book Ratings: Distribution per Year', fontsize=32, weight='bold', ha='center', va='bottom', transform=ax.transAxes)
ax.text(x=0.5, y=1.05, s='Source: Goodreads.com (amirnakar)', fontsize=24, alpha=0.75, ha='center', va='bottom', transform=ax.transAxes)



### Axis
ax.set(xlim=(4.5, None), ylim=(0,6))
#ax.set_title('Book Ratings: Distribution per Year \n', fontsize = 32)
ax.set_ylabel('Rating (out of 5 stars)', fontsize = 24)
ax.set_xlabel('Year', fontsize = 24)
ax.set_yticklabels(ax.get_yticks().astype(int), size=20)
ax.set_xticklabels(ax.get_xticks(), size=20)

### Legend
plot.legend(loc="lower center", ncol = 5 )

### Colour pallete
colorset = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"]
colorset.reverse()
sns.set_palette(sns.color_palette(colorset))


# Save the plot
#plt.show(plot)
plt.savefig("Rate-Python.svg", format="svg")

This is the output:这是 output:

在此处输入图像描述

What I want to have happen:我想要发生的事情:

I want to be able to define that each row of dots should have a maximum of 3, if it's more, than break it into a new row.我希望能够定义每行点最多应该有3 个(如果超过),而不是将其分成新行。 I demonstrate it here (done manually in PowerPoint) on two groups, but I want it for the entire plot我在这里对两组进行演示(在 PowerPoint 中手动完成),但我希望它用于整个 plot

BEFORE:前:

前

AFTER:后:

在此处输入图像描述

Here is an attempt to relocate the dots a bit upward/downward.这是将点向上/向下重新定位的尝试。 The value for delta comes from experimenting. delta的值来自实验。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Import and clea dataset
url = "https://raw.githubusercontent.com/amirnakar/scratchboard/master/Goodreads/goodreads_library_export.csv"
Books = pd.read_csv(url)
Books = Books[Books['Date Read'].notna()]  # Remove NA
Books['Year'] = pd.DatetimeIndex(Books['Date Read']).year  # Take only years
# Calculate mean rate by year
RatePerYear = Books[["My Rating", "Year"]].groupby("Year")["My Rating"].value_counts()

modified_ratings = []
delta = 0.2  # distance to move overlapping ratings
for (year, rating), count in RatePerYear.iteritems():
    higher = max(0, ((count - 3) + 1) // 2)
    lower = max(0, (count - 3) // 2)
    modified_ratings.append([year, rating, count - higher - lower])
    for k in range((higher + 2) // 3):
        modified_ratings.append([year, rating + (k + 1) * delta, 3 if (k + 1) * 3 <= higher else higher % 3])
    for k in range((lower + 2) // 3):
        modified_ratings.append([year, rating - (k + 1) * delta, 3 if (k + 1) * 3 <= lower else lower % 3])
modified_ratings = np.array(modified_ratings)
modified_ratings_df = pd.DataFrame(
    {'Year': np.repeat(modified_ratings[:, 0].astype(int), modified_ratings[:, 2].astype(int)),
     'My Rating': np.repeat(modified_ratings[:, 1], modified_ratings[:, 2].astype(int))})
modified_ratings_df['Rating'] = modified_ratings_df['My Rating'].round().astype(int)

fig, ax = plt.subplots(figsize=(20, 10))
sns.violinplot(data=Books, x="Year", y='My Rating', ax=ax, color="white", inner=None)
palette = ["#FAFF04", "#FFD500", "#9BFF00", "#0099FF", "#000BFF"].reverse()
sns.swarmplot(data=modified_ratings_df, x="Year", y='My Rating', ax=ax, hue="Rating", size=10, palette=palette)

ax.set(xlim=(4.5, None), ylim=(0, 6))
ax.legend(loc="lower center", ncol=5)
plt.tight_layout()
plt.show()

sns.swarmplot 将行宽限制为 3

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM