예시 1 답: (트리의 지름) - 1

예시 2 답: (트리의 지름)

 

 

1차 시도 - 시간초과

트리의 지름을 가지는 노드가 2쌍 이상이면 2번 예시에 들어가고, 한 쌍이면 1번 예시에 들어가게 된다.

트리의 지름을 구하는 방법은

1. 임의의 노드에서 가장 먼 거리의 노드을 찾는다.

2. 찾은 노드에서 또 가장 먼 거리의 노드를 찾는다. => 이 값이 트리의 지름, 두 노드가 트리의 끝점이다.

class Solution {
	public static int radius = 0;
	public static HashSet<NodePair> endPoints = new HashSet<NodePair>();

	public int solution(int n, int[][] edges) {
		int[] distanceFromNodeOne = new int[n + 1];

		LinkedList<Integer> connectedNodes[] = new LinkedList[n + 1];
		for (int i = 1; i <= n; i++) {
			connectedNodes[i] = new LinkedList<Integer>();
		}
		for (int i = 0; i < edges.length; i++) {
			int start = edges[i][0];
			int end = edges[i][1];

			connectedNodes[start].add(end);
			connectedNodes[end].add(start);
		}

		findRadius(1, connectedNodes, n, true);
		
		int maxSize = 0;
		for(NodePair nodepair:endPoints) {
			if(radius<=nodepair.distance) {
				maxSize++;
			}
		}

		if (maxSize >= 2) {
			return radius;
		}
		return radius - 1;
	}

	public void findRadius(int startPoint, LinkedList<Integer> connectedNodes[], int n, boolean isFirstTurn) {
		Queue<Node> queue = new LinkedList<Node>();
		queue.add(new Node(startPoint, 0));

		int maxDistance = 0;
		int distanceFromNodeOne[] = new int[n + 1];

		// 1번 노드에서 가장 먼 노드 구하기
		while (!queue.isEmpty()) {
			Node node = queue.poll();
			distanceFromNodeOne[node.name] = node.distance;
			maxDistance = Math.max(maxDistance, node.distance);

			LinkedList<Integer> connection = connectedNodes[node.name];
			for (int nextNode : connection) {
				if (distanceFromNodeOne[nextNode] == 0 && nextNode != startPoint) {
					queue.add(new Node(nextNode, node.distance + 1));
				}
			}
		}

		LinkedList<Integer> maxNodes = new LinkedList<Integer>();
		for (int i = 1; i <= n; i++) {
			if (distanceFromNodeOne[i] == maxDistance) {
				maxNodes.add(i);
			}
		}

		if (!isFirstTurn) {
			for (int nextPoint : maxNodes) {
				if (radius <= maxDistance) {
					endPoints.add(new NodePair(Math.min(startPoint, nextPoint), Math.max(startPoint, nextPoint),
							maxDistance));
					radius = maxDistance;
				}
			}
			return;
		}

		HashSet<Integer> nodeSet = new HashSet<Integer>();
		nodeSet.addAll(endPoints.stream().map(nodepair -> nodepair.node1).collect(Collectors.toSet()));
		nodeSet.addAll(endPoints.stream().map(nodepair -> nodepair.node2).collect(Collectors.toSet()));
		for (int nextPoint : maxNodes) {
			if (!nodeSet.contains(nextPoint)) {
				findRadius(nextPoint, connectedNodes, n, false);
			}
		}
	}
}

class Node {
	int name;
	int distance;

	public Node(int name, int distance) {
		this.name = name;
		this.distance = distance;
	}

	public String toString() {
		return name + " " + distance;
	}
}

class NodePair {
	int node1;
	int node2;
	int distance;

	public NodePair(int node1, int node2, int distance) {
		this.node1 = node1;
		this.node2 = node2;
		this.distance = distance;
	}

	public String toString() {
		return node1 + "-" + node2 + " " + distance;
	}
}

정확성: 84.6

합계: 84.6 / 100.0

시간초과가 난 이유는 아무래도 nodeSet을 계속 계산해 주어서가 아닐까 생각해본다.

 

 

2차 시도 - nodeSet을 static으로 빼기, 쓸데없는 객체 안쓰기 => 테스트케이스19 TL

import java.util.*;
class Solution {
	public static int radius = 0;// 트리의 지름
	public static HashSet<Integer> nodeSet = new HashSet<Integer>();// 지름이 되는 노드의 집합

	public int solution(int n, int[][] edges) {
		// 인접 노드 전처리
		LinkedList<Integer> connectedNodes[] = new LinkedList[n + 1];
		for (int i = 1; i <= n; i++) {
			connectedNodes[i] = new LinkedList<Integer>();
		}
		for (int i = 0; i < edges.length; i++) {
			int start = edges[i][0];
			int end = edges[i][1];
			connectedNodes[start].add(end);
			connectedNodes[end].add(start);
		}
		findRadius(1, connectedNodes, n, true);
		if (nodeSet.size() == 2) {
			return radius - 1;
		}
		return radius;
	}

