简体   繁体   中英

Product of all the nodes on the path of a tree

I was learning MO's Algorithm. In that I found a question. In which we have to make a program to take input n for n nodes of a tree then n-1 pairs of u and v denoting the connection between node u and node v. After that giving the n node values.

Then we will ask q queries. For each query we take input of k and l which denote the two nodes of that tree. Now we have to find the product of all the nodes in the path of k and l (including k and l).

I want to use MO's algorithm. https://codeforces.com/blog/entry/43230

But I am unable to make the code. Can anybody help me out in this.

The basic code for that would be:

int n, q;
int nxt[ N ], to[ N ], hd[ N ];

struct Que{
    int u, v, id;
} que[ N ];

void init() {
    // read how many nodes and how many queries
    cin >> n >> q;
    // read the edge of tree
    for ( int i = 1 ; i < n ; ++ i ) {
        int u, v; cin >> u >> v;
        // save the tree using adjacency list
        nxt[ i << 1 | 0 ] = hd[ u ];
        to[ i << 1 | 0 ] = v;
        hd[ u ] = i << 1 | 0;

        nxt[ i << 1 | 1 ] = hd[ v ];
        to[ i << 1 | 1 ] = u;
        hd[ v ] = i << 1 | 1;
    }

    for ( int i = 0 ; i < q ; ++ i ) {
        // read queries
        cin >> que[ i ].u >> que[ i ].v;
        que[ i ].id = i;
    }
}

int dfn[ N ], dfn_, block_id[ N ], block_;

int stk[ N ], stk_;

void dfs( int u, int f ) {
    dfn[ u ] = dfn_++;

    int saved_rbp = stk_;

    for ( int v_ = hd[ u ] ; v_ ; v_ = nxt[ v_ ] ) {
        if ( to[ v_ ] == f ) continue;
        dfs( to[ v_ ], u );
        if ( stk_ - saved_rbp < SQRT_N ) continue;
        for ( ++ block_ ; stk_ != saved_rbp ; )
             block_id[ stk[ -- stk_ ] ] = block_;
    }

    stk[ stk_ ++ ] = u;
}

bool inPath[ N ];

void SymmetricDifference( int u ) {
    if ( inPath[ u ] ) {
        // remove this edge
    } else {
        // add this edge
    }
    inPath[ u ] ^= 1;
}
void traverse( int& origin_u, int u ) {
    for ( int g = lca( origin_u, u ) ; origin_u != g ; origin_u = parent_of[ origin_u ] )
        SymmetricDifference( origin_u );
    for ( int v = u ; v != origin_u ; v = parent_of[ v ] )
        SymmetricDifference( v );
    origin_u = u;
}

void solve() {
    // construct blocks using dfs
    dfs( 1, 1 );
    while ( stk_ ) block_id[ stk[ -- stk_ ] ] = block_;
    // re-order our queries
    sort( que, que + q, [] ( const Que& x, const Que& y ) {
        return tie( block_id[ x.u ], dfn[ x.v ] ) < tie( block_id[ y.u ], dfn[ y.v ] );
    } );
    // apply mo's algorithm on tree
    int U = 1, V = 1;
    for ( int i = 0 ; i < q ; ++ i ) {
        pass( U, que[ i ].u );
        pass( V, que[ i ].v );
        // we could our answer of que[ i ].id
    }
}

This problem is a slight modification of the blog that you have shared.

Problem Tags:- MO's Algorithm, Trees, LCA, Binary Lifting, Sieve, Precomputation, Prime Factors

Precomputations:- Just we need to do some precomputations with seiveOfErothenesis to store the highest prime factor of each element possible in input constraints. Then using this we will store all the prime factors and their powers for each element in input array in another matrix.

Observation:- with the constraints you can see the there can be very few such primes possible for each element. For an element (10^6) there can be a maximum of 7 prime factors possible.

