view artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba

Implemented alternative fitting strategy for Log-Linear function.
author Gernot Belger <g.belger@bjoernsen.de>
date Mon, 02 Dec 2019 17:56:15 +0100
parents
children
line wrap: on
line source
/** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde
 * Software engineering by
 *  Björnsen Beratende Ingenieure GmbH
 *  Dr. Schumacher Ingenieurbüro für Wasser und Umwelt
 *
 * This file is Free Software under the GNU AGPL (>=v3)
 * and comes with ABSOLUTELY NO WARRANTY! Check out the
 * documentation coming with Dive4Elements River for details.
 */
package org.dive4elements.river.artifacts.model.fixings.fitting;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map.Entry;
import java.util.NavigableMap;
import java.util.TreeMap;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.MaxIter;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.util.Pair;
import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result;

import au.com.bytecode.opencsv.CSVReader;

/**
 * @author Gernot Belger
 */
public class LinearizedFittingTest {

    public static void main(final String[] args) throws IOException {

        // read test data
        final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData();

        for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) {

            final BigDecimal station = entry.getKey();
            final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue();
            optimizeA(station, testSample);
        }
    }

    private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) {

        /* extrakt observations */
        final double[] obsDischarges = new double[testSample.size()];
        final double[] obsWaterlevels = new double[testSample.size()];
        for (int i = 0; i < obsWaterlevels.length; i++) {
            obsDischarges[i] = testSample.get(i).getKey().doubleValue();
            obsWaterlevels[i] = testSample.get(i).getValue().doubleValue();
        }

        final TargetFunction targetFunction = new TargetFunction(obsDischarges);
        final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction);

        final Result directEstimation = estimateADirect(sqrtErrorFunction);
        final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize();

        final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction);
        final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction);

        printResult(station, testSample, directEstimation);
        printResult(station, testSample, linearEstimation);
        printResult(station, testSample, directPostEstimation);
        printResult(station, testSample, linearPostEstimation);
    }

    private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) {

        final double a = estimation.getA();
        final double b = estimation.getB();
        final double m = estimation.getM();
        // final double error = estimation.getError();
        final double aIndex = estimation.getBestAindex();

        final double aLow = Math.pow(10, aIndex - 1);
        final double aHigh = Math.pow(10, aIndex + 1);

        return directOptimize(aLow, aHigh, a, b, m, errorFunction);
    }

    private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) {

        final double a = optimize.getA();
        final double b = optimize.getB();
        final double m = optimize.getM();
        final double error = optimize.getError();

        System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error);

        for (final Pair<BigDecimal, BigDecimal> entry : testSample) {

            final BigDecimal waterlevel = entry.getSecond();
            System.out.print(' ');
            System.out.print(waterlevel);
        }
        System.out.println();
    }

    private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) {

        Result best = null;
        double leastError = Double.POSITIVE_INFINITY;

        // iteration über a von 10^^0 bis 10^^20
        for (int i = 0; i < 20; i++) {

            final double aStart = Math.pow(10, i);
            final double aLow = Math.pow(10, i - 1);
            final double aHigh = Math.pow(10, i + 1);

            final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction);
            final double error = result.getError();

            if (error < leastError) {
                leastError = error;
                best = result.withBestAIndex(i);
            }
        }

        return best;
    }

    private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m,
            final SqrtErrorFunction sqrtErrorFunction) {

        // n = 3
        // [n+2, (n+1)(n+2)/2]
        // --> [5, 10]
        final int interpolationPoints = 10;
        final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints);

        /* optimization data */
        final MultivariateFunction function = new MultivariateFunction() {

            @Override
            public double value(final double[] point) {
                return sqrtErrorFunction.value(point[0], point[1], point[2]);
            }
        };

        final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE);
        final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE);

        final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 });
        final double[] startValues = new double[] { a, b, m };
        final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval,
                maxIter);

        final Double error = result.getValue();
        final double[] point = result.getPoint();

        final double aEstimation = point[0];
        final double bEstimation = point[1];
        final double mEstimation = point[2];

        return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN);
    }

    private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException {

        final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>();

        try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))),
                '\t')) {

            final String[] header = reader.readNext();
            if (header == null)
                throw new IllegalStateException();

            if (header.length < 2)
                throw new IllegalStateException();

            while (true) {
                final String[] line = reader.readNext();
                if (line == null)
                    break;

                if (line.length != header.length)
                    throw new IllegalStateException();

                final BigDecimal discharge = parseDecimal(line[0]);
                if (discharge == null)
                    continue;

                for (int column = 1; column < line.length; column++) {

                    final BigDecimal station = parseDecimal(header[column]);
                    if (station == null)
                        continue;

                    if (!data.containsKey(station))
                        data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>());
                    final List<Pair<BigDecimal, BigDecimal>> points = data.get(station);

                    final BigDecimal waterlevel = parseDecimal(line[column]);
                    if (waterlevel != null)
                        points.add(Pair.create(discharge, waterlevel));
                }
            }
        }

        return data;
    }

    private static BigDecimal parseDecimal(final String token) {

        if (StringUtils.isBlank(token))
            return null;

        if ("nan".equalsIgnoreCase(token))
            return null;

        return new BigDecimal(token);
    }
}

http://dive4elements.wald.intevation.org