LeetCode Entry

3068. Find the Maximum Sum of Node Values

23.05.2025 hard 2025 kotlin rust

Max sum by k-xoring edges

3068. Find the Maximum Sum of Node Values hard blog post substack youtube 1.webp

Join me on Telegram

https://t.me/leetcode_daily_unstoppable/997

Problem TLDR

Max sum by k-xoring edges #hard #tree #dp

Intuition

Didn’t solved the second time (2024 was the previous attempt)


    // 1 - 0 - 2
    // a - b - c
    // ^k  ^k
    //     ^k  ^k   a^k - b - c^k
    // a - b - c - d
    // ^   ^
    //     ^   ^
    //         ^   ^   a^k - b - c - d^k
    //
    // a^k b   c   d^k
    // a^k b   c^k d
    // a^k b^k c   d
    // a^k b^k c^k d^k
    // a   b^k c^k d
    // a   b   c^k d^k
    // a   b^k c   d^k

    // a - b - c    a*- b*- c - e*
    //     |            |
    //     d            d*
    // didn't see any simple law
    // maybe full search?
    // wong answer: careful with flipping the last (it flips the previous?)
    // 0-2-4-3
    //   |
    //   1
    //
    // 5-0-1*-3*-6-2*
    //     |
    //     4*
    // flip current without flipping previous:
    // 1. if has next
    // 51 minutes, use hints, looks like the same dp (and what is parity?)
    // looks like i did the same mistake in 2024 (and didn't finished dp)

What was missing:

  • I didn’t paid attention to the detail: only even number of flipped numbers is possible

Why my DFS+cache simultaion didn’t worked:

  • when children count > 1, we can’t flip them all simultaneously

Approach

  • attention to details: how many flips can be done, how flips happen when node has many chilren
  • can you rewrite DFS dp to return a single Long result?

Complexity

  • Time complexity: \(O(n)\)

  • Space complexity: \(O(1)\), O(n) for dp

Code


    fun maximumValueSum(nums: IntArray, k: Int, edges: Array<IntArray>): Long {
        var sum = 0L; var xored = 0; var diff = Int.MAX_VALUE / 2
        for (x in nums) {
            sum += 1L * max(x, x xor k)
            if (x xor k > x) xored++
            diff = min(diff, abs((x xor k) - x))
        }
        return sum - diff * (xored % 2)
    }


    fun maximumValueSum(nums: IntArray, k: Int, edges: Array<IntArray>): Long {
        val g = Array(nums.size) { ArrayList<Int>() }
        for ((u, v) in edges) { g[u] += v; g[v] += u }
        val dp = HashMap<Int, Pair<Long, Long>>()
        fun dfs(u: Int, p: Int): Pair<Long, Long> = dp.getOrPut(u) {
            var sumFlip = Long.MIN_VALUE / 2
            var sumStay = 0L
            for (v in g[u]) if (v != p) {
                val (flip, stay) = dfs(v, u)
                sumFlip = max(sumStay + stay, sumFlip + flip).also {
                sumStay = max(sumStay + flip, sumFlip + stay) }
            }
            val stay = nums[u]
            val flip = stay xor k
            sumFlip = max(sumStay + stay, sumFlip + flip).also {
            sumStay = max(sumStay + flip, sumFlip + stay) }
            sumFlip to sumStay
        }
        return dfs(0, -1).first
    }


    fun maximumValueSum(nums: IntArray, k: Int, edges: Array<IntArray>): Long {
        val g = Array(nums.size) { ArrayList<Int>() }
        for ((u, v) in edges) { g[u] += v; g[v] += u }
        val dp = HashMap<Pair<Int, Int>, Long>()
        fun dfs(u: Int, p: Int, f: Int): Long = dp.getOrPut(u to f) {
            val flip = (nums[u] xor k xor f).toLong()
            val stay = (nums[u] xor f).toLong()
            var sum = max(flip, stay)
            var flips = if (flip > stay) 1 else 0
            var diff = abs(flip - stay)
            for (v in g[u]) if (v != p) {
                val flip = dfs(v, u, k)
                val stay = dfs(v, u, 0)
                if (flip > stay) flips = flips xor 1
                diff = min(diff, abs(flip - stay))
                sum += max(flip, stay)
            }
            sum - diff * flips
        }
        return dfs(0, -1, 0)
    }


    pub fn maximum_value_sum(n: Vec<i32>, k: i32, edges: Vec<Vec<i32>>) -> i64 {
        let (mut s, mut c, mut d) = (0, 0, i32::MAX);
        for x in n {
            s += x.max(x ^ k) as i64;
            if x ^ k > x { c ^= 1 }
            d = d.min(((x ^ k) - x).abs())
        } s - d as i64 * c
    }


    long long maximumValueSum(vector<int>& n, int k, vector<vector<int>>& edges) {
        long long s = 0, c = 0, d = 1e9;
        for (int& x: n) {
            s += max(x, x ^ k);
            if ((x ^ k) > x) c ^= 1;
            d = min(d, 1LL * abs((x ^ k) - x));
        } return s - d * c;
    }