comparison artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba

Implemented alternative fitting strategy for Log-Linear function.
author Gernot Belger <g.belger@bjoernsen.de>
date Mon, 02 Dec 2019 17:56:15 +0100
parents
children
comparison
equal deleted inserted replaced
9645:eb1a29fe823f 9646:0380717105ba
1 /** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde
2 * Software engineering by
3 * Björnsen Beratende Ingenieure GmbH
4 * Dr. Schumacher Ingenieurbüro für Wasser und Umwelt
5 *
6 * This file is Free Software under the GNU AGPL (>=v3)
7 * and comes with ABSOLUTELY NO WARRANTY! Check out the
8 * documentation coming with Dive4Elements River for details.
9 */
10 package org.dive4elements.river.artifacts.model.fixings.fitting;
11
12 import java.io.BufferedInputStream;
13 import java.io.IOException;
14 import java.io.InputStreamReader;
15 import java.math.BigDecimal;
16 import java.util.ArrayList;
17 import java.util.List;
18 import java.util.Map.Entry;
19 import java.util.NavigableMap;
20 import java.util.TreeMap;
21
22 import org.apache.commons.lang.StringUtils;
23 import org.apache.commons.math3.analysis.MultivariateFunction;
24 import org.apache.commons.math3.optim.InitialGuess;
25 import org.apache.commons.math3.optim.MaxEval;
26 import org.apache.commons.math3.optim.MaxIter;
27 import org.apache.commons.math3.optim.PointValuePair;
28 import org.apache.commons.math3.optim.SimpleBounds;
29 import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
30 import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
31 import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
32 import org.apache.commons.math3.util.Pair;
33 import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result;
34
35 import au.com.bytecode.opencsv.CSVReader;
36
37 /**
38 * @author Gernot Belger
39 */
40 public class LinearizedFittingTest {
41
42 public static void main(final String[] args) throws IOException {
43
44 // read test data
45 final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData();
46
47 for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) {
48
49 final BigDecimal station = entry.getKey();
50 final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue();
51 optimizeA(station, testSample);
52 }
53 }
54
55 private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) {
56
57 /* extrakt observations */
58 final double[] obsDischarges = new double[testSample.size()];
59 final double[] obsWaterlevels = new double[testSample.size()];
60 for (int i = 0; i < obsWaterlevels.length; i++) {
61 obsDischarges[i] = testSample.get(i).getKey().doubleValue();
62 obsWaterlevels[i] = testSample.get(i).getValue().doubleValue();
63 }
64
65 final TargetFunction targetFunction = new TargetFunction(obsDischarges);
66 final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction);
67
68 final Result directEstimation = estimateADirect(sqrtErrorFunction);
69 final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize();
70
71 final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction);
72 final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction);
73
74 printResult(station, testSample, directEstimation);
75 printResult(station, testSample, linearEstimation);
76 printResult(station, testSample, directPostEstimation);
77 printResult(station, testSample, linearPostEstimation);
78 }
79
80 private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) {
81
82 final double a = estimation.getA();
83 final double b = estimation.getB();
84 final double m = estimation.getM();
85 // final double error = estimation.getError();
86 final double aIndex = estimation.getBestAindex();
87
88 final double aLow = Math.pow(10, aIndex - 1);
89 final double aHigh = Math.pow(10, aIndex + 1);
90
91 return directOptimize(aLow, aHigh, a, b, m, errorFunction);
92 }
93
94 private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) {
95
96 final double a = optimize.getA();
97 final double b = optimize.getB();
98 final double m = optimize.getM();
99 final double error = optimize.getError();
100
101 System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error);
102
103 for (final Pair<BigDecimal, BigDecimal> entry : testSample) {
104
105 final BigDecimal waterlevel = entry.getSecond();
106 System.out.print(' ');
107 System.out.print(waterlevel);
108 }
109 System.out.println();
110 }
111
112 private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) {
113
114 Result best = null;
115 double leastError = Double.POSITIVE_INFINITY;
116
117 // iteration über a von 10^^0 bis 10^^20
118 for (int i = 0; i < 20; i++) {
119
120 final double aStart = Math.pow(10, i);
121 final double aLow = Math.pow(10, i - 1);
122 final double aHigh = Math.pow(10, i + 1);
123
124 final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction);
125 final double error = result.getError();
126
127 if (error < leastError) {
128 leastError = error;
129 best = result.withBestAIndex(i);
130 }
131 }
132
133 return best;
134 }
135
136 private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m,
137 final SqrtErrorFunction sqrtErrorFunction) {
138
139 // n = 3
140 // [n+2, (n+1)(n+2)/2]
141 // --> [5, 10]
142 final int interpolationPoints = 10;
143 final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints);
144
145 /* optimization data */
146 final MultivariateFunction function = new MultivariateFunction() {
147
148 @Override
149 public double value(final double[] point) {
150 return sqrtErrorFunction.value(point[0], point[1], point[2]);
151 }
152 };
153
154 final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE);
155 final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE);
156
157 final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 });
158 final double[] startValues = new double[] { a, b, m };
159 final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval,
160 maxIter);
161
162 final Double error = result.getValue();
163 final double[] point = result.getPoint();
164
165 final double aEstimation = point[0];
166 final double bEstimation = point[1];
167 final double mEstimation = point[2];
168
169 return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN);
170 }
171
172 private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException {
173
174 final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>();
175
176 try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))),
177 '\t')) {
178
179 final String[] header = reader.readNext();
180 if (header == null)
181 throw new IllegalStateException();
182
183 if (header.length < 2)
184 throw new IllegalStateException();
185
186 while (true) {
187 final String[] line = reader.readNext();
188 if (line == null)
189 break;
190
191 if (line.length != header.length)
192 throw new IllegalStateException();
193
194 final BigDecimal discharge = parseDecimal(line[0]);
195 if (discharge == null)
196 continue;
197
198 for (int column = 1; column < line.length; column++) {
199
200 final BigDecimal station = parseDecimal(header[column]);
201 if (station == null)
202 continue;
203
204 if (!data.containsKey(station))
205 data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>());
206 final List<Pair<BigDecimal, BigDecimal>> points = data.get(station);
207
208 final BigDecimal waterlevel = parseDecimal(line[column]);
209 if (waterlevel != null)
210 points.add(Pair.create(discharge, waterlevel));
211 }
212 }
213 }
214
215 return data;
216 }
217
218 private static BigDecimal parseDecimal(final String token) {
219
220 if (StringUtils.isBlank(token))
221 return null;
222
223 if ("nan".equalsIgnoreCase(token))
224 return null;
225
226 return new BigDecimal(token);
227 }
228 }

http://dive4elements.wald.intevation.org