/ SeriousOJ /

Record Detail

Accepted


  
# Status Time Cost Memory Cost
#1 Accepted 2ms 540.0 KiB
#2 Accepted 2ms 540.0 KiB
#3 Accepted 1ms 540.0 KiB
#4 Accepted 16ms 540.0 KiB
#5 Accepted 5ms 540.0 KiB
#6 Accepted 19ms 788.0 KiB
#7 Accepted 11ms 540.0 KiB
#8 Accepted 14ms 540.0 KiB
#9 Accepted 20ms 540.0 KiB
#10 Accepted 21ms 284.0 KiB
#11 Accepted 20ms 540.0 KiB
#12 Accepted 14ms 540.0 KiB
#13 Accepted 26ms 768.0 KiB
#14 Accepted 25ms 328.0 KiB
#15 Accepted 25ms 540.0 KiB
#16 Accepted 25ms 540.0 KiB
#17 Accepted 26ms 540.0 KiB
#18 Accepted 25ms 540.0 KiB
#19 Accepted 25ms 328.0 KiB
#20 Accepted 26ms 540.0 KiB
#21 Accepted 21ms 540.0 KiB
#22 Accepted 20ms 540.0 KiB
#23 Accepted 14ms 540.0 KiB
#24 Accepted 16ms 540.0 KiB
#25 Accepted 18ms 540.0 KiB
#26 Accepted 1ms 540.0 KiB
#27 Accepted 2ms 540.0 KiB

Code

// BISMILLAH

#include "bits/stdc++.h"

#define fastIO std::ios::sync_with_stdio(0);std::cin.tie(0)
#define ll long long int
#define flush fflush(stdout)
// #define int ll

using pii = std::pair<int,int>;

const int MOD = 1000000007;
// const int MOD = 998244353;
const int mxN = 200005, inf = 1000000005, bit_len = 60;

int N;
ll k;
int bit[bit_len];

ll dp[bit_len][2];

ll fn(int b, bool small) {
    if (b == -1) return 0;
    if (dp[b][small] != -1) return dp[b][small];
    ll ans = 0;
    if (small) {
        ans = (1ll << b)*std::max(bit[b], N - bit[b]) + fn(b-1, 1);
    } else {
        if ((1ll<<b)&k) {
            ll ans1 = (1ll << b) * bit[b] + fn(b - 1, 1);
            ll ans2 = (1ll << b) * (N - bit[b]) + fn(b - 1, 0);
            ans = std::max(ans1, ans2);
        } else {
            ans = (1ll << b) * bit[b] + fn(b-1, 0);
        }
    }
    return dp[b][small] = ans;
}

signed main() {
    // fastIO;
	int testCases=1;
	// scanf("%lld",&testCases);

    for (int T = 1; T <= testCases; T++) {
        scanf("%d%lld", &N, &k);
        memset(dp, -1, sizeof(dp));

        for (int i = 0; i < N; i++) {
            ll a;
            scanf("%lld", &a);
            for (int b = 0; b < bit_len; b++) {
                if ((1ll<<b)&a) {
                    bit[b]++;
                }
            }
        }
        int st_bit;
        for (int i = bit_len - 1; i >= 0; i--) {
            if ((1ll<<i)&k) {
                st_bit = i;
                break;
            }
        }
        ll ans = 0;
        for (int i = bit_len - 1; i > st_bit; i--) {
            ans += (1ll << i) * bit[i];
        }
        printf("%lld\n", ans+fn(st_bit, 0));
    }
	
	return 0;
}

/*

*/

Information

Submit By
Type
Submission
Problem
P1054 Yet another challenge for Roy!
Language
C++17 (G++ 13.2.0)
Submit At
2024-05-07 20:08:29
Judged At
2024-05-14 20:41:11
Judged By
Score
100
Total Time
26ms
Peak Memory
788.0 KiB