[Processing] Support authentication with postgis

This commit is contained in:
arnaud.morvan@camptocamp.com 2016-05-27 15:09:17 +02:00
parent 3d941e3ac9
commit 8ddae27130
6 changed files with 145 additions and 123 deletions

View File

@ -36,6 +36,7 @@ from processing.core.parameters import ParameterTableField
from processing.algs.gdal.GdalAlgorithm import GdalAlgorithm
from processing.algs.gdal.GdalUtils import GdalUtils
from processing.tools.postgis import uri_from_name, GeoDB
from processing.tools.system import isWindows
from processing.tools.vector import ogrConnectionString, ogrLayerName
@ -63,6 +64,10 @@ class Ogr2OgrTableToPostGisList(GdalAlgorithm):
PRECISION = 'PRECISION'
OPTIONS = 'OPTIONS'
def __init__(self):
GdalAlgorithm.__init__(self)
self.processing = False
def dbConnectionNames(self):
settings = QSettings()
settings.beginGroup('/PostgreSQL/connections/')
@ -109,15 +114,18 @@ class Ogr2OgrTableToPostGisList(GdalAlgorithm):
self.addParameter(ParameterString(self.OPTIONS,
self.tr('Additional creation options'), '', optional=True))
def processAlgorithm(self, progress):
self.processing = True
GdalAlgorithm.processAlgorithm(self, progress)
self.processing = False
def getConsoleCommands(self):
connection = self.DB_CONNECTIONS[self.getParameterValue(self.DATABASE)]
settings = QSettings()
mySettings = '/PostgreSQL/connections/' + connection
dbname = settings.value(mySettings + '/database')
user = settings.value(mySettings + '/username')
host = settings.value(mySettings + '/host')
port = settings.value(mySettings + '/port')
password = settings.value(mySettings + '/password')
uri = uri_from_name(connection)
if self.processing:
# to get credentials input when needed
uri = GeoDB(uri=uri).uri
inLayer = self.getParameterValue(self.INPUT_LAYER)
ogrLayer = ogrConnectionString(inLayer)[1:-1]
schema = unicode(self.getParameterValue(self.SCHEMA))
@ -142,19 +150,11 @@ class Ogr2OgrTableToPostGisList(GdalAlgorithm):
arguments.append('--config PG_USE_COPY YES')
arguments.append('-f')
arguments.append('PostgreSQL')
arguments.append('PG:"host=')
arguments.append(host)
arguments.append('port=')
arguments.append(port)
if len(dbname) > 0:
arguments.append('dbname=' + dbname)
if len(password) > 0:
arguments.append('password=' + password)
if len(schema) > 0:
arguments.append('active_schema=' + schema)
else:
arguments.append('active_schema=public')
arguments.append('user=' + user + '"')
arguments.append('PG:"')
for token in uri.connectionInfo(self.processing).split(' '):
arguments.append(token)
arguments.append('active_schema={}'.format(schema or 'public'))
arguments.append('"')
arguments.append(ogrLayer)
arguments.append('-nlt NONE')
arguments.append(ogrLayerName(inLayer))

View File

@ -39,6 +39,7 @@ from processing.core.parameters import ParameterTableField
from processing.algs.gdal.GdalAlgorithm import GdalAlgorithm
from processing.algs.gdal.GdalUtils import GdalUtils
from processing.tools.postgis import uri_from_name, GeoDB
from processing.tools.system import isWindows
from processing.tools.vector import ogrConnectionString, ogrLayerName
@ -80,6 +81,10 @@ class Ogr2OgrToPostGisList(GdalAlgorithm):
PROMOTETOMULTI = 'PROMOTETOMULTI'
OPTIONS = 'OPTIONS'
def __init__(self):
GdalAlgorithm.__init__(self)
self.processing = False
def dbConnectionNames(self):
settings = QSettings()
settings.beginGroup('/PostgreSQL/connections/')
@ -153,15 +158,18 @@ class Ogr2OgrToPostGisList(GdalAlgorithm):
self.addParameter(ParameterString(self.OPTIONS,
self.tr('Additional creation options'), '', optional=True))
def processAlgorithm(self, progress):
self.processing = True
GdalAlgorithm.processAlgorithm(self, progress)
self.processing = False
def getConsoleCommands(self):
connection = self.DB_CONNECTIONS[self.getParameterValue(self.DATABASE)]
settings = QSettings()
mySettings = '/PostgreSQL/connections/' + connection
dbname = settings.value(mySettings + '/database')
user = settings.value(mySettings + '/username')
host = settings.value(mySettings + '/host')
port = settings.value(mySettings + '/port')
password = settings.value(mySettings + '/password')
uri = uri_from_name(connection)
if self.processing:
# to get credentials input when needed
uri = GeoDB(uri=uri).uri
inLayer = self.getParameterValue(self.INPUT_LAYER)
ogrLayer = ogrConnectionString(inLayer)[1:-1]
ssrs = unicode(self.getParameterValue(self.S_SRS))
@ -200,17 +208,11 @@ class Ogr2OgrToPostGisList(GdalAlgorithm):
arguments.append('--config PG_USE_COPY YES')
arguments.append('-f')
arguments.append('PostgreSQL')
arguments.append('PG:"host=' + host)
arguments.append('port=' + port)
if len(dbname) > 0:
arguments.append('dbname=' + dbname)
if len(password) > 0:
arguments.append('password=' + password)
if len(schema) > 0:
arguments.append('active_schema=' + schema)
else:
arguments.append('active_schema=public')
arguments.append('user=' + user + '"')
arguments.append('PG:"')
for token in uri.connectionInfo(self.processing).split(' '):
arguments.append(token)
arguments.append('active_schema={}'.format(schema or 'public'))
arguments.append('"')
arguments.append(dimstring)
arguments.append(ogrLayer)
arguments.append(ogrLayerName(inLayer))

