changeset 3188:1e46ced2bb57

SQ: Added fitting shell for SQ curves. flys-artifacts/trunk@4803 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Sascha L. Teichmann <sascha.teichmann@intevation.de>
date Tue, 26 Jun 2012 17:05:11 +0000
parents 1e2733f749b5
children 89dc2db3a202
files flys-artifacts/ChangeLog flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Outlier.java flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/SQ.java flys-artifacts/src/main/java/de/intevation/flys/exports/XYChartGenerator.java
diffstat 5 files changed, 276 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/flys-artifacts/ChangeLog	Tue Jun 26 16:00:59 2012 +0000
+++ b/flys-artifacts/ChangeLog	Tue Jun 26 17:05:11 2012 +0000
@@ -1,3 +1,17 @@
+2012-06-26	Sascha L. Teichmann	<sascha.teichmann@intevation.de>
+
+	* src/main/java/de/intevation/flys/artifacts/model/sq/SQ.java:
+	  Added method to validate point.
+
+	* src/main/java/de/intevation/flys/artifacts/model/sq/Outlier.java:
+	  Added method to callback to re-initialize the function to fit.
+
+	* src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java:
+	  New. Shell for fitting of one SQ fraction.
+
+	* src/main/java/de/intevation/flys/exports/XYChartGenerator.java:
+	  Removed superfluous import.
+
 2012-06-26	Sascha L. Teichmann	<sascha.teichmann@intevation.de>
 
 	* src/main/java/de/intevation/flys/artifacts/model/sq/Outlier.java:
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java	Tue Jun 26 17:05:11 2012 +0000
@@ -0,0 +1,220 @@
+package de.intevation.flys.artifacts.model.sq;
+
+import de.intevation.flys.artifacts.math.fitting.Function;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.math.MathException;
+
+import org.apache.commons.math.optimization.fitting.CurveFitter;
+
+import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
+
+import org.apache.log4j.Logger;
+
+public class Fitting
+{
+    private static Logger log = Logger.getLogger(Fitting.class);
+
+    protected Function function;
+
+    protected double [] parameters;
+
+    protected double stdDevFactor;
+
+    protected double standardDeviation;
+
+    protected double chiSqr;
+
+    protected List<SQ> remaining;
+
+    protected List<List<SQ>> outliers;
+
+    public Fitting() {
+    }
+
+    public Fitting(Function function, double stdDevFactor) {
+        this.function     = function;
+        this.stdDevFactor = stdDevFactor;
+    }
+
+    public Function getFunction() {
+        return function;
+    }
+
+    public void setFunction(Function function) {
+        this.function = function;
+    }
+
+    public double [] getParameters() {
+        return parameters;
+    }
+
+    public void setParameters(double [] parameters) {
+        this.parameters = parameters;
+    }
+
+    public double getStdDevFactor() {
+        return stdDevFactor;
+    }
+
+    public void setStdDevFactor(double stdDevFactor) {
+        this.stdDevFactor = stdDevFactor;
+    }
+
+    public double getStandardDeviation() {
+        return standardDeviation;
+    }
+
+    public void setStandardDeviation(double standardDeviation) {
+        this.standardDeviation = standardDeviation;
+    }
+
+    public double getChiSqr() {
+        return chiSqr;
+    }
+
+    public void setChiSqr(double chiSqr) {
+        this.chiSqr = chiSqr;
+    }
+
+    public List<SQ> getRemaining() {
+        return remaining;
+    }
+
+    public void setRemaining(List<SQ> remaining) {
+        this.remaining = remaining;
+    }
+
+    public List<List<SQ>> getOutliers() {
+        return outliers;
+    }
+
+    public void setOutliers(List<List<SQ>> outliers) {
+        this.outliers = outliers;
+    }
+
+    public void reset() {
+        outliers          = null;
+        remaining         = null;
+        parameters        = null;
+        standardDeviation = 0d;
+        standardDeviation = 0d;
+        chiSqr            = 0d;
+    }
+
+    protected static final List<SQ> onlyValid(List<SQ> sqs) {
+
+        List<SQ> good = new ArrayList<SQ>(sqs.size());
+
+        for (SQ sq: sqs) {
+            if (sq.isValid()) {
+                good.add(sq);
+            }
+        }
+
+        return good;
+    }
+
+    public boolean fit(List<SQ> sqs) {
+
+        sqs = onlyValid(sqs);
+
+        if (sqs.size() < 2) {
+            log.warn("Too less points for fitting.");
+            return false;
+        }
+
+        final LevenbergMarquardtOptimizer lmo =
+            new LevenbergMarquardtOptimizer();
+
+        CurveFitter cf = new CurveFitter(lmo);
+
+        for (SQ sq: sqs) {
+            cf.addObservedPoint(sq.getQ(), sq.getS());
+        }
+
+        try {
+            parameters = cf.fit(function, function.getInitialGuess());
+        }
+        catch (MathException me) {
+            log.warn(me);
+            return false;
+        }
+
+        chiSqr = lmo.getChiSquare();
+
+        final de.intevation.flys.artifacts.math.Function [] instance = {
+            function.instantiate(parameters)
+        };
+
+        try {
+            remaining = Outlier.detectOutliers(
+                new Outlier.Callback() {
+
+                    List<List<SQ>> outliers =
+                        new ArrayList<List<SQ>>();
+
+                    int currentIteration;
+
+                    @Override
+                    public double eval(SQ sq) {
+                        double s = instance[0].value(sq.q);
+                        return s - sq.s;
+                    }
+
+                    @Override
+                    public void iteration(int i) {
+                        currentIteration = i;
+                    }
+
+                    @Override
+                    public void outlier(SQ sq) {
+                        if (currentIteration > outliers.size()) {
+                            outliers.add(new ArrayList<SQ>(2));
+                        }
+                        outliers.get(currentIteration-1).add(sq);
+                    }
+
+                    @Override
+                    public void standardDeviation(double stdDev) {
+                        setStandardDeviation(stdDev);
+                    }
+
+                    @Override
+                    public void reinitialize(Iterator<SQ> good)
+                    throws MathException
+                    {
+                        CurveFitter cf = new CurveFitter(lmo);
+                        while (good.hasNext()) {
+                            SQ sq = good.next();
+                            cf.addObservedPoint(sq.getQ(), sq.getS());
+                        }
+
+                        parameters = cf.fit(
+                            function, function.getInitialGuess());
+
+                        instance[0] = function.instantiate(parameters);
+
+                        chiSqr = lmo.getChiSquare();
+                    }
+
+                    @Override
+                    public void finished() {
+                        setOutliers(outliers);
+                    }
+                },
+                sqs,
+                stdDevFactor);
+        }
+        catch (MathException me) {
+            log.warn(me);
+            return false;
+        }
+
+        return true;
+    }
+}
+// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :
--- a/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Outlier.java	Tue Jun 26 16:00:59 2012 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Outlier.java	Tue Jun 26 17:05:11 2012 +0000
@@ -1,8 +1,11 @@
 package de.intevation.flys.artifacts.model.sq;
 
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 
+import org.apache.commons.math.MathException;
+
 import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
 
 public class Outlier