Modify MO Algo Given in blog:- Now in our compute method we just need to maintain a map that will store the current count of the prime factor. While adding or subtracting each element in solving the queries we will iterate on the prime factors of that element and divide our result (storing total no. of factors) with the old count of that prime and then update the count of that prime and the multiple our result with the new count.(This will be O(7) max for each addition/subtraction).

Complexity:- O(T * ((N + Q) * sqrt(N) * F)) where F is 7 in our case. F is the complexity of your check method().

  • T - no of test cases in input file.
  • N - the size of your input array.
  • Q - No. of queries.

Below is an implementation of the above approach in JAVA. computePrimePowers() and check() are the methods you would be interested in.

import java.util.*;
import java.io.*;

public class Main {

    static int BLOCK_SIZE;
    static int ar[];
    static ArrayList<Integer> graph[];
    static StringBuffer sb = new StringBuffer();

    static boolean notPrime[] = new boolean[1000001];
    static int hpf[] = new int[1000001];
    static void seive(){
        notPrime[0] = true;
        notPrime[1] = true;
        for(int i = 2; i < 1000001; i++){
            if(!notPrime[i]){
                hpf[i] = i;
            for(int j = 2 * i; j < 1000001; j += i){
                notPrime[j] = true;
                hpf[j] = i;
            }
        }
    }
    }

    static long modI[] = new long[1000001];
    static void computeModI() {
        for(int i = 1; i < 1000001; i++) {
            modI[i] = pow(i, 1000000005);
        }
    }
    static long pow(long x, long y) { 
        if (y == 0) 
            return 1; 

        long p = pow(x, y / 2);
        p = (p >= 1000000007) ? p % 1000000007 : p;
        p = p * p;
        p = (p >= 1000000007) ? p % 1000000007 : p;

        if ((y & 1) == 0) 
            return p; 
        else {
            long tt = x * p;
            return (tt >= 1000000007) ? tt % 1000000007 : tt; 
        }
    }

    public static void main(String[] args) throws Exception {
        Reader s = new Reader();
        int test = s.nextInt();
        seive();
        computeModI();
        for(int ii = 0; ii < test; ii++){
            int n = s.nextInt();
            lcaTable = new int[19][n + 1];
            graph = new ArrayList[n + 1];
            arrPrimes = new int[n + 1][7][2];
            primeCnt = new int[1000001];
            visited = new int[n + 1];
            ar = new int[n + 1];
            for(int i = 0; i < graph.length; i++) graph[i] = new ArrayList<>();
            for(int i = 1; i < n; i++){
                int u = s.nextInt(), v = s.nextInt();
                graph[u].add(v);
                graph[v].add(u);
            }
            int ip = 1; while(ip <= n) ar[ip++] = s.nextInt();

            computePrimePowers();

            int q = s.nextInt();
            LVL = new int[n + 1];

            dfsTime = 0;
            dfs(1, -1);

            BLOCK_SIZE = (int) Math.sqrt(dfsTime);
            int Q[][] = new int[q][4];
            int i = 0;
            while(q-- > 0) {
                int u = s.nextInt(), v = s.nextInt();
                Q[i][0] = lca(u, v);
                if (l[u] > l[v]) {
                    int temp = u; u = v; v = temp;
                }
                if (Q[i][0] == u) {
                    Q[i][1] = l[u];
                    Q[i][2] = l[v];
                }
                else {
                    Q[i][1] = r[u]; // left at col1 in query
                    Q[i][2] = l[v]; // right at col2
                }
                Q[i][3] = i;
                i++;
            }
            Arrays.sort(Q, new Comparator<int[]>() {
                @Override
                public int compare(int[] x, int[] y) {
                    int block_x = (x[1] - 1) / (BLOCK_SIZE + 1);
                    int block_y = (y[1] - 1) / (BLOCK_SIZE + 1);
                    if(block_x != block_y)
                        return block_x - block_y;
                    return x[2] - y[2];
                }
            });
            solveQueries(Q);
        }
        System.out.println(sb);
    }

