view flys-backend/contrib/shpimporter/importer.py @ 5123:64766b89beb6

WQAdaptedInputPanel: Refactored to remove code dupes.
author Felix Wolfsteller <felix.wolfsteller@intevation.de>
date Wed, 27 Feb 2013 12:19:41 +0100
parents 8e99d2d9364d
children c0a58558b817
line wrap: on
line source
try:
    from osgeo import ogr
except ImportErrror:
    import ogr
import osr
import shpimporter
import utils

class Importer:

    def __init__(self, config, dbconn):
        self.config = config
        self.dbconn = dbconn
        self.river_id = config.river_id
        self.dest_srs = osr.SpatialReference()
        self.dest_srs.ImportFromEPSG(config.target_srs)
        self.handled_fields = []
        self.tracking_import = False

    def getKind(self, path):
        raise NotImplementedError("Importer.getKind is abstract!")

    def getPath(self, base):
        raise NotImplementedError("Importer.getPath is abstract!")

    def getTablename(self):
        raise NotImplementedError("Importer.getTablename is abstract!")

    def getName(self):
        raise NotImplementedError("Importer.getTablename is abstract!")

    def IsFieldSet(self, feat, name):
        if feat.GetFieldIndex(name) == -1:
            return False # Avoids an Error in IsFieldSet
        return feat.IsFieldSet(feat.GetFieldIndex(name))

    def IsDoubleFieldSet(self, feat, name):
        try:
            isset = feat.GetFieldAsDouble(name)
            return isset is not None
        except:
            return False

    def isShapeRelevant(self, name, path):
        return True

    def walkOverShapes(self, shape):
        (name, path) = shape
        if not self.isShapeRelevant(name, path):
            shpimporter.INFO("Skip shapefile '%s'" % path)
            return

        shp = ogr.Open(shape[1])
        if shp is None:
            shpimporter.ERROR("Shapefile '%s' could not be opened!" % path)
            return

        shpimporter.INFO("Processing shapefile '%s'" % path)
        srcLayer = shp.GetLayerByName(name)

        if srcLayer is None:
            shpimporter.ERROR("Layer '%s' was not found!" % name)
            return

        return self.shape2Database(srcLayer, name, path)

    def transform(self, feat):
        geometry = feat.GetGeometryRef()
        src_srs  = geometry.GetSpatialReference()

        if src_srs is None:
            shpimporter.ERROR("No source SRS given! No transformation possible!")
            return feat

        transformer = osr.CoordinateTransformation(src_srs, self.dest_srs)
        geometry.Transform(transformer)

        return feat

    def handled(self, field):
        """
        Register a field or a map of as handled during the import.

        There is a warning printed after the import for each unhandled field!
        """
        if not field in self.handled_fields:
            self.handled_fields.append(field)

    def copyFields(self, src, target, mapping):
        """
        Checks the mapping dictonary for key value pairs to
        copy from the source to the destination feature.

        The Key is the attribute of the source feature to be copied
        into the target attribute named by the dict's value.
        """
        self.tracking_import = True
        self.handled_fields.extend(mapping.keys())
        for key, value in mapping.items():
            if src.GetFieldIndex(key) == -1:
                continue
            # 0 OFTInteger, Simple 32bit integer
            # 1 OFTIntegerList, List of 32bit integers
            # 2 OFTReal, Double Precision floating point
            # 3 OFTRealList, List of doubles
            # 4 OFTString, String of ASCII chars
            # 5 OFTStringList, Array of strings
            # 6 OFTWideString, deprecated
            # 7 OFTWideStringList, deprecated
            # 8 OFTBinary, Raw Binary data
            # 9 OFTDate, Date
            # 10 OFTTime, Time
            # 11 OFTDateTime, Date and Time
            if src.IsFieldSet(src.GetFieldIndex(key)):
                if src.GetFieldType(key) == 2:
                    target.SetField(value, src.GetFieldAsDouble(key))
                else:
                    target.SetField(value, src.GetField(key))

    def shape2Database(self, srcLayer, name, path):
        destLayer = self.dbconn.GetLayerByName(self.getTablename())

        if srcLayer is None:
            shpimporter.ERROR("Shapefile is None!")
            return -1

        if destLayer is None:
            shpimporter.ERROR("No destination layer given!")
            return -1

        count = srcLayer.GetFeatureCount()
        shpimporter.DEBUG("Try to add %i features to database." % count)

        srcLayer.ResetReading()

        geomType    = -1
        success     = 0
        unsupported = 0
        creationFailed = 0
        featureDef  = destLayer.GetLayerDefn()

        for feat in srcLayer:
            geom     = feat.GetGeometryRef()

            if geom is None:
                shpimporter.DEBUG("Unkown Geometry reference for feature")
                continue

            geomType = geom.GetGeometryType()

            if self.isGeometryValid(geomType):
                newFeat = self.createNewFeature(featureDef,
                                                feat,
                                                name=name,
                                                path=path)

                if newFeat is not None:
                    newFeat.SetField("path", utils.getUTF8Path(path))
                    newFeat = self.transform(newFeat)
                    res = destLayer.CreateFeature(newFeat)
                    if res is None or res > 0:
                        shpimporter.ERROR("Unable to insert feature. Error: %r" % res)
                    else:
                        success = success + 1
                else:
                    creationFailed = creationFailed + 1
            else:
                unsupported = unsupported + 1

        shpimporter.INFO("Inserted %i features" % success)
        shpimporter.INFO("Failed to create %i features" % creationFailed)
        shpimporter.INFO("Found %i unsupported features" % unsupported)

        if self.tracking_import:
            unhandled = []
            for i in range(0, srcLayer.GetLayerDefn().GetFieldCount()):
                act_field = srcLayer.GetLayerDefn().GetFieldDefn(i).GetNameRef()
                if not act_field in self.handled_fields:
                    unhandled.append(act_field)

            if len(unhandled):
                shpimporter.INFO("Did not import values from fields: %s " % \
                        " ".join(unhandled))

        try:
            if self.config.dry_run > 0:
                return geomType
            destLayer.CommitTransaction()
        except e:
            shpimporter.ERROR("Exception while committing transaction.")

        return geomType

http://dive4elements.wald.intevation.org