编辑
2020-04-26
算法
0
请注意,本文编写于 1819 天前,最后修改于 129 天前,其中某些信息可能已经过时。

目录

简介
工作原理
基础
查询
观察这棵树,可以发现有这两个特点
查询方法
模拟一下查询的过程(查询区间 [2,4] 的最大值)
可以由此写出查询代码(最大值)
修改
单点修改
修改方法
模拟一下单点修改的过程(将[1,1]这个数修改为9)
由此可以写出单点修改代码(最大值)
区间修改
思路
修改方法
后续查询方法
模拟一下修改的过程(将[0,3]区间的数全部+1)
再模拟一下查询的过程(查询[2,3]的最大值)
由此可以写出修改代码
构建线段树
进阶
区间和
方法
查询区间和代码
更新代码
区间乘法
方法
传递标记代码
区间乘法代码
完整代码
指针版(理解下蒟蒻的思路)
数组版(知道能拿数组写就好)
压行版(最好多打几遍背下来)

简介

对于RMQ问题,暴力计算时间复杂度踏大了,所以要预处理
预处理全部子区间空间复杂度踏大了,预处理的时间复杂度踏大了,还贼难修改
如果能预处理一部分子区间,正好对上查询区间最好,对不上也能用预处理过的区间“拼凑”出查询区间,时间空间复杂度还小就完美了
区间怎么选呢?当然是二分啦 多么优秀的想法

线段树是一种数据结构,用来解决RMQ(区间最值)问题 当然拓展一下也可以有别的用法
顾名思义,线段树是一棵二叉树 ,所谓线段,特别之处是线段树上的每个结点代表一个区间
其查询和修改的时间复杂度为Olog(n)

工作原理

对于一棵线段树来说,每个节点有两个子节点,每个节点的区间长度都是其父节点的一半,并提前预处理好了每个节点代表区间的待查询数据。这样在查询和修改时,就可以分为三类:

  1. 在区间的左半边
  2. 在区间的右半边
  3. 在左半边和右半边都有分布

这样在查询和修改时就可以按这三种情况进行操作,直到找到的区间


  1. 就像在欧亚大陆上找一个地方,可以分为三种情况:
  2. 全部在欧洲,如英国
  3. 全部在亚洲,如中国
  4. 横跨欧亚大陆,如俄罗斯

这样在查询的时候,每次将查询的区间缩小一半,复杂度就达到了log(n)log(n)级别

基础

查询

比如有一段序列:1 , 3 , 5 , 2 ,2
建立的线段树(每个结点保存其区间的最大值)如下 暂且先不管它是怎么建立的

观察这棵树,可以发现有这两个特点

  • 叶结点 区间长度为1,代表这段序列中的一个个数字
  • 其它结点 的的区间是其两个子节点区间的并集

查询方法

  • 待查询区间位于左子节点,向左子节点递归查询,并返回结果
  • 待查询区间位于右子节点,向右子节点递归查询,并返回结果
  • 待查询区间分布于两个子节点,拆分区间使其包含于被两个子节点的区间,向两个子节点递归查询,比较,并返回结果

模拟一下查询的过程(查询区间 [2,4] 的最大值)

  1. 访问根节点,其区间为[0,4],不是待查询区间,且待查询区间分布于其两个子节点
  2. 将待查询区间分为[2,2]和[3,4]向左子节点查询[2,2]
  3. 左子节点:其区间为[0,2],不是待查询区间,待查询区间位于其右子节点
  4. 访问左子节点的右子节点,其区间为[2,2],是待查询区间,返回结果:5
  5. 左子节点返回结果:5
  6. 向右子节点查询区间[3,4]
  7. 右子节点:其区间为[3,4],是待查询区间,返回结果:2
  8. 比较两子节点返回值:5>2,查返回询结果:5

可以由此写出查询代码(最大值)

