view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/Fitting.java @ 6761:9479cb7c8cd5

flys/issue748: Force linear curve fitting. This is a real hack! Set the system property "minfo.sq.fitting.nonlinear" to re-enable the old behavior.
author Sascha L. Teichmann <teichmann@intevation.de>
date Tue, 06 Aug 2013 17:00:49 +0200
parents af13ceeba52a
children 48f6780c372d
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.sq;

import org.dive4elements.river.artifacts.math.fitting.Function;
import org.dive4elements.river.artifacts.math.fitting.Linear;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.AbstractLeastSquaresOptimizer;
import org.apache.commons.math.optimization.general.GaussNewtonOptimizer;
import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;

import org.apache.log4j.Logger;

public class Fitting
implements   Outlier.Callback
{
    // XXX: Hack to force linear fitting!
    private static final boolean USE_NON_LINEAR_FITTING =
        Boolean.getBoolean("minfo.sq.fitting.nonlinear");

    private static Logger log = Logger.getLogger(Fitting.class);

    public interface Callback {

        void afterIteration(
            double [] parameters,
            SQ []     measurements,
            SQ []     outliers,
            double    standardDeviation,
            double    chiSqr);
    } // interfacte

    protected Function function;

    protected double [] coeffs;

    protected org.dive4elements.river.artifacts.math.Function instance;

    protected double stdDevFactor;
    protected double chiSqr;

    protected Callback callback;

    public Fitting() {
    }

    public Fitting(Function function, double stdDevFactor) {
        this();
        this.function     = function;
        this.stdDevFactor = stdDevFactor;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public double getStdDevFactor() {
        return stdDevFactor;
    }

    public void setStdDevFactor(double stdDevFactor) {
        this.stdDevFactor = stdDevFactor;
    }

    @Override
    public void initialize(List<SQ> sqs) throws MathException {

        AbstractLeastSquaresOptimizer optimizer = getOptimizer();

        CurveFitter cf = new CurveFitter(optimizer);
        double [] values = new double[2];
        for (SQ sq: sqs) {
            values[0] = sq.getQ();
            values[1] = sq.getS();
            transformInputValues(values);
            cf.addObservedPoint(values[0], values[1]);
        }

        coeffs = cf.fit(
            function, function.getInitialGuess());

        transformCoeffsBack(coeffs);

        instance = function.instantiate(coeffs);

        chiSqr = optimizer.getChiSquare();
    }

    protected Function getFunction(Function function) {
        return USE_NON_LINEAR_FITTING
            ? function
            : Linear.INSTANCE;
    }

    protected void transformInputValues(double [] values) {
        if (!USE_NON_LINEAR_FITTING) {
            for (int i = 0; i < values.length; ++i) {
                values[i] = Math.log(values[i]);
            }
        }
    }

    protected AbstractLeastSquaresOptimizer getOptimizer() {
        return USE_NON_LINEAR_FITTING
            ? new LevenbergMarquardtOptimizer()
            : new GaussNewtonOptimizer(false);
    }

    protected void transformCoeffsBack(double [] coeffs) {
        if (!USE_NON_LINEAR_FITTING && coeffs.length > 0) {
            coeffs[0] = Math.exp(coeffs[0]);
        }
    }

    @Override
    public double eval(SQ sq) {
        double s = instance.value(sq.q);
        return sq.s - s;
    }

    @Override
    public void iterationFinished(
        double   standardDeviation,
        SQ       outlier,
        List<SQ> remainings
    ) {
        if (log.isDebugEnabled()) {
            log.debug("iterationFinished ----");
            log.debug(" num remainings: " + remainings.size());
            log.debug(" has outlier: " + outlier != null);
            log.debug(" standardDeviation: " + standardDeviation);
            log.debug(" Chi^2: " + chiSqr);
            log.debug("---- iterationFinished");
        }
        callback.afterIteration(
            coeffs,
            remainings.toArray(new SQ[remainings.size()]),
            outlier != null ? new SQ [] { outlier} : new SQ [] {},
            standardDeviation,
            chiSqr);
    }

    protected static final List<SQ> onlyValid(List<SQ> sqs) {

        List<SQ> good = new ArrayList<SQ>(sqs.size());

        for (SQ sq: sqs) {
            if (sq.isValid()) {
                good.add(sq);
            }
        }

        return good;
    }

    public boolean fit(List<SQ> sqs, String method, Callback callback) {

        sqs = onlyValid(sqs);

        if (sqs.size() < 2) {
            log.warn("Too less points for fitting.");
            return false;
        }

        this.callback = callback;

        try {
            Outlier.detectOutliers(this, sqs, stdDevFactor, method);
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org