	/*
	 * 1. 임의의 노드에서 가장 먼 거리의 노드을 찾는다. 
	 * 2. 찾은 노드에서 또 가장 먼 거리의 노드를 찾는다. => 값이 트리의 지름, 두 노드가 트리의 끝점이다.
	 */
	public void findRadius(int startPoint, LinkedList<Integer> connectedNodes[], int n, boolean isFirstTurn) {
		Queue<Integer> queue = new LinkedList<Integer>();
		queue.add(startPoint);
		int maxDistance = 0;// startPoint에서 가장 먼 길이
		int distanceFromStartNode[] = new int[n + 1];// startPoint에서 다른 노드들까지 거리
		// startPoint 노드에서 가장 먼 노드 구하기 - bfs
		LinkedList<Integer> maxNodes = new LinkedList<Integer>();// 가장 먼 노드의 집합
		while (!queue.isEmpty()) {
			int node = queue.poll();

			// 가장 먼 노드 집합 업데이트
			if (maxDistance < distanceFromStartNode[node]) {
				maxDistance = distanceFromStartNode[node];
				maxNodes.clear();
				maxNodes.add(node);
			} else if (maxDistance == distanceFromStartNode[node]) {
				maxNodes.add(node);
			}
			LinkedList<Integer> connection = connectedNodes[node];
			for (int nextNode : connection) {
				if (distanceFromStartNode[nextNode] == 0 && nextNode != startPoint) {
					queue.add(nextNode);
					distanceFromStartNode[nextNode] = distanceFromStartNode[node] + 1;
				}
			}
		}
		if (!isFirstTurn) { // 첫번째 턴이 아니면 = 지름 다 구했으니 리턴
			if (radius < maxDistance) {
				nodeSet.clear();
				nodeSet.add(startPoint);
				radius = maxDistance;
				for (int nextPoint : maxNodes) {
					nodeSet.add(nextPoint);
				}
			} else if (radius == maxDistance) {
				nodeSet.add(startPoint);
				radius = maxDistance;
				for (int nextPoint : maxNodes) {
					nodeSet.add(nextPoint);
				}
			}
			return;
		}
		
		// 첫번째 턴이면 여기서 다시 가장 먼 노드를 구한다
		for (int nextPoint : maxNodes) {
			if (!nodeSet.contains(nextPoint)) {
				findRadius(nextPoint, connectedNodes, n, false);
			}
		}
	}
}

채점 결과

정확성: 96.2

합계: 96.2 / 100.0

 

 

3차 시도 - 첫번째 턴에서 구한 모든 노드에서 다시 가장 먼 노드를 구할 때 중복된 연산이 존재한다는 것을 깨달았다.

그래서 기존 로직에서 살짝 바꿔서 트리의 지름이 될 수 있는 쌍이 몇쌍인지 먼저 생각해 보기로 했다.

1. 임의의 노드 A 에서 가장 먼 거리의 노드을 찾는다.

2. 찾은 노드 B 에서 또 가장 먼 거리의 노드 C 를 찾는다. => 이 값이 트리의 지름, 두 노드가 트리의 끝점이다.

    => C 노드가 2개 이상 == 예시 2

3. C 노드가 1개인 경우 다시 가장 먼 거리의 노드 D 를 찾는다 => D 노드가 2개 이상 == 예시 2 / 1개 == 예시 1

class Solution {
	public static int radius = 0;

	public int solution(int n, int[][] edges) {
		LinkedList<Integer> connectedNodes[] = new LinkedList[n + 1];
		for (int i = 1; i <= n; i++) {
			connectedNodes[i] = new LinkedList<Integer>();
		}
		for (int i = 0; i < edges.length; i++) {
			int start = edges[i][0];
			int end = edges[i][1];

			connectedNodes[start].add(end);
			connectedNodes[end].add(start);
		}

		boolean isonepair = isOnePair(1, connectedNodes, n, 0);

		if (isonepair) {
			return radius - 1;
		}
		return radius;
	}

	public boolean isOnePair(int startPoint, LinkedList<Integer> connectedNodes[], int n, int turn) {

		int[] distanceFromStartPoint = calculateDistanceBFS(startPoint, connectedNodes, n);
		int maxNode = findMaxNode(distanceFromStartPoint, n);

		if (turn >= 1) {
			int maxNodeNum = 0;
			radius = Math.max(radius, distanceFromStartPoint[maxNode]);
			for (int i = 1; i <= n; i++) {
				if (distanceFromStartPoint[maxNode] == distanceFromStartPoint[i]) {
					maxNodeNum++;
				}
			}

			if (maxNodeNum >= 2) {
				return false;
			}

			if (turn > 1) {
				return true;
			}

			return isOnePair(maxNode, connectedNodes, n, 2);
		}

		return isOnePair(maxNode, connectedNodes, n, 1);
	}

	public static int findMaxNode(int[] distance, int n) {
		int maxNode = 1;
		for (int i = 1; i <= n; i++) {
			if (distance[i] > distance[maxNode]) {
				maxNode = i;
			}
		}
		return maxNode;
	}

	public static int[] calculateDistanceBFS(int startPoint, LinkedList<Integer> connectedNodes[], int n) {
		Queue<Integer> queue = new LinkedList<Integer>();
		queue.add(startPoint);

		int distanceFromNode[] = new int[n + 1];
		while (!queue.isEmpty()) {
			int node = queue.poll();

			LinkedList<Integer> connection = connectedNodes[node];
			for (int nextNode : connection) {
				if (distanceFromNode[nextNode] == 0 && nextNode != startPoint) {
					queue.add(nextNode);
					distanceFromNode[nextNode] = distanceFromNode[node] + 1;
				}
			}
		}

		return distanceFromNode;
	}
}

 

 

+ Recent posts