@@ -17,6 +20,10 @@
 
         void outlier(SQ sq);
 
+        void standardDeviation(double stdDev);
+
+        void reinitialize(Iterator<SQ> good) throws MathException;
+
         void finished();
 
     } // interface Callback
@@ -34,7 +41,9 @@
         Callback callback,
         List<SQ> sqs,
         double   stdDevFactor
-    ) {
+    )
+    throws MathException
+    {
         List<EvalSQ> data = new ArrayList<EvalSQ>(sqs.size());
 
         for (SQ sq: sqs) {
@@ -43,7 +52,7 @@
 
         List<EvalSQ> good = new ArrayList<EvalSQ>(sqs.size());
 
-        for (int i = 0; i < MAX_ITERATIONS && data.size() > 2; ++i) {
+        for (int i = 1; i <= MAX_ITERATIONS && data.size() > 2; ++i) {
 
             StandardDeviation stdDev = new StandardDeviation();
 
@@ -51,7 +60,11 @@
                 stdDev.increment(esq.value = callback.eval(esq.sq));
             }
 
-            double accepted = stdDevFactor * stdDev.getResult();
+            double sd = stdDev.getResult();
+
+            callback.standardDeviation(sd);
+
+            double accepted = stdDevFactor * sd;
 
             callback.iteration(i);
 
@@ -68,6 +81,8 @@
                 break;
             }
 
+            callback.reinitialize(asSQIterator(good));
+
             List<EvalSQ> tmp = good;
             good = data;
             data = tmp;
@@ -84,5 +99,25 @@
 
         return result;
     }
+
+    protected static Iterator<SQ> asSQIterator(List<EvalSQ> esqs) {
+        final Iterator<EvalSQ> parent = esqs.iterator();
+        return new Iterator<SQ>() {
+            @Override
+            public boolean hasNext() {
+                return parent.hasNext();
+            }
+
+            @Override
+            public SQ next() {
+                return parent.next().sq;
+            }
+
+            @Override
+            public void remove() {
+                throw new UnsupportedOperationException();
+            }
+        };
+    }
 }
 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :
--- a/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/SQ.java	Tue Jun 26 16:00:59 2012 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/SQ.java	Tue Jun 26 17:05:11 2012 +0000
@@ -33,5 +33,9 @@
     public void setQ(double q) {
         this.q = q;
     }
+
+    public boolean isValid() {
+        return !Double.isNaN(s) && !Double.isNaN(q);
+    }
 }
 // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :
--- a/flys-artifacts/src/main/java/de/intevation/flys/exports/XYChartGenerator.java	Tue Jun 26 16:00:59 2012 +0000
+++ b/flys-artifacts/src/main/java/de/intevation/flys/exports/XYChartGenerator.java	Tue Jun 26 17:05:11 2012 +0000
@@ -5,7 +5,6 @@
 import java.awt.Font;
 import java.awt.Paint;
 import java.awt.Stroke;
-import java.awt.geom.Line2D;
 
 import java.text.NumberFormat;
 

http://dive4elements.wald.intevation.org