简体   繁体   中英

“Matrices are not aligned” error in toy example of ols using gneiss

I am attempting to build a trivial example of a linear regression for compositional data. I'm using the following code:

from pandas import DataFrame
import numpy as np
from skbio import TreeNode
from gneiss.regression import ols
from IPython.display import display

#define table of compositions
yTrain = DataFrame({'y1': [0.8, 0.3, 0.5], 'y2': [0.2, 0.7, 0.5]})

#define predictors for compositions
xTrain = DataFrame({'x1': [1,3,2]})

#Once these variables are defined, a regression can be performed. These proportions will be converted to balances according to the tree specified. And the regression formula is specified to run temp and ph against the proportions in a single model.
model = ols('x1', yTrain, xTrain)
model.fit()
xTest = DataFrame({'x1': [1,3]})
yTest = model.predict(xTest)
display(yTest)

I'm getting the error matrices are not aligned . Any idea on how to get this running?

It looks like you have mixed up your x and y matrices between the training and test stages. Your xTest should perhaps be identical in structure to yTrain . In your code xTest looks like xTrain which seems to correspond to labels.

The general convention in ML is to use x for inputs and y for outputs. In your case, you have used y for inputs and x for labels during training, and the other way around during testing.

For instance, try setting xTest to the following:

xTest = DataFrame({'y1': [0.1, 0.4, 0.6], 'y2': [0.4, 0.2, 0.8]})

That should get rid of the error. You would ideally do something along the lines of the following:

from pandas import DataFrame
import numpy as np
from skbio import TreeNode
from gneiss.regression import ols
from IPython.display import display

#define table of compositions
xTrain = DataFrame({'x1': [0.8, 0.3, 0.5], 'x2': [0.2, 0.7, 0.5]})

#define predictors for compositions
yTrain = DataFrame({'y1': [1,3,2]})

model = ols('y1', xTrain, yTrain)
model.fit()
xTest = DataFrame({'x1': [0.1, 0.4, 0.6], 'x2': [0.4, 0.2, 0.8]})
yTest = model.predict(xTest)
display(yTest)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM