[C++] 백준 11049 [행렬 곱셈 순서]

date:

Updated:


Categories:

Tags: , ,

\( {_nC_2}\times{2} + _{n-2}C_2\times{2} + _{n-4}C_2\times{2} + \dots + {_2C_2}\times{2} / 10! \)

접근방식

DP로 접근하면 되는 문제이다.

풀이방법

일단, 행렬의 곱을 했을 때, 연산횟수를 구해봐야한다.

예시 행렬의 크기

A행렬의 크기 : 3 x 6
B행렬의 크기 : 6 x 4
C행렬의 크기 : 4 x 2
D행렬의 크기 : 2 x 7

행렬 곱의 연산횟수

A행렬과 B행렬을 곱한다고 생각해보자.
A행렬의 행 크기 x (A행렬의 열크기 or B행렬의 행 크기) x B행렬의 열 크기가 곱셈의 연산 횟수이다.
그래서 A x B를 했을 때, 총 연산횟수는 3 x 6 x 4 이다.

순서에 따른 행렬 곱의 총 연산횟수

A x B x C x D를 연산했을 때와, (A x (B x (C x D))) 를 했을 때와 연산 횟수가 다르다.
A x B x C x D의 연산횟수는 3 x 6 x 4 + 3 x 4 x 2 + 3 x 2 x 7 이다.
(A x (B x (C x D)))의 연산횟수는 4 x 2 x 7 + 6 x 4 x 7 + 3 x 6 x 7 이다.

점화식 작성

행렬의 크기를 담은 배열이 있다고 하자.
DP(s, e)를 행렬 배열의 s 인덱스부터 e 인덱스까지의 곱 연산횟수라고 하겠다.
s <= k < e 라고 할 때, DP(s, e) = DP(s, k) + DP(k+1, e) + (s부터 k인덱스까지의 행렬을 모두 곱한 결과행렬과 k+1부터 ~ e인덱스까지의 행렬을 모두 곱한 결과행렬의 곱 연산 횟수) 이다.
위 식의 좌변에서 우변으로 가는 과정을 괄호를 추가하는 과정이라고 보면 이해하기 쉽다.

예를들어, A, B ,C, D 행렬이 있다고 하자.
DP(0,3) 에서 DP(0,1) + DP(2,3) 이 되는 것은 (AxB)x(CxD)를 나타내는 것이다.
DP(0,3) 에서 DP(0,0) + DP(1,3) 이 되는 것은 Ax(BxCxD)를 나타내는 것이다.
그리고 A x (BxCxD)처럼 두 부분으로 나누는 것 뿐만 아니라, A와 (BxCxD) 두 행렬의 곱 연산횟수를 더해주면 되는 것이다.

if(ret == MAX){
    for(int k=s; k<e; k++){
        ret = min(ret, DP(s,k) + DP(k+1, e) + Size[s].first*Size[k].second*Size[e].second);
    }
}

점화식의 초기조건

인덱스가 같아지면 더 이상 곱할 행렬이 사라지므로 s == e가 초기조건이다.

소스코드

#include<iostream>
#include<vector>
#include<cmath>

using namespace std;

vector<pair<int, int> > Size;
int cache[501][501];
const int MAX = 987654321;
using namespace std;

int DP(int s, int e){
    if(s == e) return 0;
    int& ret = cache[s][e];
    if(ret == MAX){
        for(int k=s; k<e; k++){
            ret = min(ret, DP(s,k) + DP(k+1, e) + Size[s].first*Size[k].second*Size[e].second);
        }
    }
    return ret;
}

int main(){
    int n;
    cin >> n;
    Size = vector<pair<int, int> >(n);
    for(int i=0; i<501; i++){
        for(int j=0; j<501; j++){
            cache[i][j] = MAX;
        }
    }
    for(int i=0; i<n; i++){
        cin >> Size[i].first >> Size[i].second;
    }
    cout << DP(0, n-1);
}

Leave a comment