BitMasking DP

Updated:

애초에 DP를 너무너무너무 못하는지라 참 쓰기가 뭐하지만 대표적인 비트마스킹 DP문제가 있습니다

https://www.acmicpc.net/problem/2098

문제를 요약하면, 1번마을에서 시작해서 다시 1번마을까지 돌아오는데, 최소 비용을 출력하는 문제이다

​근데 문제를 보면, n 사이즈가 16이다.

​우선은 일반적으로 생각을 해보면 아 1번 마을에서 시작해서 다음 마을로 가는 경우의수가 15가지이고, 그다음은 14가지, … 1가지

이렇게 되니까 15! 즉 O(n!)의 시간이 나오네, 한번 해볼까해서 계산을 해보는데, 15!이 생각보다 너무 큰숫자임을 깨닫고,

이게 아니라면 어떻게 풀어야 할까라는 생각을 하는 순간, 이것을 그냥 완탐이 아니라, DP로 한번 접근을 해보면 어떨까 라고 생각을 해보면 어떨까라는 문제이다

사실 DP로 접근이라기보다는 DP의 메모이제이션 특성을 가져온다고 보는게 정확하다고 생각을 한다

그러면 이 문제는 어떻게 메모이제이션을 할것인가? 라는 생각에 휩싸이는데, 이렇게 한번 정의를 내려보면 어떨까 라는 생각이 든다 dp[here][via]: here라는 정점에서 via를 통해 오는 최소 비용 이라고 정의를 해본다

via가 무슨 말일지 모를 수 있는데, via를 비트마스킹을 통해 접근을 할것이다

1에서 4로 가는 방법이

1 3 4, 1 2 4, 1 2 3 4가 있다고 할때

via는 1 3 4의 경우는 1을 방문하고,3을 방문하고, 4를 방문하기 때문에 2^1+2^3+2^4 값일 것이다 똑같이 1 2 4의 경우는 2^1+2^3+2^4가 채워진다고 보면된다 비트마스킹을 하도 못하는 지라 이 정도만 간략하게 설명을 하도록 하겠습니다 ㅠㅠ

​어쨌든 완전 탐색을 진행하면서, DP의 메모이제이션을 활용하여 최소 비용을 계속 갱신시켜주면 되는 문제입니다

그다음 모든 지점을 다 간다면 비트가 전부다 채워질 것이기 때문에,

종료 조건은 via==(1«n)-1일때, 그 here위치에서 1번까지 가는 길이 있을때 그값까지 전달을 해준다면 될것 입니다 문제가 좀 어려우니 코드를 보면서 진행하겠습니다

by C++

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#include <queue>
#include <map>
#include <iostream>
#include <string>
#include <math.h>
#include <set>
#include <list>
#include <climits>
#include <string.h>
#include <deque>
#include <functional>
#include <stack>
using namespace std;
typedef long long ll;
#define INF 1000000000
typedef pair<int, int> P;
typedef pair<pair<int, int>, int> PP;
typedef pair<int, pair<int, int>> PPP;
int gox[4] = { 0,1,-1,0 };
int goy[4] = { 1,0,0,-1 };
int dp[17][1 << 17];
int n;
int board[17][17];
int go(int here, int via) {
	if (via == (1 << n) - 1) { // 종료조건(전부다 방문을 했을때)
		if (board[here][0] == 0) return INF; // 원래 시작점으로 돌아갈수 없다면 불가능
		else return board[here][0]; // 돌아갈수있다면 그 가중치를 전달
	}
	int &ret = dp[here][via];
	if (ret != -1) return ret;
	ret = 1e9;
	for (int i = 0; i < n; i++) {
		int next = i;
		if (board[here][next] == 0) continue; // here에서 next로 가는 길이 없다면 갈수가 없다
		if (here != next && !(via & (1 << i))) {
			ret = min(ret, go(next, via | (1 << i)) + board[here][next]); // 방문 할때마다 가중치를 누적해준다
		}
	}
	return ret;
}
int main() {
	scanf("%d", &n);
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < n; j++) {
			scanf(" %d", &board[i][j]);
		}
	}
	memset(dp, -1, sizeof(dp));
	printf("%d\n", go(0, 1)); // 시작점이 0번지점이므로 2^0인 1을 채우고 시작해야한다
	return 0;

}

