Mercurial > dive4elements > river
comparison flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/fixings/Fitting.java @ 3011:ab81ffd1343e
FixA: Reactivated rewrite of the outlier checks.
flys-artifacts/trunk@4576 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Mon, 04 Jun 2012 16:44:56 +0000 |
parents | |
children | 705d2058b682 |
comparison
equal
deleted
inserted
replaced
3010:05a3fe8800b3 | 3011:ab81ffd1343e |
---|---|
1 package de.intevation.flys.artifacts.model.fixings; | |
2 | |
3 import de.intevation.flys.artifacts.math.fitting.Function; | |
4 | |
5 import de.intevation.flys.artifacts.math.Outlier; | |
6 | |
7 import de.intevation.flys.artifacts.math.Outlier.IndexedValue; | |
8 import de.intevation.flys.artifacts.math.Outlier.Outliers; | |
9 | |
10 import org.apache.commons.math.MathException; | |
11 | |
12 import org.apache.commons.math.optimization.fitting.CurveFitter; | |
13 | |
14 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; | |
15 | |
16 import gnu.trove.TDoubleArrayList; | |
17 | |
18 import org.apache.log4j.Logger; | |
19 | |
20 import java.util.ArrayList; | |
21 import java.util.List; | |
22 | |
23 public class Fitting | |
24 { | |
25 private static Logger log = Logger.getLogger(Fitting.class); | |
26 | |
27 /** Use instance of this factory to find meta infos for outliers. */ | |
28 public interface QWFactory { | |
29 | |
30 QW create(double q, double w); | |
31 | |
32 } // interface QWFactory | |
33 | |
34 protected Function function; | |
35 protected QWFactory qwFactory; | |
36 protected double chiSqr; | |
37 protected double [] parameters; | |
38 protected ArrayList<QW> removed; | |
39 | |
40 | |
41 public Fitting() { | |
42 removed = new ArrayList<QW>(); | |
43 } | |
44 | |
45 public Fitting(Function function) { | |
46 this(); | |
47 this.function = function; | |
48 } | |
49 | |
50 public Fitting(Function function, QWFactory qwFactory) { | |
51 this(function); | |
52 this.qwFactory = qwFactory; | |
53 } | |
54 | |
55 public Function getFunction() { | |
56 return function; | |
57 } | |
58 | |
59 public void setFunction(Function function) { | |
60 this.function = function; | |
61 } | |
62 | |
63 public double getChiSquare() { | |
64 return chiSqr; | |
65 } | |
66 | |
67 public void reset() { | |
68 chiSqr = 0.0; | |
69 parameters = null; | |
70 removed.clear(); | |
71 } | |
72 | |
73 public boolean hasOutliers() { | |
74 return !removed.isEmpty(); | |
75 } | |
76 | |
77 public List<QW> getOutliers() { | |
78 return removed; | |
79 } | |
80 | |
81 public QW [] outliersToArray() { | |
82 return removed.toArray(new QW[removed.size()]); | |
83 } | |
84 | |
85 public double [] getParameters() { | |
86 return parameters; | |
87 } | |
88 | |
89 public boolean fit(double [] qs, double [] ws) { | |
90 | |
91 TDoubleArrayList xs = new TDoubleArrayList(qs.length); | |
92 TDoubleArrayList ys = new TDoubleArrayList(ws.length); | |
93 | |
94 for (int i = 0; i < qs.length; ++i) { | |
95 if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { | |
96 xs.add(qs[i]); | |
97 ys.add(ws[i]); | |
98 } | |
99 else { | |
100 log.warn("remove invalid value " + qs[i] + " " + ws[i]); | |
101 } | |
102 } | |
103 | |
104 if (xs.size() < 2) { | |
105 return false; | |
106 } | |
107 | |
108 LevenbergMarquardtOptimizer lmo = new LevenbergMarquardtOptimizer(); | |
109 | |
110 double [] parameters; | |
111 | |
112 List<IndexedValue> inputs = new ArrayList<IndexedValue>(xs.size()); | |
113 | |
114 for (;;) { | |
115 CurveFitter cf = new CurveFitter(lmo); | |
116 | |
117 for (int i = 0, N = xs.size(); i < N; ++i) { | |
118 cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); | |
119 } | |
120 | |
121 try { | |
122 parameters = cf.fit(function, function.getInitialGuess()); | |
123 } | |
124 catch (MathException me) { | |
125 log.warn(me); | |
126 return false; | |
127 } | |
128 | |
129 if (qwFactory == null) { | |
130 break; | |
131 } | |
132 | |
133 inputs.clear(); | |
134 | |
135 // This is the paraterized function for a given km. | |
136 de.intevation.flys.artifacts.math.Function instance = | |
137 function.instantiate(parameters); | |
138 | |
139 for (int i = 0, N = xs.size(); i < N; ++i) { | |
140 double y = instance.value(xs.getQuick(i)); | |
141 if (Double.isNaN(y)) { | |
142 continue; | |
143 } | |
144 inputs.add(new IndexedValue(i, ys.getQuick(i) - y)); | |
145 } | |
146 | |
147 Outliers outliers = Outlier.findOutliers(inputs); | |
148 | |
149 if (!outliers.hasOutliers()) { | |
150 break; | |
151 } | |
152 | |
153 List<IndexedValue> rem = outliers.getRemoved(); | |
154 | |
155 for (int i = rem.size()-1; i >= 0; --i) { | |
156 int idx = rem.get(i).getIndex(); | |
157 removed.add( | |
158 qwFactory.create( | |
159 xs.getQuick(idx), ys.getQuick(idx))); | |
160 xs.remove(idx); | |
161 ys.remove(idx); | |
162 } | |
163 } | |
164 | |
165 chiSqr = lmo.getChiSquare(); | |
166 | |
167 return true; | |
168 } | |
169 } | |
170 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 : |