Mercurial > dive4elements > river
view flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/fixings/Fitting.java @ 4515:17d896822d70
Added methods to sediment load object to check fractions.
author | Raimund Renkert <rrenkert@intevation.de> |
---|---|
date | Wed, 14 Nov 2012 17:02:51 +0100 |
parents | e727e3ebdf85 |
children | a7d080347ac3 |
line wrap: on
line source
package de.intevation.flys.artifacts.model.fixings; import de.intevation.flys.artifacts.math.Outlier; import de.intevation.flys.artifacts.math.fitting.Function; import gnu.trove.TDoubleArrayList; import java.util.ArrayList; import java.util.List; import org.apache.commons.math.MathException; import org.apache.commons.math.optimization.fitting.CurveFitter; import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer; import org.apache.commons.math.stat.descriptive.moment.StandardDeviation; import org.apache.log4j.Logger; public class Fitting { private static Logger log = Logger.getLogger(Fitting.class); /** Use instance of this factory to find meta infos for outliers. */ public interface QWDFactory { QWD create(double q, double w); } // interface QWFactory public static final QWDFactory QWD_FACTORY = new QWDFactory() { @Override public QWD create(double q, double w) { return new QWD(q, w); } }; protected boolean checkOutliers; protected Function function; protected QWDFactory qwdFactory; protected double chiSqr; protected double [] parameters; protected ArrayList<QWI> removed; protected QWD [] referenced; protected double standardDeviation; public Fitting() { removed = new ArrayList<QWI>(); } public Fitting(Function function) { this(function, QWD_FACTORY); } public Fitting(Function function, QWDFactory qwdFactory) { this(function, qwdFactory, false); } public Fitting( Function function, QWDFactory qwdFactory, boolean checkOutliers ) { this(); this.function = function; this.qwdFactory = qwdFactory; this.checkOutliers = checkOutliers; } public Function getFunction() { return function; } public void setFunction(Function function) { this.function = function; } public boolean getCheckOutliers() { return checkOutliers; } public void setCheckOutliers(boolean checkOutliers) { this.checkOutliers = checkOutliers; } public double getChiSquare() { return chiSqr; } public void reset() { chiSqr = 0.0; parameters = null; removed.clear(); referenced = null; standardDeviation = 0.0; } public boolean hasOutliers() { return !removed.isEmpty(); } public List<QWI> getOutliers() { return removed; } public QWI [] outliersToArray() { return removed.toArray(new QWI[removed.size()]); } public QWD [] referencedToArray() { return referenced != null ? (QWD [])referenced.clone() : null; } public double getMaxQ() { double maxQ = -Double.MAX_VALUE; if (referenced != null) { for (QWI qw: referenced) { if (qw.getQ() > maxQ) { maxQ = qw.getQ(); } } } return maxQ; } public double [] getParameters() { return parameters; } public double getStandardDeviation() { return standardDeviation; } public boolean fit(double [] qs, double [] ws) { TDoubleArrayList xs = new TDoubleArrayList(qs.length); TDoubleArrayList ys = new TDoubleArrayList(ws.length); for (int i = 0; i < qs.length; ++i) { if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) { xs.add(qs[i]); ys.add(ws[i]); } else { log.warn("remove invalid value " + qs[i] + " " + ws[i]); } } if (xs.size() < 2) { log.warn("Too less points."); return false; } List<Double> inputs = new ArrayList<Double>(xs.size()); de.intevation.flys.artifacts.math.Function instance = null; LevenbergMarquardtOptimizer lmo = null; for (;;) { parameters = null; for (double tolerance = 1e-10; tolerance < 1e-3; tolerance *= 10d) { lmo = new LevenbergMarquardtOptimizer(); lmo.setCostRelativeTolerance(tolerance); lmo.setOrthoTolerance(tolerance); lmo.setParRelativeTolerance(tolerance); CurveFitter cf = new CurveFitter(lmo); for (int i = 0, N = xs.size(); i < N; ++i) { cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i)); } try { parameters = cf.fit(function, function.getInitialGuess()); break; } catch (MathException me) { if (log.isDebugEnabled()) { log.debug("tolerance " + tolerance + " + failed."); } } } if (parameters == null) { return false; } // This is the paraterized function for a given km. instance = function.instantiate(parameters); if (!checkOutliers) { break; } inputs.clear(); for (int i = 0, N = xs.size(); i < N; ++i) { double y = instance.value(xs.getQuick(i)); if (Double.isNaN(y)) { y = Double.MAX_VALUE; } inputs.add(Double.valueOf(ys.getQuick(i) - y)); } Integer outlier = Outlier.findOutlier(inputs); if (outlier == null) { break; } int idx = outlier.intValue(); removed.add( qwdFactory.create( xs.getQuick(idx), ys.getQuick(idx))); xs.remove(idx); ys.remove(idx); } StandardDeviation stdDev = new StandardDeviation(); referenced = new QWD[xs.size()]; for (int i = 0; i < referenced.length; ++i) { QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i)); if (qwd == null) { log.warn("QW creation failed!"); } else { referenced[i] = qwd; double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0; qwd.setDeltaW(dw); stdDev.increment(dw); } } standardDeviation = stdDev.getResult(); chiSqr = lmo.getChiSquare(); return true; } } // vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :