Mercurial > dive4elements > river
comparison flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3188:1e46ced2bb57
SQ: Added fitting shell for SQ curves.
flys-artifacts/trunk@4803 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Tue, 26 Jun 2012 17:05:11 +0000 |
parents | |
children | 49fe2ed03c12 |
comparison
equal
deleted
inserted
replaced
3187:1e2733f749b5 | 3188:1e46ced2bb57 |
---|---|
1 package de.intevation.flys.artifacts.model.sq; | |
2 | |
3 import de.intevation.flys.artifacts.math.fitting.Function; | |
4 | |
5 import java.util.ArrayList; | |
6 import java.util.Iterator; | |
7 import java.util.List; | |
8 | |
9 import org.apache.commons.math.MathException; | |
10 | |
11 import org.apache.commons.math.optimization.fitting.CurveFitter; | |
12 | |
13 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; | |
14 | |
15 import org.apache.log4j.Logger; | |
16 | |
17 public class Fitting | |
18 { | |
19 private static Logger log = Logger.getLogger(Fitting.class); | |
20 | |
21 protected Function function; | |
22 | |
23 protected double [] parameters; | |
24 | |
25 protected double stdDevFactor; | |
26 | |
27 protected double standardDeviation; | |
28 | |
29 protected double chiSqr; | |
30 | |
31 protected List<SQ> remaining; | |
32 | |
33 protected List<List<SQ>> outliers; | |
34 | |
35 public Fitting() { | |
36 } | |
37 | |
38 public Fitting(Function function, double stdDevFactor) { | |
39 this.function = function; | |
40 this.stdDevFactor = stdDevFactor; | |
41 } | |
42 | |
43 public Function getFunction() { | |
44 return function; | |
45 } | |
46 | |
47 public void setFunction(Function function) { | |
48 this.function = function; | |
49 } | |
50 | |
51 public double [] getParameters() { | |
52 return parameters; | |
53 } | |
54 | |
55 public void setParameters(double [] parameters) { | |
56 this.parameters = parameters; | |
57 } | |
58 | |
59 public double getStdDevFactor() { | |
60 return stdDevFactor; | |
61 } | |
62 | |
63 public void setStdDevFactor(double stdDevFactor) { | |
64 this.stdDevFactor = stdDevFactor; | |
65 } | |
66 | |
67 public double getStandardDeviation() { | |
68 return standardDeviation; | |
69 } | |
70 | |
71 public void setStandardDeviation(double standardDeviation) { | |
72 this.standardDeviation = standardDeviation; | |
73 } | |
74 | |
75 public double getChiSqr() { | |
76 return chiSqr; | |
77 } | |
78 | |
79 public void setChiSqr(double chiSqr) { | |
80 this.chiSqr = chiSqr; | |
81 } | |
82 | |
83 public List<SQ> getRemaining() { | |
84 return remaining; | |
85 } | |
86 | |
87 public void setRemaining(List<SQ> remaining) { | |
88 this.remaining = remaining; | |
89 } | |
90 | |
91 public List<List<SQ>> getOutliers() { | |
92 return outliers; | |
93 } | |
94 | |
95 public void setOutliers(List<List<SQ>> outliers) { | |
96 this.outliers = outliers; | |
97 } | |
98 | |
99 public void reset() { | |
100 outliers = null; | |
101 remaining = null; | |
102 parameters = null; | |
103 standardDeviation = 0d; | |
104 standardDeviation = 0d; | |
105 chiSqr = 0d; | |
106 } | |
107 | |
108 protected static final List<SQ> onlyValid(List<SQ> sqs) { | |
109 | |
110 List<SQ> good = new ArrayList<SQ>(sqs.size()); | |
111 | |
112 for (SQ sq: sqs) { | |
113 if (sq.isValid()) { | |
114 good.add(sq); | |
115 } | |
116 } | |
117 | |
118 return good; | |
119 } | |
120 | |
121 public boolean fit(List<SQ> sqs) { | |
122 | |
123 sqs = onlyValid(sqs); | |
124 | |
125 if (sqs.size() < 2) { | |
126 log.warn("Too less points for fitting."); | |
127 return false; | |
128 } | |
129 | |
130 final LevenbergMarquardtOptimizer lmo = | |
131 new LevenbergMarquardtOptimizer(); | |
132 | |
133 CurveFitter cf = new CurveFitter(lmo); | |
134 | |
135 for (SQ sq: sqs) { | |
136 cf.addObservedPoint(sq.getQ(), sq.getS()); | |
137 } | |
138 | |
139 try { | |
140 parameters = cf.fit(function, function.getInitialGuess()); | |
141 } | |
142 catch (MathException me) { | |
143 log.warn(me); | |
144 return false; | |
145 } | |
146 | |
147 chiSqr = lmo.getChiSquare(); | |
148 | |
149 final de.intevation.flys.artifacts.math.Function [] instance = { | |
150 function.instantiate(parameters) | |
151 }; | |
152 | |
153 try { | |
154 remaining = Outlier.detectOutliers( | |
155 new Outlier.Callback() { | |
156 | |
157 List<List<SQ>> outliers = | |
158 new ArrayList<List<SQ>>(); | |
159 | |
160 int currentIteration; | |
161 | |
162 @Override | |
163 public double eval(SQ sq) { | |
164 double s = instance[0].value(sq.q); | |
165 return s - sq.s; | |
166 } | |
167 | |
168 @Override | |
169 public void iteration(int i) { | |
170 currentIteration = i; | |
171 } | |
172 | |
173 @Override | |
174 public void outlier(SQ sq) { | |
175 if (currentIteration > outliers.size()) { | |
176 outliers.add(new ArrayList<SQ>(2)); | |
177 } | |
178 outliers.get(currentIteration-1).add(sq); | |
179 } | |
180 | |
181 @Override | |
182 public void standardDeviation(double stdDev) { | |
183 setStandardDeviation(stdDev); | |
184 } | |
185 | |
186 @Override | |
187 public void reinitialize(Iterator<SQ> good) | |
188 throws MathException | |
189 { | |
190 CurveFitter cf = new CurveFitter(lmo); | |
191 while (good.hasNext()) { | |
192 SQ sq = good.next(); | |
193 cf.addObservedPoint(sq.getQ(), sq.getS()); | |
194 } | |
195 | |
196 parameters = cf.fit( | |
197 function, function.getInitialGuess()); | |
198 | |
199 instance[0] = function.instantiate(parameters); | |
200 | |
201 chiSqr = lmo.getChiSquare(); | |
202 } | |
203 | |
204 @Override | |
205 public void finished() { | |
206 setOutliers(outliers); | |
207 } | |
208 }, | |
209 sqs, | |
210 stdDevFactor); | |
211 } | |
212 catch (MathException me) { | |
213 log.warn(me); | |
214 return false; | |
215 } | |
216 | |
217 return true; | |
218 } | |
219 } | |
220 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 : |