본문 바로가기

🧑🏻‍💻 Dev/알고리즘

[백준] 1967번 트리의 지름 (Java)

1967번 트리의 지름

 

1. 문제 접근


  • 특정 노드 2개를 선택해서 쭉 늘렸을 때, 가장 길이가 긴 두 노드를 찾아야 한다. 이 길이를 문제에서는 트리의 지름이라고 칭합니다.
  • 1번 노드부터 탐색해서 가장 가중치가 큰 노드 하나를 선택합니다.
  • 가장 가중치가 큰 노드는 반드시 트리의 지름의 한 점이 된다고 생각했습니다.
  • 가장 가중치가 큰 노드를 하나 선택했으면, 해당 노드부터 BFS로 다시 탐색을 시작해서 가장 길이가 긴 값(트리의 지름)을 반환합니다.

 

 

 

2. 접근 과정에서 실수했던 부분


  • 문제에서 "간선에 대한 정보는 부모 노드의 번호가 작은 것이 먼저 입력되고, 부모 노드의 번호가 같으면 자식 노드의 번호가 작은 것이 먼저 입력된다."라는 부분이 있었습니다.
  • 이 부분을 잘못 이해해서 다음과 같이 문제 접근을 시작했습니다.

 

(1) 입력을 받으면서 가장 가중치가 큰 노드를 선택

Node[] nodes = new Node[n + 1]; // 가장 가중치가 큰 노드를 찾기 위한 배열
nodes[0] = new Node(0, 0); // NullPointerException 방지를 위해
nodes[1] = new Node(1, 0); // 1번은 루트이기 때문에 가중치가 0

for (int i = 1; i < n; i++) {
    StringTokenizer st = new StringTokenizer(br.readLine());
    int first = Integer.parseInt(st.nextToken());
    int second = Integer.parseInt(st.nextToken());
    int weight = Integer.parseInt(st.nextToken());
    map[first].add(new Node(second, weight)); // 트리의 양방향 설정
    map[second].add(new Node(first, weight)); // 트리의 양방향 설정
    int beforeWeight = nodes[first] == null ? 0 : nodes[first].weight;
    nodes[second] = new Node(second, beforeWeight + weight);
}

Arrays.sort(nodes, (n1, n2) -> n2.weight - n1.weight);
// 도달 거리가 가장 큰 놈을 하나 선택
int start = nodes[0].end;
  • O(N)에 모든 것을 다 챙겨가려는 오만한 생각에 코드를 이렇게 짰습니다.
  • 이렇게 해서 문제 제출을 해보니 67%쯤에서 계속 틀렸다고 나왔습니다.

 

 

(2) 도움이 됐던 반례

  • 이진 트리라는 조건이 없어서 생겼던 웬만한 반례들은 모두 통과가 됐습니다.
  • 그래서 처음에는 어느 부분에서 실수를 한지 판단할 수가 없어서 계속 생각해 보다가 알고리즘에 도움을 줬던 반례가 있었습니다.
12
1 3 3
1 12 2
2 5 1
2 11 7
3 2 50
4 8 15
4 9 4
6 7 6
6 10 10
12 4 11
12 6 9

정답 : 88
  • 해당 반례는 딱 제가 실수했던 부분에 대한 반례였습니다.
  • "간선에 대한 정보는 부모 노드의 번호가 작은 것이 먼저 입력되고, 부모 노드의 번호가 같으면 자식 노드의 번호가 작은 것이 먼저 입력된다."라는 조건은 확실하게 만족하는 입력 값들이었습니다.
  • 제 알고리즘대로라면 {2 5 1}이 입력되었을 때, 5라는 노드의 가중치는 2라는 노드의 가중치에 1을 더한 값이 세팅되어야 합니다. 하지만 위 입력에서는 아직 노드 2에 대한 제대로 된 가중치가 저장되기 전이기 때문에 옳지 않은 정답을 출력한 것을 알 수 있습니다.

 

 

 

3. 문제 해결


  • 기존의 나머지 알고리즘은 그대로 두고, 문제가 있던 입력받는 부분에 대해서만 수정을 했습니다.
  • 입력은 그대로 인접 리스트를 사용하여 양방향 정보를 저장하고, BFS 탐색을 2번 했습니다.
  • 첫 번째 탐색에서는 가중치의 최대값을 갱신하면서 그 최대 가중치를 갖는 노드(start)를 찾아냈습니다.
  • 두 번째 탐색에서는 최대 가중치를 갖는 노드(start)부터 다시 탐색을 시작해서 가장 가중치가 긴 값(트리의 지름)을 갱신해서 반환했습니다.
import java.io.*;
import java.util.*;

public class Main {

    static int n, start;
    static List<Node>[] map;
    static boolean[] visited;

    public static void main(String[] args) throws IOException {
        System.setIn(new FileInputStream("src/main.txt"));
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        map = new ArrayList[n + 1];
        for (int i = 0; i <= n; i++) {
            map[i] = new ArrayList<>();
        }

        for (int i = 1; i < n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int first = Integer.parseInt(st.nextToken());
            int second = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            map[first].add(new Node(second, weight));
            map[second].add(new Node(first, weight));
        }
        // 가장 큰 노드를 탐색
        visited = new boolean[n + 1];
        bfs(1);

        // 해당 노드부터 탐색을 시작해서 최장 거리를 찾음
        visited = new boolean[n + 1];
        int result = bfs(start);
        System.out.println(result);
    }

    public static int bfs(int startNodeNumber) {
        int result = 0;
        Queue<Node> queue = new LinkedList<>();
        queue.add(new Node(startNodeNumber, 0));
        visited[startNodeNumber] = true;

        while (!queue.isEmpty()) {
            Node curr = queue.poll();
            if (result < curr.weight) {
                result = curr.weight;
                start = curr.end;
            }

            for (Node node : map[curr.end]) {
                if (visited[node.end]) continue;
                queue.add(new Node(node.end, curr.weight + node.weight));
                visited[node.end] = true;
            }
        }

        return result;
    }

    static class Node {
        int end;
        int weight;

        public Node(int end, int weight) {
            this.end = end;
            this.weight = weight;
        }
    }
}