diff artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba

Implemented alternative fitting strategy for Log-Linear function.
author Gernot Belger <g.belger@bjoernsen.de>
date Mon, 02 Dec 2019 17:56:15 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java	Mon Dec 02 17:56:15 2019 +0100
@@ -0,0 +1,228 @@
+/** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde
+ * Software engineering by
+ *  Björnsen Beratende Ingenieure GmbH
+ *  Dr. Schumacher Ingenieurbüro für Wasser und Umwelt
+ *
+ * This file is Free Software under the GNU AGPL (>=v3)
+ * and comes with ABSOLUTELY NO WARRANTY! Check out the
+ * documentation coming with Dive4Elements River for details.
+ */
+package org.dive4elements.river.artifacts.model.fixings.fitting;
+
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.NavigableMap;
+import java.util.TreeMap;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.commons.math3.analysis.MultivariateFunction;
+import org.apache.commons.math3.optim.InitialGuess;
+import org.apache.commons.math3.optim.MaxEval;
+import org.apache.commons.math3.optim.MaxIter;
+import org.apache.commons.math3.optim.PointValuePair;
+import org.apache.commons.math3.optim.SimpleBounds;
+import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
+import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
+import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
+import org.apache.commons.math3.util.Pair;
+import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result;
+
+import au.com.bytecode.opencsv.CSVReader;
+
+/**
+ * @author Gernot Belger
+ */
+public class LinearizedFittingTest {
+
+    public static void main(final String[] args) throws IOException {
+
+        // read test data
+        final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData();
+
+        for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) {
+
+            final BigDecimal station = entry.getKey();
+            final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue();
+            optimizeA(station, testSample);
+        }
+    }
+
+    private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) {
+
+        /* extrakt observations */
+        final double[] obsDischarges = new double[testSample.size()];
+        final double[] obsWaterlevels = new double[testSample.size()];
+        for (int i = 0; i < obsWaterlevels.length; i++) {
+            obsDischarges[i] = testSample.get(i).getKey().doubleValue();
+            obsWaterlevels[i] = testSample.get(i).getValue().doubleValue();
+        }
+
+        final TargetFunction targetFunction = new TargetFunction(obsDischarges);
+        final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction);
+
+        final Result directEstimation = estimateADirect(sqrtErrorFunction);
+        final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize();
+
+        final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction);
+        final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction);
+
+        printResult(station, testSample, directEstimation);
+        printResult(station, testSample, linearEstimation);
+        printResult(station, testSample, directPostEstimation);
+        printResult(station, testSample, linearPostEstimation);
+    }
+
+    private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) {
+
+        final double a = estimation.getA();
+        final double b = estimation.getB();
+        final double m = estimation.getM();
+        // final double error = estimation.getError();
+        final double aIndex = estimation.getBestAindex();
+
+        final double aLow = Math.pow(10, aIndex - 1);
+        final double aHigh = Math.pow(10, aIndex + 1);
+
+        return directOptimize(aLow, aHigh, a, b, m, errorFunction);
+    }
+
+    private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) {
+
+        final double a = optimize.getA();
+        final double b = optimize.getB();
+        final double m = optimize.getM();
+        final double error = optimize.getError();
+
+        System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error);
+
+        for (final Pair<BigDecimal, BigDecimal> entry : testSample) {
+
+            final BigDecimal waterlevel = entry.getSecond();
+            System.out.print(' ');
+            System.out.print(waterlevel);
+        }
+        System.out.println();
+    }
+
+    private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) {
+
+        Result best = null;
+        double leastError = Double.POSITIVE_INFINITY;
+
+        // iteration über a von 10^^0 bis 10^^20
+        for (int i = 0; i < 20; i++) {
+
+            final double aStart = Math.pow(10, i);
+            final double aLow = Math.pow(10, i - 1);
+            final double aHigh = Math.pow(10, i + 1);
+
+            final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction);
+            final double error = result.getError();
+
+            if (error < leastError) {
+                leastError = error;
+                best = result.withBestAIndex(i);
+            }
+        }
+
+        return best;
+    }
+
+    private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m,
+            final SqrtErrorFunction sqrtErrorFunction) {
+
+        // n = 3
+        // [n+2, (n+1)(n+2)/2]
+        // --> [5, 10]
+        final int interpolationPoints = 10;
+        final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints);
+
+        /* optimization data */
+        final MultivariateFunction function = new MultivariateFunction() {
+
+            @Override
+            public double value(final double[] point) {
+                return sqrtErrorFunction.value(point[0], point[1], point[2]);
+            }
+        };
+
+        final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE);
+        final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE);
+
+        final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 });
+        final double[] startValues = new double[] { a, b, m };
+        final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval,
+                maxIter);
+
+        final Double error = result.getValue();
+        final double[] point = result.getPoint();
+
+        final double aEstimation = point[0];
+        final double bEstimation = point[1];
+        final double mEstimation = point[2];
+
+        return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN);
+    }
+
+    private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException {
+
+        final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>();
+
+        try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))),
+                '\t')) {
+
+            final String[] header = reader.readNext();
+            if (header == null)
+                throw new IllegalStateException();
+
+            if (header.length < 2)
+                throw new IllegalStateException();
+
+            while (true) {
+                final String[] line = reader.readNext();
+                if (line == null)
+                    break;
+
+                if (line.length != header.length)
+                    throw new IllegalStateException();
+
+                final BigDecimal discharge = parseDecimal(line[0]);
+                if (discharge == null)
+                    continue;
+
+                for (int column = 1; column < line.length; column++) {
+
+                    final BigDecimal station = parseDecimal(header[column]);
+                    if (station == null)
+                        continue;
+
+                    if (!data.containsKey(station))
+                        data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>());
+                    final List<Pair<BigDecimal, BigDecimal>> points = data.get(station);
+
+                    final BigDecimal waterlevel = parseDecimal(line[column]);
+                    if (waterlevel != null)
+                        points.add(Pair.create(discharge, waterlevel));
+                }
+            }
+        }
+
+        return data;
+    }
+
+    private static BigDecimal parseDecimal(final String token) {
+
+        if (StringUtils.isBlank(token))
+            return null;
+
+        if ("nan".equalsIgnoreCase(token))
+            return null;
+
+        return new BigDecimal(token);
+    }
+}
\ No newline at end of file

http://dive4elements.wald.intevation.org