简体   繁体   中英

Drawing a 3d scatter plot with data dependent markers and colors

I have a dataframe df with 5 columns, f1,f2,f3,f4,y, where all values in all columns are from a finite set of integers, in fact, all columns are categorial columns, converted to integers. What I would like to do is to draw a 3d scatter plot, with f1,f2,f3 on the axis, the marker style should be determined by f4, and finally the color should be determined by the y column.

The following code deals with the axis and the colors.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig =plt.figure(figsize=(20,16)).gca(projection='3d')
fig.scatter(df['f1'], df['f2'], df['f3'], c=df['y'], s=100)
plt.show()

I'm however not sure how to get the marker style based on column f4. Remotely inspired by this post I would probably define a list of possible markers:

marker_styles = ['.','o','v','^','>','<','s','p','*','h','H','D','d','1']

Then I would group my data based on f4. For each group I would use the next marker, and rotate back to the beginning of the marker_styles list if there are more groups than markers.

I'm not sure how to execute this idea, or whether there are better alternatives.

You could iterate through all possible values for f4, create a filter for that value and combine it with a marker:

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
import itertools

N = 100
print(np.random.randint(1, 10, N))
df = pd.DataFrame({'f1': np.random.randint(1, 11, N),
                   'f2': np.random.randint(1, 11, N),
                   'f3': np.random.randint(1, 11, N),
                   'f4': np.random.randint(1, 11, N),
                   'y': np.random.randint(1, 11, N)})
marker_styles = ['.', 'o', 'v', '^', '>', '<', 's', 'p', '*', 'h', 'H', 'D', 'd', '1']

fig = plt.figure(figsize=(20, 16)).gca(projection='3d')
f4min = df['f4'].min()
f4max = df['f4'].max()
for f, m in zip(range(f4min, f4max + 1), itertools.cycle(marker_styles)):
    filter = df['f4'] == f
    fig.scatter(df['f1'][filter], df['f2'][filter], df['f3'][filter], c=df['y'][filter], s=100, marker=m, cmap='plasma')
plt.show()

例子

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM