Source code for mktlib.rates

"""Treasury yield curve rates for risk-free rate estimation."""

from __future__ import annotations

import operator
from collections.abc import Sequence
from datetime import date
from enum import StrEnum
from functools import partial, reduce
from typing import TYPE_CHECKING

from . import _treasury

if TYPE_CHECKING:
    import polars as pl

__all__ = [
    "MeanMethod",
    "TreasuryRate",
    "get_mean_treasury_rate",
    "get_risk_free_rate",
    "get_treasury_rates",
    "get_treasury_spread",
]


[docs] class MeanMethod(StrEnum): """Averaging method for rate aggregation.""" ARITHMETIC = "arithmetic" GEOMETRIC = "geometric"
[docs] class TreasuryRate(StrEnum): """Treasury yield curve instruments available from Treasury.gov.""" ONE_MONTH = "BC_1MONTH" ONE_AND_HALF_MONTH = "BC_1_5MONTH" TWO_MONTH = "BC_2MONTH" THREE_MONTH = "BC_3MONTH" FOUR_MONTH = "BC_4MONTH" SIX_MONTH = "BC_6MONTH" ONE_YEAR = "BC_1YEAR" TWO_YEAR = "BC_2YEAR" THREE_YEAR = "BC_3YEAR" FIVE_YEAR = "BC_5YEAR" SEVEN_YEAR = "BC_7YEAR" TEN_YEAR = "BC_10YEAR" TWENTY_YEAR = "BC_20YEAR" THIRTY_YEAR = "BC_30YEAR" THIRTY_YEAR_DISPLAY = "BC_30YEARDISPLAY"
def _parse_date(d: date | str) -> date: return date.fromisoformat(d) if isinstance(d, str) else d
[docs] def get_risk_free_rate( start: date | str, end: date | str, instrument: TreasuryRate = TreasuryRate.THREE_MONTH, ) -> float: """Fetch the average annualised risk-free rate for a date range. Returns the arithmetic mean of daily Treasury yields as a decimal (e.g., 0.0436 for 4.36%). Parameters ---------- start, end Date range (inclusive). Accepts ``date`` objects or ISO strings (``"2024-01-01"``). instrument Which Treasury yield to use. Defaults to the 3-month T-bill, the standard academic proxy for the risk-free rate. """ return _treasury.fetch_average_rate( _parse_date(start), _parse_date(end), instrument.value )
[docs] def get_mean_treasury_rate( start: date | str, end: date | str, instrument: TreasuryRate = TreasuryRate.THREE_MONTH, method: MeanMethod = MeanMethod.ARITHMETIC, ) -> float: """Fetch the mean annualised Treasury rate for a date range. Parameters ---------- start, end Date range (inclusive). Accepts ``date`` objects or ISO strings. instrument Which Treasury yield to use. method Averaging method — arithmetic (default) or geometric. """ return _treasury.fetch_mean_rate( _parse_date(start), _parse_date(end), instrument.value, method.value )
_fetch_years = partial(map, _treasury.fetch_year)
[docs] def get_treasury_rates( start: date | str, end: date | str, instrument: TreasuryRate | Sequence[TreasuryRate] | None = None, ) -> pl.DataFrame: """Fetch daily Treasury rates as a Polars DataFrame. Parameters ---------- start, end Date range (inclusive). Accepts ``date`` objects or ISO strings. instrument Single instrument → 2-column DataFrame (``date``, ``rate``). Sequence or ``None`` (all) → wide DataFrame with one column per instrument, named by the enum member in lowercase (e.g. ``"three_month"``). """ import polars as pl start, end = _parse_date(start), _parse_date(end) df = pl.DataFrame( reduce( operator.iadd, _fetch_years(range(start.year, end.year + 1)), [] ) ).filter(pl.col("date").is_between(start, end)) # Single instrument → 2-column DataFrame (date, rate) if isinstance(instrument, TreasuryRate): key = instrument.value if df.is_empty(): return pl.DataFrame(schema={"date": pl.Date, "rate": pl.Float64}) if key not in df.columns: return pl.DataFrame(schema={"date": pl.Date, "rate": pl.Float64}) return df.select( pl.col("date").cast(pl.Date), pl.col(key).cast(pl.Float64).alias("rate"), ).drop_nulls("rate") # Multi-instrument or all if instrument is not None: keys = [i.value for i in instrument] else: keys = None # Determine rename mapping and desired column order value_to_name = {m.value: m.name.lower() for m in TreasuryRate} if keys is not None: rename = {k: value_to_name[k] for k in keys} else: rename = {m.value: m.name.lower() for m in TreasuryRate} if df.is_empty(): schema = {"date": pl.Date} | {n: pl.Float64 for n in rename.values()} return pl.DataFrame(schema=schema) # Only rename columns that exist in the data actual_rename = {k: v for k, v in rename.items() if k in df.columns} df = df.rename(actual_rename) # Build select list: date + all requested columns (missing ones as null) rate_cols = list(rename.values()) select_exprs: list[pl.Expr] = [pl.col("date").cast(pl.Date)] for c in rate_cols: if c in df.columns: select_exprs.append(pl.col(c).cast(pl.Float64)) else: select_exprs.append(pl.lit(None, dtype=pl.Float64).alias(c)) return df.select(select_exprs)
[docs] def get_treasury_spread( start: date | str, end: date | str, long: TreasuryRate = TreasuryRate.TEN_YEAR, short: TreasuryRate = TreasuryRate.TWO_YEAR, ) -> pl.DataFrame: """Fetch the daily spread between two Treasury instruments. Returns a DataFrame with ``date`` and ``spread`` columns. Only includes days where both instruments have data. Parameters ---------- start, end Date range (inclusive). Accepts ``date`` objects or ISO strings. long Longer-maturity instrument (default: 10-year). short Shorter-maturity instrument (default: 2-year). """ import polars as pl long_name = long.name.lower() short_name = short.name.lower() df = get_treasury_rates(start, end, [long, short]) return df.select( "date", (pl.col(long_name) - pl.col(short_name)).alias("spread"), ).drop_nulls("spread")