comparison artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 5838:5aa05a7a34b7

Rename modules to more fitting names.
author Sascha L. Teichmann <teichmann@intevation.de>
date Thu, 25 Apr 2013 15:23:37 +0200
parents flys-artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java@bd047b71ab37
children 4897a58c8746
comparison
equal deleted inserted replaced
5837:d9901a08d0a6 5838:5aa05a7a34b7
1 package org.dive4elements.river.artifacts.model.fixings;
2
3 import gnu.trove.TDoubleArrayList;
4
5 import java.util.ArrayList;
6 import java.util.List;
7
8 import org.apache.commons.math.MathException;
9 import org.apache.commons.math.optimization.fitting.CurveFitter;
10 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
11 import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
12 import org.apache.log4j.Logger;
13
14 import org.dive4elements.river.artifacts.math.GrubbsOutlier;
15 import org.dive4elements.river.artifacts.math.fitting.Function;
16
17 public class Fitting
18 {
19 private static Logger log = Logger.getLogger(Fitting.class);
20
21 /** Use instance of this factory to find meta infos for outliers. */
22 public interface QWDFactory {
23
24 QWD create(double q, double w);
25
26 } // interface QWFactory
27
28 public static final QWDFactory QWD_FACTORY = new QWDFactory() {
29 @Override
30 public QWD create(double q, double w) {
31 return new QWD(q, w);
32 }
33 };
34
35 protected boolean checkOutliers;
36 protected Function function;
37 protected QWDFactory qwdFactory;
38 protected double chiSqr;
39 protected double [] parameters;
40 protected ArrayList<QWI> removed;
41 protected QWD [] referenced;
42 protected double standardDeviation;
43
44
45 public Fitting() {
46 removed = new ArrayList<QWI>();
47 }
48
49 public Fitting(Function function) {
50 this(function, QWD_FACTORY);
51 }
52
53 public Fitting(Function function, QWDFactory qwdFactory) {
54 this(function, qwdFactory, false);
55 }
56
57 public Fitting(
58 Function function,
59 QWDFactory qwdFactory,
60 boolean checkOutliers
61 ) {
62 this();
63 this.function = function;
64 this.qwdFactory = qwdFactory;
65 this.checkOutliers = checkOutliers;
66 }
67
68 public Function getFunction() {
69 return function;
70 }
71
72 public void setFunction(Function function) {
73 this.function = function;
74 }
75
76 public boolean getCheckOutliers() {
77 return checkOutliers;
78 }
79
80 public void setCheckOutliers(boolean checkOutliers) {
81 this.checkOutliers = checkOutliers;
82 }
83
84 public double getChiSquare() {
85 return chiSqr;
86 }
87
88 public void reset() {
89 chiSqr = 0.0;
90 parameters = null;
91 removed.clear();
92 referenced = null;
93 standardDeviation = 0.0;
94 }
95
96 public boolean hasOutliers() {
97 return !removed.isEmpty();
98 }
99
100 public List<QWI> getOutliers() {
101 return removed;
102 }
103
104 public QWI [] outliersToArray() {
105 return removed.toArray(new QWI[removed.size()]);
106 }
107
108 public QWD [] referencedToArray() {
109 return referenced != null ? (QWD [])referenced.clone() : null;
110 }
111
112 public double getMaxQ() {
113 double maxQ = -Double.MAX_VALUE;
114 if (referenced != null) {
115 for (QWI qw: referenced) {
116 if (qw.getQ() > maxQ) {
117 maxQ = qw.getQ();
118 }
119 }
120 }
121 return maxQ;
122 }
123
124 public double [] getParameters() {
125 return parameters;
126 }
127
128 public double getStandardDeviation() {
129 return standardDeviation;
130 }
131
132 public boolean fit(double [] qs, double [] ws) {
133
134 TDoubleArrayList xs = new TDoubleArrayList(qs.length);
135 TDoubleArrayList ys = new TDoubleArrayList(ws.length);
136
137 for (int i = 0; i < qs.length; ++i) {
138 if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) {
139 xs.add(qs[i]);
140 ys.add(ws[i]);
141 }
142 else {
143 log.warn("remove invalid value " + qs[i] + " " + ws[i]);
144 }
145 }
146
147 if (xs.size() < 2) {
148 log.warn("Too less points.");
149 return false;
150 }
151
152 List<Double> inputs = new ArrayList<Double>(xs.size());
153
154 org.dive4elements.river.artifacts.math.Function instance = null;
155
156 LevenbergMarquardtOptimizer lmo = null;
157
158 for (;;) {
159 parameters = null;
160 for (double tolerance = 1e-10; tolerance < 1e-3; tolerance *= 10d) {
161
162 lmo = new LevenbergMarquardtOptimizer();
163 lmo.setCostRelativeTolerance(tolerance);
164 lmo.setOrthoTolerance(tolerance);
165 lmo.setParRelativeTolerance(tolerance);
166
167 CurveFitter cf = new CurveFitter(lmo);
168
169 for (int i = 0, N = xs.size(); i < N; ++i) {
170 cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i));
171 }
172
173 try {
174 parameters = cf.fit(function, function.getInitialGuess());
175 break;
176 }
177 catch (MathException me) {
178 if (log.isDebugEnabled()) {
179 log.debug("tolerance " + tolerance + " + failed.");
180 }
181 }
182 }
183 if (parameters == null) {
184 return false;
185 }
186
187 // This is the paraterized function for a given km.
188 instance = function.instantiate(parameters);
189
190 if (!checkOutliers) {
191 break;
192 }
193
194 inputs.clear();
195
196 for (int i = 0, N = xs.size(); i < N; ++i) {
197 double y = instance.value(xs.getQuick(i));
198 if (Double.isNaN(y)) {
199 y = Double.MAX_VALUE;
200 }
201 inputs.add(Double.valueOf(ys.getQuick(i) - y));
202 }
203
204 Integer outlier = GrubbsOutlier.findOutlier(inputs);
205
206 if (outlier == null) {
207 break;
208 }
209
210 int idx = outlier.intValue();
211 removed.add(
212 qwdFactory.create(
213 xs.getQuick(idx), ys.getQuick(idx)));
214 xs.remove(idx);
215 ys.remove(idx);
216 }
217
218 StandardDeviation stdDev = new StandardDeviation();
219
220 referenced = new QWD[xs.size()];
221 for (int i = 0; i < referenced.length; ++i) {
222 QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i));
223
224 if (qwd == null) {
225 log.warn("QW creation failed!");
226 }
227 else {
228 referenced[i] = qwd;
229 double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0;
230 qwd.setDeltaW(dw);
231 stdDev.increment(dw);
232 }
233 }
234
235 standardDeviation = stdDev.getResult();
236
237 chiSqr = lmo.getChiSquare();
238
239 return true;
240 }
241 }
242 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org