Mercurial > dive4elements > river
comparison artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java @ 5838:5aa05a7a34b7
Rename modules to more fitting names.
author | Sascha L. Teichmann <teichmann@intevation.de> |
---|---|
date | Thu, 25 Apr 2013 15:23:37 +0200 |
parents | flys-artifacts/src/main/java/org/dive4elements/river/artifacts/model/fixings/Fitting.java@bd047b71ab37 |
children | 4897a58c8746 |
comparison
equal
deleted
inserted
replaced
5837:d9901a08d0a6 | 5838:5aa05a7a34b7 |
---|---|
1 package org.dive4elements.river.artifacts.model.fixings; | |
2 | |
3 import gnu.trove.TDoubleArrayList; | |
4 | |
5 import java.util.ArrayList; | |
6 import java.util.List; | |
7 | |
8 import org.apache.commons.math.MathException; | |
9 import org.apache.commons.math.optimization.fitting.CurveFitter; | |
10 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; | |
11 import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; | |
12 import org.apache.log4j.Logger; | |
13 | |
14 import org.dive4elements.river.artifacts.math.GrubbsOutlier; | |
15 import org.dive4elements.river.artifacts.math.fitting.Function; | |
16 | |
17 public class Fitting | |
18 { | |
19 private static Logger log = Logger.getLogger(Fitting.class); | |
20 | |
21 /** Use instance of this factory to find meta infos for outliers. */ | |
22 public interface QWDFactory { | |
23 | |
24 QWD create(double q, double w); | |
25 | |
26 } // interface QWFactory | |
27 | |
28 public static final QWDFactory QWD_FACTORY = new QWDFactory() { | |
29 @Override | |
30 public QWD create(double q, double w) { | |
31 return new QWD(q, w); | |
32 } | |
33 }; | |
34 | |
35 protected boolean checkOutliers; | |
36 protected Function function; | |
37 protected QWDFactory qwdFactory; | |
38 protected double chiSqr; | |
39 protected double [] parameters; | |
40 protected ArrayList<QWI> removed; | |
41 protected QWD [] referenced; | |
42 protected double standardDeviation; | |
43 | |
44 | |
45 public Fitting() { | |
46 removed = new ArrayList<QWI>(); | |
47 } | |
48 | |
49 public Fitting(Function function) { | |
50 this(function, QWD_FACTORY); | |
51 } | |
52 | |
53 public Fitting(Function function, QWDFactory qwdFactory) { | |
54 this(function, qwdFactory, false); | |
55 } | |
56 | |
57 public Fitting( | |
58 Function function, | |
59 QWDFactory qwdFactory, | |
60 boolean checkOutliers | |
61 ) { | |
62 this(); | |
63 this.function = function; | |
64 this.qwdFactory = qwdFactory; | |
65 this.checkOutliers = checkOutliers; | |
66 } | |
67 | |
68 public Function getFunction() { | |
69 return function; | |
70 } | |
71 | |
72 public void setFunction(Function function) { | |
73 this.function = function; | |
74 } | |
75 | |
76 public boolean getCheckOutliers() { | |
77 return checkOutliers; | |
78 } | |
79 | |
80 public void setCheckOutliers(boolean checkOutliers) { | |
81 this.checkOutliers = checkOutliers; | |
82 } | |
83 | |
84 public double getChiSquare() { | |
85 return chiSqr; | |
86 } | |
87 | |
88 public void reset() { | |
89 chiSqr = 0.0; | |
90 parameters = null; | |
91 removed.clear(); | |
92 referenced = null; | |
93 standardDeviation = 0.0; | |
94 } | |
95 | |
96 public boolean hasOutliers() { | |
97 return !removed.isEmpty(); | |
98 } | |
99 | |
100 public List<QWI> getOutliers() { | |
101 return removed; | |
102 } | |
103 | |
104 public QWI [] outliersToArray() { | |
105 return removed.toArray(new QWI[removed.size()]); | |
106 } | |
107 | |
108 public QWD [] referencedToArray() { | |
109 return referenced != null ? (QWD [])referenced.clone() : null; | |
110 } | |
111 | |
112 public double getMaxQ() { | |
113 double maxQ = -Double.MAX_VALUE; | |
114 if (referenced != null) { | |
115 for (QWI qw: referenced) { | |
116 if (qw.getQ() > maxQ) { | |
117 maxQ = qw.getQ(); | |
118 } | |
119 } | |
120 } | |
121 return maxQ; | |
122 } | |
123 | |
124 public double [] getParameters() { | |
125 return parameters; | |
126 } | |
127 | |
128 public double getStandardDeviation() { | |
129 return standardDeviation; | |
130 } | |
131 | |
132 public boolean fit(double [] qs, double [] ws) { | |
133 | |
134 TDoubleArrayList xs = new TDoubleArrayList(qs.length); | |
135 TDoubleArrayList ys = new TDoubleArrayList(ws.length); | |
136 | |
137 for (int i = 0; i < qs.length; ++i) { | |
138 if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { | |
139 xs.add(qs[i]); | |
140 ys.add(ws[i]); | |
141 } | |
142 else { | |
143 log.warn("remove invalid value " + qs[i] + " " + ws[i]); | |
144 } | |
145 } | |
146 | |
147 if (xs.size() < 2) { | |
148 log.warn("Too less points."); | |
149 return false; | |
150 } | |
151 | |
152 List<Double> inputs = new ArrayList<Double>(xs.size()); | |
153 | |
154 org.dive4elements.river.artifacts.math.Function instance = null; | |
155 | |
156 LevenbergMarquardtOptimizer lmo = null; | |
157 | |
158 for (;;) { | |
159 parameters = null; | |
160 for (double tolerance = 1e-10; tolerance < 1e-3; tolerance *= 10d) { | |
161 | |
162 lmo = new LevenbergMarquardtOptimizer(); | |
163 lmo.setCostRelativeTolerance(tolerance); | |
164 lmo.setOrthoTolerance(tolerance); | |
165 lmo.setParRelativeTolerance(tolerance); | |
166 | |
167 CurveFitter cf = new CurveFitter(lmo); | |
168 | |
169 for (int i = 0, N = xs.size(); i < N; ++i) { | |
170 cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); | |
171 } | |
172 | |
173 try { | |
174 parameters = cf.fit(function, function.getInitialGuess()); | |
175 break; | |
176 } | |
177 catch (MathException me) { | |
178 if (log.isDebugEnabled()) { | |
179 log.debug("tolerance " + tolerance + " + failed."); | |
180 } | |
181 } | |
182 } | |
183 if (parameters == null) { | |
184 return false; | |
185 } | |
186 | |
187 // This is the paraterized function for a given km. | |
188 instance = function.instantiate(parameters); | |
189 | |
190 if (!checkOutliers) { | |
191 break; | |
192 } | |
193 | |
194 inputs.clear(); | |
195 | |
196 for (int i = 0, N = xs.size(); i < N; ++i) { | |
197 double y = instance.value(xs.getQuick(i)); | |
198 if (Double.isNaN(y)) { | |
199 y = Double.MAX_VALUE; | |
200 } | |
201 inputs.add(Double.valueOf(ys.getQuick(i) - y)); | |
202 } | |
203 | |
204 Integer outlier = GrubbsOutlier.findOutlier(inputs); | |
205 | |
206 if (outlier == null) { | |
207 break; | |
208 } | |
209 | |
210 int idx = outlier.intValue(); | |
211 removed.add( | |
212 qwdFactory.create( | |
213 xs.getQuick(idx), ys.getQuick(idx))); | |
214 xs.remove(idx); | |
215 ys.remove(idx); | |
216 } | |
217 | |
218 StandardDeviation stdDev = new StandardDeviation(); | |
219 | |
220 referenced = new QWD[xs.size()]; | |
221 for (int i = 0; i < referenced.length; ++i) { | |
222 QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i)); | |
223 | |
224 if (qwd == null) { | |
225 log.warn("QW creation failed!"); | |
226 } | |
227 else { | |
228 referenced[i] = qwd; | |
229 double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0; | |
230 qwd.setDeltaW(dw); | |
231 stdDev.increment(dw); | |
232 } | |
233 } | |
234 | |
235 standardDeviation = stdDev.getResult(); | |
236 | |
237 chiSqr = lmo.getChiSquare(); | |
238 | |
239 return true; | |
240 } | |
241 } | |
242 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 : |