g@9646: /** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde g@9646: * Software engineering by g@9646: * Björnsen Beratende Ingenieure GmbH g@9646: * Dr. Schumacher Ingenieurbüro für Wasser und Umwelt g@9646: * g@9646: * This file is Free Software under the GNU AGPL (>=v3) g@9646: * and comes with ABSOLUTELY NO WARRANTY! Check out the g@9646: * documentation coming with Dive4Elements River for details. g@9646: */ g@9646: package org.dive4elements.river.artifacts.model.fixings.fitting; g@9646: g@9646: import java.io.BufferedInputStream; g@9646: import java.io.IOException; g@9646: import java.io.InputStreamReader; g@9646: import java.math.BigDecimal; g@9646: import java.util.ArrayList; g@9646: import java.util.List; g@9646: import java.util.Map.Entry; g@9646: import java.util.NavigableMap; g@9646: import java.util.TreeMap; g@9646: g@9646: import org.apache.commons.lang.StringUtils; g@9646: import org.apache.commons.math3.analysis.MultivariateFunction; g@9646: import org.apache.commons.math3.optim.InitialGuess; g@9646: import org.apache.commons.math3.optim.MaxEval; g@9646: import org.apache.commons.math3.optim.MaxIter; g@9646: import org.apache.commons.math3.optim.PointValuePair; g@9646: import org.apache.commons.math3.optim.SimpleBounds; g@9646: import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; g@9646: import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; g@9646: import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer; g@9646: import org.apache.commons.math3.util.Pair; g@9646: import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result; g@9646: g@9646: import au.com.bytecode.opencsv.CSVReader; g@9646: g@9646: /** g@9646: * @author Gernot Belger g@9646: */ g@9646: public class LinearizedFittingTest { g@9646: g@9646: public static void main(final String[] args) throws IOException { g@9646: g@9646: // read test data g@9646: final NavigableMap>> testData = readTestData(); g@9646: g@9646: for (final Entry>> entry : testData.entrySet()) { g@9646: g@9646: final BigDecimal station = entry.getKey(); g@9646: final List> testSample = entry.getValue(); g@9646: optimizeA(station, testSample); g@9646: } g@9646: } g@9646: g@9646: private static void optimizeA(final BigDecimal station, final List> testSample) { g@9646: g@9646: /* extrakt observations */ g@9646: final double[] obsDischarges = new double[testSample.size()]; g@9646: final double[] obsWaterlevels = new double[testSample.size()]; g@9646: for (int i = 0; i < obsWaterlevels.length; i++) { g@9646: obsDischarges[i] = testSample.get(i).getKey().doubleValue(); g@9646: obsWaterlevels[i] = testSample.get(i).getValue().doubleValue(); g@9646: } g@9646: g@9646: final TargetFunction targetFunction = new TargetFunction(obsDischarges); g@9646: final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction); g@9646: g@9646: final Result directEstimation = estimateADirect(sqrtErrorFunction); g@9646: final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize(); g@9646: g@9646: final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction); g@9646: final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction); g@9646: g@9646: printResult(station, testSample, directEstimation); g@9646: printResult(station, testSample, linearEstimation); g@9646: printResult(station, testSample, directPostEstimation); g@9646: printResult(station, testSample, linearPostEstimation); g@9646: } g@9646: g@9646: private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) { g@9646: g@9646: final double a = estimation.getA(); g@9646: final double b = estimation.getB(); g@9646: final double m = estimation.getM(); g@9646: // final double error = estimation.getError(); g@9646: final double aIndex = estimation.getBestAindex(); g@9646: g@9646: final double aLow = Math.pow(10, aIndex - 1); g@9646: final double aHigh = Math.pow(10, aIndex + 1); g@9646: g@9646: return directOptimize(aLow, aHigh, a, b, m, errorFunction); g@9646: } g@9646: g@9646: private static void printResult(final BigDecimal station, final List> testSample, final Result optimize) { g@9646: g@9646: final double a = optimize.getA(); g@9646: final double b = optimize.getB(); g@9646: final double m = optimize.getM(); g@9646: final double error = optimize.getError(); g@9646: g@9646: System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error); g@9646: g@9646: for (final Pair entry : testSample) { g@9646: g@9646: final BigDecimal waterlevel = entry.getSecond(); g@9646: System.out.print(' '); g@9646: System.out.print(waterlevel); g@9646: } g@9646: System.out.println(); g@9646: } g@9646: g@9646: private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) { g@9646: g@9646: Result best = null; g@9646: double leastError = Double.POSITIVE_INFINITY; g@9646: g@9646: // iteration über a von 10^^0 bis 10^^20 g@9646: for (int i = 0; i < 20; i++) { g@9646: g@9646: final double aStart = Math.pow(10, i); g@9646: final double aLow = Math.pow(10, i - 1); g@9646: final double aHigh = Math.pow(10, i + 1); g@9646: g@9646: final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction); g@9646: final double error = result.getError(); g@9646: g@9646: if (error < leastError) { g@9646: leastError = error; g@9646: best = result.withBestAIndex(i); g@9646: } g@9646: } g@9646: g@9646: return best; g@9646: } g@9646: g@9646: private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m, g@9646: final SqrtErrorFunction sqrtErrorFunction) { g@9646: g@9646: // n = 3 g@9646: // [n+2, (n+1)(n+2)/2] g@9646: // --> [5, 10] g@9646: final int interpolationPoints = 10; g@9646: final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints); g@9646: g@9646: /* optimization data */ g@9646: final MultivariateFunction function = new MultivariateFunction() { g@9646: g@9646: @Override g@9646: public double value(final double[] point) { g@9646: return sqrtErrorFunction.value(point[0], point[1], point[2]); g@9646: } g@9646: }; g@9646: g@9646: final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE); g@9646: final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE); g@9646: g@9646: final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 }); g@9646: final double[] startValues = new double[] { a, b, m }; g@9646: final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval, g@9646: maxIter); g@9646: g@9646: final Double error = result.getValue(); g@9646: final double[] point = result.getPoint(); g@9646: g@9646: final double aEstimation = point[0]; g@9646: final double bEstimation = point[1]; g@9646: final double mEstimation = point[2]; g@9646: g@9646: return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN); g@9646: } g@9646: g@9646: private static NavigableMap>> readTestData() throws IOException { g@9646: g@9646: final NavigableMap>> data = new TreeMap<>(); g@9646: g@9646: try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))), g@9646: '\t')) { g@9646: g@9646: final String[] header = reader.readNext(); g@9646: if (header == null) g@9646: throw new IllegalStateException(); g@9646: g@9646: if (header.length < 2) g@9646: throw new IllegalStateException(); g@9646: g@9646: while (true) { g@9646: final String[] line = reader.readNext(); g@9646: if (line == null) g@9646: break; g@9646: g@9646: if (line.length != header.length) g@9646: throw new IllegalStateException(); g@9646: g@9646: final BigDecimal discharge = parseDecimal(line[0]); g@9646: if (discharge == null) g@9646: continue; g@9646: g@9646: for (int column = 1; column < line.length; column++) { g@9646: g@9646: final BigDecimal station = parseDecimal(header[column]); g@9646: if (station == null) g@9646: continue; g@9646: g@9646: if (!data.containsKey(station)) g@9646: data.put(station, new ArrayList>()); g@9646: final List> points = data.get(station); g@9646: g@9646: final BigDecimal waterlevel = parseDecimal(line[column]); g@9646: if (waterlevel != null) g@9646: points.add(Pair.create(discharge, waterlevel)); g@9646: } g@9646: } g@9646: } g@9646: g@9646: return data; g@9646: } g@9646: g@9646: private static BigDecimal parseDecimal(final String token) { g@9646: g@9646: if (StringUtils.isBlank(token)) g@9646: return null; g@9646: g@9646: if ("nan".equalsIgnoreCase(token)) g@9646: return null; g@9646: g@9646: return new BigDecimal(token); g@9646: } g@9646: }