cpp
double finda(node *p, int l, int r) { int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return finda(p->left, l, r);//待查询区间位于其左子树.查询左子树 if (p->l + mid < l) return finda(p->right, l, r);//待查询区间位于其右子树,查询右子树 return max(finda(p->left, l, p->l + mid), finda(p->right, p->l + mid + 1, r));//从中间切开,递归分别查询,返回最大值 } }

修改

单点修改

顾名思义,单点修改就是修改序列中的一个数
还是这段序列和这棵线段树

修改方法

  1. 找到待修改的叶结点
  2. 修改叶结点的值
  3. 更新其父节点的值

模拟一下单点修改的过程(将[1,1]这个数修改为9)

  1. 访问根节点,待修改节点位于其左子树
  2. 访问其左子节点,待修改节点位于其左子树
  3. 访问其左子节点,待修改节点是其右子节点
  4. 访问其右子节点,修改其值为9
  5. 返回其父节点,9>1,更新其值为9
  6. 返回其父节点,9>5,更新其值为9
  7. 返回其父节点,9>2,更新其值为9
  8. 到达根节点,修改结束

由此可以写出单点修改代码(最大值)

cpp
void change(node *p, int s, int v) { if(p->l==p->r==s) //找到待修改节点,修改 { p->v=v; return; } int mid = (p->r - p->l) / 2; if (p->l + mid >= s) { change(p->left, s, v);//待查询区间位于其左子树.查询左子树 update(p);//更新该节点的值 return; } if (p->l + mid < s) { change(p->right, s, v);//待查询区间位于其右子树,查询右子树 update(p);//更新该节点的值 return; } }

区间修改

顾名思义,区间修改就是对指定的区间进行 + - * / 操作

思路

一种想法是对区间中的每个数进行单点修改操作,然而这样做复杂度是Onlog(n),太大,不可取
上文提到线段树的区间修改可以做到Olog(n)的复杂度,这是怎么实现的呢?
回顾简单思路的问题:不难发现,这样把线段树全部相关节点都修改了,然而有些节点可能不会被使用到
举个栗子:如果按照这种思路将[0,4]全部+1,而后查询区间却没有包含[0,2],那么[0,2]这个子树就等于白修改了
不难想到解决方法:标记相关子树的根结点,表示这棵子树上的全部结点的值都被修改,使用时再把标记传递给子节点就好了

修改方法

  • 待查询区间就是此区间,打上标记,返回
  • 待查询区间位于左子节点,向左子节点递归修改,更新
  • 待查询区间位于右子节点,向右子节点递归修改,更新
  • 待查询区间分布于两个子节点,拆分区间使其包含于被两个子节点的区间,向两个子节点递归修改,更新

后续查询方法

  • 待查询区间位于左子节点,把标记传递给左子节点并向其递归查询,并返回结果
  • 待查询区间位于右子节点,把标记传递给右子节点并向其递归查询,并返回结果
  • 待查询区间分布于两个子节点,拆分区间使其包含于被两个子节点的区间,把标记传递给两个子节点并向其递归查询,比较,并返回结果

模拟一下修改的过程(将[0,3]区间的数全部+1)

  1. 访问根节点,不是待修改区间,拆分待修改区间为[0,2]和[3,3]
  2. 向左子树查询,是待修改区间,修改标记为1,更新值为6,返回
  3. 向右子树查询,不是待修改区间,待修改区间位于其左子树
  4. 向其左子树查询,是待修改区间,且是叶结点,直接修改,返回
  5. 3>2,更新值为3,返回
  6. 6>3,更新值为6,结束

再模拟一下查询的过程(查询[2,3]的最大值)

  1. 访问根节点,不是待查询区间,拆分其为[2,2]和[3,3]
  2. 向左子树查询,不是待查询区间,进一步查询前把标记传递给两个子节点并更新它们的值
  3. 待查询区间位于其右子树,向右子树查询
  4. 待查询区间是此区间,返回结果:6
  5. 返回结果:6
  6. 向右子树查询(此处省略查询步骤),返回结果:3
  7. 6>3,返回查询结果:6

由此可以写出修改代码

更新

cpp
void updata(node *p) { p->max = max(p->left->max, p->right->max);//更新最大值 }

传递标记

cpp
void down(node* p) { if (p->lazy_add == 0) return;//若没有标记,传递啥? if (p->left->l == p->left->r)//若为叶子结点,无需修改其标记(根本就没标记...) { p->left->max += p->lazy_add; } else//否则,先修改标记,再修改值 { p->left->lazy_add += p->lazy_add; p->left->max += p->lazy_add; } if (p->right->l == p->right->r)//同上(本来就是复制下来的...) { p->right->max += p->lazy_add; } else { p->right->lazy_add += p->lazy_add; p->right->max +=p->lazy_add; } p->lazy_add = 0;//删除加法标记 }

区间加法

cpp
void _add(node *p, int l, int r, double k) { if (p->l == l && p->r == r)//若区间完全重合 { if (p->l != p->r) p->lazy_add += k;//若不是叶子结点,更新加法标记 p->min += k;//更新区间最小值 p->max += k;//更新区间最大值 p->sum += k * (p->r - p->l + 1);//更新区间和 return; } int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被包含 { down(p);//使用子树前先传递标记 if (p->l + mid >= r) _add(p->left, l, r, k);//若区间完全位于其左子树,向左子树查找区间l~r并执行加法 else if (p->l + mid < l) _add(p->right, l, r, k);//若区间完全位于其右子树,向右子树查找区间l~r并执行加法 else//从中间切开,分别执行加法 { _add(p->left, l, p->l + mid, k); _add(p->right, p->l + mid + 1, r, k); } } updata(p);//更新 }

查询最大值

cpp
double finda(node *p, int l, int r) { if (p->l == l && p->r == r) return p->max;//若区间正正好完全重合,返回最大值 down(p);//使用子树前先传递标记 int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return finda(p->left, l, r);//待查询区间位于其左子树.查询左子树 if (p->l + mid < l) return finda(p->right, l, r);//待查询区间位于其右子树,查询右子树 return max(finda(p->left, l, p->l + mid), finda(p->right, p->l + mid + 1, r));//从中间切开,分别查询,返回最大值 } }

构建线段树

其实跟建立普通二叉树一样,只不过多了初始化区间和值以及更新的操作,递归建树即可

cpp
node* _build(node *t, int l, int r) { node* p = new node;//建立新节点 p->l = l;//区间左端点 p->r = r;//区右端点 p->father = t;//建立父子关系 if (t->left == NULL) t->left = p;//若t没有左子树,说明要构建的树为其左子树 else t->right = p;//反之,则为右子树 if (l == r)//若为叶子节点 { p->max = nums[l];//三个一样 p->min = nums[l]; p->sum = nums[l]; return p;//返回节点地址 } int mid = (r - l) / 2; p->left=_build(p, l, l + mid);//构造左子树 p->right=_build(p, l + mid + 1, r);//构造右子树 updata(p);//更新 return p; }

进阶

区间和

线段树除了可以维护最值,还可以维护区间和

方法

其实也很简单,每一个结点的值就是其子节点的值之和,单点修改只需减掉原来的在加上修改后的就行了,区间加法只需加上加数×区间长度即可

查询区间和代码

cpp
double findu(node *p, int l, int r) { if (p->l == l && p->r == r) return p->sum;//若区间正正好完全重合,返回区间和 down(p);//使用子树前先传递标记 int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return findu(p->left, l, r);//若待查询区间位于其左子树,查询左子树 if (p->l + mid < l) return findu(p->right, l, r);//若待查询区间位于其右子树,查询右子树 return findu(p->left, l, p->l + mid) + findu(p->right, p->l + mid + 1, r);//从中间切开,分别查询,返回和 } }

更新代码

cpp
void updata(node *p) { p->max = max(p->left->max, p->right->max);//更新最大值 p->min = min(p->left->min, p->right->min);//更新最小值 p->sum = p->left->sum + p->right->sum;//更新区间和 }

区间乘法

线段树除了可以进行区间加法,还可以进行区间乘法

方法

  • 进行区间乘法时,最大值,最小值 都只需乘上乘数即可,然后再修改乘法标记。
  • 此处注意,传递标记时要在修改子节点乘法标记的同时将加法标记乘以乘数,然后加上父节点的加法标记,即先乘后加原则

传递标记代码

cpp
void down(node* p) { if (p->lazy_add == 0 && p->lazy_multiply == 1) return;//若没有标记,传递啥? if (p->left->l == p->left->r)//若为叶子结点,无需修改其标记(根本就没标记...) { p->left->max *= p->lazy_multiply;//先乘后加 p->left->min *= p->lazy_multiply; p->left->sum *= p->lazy_multiply; p->left->max += p->lazy_add; p->left->min += p->lazy_add; p->left->sum += p->lazy_add; } else//否则,先修改标记,再修改值 { p->left->lazy_multiply *= p->lazy_multiply;//先乘后加 p->left->lazy_add *= p->lazy_multiply; p->left->lazy_add += p->lazy_add; p->left->max *= p->lazy_multiply; p->left->min *= p->lazy_multiply; p->left->sum *= p->lazy_multiply; p->left->max += p->lazy_add; p->left->min += p->lazy_add; p->left->sum += p->lazy_add * (p->left->r - p->left->l + 1);//区间和增加传递的标记乘以区间节点个数 } if (p->right->l == p->right->r)//同上(本来就是复制下来的...) { p->right->max *= p->lazy_multiply; p->right->min *= p->lazy_multiply; p->right->sum *= p->lazy_multiply; p->right->max += p->lazy_add; p->right->min += p->lazy_add; p->right->sum += p->lazy_add; } else { p->right->lazy_multiply *= p->lazy_multiply; p->right->lazy_add *= p->lazy_multiply; p->right->lazy_add += p->lazy_add; p->right->max *=p->lazy_multiply; p->right->min *=p->lazy_multiply; p->right->sum *=p->lazy_multiply; p->right->max +=p->lazy_add; p->right->min +=p->lazy_add; p->right->sum +=p->lazy_add * (p->right->r - p->right->l + 1); } p->lazy_add = 0;//删除加法标记 p->lazy_multiply = 1;//删除乘法标记 }

区间乘法代码

cpp
void _multiply(node *p, int l, int r, double k) { if (p->l == l && p->r == r)//若区间完全重合 { if (p->l != p->r) p->lazy_multiply *= k;//若不是叶子结点,更新乘法标记 if (k >= 0)//若乘数为正 { p->max *= k;//原来最大的还是最大的 p->min *= k;//原来最小的还是最小的 } else { double max = p->max, min = p->min;//暂时储存 p->max = min * k;//原来最大的符号相反后变成最小的 p->min = max * k;//原来最小的符号相反后变为最大的 } p->sum *= k;//更新区间和 p->lazy_add *= k;//更新加法标记(下面每个数要加的数也乘了k) return; } int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { down(p);//使用子节点前先更新标记 if (p->l + mid >= r) _multiply(p->left, l, r, k);//若待操作区间完全位于左子树,向左子树查询并进行区间乘法 else if (p->l + mid < l) _multiply(p->right, l, r, k);//若待操作区间完全位于右子树,向右子树查询并进行区间乘法 else//从中间切开,分别进行操作 { _multiply(p->left, l, p->l + mid, k); _multiply(p->right, p->l + mid + 1, r, k); } } updata(p);//更新 }

完整代码

指针版(理解下蒟蒻的思路)

cpp
#include<iostream> #include<vector> #include<algorithm> using namespace std; #define inf 0x3f3f3f3f class tree { public: struct node { node* father;//父节点 node* left;//左子树 node* right;//右子树 double max;//区间最大值 double min;//区间最小值 double sum;//区间和 int l, r;//区间端点 double lazy_add;//加法标记 double lazy_multiply;//乘法标记 node()//初始化 { father = NULL; left = NULL; right = NULL; l = -1; r = -1; max = -inf; min = inf; sum = inf; lazy_add = 0;//即什么都没加 lazy_multiply = 1;//即什么都没乘 } }; ~tree()//删除整棵树 { d_t(root); } //---------------------------------------------构造线段树-------------------------------------------------- void build(vector<double> num) { d_t(root);//删除旧树 nums = num;//赋新值 int size = nums.size() - 1; int mid = size / 2; root = new node; root->l = 0;//根节点区间左端为0 root->r = size;//根节点区间右端为size _build(root, 0, mid);//构建左子树 _build(root, mid + 1, size);//构建右子树 updata(root);//更新根节点的值 } //-----------------------------------------------查询----------------------------------------------------- double find(int l, int r, string mode) { if (mode == "max")//返回区间最大值 return finda(root, l, r); if (mode == "min")//返回区间最小值 return findi(root, l, r); if (mode == "sum")//返回区间和 return findu(root, l, r); } //----------------------------------------------区间加法-------------------------------------------------- void add(int l, int r, double k)//区间l~r内每个数加k { _add(root, l, r, k); } //----------------------------------------------区间减法-------------------------------------------------- void multiply(int l, int r, double k)//区间l~r内.每个数乘k { _multiply(root, l, r, k); } //--------------------------------------------输出整个区间------------------------------------------------- void out()//输出线段树包含的数 { _out(root); cout << "\n"; } private: node* root;//根节点 vector<double>nums;//数据 //----------------------------------------------删除------------------------------------------------------ void d_t(node* p) { if (p == NULL) return;//啥毛没有,咋删? if (p->left != NULL) d_t(p->left);//删除左子树 if (p->right != NULL) d_t(p->right);//删除右子树 delete p;//删除自己 } //----------------------------------------------更新-------------------------------------------------------- void updata(node *p) { p->max = max(p->left->max, p->right->max);//更新最大值 p->min = min(p->left->min, p->right->min);//更新最小值 p->sum = p->left->sum + p->right->sum;//更新区间和 } //----------------------------------------------构造--------------------------------------------------------- node* _build(node *t, int l, int r) { node* p = new node;//建立新节点 p->l = l;//区间左端点 p->r = r;//区右端点 p->father = t;//建立父子关系 if (t->left == NULL) t->left = p;//若t没有左子树,说明要构建的树为其左子树 else t->right = p;//反之,则为右子树 if (l == r)//若为叶子节点 { p->max = nums[l];//三个一样 p->min = nums[l]; p->sum = nums[l]; return p;//返回节点地址 } int mid = (r - l) / 2; p->left=_build(p, l, l + mid);//构造左子树 p->right=_build(p, l + mid + 1, r);//构造右子树 updata(p);//更新 return p; } //---------------------------------------------传递标记--------------------------------------------------- void down(node* p) { if (p->lazy_add == 0 && p->lazy_multiply == 1) return;//若没有标记,传递啥? if (p->left->l == p->left->r)//若为叶子结点,无需修改其标记(根本就没标记...) { p->left->max *= p->lazy_multiply;//先乘后加 p->left->min *= p->lazy_multiply; p->left->sum *= p->lazy_multiply; p->left->max += p->lazy_add; p->left->min += p->lazy_add; p->left->sum += p->lazy_add; } else//否则,先修改标记,再修改值 { p->left->lazy_multiply *= p->lazy_multiply;//先乘后加 p->left->lazy_add *= p->lazy_multiply; p->left->lazy_add += p->lazy_add; p->left->max *= p->lazy_multiply; p->left->min *= p->lazy_multiply; p->left->sum *= p->lazy_multiply; p->left->max += p->lazy_add; p->left->min += p->lazy_add; p->left->sum += p->lazy_add * (p->left->r - p->left->l + 1);//区间和增加传递的标记乘以区间节点个数 } if (p->right->l == p->right->r)//同上(本来就是复制下来的...) { p->right->max *= p->lazy_multiply; p->right->min *= p->lazy_multiply; p->right->sum *= p->lazy_multiply; p->right->max += p->lazy_add; p->right->min += p->lazy_add; p->right->sum += p->lazy_add; } else { p->right->lazy_multiply *= p->lazy_multiply; p->right->lazy_add *= p->lazy_multiply; p->right->lazy_add += p->lazy_add; p->right->max *=p->lazy_multiply; p->right->min *=p->lazy_multiply; p->right->sum *=p->lazy_multiply; p->right->max +=p->lazy_add; p->right->min +=p->lazy_add; p->right->sum +=p->lazy_add * (p->right->r - p->right->l + 1); } p->lazy_add = 0;//删除加法标记 p->lazy_multiply = 1;//删除乘法标记 } //------------------------------------------------输出---------------------------------------------------- void _out(node* p) { if (p->l == p->r)//若为叶子节点,输出 { cout << p->max << " "; return; } down(p);//使用子树前先传递标记 _out(p->left);//输出左子树 _out(p->right);//输出右子树 } //---------------------------------------------查询区间最大值---------------------------------------------- double finda(node *p, int l, int r) { if (p->l == l && p->r == r) return p->max;//若区间正正好完全重合,返回最大值 down(p);//使用子树前先传递标记 int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return finda(p->left, l, r);//待查询区间位于其左子树.查询左子树 if (p->l + mid < l) return finda(p->right, l, r);//待查询区间位于其右子树,查询右子树 return max(finda(p->left, l, p->l + mid), finda(p->right, p->l + mid + 1, r));//从中间切开,分别查询,返回最大值 } return -inf;//防止破坏 } //--------------------------------------------查询区间最小值----------------------------------------------- double findi(node *p, int l, int r) { if (p->l == l && p->r == r) return p->min;//若区间正正好完全重合,返回最小值 down(p);//使用子树前先传递标记 int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return findi(p->left, l, r);//待查询区间位于其左子树,查询左子树 if (p->l + mid < l) return findi(p->right, l, r);//待查询区间位于其右子树,查询右子树 return min(findi(p->left, l, p->l + mid), findi(p->right, p->l + mid + 1, r));//从中间切开,分别查询,返回最小值 } return inf;//防止破坏 } //----------------------------------------------查询区间和------------------------------------------------ double findu(node *p, int l, int r) { if (p->l == l && p->r == r) return p->sum;//若区间正正好完全重合,返回区间和 down(p);//使用子树前先传递标记 int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { if (p->l + mid >= r) return findu(p->left, l, r);//若待查询区间位于其左子树,查询左子树 if (p->l + mid < l) return findu(p->right, l, r);//若待查询区间位于其右子树,查询右子树 return findu(p->left, l, p->l + mid) + findu(p->right, p->l + mid + 1, r);//从中间切开,分别查询,返回和 } return inf;//防止破坏 } //---------------------------------------------区间加法--------------------------------------------------- void _add(node *p, int l, int r, double k) { if (p->l == l && p->r == r)//若区间完全重合 { if (p->l != p->r) p->lazy_add += k;//若不是叶子结点,更新加法标记 p->min += k;//更新区间最小值 p->max += k;//更新区间最大值 p->sum += k * (p->r - p->l + 1);//更新区间和 return; } int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被包含 { down(p);//使用子树前先传递标记 if (p->l + mid >= r) _add(p->left, l, r, k);//若区间完全位于其左子树,向左子树查找区间l~r并执行加法 else if (p->l + mid < l) _add(p->right, l, r, k);//若区间完全位于其右子树,向右子树查找区间l~r并执行加法 else//从中间切开,分别执行加法 { _add(p->left, l, p->l + mid, k); _add(p->right, p->l + mid + 1, r, k); } } updata(p);//更新 } //----------------------------------------------区间乘法-------------------------------------------------- void _multiply(node *p, int l, int r, double k) { if (p->l == l && p->r == r)//若区间完全重合 { if (p->l != p->r) p->lazy_multiply *= k;//若不是叶子结点,更新乘法标记 if (k >= 0)//若乘数为正 { p->max *= k;//原来最大的还是最大的 p->min *= k;//原来最小的还是最小的 } else { double max = p->max, min = p->min;//暂时储存 p->max = min * k;//原来最大的符号相反后变成最小的 p->min = max * k;//原来最小的符号相反后变为最大的 } p->sum *= k;//更新区间和 p->lazy_add *= k;//更新加法标记(下面每个数要加的数也乘了k) return; } int mid = (p->r - p->l) / 2; if (p->l <= l && p->r >= r)//若区间被完全包含 { down(p);//使用子节点前先更新标记 if (p->l + mid >= r) _multiply(p->left, l, r, k);//若待操作区间完全位于左子树,向左子树查询并进行区间乘法 else if (p->l + mid < l) _multiply(p->right, l, r, k);//若待操作区间完全位于右子树,向右子树查询并进行区间乘法 else//从中间切开,分别进行操作 { _multiply(p->left, l, p->l + mid, k); _multiply(p->right, p->l + mid + 1, r, k); } } updata(p);//更新 } }; tree Te; int main() { char D; aaa:cin >> D; switch (D) { case 'B'://构造 { int n; cin >> n; vector<double> muns; double num; for (int i = 0; i < n; i++) { cin >> num; muns.push_back(num); } Te.build(muns); break; } case 'N'://查询区间最小值 { int l, r; cin >> l >> r; cout << Te.find(l, r, "min") << endl; break; } case 'M'://查询区间和 { int l, r; cin >> l >> r; cout << Te.find(l, r, "sum") << endl; break; } case 'X'://查询区间最大值 { int l, r; cin >> l >> r; cout << Te.find(l, r, "max") << endl; break; } case 'J'://区间加法 { int l, r; double k; cin >> l >> r >> k; Te.add(l, r, k); break; } case 'C'://区间乘法 { int l, r; double k; cin >> l >> r >> k; Te.multiply(l, r, k); break; } case 'O'://输出 { Te.out(); break; } case 'E'://结束 { return 0; } } goto aaa; return 0; }

数组版(知道能拿数组写就好)

cpp
#include<iostream> #include<vector> #include<algorithm> using namespace std; #define inf 0x3f3f3f3f class tree { public: struct node { double max; double min; double sum; int l, r; double lazy_add; double lazy_multiply; node() { l = -1; r = -1; max = -inf; min = inf; sum = inf; lazy_add = 0; lazy_multiply = 1; } }T[10000];//T[n*2+1]是T[n]的左子树,T[n*2+2]是T[n]的右子树 void build(vector<double> num) { nums = num; _build(0, 0, nums.size() - 1); } double find(int l, int r, string mode) { if (mode == "max") return finda(0, l, r); if (mode == "min") return findi(0, l, r); if (mode == "sum") return findu(0, l, r); } void add(int l, int r, double k) { _add(0, l, r, k); } void multiply(int l, int r, double k) { _multiply(0, l, r, k); } private: vector<double>nums; void updata(int n) { T[n].max = max(T[n * 2 + 1].max, T[n * 2 + 2].max); T[n].min = min(T[n * 2 + 1].min, T[n * 2 + 2].min); T[n].sum = T[n * 2 + 1].sum + T[n * 2 + 2].sum; } void _build(int n, int l, int r) { T[n].l = l; T[n].r = r; if (l == r) { T[n].max = nums[l]; T[n].min = nums[l]; T[n].sum = nums[l]; return; } int mid = (r - l) / 2; _build(n * 2 + 1, l, l + mid); _build(n * 2 + 2, l + mid + 1, r); updata(n); } void down(int n) { if (T[n * 2 + 1].l == T[n * 2 + 1].r) { T[n * 2 + 1].max *= T[n].lazy_multiply; T[n * 2 + 1].min *= T[n].lazy_multiply; T[n * 2 + 1].sum *= T[n].lazy_multiply; T[n * 2 + 1].max += T[n].lazy_add; T[n * 2 + 1].min += T[n].lazy_add; T[n * 2 + 1].sum += T[n].lazy_add; } else { T[n * 2 + 1].lazy_multiply *= T[n].lazy_multiply; T[n * 2 + 1].lazy_add *= T[n].lazy_multiply; T[n * 2 + 1].lazy_add += T[n].lazy_add; T[n * 2 + 1].max *= T[n].lazy_multiply; T[n * 2 + 1].min *= T[n].lazy_multiply; T[n * 2 + 1].sum *= T[n].lazy_multiply; T[n * 2 + 1].max += T[n].lazy_add; T[n * 2 + 1].min += T[n].lazy_add; T[n * 2 + 1].sum += T[n].lazy_add * (T[n * 2 + 1].r - T[n * 2 + 1].l + 1); } if (T[n * 2 + 2].l == T[n * 2 + 2].r) { T[n * 2 + 2].max *= T[n].lazy_multiply; T[n * 2 + 2].min *= T[n].lazy_multiply; T[n * 2 + 2].sum *= T[n].lazy_multiply; T[n * 2 + 2].max += T[n].lazy_add; T[n * 2 + 2].min += T[n].lazy_add; T[n * 2 + 2].sum += T[n].lazy_add; } else { T[n * 2 + 2].lazy_multiply *= T[n].lazy_multiply; T[n * 2 + 2].lazy_add *= T[n].lazy_multiply; T[n * 2 + 2].lazy_add += T[n].lazy_add; T[n * 2 + 2].max *= T[n].lazy_multiply; T[n * 2 + 2].min *= T[n].lazy_multiply; T[n * 2 + 2].sum *= T[n].lazy_multiply; T[n * 2 + 2].max += T[n].lazy_add; T[n * 2 + 2].min += T[n].lazy_add; T[n * 2 + 2].sum += T[n].lazy_add * (T[n * 2 + 2].r - T[n * 2 + 2].l + 1); } T[n].lazy_add = 0; T[n].lazy_multiply = 1; } double finda(int n, int l, int r) { if (T[n].l == l && T[n].r == r) return T[n].max; down(n); int mid = (T[n].r - T[n].l) / 2; if (T[n].l <= l && T[n].r >= r) { if (T[n].l + mid >= r) return finda(n * 2 + 1, l, r); if (T[n].l + mid < l) return finda(n * 2 + 2, l, r); return max(finda(n * 2 + 1, l, T[n].l + mid), finda(n * 2 + 2, T[n].l + mid + 1, r)); } return -inf; } double findi(int n, int l, int r) { if (T[n].l == l && T[n].r == r) return T[n].min; down(n); int mid = (T[n].r - T[n].l) / 2; if (T[n].l <= l && T[n].r >= r) { if (T[n].l + mid >= r) return findi(n * 2 + 1, l, r); if (T[n].l + mid < l) return findi(n * 2 + 2, l, r); return min(findi(n * 2 + 1, l, T[n].l + mid), findi(n * 2 + 2, T[n].l + mid + 1, r)); } return inf; } double findu(int n, int l, int r) { if (T[n].l == l && T[n].r == r) return T[n].sum; down(n); int mid = (T[n].r - T[n].l) / 2; if (T[n].l <= l && T[n].r >= r) { if (T[n].l + mid >= r) return findu(n * 2 + 1, l, r); if (T[n].l + mid < l) return findu(n * 2 + 2, l, r); return findu(n * 2 + 1, l, T[n].l + mid) + findu(n * 2 + 2, T[n].l + mid + 1, r); } return inf; } void _add(int n, int l, int r, double k) { if (T[n].l == l && T[n].r == r) { if (T[n].l != T[n].r) T[n].lazy_add += k; T[n].min += k; T[n].max += k; T[n].sum += k * (T[n].r - T[n].l + 1); return; } int mid = (T[n].r - T[n].l) / 2; if (T[n].l <= l && T[n].r >= r) { down(n); if (T[n].l + mid >= r) _add(n * 2 + 1, l, r, k); else if (T[n].l + mid < l) _add(n * 2 + 2, l, r, k); else { _add(n * 2 + 1, l, T[n].l + mid, k); _add(n * 2 + 2, T[n].l + mid + 1, r, k); } } updata(n); } void _multiply(int n, int l, int r, double k) { if (T[n].l == l && T[n].r == r) { if (T[n].l != T[n].r)T[n].lazy_multiply *= k; if (k >= 0) { T[n].max *= k; T[n].min *= k; } else { double max = T[n].max, min = T[n].min; T[n].max = min * k; T[n].min = max * k; } T[n].sum *= k; T[n].lazy_add *= k; return; } int mid = (T[n].r - T[n].l) / 2; if (T[n].l <= l && T[n].r >= r) { down(n); if (T[n].l + mid >= r) _multiply(n * 2 + 1, l, r, k); else if (T[n].l + mid < l) _multiply(n * 2 + 2, l, r, k); else { _multiply(n * 2 + 1, l, T[n].l + mid, k); _multiply(n * 2 + 2, T[n].l + mid + 1, r, k); } } updata(n); } }; tree Te; int main() { char D; aaa:cin >> D; switch (D) { case 'B': { int n; cin >> n; vector<double> muns; double num; for (int i = 0; i < n; i++) { cin >> num; muns.push_back(num); } Te.build(muns); break; } case 'N': { int l, r; cin >> l >> r; cout << Te.find(l, r, "min") << endl; break; } case 'M': { int l, r; cin >> l >> r; cout << Te.find(l, r, "sum") << endl; break; } case 'X': { int l, r; cin >> l >> r; cout << Te.find(l, r, "max") << endl; break; } case 'J': { int l, r; double k; cin >> l >> r >> k; Te.add(l, r, k); break; } case 'C': { int l, r; double k; cin >> l >> r >> k; Te.multiply(l, r, k); } } goto aaa; return 0; }

压行版(最好多打几遍背下来)

cpp
#include<iostream> #include<algorithm> using namespace std; int la[100000];//加法标记 int lm[100000];//乘法标记 int x[100000];//节点最大值 int n[100000];//节点最小值 int s[100000];//节点区间和 int ln[100000];//左端点 int rn[100000];//右端点 int num[100000];//临时数据 #define ls (p<<1)//左儿子 #define rs (ls|1)//右儿子 #define mid (ln[p]+rn[p]>>1) #define a la[p]//当前节点加法标记 #define m lm[p]//当前节点减法标记 inline void updata(int p) { x[p] = max(x[ls], x[rs]); n[p] = min(n[ls], n[rs]); s[p] = s[ls] + s[rs]; } void build(int p, int l, int r) { ln[p] = l; rn[p] = r; if (l == r) { x[p] = n[p] = s[p] = num[l]; return; } build(ls, l, mid); build(rs, mid + 1, r); updata(p); } inline void down(int p) { if (m != 1) { la[ls] *= m; lm[ls] *= m; la[rs] *= m; lm[rs] *= m; x[ls] *= m; n[ls] *= m; s[ls] *= m; x[rs] *= m; n[rs] *= m; s[rs] *= m; lm[p] = 1; } if (a != 0) { la[ls] += a; la[rs] += a; x[ls] += a; n[ls] += a; s[ls] += (rn[ls] - ln[ls] + 1) * a; x[rs] += a; n[rs] += a; s[rs] += (rn[rs] - ln[rs] + 1) * a; a = 0; } } void change(int p, int l, int r, int add/*加法*/, int mut/*乘法*/)//乘法优先 { if (ln[p] == l && rn[p] == r) { x[p] *= mut; n[p] *= mut; s[p] *= mut; x[p] += add; n[p] += add; s[p] += (rn[p] - ln[p] + 1) * add; if (l != r) { a *= mut; a += add; m *= mut; } return; } down(p); if (l > mid) change(rs, l, r, add, mut); else if (r <= mid) change(ls, l, r, add, mut); else { change(ls, l, mid, add, mut); change(rs, mid + 1, r, add, mut); } updata(p); } int mx(int p, int l, int r) { if (ln[p] == l && rn[p] == r) return x[p]; down(p); if (l > mid) return mx(rs, l, r); if (r <= mid) return mx(ls, l, r); return max(mx(ls, l, mid), mx(rs, mid + 1, r)); } int mn(int p, int l, int r) { if (ln[p] == l && rn[p] == r) return n[p]; down(p); if (l > mid) return mn(rs, l, r); if (r <= mid) return mn(ls, l, r); return min(mn(ls, l, mid), mn(rs, mid + 1, r)); } int sum(int p, int l, int r) { if (ln[p] == l && rn[p] == r) return s[p]; down(p); if (l > mid) return sum(rs, l, r); if (r <= mid) return sum(ls, l, r); return sum(ls, l, mid) + sum(rs, mid + 1, r); } int main() { int n; cin >> n; for (int i = 0; i < n; i++) cin >> num[i]; build(1, 0, n - 1); //change(1, 0, n - 1, 1, 2); cout << sum(1, 0, n - 1); return 0; }

本文作者:GBwater

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!