Source code for mktlib.backtest._engine

from __future__ import annotations

import datetime
import logging
from collections.abc import Mapping
from typing import TYPE_CHECKING, overload

import polars as pl

logger = logging.getLogger(__name__)

from mktlib.backtest._conditions import (
    All,
    Any_,
    ColExpr,
    Condition,
    Custom,
    EntryRef,
    Limit,
    Not,
    Pct,
    ValueGT,
    ValueGTE,
    ValueLT,
    ValueLTE,
    _BinOp,
)
from mktlib.backtest._types import BacktestResult, MultiBacktestResult, Strategy, TradeSide
from mktlib.backtest._weights import (
    INSTRUMENT_COLUMN,
    InvalidPortfolioWeights,
    to_portfolio_weights_df,
)

if TYPE_CHECKING:
    from mktlib.scheduling import ExchangeCalendar


def _to_date(dt: datetime.datetime | datetime.date) -> datetime.date:
    """Convert datetime to date if needed, for calendar API compatibility."""
    if isinstance(dt, datetime.datetime):
        return dt.date()
    return dt


def _tz(series: pl.Series) -> str | None:
    """Extract timezone from a Datetime series, or None."""
    return series.dtype.time_zone  # type: ignore[union-attr]


def _align_tz(target: pl.Series, reference: pl.Series) -> pl.Series:
    """Align *target* timezone to match *reference*."""
    ref_tz = _tz(reference)
    tgt_tz = _tz(target)
    if ref_tz is None and tgt_tz is not None:
        return target.dt.replace_time_zone(None)
    if ref_tz is not None and tgt_tz is None:
        return target.dt.replace_time_zone(ref_tz)
    return target


# ---------------------------------------------------------------------------
# Schedule cache — avoids recomputing calendar.schedule() across mask calls
# ---------------------------------------------------------------------------

_schedule_cache: dict[tuple[str, datetime.date, datetime.date], pl.DataFrame] = {}


def _get_schedule(calendar: ExchangeCalendar, dates: pl.Series) -> pl.DataFrame:
    """Return cached schedule DataFrame for the calendar covering *dates*."""
    start = _to_date(dates.min())  # type: ignore[arg-type]
    end = _to_date(dates.max())  # type: ignore[arg-type]
    key = (calendar.name, start, end)
    if key not in _schedule_cache:
        _schedule_cache[key] = calendar.schedule(start, end)
    return _schedule_cache[key]


def _build_session_last_mask(
    dates: pl.Series,
    calendar: ExchangeCalendar,
) -> pl.Series:
    """Boolean series: True on the last bar of each trading session.

    Finds the actual last bar per session from the data rather than assuming
    a fixed offset, so this works for any candle size (1min, 5min, 15min, …).
    """
    sched = _get_schedule(calendar, dates)
    open_times = _align_tz(sched["market_open"], dates)
    close_times = _align_tz(sched["market_close"], dates)

    sessions = pl.DataFrame({
        "market_open": open_times,
        "market_close": close_times,
    })

    bar_df = dates.to_frame("date").with_row_index("_idx")
    joined = bar_df.join_asof(sessions, left_on="date", right_on="market_open")

    joined = joined.filter(pl.col("date") < pl.col("market_close"))
    last_per_session = joined.group_by("market_open").agg(
        pl.col("date").max().alias("last_bar"),
    )

    last_bars = last_per_session["last_bar"].to_list()
    return dates.is_in(last_bars)


# ---------------------------------------------------------------------------
# EntryRef tree walker — collects column names needed for entry-bar snapshots
# ---------------------------------------------------------------------------


def _collect_entry_refs(cond: Condition) -> set[str]:
    """Return all column names referenced by ``EntryRef`` nodes in *cond*."""
    cols: set[str] = set()
    _walk_cond(cond, cols)
    return cols


