Mercurial > dive4elements > river
diff flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java @ 3188:1e46ced2bb57
SQ: Added fitting shell for SQ curves.
flys-artifacts/trunk@4803 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author | Sascha L. Teichmann <sascha.teichmann@intevation.de> |
---|---|
date | Tue, 26 Jun 2012 17:05:11 +0000 |
parents | |
children | 49fe2ed03c12 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/sq/Fitting.java Tue Jun 26 17:05:11 2012 +0000 @@ -0,0 +1,220 @@ +package de.intevation.flys.artifacts.model.sq; + +import de.intevation.flys.artifacts.math.fitting.Function; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.commons.math.MathException; + +import org.apache.commons.math.optimization.fitting.CurveFitter; + +import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; + +import org.apache.log4j.Logger; + +public class Fitting +{ + private static Logger log = Logger.getLogger(Fitting.class); + + protected Function function; + + protected double [] parameters; + + protected double stdDevFactor; + + protected double standardDeviation; + + protected double chiSqr; + + protected List<SQ> remaining; + + protected List<List<SQ>> outliers; + + public Fitting() { + } + + public Fitting(Function function, double stdDevFactor) { + this.function = function; + this.stdDevFactor = stdDevFactor; + } + + public Function getFunction() { + return function; + } + + public void setFunction(Function function) { + this.function = function; + } + + public double [] getParameters() { + return parameters; + } + + public void setParameters(double [] parameters) { + this.parameters = parameters; + } + + public double getStdDevFactor() { + return stdDevFactor; + } + + public void setStdDevFactor(double stdDevFactor) { + this.stdDevFactor = stdDevFactor; + } + + public double getStandardDeviation() { + return standardDeviation; + } + + public void setStandardDeviation(double standardDeviation) { + this.standardDeviation = standardDeviation; + } + + public double getChiSqr() { + return chiSqr; + } + + public void setChiSqr(double chiSqr) { + this.chiSqr = chiSqr; + } + + public List<SQ> getRemaining() { + return remaining; + } + + public void setRemaining(List<SQ> remaining) { + this.remaining = remaining; + } + + public List<List<SQ>> getOutliers() { + return outliers; + } + + public void setOutliers(List<List<SQ>> outliers) { + this.outliers = outliers; + } + + public void reset() { + outliers = null; + remaining = null; + parameters = null; + standardDeviation = 0d; + standardDeviation = 0d; + chiSqr = 0d; + } + + protected static final List<SQ> onlyValid(List<SQ> sqs) { + + List<SQ> good = new ArrayList<SQ>(sqs.size()); + + for (SQ sq: sqs) { + if (sq.isValid()) { + good.add(sq); + } + } + + return good; + } + + public boolean fit(List<SQ> sqs) { + + sqs = onlyValid(sqs); + + if (sqs.size() < 2) { + log.warn("Too less points for fitting."); + return false; + } + + final LevenbergMarquardtOptimizer lmo = + new LevenbergMarquardtOptimizer(); + + CurveFitter cf = new CurveFitter(lmo); + + for (SQ sq: sqs) { + cf.addObservedPoint(sq.getQ(), sq.getS()); + } + + try { + parameters = cf.fit(function, function.getInitialGuess()); + } + catch (MathException me) { + log.warn(me); + return false; + } + + chiSqr = lmo.getChiSquare(); + + final de.intevation.flys.artifacts.math.Function [] instance = { + function.instantiate(parameters) + }; + + try { + remaining = Outlier.detectOutliers( + new Outlier.Callback() { + + List<List<SQ>> outliers = + new ArrayList<List<SQ>>(); + + int currentIteration; + + @Override + public double eval(SQ sq) { + double s = instance[0].value(sq.q); + return s - sq.s; + } + + @Override + public void iteration(int i) { + currentIteration = i; + } + + @Override + public void outlier(SQ sq) { + if (currentIteration > outliers.size()) { + outliers.add(new ArrayList<SQ>(2)); + } + outliers.get(currentIteration-1).add(sq); + } + + @Override + public void standardDeviation(double stdDev) { + setStandardDeviation(stdDev); + } + + @Override + public void reinitialize(Iterator<SQ> good) + throws MathException + { + CurveFitter cf = new CurveFitter(lmo); + while (good.hasNext()) { + SQ sq = good.next(); + cf.addObservedPoint(sq.getQ(), sq.getS()); + } + + parameters = cf.fit( + function, function.getInitialGuess()); + + instance[0] = function.instantiate(parameters); + + chiSqr = lmo.getChiSquare(); + } + + @Override + public void finished() { + setOutliers(outliers); + } + }, + sqs, + stdDevFactor); + } + catch (MathException me) { + log.warn(me); + return false; + } + + return true; + } +} +// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :