from contextlib import contextmanager
import codecs
import functools
import os
import time
from psycopg2.errors import DeadlockDetected, UniqueViolation
from sqlalchemy import create_engine, text
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.orm import sessionmaker
import geopandas as gpd
import pandas as pd
from egon.data import config
[docs]def credentials():
"""Return local database connection parameters.
Returns
-------
dict
Complete DB connection information
"""
translated = {
"--database-name": "POSTGRES_DB",
"--database-password": "POSTGRES_PASSWORD",
"--database-host": "HOST",
"--database-port": "PORT",
"--database-user": "POSTGRES_USER",
}
configuration = config.settings()["egon-data"]
update = {
translated[flag]: configuration[flag]
for flag in configuration
if flag in translated
}
configuration.update(update)
return configuration
[docs]@functools.lru_cache(maxsize=None)
def engine_for(pid): # pylint: disable=unused-argument
db_config = credentials()
return create_engine(
f"postgresql+psycopg2://{db_config['POSTGRES_USER']}:"
f"{db_config['POSTGRES_PASSWORD']}@{db_config['HOST']}:"
f"{db_config['PORT']}/{db_config['POSTGRES_DB']}",
echo=False,
)
[docs]def engine():
"""Engine for local database."""
return engine_for(os.getpid())
[docs]def execute_sql(sql_string):
"""Execute a SQL expression given as string.
The SQL expression passed as plain string is convert to a
`sqlalchemy.sql.expression.TextClause`.
Parameters
----------
sql_string : str
SQL expression
"""
engine_local = engine()
with engine_local.connect().execution_options(autocommit=True) as con:
con.execute(text(sql_string))
[docs]def execute_sql_script(script, encoding="utf-8-sig"):
"""Execute a SQL script given as a file name.
Parameters
----------
script : str
Path of the SQL-script
encoding : str
Encoding which is used for the SQL file. The default is "utf-8-sig".
Returns
-------
None.
"""
with codecs.open(script, "r", encoding) as fd:
sqlfile = fd.read()
execute_sql(sqlfile)
[docs]@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
Session = sessionmaker(bind=engine())
session = Session()
try:
yield session
session.commit()
except: # noqa: E722 (This is OK, because of the immediate re-raise.)
session.rollback()
raise
finally:
session.close()
[docs]def session_scoped(function):
"""Provide a session scope to a function.
Can be used as a decorator like this:
>>> @session_scoped
... def get_bind(session):
... return session.get_bind()
...
>>> get_bind()
Engine(postgresql+psycopg2://egon:***@127.0.0.1:59734/egon-data)
Note that the decorated function needs to accept a parameter named
`session`, but is called without supplying a value for that parameter
because the parameter's value will be filled in by `session_scoped`.
Using this decorator allows saving an indentation level when defining
such functions but it also has other usages.
"""
@functools.wraps(function)
def wrapped(*xs, **ks):
with session_scope() as session:
return function(session=session, *xs, **ks)
return wrapped
[docs]def select_dataframe(sql, index_col=None, warning=True):
"""Select data from local database as pandas.DataFrame
Parameters
----------
sql : str
SQL query to be executed.
index_col : str, optional
Column(s) to set as index(MultiIndex). The default is None.
Returns
-------
df : pandas.DataFrame
Data returned from SQL statement.
"""
df = pd.read_sql(sql, engine(), index_col=index_col)
if df.size == 0 and warning is True:
print(f"WARNING: No data returned by statement: \n {sql}")
return df
[docs]def select_geodataframe(sql, index_col=None, geom_col="geom", epsg=3035):
"""Select data from local database as geopandas.GeoDataFrame
Parameters
----------
sql : str
SQL query to be executed.
index_col : str, optional
Column(s) to set as index(MultiIndex). The default is None.
geom_col : str, optional
column name to convert to shapely geometries. The default is 'geom'.
epsg : int, optional
EPSG code specifying output projection. The default is 3035.
Returns
-------
gdf : pandas.DataFrame
Data returned from SQL statement.
"""
gdf = gpd.read_postgis(
sql, engine(), index_col=index_col, geom_col=geom_col
)
if gdf.size == 0:
print(f"WARNING: No data returned by statement: \n {sql}")
else:
gdf = gdf.to_crs(epsg=epsg)
return gdf
[docs]def next_etrago_id(component):
"""Select next id value for components in etrago tables
Parameters
----------
component : str
Name of component
Returns
-------
next_id : int
Next index value
Notes
-----
To catch concurrent DB commits, consider to use
:func:`check_db_unique_violation` instead.
"""
if component == "transformer":
id_column = "trafo_id"
else:
id_column = f"{component}_id"
max_id = select_dataframe(
f"""
SELECT MAX({id_column}) FROM grid.egon_etrago_{component}
"""
)["max"][0]
if max_id:
next_id = max_id + 1
else:
next_id = 1
return next_id
[docs]def check_db_unique_violation(func):
"""Wrapper to catch psycopg's UniqueViolation errors during concurrent DB
commits.
Preferrably used with :func:`next_etrago_id`. Retries DB operation 10
times before raising original exception.
Can be used as a decorator like this:
>>> @check_db_unique_violation
... def commit_something_to_database():
... # commit something here
... return
...
>>> commit_something_to_database() # doctest: +SKIP
Examples
--------
Add new bus to eTraGo's bus table:
>>> from egon.data import db
>>> from egon.data.datasets.etrago_setup import EgonPfHvBus
...
>>> @check_db_unique_violation
... def add_etrago_bus():
... bus_id = db.next_etrago_id("bus")
... with db.session_scope() as session:
... emob_bus_id = db.next_etrago_id("bus")
... session.add(
... EgonPfHvBus(
... scn_name="eGon2035",
... bus_id=bus_id,
... v_nom=1,
... carrier="whatever",
... x=52,
... y=13,
... geom="<some_geom>"
... )
... )
... session.commit()
...
>>> add_etrago_bus() # doctest: +SKIP
Parameters
----------
func: func
Function to wrap
Notes
-----
Background: using :func:`next_etrago_id` may cause trouble if tasks are
executed simultaneously, cf.
https://github.com/openego/eGon-data/issues/514
Important: your function requires a way to escape the violation as the
loop will not terminate until the error is resolved! In case of eTraGo
tables you can use :func:`next_etrago_id`, see example above.
"""
def commit(*args, **kwargs):
unique_violation = True
ret = None
ctr = 0
while unique_violation:
try:
ret = func(*args, **kwargs)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
print("Entry is not unique, retrying...")
ctr += 1
time.sleep(3)
if ctr > 10:
print("No success after 10 retries, exiting...")
raise e
else:
raise e
# ===== TESTING ON DEADLOCKS START =====
except OperationalError as e:
if isinstance(e.orig, DeadlockDetected):
print("Deadlock detected, retrying...")
ctr += 1
time.sleep(3)
if ctr > 10:
print("No success after 10 retries, exiting...")
raise e
# ===== TESTING ON DEADLOCKS END =======
else:
unique_violation = False
return ret
return commit
[docs]def assign_gas_bus_id(dataframe, scn_name, carrier):
"""Assign `bus_id`s to points according to location.
The points are taken from the given `dataframe` and the geometries by
which the `bus_id`s are assigned to them are taken from the
`grid.egon_gas_voronoi` table.
Parameters
----------
dataframe : pandas.DataFrame
DataFrame cointaining points
scn_name : str
Name of the scenario
carrier : str
Name of the carrier
Returns
-------
res : pandas.DataFrame
Dataframe including bus_id
"""
voronoi = select_geodataframe(
f"""
SELECT bus_id, geom FROM grid.egon_gas_voronoi
WHERE scn_name = '{scn_name}' AND carrier = '{carrier}';
""",
epsg=4326,
)
res = gpd.sjoin(dataframe, voronoi)
res["bus"] = res["bus_id"]
res = res.drop(columns=["index_right"])
# Assert that all power plants have a bus_id
assert (
res.bus.notnull().all()
), f"Some points are not attached to a {carrier} bus."
return res