view artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 9415:9744ce3c3853

Rework of fixanalysis computation and dWt and WQ facets. Got rid of strange remapping and bitshifting code by explicitely saving the column information and using it in the facets. The facets also put the valid station range into their xml-metadata
author gernotbelger
date Thu, 16 Aug 2018 16:27:53 +0200
parents ddcd52d239cd
children f51e23eb036a
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.fixings;

import java.util.ArrayList;
import java.util.Date;
import java.util.List;

import org.apache.commons.math.MathException;
import org.apache.commons.math.optimization.fitting.CurveFitter;
import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.apache.log4j.Logger;
import org.dive4elements.river.artifacts.math.GrubbsOutlier;
import org.dive4elements.river.artifacts.math.fitting.Function;

public class Fitting {
    private static Logger log = Logger.getLogger(Fitting.class);

    /** Use instance of this factory to find meta infos for outliers. */
    public interface QWDFactory {
        QWD create(double q, double w, double deltaW, boolean isOutlier);
    }

    private static final class FittingData {
        public final FixingColumnWithData event;
        public final double w;
        public final double q;
        public final boolean isInterpolated;

        public FittingData(final FixingColumnWithData event, final double w, final double q, final boolean isInterpolated) {
            this.event = event;
            this.w = w;
            this.q = q;
            this.isInterpolated = isInterpolated;
        }
    }

    private final double chiSqr;

    private final double[] parameters;

    private final double standardDeviation;

    private final double maxQ;

    public Fitting(final double[] parameters, final double standardDeviation, final double chiSqr, final double maxQ) {
        this.parameters = parameters;
        this.standardDeviation = standardDeviation;
        this.chiSqr = chiSqr;
        this.maxQ = maxQ;
    }

    public double getChiSquare() {
        return this.chiSqr;
    }

    public double getMaxQ() {
        return this.maxQ;
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public double getStandardDeviation() {
        return this.standardDeviation;
    }

    public static Fitting fit(final FixResultColumns resultColumns, final double km, final Function function, final boolean checkOutliers,
            final List<FixingColumnWithData> eventColumns) {

        final int numEvents = eventColumns.size();
        final List<FittingData> data = new ArrayList<>(numEvents);

        final double[] wTemp = new double[1];

        for (final FixingColumnWithData event : eventColumns) {

            final double q = event.getQ(km);
            if (!Double.isNaN(q)) {
                final boolean isInterpolated = !event.getW(km, wTemp, 0);
                final double w = wTemp[0];

                if (!Double.isNaN(w))
                    data.add(new FittingData(event, w, q, isInterpolated));
            }
        }

        if (data.size() < 2) {
            log.warn("Not enough data for fitting.");
            return null;
        }

        final List<FittingData> outliers = new ArrayList<>(data.size());

        org.dive4elements.river.artifacts.math.Function instance = null;
        LevenbergMarquardtOptimizer lmo = null;
        double[] parameters = null;

        for (;;) {

            parameters = null;

            for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) {

                lmo = new LevenbergMarquardtOptimizer();
                lmo.setCostRelativeTolerance(tolerance);
                lmo.setOrthoTolerance(tolerance);
                lmo.setParRelativeTolerance(tolerance);

                try {
                    final CurveFitter cf = new CurveFitter(lmo);

                    for (final FittingData fittingData : data)
                        cf.addObservedPoint(fittingData.q, fittingData.w);

                    parameters = cf.fit(function, function.getInitialGuess());
                    break;
                }
                catch (final MathException me) {
                    if (log.isDebugEnabled()) {
                        log.debug("tolerance " + tolerance + " + failed.", me);
                    }
                }
            }

            if (parameters == null) {
                /*
                 * log.debug("Parameters is null");
                 * for (int i = 0, N = xs.size(); i < N; ++i) {
                 * log.debug("DATA: " + xs.getQuick(i) + " " + ys.getQuick(i));
                 * }
                 */
                return null;
            }

            // This is the parameterized function for a given km.
            instance = function.instantiate(parameters);

            if (!checkOutliers)
                break;

            /* find the outlier */
            final List<Double> inputs = new ArrayList<>(data.size());
            for (final FittingData fittingData : data) {

                double y = instance.value(fittingData.q);
                if (Double.isNaN(y))
                    y = Double.MAX_VALUE;

                inputs.add(Double.valueOf(fittingData.w - y));
            }

            final Integer outlier = GrubbsOutlier.findOutlier(inputs);
            if (outlier == null)
                break;

            final int idx = outlier.intValue();

            // outliers.add(qwdFactory.create(xs.getQuick(idx), ys.getQuick(idx), Double.NaN, true));
            final FittingData removed = data.remove(idx);
            outliers.add(removed);
        }

        /* now build result data */
        final List<QWD> qwds = new ArrayList<>(data.size());

        /* calculate dW of outliers against the resulting function and add them to results */
        for (final FittingData outlier : outliers) {
            final QWD qwd = createQWD(outlier, instance, true);
            qwds.add(qwd);
            resultColumns.addQWD(outlier.event, km, qwd);
        }

        /*
         * calculate dW of used values against the resulting function and add them to results , also calculate standard *
         * deviation
         */
        final StandardDeviation stdDev = new StandardDeviation();
        double maxQ = -Double.MAX_VALUE;

        for (final FittingData fittingData : data) {

            final QWD qwd = createQWD(fittingData, instance, false);
            qwds.add(qwd);
            resultColumns.addQWD(fittingData.event, km, qwd);

            stdDev.increment(qwd.getDeltaW());

            final double q = qwd.getQ();
            if (q > maxQ)
                maxQ = q;
        }

        final double standardDeviation = stdDev.getResult();

        final double chiSqr = lmo.getChiSquare();

        return new Fitting(parameters, standardDeviation, chiSqr, maxQ);
    }

    private static QWD createQWD(final FittingData data, final org.dive4elements.river.artifacts.math.Function function, final boolean isOutlier) {

        final FixingColumnWithData event = data.event;
        final Date date = event.getDate();
        final boolean isInterpolated = data.isInterpolated;

        final double w = data.w;
        final double q = data.q;
        final double dw = (w - function.value(q)) * 100.0;

        return new QWD(q, w, date, isInterpolated, dw, isOutlier);
    }
}

http://dive4elements.wald.intevation.org