ZKW线段树

多年以前,我曾经写过一篇关于线段树的文章https://blog.kuludu.net/article/线段树

而现在,我们重新探讨一种传统线段树的改良版——ZKW线段树。

为啥要叫ZKW线段树呢?因为它出自THU张昆玮大佬的PPT《统计的力量》

其核心思想就是利用K叉树的性质直接计算出子节点的位置,而非自顶而下递归求解。这样的线段树不仅简单而且效率更高。

下面给出一个区间修改与区间查询的ZKW线段树写法。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int MAXN = 100001;

int n,m;

int siz;
ll tree[MAXN << 2];
ll val[MAXN << 2];
ll sizes[MAXN << 2];

void add(int l, int r, int v) {
    for ( l += siz - 1, r += siz + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
        if ( ~l & 1) {
            val[l + 1] += v;
            tree[l + 1] += v * sizes[l + 1];
        }
 
        if ( r & 1 ) {
            val[r - 1] += v;
            tree[r - 1] += v * sizes[r - 1];
        }

        tree[l >> 1] = tree[l] + tree[l ^ 1] + val[l >> 1] * sizes[l >> 1];
        tree[r >> 1] = tree[r] + tree[r ^ 1] + val[r >> 1] * sizes[r >> 1];
    }

    while ( l > 1 ) {
        tree[l >> 1] = tree[l] + tree[l ^ 1] + val[l >> 1] * sizes[l >> 1];
        l >>= 1;
    }
}

ll query(int l, int r) {
    ll retl = 0, retr = 0;
    ll lenl = 0, lenr = 0;

  // l=l+M-1->将查询区间改为L-1,r=r+M+1->将查询区间改为R+1
  // l^r^1 -> 相当于判断l与r是否是兄弟节点
    for ( l += siz - 1, r += siz + 1; l ^ r ^ 1; l >>= 1, r >>= 1 ) {
    // l % 2 == 0 即l是l/2的左儿子
        if ( ~l & 1 ) {
            retl += tree[l + 1];
            lenl += sizes[l + 1];
        }

    // r % 2 == 1 即r是r/2的右儿子
        if ( r & 1 ) {
            retr += tree[r - 1];
            lenr += sizes[r - 1];
        }

        retl += lenl * val[l >> 1];
        retr += lenr * val[r >> 1];
    }

    retl += retr;
    lenl += lenr;
    
    while ( l > 1 ) {
        retl += lenl * val[l >> 1];
        l >>= 1;
    }

    return retl;
}

void build() {
    for ( siz = 1; siz < n + 2; siz <<= 1 );

    for ( int i = siz + 1; i <= siz + n; ++i ) {
        cin >> val[i];
        tree[i] = val[i];
    }

    for ( int i = siz; i; --i )
        tree[i] = tree[i << 1] + tree[i << 1 | 1]; 

    for ( int i = siz * 2 - 2; i - siz; --i )
        sizes[i] = 1;

    for ( int i = siz - 1; i; --i )
        sizes[i] = sizes[i << 1] + sizes[i << 1 | 1];
}

int main() {
    ios::sync_with_stdio(false);

    cin >> n >> m;

    build();

    int op;
    int a, b, c;
    for ( int i = 1; i <= m; ++i ) {
        cin >> op;
        if ( op == 1 ) {
            cin >> a >> b >> c;
            add(a, b, c);
        } else {
            cin >> a >> b;
            cout << query(a, b) << endl;
        }
    }

    return 0;
}

建议阅读

Last modification:September 8th, 2019 at 03:16 pm
博客维护不易,如果你觉得我的文章有用,请随意赞赏

Leave a Comment