checking system…
Docs / back / src/maf/triggers/safe_eval.py · line 71
Python · 145 lines
  1"""Restricted Python-expression evaluator for trigger ``when:`` rules.
  2
  3What's allowed
  4--------------
  5* Names (resolved from the supplied context dict)
  6* Attribute access on dict values (``payload.symbol``)
  7* Subscript on dict or list (``payload["tickers"][0]``)
  8* Comparisons: ``==`` ``!=`` ``<`` ``>`` ``<=`` ``>=`` ``in`` ``not in``
  9* Boolean ops: ``and`` ``or`` ``not``
 10* Unary ops: ``-`` ``not``
 11* Numeric ops: ``+`` ``-`` ``*`` ``/`` ``%`` ``//``
 12* String / number / bool / None literals
 13* Tuple / list literals
 14* The single safe builtin :func:`abs`
 15
 16What isn't
 17----------
 18Function calls beyond ``abs``. Attribute access on non-dict values (so no
 19``__import__``, no ``os.system``, etc.). Lambdas, comprehensions, generator
 20expressions, walrus, f-strings — all rejected with :class:`SafeEvalError`.
 21
 22The implementation is ~80 lines of pure stdlib :mod:`ast` walking. It's
 23designed for human-readable trigger rules; if you need real DSL muscle,
 24swap it for something like ``simpleeval`` later — the interface
 25(``safe_eval(expr, ctx) -> Any``) stays the same.
 26"""
 27
 28from __future__ import annotations
 29
 30import ast
 31import operator
 32from typing import Any
 33
 34
 35class SafeEvalError(Exception):
 36    """Raised when an expression contains a disallowed construct."""
 37
 38
 39_CMP_OPS: dict[type, Any] = {
 40    ast.Eq:    operator.eq,
 41    ast.NotEq: operator.ne,
 42    ast.Lt:    operator.lt,
 43    ast.LtE:   operator.le,
 44    ast.Gt:    operator.gt,
 45    ast.GtE:   operator.ge,
 46    ast.In:    lambda a, b: a in (b or []),
 47    ast.NotIn: lambda a, b: a not in (b or []),
 48}
 49
 50_BIN_OPS: dict[type, Any] = {
 51    ast.Add:      operator.add,
 52    ast.Sub:      operator.sub,
 53    ast.Mult:     operator.mul,
 54    ast.Div:      operator.truediv,
 55    ast.FloorDiv: operator.floordiv,
 56    ast.Mod:      operator.mod,
 57}
 58
 59_UNARY_OPS: dict[type, Any] = {
 60    ast.USub: operator.neg,
 61    ast.UAdd: operator.pos,
 62    ast.Not:  operator.not_,
 63}
 64
 65_BOOL_OPS: dict[type, Any] = {
 66    ast.And: all,
 67    ast.Or:  any,
 68}
 69
 70
 71def safe_eval(expr: str, ctx: dict[str, Any]) -> Any:
 72    """Evaluate ``expr`` in the supplied ``ctx`` dict.
 73
 74    Names in the expression are resolved against ``ctx`` (and only ``ctx``).
 75    Returns the value the expression evaluates to. Raises
 76    :class:`SafeEvalError` on any disallowed construct.
 77    """
 78    try:
 79        tree = ast.parse(expr, mode="eval").body
 80    except SyntaxError as exc:
 81        raise SafeEvalError(f"syntax: {exc}") from exc
 82    return _eval(tree, ctx)
 83
 84
 85def _eval(node: Any, ctx: dict[str, Any]) -> Any:
 86    # Literals
 87    if isinstance(node, ast.Constant):
 88        return node.value
 89    # Names
 90    if isinstance(node, ast.Name):
 91        return ctx.get(node.id)
 92    # Attribute access — only on dicts, returns dict.get(attr) to keep things
 93    # boring (no descriptor weirdness).
 94    if isinstance(node, ast.Attribute):
 95        target = _eval(node.value, ctx)
 96        if isinstance(target, dict):
 97            return target.get(node.attr)
 98        # Allow ``len`` of a string/list via .__len__? No — keep it tight.
 99        raise SafeEvalError(f"attribute access on non-dict: {type(target).__name__}")
100    # Subscript — payload["tickers"][0]
101    if isinstance(node, ast.Subscript):
102        target = _eval(node.value, ctx)
103        key = _eval(node.slice, ctx)
104        try:
105            return target[key]
106        except (KeyError, IndexError, TypeError):
107            return None
108    # Comparisons (chained: a < b < c)
109    if isinstance(node, ast.Compare):
110        left = _eval(node.left, ctx)
111        for op_node, right_node in zip(node.ops, node.comparators):
112            op = _CMP_OPS.get(type(op_node))
113            if op is None:
114                raise SafeEvalError(f"unsupported comparison {type(op_node).__name__}")
115            right = _eval(right_node, ctx)
116            if not op(left, right):
117                return False
118            left = right
119        return True
120    if isinstance(node, ast.BoolOp):
121        op = _BOOL_OPS.get(type(node.op))
122        if op is None:
123            raise SafeEvalError(f"unsupported boolop {type(node.op).__name__}")
124        return op(_eval(v, ctx) for v in node.values)
125    if isinstance(node, ast.UnaryOp):
126        op = _UNARY_OPS.get(type(node.op))
127        if op is None:
128            raise SafeEvalError(f"unsupported unaryop {type(node.op).__name__}")
129        return op(_eval(node.operand, ctx))
130    if isinstance(node, ast.BinOp):
131        op = _BIN_OPS.get(type(node.op))
132        if op is None:
133            raise SafeEvalError(f"unsupported binop {type(node.op).__name__}")
134        return op(_eval(node.left, ctx), _eval(node.right, ctx))
135    if isinstance(node, (ast.Tuple, ast.List)):
136        return [_eval(e, ctx) for e in node.elts]
137    if isinstance(node, ast.Call):
138        if isinstance(node.func, ast.Name) and node.func.id == "abs" and len(node.args) == 1:
139            return abs(_eval(node.args[0], ctx))
140        raise SafeEvalError("only abs(x) calls are allowed")
141    raise SafeEvalError(f"unsupported node: {type(node).__name__}")
142
143
144__all__ = ["SafeEvalError", "safe_eval"]