""" Module to perform operations on the yahoo finance API data (tickers) """
from typing import List, Literal, Union
import pandas as pd
import yfinance as yf
from sqlalchemy import MetaData, Table, func, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql import null
from src.postgres_interface import PostgresInterface
from src.utils import custom_logger
[docs]class Ticker:
def __init__(
self,
countries: Union[str, List[str]] = None,
chunksize: int = 20,
frequency: Literal["annual", "quarterly"] = "annual",
schema: str = "stocks",
):
self.logger = custom_logger(logger_name="ticker")
self.countries = countries
self.chunksize = chunksize
self.frequency = frequency
self.schema = schema
self.postgres_interface = PostgresInterface()
self.engine = self.postgres_interface.create_engine()
def _create_yf_ticker(self, ticker_symbol: str) -> yf.Ticker:
"""
Method to create a yfinance.Ticker object
Parameters
----------
ticker_symbol : str
ticker symbol of the stock
Returns
-------
yf.Ticker
yfinance.Ticker object
"""
ticker = yf.Ticker(ticker_symbol)
return ticker
[docs] def update_tickers_list_table(self):
"""
Method to update the tickers_list table in postgres
Gets all the data in the data dir excel file (all available tickers) and
inserts them into the database
"""
valids_df = pd.read_excel("src/data/tickers_list.xlsx")
# rename columns to match the database
valids_df.rename(
columns={
"Ticker": "ticker",
"Name": "name",
"Exchange": "exchange",
"Category Name": "category_name",
"Country": "country",
},
inplace=True,
)
# insert the data into the database
valids_df.to_sql(
name="tickers_list",
con=self.engine,
if_exists="replace",
schema="stocks",
index=False,
method="multi",
chunksize=1000,
)
[docs] def load_valid_tickers(self, sink_table: str) -> List[str]:
"""
Method to load the valid tickers from the database based
on the validity status of the tickers in the valid_tickers table
Parameters
----------
sink_table : str
The name of the table to load the tickers from
Returns
-------
List[str]
A list of the valid tickers
"""
valid_tickers = Table(
"valid_tickers",
MetaData(),
autoload_with=self.engine,
schema=self.schema,
)
table_obj = Table(
sink_table, MetaData(), autoload_with=self.engine, schema=self.schema
)
query = (
select(valid_tickers.c.ticker)
.outerjoin(table_obj, valid_tickers.c.ticker == table_obj.c.ticker)
.where(table_obj.c.ticker == null())
.where(valid_tickers.c.validity)
)
with self.engine.connect() as conn:
valid_tickers = [result[0] for result in conn.execute(query).fetchall()]
return valid_tickers
[docs] def flush_records(self, table_name: str, records: list):
"""
Method to flush records to a table
Parameters
----------
table_name: str
The name of the table to flush the records to
records: list
The records to flush to the table
"""
if not records:
return
table = self.postgres_interface.create_table_object(
table_name=table_name, engine=self.engine, schema=self.schema
)
with self.engine.connect() as conn:
# insert the data into the database on conflict update
conn.execute(
insert(table)
.values(records)
.on_conflict_do_update(
index_elements=["ticker", "report_date", "frequency"],
set_={
"insert_date": func.current_date(),
},
)
)
conn.commit()
self.logger.warning(
f"Data flushed with {len(records)} records inserted into {table_name}"
)
[docs] def get_data_df(self, table_name: str, frequency: str, ticker: yf.Ticker):
"""
Method that returns a df based on the name of the table and frequency
Parameters
----------
table_name: str
The name of the table that is going to be filled
frequency: str
The frequency of the data to be extracted
Either annual or quarterly
Returns
-------
pd.DataFrame
The dataframe with the data
"""
property_dict = {
("income_stmt", "annual"): "income_stmt",
("income_stmt", "quarterly"): "quarterly_income_stmt",
("balance_sheet", "annual"): "balance_sheet",
("balance_sheet", "quarterly"): "quarterly_balance_sheet",
("cashflow", "annual"): "cashflow",
("cashflow", "quarterly"): "quarterly_cashflow",
("financials", "annual"): "financials",
("financials", "quarterly"): "quarterly_financials",
}
property = property_dict[(table_name, frequency)]
df = getattr(ticker, property).T
return df
def _adjust_df_columns(
self, df: pd.DataFrame, table_name: str, table_columns: list
) -> pd.DataFrame:
"""
Function that gets a df that contains data for a specific ticker, and
adjusts the columns of the df to match the columns of the table in the
database
Parameters
----------
df: pd.DataFrame
The dataframe that contains the data for a specific ticker
table_name: str
The name of the table that the data is extracted for
table_columns: list
The columns of the table
Returns
-------
pd.DataFrame
The dataframe with the adjusted columns ready to be inserted into the database
"""
# make column names all lower case and replace spaces with underscores
df.columns = [i.replace(" ", "_").lower() for i in list(df.columns)]
missed_columns = []
# if a column does not exist in the stocks.table_name table, drop it from the df
for column in [i.replace(" ", "_") for i in list(df.columns)]:
if column not in table_columns:
self.logger.warning(f"Column {column} not in {table_name} columns")
missed_columns.append(column)
df.drop(columns=column, inplace=True)
# if a column does not exist in the df, It will be added with null values
for column in table_columns:
if column not in df.columns:
df[column] = None
return df
[docs] def get_columns_names(self, table_name: str):
"""
Method that returns the columns names of a table
"""
table = self.postgres_interface.create_table_object(
table_name=table_name, engine=self.engine
)
columns = [column.name for column in table.columns]
return columns
[docs] def update_validity_status(
self, table_name: str, tickers: list[str], availability: bool = False
):
"""
Method That gets a list of tickers and updates the validity status of the tickers
for a specific criteria (e.g. balance_sheet_annual_availabile) in the
valid_tickers table, e.g. if the ticker has not balance sheet data for
the quarterly frequency, the balance_sheet_quarterly_available column
in the valid_tickers table will be updated to False
Parameters
----------
table_name: str
The name of the table which the ticker was supposed to be updated
ticker: list[str]
The tickers that was supposed to be updated
validity: bool
The validity status of the ticker for the specific criteria
default: False
Returns
-------
None
"""
# get the table object
valid_tickers = Table(
"valid_tickers",
MetaData(),
autoload_with=self.engine,
schema=self.schema,
)
# update the validity status of all the tickers at once
query = (
valid_tickers.update()
.where(valid_tickers.c.ticker.in_(tickers))
.values(
{
f"{table_name}_{self.frequency}_available": availability,
}
)
)
with self.engine.connect() as conn:
conn.execute(query)
conn.commit()
self.logger.warning(
f"Validity status updated to {availability} for {len(tickers)} tickers"
)
[docs] def get_currency_code(self, ticker: str) -> str:
"""
Method that gets the currency code of a ticker from valid_tickers table
in the database
Parameters
----------
ticker: str
The ticker symbol
Returns
-------
str
The currency code of the ticker
"""
table = Table(
"valid_tickers",
MetaData(),
autoload_with=self.engine,
schema=self.schema,
)
query = select(table.c.currency_code).where(table.c.ticker == ticker)
with self.engine.connect() as conn:
currency_code = conn.execute(query).fetchone()[0]
return currency_code