1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
| from collections.abc import Callable from typing import Generic, TypeVar, Union
S = TypeVar("S") F = TypeVar("F")
class LazySegtree(Generic[S, F]): def __init__( self, arr_or_N: Union[list[S], int], op: Callable[[S, S], S], e: Callable[[], S], mapping: Callable[[F, S], S], composition: Callable[[F, F], F], id: Callable[[], F], ) -> None: self._op = op self._e = e self._mapping = mapping self._composition = composition self._id = id
self._N = arr_or_N if isinstance(arr_or_N, int) else len(arr_or_N) self._d = [self._e() for _ in range(4 * self._N)] self._lz = [self._id() for _ in range(4 * self._N)]
if not isinstance(arr_or_N, int):
def _build(i: int, l: int, r: int) -> None: if l == r: self._d[i] = arr_or_N[l] return m = (l + r) // 2 _build(i * 2, l, m) _build(i * 2 + 1, m + 1, r) self._update(i)
_build(1, 0, self._N - 1)
def set(self, p: int, x: S) -> None: def _set(i: int, tl: int, tr: int, p: int, x: S) -> None: if tl == tr: self._d[i] = x return self._push(i) tm = (tl + tr) // 2 if p <= tm: _set(i * 2, tl, tm, p, x) else: _set(i * 2 + 1, tm + 1, tr, p, x) self._update(i)
return _set(1, 0, self._N - 1, p, x)
def get(self, p: int) -> S: def _get(i: int, tl: int, tr: int, p: int) -> S: if tl == tr: return self._d[i] self._push(i) tm = (tl + tr) // 2 if p <= tm: return _get(i * 2, tl, tm, p) else: return _get(i * 2 + 1, tm + 1, tr, p)
return _get(1, 0, self._N - 1, p)
def apply(self, l: int, r: int, f: F) -> None: def _apply(i: int, tl: int, tr: int, ql: int, qr: int, f: F) -> None: if qr < tl or tr < ql: return if ql <= tl and tr <= qr: self._all_apply(i, f) return self._push(i) tm = (tl + tr) // 2 _apply(i * 2, tl, tm, ql, qr, f) _apply(i * 2 + 1, tm + 1, tr, ql, qr, f) self._update(i)
return _apply(1, 0, self._N - 1, l, r - 1, f)
def prod(self, l: int, r: int) -> S: def _prod(i: int, tl: int, tr: int, ql: int, qr: int) -> S: if qr < tl or tr < ql: return self._e() if ql <= tl and tr <= qr: return self._d[i] self._push(i) tm = (tl + tr) // 2 return self._op( _prod(i * 2, tl, tm, ql, qr), _prod(i * 2 + 1, tm + 1, tr, ql, qr), )
return _prod(1, 0, self._N - 1, l, r - 1)
def all_prod(self) -> S: return self._d[1]
def _update(self, i: int) -> None: self._d[i] = self._op(self._d[i * 2], self._d[i * 2 + 1])
def _all_apply(self, i: int, f: F) -> None: self._d[i] = self._mapping(f, self._d[i]) self._lz[i] = self._composition(f, self._lz[i])
def _push(self, i: int) -> None: self._all_apply(i * 2, self._lz[i]) self._all_apply(i * 2 + 1, self._lz[i]) self._lz[i] = self._id()
|