def _walk_cond(cond: Condition, cols: set[str]) -> None:
    match cond:
        case All(left, right, _) | Any_(left, right, _):
            _walk_cond(left, cols)
            _walk_cond(right, cols)
        case Not(inner, _):
            _walk_cond(inner, cols)
        case ValueGT(a, b, _) | ValueGTE(a, b, _) | ValueLT(a, b, _) | ValueLTE(a, b, _):
            _walk_expr(a, cols)
            _walk_expr(b, cols)
        case _:
            pass


def _walk_expr(node: str | float | ColExpr, cols: set[str]) -> None:
    match node:
        case EntryRef(col):
            cols.add(col)
        case Pct(base, _):
            _walk_expr(base, cols)
        case _BinOp(left, right, _):
            _walk_expr(left, cols)
            _walk_expr(right, cols)
        case _:
            pass


def _run_core(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    trade_side: TradeSide,
    calendar: ExchangeCalendar | None,
    flatten_eod: bool,
) -> BacktestResult:
    """Inner engine: assumes *df* is already calendar-filtered."""
    # Let strategy enrich the DataFrame with indicator columns
    _init = getattr(strategy, "init", None)
    if _init is not None:
        df = _init(df)

    entry_raw = strategy.entry()
    exit_raw = strategy.exit()
    entry_cond = entry_raw if isinstance(entry_raw, Condition) else Custom(entry_raw)
    exit_cond = exit_raw if isinstance(exit_raw, Condition) else Custom(exit_raw)

    # Entry-side limits are not supported yet (v1 scope: exit-only). The
    # wrapper would resolve silently to the inner boolean with the
    # ``price`` kwarg ignored — a footgun. Fail fast with a clear message
    # so callers don't get subtle incorrect fill semantics.
    if isinstance(entry_cond, Limit):
        msg = (
            "Limit(...) is only supported on exit conditions in this "
            "release. Entry-side limit fills are planned for a future "
            "version — for now, use the inner condition directly for "
            "entry and rely on fill-at-next-open semantics."
        )
        raise NotImplementedError(msg)

    entry_expr = entry_cond.resolve()

    # Limit exit: same-bar fill at a specified price. v1 scope is a
    # top-level ``Limit`` wrapper on the exit tree only — nested limit
    # usage is not recognized and behaves as a plain boolean condition.
    is_limit_exit = isinstance(exit_cond, Limit)
    limit_price_expr: pl.Expr | None = (
        exit_cond.resolve_price() if is_limit_exit else None  # type: ignore[attr-defined]
    )

    effective_side = int(entry_cond.trade_side or trade_side)

    # Pass 1: compute _entry
    signals = df.with_columns(entry_expr.alias("_entry"))

    # Create snapshot columns for any EntryRef nodes in exit condition
    entry_refs = _collect_entry_refs(exit_cond)
    if entry_refs:
        signals = signals.with_columns(
            pl.when(pl.col("_entry")).then(pl.col(col)).otherwise(None)
            .forward_fill().alias(f"_entry_{col}")
            for col in entry_refs
        )

    # Pass 2: compute _exit (snapshot columns now exist for EntryRef.resolve())
    exit_expr = exit_cond.resolve()
    signals = signals.with_columns(exit_expr.alias("_exit"))

    # For limit exits, also materialize the fill price. Only consumed on
    # bars where the exit fires while we're holding; values elsewhere are
    # harmless.
    if is_limit_exit:
        signals = signals.with_columns(
            limit_price_expr.alias("_limit_price"),  # type: ignore[union-attr]
        )

    # Position tracking: 1 on entry, 0 on exit, forward-fill
    if flatten_eod:
        _session_last = _build_session_last_mask(signals["date"], calendar)  # type: ignore[arg-type]
        signals = signals.with_columns(_session_last.alias("_session_last"))
        # Defer entries on session-last bars to the first bar of the next
        # session (e.g. crossover on 15:59 → enter at next day's 09:30).
        _suppressed = pl.col("_entry") & pl.col("_session_last")
        signals = signals.with_columns(
            (pl.col("_entry") | _suppressed.shift(1).fill_null(False)).alias("_entry"),
        )
        # Suppress entries on session-last bars (position opens and immediately
        # force-closes in the same bar — not a valid trade).
        signals = signals.with_columns(
            pl.when(pl.col("_entry") & ~pl.col("_session_last"))
            .then(pl.lit(1))
            .when(pl.col("_exit") | pl.col("_session_last"))
            .then(pl.lit(0))
            .otherwise(pl.lit(None))
            .forward_fill()
            .fill_null(0)
            .alias("_position"),
        )
    else:
        signals = signals.with_columns(
            pl.when(pl.col("_entry"))
            .then(pl.lit(1))
            .when(pl.col("_exit"))
            .then(pl.lit(0))
            .otherwise(pl.lit(None))
            .forward_fill()
            .fill_null(0)
            .alias("_position"),
        )

    # Materialize shared shifted expressions once
    signals = signals.with_columns(
        pl.col("_position").shift(1).fill_null(0).alias("_pos_d1"),
        pl.col("_position").shift(2).fill_null(0).alias("_pos_d2"),
        pl.col("close").shift(1).alias("_close_prev"),
    )
    # Transition detection (uses materialized _pos_d1)
    signals = signals.with_columns(
        ((pl.col("_position") == 1) & (pl.col("_pos_d1") == 0)).alias("_entry_clean"),
        ((pl.col("_position") == 0) & (pl.col("_pos_d1") == 1)).alias("_exit_clean"),
    )

    # _side = position * effective_side (scalar multiply, no forward-fill)
    signals = signals.with_columns(
        (pl.col("_position") * effective_side).cast(pl.Int8).alias("_side"),
    )

    # Detect transition bars (after the 1-bar delay for fill)
    _is_entry_bar = (pl.col("_pos_d1") == 1) & (pl.col("_pos_d2") == 0)
    _is_exit_bar = (pl.col("_pos_d1") == 0) & (pl.col("_pos_d2") == 1)

    # Per-bar returns with fill-at-open adjustment
    _entry_ret = ((pl.col("close") - pl.col("open")) / pl.col("open")) * effective_side
    _normal_ret = (pl.col("close") / pl.col("_close_prev") - 1) * effective_side
    _exit_ret = (
        (pl.col("open") - pl.col("_close_prev")) / pl.col("_close_prev")
    ) * effective_side

    # Limit-exit branches: fill at _limit_price on the same bar the
    # inner condition fires while we were holding. Suppress the normal
    # next-bar exit-open fill that would otherwise apply.
    _limit_ret: pl.Expr = pl.lit(0.0)
    _is_limit_exit_bar: pl.Expr = pl.lit(False)
    _is_post_limit_bar: pl.Expr = pl.lit(False)
    if is_limit_exit:
        _limit_ret = (
            (pl.col("_limit_price") - pl.col("_close_prev"))
            / pl.col("_close_prev")
        ) * effective_side
        _is_limit_exit_bar = pl.col("_exit") & (pl.col("_pos_d1") == 1)
        _is_post_limit_bar = _is_limit_exit_bar.shift(1).fill_null(False)

    # Compute returns + flatten_eod overrides in minimal with_columns calls
    if flatten_eod:
        # Base returns + session-last override in one pass. When limits
        # are active they win against session-last on the same bar
        # (intra-bar fill precedes the session-close flatten).
        if is_limit_exit:
            ret_expr = (
                pl.when(_is_limit_exit_bar).then(_limit_ret)
                .when(_is_post_limit_bar).then(0.0)
                .when(pl.col("_session_last") & _is_entry_bar).then(0.0)
                .when(pl.col("_session_last") & (pl.col("_pos_d1") == 1)).then(_exit_ret)
                .when(_is_entry_bar).then(_entry_ret)
                .when(_is_exit_bar).then(_exit_ret)
                .when(pl.col("_pos_d1") == 1).then(_normal_ret)
                .otherwise(0.0)
                .fill_null(0.0)
            )
        else:
            ret_expr = (
                pl.when(pl.col("_session_last") & _is_entry_bar).then(0.0)
                .when(pl.col("_session_last") & (pl.col("_pos_d1") == 1)).then(_exit_ret)
                .when(_is_entry_bar).then(_entry_ret)
                .when(_is_exit_bar).then(_exit_ret)
                .when(pl.col("_pos_d1") == 1).then(_normal_ret)
                .otherwise(0.0)
                .fill_null(0.0)
            )
        signals = signals.with_columns(ret_expr.alias("return"))
        # Post-session-last bar zeroing
        signals = signals.with_columns(
            pl.when(
                pl.col("_session_last").shift(1).fill_null(False)
                & ~_is_entry_bar
            )
            .then(0.0)
            .otherwise(pl.col("return"))
            .alias("return"),
        )
    else:
        if is_limit_exit:
            ret_expr = (
                pl.when(_is_limit_exit_bar).then(_limit_ret)
                .when(_is_post_limit_bar).then(0.0)
                .when(_is_entry_bar).then(_entry_ret)
                .when(_is_exit_bar).then(_exit_ret)
                .when(pl.col("_pos_d1") == 1).then(_normal_ret)
                .otherwise(0.0)
                .fill_null(0.0)
            )
        else:
            ret_expr = (
                pl.when(_is_entry_bar).then(_entry_ret)
                .when(_is_exit_bar).then(_exit_ret)
                .when(pl.col("_pos_d1") == 1).then(_normal_ret)
                .otherwise(0.0)
                .fill_null(0.0)
            )
        signals = signals.with_columns(ret_expr.alias("return"))

    returns = signals.select("date", "return")

    # Build trade log from entry/exit transitions (before dropping internal cols)
    trades = _extract_trades(signals, flatten_eod=flatten_eod)

    # Drop internal columns before return
    _drop_cols = ["_pos_d1", "_pos_d2", "_close_prev", "_entry_clean", "_exit_clean"]
    if flatten_eod:
        _drop_cols.append("_session_last")
    if is_limit_exit:
        _drop_cols.append("_limit_price")
    signals = signals.drop(_drop_cols)

    return BacktestResult(returns=returns, trades=trades, signals=signals)


def _run_dual(
    df: pl.DataFrame,
    long_strategy: Strategy,
    short_strategy: Strategy,
    *,
    calendar: ExchangeCalendar | None,
    flatten_eod: bool,
) -> BacktestResult:
    """Run long and short strategies independently, validate, merge.

    Both sides run concurrently via threads — Polars releases the GIL
    during computation so this gives real parallelism.  Merged signals
    use the long strategy's indicator columns as base.
    """
    from concurrent.futures import ThreadPoolExecutor

    with ThreadPoolExecutor(max_workers=2) as pool:
        long_future = pool.submit(
            _run_core, df, long_strategy, trade_side=TradeSide.LONG,
            calendar=calendar, flatten_eod=flatten_eod,
        )
        short_future = pool.submit(
            _run_core, df, short_strategy, trade_side=TradeSide.SHORT,
            calendar=calendar, flatten_eod=flatten_eod,
        )
        long = long_future.result()
        short = short_future.result()

    # Validate mutual exclusivity
    overlap = (long.signals["_position"] == 1) & (short.signals["_position"] == 1)
    if overlap.any():
        msg = "Long and short strategies have overlapping positions"
        raise ValueError(msg)

    # Merge returns (additive — flat bars contribute 0)
    returns = pl.DataFrame({
        "date": long.returns["date"],
        "return": long.returns["return"] + short.returns["return"],
    })

    # Merge trades (concat + sort by entry_date)
    trades = pl.concat([long.trades, short.trades]).sort("entry_date")

    # Merge signals: long signals as base, overlay combined columns
    signals = long.signals.with_columns(
        (long.signals["_entry"] | short.signals["_entry"]).alias("_entry"),
        (long.signals["_exit"] | short.signals["_exit"]).alias("_exit"),
        (long.signals["_position"] + short.signals["_position"]).alias("_position"),
        (long.signals["_side"] + short.signals["_side"]).alias("_side"),
    )

    return BacktestResult(returns=returns, trades=trades, signals=signals)


def _run_multi(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    short_strategy: Strategy | None = None,
    instrument_col: str,
    trade_side: TradeSide,
    calendar: ExchangeCalendar | None,
    flatten_eod: bool,
    weights: pl.DataFrame | None = None,
) -> MultiBacktestResult:
    """Run independent backtests per instrument and combine results."""
    if instrument_col not in df.columns:
        msg = f"instrument_col={instrument_col!r} not found in DataFrame columns"
        raise ValueError(msg)

    # Calendar filter once on full df
    if calendar is not None:
        df = calendar.filter_market_hours(df, "date")

    by_instrument: dict[str, BacktestResult] = {}
    for inst_df in df.partition_by(instrument_col, maintain_order=True):
        instrument = inst_df[instrument_col][0]
        inst_data = inst_df.drop(instrument_col)
        if short_strategy is not None:
            by_instrument[instrument] = _run_dual(
                inst_data,
                strategy,
                short_strategy,
                calendar=calendar,
                flatten_eod=flatten_eod,
            )
        else:
            by_instrument[instrument] = _run_core(
                inst_data,
                strategy,
                trade_side=trade_side,
                calendar=calendar,
                flatten_eod=flatten_eod,
            )

    if weights is not None:
        _cross_validate_weights(by_instrument, weights)

    return MultiBacktestResult(
        by_instrument, instrument_col=instrument_col, weights=weights,
    )


