简体   繁体   中英

How to convert a matplotlib spectrogram image into a torch tensor

import numpy as np
from numpy import asarray
from matplotlib import pyplot as plt
import torch

# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal

# generate specgram
spectrum, freqs, t, im = plt.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))

# convert matplotlib image to torch tensor
# bypassing the numpy part would be even better!
torch_tensor = torch.from_numpy(asarray(im, np.float32))

print(torch_tensor)

>>> TypeError: float() argument must be a string or a number, not 'AxesImage'

I should add that the 'spectrum' variable is kind of what I am looking for, except that I am a little confused by it since it has only two columns for time, and I think the specgram image has many more than two timesteps. If there is a way to use the spectrum variable to represent the whole image as a torch tensor, then that would also work for me.

plt.specgram returns the spectrogram in the spectrum variable. This means that you need to pass that variable to the torch.from_numpy function. Additionally, according to this , specgram shows the 10*log10(spectrum) which means that you might want to do that operation ot compare the results shown by specgram with the plot of your tensor. See code below:

import numpy as np
from numpy import asarray
import numpy as np
from matplotlib import pyplot as plt
import torch

# generate a signal
fs = 50 # sampling freq
ts = np.arange(0, 10, 1/fs) # times at which signal is sampled
s1 = np.sin(2 * np.pi * 2 * ts) # 2 hz
s2 = np.sin(2 * np.pi * 3 * ts) # 3 hz
s3 = np.sin(2 * np.pi * 6 * ts) # 6 hz
s = s1 + s2 + s3 # aggregate signal

# generate specgram
ax1=plt.subplot(121)
ax1.set_title('Specgram image')
spectrum, freqs, t, im = ax1.specgram(s, Fs=fs, xextent=((0, len(s)/fs)))
ax1.axis('tight')

torch_tensor = torch.from_numpy(spectrum)

#Plot torch tensor variable
ax2=plt.subplot(122)
ax2.set_title('Torch tensor image')
ax2.imshow(10*np.log10(torch_tensor),origin='lower left',extent=[0,10,0,25])
ax2.axis('tight')

plt.show()

And the output gives:

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM