Decision trees (CARTs) are among the most intuitive supervised learning models used in financial risk management. They work by iteratively splitting the data at nodes to maximize information gain—which is measured by metrics such as the Gini coefficient or entropy. In a typical FRM example, different features such as revenue growth, debt-to-equity, and operating cash flow margin are used to build a decision tree model to predict whether there would be an earnings drop.
Below is a very brief review of the core ideas:
Decision Trees (CARTs): Supervised models that split data based on features to make predictions.
Splitting Principle: At each node, select the feature that maximizes information gain.
Key Metrics:
Entropy: \(E = -\sum_{i=1}^{M} p_i \log_2(p_i)\)
Gini: \(G = 1 - \sum_{i=1}^{M} p_i^2\)
Example Application: Using quarterly financial statements from Apple (AAPL), we derived key metrics—such as revenue growth, debt-to-equity ratio, and operating cash flow margin—to predict whether earnings will drop from one period to the next.
Algorithm: The decision tree is built iteratively using the principle of the Iterative Dichotomizer algorithm. At each node, the algorithm (via scikit‑learn’s DecisionTreeClassifier) selects the feature that maximally reduces impurity (here measured by the Gini coefficient) and splits the data. This splitting continues recursively until further splits do not improve the model or until stopping criteria (like a maximum depth) are met.
Overfitting Management: In our example, we apply pre‑pruning by setting a maximum tree depth (max_depth=3) to avoid overfitting to the limited dataset.
Ensemble Methods: Techniques like bagging, pasting, and random forests aggregate multiple trees to improve performance.
Below is the complete Python code that illustrates the end-to-end pipeline:
## --- Imports ---
import requests
import pandas as pd
import numpy as np
import yfinance as yf
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
import re
# --- Step 1: Download and save financial statements to Excel ---
# Stock symbol and URL mapping for the various statements
ticker = 'aapl'
url_sheet_map = {
f"https://stockanalysis.com/stocks/{ticker}/?p=quarterly": "Overview",
f"https://stockanalysis.com/stocks/{ticker}/financials/?p=quarterly": "Income Statement",
f"https://stockanalysis.com/stocks/{ticker}/financials/balance-sheet/?p=quarterly": "Balance Sheet",
f"https://stockanalysis.com/stocks/{ticker}/financials/cash-flow-statement/?p=quarterly": "Cash Flow",
f"https://stockanalysis.com/stocks/{ticker}/financials/ratios/?p=quarterly": "Ratios"
}
excel_file = f"{ticker}_financial_statements.xlsx"
# Write each table found on each URL to a separate sheet in an Excel file.
with pd.ExcelWriter(f"{ticker}_financial_statements.xlsx") as writer:
# Loop through each URL and its corresponding sheet name
for url, sheet_name in url_sheet_map.items():
print(f"Processing: {url}")
response = requests.get(url)
response.raise_for_status() # Ensure the request was successful
# Parse all tables from the current URL
tables = pd.read_html(response.text)
print(f"Found {len(tables)} tables at {url}.")
# If multiple tables, we write them sequentially in the same sheet
startrow = 0 # Initial row position for writing
# Use a list to collect dataframes if you prefer concatenation, but here we write them one after another
for idx, table in enumerate(tables):
# Optionally, add a header row in the Excel sheet to indicate table separation
header = pd.DataFrame({f"Table {idx} from {sheet_name}": []})
header.to_excel(writer, sheet_name=sheet_name, startrow=startrow)
startrow += 1 # Move down one row for the table data
# Write the table to the current sheet starting at the designated row
table.to_excel(writer, sheet_name=sheet_name, startrow=startrow)
# Update the startrow for the next table (current table rows + 2 extra rows as spacer)
startrow += len(table.index) + 2
print("All tables have been saved into 'tables_by_url.xlsx', each URL in its own sheet.")
# --- Step 2: Clean and load the sheets from Excel ---
TICKER = "AAPL"
EXCEL = f"{TICKER}_financial_statements.xlsx"
FY_COL = "FY2024" # Adjust as necessary
# --- Helper Functions ---
def parse_value(val):
if isinstance(val, str):
val = val.replace(",", "").strip()
if val in ['-', '', 'NA', 'N/A']:
return np.nan
if "%" in val:
try: return float(val.replace("%", "").strip()) / 100
except: return np.nan
m = {'B':1e9, 'M':1e6, 'T':1e12}
if val[-1] in m:
try: return float(val[:-1].strip()) * m[val[-1]]
except: return np.nan
try: return float(val) * 1e6 if val[-1].isdigit() else np.nan
except: return np.nan
return np.nan if pd.isna(val) else val
def clean_sheet(sheet, file):
# Read the sheet using the first two rows (0 and 1) as headers.
# (Adjust header=[0,1] if your file’s first two rows contain the date information.)
df = pd.read_excel(EXCEL, sheet_name=sheet, header=[1, 2])
df = df.iloc[1:].reset_index(drop=True)
# Flatten the multi-index columns by joining the two header levels.
# For each column (a tuple), join non-empty parts with a space.
df.columns = [
' '.join(str(part).strip() for part in col if pd.notna(part) and str(part).strip() != '')
for col in df.columns.values
]
return df
def get_val(df, key, col=FY_COL, default=None):
row = df[df["Item"].str.contains(key, case=False, na=False)]
return row[col].values[0] if not row.empty else default
# --- Load Data ---
fin = clean_sheet("Income Statement", EXCEL)
fin = fin.set_index(fin.columns[1])
fin = fin.drop(fin.columns[0], axis=1)
fin.index.name = 'Item'
bal = clean_sheet("Balance Sheet", EXCEL)
bal = bal.set_index(bal.columns[1])
bal = bal.drop(bal.columns[0], axis=1)
bal.index.name = 'Item'
cf = clean_sheet("Cash Flow", EXCEL)
cf = cf.set_index(cf.columns[1])
cf = cf.drop(cf.columns[0], axis=1)
cf.index.name = 'Item'
# --- Step 3: Process and extract metrics from each statement ---
# For the Income Statement, extract rows for "Revenue" and "Net Income".
revenue_row = fin[fin.index.str.strip().str.match(r"^Revenue$", case=False, na=False)]
rev_df = revenue_row.T.reset_index().rename(
columns={'index': 'Period', revenue_row.index[0]: 'Revenue'}
)
net_income_row = fin[fin.index.str.strip().str.match(r"^Net Income$", case=False, na=False)]
ni_df = net_income_row.T.reset_index().rename(
columns={'index': 'Period', net_income_row.index[0]: 'Net Income'}
)
# For the Balance Sheet, extract rows for "Total Liabilities" and "Total Equity"
liab_row = bal.loc[bal.index.str.strip().str.match(r"^Total Liabilities$", case=False, na=False)]
liab_df = liab_row.T.reset_index().rename(
columns={'index': 'Period', liab_row.index[0]: 'Total Liabilities'}
)
equity_row = bal.loc[bal.index.str.strip().str.match(r"^Shareholders' Equity$", case=False, na=False)]
equity_df = equity_row.T.reset_index().rename(
columns={'index': 'Period', equity_row.index[0]: "Shareholders' Equity"}
)
# For the Cash Flow, extract the row for "Operating Cash Flow"
ocf_row = cf.loc[cf.index.str.strip().str.match(r"^Operating Cash Flow$", case=False, na=False)]
ocf_df = ocf_row.T.reset_index().rename(
columns={'index': 'Period', ocf_row.index[0]: "Operating Cash Flow"}
)
# --- Step 4: Merge the metrics by period ---
# All these DataFrames have column names like FY2024, FY2023, etc.
# Merge on the "Period" column
merged = rev_df.merge(ni_df, on="Period", suffixes=("_Revenue", "_NetIncome"))
merged = merged.merge(liab_df, on="Period")
merged = merged.merge(equity_df, on="Period")
merged = merged.merge(ocf_df, on="Period")
# Rename columns if needed (here we assume the first non-"Item" column represents the value)
merged.rename(columns={
merged.columns[1]: "Revenue",
merged.columns[2]: "Net Income",
merged.columns[3]: "Total Liabilities",
merged.columns[4]: "Total Equity",
merged.columns[5]: "Operating Cash Flow"
}, inplace=True)
# Convert key columns to numeric
for col in ["Revenue", "Net Income", "Total Liabilities", "Total Equity", "Operating Cash Flow"]:
merged[col] = pd.to_numeric(merged[col], errors='coerce')
# Define a function to extract the full period-ending date.
def extract_date(period_str):
# This regex attempts to capture the date at the end of the string.
# It expects a pattern like "Sep 28, 2024" at the very end.
m = re.search(r"([A-Za-z]{3} \d{1,2}, \d{4})$", period_str)
if m:
return m.group(1)
else:
return None
# Create a new Date column in 'merged' by extracting and converting the date.
merged["Date"] = merged["Period"].apply(lambda x: pd.to_datetime(extract_date(x), errors='coerce'))
# Sort the merged DataFrame in chronological order (oldest to newest) based on the extracted Date.
merged = merged.sort_values("Date", ascending=True).reset_index(drop=True)
print("Merged financial data after processing dates (oldest to newest):")
print(merged[['Period', 'Date', 'Revenue', 'Net Income', 'Debt_to_Equity',
'Revenue_growth', 'OCF_margin', 'Earnings_drop']].to_string(index=False))
# Compute debt-to-equity ratio
merged["Debt_to_Equity"] = merged["Total Liabilities"] / merged["Total Equity"]
# Compute revenue growth as percent change (assuming periods are in chronological order)
merged["Revenue_growth"] = merged["Revenue"].pct_change()
# Compute Operating Cash Flow margin
merged["OCF_margin"] = merged["Operating Cash Flow"] / merged["Revenue"]
# Define Earnings Drop: 1 if current period's Net Income is lower than previous period's
merged["Earnings_drop"] = (merged["Net Income"] < merged["Net Income"].shift(1)).astype(int)
# Drop the first period (which has no previous period to compare)
model_df = merged.dropna(subset=["Revenue_growth", "Debt_to_Equity", "OCF_margin", "Earnings_drop"])
print("Merged financial data for modeling:")
print(model_df[['Period', 'Date', 'Revenue', 'Net Income', 'Debt_to_Equity',
'Revenue_growth', 'OCF_margin', 'Earnings_drop']].to_string(index=False))
# --- Step 6: Build a multi-feature decision tree ---
features = ['Revenue_growth', 'Debt_to_Equity', 'OCF_margin']
target = 'Earnings_drop'
X = model_df[features]
y = model_df[target]
clf = DecisionTreeClassifier(criterion='gini', random_state=0, max_depth=3)
clf.fit(X, y)
plt.figure(figsize=(10, 8))
plot_tree(clf, feature_names=features, class_names=["No Drop", "Drop"],
filled=True, rounded=True)
plt.title("Decision Tree: Earnings Drop Prediction (Multi-Feature Model)")
plt.show()