""" /*************************************************************************** Name : DB Manager Description : Database manager plugin for QGIS Date : May 23, 2011 copyright : (C) 2011 by Giuseppe Sucameli email : brush.tyler@gmail.com ***************************************************************************/ /*************************************************************************** * * * This program is free software; you can redistribute it and/or modify * * it under the terms of the GNU General Public License as published by * * the Free Software Foundation; either version 2 of the License, or * * (at your option) any later version. * * * ***************************************************************************/ """ from functools import partial from qgis.PyQt.QtCore import Qt, QObject, qDebug, QByteArray, QMimeData, QDataStream, QIODevice, QFileInfo, QAbstractItemModel, QModelIndex, pyqtSignal from qgis.PyQt.QtWidgets import QApplication, QMessageBox from qgis.PyQt.QtGui import QIcon from .db_plugins import supportedDbTypes, createDbPlugin from .db_plugins.plugin import BaseError, Table, Database from .dlg_db_error import DlgDbError from .gui_utils import GuiUtils from qgis.core import ( QgsApplication, QgsDataSourceUri, QgsVectorLayer, QgsRasterLayer, QgsMimeDataUtils, QgsProviderConnectionException, QgsProviderRegistry, QgsAbstractDatabaseProviderConnection, QgsMessageLog, ) from qgis.utils import OverrideCursor try: from qgis.core import QgsVectorLayerExporter # NOQA isImportVectorAvail = True except: isImportVectorAvail = False from osgeo import gdal class TreeItem(QObject): deleted = pyqtSignal() changed = pyqtSignal() def __init__(self, data, parent=None): QObject.__init__(self, parent) self.populated = False self.itemData = data self.childItems = [] if parent: parent.appendChild(self) def childRemoved(self): self.itemChanged() def itemChanged(self): self.changed.emit() def itemDeleted(self): self.deleted.emit() def populate(self): self.populated = True return True def getItemData(self): return self.itemData def appendChild(self, child): self.childItems.append(child) child.deleted.connect(self.childRemoved) def child(self, row): return self.childItems[row] def removeChild(self, row): if row >= 0 and row < len(self.childItems): self.childItems[row].itemData.deleteLater() self.childItems[row].deleted.disconnect(self.childRemoved) del self.childItems[row] def childCount(self): return len(self.childItems) def columnCount(self): return 1 def row(self): if self.parent(): for row, item in enumerate(self.parent().childItems): if item is self: return row return 0 def data(self, column): return "" if column == 0 else None def icon(self): return None def path(self): pathList = [] if self.parent(): pathList.extend(self.parent().path()) pathList.append(self.data(0)) return pathList class PluginItem(TreeItem): def __init__(self, dbplugin, parent=None): TreeItem.__init__(self, dbplugin, parent) def populate(self): if self.populated: return True # create items for connections for c in self.getItemData().connections(): ConnectionItem(c, self) self.populated = True return True def data(self, column): if column == 0: return self.getItemData().typeNameString() return None def icon(self): return self.getItemData().icon() def path(self): return [self.getItemData().typeName()] class ConnectionItem(TreeItem): def __init__(self, connection, parent=None): TreeItem.__init__(self, connection, parent) connection.changed.connect(self.itemChanged) connection.deleted.connect(self.itemDeleted) # load (shared) icon with first instance of table item if not hasattr(ConnectionItem, 'connectedIcon'): ConnectionItem.connectedIcon = GuiUtils.get_icon("plugged") ConnectionItem.disconnectedIcon = GuiUtils.get_icon("unplugged") def data(self, column): if column == 0: return self.getItemData().connectionName() return None def icon(self): return self.getItemData().connectionIcon() def populate(self): if self.populated: return True connection = self.getItemData() if connection.database() is None: # connect to database try: if not connection.connect(): return False except BaseError as e: DlgDbError.showError(e, None) return False database = connection.database() database.changed.connect(self.itemChanged) database.deleted.connect(self.itemDeleted) schemas = database.schemas() if schemas is not None: for s in schemas: SchemaItem(s, self) else: tables = database.tables() for t in tables: TableItem(t, self) self.populated = True return True def isConnected(self): return self.getItemData().database() is not None # def icon(self): # return self.connectedIcon if self.isConnected() else self.disconnectedIcon class SchemaItem(TreeItem): def __init__(self, schema, parent): TreeItem.__init__(self, schema, parent) schema.changed.connect(self.itemChanged) schema.deleted.connect(self.itemDeleted) # load (shared) icon with first instance of schema item if not hasattr(SchemaItem, 'schemaIcon'): SchemaItem.schemaIcon = GuiUtils.get_icon("namespace") def data(self, column): if column == 0: return self.getItemData().name return None def icon(self): return self.schemaIcon def populate(self): if self.populated: return True for t in self.getItemData().tables(): TableItem(t, self) self.populated = True return True class TableItem(TreeItem): def __init__(self, table, parent): TreeItem.__init__(self, table, parent) table.changed.connect(self.itemChanged) table.deleted.connect(self.itemDeleted) self.populate() # load (shared) icon with first instance of table item if not hasattr(TableItem, 'tableIcon'): TableItem.tableIcon = QgsApplication.getThemeIcon("/mIconTableLayer.svg") TableItem.viewIcon = GuiUtils.get_icon("view") TableItem.viewMaterializedIcon = GuiUtils.get_icon("view_materialized") TableItem.layerPointIcon = QgsApplication.getThemeIcon("/mIconPointLayer.svg") TableItem.layerLineIcon = QgsApplication.getThemeIcon("/mIconLineLayer.svg") TableItem.layerPolygonIcon = QgsApplication.getThemeIcon("/mIconPolygonLayer.svg") TableItem.layerRasterIcon = QgsApplication.getThemeIcon("/mIconRasterLayer.svg") TableItem.layerUnknownIcon = GuiUtils.get_icon("layer_unknown") def data(self, column): if column == 0: return self.getItemData().name elif column == 1: if self.getItemData().type == Table.VectorType: return self.getItemData().geomType return None def icon(self): if self.getItemData().type == Table.VectorType: geom_type = self.getItemData().geomType if geom_type is not None: if geom_type.find('POINT') != -1: return self.layerPointIcon elif geom_type.find('LINESTRING') != -1 or geom_type in ('CIRCULARSTRING', 'COMPOUNDCURVE', 'MULTICURVE'): return self.layerLineIcon elif geom_type.find('POLYGON') != -1 or geom_type == 'MULTISURFACE': return self.layerPolygonIcon return self.layerUnknownIcon elif self.getItemData().type == Table.RasterType: return self.layerRasterIcon if self.getItemData().isView: if hasattr(self.getItemData(), '_relationType') and self.getItemData()._relationType == 'm': return self.viewMaterializedIcon else: return self.viewIcon return self.tableIcon def path(self): pathList = [] if self.parent(): pathList.extend(self.parent().path()) if self.getItemData().type == Table.VectorType: pathList.append("%s::%s" % (self.data(0), self.getItemData().geomColumn)) else: pathList.append(self.data(0)) return pathList class DBModel(QAbstractItemModel): importVector = pyqtSignal(QgsVectorLayer, Database, QgsDataSourceUri, QModelIndex) notPopulated = pyqtSignal(QModelIndex) def __init__(self, parent=None): global isImportVectorAvail QAbstractItemModel.__init__(self, parent) self.treeView = parent self.header = [self.tr('Databases')] if isImportVectorAvail: self.importVector.connect(self.vectorImport) self.hasSpatialiteSupport = "spatialite" in supportedDbTypes() self.hasGPKGSupport = "gpkg" in supportedDbTypes() self.rootItem = TreeItem(None, None) for dbtype in supportedDbTypes(): dbpluginclass = createDbPlugin(dbtype) item = PluginItem(dbpluginclass, self.rootItem) item.changed.connect(partial(self.refreshItem, item)) def refreshItem(self, item): if isinstance(item, TreeItem): # find the index for the tree item using the path index = self._rPath2Index(item.path()) else: # find the index for the db item index = self._rItem2Index(item) if index.isValid(): self._refreshIndex(index) else: qDebug("invalid index") def _rItem2Index(self, item, parent=None): if parent is None: parent = QModelIndex() if item == self.getItem(parent): return parent if not parent.isValid() or parent.internalPointer().populated: for i in range(self.rowCount(parent)): index = self.index(i, 0, parent) index = self._rItem2Index(item, index) if index.isValid(): return index return QModelIndex() def _rPath2Index(self, path, parent=None, n=0): if parent is None: parent = QModelIndex() if path is None or len(path) == 0: return parent for i in range(self.rowCount(parent)): index = self.index(i, 0, parent) if self._getPath(index)[n] == path[0]: return self._rPath2Index(path[1:], index, n + 1) return parent def getItem(self, index): if not index.isValid(): return None return index.internalPointer().getItemData() def _getPath(self, index): if not index.isValid(): return None return index.internalPointer().path() def columnCount(self, parent): return 1 def data(self, index, role): if not index.isValid(): return None if role == Qt.ItemDataRole.DecorationRole and index.column() == 0: icon = index.internalPointer().icon() if icon: return icon if role != Qt.ItemDataRole.DisplayRole and role != Qt.ItemDataRole.EditRole: return None retval = index.internalPointer().data(index.column()) return retval def flags(self, index): global isImportVectorAvail if not index.isValid(): return Qt.ItemFlag.NoItemFlags flags = Qt.ItemFlag.ItemIsEnabled | Qt.ItemFlag.ItemIsSelectable if index.column() == 0: item = index.internalPointer() if isinstance(item, SchemaItem) \ or (isinstance(item, TableItem) and not (self.hasGPKGSupport and item.getItemData().type == Table.RasterType and int(gdal.VersionInfo()) < 3100000)): flags |= Qt.ItemFlag.ItemIsEditable if isinstance(item, TableItem): flags |= Qt.ItemFlag.ItemIsDragEnabled # vectors/tables can be dropped on connected databases to be imported if isImportVectorAvail: if isinstance(item, ConnectionItem) and item.populated: flags |= Qt.ItemFlag.ItemIsDropEnabled if isinstance(item, (SchemaItem, TableItem)): flags |= Qt.ItemFlag.ItemIsDropEnabled # SL/Geopackage db files can be dropped everywhere in the tree if self.hasSpatialiteSupport or self.hasGPKGSupport: flags |= Qt.ItemFlag.ItemIsDropEnabled return flags def headerData(self, section, orientation, role): if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole and section < len(self.header): return self.header[section] return None def index(self, row, column, parent): if not self.hasIndex(row, column, parent): return QModelIndex() parentItem = parent.internalPointer() if parent.isValid() else self.rootItem childItem = parentItem.child(row) if childItem: return self.createIndex(row, column, childItem) return QModelIndex() def parent(self, index): if not index.isValid(): return QModelIndex() childItem = index.internalPointer() parentItem = childItem.parent() if parentItem == self.rootItem: return QModelIndex() return self.createIndex(parentItem.row(), 0, parentItem) def rowCount(self, parent): parentItem = parent.internalPointer() if parent.isValid() else self.rootItem if not parentItem.populated: self._refreshIndex(parent, True) return parentItem.childCount() def hasChildren(self, parent): parentItem = parent.internalPointer() if parent.isValid() else self.rootItem return parentItem.childCount() > 0 or not parentItem.populated def setData(self, index, value, role): if role != Qt.ItemDataRole.EditRole or index.column() != 0: return False item = index.internalPointer() new_value = str(value) if isinstance(item, SchemaItem) or isinstance(item, TableItem): obj = item.getItemData() # rename schema or table or view if new_value == obj.name: return False with OverrideCursor(Qt.CursorShape.WaitCursor): try: obj.rename(new_value) self._onDataChanged(index) except BaseError as e: DlgDbError.showError(e, self.treeView) return False else: return True return False def removeRows(self, row, count, parent): self.beginRemoveRows(parent, row, count + row - 1) item = parent.internalPointer() for i in range(row, count + row): item.removeChild(row) self.endRemoveRows() def _refreshIndex(self, index, force=False): with OverrideCursor(Qt.CursorShape.WaitCursor): try: item = index.internalPointer() if index.isValid() else self.rootItem prevPopulated = item.populated if prevPopulated: self.removeRows(0, self.rowCount(index), index) item.populated = False if prevPopulated or force: if item.populate(): for child in item.childItems: child.changed.connect(partial(self.refreshItem, child)) self._onDataChanged(index) else: self.notPopulated.emit(index) except BaseError: item.populated = False def _onDataChanged(self, indexFrom, indexTo=None): if indexTo is None: indexTo = indexFrom self.dataChanged.emit(indexFrom, indexTo) QGIS_URI_MIME = "application/x-vnd.qgis.qgis.uri" def mimeTypes(self): return ["text/uri-list", self.QGIS_URI_MIME] def mimeData(self, indexes): mimeData = QMimeData() encodedData = QByteArray() stream = QDataStream(encodedData, QIODevice.OpenModeFlag.WriteOnly) for index in indexes: if not index.isValid(): continue if not isinstance(index.internalPointer(), TableItem): continue table = self.getItem(index) stream.writeQString(table.mimeUri()) mimeData.setData(self.QGIS_URI_MIME, encodedData) return mimeData def dropMimeData(self, data, action, row, column, parent): global isImportVectorAvail if action == Qt.DropAction.IgnoreAction: return True # vectors/tables to be imported must be dropped on connected db, schema or table canImportLayer = isImportVectorAvail and parent.isValid() and \ (isinstance(parent.internalPointer(), (SchemaItem, TableItem)) or (isinstance(parent.internalPointer(), ConnectionItem) and parent.internalPointer().populated)) added = 0 if data.hasUrls(): for u in data.urls(): filename = u.toLocalFile() if filename == "": continue if self.hasSpatialiteSupport: from .db_plugins.spatialite.connector import SpatiaLiteDBConnector if SpatiaLiteDBConnector.isValidDatabase(filename): # retrieve the SL plugin tree item using its path index = self._rPath2Index(["spatialite"]) if not index.isValid(): continue item = index.internalPointer() conn_name = QFileInfo(filename).fileName() uri = QgsDataSourceUri() uri.setDatabase(filename) item.getItemData().addConnection(conn_name, uri) item.changed.emit() added += 1 continue if canImportLayer: if QgsRasterLayer.isValidRasterFileName(filename): layerType = 'raster' providerKey = 'gdal' else: layerType = 'vector' providerKey = 'ogr' layerName = QFileInfo(filename).completeBaseName() if self.importLayer(layerType, providerKey, layerName, filename, parent): added += 1 if data.hasFormat(self.QGIS_URI_MIME): for uri in QgsMimeDataUtils.decodeUriList(data): if canImportLayer: if self.importLayer(uri.layerType, uri.providerKey, uri.name, uri.uri, parent): added += 1 return added > 0 def importLayer(self, layerType, providerKey, layerName, uriString, parent): global isImportVectorAvail if not isImportVectorAvail: return False if layerType == 'raster': return False # not implemented yet inLayer = QgsRasterLayer(uriString, layerName, providerKey) else: inLayer = QgsVectorLayer(uriString, layerName, providerKey) if not inLayer.isValid(): # invalid layer QMessageBox.warning(None, self.tr("Invalid layer"), self.tr("Unable to load the layer {0}").format(inLayer.name())) return False # retrieve information about the new table's db and schema outItem = parent.internalPointer() outObj = outItem.getItemData() outDb = outObj.database() outSchema = None if isinstance(outItem, SchemaItem): outSchema = outObj elif isinstance(outItem, TableItem): outSchema = outObj.schema() # toIndex will point to the parent item of the new table toIndex = parent if isinstance(toIndex.internalPointer(), TableItem): toIndex = toIndex.parent() if inLayer.type() == inLayer.VectorLayer: # create the output uri schema = outSchema.name if outDb.schemas() is not None and outSchema is not None else "" pkCol = geomCol = "" # default pk and geom field name value if providerKey in ['postgres', 'spatialite']: inUri = QgsDataSourceUri(inLayer.source()) pkCol = inUri.keyColumn() geomCol = inUri.geometryColumn() outUri = outDb.uri() outUri.setDataSource(schema, layerName, geomCol, "", pkCol) self.importVector.emit(inLayer, outDb, outUri, toIndex) return True return False def vectorImport(self, inLayer, outDb, outUri, parent): global isImportVectorAvail if not isImportVectorAvail: return False try: from .dlg_import_vector import DlgImportVector dlg = DlgImportVector(inLayer, outDb, outUri) QApplication.restoreOverrideCursor() if dlg.exec(): self._refreshIndex(parent) finally: inLayer.deleteLater()