출처 : 종만복 07 분할정복 - 행렬의 거듭제곱

코드설명

nxn 크기의 행렬 A가 주어질 때, A의 거듭제곱(power) A^m은 A를 연속해서 m번 곱한 것 입니다.

이것을 계산하는 알고리즘 자체는 어려울 것이 없지만, m이 매우 클때(m === 1,000,000 (10^6)) A^m을 구하는 것은 꽤나 시간이 오래 걸리는 작업입니다. 행렬의 곱셈에는 O(n^3)의 시간이 들기에 곧대로 m-1번의 곱셈을 통해 A^m을 구하려면 모두 O(n^3m)번의 연산이 필요합니다. n=100, m=1,000,000 이라고 하면 필요한 연산의 수는 대략 1조 정도가 됩니다. 1억연산을 1초만에 하는 컴퓨터가 1조 정도의 연산을 하려면 1000초가 걸립니다.  

 

그러나 분할 정복을 활용하면 순식간에 구할 수 있습니다.

 

먼저, 분할정복의 기본인 문제를 둘 이상의 부분문제로 나눕니다.

A^m을 구하는데 필요한 m개의 조각을 절반으로 나눕니다. 즉, 두개로 부분문제로 나눌 것 입니다.

 

A^m = A^(m/2) x A^(m/2) 

 

반으로 자르기만 하면 절반 크기의 부분 문제가 갑자기 툭 튀어나오니 fastSum()보다 간단하면 더 간단했지 그다지 다를 것이 없습니다. 

이렇게 할경우 시간복잡도는 O(n^3 * lg m)이 될 것 입니다. 곱셉의 경우 여전히 n^3이고, m 곱셈의 개수가 lg m 으로 줄어듭니다.

 

행렬의 거듭제곱을 구하는 분할정복 알고리즘을 구현합니다.

수도코드입니다.

//정방 행렬을 표현하는 SquareMatrix 클래스가 있다고 가정하자.
class SquareMatrix;
//n*n 크기의 항등행렬(identity matrix)을 반환하는 함수
SquareMatrix identity(int n);
//A^m을 반환한다.
SquareMatrix pow(SquareMatrix A, int m) {
    //기저 사례 : A^0 = 1
    if(m == 0) return identity(A.size());
    if(m % 2 > 0) return pow(A, m-1) * A;
    SquareMatrix half = pow(A, m/2);
    //A^m = (A^(m/2)) * (A^(m/2))
    return half*half;
}

 

시간복잡도

 

m이 홀수일떄(if (m%2) > 0), A^m = A*A^(m-1) (return pow(A, m-1) * A) 로 나누지 않고, A^(m/2) * A^(m/2 + 1) 좀 더 절반에 가까운 형식으로 나누는것이 더 적합하지 않을까 생각할 수 있습니다.예를 들어 A^7을 A*A^6 으로 나누는 것이 아니라, A^3 * A^4 로 나누는 것의 방안입니다. 실제로 문제의 크기가 매번 절반에 가깝게 줄어들면 기저 사례에 도달하기까지 걸리는 분할의 횟수가 줄어들기에 대부분의 분할정복 알고리즘은 가능한 한 절반에 가깝게 문제를 나누고자 합니다. 퀵 정렬에서 pivot 값을 찾아 분할하며 최대한 분할의 횟수를 줄이려는 것과 같습니다. 그러나 이 문제에서 이 방식의 분할(절반대로 나누는 분할)은 오히려 알고리즘을 더 느리게 만듭니다. 이런 식으로 문제를 나누면 A^m 을 찾기 위해 계산해야 할 부분 문제의 수가 늘어납니다.

 

만약 pow(A, 31)이 주어졌다고 가정하고, 두가지 분할방식으로 시뮬레이션 해보면 알 수 있습니다.

- 31을 절반으로 나누는 경우입니다.

pow(A, 31) => ...    분할되면서 pow(A, 8)는 pow(A, 15)를 계산할때도 호출되고 pow(A, 16)을 계산할때도 호출되므로 모두 두번 호출된다는 것을 알 수 있습니다.

따라서 pow(A, 8)과 pow(A, 7)을 계산할떄 사용하는 pow(A, 4)는 모두 세번 호출됩니다. 이와 같이 같은 값을 중복을 계산하는 일이 많기에, m이 증가함에 따라 pow(A, m)을 계산하는데 필요한 pow()의 호출횟수는 m에 대해 선형적으로 증가합니다. pow()가 한번 호출될때마다 행렬 곱셈을 한번 하기 때문에, 그림(a)가 보여주는 분할 방식은 대문자 O 표기법으로 보면 결국 m-1번 곱셈을 하는 것과 같습니다. 

 

- 31을 30 1 로 나눌 경우입니다.

lg m 개의 거듭제곱에 대해 한번씩만 호출합니다.

 

어떻게 분할하느냐에 따라 시간복잡도 차이가 커집니다. 절반으로 나누는 알고리즘이 비효율적인 이유는, 여러 번 중복되어 계산되며 시간을 소모하는 부분문제,부분문제가 중복되는 속성이 있기에 그렇습니다.

입력예시

N과 M 을 입력받습니다. N은 matic의 Size, M은 거듭제곱의 횟수입니다.

그 이후에 NxN의 matric의 값을 입력받습니다.

입력 예시 1

3 2
1 2 3
4 5 6
7 8 9

결과 예시 1

 30 36 42
 66 81 96
 102 126 150

코드

package Main;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
	public static int N, M, C;
	public static int answer = 0;
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringTokenizer st = new StringTokenizer(br.readLine());
		N = Integer.parseInt(st.nextToken()); M = Integer.parseInt(st.nextToken());
		
		int[][] matrix = new int[N][N];
		for(int i=0;i<N;i++) {
			st = new StringTokenizer(br.readLine());
			for(int j=0;j<N;j++) {
				matrix[i][j] = Integer.parseInt(st.nextToken());
			}
		}
		
		SquareMatrix squareMatrix = pow(new SquareMatrix(matrix), M);
		squareMatrix.print();
		
	}
	
	public static SquareMatrix pow(SquareMatrix A, int m) {
		//기저 사례 : A^0  1 (항등행렬 반환)
		if(m == 0) return A.identity(A);
		if(m % 2 == 1) return pow(A, m - 1).multiply(A);
		SquareMatrix half = pow(A, m / 2);
		// A^m = (A^(m/2)) * (A^(m/2))
		return half.multiply(half);
	}
	
	
	static class SquareMatrix{
		public int[][] matrix;
		
		public SquareMatrix(int[][] matrix) {
			this.matrix = matrix;
		}
		
		public SquareMatrix multiply(SquareMatrix other) {
			int size = this.matrix.length;
			int[][] newMatrix = new int[size][size];
			for(int i=0;i<size;i++) {
				for(int j=0;j<size;j++) {
					for(int k=0;k<size;k++) {
						newMatrix[i][j] += this.matrix[i][k] * other.matrix[k][j];
					}
				}
			}
			return new SquareMatrix(newMatrix);
		}
		
		//M^0 은 단위행렬, 항등행렬, identityMatrix가 반환됩니다.
		public SquareMatrix identity(SquareMatrix other) {
			int size = other.matrix.length;
			int[][] newMatrix = new int[size][size];
			for(int i=0;i<size;i++) {
				newMatrix[i][i] = 1;
			}
			return new SquareMatrix(newMatrix);
		}
		
		public void print() {
			for(int[] r : matrix) {
				for(int c : r) {
					System.out.print(" "+c);
				}
				System.out.println();
			}
		}
	}	
	
}

 

+ Recent posts