view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/Fitting.java @ 7300:83bb52fa0c32

(issue1529) Be more tolerant in the fitting. The invalid value warning is removed because invalid data is expected there when datapoints are not valid for this KM
author Andre Heinecke <aheinecke@intevation.de>
date Fri, 11 Oct 2013 18:40:33 +0200
parents 51eb6491c537
children 0a5239a1e46e
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.sq;

import org.dive4elements.artifacts.common.utils.StringUtils;
import org.dive4elements.river.artifacts.math.fitting.Function;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.regression.SimpleRegression;

import org.apache.log4j.Logger;

public class Fitting
implements   Outlier.Callback
{
    // XXX: Hack to force linear fitting!
    private static final boolean USE_NON_LINEAR_FITTING =
        Boolean.getBoolean("minfo.sq.fitting.nonlinear");

    private static Logger log = Logger.getLogger(Fitting.class);

    public interface Callback {

        void afterIteration(
            double [] parameters,
            SQ []     measurements,
            SQ []     outliers,
            double    standardDeviation,
            double    chiSqr);
    } // interfacte

    protected Function function;

    protected double [] coeffs;

    protected org.dive4elements.river.artifacts.math.Function instance;

    protected double stdDevFactor;
    protected double chiSqr;

    protected Callback callback;

    protected SQ.View sqView;

    public Fitting() {
    }

    public Fitting(Function function, double stdDevFactor, SQ.View sqView) {
        this.function     = function;
        this.stdDevFactor = stdDevFactor;
        this.sqView       = sqView;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public double getStdDevFactor() {
        return stdDevFactor;
    }

    public void setStdDevFactor(double stdDevFactor) {
        this.stdDevFactor = stdDevFactor;
    }

    @Override
    public void initialize(List<SQ> sqs) throws MathException {

        if (USE_NON_LINEAR_FITTING
        || function.getParameterNames().length != 2) {
            nonLinearFitting(sqs);
        }
        else {
            linearFitting(sqs);
        }
    }

    protected void linearFitting(List<SQ> sqs) {
        coeffs   = linearRegression(sqs);
        instance = function.instantiate(coeffs);
    }

    protected double [] linearRegression(List<SQ> sqs) {

        String [] pns = function.getParameterNames();
        double [] result = new double[pns.length];

        if (sqs.size() < 2) {
            log.debug("not enough points");
            return result;
        }

        SimpleRegression reg = new SimpleRegression();

        for (SQ sq: sqs) {
            double s = sqView.getS(sq);
            double q = sqView.getQ(sq);
            reg.addData(q, s);
        }

        double m = reg.getIntercept();
        double b = reg.getSlope();

        if (log.isDebugEnabled()) {
            log.debug("m: " + m);
            log.debug("b: " + b);
        }

        int mIdx = StringUtils.indexOf("m", pns);
        int bIdx = StringUtils.indexOf("b", pns);

        if (mIdx == -1 || bIdx == -1) {
            log.error("index not found: " + mIdx + " " + bIdx);
            return result;
        }

        result[bIdx] = m;
        result[mIdx] = b;

        return result;
    }


    protected void nonLinearFitting(List<SQ> sqs) throws MathException {

        LevenbergMarquardtOptimizer optimizer =
            new LevenbergMarquardtOptimizer();

        CurveFitter cf = new CurveFitter(optimizer);

        for (SQ sq: sqs) {
            cf.addObservedPoint(sqView.getQ(sq), sqView.getS(sq));
        }

        coeffs = cf.fit(
            function, function.getInitialGuess());

        instance = function.instantiate(coeffs);

        chiSqr = optimizer.getChiSquare();
    }

    @Override
    public double eval(SQ sq) {
        double s = instance.value(sqView.getQ(sq));
        return sqView.getS(sq) - s;
    }

    @Override
    public void iterationFinished(
        double   standardDeviation,
        SQ       outlier,
        List<SQ> remainings
    ) {
        if (log.isDebugEnabled()) {
            log.debug("iterationFinished ----");
            log.debug(" num remainings: " + remainings.size());
            log.debug(" has outlier: " + outlier != null);
            log.debug(" standardDeviation: " + standardDeviation);
            log.debug(" Chi^2: " + chiSqr);
            log.debug("---- iterationFinished");
        }
        callback.afterIteration(
            coeffs,
            remainings.toArray(new SQ[remainings.size()]),
            outlier != null ? new SQ [] { outlier} : new SQ [] {},
            standardDeviation,
            chiSqr);
    }

    public boolean fit(List<SQ> sqs, String  method, Callback callback) {

        if (sqs.size() < 2) {
            log.warn("Too less points for fitting.");
            return false;
        }

        sqs = new ArrayList<SQ>(sqs);

        this.callback = callback;

        try {
            Outlier.detectOutliers(this, sqs, stdDevFactor, method);
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org