view artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 9360:ddcd52d239cd

Outliers in fixation calculation are now shown within the other 'B' event themes and get a separate symbol (triangle). Removed old outliers theme. Also consider showpoints property. Also consider pointsize property.
author gernotbelger
date Wed, 01 Aug 2018 17:13:52 +0200
parents a3f318347707
children 9744ce3c3853
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.fixings;

import gnu.trove.TDoubleArrayList;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;
import org.apache.commons.math.optimization.fitting.CurveFitter;
import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.apache.log4j.Logger;

import org.dive4elements.river.artifacts.math.GrubbsOutlier;
import org.dive4elements.river.artifacts.math.fitting.Function;

public class Fitting {
    private static Logger log = Logger.getLogger(Fitting.class);

    /** Use instance of this factory to find meta infos for outliers. */
    public interface QWDFactory {
        QWD create(double q, double w, double deltaW, boolean isOutlier);
    }

    private final double chiSqr;

    private final double[] parameters;

    private final double standardDeviation;

    private final List<QWD> qwds;

    public Fitting(final double[] parameters, final double standardDeviation, final double chiSqr, final List<QWD> qwds) {
        this.parameters = parameters;
        this.standardDeviation = standardDeviation;
        this.chiSqr = chiSqr;
        this.qwds = qwds;
    }

    public double getChiSquare() {
        return chiSqr;
    }

    /**
     * Returns all referenced and outliers as one array.
     */
    public QWD[] getFixingsArray() {
        return qwds.toArray(new QWD[qwds.size()]);
    }

    public double getMaxQ() {
        double maxQ = -Double.MAX_VALUE;

        for (QWD qw : qwds) {
            final double q = qw.getQ();
            if (!qw.isOutlier() && q > maxQ)
                maxQ = q;
        }

        return maxQ;
    }

    public double[] getParameters() {
        return parameters;
    }

    public double getStandardDeviation() {
        return standardDeviation;
    }

    public static Fitting fit(final Function function, final QWDFactory qwdFactory, final boolean checkOutliers, final double[] qs, final double[] ws) {

        final TDoubleArrayList xs = new TDoubleArrayList(qs.length);
        final TDoubleArrayList ys = new TDoubleArrayList(ws.length);

        for (int i = 0; i < qs.length; ++i) {
            if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) {
                xs.add(qs[i]);
                ys.add(ws[i]);
            }
        }

        if (xs.size() < 2) {
            log.warn("Too less points.");
            return null;
        }

        final List<Double> inputs = new ArrayList<>(xs.size());
        final List<QWD> qwds = new ArrayList<>(xs.size());
        final List<QWD> outliers = new ArrayList<>(xs.size());

        org.dive4elements.river.artifacts.math.Function instance = null;
        LevenbergMarquardtOptimizer lmo = null;
        double[] parameters = null;

        for (;;) {
            parameters = null;
            for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) {

                lmo = new LevenbergMarquardtOptimizer();
                lmo.setCostRelativeTolerance(tolerance);
                lmo.setOrthoTolerance(tolerance);
                lmo.setParRelativeTolerance(tolerance);

                CurveFitter cf = new CurveFitter(lmo);

                for (int i = 0, N = xs.size(); i < N; ++i) {
                    cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i));
                }

                try {
                    parameters = cf.fit(function, function.getInitialGuess());
                    break;
                }
                catch (MathException me) {
                    if (log.isDebugEnabled()) {
                        log.debug("tolerance " + tolerance + " + failed.", me);
                    }
                }
            }

            if (parameters == null) {
                /*
                 * log.debug("Parameters is null");
                 * for (int i = 0, N = xs.size(); i < N; ++i) {
                 * log.debug("DATA: " + xs.getQuick(i) + " " + ys.getQuick(i));
                 * }
                 */
                return null;
            }

            // This is the paraterized function for a given km.
            instance = function.instantiate(parameters);

            if (!checkOutliers)
                break;

            inputs.clear();

            for (int i = 0, N = xs.size(); i < N; ++i) {
                double y = instance.value(xs.getQuick(i));
                if (Double.isNaN(y)) {
                    y = Double.MAX_VALUE;
                }
                inputs.add(Double.valueOf(ys.getQuick(i) - y));
            }

            final Integer outlier = GrubbsOutlier.findOutlier(inputs);
            if (outlier == null)
                break;

            final int idx = outlier.intValue();
            outliers.add(qwdFactory.create(xs.getQuick(idx), ys.getQuick(idx), Double.NaN, true));
            xs.remove(idx);
            ys.remove(idx);
        }
        
        for (QWD outlier : outliers) {
            
            final double w = outlier.getW();
            final double q = outlier.getQ();
            
            final double dw = (w - instance.value(q)) * 100.0;

            outlier.setDeltaW(dw);
            
            qwds.add(outlier);
        }

        final StandardDeviation stdDev = new StandardDeviation();

        for (int i = 0; i < xs.size(); ++i) {

            final QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i), Double.NaN, false);

            double dw = (qwd.getW() - instance.value(qwd.getQ())) * 100.0;
            qwd.setDeltaW(dw);
            
            qwds.add(qwd);
            
            stdDev.increment(dw);
        }

        final double standardDeviation = stdDev.getResult();

        final double chiSqr = lmo.getChiSquare();

        return new Fitting(parameters, standardDeviation, chiSqr, qwds);
    }
}

http://dive4elements.wald.intevation.org