비트마스크DP의 경우는 일반적으로 배열 인덱스를 0부터 채우는 것이 일반적입니다

왜냐하면, 비트가 전부다 들어와있음을 확인할수 있는방법이 2^n 에서 1을 뺀경우인데

예를 들면, 이진수 1111은 십진수로 15임을 알수 있다. 즉 2^4 -1 로 표현이 가능하다는 점이다

모든 비트가 다 채워지면 그 2^(채워진 비트의 갯수) -1 로 표현하기 위해서

일반적으로 인덱스를 0부터 표현하는것이라고 생각합니다

저기서 이해해야 할 부분의 코드가 (via & (1«i)) 인 점인데,
여기서 via &&(1«i)이렇게 쓰면 절대로 안된다는 점입니다.
제가 알기로는 비트연산을 할때는 &&인 AND 연산자가 아니라, 비트 연산자와 관련된 & 연산, | 연산 을 해야 합니다.

via가 0 1 2을 방문했다면 00111 의상태일것입니다

다음 방문 지점이 2라면 00100 일것입니다.

이 상태에서 비트연산하면 00100 이 나오고 0이 아니기 때문에, 아 2번은 방문을 한거구나 라고 생각을 할것입니다

그래서 방문 지점이 3이면 01000 일것이고, 연산을 해보면
via: 00111
next: 01000
result: 00000
이므로 0 이나온다. AND연산을 하게 되면 둘다 1이어야 1이나오고, 하나라도 0이면 0이나옴을 알수 있다

​그래서 0이면 방문을 하자는 의미이다

방문을 하는데,
via | (1«i) 라는 코드는 via의 상태에서 i번째 자리의 비트를 가지고 전달을 하겠다는 의미입니다

그래서 위의 내용을 이으면 via: 00111
next: 01000
result: 01111
의 상태로 전달을 하게 됩니다

그리고 재귀함수를 진행하고 돌아 오게 된다면 3번비트를 키고 전달했기때문에, 다시 3번비트는 꺼진상태로 남아있게 됩니다 이와 같이 모든 경우의수를 비교하는데, 메모이제이션을 이용해서 불필요한 연산을 엄청나게 줄일수 있습니다

by Java

import java.io.*;
import java.util.StringTokenizer;

import static java.lang.Integer.min;

public class Main {
    static BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        int[][] dp = new int[n][1 << n];
        int[][] board = new int[n][n];
        for (int i = 0; i < n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            for (int j = 0; j < n; j++) {
                board[i][j] = Integer.parseInt(st.nextToken());
            }
        }
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < (1 << n); j++) {
                dp[i][j] = -1;
            }
        }

        System.out.println(go(0, 1, n, board, dp));

    }

    private static int go(int here, int via, int n, int[][] board, int[][] dp) {
        if (via == (1 << n) - 1) {
            if (board[here][0] == 0) return (int) 1e9;
            else return board[here][0];
        }
        if (dp[here][via] != -1) return dp[here][via];
        dp[here][via] = (int) 1e9;
        for (int i = 0; i < n; i++) {
            int next = i;
            if (board[here][next] == 0) continue;
            if (here != next && (via & (1 << i)) == 0) {
                dp[here][via] = min(dp[here][via], go(next, via | (1 << i), n, board, dp) + board[here][next]);
            }
        }
        return dp[here][via];
    }
}

궁금하신 사항은 댓글을 남겨주시면 감사하겠습니다^^
잘못 된점 이 있다면 댓글 남겨주시면 감사드리겠습니다

Leave a comment