标签: 线段树

  • CFGYM101982F (2018 ICPC Pacific Northwest Regional Contest – Rectangles)

    题意:给定n(1e5)个矩形,边与坐标轴平行,四个点的坐标都是整数(1e9级别),问被矩形覆盖了奇数次的面积和。

    首先把所有的矩形拆成上边和下边两条边,这样我们就得到了若干条与x轴平行的线段。然后我们把这些线段按照y坐标从小到大排个序。

    现在我们维护一棵线段树,下标表示的是(离散化后的)整个横轴的覆盖情况。以及为了方便起见,整个线段树上所表示的区间都是左闭右开的(一条线段所能表示的实际长度是右端点减左端点)。

    现在我们按照y坐标从小到大枚举每一条横线,对于一条横线,我们把它在线段树上所覆盖的的区域异或上1,同时面积取反(即用线段长度减掉区域内被标记过的线段长度和,此操作即为异或),由于我们存储的都是左闭右开区间,所以直接用右端点对应的数减掉左端点对应的数即可。

    最后每插入一条线段,我们就查一下现在线段树上被覆盖的区间长度,然后乘上下一条线段的高度减掉当前线段的高度(此方法是可以正常处理有相同高度的线段的,因为他们高度相同,所以算出来都是0)。

    #include <algorithm>
    #include <iostream>
    #include <vector>
    
    using int_t = long long int;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    struct Segment {
        int_t left, right, height;
    };
    struct Node {
        Node *left = nullptr, *right = nullptr;
        int_t begin, end;
        bool flag = false;
        int_t sum = 0;
        const std::vector<int_t>& vals;
        Node(int_t begin, int_t end, const std::vector<int_t>& vals)
            : begin(begin), end(end), vals(vals) {}
        void swap() {
            flag ^= 1;
            sum = vals[end] - vals[begin] - sum;
        }
        void pushdown() {
            if (flag) {
                left->swap();
                right->swap();
                flag = 0;
            }
        }
        void maintain() { sum = left->sum + right->sum; }
        void insert(int_t begin, int_t end) {  //插入一条线段
    #ifdef DEBUG
            cout << "insert " << begin << "," << end << " to " << this->begin << ","
                 << this->end << endl;
    #endif
            if (this->begin >= begin && this->end <= end) {
                this->swap();
                return;
            }
            if (this->begin >= end || this->end <= begin) {
                return;
            }
            int_t mid0 = (this->begin + this->end) / 2;
            pushdown();
            left->insert(begin, end);
            right->insert(begin, end);
            maintain();
        }
    };
    Node* build(int_t begin, int_t end, const std::vector<int_t>& vals) {
        Node* node = new Node(begin, end, vals);
        if (begin + 1 != end) {
            int_t mid = (begin + end) / 2;
            node->left = build(begin, mid, vals);
            node->right = build(mid, end, vals);
        }
        return node;
    }
    int main() {
        std::ios::sync_with_stdio(false);
        std::vector<Segment> segs;
        std::vector<int_t> nums;
        int_t n;
        cin >> n;
        for (int_t i = 1; i <= n; i++) {
            int_t x1, y1, x2, y2;
            cin >> x1 >> y1 >> x2 >> y2;
            segs.push_back(Segment{x1, x2, y1});
            segs.push_back(Segment{x1, x2, y2});
            nums.push_back(x1);
            nums.push_back(x2);
        }
        std::sort(nums.begin(), nums.end());
        nums.resize(std::unique(nums.begin(), nums.end()) - nums.begin());
    
        std::sort(segs.begin(), segs.end(), [](const Segment& a, const Segment& b) {
            return a.height < b.height;
        });
        int_t result = 0;
        Node* root = build(0, nums.size() - 1, nums);
        const auto rank = [&](int_t x) {
            return std::lower_bound(nums.begin(), nums.end(), x) - nums.begin();
        };
    #ifdef DEBUG
        for (const auto& s : segs) {
            cout << "seg " << s.left << " " << s.right << " " << s.height << endl;
        }
    #endif
        for (int_t i = 0; i < segs.size() - 1; i++) {
            //先插入后统计结果
            const auto& curr = segs[i];
            root->insert(rank(curr.left), rank(curr.right));
            result += (segs[i + 1].height - curr.height) * root->sum;
        }
        cout << result << endl;
        return 0;
    }
    
  • BJOI2019 删数

    考虑$p\neq 0$的情况。

    使用$a_i$表示数字i出现的次数。

    构建一棵线段树,令$a_i$覆盖$[i-a_i+1,a_i]$的区间,最终$[1,n]$内为0的位置的个数即为答案。

    很显然一个位置为0,我们必定要调整之后的某个数到这个位置来覆盖它。

    带整体加减?

    区间平移。

    #include <iostream>
    #include <algorithm>
    #include <cstdio>
    #include <vector>
    #include <inttypes.h>
    #define debug(x) std::cout << #x << " = " << x << std::endl;
    
    typedef long long int int_t;
    using std::cin;
    using std::endl;
    using std::cout;
    
    const int_t LARGE = 5e5;
    
    struct State {
    	int_t val,count;
    	State operator+(const State& rhs) const {
    		State next=*this;
    		if(rhs.val==next.val) next.count+=rhs.count;
    		else if(rhs.val<next.val) {
    			next=rhs;
    		}
    		return next;
    	}
    };
    struct Node {
    	Node*left,*right;
    	State state {0,1};
    	int_t mark=0;
    	int_t begin,end;
    	Node(int_t begin,int_t end) {
    		this->begin=begin,this->end=end;
    	}
    	void add(int_t x) {
    //        minval+=x;
    		state.val+=x;
    		mark+=x;
    	}
    	void pushDown() {
    		if(mark) {
    			left->add(mark),right->add(mark);
    			mark=0;
    		}
    	}
    	void maintain() {
    		if(begin==end) return;
    		state=left->state+right->state;
    	}
    	static Node* build(int_t begin,int_t end) {
    		Node* node=new Node(begin,end);
    		if(begin!=end) {
    			int_t mid=(begin+end)/2;
    			node->left=build(begin,mid);
    			node->right=build(mid+1,end);
    		}
    		node->maintain();
    		return node;
    	}
    	void add(int_t begin,int_t end,int_t x) {
    		if(end<this->begin||begin>this->end) return;
    		if(this->begin>=begin&&this->end<=end) {
    			this->add(x);
    			return;
    		}
    		pushDown();
    		left->add(begin,end,x);
    		right->add(begin,end,x);
    		maintain();
    	}
    	State query(int_t begin,int_t end) {
    		if(this->begin>=begin&&this->end<=end) return state;
    		int_t mid=(this->begin+this->end)/2;
    		pushDown();
    		if(end<=mid) return left->query(begin,end);
    		else if(begin>mid) return right->query(begin,end);
    		return left->query(begin,mid)+right->query(mid+1,end);
    	}
    };
    int_t count[LARGE+1];
    int_t seq[LARGE+1];
    int_t n,m;
    Node* root;
    int_t lOff,rOff;
    //ÈÃij¸öÊý½øÈë/Í˳ö
    void modify(int_t x,int_t opt) {
    	if(opt==1) count[x]+=opt;
    #ifdef DEBUG
    	cout<<"exec pos "<<x-count[x]+1<<" "<<opt<<" number "<<x<<endl;
    	cout<<"before exec count = "<<count[x]<<endl;
    #endif
    	root->add(x-count[x]+1,x-count[x]+1,opt);
    	if(opt==-1) count[x]+=opt;
    }
    int main() {
    	scanf("%lld%lld",&n,&m);
    
    	lOff=std::max(n,m)+1;
    	rOff=lOff+n-1;
    	for(int i=1; i<=n; i++) {
    		int_t x;
    		scanf("%lld",&x);
    		x+=std::max(n,m);
    		count[x]++;
    		seq[i]=x;
    	}
    	root=Node::build(1,3*std::max(n,m));
    	for(int_t i=lOff; i<=rOff+1; i++) {
    #ifdef DEBUG
    		cout<<"count "<<i<<" = "<<count[i]<<" cover "<<i-count[i]+1<<" to "<<i<<endl;
    #endif
    		if(count[i])
    			root->add(i-count[i]+1,i,1);
    	}
    	for(int_t i=1; i<=m; i++) {
    		int_t opt,x;
    		scanf("%lld%lld",&opt,&x);
    		if(opt==0) {
    			if(x==1) {
    				if(count[rOff]) {
    					root->add(rOff-count[rOff]+1,rOff,-1);
    				}
    				lOff--,rOff--;
    			} else {
    				lOff++,rOff++;
    				if(count[rOff]) {
    					root->add(rOff-count[rOff]+1,rOff,1);
    				}
    
    			}
    #ifdef DEBUG
    			cout<<"moved to "<<lOff<<" "<<rOff<<endl;
    #endif
    
    		} else {
    			if(seq[opt]<=rOff)
    				root->add(seq[opt]-count[seq[opt]]+1,seq[opt]-count[seq[opt]]+1,-1);
    			count[seq[opt]]-=1;
    			seq[opt]=x+lOff-1;
    			modify(seq[opt],1);
    
    		}
    		auto ret=root->query(lOff,rOff);
    		if(ret.val!=0) {
    			printf("0\n");
    		} else {
    			printf("%lld\n",ret.count);
    		}
    #ifdef DEBUG
    		for(int_t i=1; i<=rOff; i++) cout<<i<<" = "<<root->query(i,i).val<<endl;
    #endif
    
    	}
    	return 0;
    }

     

  • YT2sOJ04 give you a tree

    zzs上辈子就会做的题我现在才会做..

    一个区间能形成一个联通块,当且仅当这个区间内的边有 区间长度-1 条。

    考虑扫描线。

    每个点存下终点编号比它小的边。

    然后从1开始扫,设当前扫到的点为vtx,维护一棵线段树,线段树下标为x的点的值减掉x后表示的是区间[x,vtx]内存在的边数。

    每次扫到先vtx后,首先把vtx的值设置成vtx,然后枚举vtx一条终点编号比vtx小的出边v,显然v可以给左端点区间在[1,v]内的区间贡献一条边(让以这些地方为左端点,vtx为右端点的区间包括的边数多了1)。

    考虑一个左端点什么时候会成为合法的左端点。

    设这个左端点是x,这个左端点在线段树上的值为val,当且仅当$val-x=vtx-x$(根据上面的定义,线段树下标为x的点的值减掉x后表示的是区间[x,vtx]内存在的边数),即$val=vtx$时,这个左端点会成为一个合法的左端点。

    所以只需要维护全局最大值出现的次数即可。

    #include <iostream>
    #include <algorithm>
    #include <cstdio>
    #include <vector>
    #include <inttypes.h>
    #define debug(x) std::cout << #x << " = " << x << std::endl;
    
    typedef int int_t;
    
    using std::cin;
    using std::endl;
    using std::cout;
    const int_t LARGE = 3e5;
    struct Node {
        Node* left, *right;
        int_t begin, end;
        int_t mark;
        int_t count;
        int_t max;
        void maintain() {
            if (begin != end) {
                max = std::max(left->max, right->max);
                count = 0;
                if (left->max == max)
                    count += left->count;
                if (right->max == max)
                    count += right->count;
            }
        }
        void add(int_t x) {
            max += x;
            mark += x;
        }
        void pushDown() {
            if (begin != end) {
                left->add(mark);
                right->add(mark);
                mark = 0;
            }
        }
        Node(int_t begin, int_t end) {
            this->begin = begin;
            this->end = end;
            max = mark = count = 0;
            left = right = nullptr;
        }
        void add(int_t begin, int_t end, int_t val) {
            if (end < this->begin || begin > this->end) {
                return;
            }
            if (this->begin >= begin && this->end <= end) {
                this->add(val);
                return;
            }
            pushDown();
            left->add(begin, end, val);
            right->add(begin, end, val);
            maintain();
        }
        void setPos(int_t pos, int_t val) {
            if (begin == end) {
                this->max = pos;
                this->count = 1;
                return;
            }
            int_t mid = (begin + end) / 2;
            pushDown();
            if (pos <= mid)
                left->setPos(pos, val);
            else
                right->setPos(pos, val);
            maintain();
        }
        static Node* build(int_t begin, int_t end) {
            Node* node = new Node(begin, end);
            if (begin != end) {
                int_t mid = (begin + end) / 2;
                node->left = Node::build(begin, mid);
                node->right = Node::build(mid + 1, end);
                node->maintain();
            }
            return node;
        }
    };
    std::vector<int_t> graph[LARGE + 1];
    int_t n;
    int main() {
        scanf("%d", &n);
        for (int_t i = 1; i <= n - 1; i++) {
            int_t v1, v2;
            scanf("%d%d", &v1, &v2);
            if (v1 < v2)
                std::swap(v1, v2);
            graph[v1].push_back(v2);
        }
        int64_t result = 0;
        Node* root = Node::build(1, n);
        for (int_t i = 1; i <= n; i++) {
            root->setPos(i, i);
            for (int_t to : graph[i]) root->add(1, to, 1);
            result += root->count;
    #ifdef DEBUG
            cout << "got " << root->count << " at vtx " << i << endl;
    #endif
        }
        printf("%lld\n", result);
        return 0;
    }

     

  • LOJ6029 市场

    不会势能分析..

    #include <iostream>
    #include <algorithm>
    #include <cstdio>
    #include <vector>
    #include <inttypes.h>
    #include <cmath>
    #define debug(x) std::cout << #x << " = " << x << std::endl;
    
    using int_t = long long int;
    
    using std::cin;
    using std::endl;
    using std::cout;
    const int_t INF = 0x7fffffff;
    
    const int_t LARGE = 1e5;
    void* allocate();
    struct Node {
    	Node* left = nullptr,*right = nullptr;
    	int begin,end;
    	int_t sum = 0;
    	int minval = INF;
    	int maxval = -INF;
    	//±äÁ¿Î´³õʼ»¯
    	int mark = 0;
    	Node(int begin,int end) {
    		this -> begin = begin;
    		this -> end = end;
    	}
    	void maintain() {
    		if(begin != end) {
    			sum = left -> sum + right -> sum;
    			minval = std::min(left->minval,right -> minval);
    			maxval = std::max(left -> maxval ,right -> maxval);
    		}
    	}
    	void add(int x) {
    		mark += x;
    		minval += x;
    		maxval += x;
    		sum += (int_t)(end - begin + 1) * x;
    #ifdef DEBUG
    		cout<<"adding "<<x<<" at node "<<this->begin<<" "<<this->end<<" "<<sum<<" "<<mark<<" "<<minval<<" "<<maxval<<endl;
    #endif
    	}
    	void pushDown() {
    		if(begin != end) {
    			left -> add(mark);
    			right -> add(mark);
    			mark = 0;
    		}
    	}
    	void add(int begin,int end,int val) {
    		if(end < this -> begin || begin > this -> end) return;
    #ifdef DEBUG
    		cout<<"adding "<<val<<" for "<<begin<<" "<<end<<" at "<<this->begin<<" "<<this->end<<endl;
    #endif
    		if(this -> begin >= begin && this -> end <= end) {
    #ifdef DEBUG
    			cout<<"dir add at "<<this->begin<<" "<<this->end<<" "<<val<<endl;
    #endif
    			this -> add(val);
    			return;
    		}
    		pushDown();
    		left -> add(begin,end,val);
    		right -> add(begin,end,val);
    		maintain();
    	}
    	int_t querySum(int begin,int end) {
    		if(end < this -> begin || begin > this -> end) return 0;
    		if(this -> begin >= begin && this -> end <= end) {
    			return this -> sum;
    		}
    		pushDown();
    		return left -> querySum(begin,end) + right -> querySum(begin,end);
    	}
    	int queryMin(int begin,int end) {
    		if(end < this -> begin || begin > this -> end ) return INF;
    		if(this -> begin >= begin && this -> end <= end) return this -> minval;
    		pushDown();
    		return std::min(left -> queryMin(begin,end),right -> queryMin(begin,end));
    	}
    	void divide(int x) {
    #ifdef DEBUG
    		cout<<"div "<<x<<" at node "<<this->begin<<" "<<this->end<<endl;
    #endif
    		if(int_t(this -> maxval - floor((double)this -> maxval / x)) == int_t(this -> minval - floor((double)this -> minval / x))) {
    			this -> add(-(this -> maxval - floor((double)this -> maxval / x)));
    			return;
    		}
    		pushDown();
    		left -> divide(x);
    		right -> divide(x);
    		maintain();
    	}
    	void divide(int begin,int end,int val) {
    		if(end < this -> begin || begin > this -> end) return;
    		if(this -> begin >= begin && this -> end <= end) {
    			this -> divide(val);
    			return;
    		}
    		pushDown();
    		left -> divide(begin,end,val);
    		right -> divide(begin,end,val);
    		maintain();
    	}
    	static Node* build(int begin,int end,int* val) {
    		Node* node = new (allocate()) Node(begin,end);
    		if(begin != end) {
    			int_t mid = (begin + end) / 2;
    			node -> left = build(begin,mid,val);
    			node -> right = build(mid + 1 ,end ,val);
    			node -> maintain();
    		} else {
    			node -> sum = node -> minval = node -> maxval = val[begin];
    		}
    		return node;
    	}
    };
    void* allocate(){
        const int_t SIZE = 3e5;
        static char buf[SIZE * sizeof(Node)];
        static int_t used = 0;
        return buf + (++used) * sizeof(Node);
    }
    
    int val[LARGE + 1];
    int n,m;
    int main() {
    	scanf("%d%d",&n,&m);
    	for(int_t i = 1; i <= n; i ++) scanf("%d",&val[i]);
    	Node* root = Node::build(1,n,val);
    	for(int_t i = 1; i <= m; i ++) {
    		int opt,left,right,val;
    		scanf("%d%d%d",&opt,&left,&right);
    		left += 1;
    		right += 1;
    		if(opt == 1) {
    			scanf("%d",&val);
    			root -> add(left,right,val);
    		} else if(opt == 2) {
    			scanf("%d",&val);
    			root -> divide(left,right,val);
    		} else if(opt == 3) {
    			printf("%d\n",root->queryMin(left,right));
    		} else if(opt == 4) {
    			printf("%lld\n",root->querySum(left,right));
    		}
    	}
    	return 0;
    }

     

  • SDOI2017 树点涂色

    假设我们可以通过某种东西维护出$f(n)$,表示n号点的树根的路径上所经过的不同颜色的点的个数(初始时$f(n)=depth(n)$),那么询问2,3显然就可以直接回答了。

    怎么维护呢?

    注意到每次操作是给树根到一个点的路径上染上不同的权值.

    这恰好和LCT的access操作很像。

    所以我们用若干棵辅助树来维护若干个若干条链。

    每个辅助树维护一条同色链。

    一次染色操作就相当于一次access操作,即把一条到根的链染成同一种颜色。

    然后考虑如何维护$f(n)$?

    考虑把一个点vtx染成新的颜色会发生什么。

    vtx的一条边从虚边变成了实边:

    这时代表vtx和这条边的另一个端点的颜色一定一样,另一个端点为根的整棵树答案减1.

    因为颜色数少了一种。

    反之,如果一条边从实边变成了虚边。

    这条边的另一个端点所有答案加1,因为这条变得两个端点原来颜色是相同的,但是现在不同了。

    注意别搞混DFS序。

    #include <cmath>
    #include <cstdio>
    #include <iostream>
    #include <string>
    #include <vector>
    using int_t = long long int;
    
    using std::cin;
    using std::cout;
    using std::endl;
    const int_t LARGE = 1e5;
    const int_t INF = 0x7fffffff;
    int_t DFSN[LARGE + 1], size[LARGE + 1], byDFSN[LARGE + 1], parent[LARGE + 1],
        first[LARGE + 1], depth[LARGE + 1];
    std::vector<int_t> graph[LARGE + 1];
    struct SegNode {
        int_t begin, end;
        SegNode *left, *right;
        int_t sum;
        int_t max;
        int_t mark;
        SegNode(int_t begin, int_t end) {
            this->left = this->right = nullptr;
            mark = sum = 0;
            max = -INF;
            this->begin = begin;
            this->end = end;
        }
        void add(int_t x) {
            sum += (end - begin + 1) * x;
            max += x;
            mark += x;
        }
        void pushDown() {
            if (begin != end) {
                if (mark) {
                    left->add(mark);
                    right->add(mark);
                    mark = 0;
                }
            }
        }
        void maintain() {
            if (begin != end) {
                sum = left->sum + right->sum;
                max = std::max(left->max, right->max);
            }
        }
        int_t querySum(int_t begin, int_t end) {
            if (end < this->begin || begin > this->end) return 0;
            if (this->begin >= begin && this->end <= end) {
                return this->sum;
            }
            this->pushDown();
            return left->querySum(begin, end) + right->querySum(begin, end);
        }
        int_t queryMax(int_t begin, int_t end) {
            if (end < this->begin || begin > this->end) return -INF;
            if (this->begin >= begin && this->end <= end) return this->max;
            pushDown();
            return std::max(left->queryMax(begin, end),
                            right->queryMax(begin, end));
        }
        void add(int_t begin, int_t end, int_t val) {
            if (end < this->begin || begin > this->end) return;
            if (this->begin >= begin && this->end <= end) {
                this->add(val);
                return;
            }
            pushDown();
            left->add(begin, end, val);
            right->add(begin, end, val);
            maintain();
        }
        template <class Func>
        static SegNode* build(int_t begin, int_t end, Func f) {
            SegNode* node = new SegNode(begin, end);
            if (begin != end) {
                int_t mid = (begin + end) / 2;
                node->left = build(begin, mid, f);
                node->right = build(mid + 1, end, f);
                node->maintain();
            } else {
                node->max = node->sum = f(begin);
            }
    
    #ifdef DEBUG
            cout << "interval " << begin << " " << end << " sum = " << node->sum
                 << " max = " << node->max << endl;
    #endif
            return node;
        }
    };
    SegNode* root = nullptr;
    struct Node {
        Node* chds[2] = {nullptr, nullptr};
        Node*& left;
        Node*& right;
        Node* parent = nullptr;
        int_t id;
        int_t minid;
        Node(int_t id) : left(chds[0]), right(chds[1]) {
            this->id = this->minid = id;
        }
        int_t chdOf() {
            if (parent == nullptr)
                return -1;
            else if (this == parent->left)
                return 0;
            else if (this == parent->right)
                return 1;
            return -1;
        }
    
        void maintain() {
            minid = id;
            if (left) {
                left->parent = this;
                minid = left->minid;
            }
            if (right) right->parent = this;
        }
        bool isAuxRoot() { return chdOf() == -1; }
        // bool isRoot() { return parent == nullptr; }
    };
    Node* nodes[LARGE + 1];
    void rotate(Node* node) {
        Node* Pr = node->parent;
        Node* Gr = Pr->parent;
        int_t side = node->chdOf();
        int_t sidep = Pr->chdOf();
        Pr->chds[side] = node->chds[side ^ 1];
        node->chds[side ^ 1] = Pr;
        node->parent = Gr;
        Pr->maintain();
        node->maintain();
        if (Gr && sidep != -1) {
            Gr->chds[sidep] = node;
            Gr->maintain();
        }
    }
    void splay(Node* node) {
        while (!(node->isAuxRoot())) {
            if (node->parent->isAuxRoot()) {
                rotate(node);
            } else {
                if (node->chdOf() == node->parent->chdOf()) {
                    rotate(node->parent);
                } else {
                    rotate(node);
                }
                rotate(node);
            }
        }
    }
    
    void access(Node* node) {
        Node* last = nullptr;
        while (node) {
            splay(node);
            //断掉原有链接
            if (node->right) {
                int_t vtx = node->right->minid;
                //重边->轻边 答案+1
                root->add(DFSN[vtx], DFSN[vtx] + size[vtx] - 1, 1);
            }
            node->right = last;
            node->maintain();
            //轻边->重边 答案-1
            if (node->right) {
                int_t vtx = node->right->minid;
                root->add(DFSN[vtx], DFSN[vtx] + size[vtx] - 1, -1);
            }
            last = node;
            node = node->parent;
        }
    }
    struct Data {
        int_t depth;
        int_t vtx;
        bool operator<(const Data& data) const { return depth < data.depth; }
    } seq[20][LARGE * 2 + 2];
    int_t count = 0;
    
    int_t n, m;
    void DFS(int_t vtx, int_t from = -1, int_t depth = 1) {
        parent[vtx] = from;
        size[vtx] = 1;
        seq[0][++count] = Data{depth, vtx};
        if (first[vtx] == 0) first[vtx] = count;
        DFSN[vtx] = ++DFSN[0];
        byDFSN[DFSN[0]] = vtx;
        ::depth[vtx] = depth;
        nodes[vtx] = new Node(vtx);
        for (int_t to : graph[vtx]) {
            if (to == from) continue;
            DFS(to, vtx, depth + 1);
            size[vtx] += size[to];
            seq[0][++count] = Data{depth, vtx};
            nodes[to]->parent = nodes[vtx];
        }
    }
    int_t getLCA(int_t v1, int_t v2) {
        v1 = first[v1];
        v2 = first[v2];
        if (v1 > v2) std::swap(v1, v2);
        int_t len = log2(v2 - v1 + 1);
        return std::min(seq[len][v1], seq[len][v2 - (1 << len) + 1]).vtx;
    }
    
    int main() {
        scanf("%lld%lld", &n, &m);
        for (int_t i = 1; i <= n - 1; i++) {
            int from, to;
            scanf("%d%d", &from, &to);
            graph[from].push_back(to);
            graph[to].push_back(from);
        }
        DFS(1);
    #ifdef DEBUG
        for (int_t i = 1; i <= n; i++) {
            cout << "by dfsn " << i << " = " << byDFSN[i] << endl;
        }
    #endif
        for (int_t i = 1; (1 << i) <= count; i++) {
            for (int_t j = 1; j + (1 << i) - 1 <= count; j++) {
                seq[i][j] = std::min(seq[i - 1][j], seq[i - 1][j + (1 << (i - 1))]);
            }
        }
        root =
            SegNode::build(1, n, [](int_t x) -> int_t { return depth[byDFSN[x]]; });
        for (int_t i = 1; i <= m; i++) {
            int opt, arg1, arg2;
            scanf("%d%d", &opt, &arg1);
            if (opt == 1) {
                access(nodes[arg1]);
            } else if (opt == 2) {
                scanf("%d", &arg2);
                int_t lca = DFSN[getLCA(arg1, arg2)];
                arg1 = DFSN[arg1];
                arg2 = DFSN[arg2];
                printf("%lld\n", root->querySum(arg1, arg1) +
                                     root->querySum(arg2, arg2) -
                                     2 * root->querySum(lca, lca) + 1);
            } else if (opt == 3) {
    
                printf("%lld\n", root->queryMax(DFSN[arg1], DFSN[arg1] + size[arg1] - 1));
            }
        }
        return 0;
    }

     

  • 洛谷4097 Segment

    李超线段树板子。

    线段树以x坐标为下标。

    每个点维护一下这个点上的所谓最优势线段。

    即这个点上的线段,从天上看能看到的长度最长的线段。

    查询经过一个x坐标的最高的线段,即所有覆盖这个x的区间上的最高的最优势线段。

    如何插入一条线段呢?

    把这个线段对应的横坐标区间拆成线段树上的区间,然后考虑怎么处理。

    假设当前要在一个点vtx上处理一条线段seg。

    假设vtx的最优势线段完全被seg所覆盖,就把vtx的最优势线段换成seg。

    如果vtx的最优势线段完全覆盖seg,直接返回。

    有交的情况下,求一下交点,然后判断一下seg和vtx本来的线段哪条比较长,把长的那条视为vtx的最优势线段。

    短的那个呢?

    把短的那个视为新的线段,然后从短的那个线段所在的区间接着往下处理。

    单次插入复杂度$O(log^2n)$

    #include <algorithm>
    #include <iostream>
    #include <string>
    #include <vector>
    using int_t = long long int;
    using real_t = double;
    
    using std::cin;
    using std::cout;
    using std::endl;
    const int_t INF = 0x7fffffff;
    
    struct Line {
        real_t k, b;
        int_t x0, x1;
        int_t id;
        Line(int_t x0 = 0, int_t y0 = 0, int_t x1 = 1, int_t y1 = 0,
             int_t id = -1) {
            this->x0 = x0;
            this->x1 = x1;
            k = (real_t)(y1 - y0) / (x1 - x0);
            b = y1 - k * x1;
            this->id = id;
        }
        real_t f(int_t x) const { return (real_t)k * x + b; }
    };
    
    // std::vector<Line> lines;
    real_t cross(const Line& l1, const Line& l2) {
        return (real_t)(-l1.b + l2.b) / (l1.k - l2.k);
    }
    std::ostream& operator<<(std::ostream& os, const Line& line) {
        os << "Line{k=" << line.k << ",b=" << line.b << ",x0=" << line.x0
           << ",x1=" << line.x1 << ",id=" << line.id << "}";
        return os;
    }
    struct Node {
        int_t begin, end;
        Node *left, *right;
        // int_t id = -1;
        Line line = Line(0, -INF, INF, -INF, -1);
        Node(int_t begin, int_t end) {
            this->begin = begin;
            this->end = end;
            left = right = nullptr;
        }
        Line query(int_t pos) {
    #ifdef DEBUG
            cout << "querying at " << begin << "," << end << endl;
    #endif
            if (begin == end) {
    #ifdef DEBUG
                cout << "direct ret " << line << endl;
    #endif
                return line;
            }
            Line next;
            int_t mid = (begin + end) / 2;
            if (pos <= mid)
                next = left->query(pos);
            else
                next = right->query(pos);
    #ifdef DEBUG
            cout << "at " << begin << "," << end << endl;
            cout << "chd ret " << next << endl;
            cout << "thiz " << line << endl;
    #endif
            if (next.id == -1) return line;
            if (line.id == -1) return next;
            if (next.f(pos) == line.f(pos))
                return std::min(next, line,
                                [](const Line& a, const Line& b) -> bool {
                                    return a.id < b.id;
                                });
            return std::max(line, next, [=](const Line& a, const Line& b) -> bool {
                return a.f(pos) < b.f(pos);
            });
        }
        void applyTo(Line next) {
    #ifdef DEBUG
            cout << "applying " << next << " at " << begin << "," << end << endl;
    #endif
            if (this->line.id == -1) {
                this->line = next;
    #ifdef DEBUG
                cout << "direct set " << next << endl;
    #endif
                return;
            }
            if (line.f(begin) > next.f(begin) && line.f(end) > next.f(end)) {
    #ifdef DEBUG
                cout << " all covered ret" << endl;
    #endif
                return;
            } else if (line.f(begin) < next.f(begin) && line.f(end) < next.f(end)) {
                this->line = next;
    #ifdef DEBUG
                cout << "cover all ret " << endl;
    #endif
                return;
            }
            real_t cross = ::cross(line, next);
            int_t mid = (begin + end) / 2;
            if (cross == mid) {
                if (line.f(begin) > next.f(begin)) {
                    if (right) right->applyTo(next);
                } else {
                    if (left) left->applyTo(line);
                }
            } else if (cross < mid) {
                if (next.f(end) > line.f(end)) {
                    std::swap(next, line);
                }
                if (left) left->applyTo(next);
            } else if (cross > mid) {
                if (next.f(begin) > line.f(begin)) {
                    std::swap(next, line);
                }
                if (right) right->applyTo(next);
            }
        }
        void apply(int_t begin, int_t end, const Line& line) {
            if (end < this->begin || begin > this->end) return;
            if (this->begin >= begin && this->end <= end) {
                this->applyTo(line);
                return;
            }
            left->apply(begin, end, line);
            right->apply(begin, end, line);
        }
        static Node* build(int_t begin, int_t end) {
            Node* node = new Node(begin, end);
            if (begin != end) {
                int_t mid = (begin + end) / 2;
                node->left = build(begin, mid);
                node->right = build(mid + 1, end);
            }
            return node;
        }
    };
    
    int main() {
        const int_t mod = 39989;
        Node* root = Node::build(1, mod);
        int_t lastans = 0;
        const auto f = [&](int_t x, int_t mod = 39989) -> int_t {
            return ((x + lastans - 1) % mod + 1);
        };
        int n;
        scanf("%d", &n);
        int_t id = 0;
        for (int i = 1; i <= n; i++) {
            int opt;
            scanf("%d", &opt);
            if (opt == 1) {
                int x0, x1, y0, y1;
                scanf("%d%d%d%d", &x0, &y0, &x1, &y1);
                x0 = f(x0);
                x1 = f(x1);
                y0 = f(y0, 1e9);
                y1 = f(y1, 1e9);
                if (x0 > x1) {
                    std::swap(x0, x1);
                    std::swap(y0, y1);
                }
    #ifdef TEST
                cout << "apply " << x0 << " " << y0 << " " << x1 << " " << y1
                     << endl;
    #endif
                root->apply(x0, x1, Line{x0, y0, x1, y1, id++});
            } else if (opt == 0) {
                int x;
                scanf("%d", &x);
                x = f(x);
                lastans = root->query(x).id + 1;
    #ifdef TEST
                cout << "query " << x << " = " << lastans << endl;
    #endif
                printf("%lld\n", lastans);
            }
        }
        // int_t id = 0;
        // while (true) {
        //     std::string opt;
        //     cin >> opt;
        //     if (opt == "insert") {
        //         int_t x0, y0, x1, y1;
        //         cin >> x0 >> y0 >> x1 >> y1;
        //         // lines.push_back();
        //         root->apply(x0, x1, Line{x0, y0, x1, y1, id++});
        //     } else if (opt == "query") {
        //         int_t x;
        //         cin >> x;
        //         cout << root->query(x) << endl;
        //     }
        // }
        return 0;
    }

    注意纵坐标的模数和横坐标不同。

  • 洛谷4588 数学计算

    线段树板子。

    #include <algorithm>
    #include <iostream>
    
    using int_t = long long int;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 1e6;
    int_t upper2n(int_t x) {
        int_t result = 1;
        while (result < x) result *= 2;
        return result;
    }
    
    int_t arr[LARGE + 1];
    int_t base, q, mod;
    void build() {
        for (int_t i = base * 2; i >= 1; i--) arr[i] = 1;
    }
    int_t query(int_t left, int_t right) {
        left = left - 1 + base;
        right = right + 1 + base;
        int_t result = 1;
        while ((left xor right) != 1) {
            if ((left & 1) == 0) result = result * arr[left ^ 1] % mod;
            if (right & 1) result = result * arr[right ^ 1] % mod;
            left >>= 1;
            right >>= 1;
        }
        return result;
    }
    void modify(int_t pos, int_t x) {
        pos += base;
        arr[pos] = x;
        pos >>= 1;
        while (pos) {
            arr[pos] = arr[pos * 2] * arr[pos * 2 + 1] % mod;
            pos >>= 1;
        }
    }
    void process() {
        cin >> q >> mod;
        base = upper2n(q + 2);
        build();
        for (int_t i = 1; i <= q; i++) {
            int_t opt, val;
            cin >> opt >> val;
            if (opt == 1) {
                modify(i, val);
            } else {
                modify(val, 1);
            }
            cout << arr[1] << endl;
        }
    }
    int main() {
        int_t T;
        cin >> T;
        while (T--) process();
        return 0;
    }

     

  • 九省联考2018 IIIDX

    首先可以根据题意构建出一个森林,其中i号节点是$\frac i k$的子节点,然后要求把给定的n个权值分配到这n个点上,满足每个点的权值不超过其所有子节点的权值。

    考虑一种做法,首先建树,然后求出以每个节点为根的子树大小。

    然后把给定的权值从大到小排序,然后去重并统计一下每一个元素出现的次数,可以得到一个数组$f_i$,i表示严格第i大的权值出现的次数.

    然后把f数组求一个前缀和$F_i$,表示权值严格前i大的元素的出现次数。

    因为$F_i$是一个正序列的前缀和,所以$F_i$一定是递增的.

     

    然后由于题目要求字典序最大,所以我们按照节点编号从小到大的顺序给每一个节点分配权值。

    现在到了i号节点,设$size_i$是以i号节点为根的子树大小,那么这棵子树一共需要$size_i$个数,并且满足其他所有的数都不小于树根的数值。

    所以在$F_i$中倒着找一个位置$p$,满足$F_p \geq size_i$,并且使得$p$尽量的小(因为我们要使得编号尽量的大)。

    即在$F_i$找一个位置$p$,满足min($F_p$,$F_{p+1}$….$F_{n}$)不小于$size_i$,并且$p$尽可能的小.

    这样的意思就是说,我们找到了一个位置,前面有足够i这棵子树用的元素,同时这个值还尽可能的大。

    然后,我们把$p$位置的值留给$i$号点,然后把i到最后的位置的数量全部减掉$size_i$,因为我们给以$i$为根的子树分配了$size_i$个节点。

    但是有一点要注意。

    假设2是1的子节点,那么我们处理完1要处理2之前,需要先撤销掉在处理1时扣掉的$size_1-1$个值,因为那些位置是给1这整颗子树留的,现在我们在处理1的一颗子树,所以我们应该把在处理1时为了1的子节点而多扣掉的$size_1-1$个数加回来。

    至于为什么要在1处扣掉$size_1$而不是1?

    因为我们处理完1之后,接下来不一定要处理1的子节点。

    因为我们是按照编号序贪心,而不是按照DFS序贪心。

    如果是按照DFS序贪心的话,我们直接给1号点扣掉1,然后去处理1的子树即可,但是这样做是错误的,我们这样做只能保证按照DFS序排字典序最大。

    假设2不是1的子节点,而是兄弟节点,或者是在另一颗树里。

    所以我们要在1处减掉整颗子树的size,这样在处理2的时候,就不会影响1的子树内的结果了(因为1的子树需要的位置已经在1分配好了)

    这也解释了为什么要在$F_i$中倒着找的原因,就是尽可能的不去影响前面的子树。

    区间减和最长的最小值大于等于某个数的后缀可以用线段树维护。

    复杂度$O(nlogn)$.

     

  • TJOI2016 排序

    二分一下答案。

    然后把原序列中大于等于当前二分值的数设为1,其他设为0

    然后对这个01序列执行操作。

    如果结果为1,说明真正的答案大于等于当前二分值,反之小于当前二分值。

    #include <iostream>
    #include <algorithm>
    
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 30000;
    
    struct Node
    {
        int_t begin;
        int_t end;
        int_t value = 0;
        int_t mark;
        Node *left = nullptr;
        Node *right = nullptr;
        bool hasMark = false;
        Node(int_t begin, int_t end)
        {
            this->begin = begin;
            this->end = end;
        }
        ~Node()
        {
            if (begin != end)
            {
                delete left;
                delete right;
            }
        }
        void set(int_t x)
        {
            mark = x;
            hasMark = true;
            value = (end - begin + 1) * x;
        }
        void pushDown()
        {
            if (hasMark && begin != end)
            {
                hasMark = false;
                left->set(mark);
                right->set(mark);
            }
        }
        void maintain()
        {
            if (begin != end)
            {
                value = left->value + right->value;
            }
        }
        int_t query(int_t begin, int_t end)
        {
            if (end < this->begin || begin > this->end)
                return 0;
            if (this->begin >= begin && this->end <= end)
                return this->value;
            this->pushDown();
            return left->query(begin, end) + right->query(begin, end);
        }
        void set(int_t begin, int_t end, int_t value)
        {
            if (end < this->begin || begin > this->end)
                return;
            if (this->begin >= begin && this->end <= end)
            {
                this->set(value);
                return;
            }
            this->pushDown();
            left->set(begin, end, value);
            right->set(begin, end, value);
            this->maintain();
        }
        static Node *build(int_t begin, int_t end)
        {
            int_t mid = (begin + end) / 2;
            Node *node = new Node(begin, end);
            if (begin != end)
            {
                node->left = build(begin, mid);
                node->right = build(mid + 1, end);
            }
            return node;
        }
    };
    int_t n, m;
    int_t seq[LARGE + 1];
    int_t q;
    struct Operation
    {
        int_t opt;
        int_t left;
        int_t right;
    } opts[LARGE + 1];
    
    int_t check(int_t x)
    {
        Node *root = Node::build(1, n);
        for (int_t i = 1; i <= n; i++)
        {
            if (seq[i] >= x)
            {
                root->set(i, i, 1);
            }
        }
        for (int_t i = 1; i <= m; i++)
        {
            auto &curr = opts[i];
            int_t sum = root->query(curr.left, curr.right);
            if (curr.opt == 0)
            {
                root->set(curr.left, curr.right - sum, 0);
                root->set(curr.right - sum + 1, curr.right, 1);
            }
            else
            {
                root->set(curr.left, curr.left + sum - 1, 1);
                root->set(curr.left + sum, curr.right, 0);
            }
        }
        int_t ret = root->query(q, q);
        delete root;
        return ret;
    }
    
    int main()
    {
        cin >> n >> m;
        for (int_t i = 1; i <= n; i++)
            cin >> seq[i];
        for (int_t i = 1; i <= m; i++)
            cin >> opts[i].opt >> opts[i].left >> opts[i].right;
        cin >> q;
        int_t left = 1;
        int_t right = n;
        int_t result = 0;
        while (left + 1 < right)
        {
            int_t mid = (left + right) / 2;
            if (check(mid))
            {
                result = std::max(mid, result);
                left = mid + 1;
            }
            else
            {
                right = mid - 1;
            }
        }
        for (int_t i = left; i <= right; i++)
        {
            if (check(i))
                result = std::max(result, i);
        }
        cout << result << endl;
        return 0;
    }

     

  • SDOI2017 相关分析

    $$ \sum_{l\le i\le r}{\left( x_i-\overline{x} \right) \left( y_i-\overline{y} \right)}=\sum_{l\le i\le r}{x_iy_i}-\overline{x}\sum_{l\le i\le r}{y_i}-\overline{y}\sum_{l\le i\le r}{x_i}+\left( r-l+1 \right) \overline{x}\overline{y} \\ \sum_{l\le i\le r}{\left( x_i-\overline{x} \right) =}\sum_{l\le i\le r}{x_{i}^{2}}-2\overline{x}\sum_{l\le i\le r}{x_i}+\left( r-l+1 \right) \overline{x}^2 \\ \text{维护}\sum_{l\le i\le r}{x_i^2},\sum_{l\le i\le r}{x_iy_i},\sum_{l\le i\le r}{x_i},\sum_{l\le i\le r}{y_i} \\ \text{区间加:} \\ \sum_{l\le i\le r}{\left( x_i+S \right) ^2}=\sum_{l\le i\le r}{x_i^2}+2S\sum_{l\le i\le r}{x_i}+\left( r-l+1 \right) S^2 \\ \sum_{l\le i\le r}{\left( x_i+S \right) \left( y_i+T \right) =\sum_{l\le i\le r}{x_iy_i+S\sum_{l\le i\le r}{y_i}+T\sum_{l\le i\le r}{x_i}+\left( r-l+1 \right) ST}} \\ \text{区间覆盖等差数列} \\ \sum_{l\le i\le r}{\left( S+i \right) ^2}=\left( r-l+1 \right) S^2+2S\sum_{l\le i\le r}{i}+\sum_{l\le i\le r}{i^2} \\ \sum_{l\le i\le r}{\left( S+i \right) \left( T+i \right)}=\left( r-l+1 \right) ST+\left( S+T \right) \sum_{l\le i\le r}{i}+\sum_{l\le i\le r}{i^2} \\ $$

    一定要注意覆盖标记和区间加标记的下传顺序!!!!一定一定!!!

    #include <iostream>
    #include <algorithm>
    #include <iomanip>
    using int_t = long long int;
    using real_t = long double;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 100000;
    
    real_t S1(real_t left, real_t right)
    {
        const auto Sp1 = [](real_t x) -> real_t {
            if (x == 0)
                return 0;
            return x * (x + 1) / 2;
        };
        return Sp1(right) - Sp1(left - 1);
    }
    real_t S2(real_t left, real_t right)
    {
        const auto Sp2 = [](real_t x) -> real_t {
            if (x == 0)
                return 0;
            return x * (2 * x + 1) * (x + 1) / 6;
        };
        return Sp2(right) - Sp2(left - 1);
    }
    
    struct State
    {
        real_t x;
        real_t y;
        real_t xp2;
        real_t xy;
    
        State(real_t x = 0, real_t y = 0)
        {
            this->x = x;
            this->y = y;
            this->xy = x * y;
            this->xp2 = x * x;
        }
        State operator+(const State &another)
        {
            State result = *this;
            result.x += another.x;
            result.y += another.y;
            result.xp2 += another.xp2;
            result.xy += another.xy;
            return result;
        }
    };
    std::ostream &operator<<(std::ostream &os, const State &state);
    struct Node
    {
        int_t begin;
        int_t end;
        Node *left = nullptr;
        Node *right = nullptr;
        struct
        {
            real_t S = 0;
            real_t T = 0;
            void clear()
            {
                S = T = 0;
            }
        } addMark;
        struct
        {
            real_t S;
            real_t T;
            bool has = false;
        } setMark;
        State state;
        Node(int_t begin, int_t end)
        {
            this->begin = begin;
            this->end = end;
        }
    
        void add(real_t S, real_t T)
        {
            addMark.S += S;
            addMark.T += T;
            //注意要先修改xp2和xy,因为这两个的修改依赖于修改前x y的值
            state.xp2 = state.xp2 + 2 * S * state.x + (end - begin + 1) * S * S;
            state.xy = state.xy + S * state.y + T * state.x + (end - begin + 1) * S * T;
            state.x = state.x + (end - begin + 1) * S;
            state.y = state.y + (end - begin + 1) * T;
        }
        void set(real_t S, real_t T)
        {
            setMark.has = true;
            setMark.S = S;
            setMark.T = T;
            addMark.clear();
            state.x = (end - begin + 1) * S + S1(begin, end);
            state.y = (end - begin + 1) * T + S1(begin, end);
            state.xp2 = (end - begin + 1) * S * S + 2 * S * S1(begin, end) + S2(begin, end);
            state.xy = (end - begin + 1) * S * T + (S + T) * S1(begin, end) + S2(begin, end);
        }
        void pushDown()
        {
            if (begin != end)
            {
                //要先下传覆盖标记,再下传加标记!!
                if (setMark.has)
                {
                    setMark.has = false;
                    left->set(setMark.S, setMark.T);
                    right->set(setMark.S, setMark.T);
                }
                if (true)
                {
                    left->add(addMark.S, addMark.T);
                    right->add(addMark.S, addMark.T);
                    addMark.clear();
                }
            }
        }
        void maintain()
        {
            if (begin != end)
            {
                this->state = left->state + right->state;
            }
        }
        State query(int_t begin, int_t end)
        {
            if (end < this->begin || begin > this->end)
                return State();
            if (this->begin >= begin && this->end <= end)
            {
                return this->state;
            }
            this->pushDown();
            return left->query(begin, end) + right->query(begin, end);
        }
        void add(int_t begin, int_t end, real_t S, real_t T)
        {
            if (end < this->begin || begin > this->end)
                return;
            if (this->begin >= begin && this->end <= end)
            {
                // auto prev = state;
                this->add(S, T);
                // cout << "added at " << this->begin << "," << this->end << " with result " << state << " from " << prev << endl;
                return;
            }
            this->pushDown();
            left->add(begin, end, S, T);
            right->add(begin, end, S, T);
            this->maintain();
        }
        void set(int_t begin, int_t end, real_t S, real_t T)
        {
            if (end < this->begin || begin > this->end)
                return;
            if (this->begin >= begin && this->end <= end)
            {
                this->set(S, T);
                return;
            }
            this->pushDown();
            left->set(begin, end, S, T);
            right->set(begin, end, S, T);
            this->maintain();
        }
    
        static Node *build(int_t begin, int_t end, real_t *x, real_t *y)
        {
            Node *node = new Node(begin, end);
            if (begin != end)
            {
                int_t mid = (begin + end) / 2;
                node->left = build(begin, mid, x, y);
                node->right = build(mid + 1, end, x, y);
                node->maintain();
            }
            else
            {
                node->state = State(x[begin], y[begin]);
            }
            return node;
        }
    };
    
    int main()
    {
        std::cout.setf(std::ios::fixed);
        cout << std::setprecision(20);
        int_t n, m;
        static real_t x[LARGE + 1];
        static real_t y[LARGE + 1];
        cin >> n >> m;
        for (int_t i = 1; i <= n; i++)
            cin >> x[i];
        for (int_t i = 1; i <= n; i++)
            cin >> y[i];
        Node *root = Node::build(1, n, x, y);
        for (int_t i = 1; i <= m; i++)
        {
            int_t opt, left, right;
            real_t S, T;
            cin >> opt >> left >> right;
            if (opt == 1)
            {
                auto query = root->query(left, right);
                // cout << query << endl;
                real_t length = right - left + 1;
                real_t xAverage = query.x / length;
                real_t yAverage = query.y / length;
                cout << (query.xy - xAverage * query.y - yAverage * query.x + length * xAverage * yAverage) / (query.xp2 - 2 * xAverage * query.x + length * xAverage * xAverage) << endl;
                continue;
            }
            cin >> S >> T;
            if (opt == 2)
            {
                root->add(left, right, S, T);
                // for (int_t i = 1; i <= n; i++)
                // {
                //     cout << "pos " << i << " : " << root->query(i, i) << endl;
                // }
            }
            else
            {
                root->set(left, right, S, T);
            }
        }
        return 0;
    }
    
    std::ostream &operator<<(std::ostream &os, const State &state)
    {
        os << "State { x = " << state.x << " , y = " << state.y << " , x^2 = " << state.xp2 << " , xy = " << state.xy << " }";
        return os;
    }