Editorial
Using depth-first search, compute the XOR sum between vertex and all other vertices. Store the XOR sum between vertex and in . Let be the number of ones and be the number of zeros in the array . Then the answer to the problem will be .
Example
Consider the graph provided in the example.
Next to each vertex , the XOR sum between and is written. If for some vertices and , then the XOR sum between them is equal to , thus contributing to the total sum. The XOR sum for each pair of vertices , for which , contributes to the total sum.
Therefore, the answer is equal to the number of pairs of vertices for which . This number is equal to . The pairs of vertices that contribute to the total sum are: .
Algorithm realization
Store the input graph in the adjacency list . Declare an array .
vector<vector<pair<int,int> > > g; vector<int> x;
The dfs function implements a depth-first search that computes the XOR sum between vertices and . The current XOR sum between and is represented by . The parent of vertex is .
void dfs(int v, int cur_xor, int p = -1) { x[v] = cur_xor; for (auto z : g[v]) { int to = z.first; int w = z.second; if (to != p) dfs(to, cur_xor ^ w, v); } }
The main part of the program. Read the input graph.
scanf("%d", &n); g.resize(n + 1); x.resize(n + 1); for (i = 1; i < n; i++) { scanf("%d %d %d", &u, &v, &d); g[u].push_back(make_pair(v, d)); g[v].push_back(make_pair(u, d)); }
Start the depth-first search from vertex .
dfs(1, 0, -1);
Compute the number of and in the array .
ones = 0; for (i = 1; i <= n; i++) if (x[i] == 1) ones++; zeroes = n - ones;
Print the answer.
printf("%lld\n", 1LL * ones * zeroes);
Python realization
The dfs function implements a depth-first search that computes the XOR sum between vertices and . The current XOR sum between and is represented by . The parent of vertex is .
def dfs(v, cur_xor = 0, p = -1): x[v] = cur_xor for to, w in g[v]: if to != p: dfs(to, cur_xor ^ w, v)
The main part of the program. Read the input graph.
n = int(input()) g = [[] for _ in range(n + 1)] for _ in range(n - 1): u, v, d = map(int, input().split()) g[u].append((v, d)) g[v].append((u, d))
Initialize the list .
x = [0] * (n + 1)
Start the depth-first search from vertex .
dfs(1)
Compute the number of and in the array .
ones = sum(x) zeroes = n - ones
Print the answer.
print(ones * zeroes)