简体   繁体   中英

Astropy: split a FITS table into a training and testing set

I have an FITS table I am manipulating with astropy. I would like to split the table into training and testing data at random to create two new FITS tables.

I first thought of using the scikit-learn function test_train_split , but then I would have to convert my data back and forth into a numpy.array .

So far, I have read the astropy.table.Table data from a FITS file and tried the following

training_fraction = 0.5
n = len(data)
indexes = random.sample(range(n), k=int(n*training_fraction))
testing_sample = data[indexes]
training_sample = ?

But then, I don't know how to get all the lines whose indexes are not in indexes . Perhaps is there a better way to do this? How can I get a random partition of my Table ?


The samples in my table happen to each have a unique ID which is an integer between 1 and len(data). So I figured, I could do

indexes = random.sample(range(1, n+1), k=int(n*training_fraction))
testing_sample = data[data['ID'] in indexes]
training_sample = data[data['ID'] not in indexes]

but the first line raises ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

How I managed to get this done was

training_indexes = sorted(random.sample(range(n), k=int(n*training_fraction)))
testing_indexes = [i for i in range(n) if i not in training_indexes]


testing_sample = data[testing_indexes]
training_sample = data[training_indexes]

But I don't know if this is the most efficient way, or the most pythonic way.

You mentioned using the existing train_test_split routing from scikit-learn. If this is the only thing you're using scikit-learn for, it would be overkill. But if you're already using it for other parts of your task you might as well. Astropy Tables are already backed by Numpy arrays to begin with, so you don't need to "convert your data back and forth".

Since the 'ID' column of your table indexes rows in your table, it would be useful to formally set it as an index of your table, so that ID values can be used to index rows in the table (independently of their actual positional index). For example:

>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
...     'ID': [1, 3, 5, 6, 7, 9],
...     'a': np.random.random(6),
...     'b': np.random.random(6)
... })
>>> t
<Table length=6>
  ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
    1  0.7285295918917892  0.6180944983953155
    3  0.9273855839237182 0.28085439237508925
    5  0.8677312765220222  0.5996267567496841
    6 0.06182255608446752  0.6604620336092745
    7 0.21450048405835265  0.5351066893214822
    9   0.928930682667869  0.8178640424254757

Then set 'ID' as the table's index:

>>> t.add_index('ID')

Use train_test_split to partition the IDs however you want:

>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
  ID           a                  b         
int64       float64            float64      
----- ------------------- ------------------
    7 0.21450048405835265 0.5351066893214822
    9   0.928930682667869 0.8178640424254757
    5  0.8677312765220222 0.5996267567496841
    1  0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
  ID           a                   b         
int64       float64             float64      
----- ------------------- -------------------
    6 0.06182255608446752  0.6604620336092745
    3  0.9273855839237182 0.28085439237508925

(Note:

>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
 astropy.table.column.BaseColumn,
 astropy.table._column_mixins._ColumnGetitemShim,
 numpy.ndarray,
 object)

)

For what it's worth, as it might help you find answers to problems like this more easily in the future, it would help to consider what you're trying to do more abstractly (it seems you already are doing this, but phrasing of your question suggests otherwise): The columns in your table are just Numpy arrays--once it's in that form it's irrelevant that they were read from FITS files. What you're doing has nothing directly at that point to do with Astropy either. The question just becomes how to randomly partition a Numpy array.

You can find generic answers to this problem, for example, in this question . But it's also nice to use an existing, special-purpose utility like train_test_split if you have it.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM