fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. vector<vector<int>> buildGraph(int n, const vector<pair<int, int>>& edges) {
  5. vector<vector<int>> g(n + 1);
  6. for (auto [u, v] : edges) {
  7. g[u].push_back(v);
  8. g[v].push_back(u);
  9. }
  10. return g;
  11. }
  12.  
  13. vector<int> bfs(int root, const vector<vector<int>>& g) {
  14. int n = g.size() - 1;
  15. vector<int> par(n + 1, -1);
  16. queue<int> q;
  17. q.push(root);
  18. par[root] = 0;
  19.  
  20. // build parent pointers from root so we can trace paths back later
  21. while (!q.empty()) {
  22. int u = q.front();
  23. q.pop();
  24. for (int v : g[u]) {
  25. if (par[v] == -1) {
  26. par[v] = u;
  27. q.push(v);
  28. }
  29. }
  30. }
  31. return par;
  32. }
  33.  
  34. set<int> getPath(int root, vector<int> targets, const vector<int>& par) {
  35. set<int> res;
  36. res.insert(root);
  37.  
  38. // walk up from each target to root and grab everything in between
  39. for (int t : targets) {
  40. res.insert(t);
  41. while (t != root && t != 0) {
  42. t = par[t];
  43. res.insert(t);
  44. }
  45. }
  46. return res;
  47. }
  48.  
  49. int main() {
  50. int n = 8;
  51. vector<int> targets = {5, 6};
  52. vector<pair<int, int>> edges = {
  53. {1, 2}, {2, 5}, {2, 3}, {2, 6},
  54. {1, 4}, {4, 7}, {7, 8}
  55. };
  56.  
  57. auto g = buildGraph(n, edges);
  58. auto par = bfs(1, g);
  59.  
  60. // start and end at 1
  61. auto nodes = getPath(1, targets, par);
  62. int cost = nodes.size() == 1 ? 0 : 2 * (nodes.size() - 1);
  63.  
  64. // or try starting at n and ending at 1
  65. // this means we have to visit n, so add it to our targets
  66. auto targets2 = targets;
  67. if (find(targets2.begin(), targets2.end(), n) == targets2.end()) {
  68. targets2.push_back(n);
  69. }
  70. auto nodes2 = getPath(1, targets2, par);
  71.  
  72. // calculate how far n is from the root
  73. // if we start at n and end at 1, we save traversing this path back
  74. int dist = 0;
  75. int u = n;
  76. while (u != 1 && u != 0) {
  77. u = par[u];
  78. dist++;
  79. }
  80.  
  81. int cost2 = nodes2.size() == 1 ? 0 : 2 * (nodes2.size() - 1) - dist;
  82.  
  83. cout << min(cost, cost2) << endl;
  84.  
  85. return 0;
  86. }
Success #stdin #stdout 0.01s 5320KB
stdin
Standard input is empty
stdout
6