view artifacts/src/main/java/org/dive4elements/river/artifacts/model/sq/Fitting.java @ 8605:b0e5a2ce0b09

(issue1750) Fix id handling with oracle On oracle the id is returned as BigDecimal. This could probably be done better with hibernate but checking the Object also works.
author Andre Heinecke <andre.heinecke@intevation.de>
date Fri, 20 Mar 2015 18:42:00 +0100
parents 51eb6491c537
children 0a5239a1e46e
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.sq;

import org.dive4elements.artifacts.common.utils.StringUtils;
import org.dive4elements.river.artifacts.math.fitting.Function;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.regression.SimpleRegression;

import org.apache.log4j.Logger;

public class Fitting
implements   Outlier.Callback
{
    // XXX: Hack to force linear fitting!
    private static final boolean USE_NON_LINEAR_FITTING =
        Boolean.getBoolean("minfo.sq.fitting.nonlinear");

    private static Logger log = Logger.getLogger(Fitting.class);

    public interface Callback {

        void afterIteration(
            double [] parameters,
            SQ []     measurements,
            SQ []     outliers,
            double    standardDeviation,
            double    chiSqr);
    } // interfacte

    protected Function function;

    protected double [] coeffs;

    protected org.dive4elements.river.artifacts.math.Function instance;

    protected double stdDevFactor;
    protected double chiSqr;

    protected Callback callback;

    protected SQ.View sqView;

    public Fitting() {
    }

    public Fitting(Function function, double stdDevFactor, SQ.View sqView) {
        this.function     = function;
        this.stdDevFactor = stdDevFactor;
        this.sqView       = sqView;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public double getStdDevFactor() {
        return stdDevFactor;
    }

    public void setStdDevFactor(double stdDevFactor) {
        this.stdDevFactor = stdDevFactor;
    }

    @Override
    public void initialize(List<SQ> sqs) throws MathException {

        if (USE_NON_LINEAR_FITTING
        || function.getParameterNames().length != 2) {
            nonLinearFitting(sqs);
        }
        else {
            linearFitting(sqs);
        }
    }

    protected void linearFitting(List<SQ> sqs) {
        coeffs   = linearRegression(sqs);
        instance = function.instantiate(coeffs);
    }

    protected double [] linearRegression(List<SQ> sqs) {

        String [] pns = function.getParameterNames();
        double [] result = new double[pns.length];

        if (sqs.size() < 2) {
            log.debug("not enough points");
            return result;
        }

        SimpleRegression reg = new SimpleRegression();

        for (SQ sq: sqs) {
            double s = sqView.getS(sq);
            double q = sqView.getQ(sq);
            reg.addData(q, s);
        }

        double m = reg.getIntercept();
        double b = reg.getSlope();

        if (log.isDebugEnabled()) {
            log.debug("m: " + m);
            log.debug("b: " + b);
        }

        int mIdx = StringUtils.indexOf("m", pns);
        int bIdx = StringUtils.indexOf("b", pns);

        if (mIdx == -1 || bIdx == -1) {
            log.error("index not found: " + mIdx + " " + bIdx);
            return result;
        }

        result[bIdx] = m;
        result[mIdx] = b;

        return result;
    }


    protected void nonLinearFitting(List<SQ> sqs) throws MathException {

        LevenbergMarquardtOptimizer optimizer =
            new LevenbergMarquardtOptimizer();

        CurveFitter cf = new CurveFitter(optimizer);

        for (SQ sq: sqs) {
            cf.addObservedPoint(sqView.getQ(sq), sqView.getS(sq));
        }

        coeffs = cf.fit(
            function, function.getInitialGuess());

        instance = function.instantiate(coeffs);

        chiSqr = optimizer.getChiSquare();
    }

    @Override
    public double eval(SQ sq) {
        double s = instance.value(sqView.getQ(sq));
        return sqView.getS(sq) - s;
    }

    @Override
    public void iterationFinished(
        double   standardDeviation,
        SQ       outlier,
        List<SQ> remainings
    ) {
        if (log.isDebugEnabled()) {
            log.debug("iterationFinished ----");
            log.debug(" num remainings: " + remainings.size());
            log.debug(" has outlier: " + outlier != null);
            log.debug(" standardDeviation: " + standardDeviation);
            log.debug(" Chi^2: " + chiSqr);
            log.debug("---- iterationFinished");
        }
        callback.afterIteration(
            coeffs,
            remainings.toArray(new SQ[remainings.size()]),
            outlier != null ? new SQ [] { outlier} : new SQ [] {},
            standardDeviation,
            chiSqr);
    }

    public boolean fit(List<SQ> sqs, String  method, Callback callback) {

        if (sqs.size() < 2) {
            log.warn("Too less points for fitting.");
            return false;
        }

        sqs = new ArrayList<SQ>(sqs);

        this.callback = callback;

        try {
            Outlier.detectOutliers(this, sqs, stdDevFactor, method);
        }
        catch (MathException me) {
            log.warn(me);
            return false;
        }

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org