Mercurial > dive4elements > river
view artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba
Implemented alternative fitting strategy for Log-Linear function.
author | Gernot Belger <g.belger@bjoernsen.de> |
---|---|
date | Mon, 02 Dec 2019 17:56:15 +0100 |
parents | |
children |
line wrap: on
line source
/** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde * Software engineering by * Björnsen Beratende Ingenieure GmbH * Dr. Schumacher Ingenieurbüro für Wasser und Umwelt * * This file is Free Software under the GNU AGPL (>=v3) * and comes with ABSOLUTELY NO WARRANTY! Check out the * documentation coming with Dive4Elements River for details. */ package org.dive4elements.river.artifacts.model.fixings.fitting; import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.Map.Entry; import java.util.NavigableMap; import java.util.TreeMap; import org.apache.commons.lang.StringUtils; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.optim.InitialGuess; import org.apache.commons.math3.optim.MaxEval; import org.apache.commons.math3.optim.MaxIter; import org.apache.commons.math3.optim.PointValuePair; import org.apache.commons.math3.optim.SimpleBounds; import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer; import org.apache.commons.math3.util.Pair; import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result; import au.com.bytecode.opencsv.CSVReader; /** * @author Gernot Belger */ public class LinearizedFittingTest { public static void main(final String[] args) throws IOException { // read test data final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData(); for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) { final BigDecimal station = entry.getKey(); final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue(); optimizeA(station, testSample); } } private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) { /* extrakt observations */ final double[] obsDischarges = new double[testSample.size()]; final double[] obsWaterlevels = new double[testSample.size()]; for (int i = 0; i < obsWaterlevels.length; i++) { obsDischarges[i] = testSample.get(i).getKey().doubleValue(); obsWaterlevels[i] = testSample.get(i).getValue().doubleValue(); } final TargetFunction targetFunction = new TargetFunction(obsDischarges); final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction); final Result directEstimation = estimateADirect(sqrtErrorFunction); final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize(); final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction); final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction); printResult(station, testSample, directEstimation); printResult(station, testSample, linearEstimation); printResult(station, testSample, directPostEstimation); printResult(station, testSample, linearPostEstimation); } private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) { final double a = estimation.getA(); final double b = estimation.getB(); final double m = estimation.getM(); // final double error = estimation.getError(); final double aIndex = estimation.getBestAindex(); final double aLow = Math.pow(10, aIndex - 1); final double aHigh = Math.pow(10, aIndex + 1); return directOptimize(aLow, aHigh, a, b, m, errorFunction); } private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) { final double a = optimize.getA(); final double b = optimize.getB(); final double m = optimize.getM(); final double error = optimize.getError(); System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error); for (final Pair<BigDecimal, BigDecimal> entry : testSample) { final BigDecimal waterlevel = entry.getSecond(); System.out.print(' '); System.out.print(waterlevel); } System.out.println(); } private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) { Result best = null; double leastError = Double.POSITIVE_INFINITY; // iteration über a von 10^^0 bis 10^^20 for (int i = 0; i < 20; i++) { final double aStart = Math.pow(10, i); final double aLow = Math.pow(10, i - 1); final double aHigh = Math.pow(10, i + 1); final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction); final double error = result.getError(); if (error < leastError) { leastError = error; best = result.withBestAIndex(i); } } return best; } private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m, final SqrtErrorFunction sqrtErrorFunction) { // n = 3 // [n+2, (n+1)(n+2)/2] // --> [5, 10] final int interpolationPoints = 10; final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints); /* optimization data */ final MultivariateFunction function = new MultivariateFunction() { @Override public double value(final double[] point) { return sqrtErrorFunction.value(point[0], point[1], point[2]); } }; final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE); final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE); final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 }); final double[] startValues = new double[] { a, b, m }; final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval, maxIter); final Double error = result.getValue(); final double[] point = result.getPoint(); final double aEstimation = point[0]; final double bEstimation = point[1]; final double mEstimation = point[2]; return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN); } private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException { final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>(); try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))), '\t')) { final String[] header = reader.readNext(); if (header == null) throw new IllegalStateException(); if (header.length < 2) throw new IllegalStateException(); while (true) { final String[] line = reader.readNext(); if (line == null) break; if (line.length != header.length) throw new IllegalStateException(); final BigDecimal discharge = parseDecimal(line[0]); if (discharge == null) continue; for (int column = 1; column < line.length; column++) { final BigDecimal station = parseDecimal(header[column]); if (station == null) continue; if (!data.containsKey(station)) data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>()); final List<Pair<BigDecimal, BigDecimal>> points = data.get(station); final BigDecimal waterlevel = parseDecimal(line[column]); if (waterlevel != null) points.add(Pair.create(discharge, waterlevel)); } } } return data; } private static BigDecimal parseDecimal(final String token) { if (StringUtils.isBlank(token)) return null; if ("nan".equalsIgnoreCase(token)) return null; return new BigDecimal(token); } }