Source code for finance.src.postgres_interface

"""
Module to interact with postgres databases

It contains generic methods to interact with postgres databases regardless of
the data they contain
"""

import os

import pandas as pd
import sqlalchemy
from dotenv import load_dotenv
from sqlalchemy import MetaData, Table, select, text
from sqlalchemy.dialects.postgresql import insert
from src.utils import custom_logger


[docs]class PostgresInterface: """ Class to interact with postgres databases """ def __init__(self): load_dotenv() self.logger = custom_logger(logger_name="postgres_interface") self.provider = os.environ.get("PROVIDER")
[docs] def create_engine(self) -> sqlalchemy.engine.Engine: """ function that creates engines to connect to postgres databases Returns ------- dict dictionary with the engines to connect to the databases """ user = os.environ.get(f"{self.provider}_POSTGRES_USER") password = os.environ.get(f"{self.provider}_POSTGRES_PASSWORD") host = os.environ.get(f"{self.provider}_POSTGRES_HOST") port = os.environ.get(f"{self.provider}_POSTGRES_PORT") db = os.environ.get(f"{self.provider}_POSTGRES_DB") ssl_mode = "?sslmode=" + os.environ.get(f"{self.provider}_SSL_MODE") or "" engine = sqlalchemy.create_engine( f"postgresql://{user}:{password}@{host}:{port}/{db}{ssl_mode}" ) return engine
[docs] def create_table_object( self, table_name: str, engine: sqlalchemy.engine.Engine, schema: str = "stocks" ) -> sqlalchemy.Table: """ Method to create a table object Parameters ---------- table_name : str name of the table to create the object for engine : sqlalchemy.engine.Engine engine to connect to the database schema : str schema of the table default: stocks Returns ------- sqlalchemy.Table table object """ metadata = MetaData() table = Table(table_name, metadata, autoload_with=engine, schema=schema) return table
[docs] def insert_batch( self, table: sqlalchemy.Table, batch: list, conn: sqlalchemy.engine.Connection ) -> None: """ Method to insert a batch of data into a table Parameters ---------- table : str table to insert data into data : list list of tuples with the data to insert into the table Returns ------- None """ # statement to insert data into neon database self.logger.warning(f"Inserting batch of {len(batch)} rows") insert_statement = insert(table).values(batch).on_conflict_do_nothing() conn.execute(insert_statement) conn.commit()
[docs] def read_table_to_df(self, table: str, schema: str = "stocks") -> pd.DataFrame: """ Method to read a table into a dataframe Parameters ---------- table : tabble name to read Returns ------- pd.DataFrame dataframe with the data from the table """ engine = self.create_engine() table = self.create_table_object(table, engine, schema) query = select(table) with engine.connect() as conn: result = conn.execute(query).fetchall() df = pd.DataFrame(result, columns=table.columns.keys()) return df
[docs] def migrate_dbs( self, batch_size: int = 5000, tap_cloud_provider: str = "NEON", target_cloud_provider: str = "AVN", ) -> None: """ Method to migrate a database to another one Supposed to be used only once to migrate data to a target database Parameters ---------- batch_size : int number of rows to insert in each batch default: 5000 tap_cloud_provider : str cloud provider of the tap database default: NEON target_cloud_provider : str cloud provider of the target database default: AVN Returns ------- None """ tap_user = os.environ.get(f"{tap_cloud_provider}_POSTGRES_USER") tap_password = os.environ.get(f"{tap_cloud_provider}_POSTGRES_PASSWORD") tap_host = os.environ.get(f"{tap_cloud_provider}_POSTGRES_HOST") tap_port = os.environ.get(f"{tap_cloud_provider}_POSTGRES_PORT") tap_db = os.environ.get(f"{tap_cloud_provider}_POSTGRES_DB") engine_tap = sqlalchemy.create_engine( f"postgresql://{tap_user}:{tap_password}@{tap_host}:{tap_port}/{tap_db}" ) target_user = os.environ.get(f"{target_cloud_provider}_POSTGRES_USER") target_password = os.environ.get(f"{target_cloud_provider}_POSTGRES_PASSWORD") target_host = os.environ.get(f"{target_cloud_provider}_POSTGRES_HOST") target_port = os.environ.get(f"{target_cloud_provider}_POSTGRES_PORT") target_db = os.environ.get(f"{target_cloud_provider}_POSTGRES_DB") engine_target = sqlalchemy.create_engine( f"postgresql://{target_user}:{target_password}@{target_host}:{target_port}/{target_db}" ) # List of tables in local postgres database in stocks schema metadata = MetaData() information_schema_tables = Table( "tables", metadata, autoload_with=engine_tap, schema="information_schema" ) query = ( select(information_schema_tables.c.table_name) .where(information_schema_tables.c.table_schema == "stocks") .order_by(information_schema_tables.c.table_name) ) with engine_tap.connect() as conn_local: tables = [table[0] for table in conn_local.execute(query).fetchall()] # if "alembic" of "dbt" are in table's name, remove them blacklist = ["alembic", "dbt"] tables = [table for table in tables if not any(x in table for x in blacklist)] # insert table's data into neon database for table in tables: self.logger.warning(f"Inserting data from {table} into target database") with engine_tap.connect() as conn_tap: with engine_target.connect() as conn_target: table_tap = self.create_table_object(table, engine_tap) total_rows = conn_tap.execute( text("SELECT COUNT(*) FROM stocks." + table) ).scalar() self.logger.warning( f"Total rows in {table_tap} from tap database: {total_rows}" ) # Calculate how many iterations you will need iterations = total_rows // batch_size + ( 1 if total_rows % batch_size else 0 ) for i in range(iterations): offset = i * batch_size # Create a SELECT statement with LIMIT and OFFSET select_stmt = conn_tap.execute( text( "SELECT * FROM stocks." + table + f" LIMIT {batch_size} OFFSET {offset}" ) ) result_set = select_stmt.fetchall() # convert result set to list of dicts result_set = [tuple(row) for row in result_set] self.logger.warning( f"""Selected {len(result_set)} rows from {table} from tap database with offset {offset}""" ) table_target = self.create_table_object(table, engine_target) # statement to insert data into neon database self.logger.warning( f"Inserting batch of {len(result_set)} rows" ) self.insert_batch(table_target, result_set, conn_target) self.logger.warning( f"Inserted {len(result_set)} rows from {table} into target database" ) del result_set