diff flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3188:1e46ced2bb57

SQ: Added fitting shell for SQ curves. flys-artifacts/trunk@4803 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Tue, 26 Jun 2012 17:05:11 +0000
parents
children 49fe2ed03c12
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java	Tue Jun 26 17:05:11 2012 +0000
@@ -0,0 +1,220 @@
+package de.intevation.flys.artifacts.model.sq;
+
+import de.intevation.flys.artifacts.math.fitting.Function;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.math.MathException;
+
+import org.apache.commons.math.optimization.fitting.CurveFitter;
+
+import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
+
+import org.apache.log4j.Logger;
+
+public class Fitting
+{
+    private static Logger log = Logger.getLogger(Fitting.class);
+
+    protected Function function;
+
+    protected double [] parameters;
+
+    protected double stdDevFactor;
+
+    protected double standardDeviation;
+
+    protected double chiSqr;
+
+    protected List<SQ> remaining;
+
+    protected List<List<SQ>> outliers;
+
+    public Fitting() {
+    }
+
+    public Fitting(Function function, double stdDevFactor) {
+        this.function     = function;
+        this.stdDevFactor = stdDevFactor;
+    }
+
+    public Function getFunction() {
+        return function;
+    }
+
+    public void setFunction(Function function) {
+        this.function = function;
+    }
+
+    public double [] getParameters() {
+        return parameters;
+    }
+
+    public void setParameters(double [] parameters) {
+        this.parameters = parameters;
+    }
+
+    public double getStdDevFactor() {
+        return stdDevFactor;
+    }
+
+    public void setStdDevFactor(double stdDevFactor) {
+        this.stdDevFactor = stdDevFactor;
+    }
+
+    public double getStandardDeviation() {
+        return standardDeviation;
+    }
+
+    public void setStandardDeviation(double standardDeviation) {
+        this.standardDeviation = standardDeviation;
+    }
+
+    public double getChiSqr() {
+        return chiSqr;
+    }
+
+    public void setChiSqr(double chiSqr) {
+        this.chiSqr = chiSqr;
+    }
+
+    public List<SQ> getRemaining() {
+        return remaining;
+    }
+
+    public void setRemaining(List<SQ> remaining) {
+        this.remaining = remaining;
+    }
+
+    public List<List<SQ>> getOutliers() {
+        return outliers;
+    }
+
+    public void setOutliers(List<List<SQ>> outliers) {
+        this.outliers = outliers;
+    }
+
+    public void reset() {
+        outliers          = null;
+        remaining         = null;
+        parameters        = null;
+        standardDeviation = 0d;
+        standardDeviation = 0d;
+        chiSqr            = 0d;
+    }
+
+    protected static final List<SQ> onlyValid(List<SQ> sqs) {
+
+        List<SQ> good = new ArrayList<SQ>(sqs.size());
+
+        for (SQ sq: sqs) {
+            if (sq.isValid()) {
+                good.add(sq);
+            }
+        }
+
+        return good;
+    }
+
+    public boolean fit(List<SQ> sqs) {
+
+        sqs = onlyValid(sqs);
+
+        if (sqs.size() < 2) {
+            log.warn("Too less points for fitting.");
+            return false;
+        }
+
+        final LevenbergMarquardtOptimizer lmo =
+            new LevenbergMarquardtOptimizer();
+
+        CurveFitter cf = new CurveFitter(lmo);
+
+        for (SQ sq: sqs) {
+            cf.addObservedPoint(sq.getQ(), sq.getS());
+        }
+
+        try {
+            parameters = cf.fit(function, function.getInitialGuess());
+        }
+        catch (MathException me) {
+            log.warn(me);
+            return false;
+        }
+
+        chiSqr = lmo.getChiSquare();
+
+        final de.intevation.flys.artifacts.math.Function [] instance = {
+            function.instantiate(parameters)
+        };
+
+        try {
+            remaining = Outlier.detectOutliers(
+                new Outlier.Callback() {
+
+                    List<List<SQ>> outliers =
+                        new ArrayList<List<SQ>>();
+
+                    int currentIteration;
+
+                    @Override
+                    public double eval(SQ sq) {
+                        double s = instance[0].value(sq.q);
+                        return s - sq.s;
+                    }
+
+                    @Override
+                    public void iteration(int i) {
+                        currentIteration = i;
+                    }
+
+                    @Override
+                    public void outlier(SQ sq) {
+                        if (currentIteration > outliers.size()) {
+                            outliers.add(new ArrayList<SQ>(2));
+                        }
+                        outliers.get(currentIteration-1).add(sq);
+                    }
+
+                    @Override
+                    public void standardDeviation(double stdDev) {
+                        setStandardDeviation(stdDev);
+                    }
+
+                    @Override
+                    public void reinitialize(Iterator<SQ> good)
+                    throws MathException
+                    {
+                        CurveFitter cf = new CurveFitter(lmo);
+                        while (good.hasNext()) {
+                            SQ sq = good.next();
+                            cf.addObservedPoint(sq.getQ(), sq.getS());
+                        }
+
+                        parameters = cf.fit(
+                            function, function.getInitialGuess());
+
+                        instance[0] = function.instantiate(parameters);
+
+                        chiSqr = lmo.getChiSquare();
+                    }
+
+                    @Override
+                    public void finished() {
+                        setOutliers(outliers);
+                    }
+                },
+                sqs,
+                stdDevFactor);
+        }
+        catch (MathException me) {
+            log.warn(me);
+            return false;
+        }
+
+        return true;
+    }
+}
+// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org