1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| from collections.abc import Callable from typing import Callable, Generic, TypeVar, Union
S = TypeVar("S")
class Segtree(Generic[S]): def __init__( self, arr_or_N: Union[list[S], int], op: Callable[[S, S], S], e: Callable[[], S], ) -> None: self._op = op self._e = e
self._N = arr_or_N if isinstance(arr_or_N, int) else len(arr_or_N) if self._N <= 0: raise ValueError("N must be positive")
self._d = [self._e() for _ in range(4 * self._N)]
if not isinstance(arr_or_N, int):
def _build(i: int, l: int, r: int) -> None: if l == r: self._d[i] = arr_or_N[l] return m = (l + r) // 2 _build(i * 2, l, m) _build(i * 2 + 1, m + 1, r) self._update(i)
_build(1, 0, self._N - 1)
def set(self, p: int, x: S) -> None: if not (0 <= p < self._N): raise IndexError("p out of range")
def _set(i: int, tl: int, tr: int, p: int, x: S) -> None: if tl == tr: self._d[i] = x return tm = (tl + tr) // 2 if p <= tm: _set(i * 2, tl, tm, p, x) else: _set(i * 2 + 1, tm + 1, tr, p, x) self._update(i)
_set(1, 0, self._N - 1, p, x)
def get(self, p: int) -> S: if not (0 <= p < self._N): raise IndexError("p out of range")
def _get(i: int, tl: int, tr: int, p: int) -> S: if tl == tr: return self._d[i] tm = (tl + tr) // 2 if p <= tm: return _get(i * 2, tl, tm, p) return _get(i * 2 + 1, tm + 1, tr, p)
return _get(1, 0, self._N - 1, p)
def prod(self, l: int, r: int) -> S: if l < 0 or r < 0 or l > r or r > self._N: raise IndexError("invalid range")
if l == r: return self._e()
def _prod(i: int, tl: int, tr: int, ql: int, qr: int) -> S: if qr < tl or tr < ql: return self._e() if ql <= tl and tr <= qr: return self._d[i] tm = (tl + tr) // 2 return self._op( _prod(i * 2, tl, tm, ql, qr), _prod(i * 2 + 1, tm + 1, tr, ql, qr), )
return _prod(1, 0, self._N - 1, l, r - 1)
def all_prod(self) -> S: return self._d[1]
def _update(self, i: int) -> None: self._d[i] = self._op(self._d[i * 2], self._d[i * 2 + 1])
|