View File

@ -86,6 +86,8 @@ class ImportIntoPostGIS(GeoAlgorithm):
def processAlgorithm(self, progress):
connection = self.DB_CONNECTIONS[self.getParameterValue(self.DATABASE)]
db = postgis.GeoDB.from_name(connection)
schema = self.getParameterValue(self.SCHEMA)
overwrite = self.getParameterValue(self.OVERWRITE)
createIndex = self.getParameterValue(self.CREATEINDEX)
@ -94,17 +96,6 @@ class ImportIntoPostGIS(GeoAlgorithm):
forceSinglePart = self.getParameterValue(self.FORCE_SINGLEPART)
primaryKeyField = self.getParameterValue(self.PRIMARY_KEY)
encoding = self.getParameterValue(self.ENCODING)
settings = QSettings()
mySettings = '/PostgreSQL/connections/' + connection
try:
database = settings.value(mySettings + '/database')
username = settings.value(mySettings + '/username')
host = settings.value(mySettings + '/host')
port = settings.value(mySettings + '/port', type=int)
password = settings.value(mySettings + '/password')
except Exception as e:
raise GeoAlgorithmExecutionException(
self.tr('Wrong database connection name: %s' % connection))
layerUri = self.getParameterValue(self.INPUT)
layer = dataobjects.getObjectFromUri(layerUri)
@ -115,13 +106,6 @@ class ImportIntoPostGIS(GeoAlgorithm):
table = table.replace(' ', '').lower()[0:62]
providerName = 'postgres'
try:
db = postgis.GeoDB(host=host, port=port, dbname=database,
user=username, passwd=password)
except postgis.DbError as e:
raise GeoAlgorithmExecutionException(
self.tr("Couldn't connect to database:\n%s") % unicode(e))
geomColumn = self.getParameterValue(self.GEOMETRY_COLUMN)
if not geomColumn:
geomColumn = 'the_geom'
@ -141,8 +125,7 @@ class ImportIntoPostGIS(GeoAlgorithm):
if not layer.hasGeometryType():
geomColumn = None
uri = QgsDataSourceURI()
uri.setConnection(host, unicode(port), database, username, password)
uri = db.uri
if primaryKeyField:
uri.setDataSource(schema, table, geomColumn, '', primaryKeyField)
else:

View File

@ -46,24 +46,7 @@ class PostGISExecuteSQL(GeoAlgorithm):
def processAlgorithm(self, progress):
connection = self.getParameterValue(self.DATABASE)
settings = QSettings()
mySettings = '/PostgreSQL/connections/' + connection
try:
database = settings.value(mySettings + '/database')
username = settings.value(mySettings + '/username')
host = settings.value(mySettings + '/host')
port = settings.value(mySettings + '/port', type=int)
password = settings.value(mySettings + '/password')
except Exception as e:
raise GeoAlgorithmExecutionException(
self.tr('Wrong database connection name: %s' % connection))
try:
self.db = postgis.GeoDB(host=host, port=port,
dbname=database, user=username, passwd=password)
except postgis.DbError as e:
raise GeoAlgorithmExecutionException(
self.tr("Couldn't connect to database:\n%s") % unicode(e))
self.db = postgis.GeoDB.from_name(connection)
sql = self.getParameterValue(self.SQL).replace('\n', ' ')
try:
self.db._exec_sql_and_commit(unicode(sql))

View File

@ -100,23 +100,10 @@ class ConnectionItem(QTreeWidgetItem):
def populateSchemas(self):
if self.childCount() != 0:
return
settings = QSettings()
connSettings = '/PostgreSQL/connections/' + self.connection
database = settings.value(connSettings + '/database')
user = settings.value(connSettings + '/username')
host = settings.value(connSettings + '/host')
port = settings.value(connSettings + '/port')
passwd = settings.value(connSettings + '/password')
uri = QgsDataSourceURI()
uri.setConnection(host, str(port), database, user, passwd)
connInfo = uri.connectionInfo()
(success, user, passwd) = QgsCredentials.instance().get(connInfo, None, None)
if success:
QgsCredentials.instance().put(connInfo, user, passwd)
geodb = GeoDB(host, int(port), database, user, passwd)
schemas = geodb.list_schemas()
for oid, name, owner, perms in schemas:
item = QTreeWidgetItem()
item.setText(0, name)
item.setIcon(0, self.schemaIcon)
self.addChild(item)
geodb = GeoDB.from_name(self.connection)
schemas = geodb.list_schemas()
for oid, name, owner, perms in schemas:
item = QTreeWidgetItem()
item.setText(0, name)
item.setIcon(0, self.schemaIcon)
self.addChild(item)

