/ SeriousOJ /

Record Detail

Wrong Answer


  
# Status Time Cost Memory Cost
#1 Accepted 4ms 580.0 KiB
#2 Accepted 4ms 328.0 KiB
#3 Wrong Answer 4ms 576.0 KiB
#4 Wrong Answer 5ms 580.0 KiB

Code

// I AM A MUSLIM

#include "bits/stdc++.h"

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")

#define fastIO std::ios::sync_with_stdio(0);std::cin.tie(0)
#define ll long long int
#define flush fflush(stdout)
#define bl printf("\n")
#define yn(a, b) printf("%s\n", a >= b ? "Yes":"No")
// #define int ll

using pii = std::pair<int,int>;

const int MOD = 1000000007;
// const int MOD = 998244353;
const int mxN = 200100;

int N, a[mxN];
std::vector<int> g[mxN];
int dis[mxN], par[mxN];

void dfs(int u, int p, int d) {
    if (a[u]) dis[u] = d;
    for (auto &v : g[u]) {
        if (v == p) continue;
        dfs(v, u, d+1);
    }
}

void dfs_par(int u, int p, int d) {
    if (a[u]) dis[u] = d;
    for (auto &v : g[u]) {
        if (v == p) continue;
        par[v] = u;
        dfs_par(v, u, d+1);
    }
}

ll dfs_calc(int u, int p, int d) {
    ll cur = 0;
    if (a[u]) cur += 2*d;
    for (auto &v : g[u]) {
        if (v == p) continue;
        cur += dfs_calc(v, u, d+1);
    }
    return cur;
}

signed main() {
    // fastIO;
    int testCases=1;
    scanf("%d",&testCases);
    // std::cin>>testCases;
    
    for (int TC = 1; TC <= testCases; TC++) {
        scanf("%d",&N);
        for (int i = 1; i <= N; i++) {
            scanf("%d",&a[i]);
        }
        for (int i = 0, u,v; i < N-1; i++) {
            scanf("%d%d",&u,&v);
            g[u].push_back(v);
            g[v].push_back(u);
        }
        
        for (int i = 1; i <= N; i++) {
            dis[i] = -1;
        }
        dfs(1, -1, 0);
        int L = -1, val = 0;
        for (int i = 1; i <= N; i++) {
            if (dis[i] != -1 && dis[i] > val) {
                val = dis[i];
                L = i;
            }
        }
        
        if (L == -1) {
            puts("0");
            continue;
        }
        for (int i = 1; i <= N; i++) {
            dis[i] = -1;
        }
        par[L] = -1;
        dfs_par(L, -1, 0);
        int R = -1; val = -1;
        for (int i = 1; i <= N; i++) {
            if (dis[i] != -1 && dis[i] > val) {
                val = dis[i];
                R = i;
            }
        }
        
        ll ans = 0;
        assert(R != -1);
        int cur = R;
        std::vector<int> path;
        while (cur != -1) {
            path.push_back(cur);
            cur = par[cur];
        } reverse(path.begin(), path.end());
        // puts("HERE");
        for (int i = 1; i < (int)path.size()-1; i++) {
            int u = path[i];
            for (auto &v : g[u]) {
                if (v == path[i-1]) continue;
                if (v == path[i+1]) continue;
                ans += dfs_calc(v, u, 1);
            }
        }
        
        printf("%lld\n", ans+val);
        for (int i = 1; i <= N; i++) g[i].clear();
        
    }
    
    return 0;
}

/*

*/

Information

Submit By
Type
Submission
Problem
P1078 Apple on Tree
Language
C++20 (G++ 13.2.0)
Submit At
2024-08-16 17:55:21
Judged At
2024-11-11 03:10:59
Judged By
Score
4
Total Time
5ms
Peak Memory
580.0 KiB