본문 바로가기
백준_JAVA

[백준] 10830번 - 행렬 제곱

by stonage 2023. 6. 17.

 

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net

 

 

 

 

 

문제 

 

크기가 N*N 인 행렬 A를 B번 제곱한 결과를 출력하시오. 단, A^B의 각 원소를 1,000으로 나누어 출력한다.

 

 

 

 

 

조건 

 

2 <= N <= 5

1 <= B <= 100,000,000,000

행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0

시간 제한 1초

 

 

 

 

 

풀이 방향

 

필요한 개념

- 행렬의 곱

- 분할 정복

- 모듈러 연산

 

간단해 보이는 문제이지만 입력으로 주어지는 B의 범위에 의해서 단순하게 풀이하면 시간 제한 조건에 발목을 잡힌다. 행렬을 N번 제곱할 경우 시간 복잡도는 log(N)인데 B의 최대값이 1,000 억 이기 때문에 대략 1000초의 시간이 필요하다. 따라서 보다 효율적인 알고리즘을 사용해야 한다. 

행렬의 제곱에 대해서 생각해보자. 행렬의 곱은 교환법칙은 적용되지 않지만 결합법칙은 적용 가능하다. 해당 문제의 경우 입력으로 주어지는 하나의 행렬에 대한 거듭 제곱을 수행하면 되기 때문에 전체 거듭 제곱을 작은 부분으로 나누어서 먼저 계산해도 결과는 동일하다. 따라서 행렬 A에 대해서 아래와 같은 연산이 가능하다.

 

A^4 = (A^2) * (A^2)

A^5 = (A^2) * (A^2) * A

 

따라서 연쇄적으로 A의 (B / 2) 제곱 한 결과를 제곱해주면서 제곱 연산을 절반씩만 수행하는 분할 정복을 적용하면 시간 복잡도를 log 로 줄일 수 있다. 단, 주의할 점은 지수가 홀수인 경우 위 과정에서 A를 한 번 더 곱해주는 과정을 추가해준다. 

B의 최대값이 큰 만큼 자료형에도 신경써줘야 하는데, 아예 편하게 사용하는 정수 자료형을 long으로 지정해도 되고, 만약 int 범위를 초과하는 데이터가 필요 없는 변수에는 int으로 선언해주어도 된다. 나의 경우에는 입력으로 주어지는 지수인 B만 long으로 지정해주었는데, 분할 정복의 중간마다 각 원소에 1,000으로 나눈 값을 저장하기 때문에 굳이 long 타입을 사용할 필요가 없었다. 

 

private int[][] getSquare (long exp) {
        
    int[][] result = {};


    if (exp == 1) {
    
    	// square은 입력으로 주어진 행렬 A를 의미한다. 
        return square;
    }

    int[][] log_square = getSquare(exp / 2);
	
    
    // 지수가 짝수인 경우 (지수 / 2)만큼 제곱한 결과를 한 번 더 제곱해준다.
    if (exp % 2 == 0) 
        result = multipleSquare(log_square, log_square);
        
    // 지수가 홀수인 경우 짝수인 경우의 결과에 처음 입력으로 주어진 행렬 A를 한 번 더 곱해준다. 
    else 
        result = multipleSquare(multipleSquare(log_square, log_square), square);

    return result;
}


// 매개변수로 받은 두 행렬 A와 B에 대한 곱을 반환해주는 함수. 지수가 홀수인 경우를 아우르기 위해 단순히 제곱하는 경우에도 독립적인 행렬로 취급한다. 
public int[][] multipleSquare (int[][] A, int[][] B) {

    int r_leng = A.length;
    int c_leng = B[0].length;
    int s_leng = A[0].length;

    int[][] result = new int[r_leng][c_leng];

    for (int i = 0; i < r_leng; i++) {
        for (int j = 0; j < c_leng; j++) {
            for (int k = 0; k < s_leng; k++) {
                result[i][j] += A[i][k] * B[k][j];
            }
            result[i][j] %= MOD;
        }
    }

    return result;
}

 

 

 

그리고 또 하나 주의해야할 부분은 입력으로 주어지는 B가 1인 경우이다. 만약 B == 1이라면 getSquare 함수는 곧바로 행렬 A를 반환하게 되는데, 만약 A의 원소에 1,000이 포함되어 있다면 1,000으로 나눈 나머지인 0이 아니라 1,000이 곧바로 출력되기 때문에 이 부분을 처리하지 않는다면 '틀렸습니다'를 마주하기 때문이다. 나는 행렬 A에 대한 입력을 받을 때 각 원소를 1,000으로 나눈 나머지를 행렬에 저장하여 해결하였다. 

 

 

 

 

 

 

 

 

 

 

전체 코드

import java.util.*;

public class Main {
    public static void main (String[] args) {
        new Main().solution();        
    }
    
    
    int N, MOD = 1000, square[][];
    long B;
    
    private void solution () {
        
        Scanner sc = new Scanner(System.in);
        
        N = sc.nextInt();
        B = sc.nextLong();
        
        square = new int[N][N];
        
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
            
            	// (B == 1 && 행렬 A가 1,000을 포함) 의 케이스에 대한 처리로 1,000으로 나눈 값을 행렬에 저장한다. 
                square[i][j] = sc.nextInt() % MOD;
            }
        }
        
        
        int[][] result = getSquare(B);
        
        
        for (int[] arr : result) {
            for (int num : arr) System.out.print(num + " ");
            System.out.println();        
        }
        
         
        sc.close();
    }
    
    
    private int[][] getSquare (long exp) {
        
        int[][] result = {};
        
        
        if (exp == 1) {
            return square;
        }
        
        int[][] log_square = getSquare(exp / 2);
        
        if (exp % 2 == 0) 
            result = multipleSquare(log_square, log_square);
        else 
            result = multipleSquare(multipleSquare(log_square, log_square), square);
        
        return result;
    }
    
    
    public int[][] multipleSquare (int[][] A, int[][] B) {
        
        int r_leng = A.length;
        int c_leng = B[0].length;
        int s_leng = A[0].length;
        
        int[][] result = new int[r_leng][c_leng];
        
        for (int i = 0; i < r_leng; i++) {
            for (int j = 0; j < c_leng; j++) {
                for (int k = 0; k < s_leng; k++) {
                    result[i][j] += A[i][k] * B[k][j];
                }
                result[i][j] %= MOD;
            }
        }
        
        return result;
    }
}