def _cross_validate_weights(
    by_instrument: dict[str, BacktestResult],
    weights: pl.DataFrame,
) -> None:
    """Check that every backtested symbol has a weight; warn on extra weights.

    Symbols in the data but not in *weights* are ambiguous — raise. Symbols
    in *weights* but not in the data are harmless (master-config pattern) —
    log a warning and let the downstream inner-join drop them.
    """
    data_symbols = set(by_instrument)
    weight_symbols = set(weights[INSTRUMENT_COLUMN].to_list())

    missing = data_symbols - weight_symbols
    if missing:
        msg = (
            f"symbols backtested but not in instrument_weights: {sorted(missing)}. "
            "Either add them to the weights dict or filter them out of the input "
            "DataFrame before calling run()."
        )
        raise InvalidPortfolioWeights(msg)

    extras = weight_symbols - data_symbols
    if extras:
        logger.warning(
            "instrument_weights contains symbols not present in the data; "
            "these will be ignored: %s",
            sorted(extras),
        )


@overload
def run(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    short_strategy: Strategy,
    calendar: ExchangeCalendar | None = ...,
    flatten_eod: bool = ...,
    instrument_col: str,
    instrument_weights: Mapping[str, float] | pl.DataFrame | None = ...,
) -> MultiBacktestResult: ...


@overload
def run(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    short_strategy: Strategy,
    calendar: ExchangeCalendar | None = ...,
    flatten_eod: bool = ...,
    instrument_col: None = ...,
) -> BacktestResult: ...


@overload
def run(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    trade_side: TradeSide = ...,
    calendar: ExchangeCalendar | None = ...,
    flatten_eod: bool = ...,
    instrument_col: str,
    instrument_weights: Mapping[str, float] | pl.DataFrame | None = ...,
) -> MultiBacktestResult: ...