    static long res;
    private static void solveQueries(int [][] Q) {
        int M = Q.length;
        long results[] = new long[M];
        res = 1;
        int curL = Q[0][1], curR = Q[0][1] - 1;
        int i = 0;
        while(i < M){
            while (curL < Q[i][1]) check(ID[curL++]);
            while (curL > Q[i][1]) check(ID[--curL]);
            while (curR < Q[i][2]) check(ID[++curR]);
            while (curR > Q[i][2]) check(ID[curR--]);

            int u = ID[curL], v = ID[curR];

            if (Q[i][0] != u && Q[i][0] != v) check(Q[i][0]);

            results[Q[i][3]] = res;

            if (Q[i][0] != u && Q[i][0] != v) check(Q[i][0]);

            i++;
        }

        i = 0;
        while(i < M) sb.append(results[i++] + "\n");
    }

    static int visited[];
    static int primeCnt[];
    private static void check(int x) {
        if(visited[x] == 1){
            for(int i = 0; i < 7; i++) {
                int c = arrPrimes[x][i][1];
                int pp = arrPrimes[x][i][0];
                if(pp == 0) break;
                long tem = res * modI[primeCnt[pp] + 1];
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
                primeCnt[pp] -= c;
                tem = res * (primeCnt[pp] + 1);
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
            }
        }
        else if(visited[x] == 0){
            for(int i = 0; i < 7; i++) {
                int c = arrPrimes[x][i][1];
                int pp = arrPrimes[x][i][0];
                if(pp == 0) break;
                long tem = res * modI[primeCnt[pp] + 1];
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
                primeCnt[pp] += c;
                tem = res * (primeCnt[pp] + 1);
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
            }
        }
        visited[x] ^= 1;
    }

    static int arrPrimes[][][];
    static void computePrimePowers() {
        int n = arrPrimes.length;
        int i = 0;
        while(i < n) {
            int ele = ar[i];
            int k = 0;
            while(ele > 1) {
                int c = 0;
                int pp = hpf[ele];
                while(hpf[ele] == pp) {
                    c++; ele /= pp;
                }
                arrPrimes[i][k][0] = pp;
                arrPrimes[i][k][1] = c;
                k++;
            }
            i++;
        }
    }

    static int dfsTime;
    static int l[] = new int[1000001], r[] = new int[1000001], ID[] = new int[1000001], LVL[], lcaTable[][];
    static void dfs(int u, int p){
        l[u] = ++dfsTime;
        ID[dfsTime] = u;
        int i = 1;
        while(i < 19) {
            lcaTable[i][u] = lcaTable[i - 1][lcaTable[i - 1][u]];
            i++;
        }
        i = 0;
        while(i < graph[u].size()){
            int v = graph[u].get(i);
            i++;
            if (v == p) continue;
            LVL[v] = LVL[u] + 1;
            lcaTable[0][v] = u;
            dfs(v, u);
        }
        r[u] = ++dfsTime;
        ID[dfsTime] = u;
    }

    static int lca(int u, int v){
        if (LVL[u] > LVL[v]) {
            int temp = u;
            u = v; v = temp;
        }
        int i = 18;
        while(i >= 0) {
            if (LVL[v] - (1 << i) >= LVL[u]) v = lcaTable[i][v];
            i--;
        }

        if (u == v) return u;

        i = 18;
        while(i >= 0){
            if (lcaTable[i][u] != lcaTable[i][v]){
                u = lcaTable[i][u];
                v = lcaTable[i][v];
            }
            i--;
        }
        return lcaTable[0][u];
    }
}
// SIMILAR SOLUTION FOR FINDING NUMBER OF DISTINCT ELEMENTS FROM U TO V
// USING MO's ALGORITHM
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 40005;
const int MAXM = 100005;
const int LN = 19;

int N, M, K, cur, A[MAXN], LVL[MAXN], DP[LN][MAXN];
int BL[MAXN << 1], ID[MAXN << 1], VAL[MAXN], ANS[MAXM];
int d[MAXN], l[MAXN], r[MAXN];
bool VIS[MAXN];
vector < int > adjList[MAXN];

