comparison flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3552:1df6984628c3

S/Q: Extented the result data model of the S/Q calculation to store the curve coefficients for each iteration step of the outlier elimination. flys-artifacts/trunk@5146 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Fri, 27 Jul 2012 12:36:09 +0000
parents 49fe2ed03c12
children 8d0f06b76e09
comparison
equal deleted inserted replaced
3551:e7f1556192b3 3552:1df6984628c3
13 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; 13 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
14 14
15 import org.apache.log4j.Logger; 15 import org.apache.log4j.Logger;
16 16
17 public class Fitting 17 public class Fitting
18 implements Outlier.Callback
18 { 19 {
19 private static Logger log = Logger.getLogger(Fitting.class); 20 private static Logger log = Logger.getLogger(Fitting.class);
20 21
22 public interface Callback {
23
24 void afterIteration(
25 double [] parameters,
26 SQ [] measurements,
27 SQ [] outliers,
28 double standardDeviation,
29 double chiSqr);
30 } // interfacte
31
21 protected Function function; 32 protected Function function;
22 33
23 protected double [] parameters; 34 protected double [] coeffs;
24 35
25 protected double stdDevFactor; 36 protected de.intevation.flys.artifacts.math.Function instance;
37
38 protected List<SQ> remainings;
39 protected List<SQ> outliers;
26 40
27 protected double standardDeviation; 41 protected double standardDeviation;
28 42 protected double stdDevFactor;
29 protected double chiSqr; 43 protected double chiSqr;
30 44
31 protected SQ [] remaining; 45 protected Callback callback;
32
33 protected List<SQ []> outliers;
34 46
35 public Fitting() { 47 public Fitting() {
48 remainings = new ArrayList<SQ>();
49 outliers = new ArrayList<SQ>();
36 } 50 }
37 51
38 public Fitting(Function function, double stdDevFactor) { 52 public Fitting(Function function, double stdDevFactor) {
53 this();
39 this.function = function; 54 this.function = function;
40 this.stdDevFactor = stdDevFactor; 55 this.stdDevFactor = stdDevFactor;
41 } 56 }
42 57
43 public Function getFunction() { 58 public Function getFunction() {
46 61
47 public void setFunction(Function function) { 62 public void setFunction(Function function) {
48 this.function = function; 63 this.function = function;
49 } 64 }
50 65
51 public double [] getParameters() {
52 return parameters;
53 }
54
55 public void setParameters(double [] parameters) {
56 this.parameters = parameters;
57 }
58
59 public double getStdDevFactor() { 66 public double getStdDevFactor() {
60 return stdDevFactor; 67 return stdDevFactor;
61 } 68 }
62 69
63 public void setStdDevFactor(double stdDevFactor) { 70 public void setStdDevFactor(double stdDevFactor) {
64 this.stdDevFactor = stdDevFactor; 71 this.stdDevFactor = stdDevFactor;
65 } 72 }
66 73
67 public double getStandardDeviation() { 74 @Override
68 return standardDeviation; 75 public void initialize(Iterator<SQ> good) throws MathException {
76
77 LevenbergMarquardtOptimizer lmo =
78 new LevenbergMarquardtOptimizer();
79
80 CurveFitter cf = new CurveFitter(lmo);
81 while (good.hasNext()) {
82 SQ sq = good.next();
83 cf.addObservedPoint(sq.getQ(), sq.getS());
84 }
85
86 coeffs = cf.fit(
87 function, function.getInitialGuess());
88
89 instance = function.instantiate(coeffs);
90
91 chiSqr = lmo.getChiSquare();
92
69 } 93 }
70 94
71 public void setStandardDeviation(double standardDeviation) { 95 @Override
96 public double eval(SQ sq) {
97 double s = instance.value(sq.q);
98 return sq.s - s;
99 }
100
101 @Override
102 public void outlier(SQ sq) {
103 outliers.add(sq);
104 }
105
106 @Override
107 public void remaining(SQ sq) {
108 remainings.add(sq);
109 }
110
111 @Override
112 public void standardDeviation(double standardDeviation) {
72 this.standardDeviation = standardDeviation; 113 this.standardDeviation = standardDeviation;
73 } 114 }
74 115
75 public double getChiSqr() { 116 @Override
76 return chiSqr; 117 public void iterationFinished() {
77 } 118 if (log.isDebugEnabled()) {
78 119 log.debug("iterationFinished ----");
79 public void setChiSqr(double chiSqr) { 120 log.debug(" num remainings: " + remainings.size());
80 this.chiSqr = chiSqr; 121 log.debug(" num outliers: " + outliers.size());
81 } 122 log.debug(" standardDeviation: " + standardDeviation);
82 123 log.debug(" Chi^2: " + chiSqr);
83 public SQ [] getRemaining() { 124 log.debug("---- iterationFinished");
84 return remaining; 125 }
85 } 126 callback.afterIteration(
86 127 coeffs,
87 public void setRemaining(SQ [] remaining) { 128 remainings.toArray(new SQ[remainings.size()]),
88 this.remaining = remaining; 129 outliers.toArray(new SQ[outliers.size()]),
89 } 130 standardDeviation,
90 131 chiSqr);
91 public List<SQ []> getOutliers() { 132 remainings.clear();
92 return outliers; 133 outliers.clear();
93 }
94
95 public void setOutliers(List<SQ []> outliers) {
96 this.outliers = outliers;
97 }
98
99 public void reset() {
100 outliers = null;
101 remaining = null;
102 parameters = null;
103 standardDeviation = 0d;
104 standardDeviation = 0d;
105 chiSqr = 0d;
106 } 134 }
107 135
108 protected static final List<SQ> onlyValid(List<SQ> sqs) { 136 protected static final List<SQ> onlyValid(List<SQ> sqs) {
109 137
110 List<SQ> good = new ArrayList<SQ>(sqs.size()); 138 List<SQ> good = new ArrayList<SQ>(sqs.size());
116 } 144 }
117 145
118 return good; 146 return good;
119 } 147 }
120 148
121 public boolean fit(List<SQ> sqs) { 149 public boolean fit(List<SQ> sqs, Callback callback) {
122 150
123 sqs = onlyValid(sqs); 151 sqs = onlyValid(sqs);
124 152
125 if (sqs.size() < 2) { 153 if (sqs.size() < 2) {
126 log.warn("Too less points for fitting."); 154 log.warn("Too less points for fitting.");
127 return false; 155 return false;
128 } 156 }
129 157
130 final LevenbergMarquardtOptimizer lmo = 158 this.callback = callback;
131 new LevenbergMarquardtOptimizer();
132
133 CurveFitter cf = new CurveFitter(lmo);
134
135 for (SQ sq: sqs) {
136 cf.addObservedPoint(sq.getQ(), sq.getS());
137 }
138 159
139 try { 160 try {
140 parameters = cf.fit(function, function.getInitialGuess()); 161 Outlier.detectOutliers(this, sqs, stdDevFactor);
141 }
142 catch (MathException me) {
143 log.warn(me);
144 return false;
145 }
146
147 chiSqr = lmo.getChiSquare();
148
149 final de.intevation.flys.artifacts.math.Function [] instance = {
150 function.instantiate(parameters)
151 };
152
153 try {
154 remaining = Outlier.detectOutliers(
155 new Outlier.Callback() {
156
157 List<List<SQ>> outliers =
158 new ArrayList<List<SQ>>();
159
160 int currentIteration;
161
162 @Override
163 public double eval(SQ sq) {
164 double s = instance[0].value(sq.q);
165 return s - sq.s;
166 }
167
168 @Override
169 public void iteration(int i) {
170 currentIteration = i;
171 }
172
173 @Override
174 public void outlier(SQ sq) {
175 if (currentIteration > outliers.size()) {
176 outliers.add(new ArrayList<SQ>(2));
177 }
178 outliers.get(currentIteration-1).add(sq);
179 }
180
181 @Override
182 public void standardDeviation(double stdDev) {
183 setStandardDeviation(stdDev);
184 }
185
186 @Override
187 public void reinitialize(Iterator<SQ> good)
188 throws MathException
189 {
190 CurveFitter cf = new CurveFitter(lmo);
191 while (good.hasNext()) {
192 SQ sq = good.next();
193 cf.addObservedPoint(sq.getQ(), sq.getS());
194 }
195
196 parameters = cf.fit(
197 function, function.getInitialGuess());
198
199 instance[0] = function.instantiate(parameters);
200
201 chiSqr = lmo.getChiSquare();
202 }
203
204 @Override
205 public void finished() {
206 List<SQ []> result =
207 new ArrayList<SQ []>(outliers.size());
208
209 for (List<SQ> ols: outliers) {
210 result.add(ols.toArray(new SQ[ols.size()]));
211 }
212
213 setOutliers(result);
214 }
215 },
216 sqs,
217 stdDevFactor);
218 } 162 }
219 catch (MathException me) { 163 catch (MathException me) {
220 log.warn(me); 164 log.warn(me);
221 return false; 165 return false;
222 } 166 }

http://dive4elements.wald.intevation.org