view artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 9628:bc50ecfc58c5

undo unwanted commit
author dnt_bjoernsen <d.tironi@bjoernsen.de>
date Mon, 14 Oct 2019 16:40:47 +0200
parents f51e23eb036a
children 0380717105ba
line wrap: on
line source
/* Copyright (C) 2011, 2012, 2013 by Bundesanstalt für Gewässerkunde
 * Software engineering by Intevation GmbH
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */

package org.dive4elements.river.artifacts.model.fixings;

import java.util.ArrayList;
import java.util.Date;
import java.util.List;

import org.apache.commons.math.MathException;
import org.apache.commons.math.optimization.fitting.CurveFitter;
import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.apache.log4j.Logger;
import org.dive4elements.river.artifacts.math.GrubbsOutlier;
import org.dive4elements.river.artifacts.math.fitting.Function;

public class Fitting {
    private static Logger log = Logger.getLogger(Fitting.class);

    /** Use instance of this factory to find meta infos for outliers. */
    public interface QWDFactory {
        QWD create(double q, double w, double deltaW, boolean isOutlier);
    }

    private static final class FittingData {
        public final FixingColumnWithData event;
        public final double w;
        public final double q;
        public final boolean isInterpolated;

        public FittingData(final FixingColumnWithData event, final double w, final double q, final boolean isInterpolated) {
            this.event = event;
            this.w = w;
            this.q = q;
            this.isInterpolated = isInterpolated;
        }
    }

    private final double chiSqr;

    private final double[] parameters;

    private final double standardDeviation;

    private final double maxQ;

    public Fitting(final double[] parameters, final double standardDeviation, final double chiSqr, final double maxQ) {
        this.parameters = parameters;
        this.standardDeviation = standardDeviation;
        this.chiSqr = chiSqr;
        this.maxQ = maxQ;
    }

    public double getChiSquare() {
        return this.chiSqr;
    }

    public double getMaxQ() {
        return this.maxQ;
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public double getStandardDeviation() {
        return this.standardDeviation;
    }

    public static Fitting fit(final FixResultColumns resultColumns, final double km, final Function function, final boolean checkOutliers,
            final List<FixingColumnWithData> eventColumns) {

        final int numEvents = eventColumns.size();
        final List<FittingData> data = new ArrayList<>(numEvents);

        final double[] wTemp = new double[1];

        for (final FixingColumnWithData event : eventColumns) {

            final double q = event.getQ(km);
            if (!Double.isNaN(q)) {
                final boolean isInterpolated = !event.getW(km, wTemp, 0);
                final double w = wTemp[0];

                if (!Double.isNaN(w))
                    data.add(new FittingData(event, w, q, isInterpolated));
            }
        }

        if (data.size() < 2) {
            log.warn("Not enough data for fitting.");
            return null;
        }

        final List<FittingData> outliers = new ArrayList<>(data.size());

        org.dive4elements.river.artifacts.math.Function instance = null;
        LevenbergMarquardtOptimizer lmo = null;
        double[] parameters = null;

        for (;;) {

            parameters = null;

            for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) {

                lmo = new LevenbergMarquardtOptimizer();
                lmo.setCostRelativeTolerance(tolerance);
                lmo.setOrthoTolerance(tolerance);
                lmo.setParRelativeTolerance(tolerance);

                try {
                    final CurveFitter cf = new CurveFitter(lmo);

                    for (final FittingData fittingData : data)
                        cf.addObservedPoint(fittingData.q, fittingData.w);

                    parameters = cf.fit(function, function.getInitialGuess());
                    break;
                }
                catch (final MathException me) {
                    if (log.isDebugEnabled()) {
                        log.debug("tolerance " + tolerance + " + failed.", me);
                    }
                }
            }

            if (parameters == null) {
                /*
                 * log.debug("Parameters is null");
                 * for (int i = 0, N = xs.size(); i < N; ++i) {
                 * log.debug("DATA: " + xs.getQuick(i) + " " + ys.getQuick(i));
                 * }
                 */
                return null;
            }

            // This is the parameterized function for a given km.
            instance = function.instantiate(parameters);

            if (!checkOutliers)
                break;

            /* find the outlier */
            final List<Double> inputs = new ArrayList<>(data.size());
            for (final FittingData fittingData : data) {

                double y = instance.value(fittingData.q);
                if (Double.isNaN(y))
                    y = Double.MAX_VALUE;

                inputs.add(Double.valueOf(fittingData.w - y));
            }

            final Integer outlier = GrubbsOutlier.findOutlier(inputs);
            if (outlier == null)
                break;

            final int idx = outlier.intValue();

            // outliers.add(qwdFactory.create(xs.getQuick(idx), ys.getQuick(idx), Double.NaN, true));
            final FittingData removed = data.remove(idx);
            outliers.add(removed);
        }

        /* now build result data */
        final List<QWD> qwds = new ArrayList<>(data.size());

        /* calculate dW of outliers against the resulting function and add them to results */
        for (final FittingData outlier : outliers) {
            final QWD qwd = createQWD(outlier, instance, true);
            qwds.add(qwd);
            resultColumns.addQWD(outlier.event, km, qwd);
        }

        /*
         * calculate dW of used values against the resulting function and add them to results , also calculate standard *
         * deviation
         */
        final StandardDeviation stdDev = new StandardDeviation();
        double maxQ = -Double.MAX_VALUE;

        for (final FittingData fittingData : data) {

            final QWD qwd = createQWD(fittingData, instance, false);
            qwds.add(qwd);
            resultColumns.addQWD(fittingData.event, km, qwd);

            stdDev.increment(qwd.getDeltaW());

            final double q = qwd.getQ();
            if (q > maxQ)
                maxQ = q;
        }

        final double standardDeviation = stdDev.getResult();

        final double chiSqr = lmo.getChiSquare();

        return new Fitting(parameters, standardDeviation, chiSqr, maxQ);
    }

    private static QWD createQWD(final FittingData data, final org.dive4elements.river.artifacts.math.Function function, final boolean isOutlier) {

        final FixingColumnWithData event = data.event;
        final Date date = event.getDate();
        final boolean isInterpolated = data.isInterpolated;

        final double w = data.w;
        final double q = data.q;
        final double dw = (w - function.value(q)) * 100.0;

        return new QWD(q, w, date, isInterpolated, dw, isOutlier);
    }
}

http://dive4elements.wald.intevation.org