view flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3190:49fe2ed03c12

SQ: Refactored fitting to better fit the data types of SQResult. flys-artifacts/trunk@4805 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Tue, 26 Jun 2012 17:20:31 +0000
parents 1e46ced2bb57
children 1df6984628c3
line wrap: on
line source
package de.intevation.flys.artifacts.model.sq;

import de.intevation.flys.artifacts.math.fitting.Function;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;

import org.apache.log4j.Logger;

public class Fitting
{
    private static Logger log = Logger.getLogger(Fitting.class);

    protected Function function;

    protected double [] parameters;

    protected double stdDevFactor;

    protected double standardDeviation;

    protected double chiSqr;

    protected SQ [] remaining;

    protected List<SQ []> outliers;

    public Fitting() {
    }

    public Fitting(Function function, double stdDevFactor) {
        this.function     = function;
        this.stdDevFactor = stdDevFactor;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public double [] getParameters() {
        return parameters;
    }

    public void setParameters(double [] parameters) {
        this.parameters = parameters;
    }

    public double getStdDevFactor() {
        return stdDevFactor;
    }

    public void setStdDevFactor(double stdDevFactor) {
        this.stdDevFactor = stdDevFactor;
    }

    public double getStandardDeviation() {
        return standardDeviation;
    }

    public void setStandardDeviation(double standardDeviation) {
        this.standardDeviation = standardDeviation;
    }

    public double getChiSqr() {
        return chiSqr;
    }

    public void setChiSqr(double chiSqr) {
        this.chiSqr = chiSqr;
    }

    public SQ [] getRemaining() {
        return remaining;
    }

    public void setRemaining(SQ [] remaining) {
        this.remaining = remaining;
    }

    public List<SQ []> getOutliers() {
        return outliers;
    }

    public void setOutliers(List<SQ []> outliers) {
        this.outliers = outliers;
    }

    public void reset() {
        outliers          = null;
        remaining         = null;
        parameters        = null;
        standardDeviation = 0d;
        standardDeviation = 0d;
        chiSqr            = 0d;
    }

    protected static final List<SQ> onlyValid(List<SQ> sqs) {

        List<SQ> good = new ArrayList<SQ>(sqs.size());

        for (SQ sq: sqs) {
            if (sq.isValid()) {
                good.add(sq);
            }
        }

        return good;
    }

    public boolean fit(List<SQ> sqs) {

        sqs = onlyValid(sqs);

        if (sqs.size() < 2) {
            log.warn("Too less points for fitting.");
            return false;
        }

        final LevenbergMarquardtOptimizer lmo =
            new LevenbergMarquardtOptimizer();

        CurveFitter cf = new CurveFitter(lmo);

        for (SQ sq: sqs) {
            cf.addObservedPoint(sq.getQ(), sq.getS());
        }

        try {
            parameters = cf.fit(function, function.getInitialGuess());
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        chiSqr = lmo.getChiSquare();

        final de.intevation.flys.artifacts.math.Function [] instance = {
            function.instantiate(parameters)
        };

        try {
            remaining = Outlier.detectOutliers(
                new Outlier.Callback() {

                    List<List<SQ>> outliers =
                        new ArrayList<List<SQ>>();

                    int currentIteration;

                    @Override
                    public double eval(SQ sq) {
                        double s = instance[0].value(sq.q);
                        return s - sq.s;
                    }

                    @Override
                    public void iteration(int i) {
                        currentIteration = i;
                    }

                    @Override
                    public void outlier(SQ sq) {
                        if (currentIteration > outliers.size()) {
                            outliers.add(new ArrayList<SQ>(2));
                        }
                        outliers.get(currentIteration-1).add(sq);
                    }

                    @Override
                    public void standardDeviation(double stdDev) {
                        setStandardDeviation(stdDev);
                    }

                    @Override
                    public void reinitialize(Iterator<SQ> good)
                    throws MathException
                    {
                        CurveFitter cf = new CurveFitter(lmo);
                        while (good.hasNext()) {
                            SQ sq = good.next();
                            cf.addObservedPoint(sq.getQ(), sq.getS());
                        }

                        parameters = cf.fit(
                            function, function.getInitialGuess());

                        instance[0] = function.instantiate(parameters);

                        chiSqr = lmo.getChiSquare();
                    }

                    @Override
                    public void finished() {
                        List<SQ []> result =
                            new ArrayList<SQ []>(outliers.size());

                        for (List<SQ> ols: outliers) {
                            result.add(ols.toArray(new SQ[ols.size()]));
                        }

                        setOutliers(result);
                    }
                },
                sqs,
                stdDevFactor);
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org