@overload
def run(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    trade_side: TradeSide = ...,
    calendar: ExchangeCalendar | None = ...,
    flatten_eod: bool = ...,
    instrument_col: None = ...,
) -> BacktestResult: ...


@overload
def run(
    df: pl.DataFrame,
    strategy: Strategy,
    *,
    trade_side: TradeSide = ...,
    calendar: ExchangeCalendar | None = ...,
    flatten_eod: bool = ...,
    instrument_col: None = ...,
    instrument_weights: Mapping[str, float] | pl.DataFrame,
) -> MultiBacktestResult: ...


[docs] def run( df: pl.DataFrame, strategy: Strategy, *, short_strategy: Strategy | None = None, trade_side: TradeSide = TradeSide.LONG, calendar: ExchangeCalendar | None = None, flatten_eod: bool = False, instrument_col: str | None = None, instrument_weights: Mapping[str, float] | pl.DataFrame | None = None, ) -> BacktestResult | MultiBacktestResult: """Run a vectorized backtest with fill-at-next-open semantics. Parameters ---------- df Must contain ``date``, ``open``, ``close``, and any indicator columns referenced by the strategy. strategy Object with ``entry()`` and ``exit()`` returning Conditions. May optionally define ``init(df) -> pl.DataFrame`` to enrich the DataFrame with indicator columns before signal evaluation. short_strategy When provided, *strategy* is used as the long strategy and *short_strategy* as the short strategy. They are run independently, validated for mutual exclusivity (no overlapping positions), and merged into a single result. trade_side Trade direction (single-strategy mode only). Overridden by the entry condition's ``trade_side`` if set. Ignored when *short_strategy* is provided. calendar Exchange calendar for market-hours filtering. When provided, the DataFrame is filtered to market hours before signal computation. flatten_eod Force-close positions at each session's last bar, eliminating overnight exposure. Requires *calendar*. instrument_col Column name identifying the symbol/ticker in a multi-symbol DataFrame. When provided, returns a :class:`~mktlib.backtest.MultiBacktestResult` that stores per-symbol results for O(1) access (``result["AAPL"]``). Combined DataFrames with the symbol column prepended are available via ``.returns``, ``.trades``, ``.signals`` properties. instrument_weights Optional portfolio weights for multi-instrument runs. Accepts either a ``Mapping[str, float]`` (``{"TQQQ": 0.5, "AAPL": 0.1}``) or a ``pl.DataFrame`` with columns ``(instrument, weight)``. When supplied, :attr:`MultiBacktestResult.returns` is the weighted portfolio time series (``(date, return)``) instead of the per-symbol concatenation. Requires *instrument_col*. See :mod:`mktlib.backtest._weights` for schema and renormalization semantics. Returns ------- BacktestResult When *instrument_col* is ``None`` (default). MultiBacktestResult When *instrument_col* is set. Supports ``result[symbol]`` for O(1) per-symbol access, iteration, and lazy-cached combined views. Notes ----- Signal at bar *t* → market order fills at bar *t+1*'s open. - **Entry bar** (*t+1*): return = ``(close - open) / open`` - **Middle bars**: return = ``close / prev_close - 1`` - **Exit bar** (first bar where position drops to 0): return = ``(open - prev_close) / prev_close`` (gap to fill price only) When *instrument_col* is set, each symbol is backtested independently — indicators (e.g. rolling SMA) do not bleed across symbols. Calendar filtering is applied once on the full DataFrame before partitioning. If *instrument_weights* is omitted, aggregation stays per-symbol and the caller decides how to combine:: result.returns.group_by("date").agg(pl.col("return").mean()) """ if flatten_eod and calendar is None: msg = "flatten_eod=True requires a calendar" raise ValueError(msg) if instrument_weights is not None and instrument_col is None: # Canonical column name for multi-instrument runs. Callers in the # quant-finance convention use ``instrument`` for the ticker column; # default to that when they've signaled multi via weights but not # named the column explicitly. instrument_col = "instrument" weights_df: pl.DataFrame | None = None if instrument_weights is not None: # Validate early so input errors surface before expensive compute. weights_df = to_portfolio_weights_df(instrument_weights) if short_strategy is not None: if trade_side is not TradeSide.LONG: msg = "trade_side is ignored when short_strategy is provided" raise ValueError(msg) if instrument_col is not None: return _run_multi( df, strategy, short_strategy=short_strategy, instrument_col=instrument_col, trade_side=TradeSide.LONG, calendar=calendar, flatten_eod=flatten_eod, weights=weights_df, ) if calendar is not None: df = calendar.filter_market_hours(df, "date") return _run_dual( df, strategy, short_strategy, calendar=calendar, flatten_eod=flatten_eod, ) if instrument_col is not None: return _run_multi( df, strategy, instrument_col=instrument_col, trade_side=trade_side, calendar=calendar, flatten_eod=flatten_eod, weights=weights_df, ) # Single-symbol path: calendar filter → _run_core if calendar is not None: df = calendar.filter_market_hours(df, "date") return _run_core( df, strategy, trade_side=trade_side, calendar=calendar, flatten_eod=flatten_eod, )
def _extract_trades( signals: pl.DataFrame, *, flatten_eod: bool = False, ) -> pl.DataFrame: """Extract per-trade PnL from position transitions. Fill prices use the *next* bar's open (fill-at-next-open model). For session-forced exits (flatten_eod), the exit fill is the session-last bar's own open (can't trade during the close minute). Side is extracted from ``_side`` at each entry bar. """ # Pre-compute next bar's open for fill price signals_with_next = signals.with_columns( pl.col("open").shift(-1).alias("_next_open"), ) entries = signals_with_next.filter(pl.col("_entry_clean")).select( pl.col("date").alias("entry_date"), pl.col("_next_open").alias("entry_price"), pl.col("_side").alias("_trade_side"), pl.int_range(pl.len()).alias("_entry_idx"), ) # Exit fill price priority: # 1. _limit_price when a same-bar limit exit fired on this bar # 2. session-last bar's own open (flatten_eod path) # 3. next bar's open (default fill-at-next-open) has_limit = "_limit_price" in signals.columns if flatten_eod: base_expr = ( pl.when(pl.col("_session_last")) .then(pl.col("open")) .otherwise(pl.col("_next_open")) ) else: base_expr = pl.col("_next_open") if has_limit: exit_price_expr = ( pl.when(pl.col("_limit_price").is_not_null()) .then(pl.col("_limit_price")) .otherwise(base_expr) .alias("exit_price") ) else: exit_price_expr = base_expr.alias("exit_price") exits = signals_with_next.filter(pl.col("_exit_clean")).select( pl.col("date").alias("exit_date"), exit_price_expr, pl.int_range(pl.len()).alias("_exit_idx"), ) # Pair entries with exits by ordinal position n_trades = min(entries.height, exits.height) if n_trades == 0: return pl.DataFrame( schema={ "entry_date": signals["date"].dtype, "exit_date": signals["date"].dtype, "side": pl.Int8, "pnl": pl.Float64, "bars_held": pl.Int64, } ) entries = entries.head(n_trades) exits = exits.head(n_trades) trades = pl.DataFrame( { "entry_date": entries["entry_date"], "exit_date": exits["exit_date"], "side": entries["_trade_side"], "entry_price": entries["entry_price"], "exit_price": exits["exit_price"], } ) trades = trades.with_columns( (pl.col("side") * (pl.col("exit_price") / pl.col("entry_price") - 1)).alias("pnl"), ( (pl.col("exit_date").cast(pl.Date) - pl.col("entry_date").cast(pl.Date)).dt.total_days() ).alias("bars_held"), ).select("entry_date", "exit_date", "side", "pnl", "bars_held") return trades