"""Expression classes for LumiX optimization models."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Tuple, TypeVar
from typing_extensions import Self
from .variables import LXVariable
TModel = TypeVar("TModel")
[docs]
@dataclass
class LXLinearExpression(Generic[TModel]):
"""
Type-safe linear expression builder with multi-model support.
Represents: sum(coeff[i] * var[i]) + constant
Examples::
expr = LXLinearExpression()
expr.add_term(production, 1.0)
expr.add_term(inventory, -1.0)
expr.constant(100)
# Multi-model
expr = LXLinearExpression()
expr.sum_over(duty, where=lambda driver, date: date.is_weekend)
"""
terms: Dict[str, Tuple[LXVariable, Callable[[TModel], float]], Tuple[Callable[[TModel], bool]]] = field(default_factory=dict)
constant: float = 0.0
# Multi-model terms
_multi_terms: List[Tuple[LXVariable, Callable[..., float], Optional[Callable[..., bool]]]] = field(
default_factory=list)
[docs]
def __deepcopy__(self, memo):
"""Custom deepcopy that handles variables and lambda functions.
This method enables what-if analysis on expressions by:
1. Deep copying all variables in the expression
2. Safely copying coefficient and filter lambda functions
3. Preserving the expression structure
Args:
memo: Dictionary for tracking circular references during deepcopy
Returns:
Deep copy of this expression with all dependencies resolved
"""
from copy import deepcopy
from ..utils.copy_utils import copy_function_detaching_closure
# Create new instance
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
# Copy simple attributes
result.constant = self.constant
# Deep copy terms dictionary
# Each term: (LXVariable, coeff_func, where_func)
result.terms = {}
for var_name, (var, coeff_func, where_func) in self.terms.items():
copied_var = deepcopy(var, memo)
copied_coeff = copy_function_detaching_closure(coeff_func, memo)
copied_where = copy_function_detaching_closure(where_func, memo)
result.terms[var_name] = (copied_var, copied_coeff, copied_where)
# Deep copy multi-terms list
# Each multi-term: (LXVariable, coeff_func, where_func)
result._multi_terms = []
for var, coeff_func, where_func in self._multi_terms:
copied_var = deepcopy(var, memo)
copied_coeff = copy_function_detaching_closure(coeff_func, memo)
copied_where = (
copy_function_detaching_closure(where_func, memo)
if where_func is not None
else None
)
result._multi_terms.append((copied_var, copied_coeff, copied_where))
return result
[docs]
def add_term(
self, var: LXVariable[TModel, Any],
coeff: float | Callable[[TModel], float] = 1.0,
where: Callable[[TModel], bool] | None = None,
) -> Self:
"""
Add term with coefficient (constant or function).
Args:
var: Variable Family to add
coeff: Coefficient (constant or function)
where: Optional filter function
Returns:
Self for chaining
"""
coeff_func = coeff if callable(coeff) else lambda _: coeff
if where is None:
where = lambda _: True
self.terms[var.name] = (var, coeff_func, where)
return self
[docs]
def add_multi_term(
self,
var: LXVariable,
coeff: Callable[..., float] = lambda *args: 1.0,
where: Optional[Callable[..., bool]] = None,
) -> Self:
"""
Add multi-indexed variable to expression.
Args:
var: Multi-indexed variable family
coeff: Coefficient function receiving all dimension models
where: Optional filter function
Returns:
Self for chaining
Example::
expr.add_multi_term(
duty,
coeff=lambda driver, date: driver.cost * date.multiplier,
where=lambda driver, date: date.is_weekend
)
"""
self._multi_terms.append((var, coeff, where))
return self
[docs]
def sum_over(
self,
var: LXVariable,
where: Optional[Callable[..., bool]] = None,
) -> Self:
"""
Syntactic sugar for summing over all dimensions of a variable.
Args:
var: Variable to sum over all its dimensions
where: Optional filter function to selectively include terms
Returns:
Self for chaining
Example::
# Sum all driver duties (over all drivers and dates)
expr.sum_over(duty)
# Sum duties for all drivers on a specific date
expr.sum_over(duty, where=lambda d, dt: dt == specific_date)
Note:
Currently sums over all dimensions. Future enhancement could add
selective dimension summing (e.g., sum only over drivers, not dates).
"""
return self.add_multi_term(var, where=where)
[docs]
def add_constant(self, value: float) -> Self:
"""
Add constant to expression.
Args:
value: Constant value
Returns:
Self for chaining
"""
self.constant += value
return self
[docs]
def __add__(self, other: float | Self) -> Self:
"""
Enable expr1 + expr2 or expr + constant.
Args:
other: Expression or constant to add
Returns:
Self for chaining
"""
if isinstance(other, (int, float)):
self.constant += other
return self
# Merge expressions
for var_name, term in other.terms.items():
if var_name in self.terms:
# Combine coefficients when same variable appears in both expressions
var1, coeff_func1, where1 = self.terms[var_name]
var2, coeff_func2, where2 = term
# Create combined coefficient function that sums both
def combined_coeff(m, cf1=coeff_func1, cf2=coeff_func2):
return cf1(m) + cf2(m)
# Combine where clauses with AND logic
def combined_where(m, w1=where1, w2=where2):
return w1(m) and w2(m)
self.terms[var_name] = (var1, combined_coeff, combined_where)
else:
self.terms[var_name] = term
self.constant += other.constant
# Also merge multi-terms
self._multi_terms.extend(other._multi_terms)
return self
[docs]
def __mul__(self, scalar: float) -> Self:
"""
Enable scalar * expression.
Args:
scalar: Scalar multiplier
Returns:
Self for chaining
"""
for var_name in self.terms:
var, old_coeff = self.terms[var_name]
self.terms[var_name] = (var, lambda m, c=old_coeff, s=scalar: c(m) * s)
self.constant *= scalar
return self
[docs]
def copy(self) -> Self:
"""
Create a deep copy of this expression.
Returns:
New expression with same terms and constant
"""
new_expr = LXLinearExpression[TModel]()
new_expr.terms = self.terms.copy()
new_expr.constant = self.constant
new_expr._multi_terms = self._multi_terms.copy()
return new_expr
[docs]
@dataclass
class LXQuadraticTerm:
"""
Quadratic term: coeff * var1 * var2
Used in portfolio optimization, risk modeling, etc.
"""
var1: LXVariable
var2: LXVariable
coefficient: float = 1.0
[docs]
def __deepcopy__(self, memo):
"""Custom deepcopy that handles variables.
Args:
memo: Dictionary for tracking circular references during deepcopy
Returns:
Deep copy of this quadratic term
"""
from copy import deepcopy
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
result.var1 = deepcopy(self.var1, memo)
result.var2 = deepcopy(self.var2, memo)
result.coefficient = self.coefficient
return result
[docs]
def is_squared_term(self) -> bool:
"""Check if this is x^2 (same variable twice)."""
return self.var1.name == self.var2.name
[docs]
@dataclass
class LXQuadraticExpression:
"""
Quadratic expression: linear_terms + quadratic_terms + constant
Represents: 0.5 * x^T Q x + c^T x + constant
Example:
# Portfolio variance: sum(w[i] * w[j] * cov[i,j])
# Plus linear returns: sum(return[i] * w[i])
quad_expr = LXQuadraticExpression()
quad_expr.add_quadratic(w[0], w[1], cov[0,1])
quad_expr.linear_terms.add_term(w[0], returns[0])
"""
linear_terms: LXLinearExpression = field(default_factory=LXLinearExpression)
quadratic_terms: List[LXQuadraticTerm] = field(default_factory=list)
constant: float = 0.0
[docs]
def __deepcopy__(self, memo):
"""Custom deepcopy that handles linear and quadratic terms.
This method enables what-if analysis on quadratic expressions by:
1. Deep copying the linear expression component
2. Deep copying all quadratic terms
3. Preserving the expression structure
Args:
memo: Dictionary for tracking circular references during deepcopy
Returns:
Deep copy of this quadratic expression with all dependencies resolved
"""
from copy import deepcopy
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
# Deep copy linear terms
result.linear_terms = deepcopy(self.linear_terms, memo)
# Deep copy quadratic terms
result.quadratic_terms = [deepcopy(term, memo) for term in self.quadratic_terms]
# Copy constant
result.constant = self.constant
return result
[docs]
def add_quadratic(self, var1: LXVariable, var2: LXVariable, coeff: float = 1.0) -> Self:
"""
Add quadratic term.
Args:
var1: First variable
var2: Second variable
coeff: Coefficient
Returns:
Self for chaining
"""
self.quadratic_terms.append(LXQuadraticTerm(var1, var2, coeff))
return self
[docs]
def add_squared(self, var: LXVariable, coeff: float = 1.0) -> Self:
"""
Add x^2 term.
Args:
var: Variable to square
coeff: Coefficient
Returns:
Self for chaining
"""
self.quadratic_terms.append(LXQuadraticTerm(var, var, coeff))
return self
[docs]
def __add__(self, other: LXLinearExpression | float) -> Self:
"""
Enable: quad_expr + linear_expr or quad_expr + constant.
Args:
other: Linear expression or constant
Returns:
Self for chaining
"""
if isinstance(other, LXLinearExpression):
self.linear_terms = self.linear_terms + other
elif isinstance(other, (int, float)):
self.constant += other
return self
[docs]
@dataclass
class LXNonLinearExpression:
"""
Non-linear expression containing arbitrary non-linear terms.
Supports:
- Bilinear terms (x * y)
- Absolute value (|x|)
- Min/max functions
- Piecewise-linear approximations
- Conditional expressions
- Custom non-linear functions
These will be automatically linearized by the Linearizer engine.
Example:
# Create nonlinear expression
expr = LXNonLinearExpression()
# Add bilinear product
expr.add_product(length, width)
# Add absolute value
expr.add_abs(deviation)
# Add piecewise function
expr.add_piecewise(time, lambda t: math.exp(t), num_segments=30)
"""
linear_terms: LXLinearExpression = field(default_factory=LXLinearExpression)
nonlinear_terms: List[Any] = field(default_factory=list)
constant: float = 0.0
[docs]
def add_linear(self, expr: LXLinearExpression) -> Self:
"""
Add linear terms.
Args:
expr: Linear expression to add
Returns:
Self for chaining
"""
self.linear_terms = self.linear_terms + expr
return self
[docs]
def add_abs(self, var: LXVariable, coeff: float = 1.0) -> Self:
"""
Add absolute value term: coeff * |var|
Args:
var: Variable to take absolute value of
coeff: Coefficient (default: 1.0)
Returns:
Self for chaining
Example:
# Minimize absolute deviation
expr.add_abs(actual - target)
"""
from ..nonlinear.terms import LXAbsoluteTerm
self.nonlinear_terms.append(LXAbsoluteTerm(var, coeff))
return self
[docs]
def add_min(self, *vars: LXVariable, coefficients: Optional[List[float]] = None) -> Self:
"""
Add minimum function: min(vars)
Args:
*vars: Variables to take minimum of
coefficients: Optional coefficients for each variable
Returns:
Self for chaining
Example:
# Minimum of three costs
expr.add_min(cost_a, cost_b, cost_c)
"""
from ..nonlinear.terms import LXMinMaxTerm
coeffs = coefficients or [1.0] * len(vars)
self.nonlinear_terms.append(LXMinMaxTerm(list(vars), "min", coeffs))
return self
[docs]
def add_max(self, *vars: LXVariable, coefficients: Optional[List[float]] = None) -> Self:
"""
Add maximum function: max(vars)
Args:
*vars: Variables to take maximum of
coefficients: Optional coefficients for each variable
Returns:
Self for chaining
Example:
# Maximum capacity
expr.add_max(capacity_1, capacity_2, capacity_3)
"""
from ..nonlinear.terms import LXMinMaxTerm
coeffs = coefficients or [1.0] * len(vars)
self.nonlinear_terms.append(LXMinMaxTerm(list(vars), "max", coeffs))
return self
[docs]
def add_product(self, var1: LXVariable, var2: LXVariable, coeff: float = 1.0) -> Self:
"""
Add bilinear product: coeff * var1 * var2
Automatically linearized based on variable types:
- Binary × Binary: AND logic
- Binary × Continuous: Big-M method
- Continuous × Continuous: McCormick envelopes
Args:
var1: First variable
var2: Second variable
coeff: Coefficient (default: 1.0)
Returns:
Self for chaining
Example:
# Rectangle area
expr.add_product(length, width)
# Facility open × flow amount
expr.add_product(is_open, flow_amount)
"""
from ..nonlinear.terms import LXBilinearTerm
self.nonlinear_terms.append(LXBilinearTerm(var1, var2, coeff))
return self
[docs]
def add_indicator(
self, binary_var: LXVariable, condition: bool, linear_expr: LXLinearExpression
) -> Self:
"""
Add conditional constraint: if binary_var == condition then linear_expr
Args:
binary_var: Binary variable
condition: Condition value (True or False)
linear_expr: Expression to apply when condition is met
Returns:
Self for chaining
Example::
# If warehouse is open, then demand must be met
expr.add_indicator(
is_open,
True,
LXLinearExpression().add_term(supply, 1.0)
)
"""
from ..nonlinear.terms import LXIndicatorTerm
self.nonlinear_terms.append(LXIndicatorTerm(binary_var, condition, linear_expr))
return self
[docs]
def add_piecewise(
self,
var: LXVariable,
func: Callable[[float], float],
num_segments: int = 20,
x_min: Optional[float] = None,
x_max: Optional[float] = None,
adaptive: bool = True,
method: Literal["sos2", "incremental", "logarithmic"] = "sos2",
) -> Self:
"""
Add piecewise-linear approximation of arbitrary function.
Args:
var: Input variable
func: Function to approximate (e.g., lambda x: math.exp(x))
num_segments: Number of linear segments
x_min: Minimum domain value (default: var.lower_bound)
x_max: Maximum domain value (default: var.upper_bound)
adaptive: Use adaptive breakpoint generation
method: Linearization method ("sos2", "incremental", "logarithmic")
Returns:
Self for chaining
Example::
# Exponential growth
expr.add_piecewise(time, lambda t: math.exp(t), num_segments=30)
# Custom discount curve
expr.add_piecewise(
quantity,
lambda q: 1.0 if q < 100 else 0.9 if q < 1000 else 0.8,
num_segments=50
)
"""
from ..nonlinear.terms import LXPiecewiseLinearTerm
self.nonlinear_terms.append(
LXPiecewiseLinearTerm(var, func, num_segments, x_min, x_max, adaptive, method)
)
return self
[docs]
def add_nonlinear_term(self, term: Any) -> Self:
"""
Add pre-constructed non-linear term.
Args:
term: Non-linear term object
Returns:
Self for chaining
"""
self.nonlinear_terms.append(term)
return self
[docs]
def add_nonlinear_terms(self, terms: List[Any]) -> Self:
"""
Add multiple non-linear terms.
Args:
terms: List of non-linear terms
Returns:
Self for chaining
"""
self.nonlinear_terms.extend(terms)
return self
__all__ = [
"LXLinearExpression",
"LXQuadraticTerm",
"LXQuadraticExpression",
"LXNonLinearExpression",
]