三种一维树状数组
单点修改+区间查询
最基本的树状数组
不解释
模板(洛谷P3374
【模板】树状数组1) 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
using namespace std;
int BIT[500010],n,m;
inline int getint(){
register int mark=1,ret=0;
register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')mark=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){ret=ret*10+ch-'0';ch=getchar();}
return mark*ret;
}
inline void add(int pos,int num){
while(pos<=n){
BIT[pos]+=num;
pos+=lowbit(pos);
}
}
inline int getsum(int pos){
register int ret(0);
while(pos>0){
ret+=BIT[pos];
pos-=lowbit(pos);
}
return ret;
}
int main(int argc,char **argv){
n=getint(),m=getint();
for(register int i(1),tmp;i<=n;++i){
tmp=getint();
add(i,tmp);
}
register int cmd,x,y;
while(m--){
cmd=getint(),x=getint(),y=getint();
switch(cmd){
case 1:
add(x,y);
break;
case 2:
printf("%d\n",getsum(y)-getsum(x-1));
}
}
return 0;
}
区间修改+单点查询
通常区间修改大家都会选择用线段树了,但实际上树状数组也能解决这类问题。
区间修改+单点查询的树状数组用到了差分的思想。什么是差分呢?例如:
1 | int a[]={0,2,3,5,3,8,20,23,56}; |
不难观察,b数组的每一位存的是a每一位与前面一位的差值。这时区间修改就很容易了,例如给a的区间[2,4]加上2:
1 | a[]={0,2,5,7,5,8,20,23,56}; |
可以看到b只变化了两位,也就是b[3]:1->3和b[5]:5->3。这样只需要修改两位就可以实现区间修改,但单点查询时间是O(n),这显然也就很不好,这时就想到树状数组实现。
代码中的三个函数均不改变。但是主函数中: 1. 输入模块变为了存查值 2. 修改模块对修改区间后的区间造成了影响,所以此时再减去增加的数 3. 查询函数并没有改变,但它本身是用来求前n项和的,由前面所讲的差分可以知道它现在求的实际上是第n项
这样就可以实现区间修改和单点查询了
模板(洛谷P3368
【模板】树状数组2) 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
using namespace std;
int BIT[500010],n,m;
inline int getint(){
register int mark=1,ret=0;
register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')mark=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){ret=ret*10+ch-'0';ch=getchar();}
return mark*ret;
}
inline void add(int pos,int num){
while(pos<=n){
BIT[pos]+=num;
pos+=lowbit(pos);
}
}
inline int getnum(int pos){
register int ret(0);
while(pos>0){
ret+=BIT[pos];
pos-=lowbit(pos);
}
return ret;
}
int main(int argc,char **argv){
n=getint(),m=getint();
register int x(0),y;
for(register int i(1);i<=n;++i){
y=getint();
add(i,y-x);
x=y;
}
register int cmd,k;
while(m--){
cmd=getint();
switch(cmd){
case 1:
x=getint(),y=getint(),k=getint();
add(x,k);
add(y+1,-k);
break;
case 2:
x=getint();
printf("%d\n",getnum(x));
}
}
return 0;
}
区间修改+区间查询
学到这里我整匹马都惊了
前面讲到用差分来维护树状数组,这里思想差不多。
设原数组为a,使得数组b
\[b_i=a_i+a_{i-1}\]
可以得到
\[ \begin{aligned} S(n)&=b_1+b_2+b_3+···+b_n\\\\ & = a_1+(a_1+a_2)+(a_1+a_2+a_3)+···+(a_1+a_2+a_3+···+a_n)\\\\ & = n*a_1+(n-1)*a_2+···+1*a_n\\\\ & = (n+1-1)*a_1+(n+1-2)*a_2+···+(n+1-n)*a_n\\\\ & = (n+1)*(a_1+a_2+···+a_n)-(1*a_1+2*a_2+···+n*a_n)\\\\ & = (n+1)*\sum_{i=1}^n a_i-\sum_{i=1}^n i*a_i \end{aligned} \]
从而只需要求出\(\sum_{i=1}^n a_i\)和\(\sum_{i=1}^n i*a_i\)就可以了,所以我们维护两个数组分别代表\(a_i\)和\(i*a_i\),其他的函数全部沿用就可以了
模板(洛谷P3372
【模板】线段树1) 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
using namespace std;
long long BIT1[100010],BIT2[100010];
int n,m;
inline int getint(){
register int mark(1),ret=(0);
register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')mark=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){ret=ret*10+ch-'0';ch=getchar();}
return mark*ret;
}
inline long long getll(){
register long long mark(1),ret(0);
register char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')mark=-1ll;ch=getchar();}
while(ch>='0'&&ch<='9'){ret=ret*10ll+ch-'0';ch=getchar();}
return mark*ret;
}
inline void add(long long *BIT,int pos,long long num){
while(pos<=n){
BIT[pos]+=num;
pos+=lowbit(pos);
}
}
inline long long getsum(long long *BIT,int pos){
register long long ret(0);
while(pos>0){
ret+=BIT[pos];
pos-=lowbit(pos);
}
return ret;
}
int main(int argc,char **argv){
n=getint(),m=getint();
register long long tmp1,tmp2(0);
for(register int i(1);i<=n;++i){
tmp1=getll();
add(BIT1,i,tmp1-tmp2);
add(BIT2,i,(long long)i*(tmp1-tmp2));
tmp2=tmp1;
}
register int cmd,x,y;
register long long k;
while(m--){
cmd=getint();
switch(cmd){
case 1:
x=getint(),y=getint(),k=getll();
add(BIT1,x,k);
add(BIT1,y+1,-k);
add(BIT2,x,(long long)x*k);
add(BIT2,y+1,(long long)(y+1)*(-k));
break;
case 2:
x=getint(),y=getint();
register long long suml=(long long)x*getsum(BIT1,x-1)-getsum(BIT2,x-1);
register long long sumr=(long long)(y+1)*getsum(BIT1,y)-getsum(BIT2,y);
printf("%lld\n",sumr-suml);
}
}
return 0;
}