view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/Fitting.java @ 6777:48f6780c372d

S/Q relation: More Excel compat.
author Sascha L. Teichmann <teichmann@intevation.de>
date Wed, 07 Aug 2013 19:36:05 +0200
parents 9479cb7c8cd5
children b8f94e865875
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.sq;

import org.dive4elements.river.artifacts.math.fitting.Function;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.regression.SimpleRegression;

import org.apache.log4j.Logger;

public class Fitting
implements   Outlier.Callback
{
    // XXX: Hack to force linear fitting!
    private static final boolean USE_NON_LINEAR_FITTING =
        Boolean.getBoolean("minfo.sq.fitting.nonlinear");

    private static Logger log = Logger.getLogger(Fitting.class);

    public interface Callback {

        void afterIteration(
            double [] parameters,
            SQ []     measurements,
            SQ []     outliers,
            double    standardDeviation,
            double    chiSqr);
    } // interfacte

    protected Function function;

    protected double [] coeffs;

    protected org.dive4elements.river.artifacts.math.Function instance;

    protected double stdDevFactor;
    protected double chiSqr;

    protected Callback callback;

    public Fitting() {
    }

    public Fitting(Function function, double stdDevFactor) {
        this();
        this.function     = function;
        this.stdDevFactor = stdDevFactor;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public double getStdDevFactor() {
        return stdDevFactor;
    }

    public void setStdDevFactor(double stdDevFactor) {
        this.stdDevFactor = stdDevFactor;
    }

    @Override
    public void initialize(List<SQ> sqs) throws MathException {

        if (USE_NON_LINEAR_FITTING
        || function.getInitialGuess().length != 2) {
            nonLinearFitting(sqs);
        }
        else {
            linearFitting(sqs);
        }
    }

    protected void linearFitting(List<SQ> sqs) {

        coeffs = linearRegression(sqs);

        instance = function.instantiate(coeffs);
    }

    protected double [] linearRegression(List<SQ> sqs) {

        SimpleRegression reg = new SimpleRegression();

        int invalidPoints = 0;
        for (SQ sq: sqs) {
            double s = sq.getS();
            double q = sq.getQ();
            if (s <= 0d || q <= 0d) {
                ++invalidPoints;
                continue;
            }
            reg.addData(Math.log(q), Math.log(s));
        }

        if (sqs.size() - invalidPoints < 2) {
            log.debug("not enough points");
            return new double [] { 0, 0 };
        }

        double a = Math.exp(reg.getIntercept());
        double b = reg.getSlope();

        if (log.isDebugEnabled()) {
            log.debug("invalid points: " +
                invalidPoints + " (" + sqs.size() + ")");
            log.debug("a: " + a + " (" + Math.log(a) + ")");
            log.debug("b: " + b);
        }

        return new double [] { a, b };
    }


    protected void nonLinearFitting(List<SQ> sqs) throws MathException {

        LevenbergMarquardtOptimizer optimizer =
            new LevenbergMarquardtOptimizer();

        CurveFitter cf = new CurveFitter(optimizer);

        for (SQ sq: sqs) {
            cf.addObservedPoint(sq.getS(), sq.getQ());
        }

        coeffs = cf.fit(
            function, function.getInitialGuess());

        instance = function.instantiate(coeffs);

        chiSqr = optimizer.getChiSquare();
    }

    @Override
    public double eval(SQ sq) {
        double s = instance.value(sq.q);
        return sq.s - s;
    }

    @Override
    public void iterationFinished(
        double   standardDeviation,
        SQ       outlier,
        List<SQ> remainings
    ) {
        if (log.isDebugEnabled()) {
            log.debug("iterationFinished ----");
            log.debug(" num remainings: " + remainings.size());
            log.debug(" has outlier: " + outlier != null);
            log.debug(" standardDeviation: " + standardDeviation);
            log.debug(" Chi^2: " + chiSqr);
            log.debug("---- iterationFinished");
        }
        callback.afterIteration(
            coeffs,
            remainings.toArray(new SQ[remainings.size()]),
            outlier != null ? new SQ [] { outlier} : new SQ [] {},
            standardDeviation,
            chiSqr);
    }

    protected static final List<SQ> onlyValid(List<SQ> sqs) {

        List<SQ> good = new ArrayList<SQ>(sqs.size());

        for (SQ sq: sqs) {
            if (sq.isValid()) {
                good.add(sq);
            }
        }

        return good;
    }

    public boolean fit(List<SQ> sqs, String method, Callback callback) {

        sqs = onlyValid(sqs);

        if (sqs.size() < 2) {
            log.warn("Too less points for fitting.");
            return false;
        }

        this.callback = callback;

        try {
            Outlier.detectOutliers(this, sqs, stdDevFactor, method);
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org