Source code for mktlib.backtest._conditions

from __future__ import annotations

from dataclasses import dataclass

import polars as pl

from mktlib.backtest._types import TradeSide


def _ref(b: str | float | ColExpr) -> pl.Expr:
    """Resolve a column name, literal, or ColExpr to a Polars expression."""
    if isinstance(b, ColExpr):
        return b.resolve()
    return pl.col(b) if isinstance(b, str) else pl.lit(b)


# ---------------------------------------------------------------------------
# Composable column expressions
# ---------------------------------------------------------------------------


def _coerce_cmp(v: ColExpr | str | float | int) -> ColExpr:
    """Coerce a comparison operand to ColExpr."""
    if isinstance(v, ColExpr):
        return v
    if isinstance(v, str):
        return Col(v)
    return Lit(float(v))


[docs] class ColExpr: """Base for composable column expressions that resolve to ``pl.Expr``.""" __slots__ = () def resolve(self) -> pl.Expr: raise NotImplementedError # --- arithmetic operators --- def __add__(self, other: ColExpr | float | int) -> _BinOp: return _BinOp(self, _coerce(other), "+") def __radd__(self, other: float | int) -> _BinOp: return _BinOp(_coerce(other), self, "+") def __sub__(self, other: ColExpr | float | int) -> _BinOp: return _BinOp(self, _coerce(other), "-") def __rsub__(self, other: float | int) -> _BinOp: return _BinOp(_coerce(other), self, "-") def __mul__(self, other: ColExpr | float | int) -> _BinOp: return _BinOp(self, _coerce(other), "*") def __rmul__(self, other: float | int) -> _BinOp: return _BinOp(_coerce(other), self, "*") def __truediv__(self, other: ColExpr | float | int) -> _BinOp: return _BinOp(self, _coerce(other), "/") def __rtruediv__(self, other: float | int) -> _BinOp: return _BinOp(_coerce(other), self, "/") def __mod__(self, other: ColExpr | float | int) -> _BinOp: return _BinOp(self, _coerce(other), "%") def __rmod__(self, other: float | int) -> _BinOp: return _BinOp(_coerce(other), self, "%") def __neg__(self) -> _BinOp: return _BinOp(Lit(0.0), self, "-") # --- comparison operators (return Conditions) --- def __gt__(self, other: ColExpr | str | float | int) -> ValueGT: return ValueGT(self, _coerce_cmp(other)) def __ge__(self, other: ColExpr | str | float | int) -> ValueGTE: return ValueGTE(self, _coerce_cmp(other)) def __lt__(self, other: ColExpr | str | float | int) -> ValueLT: return ValueLT(self, _coerce_cmp(other)) def __le__(self, other: ColExpr | str | float | int) -> ValueLTE: return ValueLTE(self, _coerce_cmp(other))
# Backward-compat alias PriceExpr = ColExpr
[docs] @dataclass(frozen=True, slots=True) class Col(ColExpr): """Column reference — resolves to ``pl.col(name)``.""" name: str def resolve(self) -> pl.Expr: return pl.col(self.name)
[docs] @dataclass(frozen=True, slots=True) class Lit(ColExpr): """Literal constant — resolves to ``pl.lit(value)``.""" value: float def resolve(self) -> pl.Expr: return pl.lit(self.value)
@dataclass(frozen=True, slots=True) class _BinOp(ColExpr): """Binary arithmetic node (internal).""" left: ColExpr right: ColExpr op: str def resolve(self) -> pl.Expr: l = self.left.resolve() r = self.right.resolve() match self.op: case "+": return l + r case "-": return l - r case "*": return l * r case "/": return l / r case "%": return l % r case _: # pragma: no cover msg = f"Unknown op: {self.op}" raise ValueError(msg) def _coerce(v: ColExpr | float | int) -> ColExpr: """Wrap a raw numeric into ``Lit`` if needed.""" if isinstance(v, ColExpr): return v return Lit(float(v)) def _coerce_base(v: ColExpr | str | float) -> ColExpr: """Coerce a base price to ColExpr: str -> Col, float -> Lit.""" if isinstance(v, ColExpr): return v if isinstance(v, str): return Col(v) return Lit(float(v))
[docs] @dataclass(frozen=True, slots=True) class Pct(ColExpr): """Price offset by ``pct``% from ``base``. Positive ``pct`` -> above, negative -> below. ``Pct("close", 1.0)`` -> ``close * 1.01`` (1% above) ``Pct("close", -0.5)`` -> ``close * 0.995`` (0.5% below) """ base: ColExpr | str | float pct: float def resolve(self) -> pl.Expr: return _coerce_base(self.base).resolve() * (1.0 + self.pct / 100.0)
[docs] @dataclass(frozen=True, slots=True) class EntryRef(ColExpr): """Column value snapshotted at the entry signal bar, forward-filled. The engine creates ``_entry_{col}`` columns automatically when it detects ``EntryRef`` nodes in the exit condition tree. ``EntryRef("close")`` resolves to ``pl.col("_entry_close")``. """ col: str def resolve(self) -> pl.Expr: return pl.col(f"_entry_{self.col}")
[docs] class Condition: """Base class for signal conditions that resolve to boolean ``pl.Expr``.""" trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: raise NotImplementedError def __and__(self, other: Condition) -> All: return All(self, other) def __or__(self, other: Condition) -> Any_: return Any_(self, other) def __invert__(self) -> Not: return Not(self)
[docs] @dataclass(frozen=True, slots=True) class Crossover(Condition): """``a`` crosses above ``b`` (column name or constant).""" a: str b: str | float trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: ref = _ref(self.b) prev_ref = ref.shift(1) if isinstance(self.b, str) else ref return (pl.col(self.a) > ref) & (pl.col(self.a).shift(1) <= prev_ref)
[docs] @dataclass(frozen=True, slots=True) class Crossunder(Condition): """``a`` crosses below ``b`` (column name or constant).""" a: str b: str | float trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: ref = _ref(self.b) prev_ref = ref.shift(1) if isinstance(self.b, str) else ref return (pl.col(self.a) < ref) & (pl.col(self.a).shift(1) >= prev_ref)
[docs] @dataclass(frozen=True, slots=True) class ValueGT(Condition): """``a > b`` (column name, constant, or ColExpr).""" a: str | ColExpr b: str | float | ColExpr trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return _ref(self.a) > _ref(self.b)
[docs] @dataclass(frozen=True, slots=True) class ValueGTE(Condition): """``a >= b`` (column name, constant, or ColExpr).""" a: str | ColExpr b: str | float | ColExpr trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return _ref(self.a) >= _ref(self.b)
[docs] @dataclass(frozen=True, slots=True) class ValueLT(Condition): """``a < b`` (column name, constant, or ColExpr).""" a: str | ColExpr b: str | float | ColExpr trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return _ref(self.a) < _ref(self.b)
[docs] @dataclass(frozen=True, slots=True) class ValueLTE(Condition): """``a <= b`` (column name, constant, or ColExpr).""" a: str | ColExpr b: str | float | ColExpr trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return _ref(self.a) <= _ref(self.b)
# Backward-compat aliases PriceIsAbove = ValueGT PriceIsBelow = ValueLT
[docs] @dataclass(frozen=True, slots=True) class IsRising(Condition): """Column value is greater than its value ``period`` bars ago.""" col: str period: int = 1 trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return pl.col(self.col) > pl.col(self.col).shift(self.period)
[docs] @dataclass(frozen=True, slots=True) class IsFalling(Condition): """Column value is less than its value ``period`` bars ago.""" col: str period: int = 1 trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return pl.col(self.col) < pl.col(self.col).shift(self.period)
[docs] @dataclass(frozen=True, slots=True) class Custom(Condition): """User-supplied polars expression — must evaluate to a boolean column.""" expr: pl.Expr trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return self.expr
# --- Combinators ---
[docs] @dataclass(frozen=True, slots=True) class All(Condition): """Both conditions must be true (``a & b``).""" left: Condition right: Condition trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return self.left.resolve() & self.right.resolve()
[docs] @dataclass(frozen=True, slots=True) class Any_(Condition): """Either condition is true (``a | b``).""" left: Condition right: Condition trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return self.left.resolve() | self.right.resolve()
[docs] @dataclass(frozen=True, slots=True) class Not(Condition): """Invert a condition (``~a``).""" inner: Condition trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return ~self.inner.resolve()
[docs] @dataclass(frozen=True, slots=True) class Limit(Condition): """Exit condition with same-bar fill at a limit price. Wraps an exit condition (typically ``ValueGT/GTE/LT/LTE``) so that when the inner condition fires on bar ``t``, the position exits on that same bar at *price* — not at the next bar's open (mktlib's default fill-at-next-open semantics). Intended for take-profit / stop-loss strategies where the fill price is known in advance. When *price* is omitted, the fill price is auto-extracted from the right-hand side of the wrapped comparison — the typical TP/SL idiom ``high >= TP`` → fill at ``TP``. Pass an explicit *price* expression for trailing stops or decoupled trigger/fill (e.g. ``price=Col( "trailing_stop")`` with the comparison also against that column). v1 scope: only honored at the top level of ``exit_cond``. Nested use inside ``All`` / ``Any_`` / ``Not`` is treated as a plain boolean condition with no same-bar semantics. Bracket patterns (``Any_(TP, SL)``) are planned for a later release. """ inner: Condition price: ColExpr | str | float | int | None = None trade_side: TradeSide | None = None def resolve(self) -> pl.Expr: return self.inner.resolve() def resolve_price(self) -> pl.Expr: if self.price is not None: return _ref(self.price) if isinstance(self.inner, (ValueGT, ValueGTE, ValueLT, ValueLTE)): return _ref(self.inner.b) msg = ( f"Limit auto-extract requires inner to be " f"ValueGT/GTE/LT/LTE; got {type(self.inner).__name__}. " "Pass price= explicitly." ) raise ValueError(msg)