aheinecke@4872: try:
aheinecke@5004:     from osgeo import ogr, osr
aheinecke@5077: except ImportError:
aheinecke@5004:     import ogr, osr
aheinecke@4877: import utils
aheinecke@4935: import re
aheinecke@5006: import logging
aheinecke@5006: 
aheinecke@5006: logger = logging.getLogger("importer")
ingo@2853: 
ingo@2853: class Importer:
ingo@2853: 
aheinecke@4970:     def __init__(self, river_id, dbconn, dry_run):
aheinecke@4970:         self.river_id = river_id
aheinecke@4872:         self.dbconn = dbconn
aheinecke@4970:         self.dry_run = dry_run
ingo@2861:         self.dest_srs = osr.SpatialReference()
aheinecke@4970:         self.dest_srs.ImportFromEPSG(31467)
aheinecke@4884:         self.handled_fields = []
aheinecke@4885:         self.tracking_import = False
aheinecke@4935:         self.srcLayer = None
ingo@2853: 
ingo@2853:     def getKind(self, path):
ingo@2853:         raise NotImplementedError("Importer.getKind is abstract!")
ingo@2853: 
ingo@2853:     def getPath(self, base):
ingo@2853:         raise NotImplementedError("Importer.getPath is abstract!")
ingo@2853: 
ingo@2853:     def getTablename(self):
ingo@2853:         raise NotImplementedError("Importer.getTablename is abstract!")
ingo@2853: 
ingo@3654:     def getName(self):
aheinecke@5160:         raise NotImplementedError("Importer.getName is abstract!")
aheinecke@5160: 
aheinecke@5160:     def isGeometryValid(self, geomType):
aheinecke@5160:         raise NotImplementedError("Importer.isGeometryValid is abstract!")
aheinecke@5160: 
aheinecke@5160:     def createNewFeature(self, featureDef, feat, **args):
aheinecke@5160:         raise NotImplementedError("Importer.createNewFeature is abstract!")
ingo@3654: 
ingo@2853:     def IsFieldSet(self, feat, name):
aheinecke@4935:         if not name:
aheinecke@4935:             return False
aheinecke@4878:         if feat.GetFieldIndex(name) == -1:
aheinecke@4878:             return False # Avoids an Error in IsFieldSet
aheinecke@4878:         return feat.IsFieldSet(feat.GetFieldIndex(name))
ingo@2853: 
aheinecke@5386:     def searchValue(self, feat, regex):
aheinecke@5386:         """
aheinecke@5401:         Searches for a value that matches regex in all attribute
aheinecke@5386:         fields of a feature.
aheinecke@5401: 
aheinecke@5401:         @returns the name of the field where a match was found or None
aheinecke@5386:         """
aheinecke@5386:         for val in feat.items():
aheinecke@5407:             if not isinstance(feat.items()[val], basestring):
aheinecke@5407:                 continue
aheinecke@5401:             match = re.match(regex, feat.items()[val], re.IGNORECASE)
aheinecke@5401:             if match:
aheinecke@5401:                 return val
aheinecke@5386: 
aheinecke@4935:     def searchField(self, regex):
aheinecke@4935:         """
aheinecke@4935:         Searches for a field in the current src layer that matches
aheinecke@4935:         the expression regex.
tom@5174:         Throws an exception if more than one field matches
aheinecke@4935:         @param feat: The feature to search for attributes
aheinecke@4935:         @param regex: The regex to look for
aheinecke@4935: 
aheinecke@4935:         @returns: The field name as a string
aheinecke@4935:         """
aheinecke@4935: 
aheinecke@4935:         if not hasattr(self.srcLayer, "fieldnames"):
aheinecke@4935:             self.srcLayer.fieldnames = []
aheinecke@4935:             for i in range(0, self.srcLayer.GetLayerDefn().GetFieldCount()):
aheinecke@4935:                 self.srcLayer.fieldnames.append(
aheinecke@4935:                     self.srcLayer.GetLayerDefn().GetFieldDefn(i).GetNameRef())
aheinecke@4935: 
aheinecke@4935:         result = None
aheinecke@4935:         for name in self.srcLayer.fieldnames:
aheinecke@4935:             match = re.match(regex, name, re.IGNORECASE)
aheinecke@4935:             if match:
aheinecke@4935:                 if result:
tom@5174:                     raise Exception("More than one field matches: %s" % regex)
aheinecke@4935:                 else:
aheinecke@4935:                     result = match.group(0)
aheinecke@4935:         return result
aheinecke@4935: 
ingo@2853:     def IsDoubleFieldSet(self, feat, name):
aheinecke@5365:         if not self.IsFieldSet(feat, name):
aheinecke@5365:             return False
ingo@2853:         try:
ingo@2853:             isset = feat.GetFieldAsDouble(name)
ingo@2853:             return isset is not None
ingo@2853:         except:
ingo@2853:             return False
ingo@2853: 
ingo@2853:     def isShapeRelevant(self, name, path):
ingo@2853:         return True
ingo@2853: 
ingo@2853:     def walkOverShapes(self, shape):
ingo@2853:         (name, path) = shape
ingo@2853: 
ingo@2853:         shp = ogr.Open(shape[1])
ingo@2853:         if shp is None:
aheinecke@5006:             logger.error("Shapefile '%s' could not be opened!" % path)
ingo@2853:             return
ingo@2853: 
aheinecke@4995:         if not self.isShapeRelevant(name, path):
aheinecke@5006:             logger.info("Skip shapefile: '%s' of Type: %s" % (path,
aheinecke@4995:                 utils.getWkbString(shp.GetLayerByName(name).GetGeomType())))
aheinecke@4995:             return
aheinecke@4995: 
aheinecke@4995: 
aheinecke@5006:         logger.info("Processing shapefile '%s'" % path)
ingo@2853:         srcLayer = shp.GetLayerByName(name)
ingo@2853: 
ingo@2853:         if srcLayer is None:
aheinecke@5006:             logger.error("Layer '%s' was not found!" % name)
ingo@2853:             return
ingo@2853: 
ingo@2853:         return self.shape2Database(srcLayer, name, path)
ingo@2853: 
ingo@2861:     def transform(self, feat):
ingo@2861:         geometry = feat.GetGeometryRef()
ingo@2861:         src_srs  = geometry.GetSpatialReference()
ingo@2861: 
ingo@2861:         if src_srs is None:
aheinecke@5006:             logger.error("No source SRS given! No transformation possible!")
ingo@2861:             return feat
ingo@2861: 
ingo@2861:         transformer = osr.CoordinateTransformation(src_srs, self.dest_srs)
aheinecke@4974:         if geometry.Transform(transformer):
aheinecke@4974:             return None
ingo@2861: 
ingo@2861:         return feat
ingo@2861: 
aheinecke@4884:     def handled(self, field):
aheinecke@4884:         """
aheinecke@4884:         Register a field or a map of as handled during the import.
aheinecke@4884: 
aheinecke@4884:         There is a warning printed after the import for each unhandled field!
aheinecke@4884:         """
aheinecke@4884:         if not field in self.handled_fields:
aheinecke@4884:             self.handled_fields.append(field)
aheinecke@4884: 
aheinecke@4884:     def copyFields(self, src, target, mapping):
aheinecke@4884:         """
aheinecke@4884:         Checks the mapping dictonary for key value pairs to
aheinecke@4884:         copy from the source to the destination feature.
aheinecke@4935:         The keys can be reguar expressions that are matched
aheinecke@4935:         agains the source fieldnames
aheinecke@4884: 
aheinecke@4884:         The Key is the attribute of the source feature to be copied
aheinecke@4884:         into the target attribute named by the dict's value.
aheinecke@4884:         """
aheinecke@4885:         self.tracking_import = True
aheinecke@4884:         for key, value in mapping.items():
aheinecke@4935:             realname = self.searchField(key)
aheinecke@4935:             if realname == None:
aheinecke@4884:                 continue
aheinecke@4935:             if not realname in self.handled_fields:
aheinecke@4935:                 self.handled_fields.append(realname)
aheinecke@4884:             # 0 OFTInteger, Simple 32bit integer
aheinecke@4884:             # 1 OFTIntegerList, List of 32bit integers
aheinecke@4884:             # 2 OFTReal, Double Precision floating point
aheinecke@4884:             # 3 OFTRealList, List of doubles
aheinecke@4884:             # 4 OFTString, String of ASCII chars
aheinecke@4884:             # 5 OFTStringList, Array of strings
aheinecke@4884:             # 6 OFTWideString, deprecated
aheinecke@4884:             # 7 OFTWideStringList, deprecated
aheinecke@4884:             # 8 OFTBinary, Raw Binary data
aheinecke@4884:             # 9 OFTDate, Date
aheinecke@4884:             # 10 OFTTime, Time
aheinecke@4884:             # 11 OFTDateTime, Date and Time
aheinecke@4935:             if src.IsFieldSet(src.GetFieldIndex(realname)):
aheinecke@4935:                 if src.GetFieldType(realname) == 2:
aheinecke@4935:                     target.SetField(value, src.GetFieldAsDouble(realname))
aheinecke@4884:                 else:
aheinecke@4935:                     target.SetField(value, utils.getUTF8(src.GetField(realname)))
aheinecke@4884: 
ingo@2853:     def shape2Database(self, srcLayer, name, path):
aheinecke@4872:         destLayer = self.dbconn.GetLayerByName(self.getTablename())
ingo@2853: 
ingo@2853:         if srcLayer is None:
aheinecke@5006:             logger.error("Shapefile is None!")
ingo@2853:             return -1
ingo@2853: 
ingo@2853:         if destLayer is None:
aheinecke@5006:             logger.error("No destination layer given!")
ingo@2853:             return -1
ingo@2853: 
ingo@2853:         count = srcLayer.GetFeatureCount()
aheinecke@5006:         logger.debug("Try to add %i features to database." % count)
ingo@2853: 
ingo@2853:         srcLayer.ResetReading()
aheinecke@4935:         self.srcLayer = srcLayer
ingo@2853: 
ingo@2853:         geomType    = -1
ingo@2853:         success     = 0
aheinecke@4995:         unsupported = {}
ingo@2861:         creationFailed = 0
ingo@2853:         featureDef  = destLayer.GetLayerDefn()
ingo@2853: 
ingo@2853:         for feat in srcLayer:
ingo@2853:             geom     = feat.GetGeometryRef()
ingo@2861: 
ingo@2861:             if geom is None:
aheinecke@5006:                 logger.debug("Unkown Geometry reference for feature")
ingo@2861:                 continue
ingo@2861: 
ingo@2853:             geomType = geom.GetGeometryType()
ingo@2853: 
ingo@2853:             if self.isGeometryValid(geomType):
ingo@2853:                 newFeat = self.createNewFeature(featureDef,
ingo@2853:                                                 feat,
aheinecke@4951:                                                 name=utils.getUTF8(name),
ingo@2853:                                                 path=path)
ingo@2853: 
ingo@2853:                 if newFeat is not None:
aheinecke@4877:                     newFeat.SetField("path", utils.getUTF8Path(path))
ingo@2861:                     newFeat = self.transform(newFeat)
aheinecke@4974:                     if newFeat:
aheinecke@4974:                         res = destLayer.CreateFeature(newFeat)
aheinecke@4974:                         if res is None or res > 0:
aheinecke@5006:                             logger.error("Unable to insert feature. Error: %r" % res)
aheinecke@4974:                         else:
aheinecke@4974:                             success = success + 1
ingo@2853:                     else:
aheinecke@5006:                         logger.error("Could not transform feature: %s " % feat.GetFID())
aheinecke@4974:                         creationFailed += 1
ingo@2861:                 else:
ingo@2861:                     creationFailed = creationFailed + 1
ingo@2853:             else:
aheinecke@4995:                 unsupported[utils.getWkbString(geomType)] = \
aheinecke@4995:                         unsupported.get(utils.getWkbString(geomType), 0) + 1
ingo@2853: 
aheinecke@5006:         logger.info("Inserted %i features" % success)
aheinecke@5006:         logger.info("Failed to create %i features" % creationFailed)
aheinecke@5001:         for key, value in unsupported.items():
aheinecke@5006:             logger.info("Found %i unsupported features of type: %s" % (value, key))
ingo@2853: 
aheinecke@4886:         if self.tracking_import:
aheinecke@4886:             unhandled = []
aheinecke@4886:             for i in range(0, srcLayer.GetLayerDefn().GetFieldCount()):
aheinecke@4886:                 act_field = srcLayer.GetLayerDefn().GetFieldDefn(i).GetNameRef()
aheinecke@4886:                 if not act_field in self.handled_fields:
aheinecke@4886:                     unhandled.append(act_field)
aheinecke@4884: 
aheinecke@4886:             if len(unhandled):
aheinecke@5006:                 logger.info("Did not import values from fields: %s " % \
aheinecke@4886:                         " ".join(unhandled))
aheinecke@4884: 
ingo@2853:         try:
aheinecke@4970:             if self.dry_run:
ingo@3655:                 return geomType
ingo@2853:             destLayer.CommitTransaction()
aheinecke@5160:         except:
aheinecke@5006:             logger.error("Exception while committing transaction.")
ingo@2853: 
ingo@2853:         return geomType