Mercurial > dive4elements > river
comparison artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/fitting/LinearizedFittingTest.java @ 9646:0380717105ba
Implemented alternative fitting strategy for Log-Linear function.
author | Gernot Belger <g.belger@bjoernsen.de> |
---|---|
date | Mon, 02 Dec 2019 17:56:15 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
9645:eb1a29fe823f | 9646:0380717105ba |
---|---|
1 /** Copyright (C) 2017 by Bundesanstalt für Gewässerkunde | |
2 * Software engineering by | |
3 * Björnsen Beratende Ingenieure GmbH | |
4 * Dr. Schumacher Ingenieurbüro für Wasser und Umwelt | |
5 * | |
6 * This file is Free Software under the GNU AGPL (>=v3) | |
7 * and comes with ABSOLUTELY NO WARRANTY! Check out the | |
8 * documentation coming with Dive4Elements River for details. | |
9 */ | |
10 package org.dive4elements.river.artifacts.model.fixings.fitting; | |
11 | |
12 import java.io.BufferedInputStream; | |
13 import java.io.IOException; | |
14 import java.io.InputStreamReader; | |
15 import java.math.BigDecimal; | |
16 import java.util.ArrayList; | |
17 import java.util.List; | |
18 import java.util.Map.Entry; | |
19 import java.util.NavigableMap; | |
20 import java.util.TreeMap; | |
21 | |
22 import org.apache.commons.lang.StringUtils; | |
23 import org.apache.commons.math3.analysis.MultivariateFunction; | |
24 import org.apache.commons.math3.optim.InitialGuess; | |
25 import org.apache.commons.math3.optim.MaxEval; | |
26 import org.apache.commons.math3.optim.MaxIter; | |
27 import org.apache.commons.math3.optim.PointValuePair; | |
28 import org.apache.commons.math3.optim.SimpleBounds; | |
29 import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; | |
30 import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; | |
31 import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer; | |
32 import org.apache.commons.math3.util.Pair; | |
33 import org.dive4elements.river.artifacts.model.fixings.fitting.LinearLogLinearizedFitting.Result; | |
34 | |
35 import au.com.bytecode.opencsv.CSVReader; | |
36 | |
37 /** | |
38 * @author Gernot Belger | |
39 */ | |
40 public class LinearizedFittingTest { | |
41 | |
42 public static void main(final String[] args) throws IOException { | |
43 | |
44 // read test data | |
45 final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> testData = readTestData(); | |
46 | |
47 for (final Entry<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> entry : testData.entrySet()) { | |
48 | |
49 final BigDecimal station = entry.getKey(); | |
50 final List<Pair<BigDecimal, BigDecimal>> testSample = entry.getValue(); | |
51 optimizeA(station, testSample); | |
52 } | |
53 } | |
54 | |
55 private static void optimizeA(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample) { | |
56 | |
57 /* extrakt observations */ | |
58 final double[] obsDischarges = new double[testSample.size()]; | |
59 final double[] obsWaterlevels = new double[testSample.size()]; | |
60 for (int i = 0; i < obsWaterlevels.length; i++) { | |
61 obsDischarges[i] = testSample.get(i).getKey().doubleValue(); | |
62 obsWaterlevels[i] = testSample.get(i).getValue().doubleValue(); | |
63 } | |
64 | |
65 final TargetFunction targetFunction = new TargetFunction(obsDischarges); | |
66 final SqrtErrorFunction sqrtErrorFunction = new SqrtErrorFunction(obsWaterlevels, targetFunction); | |
67 | |
68 final Result directEstimation = estimateADirect(sqrtErrorFunction); | |
69 final Result linearEstimation = new LinearLogLinearizedFitting(obsDischarges, obsWaterlevels).optimize(); | |
70 | |
71 final Result directPostEstimation = postOptimize(directEstimation, sqrtErrorFunction); | |
72 final Result linearPostEstimation = postOptimize(linearEstimation, sqrtErrorFunction); | |
73 | |
74 printResult(station, testSample, directEstimation); | |
75 printResult(station, testSample, linearEstimation); | |
76 printResult(station, testSample, directPostEstimation); | |
77 printResult(station, testSample, linearPostEstimation); | |
78 } | |
79 | |
80 private static Result postOptimize(final Result estimation, final SqrtErrorFunction errorFunction) { | |
81 | |
82 final double a = estimation.getA(); | |
83 final double b = estimation.getB(); | |
84 final double m = estimation.getM(); | |
85 // final double error = estimation.getError(); | |
86 final double aIndex = estimation.getBestAindex(); | |
87 | |
88 final double aLow = Math.pow(10, aIndex - 1); | |
89 final double aHigh = Math.pow(10, aIndex + 1); | |
90 | |
91 return directOptimize(aLow, aHigh, a, b, m, errorFunction); | |
92 } | |
93 | |
94 private static void printResult(final BigDecimal station, final List<Pair<BigDecimal, BigDecimal>> testSample, final Result optimize) { | |
95 | |
96 final double a = optimize.getA(); | |
97 final double b = optimize.getB(); | |
98 final double m = optimize.getM(); | |
99 final double error = optimize.getError(); | |
100 | |
101 System.out.format("%s %.10f %.10f %.10f %.10f", station, a, b, m, error); | |
102 | |
103 for (final Pair<BigDecimal, BigDecimal> entry : testSample) { | |
104 | |
105 final BigDecimal waterlevel = entry.getSecond(); | |
106 System.out.print(' '); | |
107 System.out.print(waterlevel); | |
108 } | |
109 System.out.println(); | |
110 } | |
111 | |
112 private static Result estimateADirect(final SqrtErrorFunction sqrtErrorFunction) { | |
113 | |
114 Result best = null; | |
115 double leastError = Double.POSITIVE_INFINITY; | |
116 | |
117 // iteration über a von 10^^0 bis 10^^20 | |
118 for (int i = 0; i < 20; i++) { | |
119 | |
120 final double aStart = Math.pow(10, i); | |
121 final double aLow = Math.pow(10, i - 1); | |
122 final double aHigh = Math.pow(10, i + 1); | |
123 | |
124 final Result result = directOptimize(aLow, aHigh, aStart, 1, 1, sqrtErrorFunction); | |
125 final double error = result.getError(); | |
126 | |
127 if (error < leastError) { | |
128 leastError = error; | |
129 best = result.withBestAIndex(i); | |
130 } | |
131 } | |
132 | |
133 return best; | |
134 } | |
135 | |
136 private static Result directOptimize(final double aLow, final double aHigh, final double a, final double b, final double m, | |
137 final SqrtErrorFunction sqrtErrorFunction) { | |
138 | |
139 // n = 3 | |
140 // [n+2, (n+1)(n+2)/2] | |
141 // --> [5, 10] | |
142 final int interpolationPoints = 10; | |
143 final BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints); | |
144 | |
145 /* optimization data */ | |
146 final MultivariateFunction function = new MultivariateFunction() { | |
147 | |
148 @Override | |
149 public double value(final double[] point) { | |
150 return sqrtErrorFunction.value(point[0], point[1], point[2]); | |
151 } | |
152 }; | |
153 | |
154 final MaxEval maxEval = new MaxEval(Integer.MAX_VALUE); | |
155 final MaxIter maxIter = new MaxIter(Integer.MAX_VALUE); | |
156 | |
157 final SimpleBounds bounds = new SimpleBounds(new double[] { aLow, -1e3, 0 }, new double[] { aHigh, 1e3, 1e3 }); | |
158 final double[] startValues = new double[] { a, b, m }; | |
159 final PointValuePair result = optimizer.optimize(GoalType.MINIMIZE, new ObjectiveFunction(function), new InitialGuess(startValues), bounds, maxEval, | |
160 maxIter); | |
161 | |
162 final Double error = result.getValue(); | |
163 final double[] point = result.getPoint(); | |
164 | |
165 final double aEstimation = point[0]; | |
166 final double bEstimation = point[1]; | |
167 final double mEstimation = point[2]; | |
168 | |
169 return new Result(aEstimation, bEstimation, mEstimation, error, -1, Double.NaN); | |
170 } | |
171 | |
172 private static NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> readTestData() throws IOException { | |
173 | |
174 final NavigableMap<BigDecimal, List<Pair<BigDecimal, BigDecimal>>> data = new TreeMap<>(); | |
175 | |
176 try (final CSVReader reader = new CSVReader(new InputStreamReader(new BufferedInputStream(LinearizedFittingTest.class.getResourceAsStream("testdata.txt"))), | |
177 '\t')) { | |
178 | |
179 final String[] header = reader.readNext(); | |
180 if (header == null) | |
181 throw new IllegalStateException(); | |
182 | |
183 if (header.length < 2) | |
184 throw new IllegalStateException(); | |
185 | |
186 while (true) { | |
187 final String[] line = reader.readNext(); | |
188 if (line == null) | |
189 break; | |
190 | |
191 if (line.length != header.length) | |
192 throw new IllegalStateException(); | |
193 | |
194 final BigDecimal discharge = parseDecimal(line[0]); | |
195 if (discharge == null) | |
196 continue; | |
197 | |
198 for (int column = 1; column < line.length; column++) { | |
199 | |
200 final BigDecimal station = parseDecimal(header[column]); | |
201 if (station == null) | |
202 continue; | |
203 | |
204 if (!data.containsKey(station)) | |
205 data.put(station, new ArrayList<Pair<BigDecimal, BigDecimal>>()); | |
206 final List<Pair<BigDecimal, BigDecimal>> points = data.get(station); | |
207 | |
208 final BigDecimal waterlevel = parseDecimal(line[column]); | |
209 if (waterlevel != null) | |
210 points.add(Pair.create(discharge, waterlevel)); | |
211 } | |
212 } | |
213 } | |
214 | |
215 return data; | |
216 } | |
217 | |
218 private static BigDecimal parseDecimal(final String token) { | |
219 | |
220 if (StringUtils.isBlank(token)) | |
221 return null; | |
222 | |
223 if ("nan".equalsIgnoreCase(token)) | |
224 return null; | |
225 | |
226 return new BigDecimal(token); | |
227 } | |
228 } |