diff flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/fixings/Fitting.java @ 3011:ab81ffd1343e

FixA: Reactivated rewrite of the outlier checks. flys-artifacts/trunk@4576 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Mon, 04 Jun 2012 16:44:56 +0000
parents
children 705d2058b682
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/fixings/Fitting.java	Mon Jun 04 16:44:56 2012 +0000
@@ -0,0 +1,170 @@
+package de.intevation.flys.artifacts.model.fixings;
+
+import de.intevation.flys.artifacts.math.fitting.Function;
+
+import de.intevation.flys.artifacts.math.Outlier;
+
+import de.intevation.flys.artifacts.math.Outlier.IndexedValue;
+import de.intevation.flys.artifacts.math.Outlier.Outliers;
+
+import org.apache.commons.math.MathException;
+
+import org.apache.commons.math.optimization.fitting.CurveFitter;
+
+import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
+
+import gnu.trove.TDoubleArrayList;
+
+import org.apache.log4j.Logger;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class Fitting
+{
+    private static Logger log = Logger.getLogger(Fitting.class);
+
+    /** Use instance of this factory to find meta infos for outliers. */
+    public interface QWFactory {
+
+        QW create(double q, double w);
+
+    } // interface QWFactory
+
+    protected Function      function;
+    protected QWFactory     qwFactory;
+    protected double        chiSqr;
+    protected double []     parameters;
+    protected ArrayList<QW> removed;
+
+
+    public Fitting() {
+        removed = new ArrayList<QW>();
+    }
+
+    public Fitting(Function function) {
+        this();
+        this.function = function;
+    }
+
+    public Fitting(Function function, QWFactory qwFactory) {
+        this(function);
+        this.qwFactory = qwFactory;
+    }
+
+    public Function getFunction() {
+        return function;
+    }
+
+    public void setFunction(Function function) {
+        this.function = function;
+    }
+
+    public double getChiSquare() {
+        return chiSqr;
+    }
+
+    public void reset() {
+        chiSqr     = 0.0;
+        parameters = null;
+        removed.clear();
+    }
+
+    public boolean hasOutliers() {
+        return !removed.isEmpty();
+    }
+
+    public List<QW> getOutliers() {
+        return removed;
+    }
+
+    public QW [] outliersToArray() {
+        return removed.toArray(new QW[removed.size()]);
+    }
+
+    public double [] getParameters() {
+        return parameters;
+    }
+
+    public boolean fit(double [] qs, double [] ws) {
+
+        TDoubleArrayList xs = new TDoubleArrayList(qs.length);
+        TDoubleArrayList ys = new TDoubleArrayList(ws.length);
+
+        for (int i = 0; i < qs.length; ++i) {
+            if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) {
+                xs.add(qs[i]);
+                ys.add(ws[i]);
+            }
+            else {
+                log.warn("remove invalid value " + qs[i] + " " + ws[i]);
+            }
+        }
+
+        if (xs.size() < 2) {
+            return false;
+        }
+
+        LevenbergMarquardtOptimizer lmo = new LevenbergMarquardtOptimizer();
+
+        double [] parameters;
+
+        List<IndexedValue> inputs = new ArrayList<IndexedValue>(xs.size());
+
+        for (;;) {
+            CurveFitter cf = new CurveFitter(lmo);
+
+            for (int i = 0, N = xs.size(); i < N; ++i) {
+                cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i));
+            }
+
+            try {
+                parameters = cf.fit(function, function.getInitialGuess());
+            }
+            catch (MathException me) {
+                log.warn(me);
+                return false;
+            }
+
+            if (qwFactory == null) {
+                break;
+            }
+
+            inputs.clear();
+
+            // This is the paraterized function for a given km.
+            de.intevation.flys.artifacts.math.Function instance =
+                function.instantiate(parameters);
+
+            for (int i = 0, N = xs.size(); i < N; ++i) {
+                double y = instance.value(xs.getQuick(i));
+                if (Double.isNaN(y)) {
+                    continue;
+                }
+                inputs.add(new IndexedValue(i, ys.getQuick(i) - y));
+            }
+
+            Outliers outliers = Outlier.findOutliers(inputs);
+
+            if (!outliers.hasOutliers()) {
+                break;
+            }
+
+            List<IndexedValue> rem = outliers.getRemoved();
+
+            for (int i = rem.size()-1; i >= 0; --i) {
+                int idx = rem.get(i).getIndex();
+                removed.add(
+                    qwFactory.create(
+                        xs.getQuick(idx), ys.getQuick(idx)));
+                xs.remove(idx);
+                ys.remove(idx);
+            }
+        }
+
+        chiSqr = lmo.getChiSquare();
+
+        return true;
+    }
+}
+// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org