View File

@ -29,10 +29,41 @@ import psycopg2
import psycopg2.extensions # For isolation levels
import re
from qgis.PyQt.QtCore import QSettings
from qgis.core import QgsDataSourceURI, QgsCredentials
# Use unicode!
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
def uri_from_name(conn_name):
settings = QSettings()
settings.beginGroup(u"/PostgreSQL/connections/%s" % conn_name)
if not settings.contains("database"): # non-existent entry?
raise DbError('There is no defined database connection "%s".' % conn_name)
uri = QgsDataSourceURI()
settingsList = ["service", "host", "port", "database", "username", "password", "authcfg"]
service, host, port, database, username, password, authcfg = [settings.value(x, "", type=str) for x in settingsList]
useEstimatedMetadata = settings.value("estimatedMetadata", False, type=bool)
sslmode = settings.value("sslmode", QgsDataSourceURI.SSLprefer, type=int)
settings.endGroup()
if service:
uri.setConnection(service, database, username, password, sslmode, authcfg)
else:
uri.setConnection(host, port, database, username, password, sslmode, authcfg)
uri.setUseEstimatedMetadata(useEstimatedMetadata)
return uri
class TableAttribute:
def __init__(self, row):
@ -100,7 +131,10 @@ class DbError(Exception):
return unicode(self).encode('utf-8')
def __unicode__(self):
return u'MESSAGE: %s\nQUERY: %s' % (self.message, self.query)
text = u'MESSAGE: %s' % self.message
if self.query:
text += u'\nQUERY: %s' % self.query
return text
class TableField:
@ -138,41 +172,74 @@ class TableField:
class GeoDB:
@classmethod
def from_name(cls, conn_name):
uri = uri_from_name(conn_name)
return cls(uri=uri)
def __init__(self, host=None, port=None, dbname=None, user=None,
passwd=None):
passwd=None, service=None, uri=None):
# Regular expression for identifiers without need to quote them
self.re_ident_ok = re.compile(r"^\w+$")
self.host = host
self.port = port
self.dbname = dbname
self.user = user
self.passwd = passwd
if uri:
self.uri = uri
else:
self.uri = QgsDataSourceURI()
if service:
self.uri.setConnection(service, dbname, user, passwd)
else:
self.uri.setConnection(host, port, dbname, user, passwd)
if self.dbname == '' or self.dbname is None:
self.dbname = self.user
conninfo = self.uri.connectionInfo(False)
err = None
for i in range(4):
expandedConnInfo = uri.connectionInfo(True)
try:
self.con = psycopg2.connect(expandedConnInfo.encode('utf-8'))
if err is not None:
QgsCredentials.instance().put(conninfo,
self.uri.username(),
self.uri.password())
break
except psycopg2.OperationalError as e:
if i == 3:
raise DbError(unicode(e))
try:
self.con = psycopg2.connect(self.con_info())
except psycopg2.OperationalError as e:
raise DbError(unicode(e))
err = unicode(e)
user = self.uri.username()
password = self.uri.password()
(ok, user, password) = QgsCredentials.instance().get(conninfo,
user,
password,
err)
if not ok:
raise DbError(u'Action cancelled by user')
if user:
self.uri.setUsername(user)
if password:
self.uri.setPassword(password)
finally:
# remove certs (if any) of the expanded connectionInfo
expandedUri = QgsDataSourceURI(expandedConnInfo)
sslCertFile = expandedUri.param("sslcert")
if sslCertFile:
sslCertFile = sslCertFile.replace("'", "")
os.remove(sslCertFile)
sslKeyFile = expandedUri.param("sslkey")
if sslKeyFile:
sslKeyFile = sslKeyFile.replace("'", "")
os.remove(sslKeyFile)
sslCAFile = expandedUri.param("sslrootcert")
if sslCAFile:
sslCAFile = sslCAFile.replace("'", "")
os.remove(sslCAFile)
self.has_postgis = self.check_postgis()
def con_info(self):
con_str = ''
if self.host:
con_str += "host='%s' " % self.host
if self.port:
con_str += 'port=%d ' % self.port
if self.dbname:
con_str += "dbname='%s' " % self.dbname
if self.user:
con_str += "user='%s' " % self.user
if self.passwd:
con_str += "password='%s' " % self.passwd
return con_str
def get_info(self):
c = self.con.cursor()
self._exec_sql(c, 'SELECT version()')