Mercurial > dive4elements > river
diff artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba
Implemented alternative fitting strategy for Log-Linear function.
author | Gernot Belger <g.belger@bjoernsen.de> |
---|---|
date | Mon, 02 Dec 2019 17:56:15 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java Mon Dec 02 17:56:15 2019 +0100 @@ -0,0 +1,228 @@ +/** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde + * Software engineering by + * Björnsen Beratende Ingenieure GmbH + * Dr. Schumacher Ingenieurbüro für Wasser und Umwelt + * + * This file is Free Software under the GNU AGPL (>=v3) + * and comes with ABSOLUTELY NO WARRANTY! Check out the + * documentation coming with Dive4Elements River for details. + */ +package org.dive4elements.river.artifacts.model.fixings.fitting; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.TreeMap; + +import org.apache.commons.lang.StringUtils; +import org.apache.commons.math3.analysis.MultivariateFunction; +import org.apache.commons.math3.optim.InitialGuess; +import org.apache.commons.math3.optim.MaxEval; +import org.apache.commons.math3.optim.MaxIter; +import org.apache.commons.math3.optim.PointValuePair; +import org.apache.commons.math3.optim.SimpleBounds; +import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; +import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; +import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer; +import org.apache.commons.math3.util.Pair; +import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result; + +import au.com.bytecode.opencsv.CSVReader; + +/** + * @author Gernot Belger + */ +public class LinearizedFittingTest { + + public static void main(final String[] args) throws IOException { + + // read test data + final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData(); + + for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) { + + final BigDecimal station = entry.getKey(); + final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue(); + optimizeA(station, testSample); + } + } + + private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) { + + /* extrakt observations */ + final double[] obsDischarges = new double[testSample.size()]; + final double[] obsWaterlevels = new double[testSample.size()]; + for (int i = 0; i < obsWaterlevels.length; i++) { + obsDischarges[i] = testSample.get(i).getKey().doubleValue(); + obsWaterlevels[i] = testSample.get(i).getValue().doubleValue(); + } + + final TargetFunction targetFunction = new TargetFunction(obsDischarges); + final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction); + + final Result directEstimation = estimateADirect(sqrtErrorFunction); + final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize(); + + final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction); + final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction); + + printResult(station, testSample, directEstimation); + printResult(station, testSample, linearEstimation); + printResult(station, testSample, directPostEstimation); + printResult(station, testSample, linearPostEstimation); + } + + private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) { + + final double a = estimation.getA(); + final double b = estimation.getB(); + final double m = estimation.getM(); + // final double error = estimation.getError(); + final double aIndex = estimation.getBestAindex(); + + final double aLow = Math.pow(10, aIndex - 1); + final double aHigh = Math.pow(10, aIndex + 1); + + return directOptimize(aLow, aHigh, a, b, m, errorFunction); + } + + private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) { + + final double a = optimize.getA(); + final double b = optimize.getB(); + final double m = optimize.getM(); + final double error = optimize.getError(); + + System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error); + + for (final Pair<BigDecimal, BigDecimal> entry : testSample) { + + final BigDecimal waterlevel = entry.getSecond(); + System.out.print(' '); + System.out.print(waterlevel); + } + System.out.println(); + } + + private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) { + + Result best = null; + double leastError = Double.POSITIVE_INFINITY; + + // iteration über a von 10^^0 bis 10^^20 + for (int i = 0; i < 20; i++) { + + final double aStart = Math.pow(10, i); + final double aLow = Math.pow(10, i - 1); + final double aHigh = Math.pow(10, i + 1); + + final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction); + final double error = result.getError(); + + if (error < leastError) { + leastError = error; + best = result.withBestAIndex(i); + } + } + + return best; + } + + private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m, + final SqrtErrorFunction sqrtErrorFunction) { + + // n = 3 + // [n+2, (n+1)(n+2)/2] + // --> [5, 10] + final int interpolationPoints = 10; + final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints); + + /* optimization data */ + final MultivariateFunction function = new MultivariateFunction() { + + @Override + public double value(final double[] point) { + return sqrtErrorFunction.value(point[0], point[1], point[2]); + } + }; + + final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE); + final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE); + + final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 }); + final double[] startValues = new double[] { a, b, m }; + final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval, + maxIter); + + final Double error = result.getValue(); + final double[] point = result.getPoint(); + + final double aEstimation = point[0]; + final double bEstimation = point[1]; + final double mEstimation = point[2]; + + return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN); + } + + private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException { + + final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>(); + + try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))), + '\t')) { + + final String[] header = reader.readNext(); + if (header == null) + throw new IllegalStateException(); + + if (header.length < 2) + throw new IllegalStateException(); + + while (true) { + final String[] line = reader.readNext(); + if (line == null) + break; + + if (line.length != header.length) + throw new IllegalStateException(); + + final BigDecimal discharge = parseDecimal(line[0]); + if (discharge == null) + continue; + + for (int column = 1; column < line.length; column++) { + + final BigDecimal station = parseDecimal(header[column]); + if (station == null) + continue; + + if (!data.containsKey(station)) + data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>()); + final List<Pair<BigDecimal, BigDecimal>> points = data.get(station); + + final BigDecimal waterlevel = parseDecimal(line[column]); + if (waterlevel != null) + points.add(Pair.create(discharge, waterlevel)); + } + } + } + + return data; + } + + private static BigDecimal parseDecimal(final String token) { + + if (StringUtils.isBlank(token)) + return null; + + if ("nan".equalsIgnoreCase(token)) + return null; + + return new BigDecimal(token); + } +} \ No newline at end of file