Mercurial > dive4elements > river
comparison flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3552:1df6984628c3
S/Q: Extented the result data model of the S/Q calculation to
store the curve coefficients for each iteration step
of the outlier elimination.
flys-artifacts/trunk@5146 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Fri, 27 Jul 2012 12:36:09 +0000 |
parents | 49fe2ed03c12 |
children | 8d0f06b76e09 |
comparison
equal
deleted
inserted
replaced
3551:e7f1556192b3 | 3552:1df6984628c3 |
---|---|
13 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; | 13 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; |
14 | 14 |
15 import org.apache.log4j.Logger; | 15 import org.apache.log4j.Logger; |
16 | 16 |
17 public class Fitting | 17 public class Fitting |
18 implements Outlier.Callback | |
18 { | 19 { |
19 private static Logger log = Logger.getLogger(Fitting.class); | 20 private static Logger log = Logger.getLogger(Fitting.class); |
20 | 21 |
22 public interface Callback { | |
23 | |
24 void afterIteration( | |
25 double [] parameters, | |
26 SQ [] measurements, | |
27 SQ [] outliers, | |
28 double standardDeviation, | |
29 double chiSqr); | |
30 } // interfacte | |
31 | |
21 protected Function function; | 32 protected Function function; |
22 | 33 |
23 protected double [] parameters; | 34 protected double [] coeffs; |
24 | 35 |
25 protected double stdDevFactor; | 36 protected de.intevation.flys.artifacts.math.Function instance; |
37 | |
38 protected List<SQ> remainings; | |
39 protected List<SQ> outliers; | |
26 | 40 |
27 protected double standardDeviation; | 41 protected double standardDeviation; |
28 | 42 protected double stdDevFactor; |
29 protected double chiSqr; | 43 protected double chiSqr; |
30 | 44 |
31 protected SQ [] remaining; | 45 protected Callback callback; |
32 | |
33 protected List<SQ []> outliers; | |
34 | 46 |
35 public Fitting() { | 47 public Fitting() { |
48 remainings = new ArrayList<SQ>(); | |
49 outliers = new ArrayList<SQ>(); | |
36 } | 50 } |
37 | 51 |
38 public Fitting(Function function, double stdDevFactor) { | 52 public Fitting(Function function, double stdDevFactor) { |
53 this(); | |
39 this.function = function; | 54 this.function = function; |
40 this.stdDevFactor = stdDevFactor; | 55 this.stdDevFactor = stdDevFactor; |
41 } | 56 } |
42 | 57 |
43 public Function getFunction() { | 58 public Function getFunction() { |
46 | 61 |
47 public void setFunction(Function function) { | 62 public void setFunction(Function function) { |
48 this.function = function; | 63 this.function = function; |
49 } | 64 } |
50 | 65 |
51 public double [] getParameters() { | |
52 return parameters; | |
53 } | |
54 | |
55 public void setParameters(double [] parameters) { | |
56 this.parameters = parameters; | |
57 } | |
58 | |
59 public double getStdDevFactor() { | 66 public double getStdDevFactor() { |
60 return stdDevFactor; | 67 return stdDevFactor; |
61 } | 68 } |
62 | 69 |
63 public void setStdDevFactor(double stdDevFactor) { | 70 public void setStdDevFactor(double stdDevFactor) { |
64 this.stdDevFactor = stdDevFactor; | 71 this.stdDevFactor = stdDevFactor; |
65 } | 72 } |
66 | 73 |
67 public double getStandardDeviation() { | 74 @Override |
68 return standardDeviation; | 75 public void initialize(Iterator<SQ> good) throws MathException { |
76 | |
77 LevenbergMarquardtOptimizer lmo = | |
78 new LevenbergMarquardtOptimizer(); | |
79 | |
80 CurveFitter cf = new CurveFitter(lmo); | |
81 while (good.hasNext()) { | |
82 SQ sq = good.next(); | |
83 cf.addObservedPoint(sq.getQ(), sq.getS()); | |
84 } | |
85 | |
86 coeffs = cf.fit( | |
87 function, function.getInitialGuess()); | |
88 | |
89 instance = function.instantiate(coeffs); | |
90 | |
91 chiSqr = lmo.getChiSquare(); | |
92 | |
69 } | 93 } |
70 | 94 |
71 public void setStandardDeviation(double standardDeviation) { | 95 @Override |
96 public double eval(SQ sq) { | |
97 double s = instance.value(sq.q); | |
98 return sq.s - s; | |
99 } | |
100 | |
101 @Override | |
102 public void outlier(SQ sq) { | |
103 outliers.add(sq); | |
104 } | |
105 | |
106 @Override | |
107 public void remaining(SQ sq) { | |
108 remainings.add(sq); | |
109 } | |
110 | |
111 @Override | |
112 public void standardDeviation(double standardDeviation) { | |
72 this.standardDeviation = standardDeviation; | 113 this.standardDeviation = standardDeviation; |
73 } | 114 } |
74 | 115 |
75 public double getChiSqr() { | 116 @Override |
76 return chiSqr; | 117 public void iterationFinished() { |
77 } | 118 if (log.isDebugEnabled()) { |
78 | 119 log.debug("iterationFinished ----"); |
79 public void setChiSqr(double chiSqr) { | 120 log.debug(" num remainings: " + remainings.size()); |
80 this.chiSqr = chiSqr; | 121 log.debug(" num outliers: " + outliers.size()); |
81 } | 122 log.debug(" standardDeviation: " + standardDeviation); |
82 | 123 log.debug(" Chi^2: " + chiSqr); |
83 public SQ [] getRemaining() { | 124 log.debug("---- iterationFinished"); |
84 return remaining; | 125 } |
85 } | 126 callback.afterIteration( |
86 | 127 coeffs, |
87 public void setRemaining(SQ [] remaining) { | 128 remainings.toArray(new SQ[remainings.size()]), |
88 this.remaining = remaining; | 129 outliers.toArray(new SQ[outliers.size()]), |
89 } | 130 standardDeviation, |
90 | 131 chiSqr); |
91 public List<SQ []> getOutliers() { | 132 remainings.clear(); |
92 return outliers; | 133 outliers.clear(); |
93 } | |
94 | |
95 public void setOutliers(List<SQ []> outliers) { | |
96 this.outliers = outliers; | |
97 } | |
98 | |
99 public void reset() { | |
100 outliers = null; | |
101 remaining = null; | |
102 parameters = null; | |
103 standardDeviation = 0d; | |
104 standardDeviation = 0d; | |
105 chiSqr = 0d; | |
106 } | 134 } |
107 | 135 |
108 protected static final List<SQ> onlyValid(List<SQ> sqs) { | 136 protected static final List<SQ> onlyValid(List<SQ> sqs) { |
109 | 137 |
110 List<SQ> good = new ArrayList<SQ>(sqs.size()); | 138 List<SQ> good = new ArrayList<SQ>(sqs.size()); |
116 } | 144 } |
117 | 145 |
118 return good; | 146 return good; |
119 } | 147 } |
120 | 148 |
121 public boolean fit(List<SQ> sqs) { | 149 public boolean fit(List<SQ> sqs, Callback callback) { |
122 | 150 |
123 sqs = onlyValid(sqs); | 151 sqs = onlyValid(sqs); |
124 | 152 |
125 if (sqs.size() < 2) { | 153 if (sqs.size() < 2) { |
126 log.warn("Too less points for fitting."); | 154 log.warn("Too less points for fitting."); |
127 return false; | 155 return false; |
128 } | 156 } |
129 | 157 |
130 final LevenbergMarquardtOptimizer lmo = | 158 this.callback = callback; |
131 new LevenbergMarquardtOptimizer(); | |
132 | |
133 CurveFitter cf = new CurveFitter(lmo); | |
134 | |
135 for (SQ sq: sqs) { | |
136 cf.addObservedPoint(sq.getQ(), sq.getS()); | |
137 } | |
138 | 159 |
139 try { | 160 try { |
140 parameters = cf.fit(function, function.getInitialGuess()); | 161 Outlier.detectOutliers(this, sqs, stdDevFactor); |
141 } | |
142 catch (MathException me) { | |
143 log.warn(me); | |
144 return false; | |
145 } | |
146 | |
147 chiSqr = lmo.getChiSquare(); | |
148 | |
149 final de.intevation.flys.artifacts.math.Function [] instance = { | |
150 function.instantiate(parameters) | |
151 }; | |
152 | |
153 try { | |
154 remaining = Outlier.detectOutliers( | |
155 new Outlier.Callback() { | |
156 | |
157 List<List<SQ>> outliers = | |
158 new ArrayList<List<SQ>>(); | |
159 | |
160 int currentIteration; | |
161 | |
162 @Override | |
163 public double eval(SQ sq) { | |
164 double s = instance[0].value(sq.q); | |
165 return s - sq.s; | |
166 } | |
167 | |
168 @Override | |
169 public void iteration(int i) { | |
170 currentIteration = i; | |
171 } | |
172 | |
173 @Override | |
174 public void outlier(SQ sq) { | |
175 if (currentIteration > outliers.size()) { | |
176 outliers.add(new ArrayList<SQ>(2)); | |
177 } | |
178 outliers.get(currentIteration-1).add(sq); | |
179 } | |
180 | |
181 @Override | |
182 public void standardDeviation(double stdDev) { | |
183 setStandardDeviation(stdDev); | |
184 } | |
185 | |
186 @Override | |
187 public void reinitialize(Iterator<SQ> good) | |
188 throws MathException | |
189 { | |
190 CurveFitter cf = new CurveFitter(lmo); | |
191 while (good.hasNext()) { | |
192 SQ sq = good.next(); | |
193 cf.addObservedPoint(sq.getQ(), sq.getS()); | |
194 } | |
195 | |
196 parameters = cf.fit( | |
197 function, function.getInitialGuess()); | |
198 | |
199 instance[0] = function.instantiate(parameters); | |
200 | |
201 chiSqr = lmo.getChiSquare(); | |
202 } | |
203 | |
204 @Override | |
205 public void finished() { | |
206 List<SQ []> result = | |
207 new ArrayList<SQ []>(outliers.size()); | |
208 | |
209 for (List<SQ> ols: outliers) { | |
210 result.add(ols.toArray(new SQ[ols.size()])); | |
211 } | |
212 | |
213 setOutliers(result); | |
214 } | |
215 }, | |
216 sqs, | |
217 stdDevFactor); | |
218 } | 162 } |
219 catch (MathException me) { | 163 catch (MathException me) { |
220 log.warn(me); | 164 log.warn(me); |
221 return false; | 165 return false; |
222 } | 166 } |