struct query{
    int id, l, r, lc;
    bool operator < (const query& rhs){
        return (BL[l] == BL[rhs.l]) ? (r < rhs.r) : (BL[l] < BL[rhs.l]);
    }
}Q[MAXM];

// Set up Stuff
void dfs(int u, int par){
    l[u] = ++cur; 
    ID[cur] = u;
    for (int i = 1; i < LN; i++) DP[i][u] = DP[i - 1][DP[i - 1][u]];
    for (int i = 0; i < adjList[u].size(); i++){
        int v = adjList[u][i];
        if (v == par) continue;
        LVL[v] = LVL[u] + 1;
        DP[0][v] = u;
        dfs(v, u);
    }
    r[u] = ++cur; ID[cur] = u;
}

// Function returns lca of (u) and (v)
inline int lca(int u, int v){
    if (LVL[u] > LVL[v]) swap(u, v);
    for (int i = LN - 1; i >= 0; i--)
        if (LVL[v] - (1 << i) >= LVL[u]) v = DP[i][v];
    if (u == v) return u;
    for (int i = LN - 1; i >= 0; i--){
        if (DP[i][u] != DP[i][v]){
            u = DP[i][u];
            v = DP[i][v];
        }
    }
    return DP[0][u];
}

inline void check(int x, int& res){
    // If (x) occurs twice, then don't consider it's value 
    if ( (VIS[x]) and (--VAL[A[x]] == 0) ) res--; 
    else if ( (!VIS[x]) and (VAL[A[x]]++ == 0) ) res++;
    VIS[x] ^= 1;
}

void compute(){

    // Perform standard Mo's Algorithm
    int curL = Q[0].l, curR = Q[0].l - 1, res = 0;

    for (int i = 0; i < M; i++){

        while (curL < Q[i].l) check(ID[curL++], res);
        while (curL > Q[i].l) check(ID[--curL], res);
        while (curR < Q[i].r) check(ID[++curR], res);
        while (curR > Q[i].r) check(ID[curR--], res);

        int u = ID[curL], v = ID[curR];

        // Case 2
        if (Q[i].lc != u and Q[i].lc != v) check(Q[i].lc, res);

        ANS[Q[i].id] = res;

        if (Q[i].lc != u and Q[i].lc != v) check(Q[i].lc, res);
    }

    for (int i = 0; i < M; i++) printf("%d\n", ANS[i]);
}

int main(){

    int u, v, x;

    while (scanf("%d %d", &N, &M) != EOF){

        // Cleanup
        cur = 0;
        memset(VIS, 0, sizeof(VIS));
        memset(VAL, 0, sizeof(VAL));
        for (int i = 1; i <= N; i++) adjList[i].clear();

        // Inputting Values
        for (int i = 1; i <= N; i++) scanf("%d", &A[i]);
        memcpy(d + 1, A + 1, sizeof(int) * N);

        // Compressing Coordinates
        sort(d + 1, d + N + 1);
        K = unique(d + 1, d + N + 1) - d - 1;
        for (int i = 1; i <= N; i++) A[i] = lower_bound(d + 1, d + K + 1, A[i]) - d;

        // Inputting Tree
        for (int i = 1; i < N; i++){
            scanf("%d %d", &u, &v);
            adjList[u].push_back(v);
            adjList[v].push_back(u);
        }

        // Preprocess
        DP[0][1] = 1;
        dfs(1, -1);
        int size = sqrt(cur);

        for (int i = 1; i <= cur; i++) BL[i] = (i - 1) / size + 1;

        for (int i = 0; i < M; i++){
            scanf("%d %d", &u, &v);
            Q[i].lc = lca(u, v);
            if (l[u] > l[v]) swap(u, v);
            if (Q[i].lc == u) Q[i].l = l[u], Q[i].r = l[v];
            else Q[i].l = r[u], Q[i].r = l[v];
            Q[i].id = i;
        }

        sort(Q, Q + M);
        compute();
    }
}

Demo

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM