Editorial
Let's use the disjoint-set data structure. Initially, each vertex will be placed in a separate set. We will iterate through the edges in decreasing order of their weights. For each edge , we will run the function. It will merge the sets containing vertices and .
Consider the current edge of the tree with weight . Let it merge the sets with sizes and . All edges in these two sets have weights greater than or equal to . The edge is part of different paths. Since the weight is minimal, the distances of these paths are equal to . Therefore, the contribution of these paths to the sum of distances of all pairs of vertices is .
Example
Consider the first test.
The graph contains three paths:
(distance );
(distance );
(distance );
The sum of distances between all pairs of vertices is equal to .
Consider the second test.
Assign each vertex to a separate set.
Iterate through the edges in descending order of weight. The first edge will be . Perform . The distance has path (from to ).
The next edge will be . Perform . The distance has path (from to ).
The next edge will be . Perform . The distance have paths (from to and from to ).
The next edge will be . Perform . The distance have paths .
The desired sum of distances is equal to .
Algorithm realization
Declare arrays used by the disjoint-set system.
struct Edge { int u, v, cost; } temp; vector<Edge> e; vector<int> parent, ssize;
The cmpGreater function is a comparator that sorts the edges in descending order of weight.
int cmpGreater(Edge a, Edge b) { return a.cost > b.cost; }
The Repr function returns the representative of the set containing the vertex .
int Repr(int v) { if (v == parent[v]) return v; return parent[v] = Repr(parent[v]); }
The Union function merges sets with elements and . Implement the heuristic with set sizes.
void Union(int x, int y) { x = Repr(x); y = Repr(y); if (x == y) return; if (ssize[x] < ssize[y]) swap(x, y); parent[y] = x; ssize[x] += ssize[y]; }
The main part of the program. Read the number of vertices in the graph.
scanf("%d", &n);
Initialize the arrays.
parent.resize(n + 1); ssize.resize(n + 1); for (i = 0; i <= n; i++) { parent[i] = i; ssize[i] = 1; }
Read the edges of the tree. Store them in the array .
for (i = 0; i < n - 1; i++) { scanf("%d %d %d", &temp.u, &temp.v, &temp.cost); e.push_back(temp); }
Sort the edges in descending order of weights.
sort(e.begin(), e.end(), cmpGreater);
Compute the sum of distances between all pairs of vertices in the variable .
res = 0;
Iterate through the edges. For each edge with weight call the function. Add the contribution of all paths with distance to .
for (i = 0; i < e.size(); i++) { res += 1LL * e[i].cost * ssize[Repr(e[i].u)] * ssize[Repr(e[i].v)]; Union(e[i].u, e[i].v); }
Print the answer.
printf("%lld\n", res);
Python realization
Declare a class , describing an edge of the graph.
class Edge: def __init__(self, u, v, cost): self.u = u self.v = v self.cost = cost
The cmpGreater function is a comparator that sorts the edges in descending order of weight.
def cmpGreater(a, b): return a.cost > b.cost
The Repr function returns the representative of the set containing the vertex .
def Repr(v): if v == parent[v]: return v parent[v] = Repr(parent[v]) return parent[v]
The Union function merges sets with elements and . Implement the heuristic with set sizes.
def Union(x, y): x = Repr(x) y = Repr(y) if x == y: return if ssize[x] < ssize[y]: x, y = y, x parent[y] = x ssize[x] += ssize[y]
The main part of the program. Read the number of vertices in the graph.
n = int(input().strip())
Initialize the lists.
parent = list(range(n + 1)) ssize = [1] * (n + 1) e = []
Read the edges of the tree. Store them in the list .
for _ in range(n - 1): u, v, cost = map(int, input().split()) e.append(Edge(u, v, cost))
Sort the edges in descending order of weights.
e.sort(key=lambda x: x.cost, reverse=True)
Compute the sum of distances between all pairs of vertices in the variable .
res = 0
Iterate through the edges. For each edge with weight call the function. Add the contribution of all paths with distance to .
for edge in e: res += edge.cost * ssize[Repr(edge.u)] * ssize[Repr(edge.v)] Union(edge.u, edge.v)
Print the answer.
print(res)