开发者

Weighted Linear Regression in Java

Does anyone know of a scientific/mathematical library in Java that has a straightforward implementation of weighted linear regression? Something along the lines of a function that takes 3 arguments and returns the corresponding coefficients:

linearRegression(x,y,weights)

This 开发者_高级运维seems fairly straightforward, so I imagine it exists somewhere.

PS) I've tried Flannigan's library: http://www.ee.ucl.ac.uk/~mflanaga/java/Regression.html, it has the right idea but seems to crash sporadically and complain out my degrees of freedom?


Not a library, but the code is posted: http://www.codeproject.com/KB/recipes/LinReg.aspx (and includes the mathematical explanation for the code, which is a huge plus). Also, it seems that there is another implementation of the same algorithm here: http://sin-memories.blogspot.com/2009/04/weighted-linear-regression-in-java-and.html

Finally, there is a lib from a University in New Zealand that seems to have it implemented: http://www.cs.waikato.ac.nz/~ml/weka/ (pretty decent javadocs). The specific method is described here: http://weka.sourceforge.net/doc/weka/classifiers/functions/LinearRegression.html


I was also searching for this, but I couldn't find anything. The reason might be that you can simplify the problem to the standard regression as follows:

The weighted linear regression without residual can be represented as diag(sqrt(weights))y = diag(sqrt(weights))Xb where diag(sqrt(weights))T basically means multiplying each row of the T matrix by a different square rooted weight. Therefore, the translation between weighted and unweighted regressions without residual is trivial.

To translate a regression with residual y=Xb+u into a regression without residual y=Xb, you add an additional column to X - a new column with only ones.

Now that you know how to simplify the problem, you can use any library to solve the standard linear regression.

Here's an example, using Apache Commons Math:

void linearRegression(double[] xUnweighted, double[] yUnweighted, double[] weights) {
    double[] y = new double[yUnweighted.length];
    double[][] x = new double[xUnweighted.length][2];

    for (int i = 0; i < y.length; i++) {
        y[i] = Math.sqrt(weights[i]) * yUnweighted[i];
        x[i][0] = Math.sqrt(weights[i]) * xUnweighted[i];
        x[i][1] = Math.sqrt(weights[i]);
    }

    OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    regression.setNoIntercept(true);
    regression.newSampleData(y, x);

    double[] regressionParameters = regression.estimateRegressionParameters();
    double slope = regressionParameters[0];
    double intercept = regressionParameters[1];

    System.out.println("y = " + slope + "*x + " + intercept);
}

This can be explained intuitively by the fact that in linear regression with u=0, if you take any point (x,y) and convert it to (xC,yC), the error for the new point will also get multiplied by C. In other words, linear regression already applies higher weight to points with higher x. We are minimizing the squared error, that's why we extract the roots of the weights.


I personally used org.apache.commons.math.stat.regression.SimpleRegression Class of the Apache Math library.

I also found a more lightweight class from Princeton university but didn't test it:

http://introcs.cs.princeton.edu/java/97data/LinearRegression.java.html


Here's a direct Java port of the C# code for weighted linear regression from the first link in Aleadam's answer:

https://github.com/lukehutch/WeightedLinearRegression.java

0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