view flys-artifacts/src/main/java/de/intevation/flys/artifacts/model/fixings/Fitting.java @ 3127:71484036b6ae

FixA: Moved function instantiation to always have a valid function instance. flys-artifacts/trunk@4728 c6561f87-3c4e-4783-a992-168aeb5c3f6f
author Raimund Renkert <raimund.renkert@intevation.de>
date Wed, 20 Jun 2012 15:20:44 +0000
parents cbf308f5c41b
children 307842cf8d9e
line wrap: on
line source
package de.intevation.flys.artifacts.model.fixings;

import de.intevation.flys.artifacts.math.Outlier.IndexedValue;
import de.intevation.flys.artifacts.math.Outlier.Outliers;

import de.intevation.flys.artifacts.math.Outlier;

import de.intevation.flys.artifacts.math.fitting.Function;

import gnu.trove.TDoubleArrayList;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math.MathException;

import org.apache.commons.math.optimization.fitting.CurveFitter;

import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;

import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;

import org.apache.log4j.Logger;

public class Fitting
{
    private static Logger log = Logger.getLogger(Fitting.class);

    /** Use instance of this factory to find meta infos for outliers. */
    public interface QWDFactory {

        QWD create(double q, double w);

    } // interface QWFactory

    public static final QWDFactory QWD_FACTORY = new QWDFactory() {
        @Override
        public QWD create(double q, double w) {
            return new QWD(q, w);
        }
    };

    protected boolean       checkOutliers;
    protected Function      function;
    protected QWDFactory    qwdFactory;
    protected double        chiSqr;
    protected double []     parameters;
    protected ArrayList<QW> removed;
    protected QWD []        referenced;
    protected double        standardDeviation;


    public Fitting() {
        removed = new ArrayList<QW>();
    }

    public Fitting(Function function) {
        this(function, QWD_FACTORY);
    }

    public Fitting(Function function, QWDFactory qwdFactory) {
        this(function, qwdFactory, false);
    }

    public Fitting(
        Function   function,
        QWDFactory qwdFactory,
        boolean    checkOutliers
    ) {
        this();
        this.function      = function;
        this.qwdFactory    = qwdFactory;
        this.checkOutliers = checkOutliers;
    }

    public Function getFunction() {
        return function;
    }

    public void setFunction(Function function) {
        this.function = function;
    }

    public boolean getCheckOutliers() {
        return checkOutliers;
    }

    public void setCheckOutliers(boolean checkOutliers) {
        this.checkOutliers = checkOutliers;
    }

    public double getChiSquare() {
        return chiSqr;
    }

    public void reset() {
        chiSqr     = 0.0;
        parameters = null;
        removed.clear();
        referenced = null;
        standardDeviation = 0.0;
    }

    public boolean hasOutliers() {
        return !removed.isEmpty();
    }

    public List<QW> getOutliers() {
        return removed;
    }

    public QW [] outliersToArray() {
        return removed.toArray(new QW[removed.size()]);
    }

    public QWD [] referencedToArray() {
        return referenced != null ? (QWD [])referenced.clone() : null;
    }

    public double getMaxQ() {
        double maxQ = -Double.MAX_VALUE;
        if (referenced != null) {
            for (QW qw: referenced) {
                if (qw.getQ() > maxQ) {
                    maxQ = qw.getQ();
                }
            }
        }
        return maxQ;
    }

    public double [] getParameters() {
        return parameters;
    }

    public double getStandardDeviation() {
        return standardDeviation;
    }

    public boolean fit(double [] qs, double [] ws) {

        TDoubleArrayList xs = new TDoubleArrayList(qs.length);
        TDoubleArrayList ys = new TDoubleArrayList(ws.length);

        for (int i = 0; i < qs.length; ++i) {
            if (!Double.isNaN(qs[i]) && !Double.isNaN(ws[i])) {
                xs.add(qs[i]);
                ys.add(ws[i]);
            }
            else {
                log.warn("remove invalid value " + qs[i] + " " + ws[i]);
            }
        }

        if (xs.size() < 2) {
            log.warn("Too less points.");
            return false;
        }

        LevenbergMarquardtOptimizer lmo = new LevenbergMarquardtOptimizer();

        List<IndexedValue> inputs = new ArrayList<IndexedValue>(xs.size());

        de.intevation.flys.artifacts.math.Function instance = null;

        for (;;) {
            CurveFitter cf = new CurveFitter(lmo);

            for (int i = 0, N = xs.size(); i < N; ++i) {
                cf.addObservedPoint(xs.getQuick(i), ys.getQuick(i));
            }

            try {
                parameters = cf.fit(function, function.getInitialGuess());
            }
            catch (MathException me) {
                log.warn(me);
                return false;
            }

            // This is the paraterized function for a given km.
            instance = function.instantiate(parameters);

            if (!checkOutliers) {
                break;
            }

            inputs.clear();

            for (int i = 0, N = xs.size(); i < N; ++i) {
                double y = instance.value(xs.getQuick(i));
                if (Double.isNaN(y)) {
                    continue;
                }
                inputs.add(new IndexedValue(i, ys.getQuick(i) - y));
            }

            Outliers outliers = Outlier.findOutliers(inputs);

            if (!outliers.hasOutliers()) {
                break;
            }

            List<IndexedValue> rem = outliers.getRemoved();

            for (int i = rem.size()-1; i >= 0; --i) {
                int idx = rem.get(i).getIndex();
                removed.add(
                    qwdFactory.create(
                        xs.getQuick(idx), ys.getQuick(idx)));
                xs.remove(idx);
                ys.remove(idx);
            }
        }

        StandardDeviation stdDev = new StandardDeviation();

        referenced = new QWD[xs.size()];
        for (int i = 0; i < referenced.length; ++i) {
            QWD qwd = qwdFactory.create(xs.getQuick(i), ys.getQuick(i));

            if (qwd == null) {
                log.warn("QW creation failed!");
            }
            else {
                referenced[i] = qwd;
                double dw = (qwd.getW() - instance.value(qwd.getQ()))*100.0;
                qwd.setDeltaW(dw);
                stdDev.increment(dw);
            }
        }

        standardDeviation = stdDev.getResult();

        chiSqr = lmo.getChiSquare();

        return true;
    }
}
// vim:set ts=4 sw=4 si et sta sts=4 fenc=utf8 :

http://dive4elements.wald.intevation.org