Mercurial > dive4elements > river
comparison artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 9415:9744ce3c3853
Rework of fixanalysis computation and dWt and WQ facets. Got rid of strange remapping and bitshifting code by explicitely saving the column information and using it in the facets.
The facets also put the valid station range into their xml-metadata
author | gernotbelger |
---|---|
date | Thu, 16 Aug 2018 16:27:53 +0200 |
parents | ddcd52d239cd |
children | f51e23eb036a |
comparison
equal
deleted
inserted
replaced
9414:096f151a0a9f | 9415:9744ce3c3853 |
---|---|
6 * documentation coming with Dive4Elements River for details. | 6 * documentation coming with Dive4Elements River for details. |
7 */ | 7 */ |
8 | 8 |
9 package org.dive4elements.river.artifacts.model.fixings; | 9 package org.dive4elements.river.artifacts.model.fixings; |
10 | 10 |
11 import gnu.trove.TDoubleArrayList; | |
12 | |
13 import java.util.ArrayList; | 11 import java.util.ArrayList; |
12 import java.util.Date; | |
14 import java.util.List; | 13 import java.util.List; |
15 | 14 |
16 import org.apache.commons.math.MathException; | 15 import org.apache.commons.math.MathException; |
17 import org.apache.commons.math.optimization.fitting.CurveFitter; | 16 import org.apache.commons.math.optimization.fitting.CurveFitter; |
18 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; | 17 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; |
19 import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; | 18 import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; |
20 import org.apache.log4j.Logger; | 19 import org.apache.log4j.Logger; |
21 | |
22 import org.dive4elements.river.artifacts.math.GrubbsOutlier; | 20 import org.dive4elements.river.artifacts.math.GrubbsOutlier; |
23 import org.dive4elements.river.artifacts.math.fitting.Function; | 21 import org.dive4elements.river.artifacts.math.fitting.Function; |
24 | 22 |
25 public class Fitting { | 23 public class Fitting { |
26 private static Logger log = Logger.getLogger(Fitting.class); | 24 private static Logger log = Logger.getLogger(Fitting.class); |
28 /** Use instance of this factory to find meta infos for outliers. */ | 26 /** Use instance of this factory to find meta infos for outliers. */ |
29 public interface QWDFactory { | 27 public interface QWDFactory { |
30 QWD create(double q, double w, double deltaW, boolean isOutlier); | 28 QWD create(double q, double w, double deltaW, boolean isOutlier); |
31 } | 29 } |
32 | 30 |
31 private static final class FittingData { | |
32 public final FixingColumnWithData event; | |
33 public final double w; | |
34 public final double q; | |
35 public final boolean isInterpolated; | |
36 | |
37 public FittingData(final FixingColumnWithData event, final double w, final double q, final boolean isInterpolated) { | |
38 this.event = event; | |
39 this.w = w; | |
40 this.q = q; | |
41 this.isInterpolated = isInterpolated; | |
42 } | |
43 } | |
44 | |
33 private final double chiSqr; | 45 private final double chiSqr; |
34 | 46 |
35 private final double[] parameters; | 47 private final double[] parameters; |
36 | 48 |
37 private final double standardDeviation; | 49 private final double standardDeviation; |
38 | 50 |
39 private final List<QWD> qwds; | 51 private final double maxQ; |
40 | 52 |
41 public Fitting(final double[] parameters, final double standardDeviation, final double chiSqr, final List<QWD> qwds) { | 53 public Fitting(final double[] parameters, final double standardDeviation, final double chiSqr, final double maxQ) { |
42 this.parameters = parameters; | 54 this.parameters = parameters; |
43 this.standardDeviation = standardDeviation; | 55 this.standardDeviation = standardDeviation; |
44 this.chiSqr = chiSqr; | 56 this.chiSqr = chiSqr; |
45 this.qwds = qwds; | 57 this.maxQ = maxQ; |
46 } | 58 } |
47 | 59 |
48 public double getChiSquare() { | 60 public double getChiSquare() { |
49 return chiSqr; | 61 return this.chiSqr; |
50 } | |
51 | |
52 /** | |
53 * Returns all referenced and outliers as one array. | |
54 */ | |
55 public QWD[] getFixingsArray() { | |
56 return qwds.toArray(new QWD[qwds.size()]); | |
57 } | 62 } |
58 | 63 |
59 public double getMaxQ() { | 64 public double getMaxQ() { |
60 double maxQ = -Double.MAX_VALUE; | 65 return this.maxQ; |
61 | |
62 for (QWD qw : qwds) { | |
63 final double q = qw.getQ(); | |
64 if (!qw.isOutlier() && q > maxQ) | |
65 maxQ = q; | |
66 } | |
67 | |
68 return maxQ; | |
69 } | 66 } |
70 | 67 |
71 public double[] getParameters() { | 68 public double[] getParameters() { |
72 return parameters; | 69 return this.parameters; |
73 } | 70 } |
74 | 71 |
75 public double getStandardDeviation() { | 72 public double getStandardDeviation() { |
76 return standardDeviation; | 73 return this.standardDeviation; |
77 } | 74 } |
78 | 75 |
79 public static Fitting fit(final Function function, final QWDFactory qwdFactory, final boolean checkOutliers, final double[] qs, final double[] ws) { | 76 public static Fitting fit(final FixResultColumns resultColumns, final double km, final Function function, final boolean checkOutliers, |
80 | 77 final List<FixingColumnWithData> eventColumns) { |
81 final TDoubleArrayList xs = new TDoubleArrayList(qs.length); | 78 |
82 final TDoubleArrayList ys = new TDoubleArrayList(ws.length); | 79 final int numEvents = eventColumns.size(); |
83 | 80 final List<FittingData> data = new ArrayList<>(numEvents); |
84 for (int i = 0; i < qs.length; ++i) { | 81 |
85 if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { | 82 final double[] wTemp = new double[1]; |
86 xs.add(qs[i]); | 83 |
87 ys.add(ws[i]); | 84 for (final FixingColumnWithData event : eventColumns) { |
88 } | 85 |
89 } | 86 final double q = event.getQ(km); |
90 | 87 if (!Double.isNaN(q)) { |
91 if (xs.size() < 2) { | 88 final boolean isInterpolated = !event.getW(km, wTemp, 0); |
92 log.warn("Too less points."); | 89 final double w = wTemp[0]; |
90 | |
91 if (!Double.isNaN(w)) | |
92 data.add(new FittingData(event, w, q, isInterpolated)); | |
93 } | |
94 } | |
95 | |
96 if (data.size() < 2) { | |
97 log.warn("Not enough data for fitting."); | |
93 return null; | 98 return null; |
94 } | 99 } |
95 | 100 |
96 final List<Double> inputs = new ArrayList<>(xs.size()); | 101 final List<FittingData> outliers = new ArrayList<>(data.size()); |
97 final List<QWD> qwds = new ArrayList<>(xs.size()); | |
98 final List<QWD> outliers = new ArrayList<>(xs.size()); | |
99 | 102 |
100 org.dive4elements.river.artifacts.math.Function instance = null; | 103 org.dive4elements.river.artifacts.math.Function instance = null; |
101 LevenbergMarquardtOptimizer lmo = null; | 104 LevenbergMarquardtOptimizer lmo = null; |
102 double[] parameters = null; | 105 double[] parameters = null; |
103 | 106 |
104 for (;;) { | 107 for (;;) { |
108 | |
105 parameters = null; | 109 parameters = null; |
110 | |
106 for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) { | 111 for (double tolerance = 1e-10; tolerance < 1e-1; tolerance *= 10d) { |
107 | 112 |
108 lmo = new LevenbergMarquardtOptimizer(); | 113 lmo = new LevenbergMarquardtOptimizer(); |
109 lmo.setCostRelativeTolerance(tolerance); | 114 lmo.setCostRelativeTolerance(tolerance); |
110 lmo.setOrthoTolerance(tolerance); | 115 lmo.setOrthoTolerance(tolerance); |
111 lmo.setParRelativeTolerance(tolerance); | 116 lmo.setParRelativeTolerance(tolerance); |
112 | 117 |
113 CurveFitter cf = new CurveFitter(lmo); | |
114 | |
115 for (int i = 0, N = xs.size(); i < N; ++i) { | |
116 cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); | |
117 } | |
118 | |
119 try { | 118 try { |
119 final CurveFitter cf = new CurveFitter(lmo); | |
120 | |
121 for (final FittingData fittingData : data) | |
122 cf.addObservedPoint(fittingData.q, fittingData.w); | |
123 | |
120 parameters = cf.fit(function, function.getInitialGuess()); | 124 parameters = cf.fit(function, function.getInitialGuess()); |
121 break; | 125 break; |
122 } | 126 } |
123 catch (MathException me) { | 127 catch (final MathException me) { |
124 if (log.isDebugEnabled()) { | 128 if (log.isDebugEnabled()) { |
125 log.debug("tolerance " + tolerance + " + failed.", me); | 129 log.debug("tolerance " + tolerance + " + failed.", me); |
126 } | 130 } |
127 } | 131 } |
128 } | 132 } |
135 * } | 139 * } |
136 */ | 140 */ |
137 return null; | 141 return null; |
138 } | 142 } |
139 | 143 |
140 // This is the paraterized function for a given km. | 144 // This is the parameterized function for a given km. |
141 instance = function.instantiate(parameters); | 145 instance = function.instantiate(parameters); |
142 | 146 |
143 if (!checkOutliers) | 147 if (!checkOutliers) |
144 break; | 148 break; |
145 | 149 |
146 inputs.clear(); | 150 /* find the outlier */ |
147 | 151 final List<Double> inputs = new ArrayList<>(data.size()); |
148 for (int i = 0, N = xs.size(); i < N; ++i) { | 152 for (final FittingData fittingData : data) { |
149 double y = instance.value(xs.getQuick(i)); | 153 |
150 if (Double.isNaN(y)) { | 154 double y = instance.value(fittingData.q); |
155 if (Double.isNaN(y)) | |
151 y = Double.MAX_VALUE; | 156 y = Double.MAX_VALUE; |
152 } | 157 |
153 inputs.add(Double.valueOf(ys.getQuick(i) - y)); | 158 inputs.add(Double.valueOf(fittingData.w - y)); |
154 } | 159 } |
155 | 160 |
156 final Integer outlier = GrubbsOutlier.findOutlier(inputs); | 161 final Integer outlier = GrubbsOutlier.findOutlier(inputs); |
157 if (outlier == null) | 162 if (outlier == null) |
158 break; | 163 break; |
159 | 164 |
160 final int idx = outlier.intValue(); | 165 final int idx = outlier.intValue(); |
161 outliers.add(qwdFactory.create(xs.getQuick(idx), ys.getQuick(idx), Double.NaN, true)); | 166 |
162 xs.remove(idx); | 167 // outliers.add(qwdFactory.create(xs.getQuick(idx), ys.getQuick(idx), Double.NaN, true)); |
163 ys.remove(idx); | 168 final FittingData removed = data.remove(idx); |
164 } | 169 outliers.add(removed); |
165 | 170 } |
166 for (QWD outlier : outliers) { | 171 |
167 | 172 /* now build result data */ |
168 final double w = outlier.getW(); | 173 final List<QWD> qwds = new ArrayList<>(data.size()); |
169 final double q = outlier.getQ(); | 174 |
170 | 175 /* calculate dW of outliers against the resulting function and add them to results */ |
171 final double dw = (w - instance.value(q)) * 100.0; | 176 for (final FittingData outlier : outliers) { |
172 | 177 final QWD qwd = createQWD(outlier, instance, true); |
173 outlier.setDeltaW(dw); | 178 qwds.add(qwd); |
174 | 179 resultColumns.addQWD(outlier.event, km, qwd); |
175 qwds.add(outlier); | 180 } |
176 } | 181 |
177 | 182 /* |
183 * calculate dW of used values against the resulting function and add them to results , also calculate standard * | |
184 * deviation | |
185 */ | |
178 final StandardDeviation stdDev = new StandardDeviation(); | 186 final StandardDeviation stdDev = new StandardDeviation(); |
179 | 187 double maxQ = -Double.MAX_VALUE; |
180 for (int i = 0; i < xs.size(); ++i) { | 188 |
181 | 189 for (final FittingData fittingData : data) { |
182 final QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i), Double.NaN, false); | 190 |
183 | 191 final QWD qwd = createQWD(fittingData, instance, false); |
184 double dw = (qwd.getW() - instance.value(qwd.getQ())) * 100.0; | |
185 qwd.setDeltaW(dw); | |
186 | |
187 qwds.add(qwd); | 192 qwds.add(qwd); |
188 | 193 resultColumns.addQWD(fittingData.event, km, qwd); |
189 stdDev.increment(dw); | 194 |
195 stdDev.increment(qwd.getDeltaW()); | |
196 | |
197 final double q = qwd.getQ(); | |
198 if (q > maxQ) | |
199 maxQ = q; | |
190 } | 200 } |
191 | 201 |
192 final double standardDeviation = stdDev.getResult(); | 202 final double standardDeviation = stdDev.getResult(); |
193 | 203 |
194 final double chiSqr = lmo.getChiSquare(); | 204 final double chiSqr = lmo.getChiSquare(); |
195 | 205 |
196 return new Fitting(parameters, standardDeviation, chiSqr, qwds); | 206 return new Fitting(parameters, standardDeviation, chiSqr, maxQ); |
207 } | |
208 | |
209 private static QWD createQWD(final FittingData data, final org.dive4elements.river.artifacts.math.Function function, final boolean isOutlier) { | |
210 | |
211 final FixingColumnWithData event = data.event; | |
212 final Date date = event.getDate(); | |
213 final boolean isInterpolated = data.isInterpolated; | |
214 | |
215 final double w = data.w; | |
216 final double q = data.q; | |
217 final double dw = (w - function.value(q)) * 100.0; | |
218 | |
219 return new QWD(q, w, date, isInterpolated, dw, isOutlier); | |
197 } | 220 } |
198 } | 221 } |