とある京大生の作業ログと日々の雑記

コンピュータサイエンスについて学んだことを可視化したり日々の雑記をまとめてます。

年が明けたけど進捗が特にないのでなんとなくDQNの解説をしてみた話

こんばんは〜


年明け最初の更新から10日経ちまして、まあ色々ごちゃごちゃやってたんですけど、掲題の通りどうも進捗が良くないですよね....


書類作業もひととおり終わって、テストまでもある程度時間的に余裕があるので機械学習系で何か仕上げようと思ってたんですけど、なんかこううまくいかないんです。


ちょっと機械学習系の話題は後に回すとして、一応生きた証として最近何やってたのかを可視化するためにちょこっとまとめると


・一応毎日フランス語はやってて、もうすぐ前に買ったフランス語の文法書が一周終わるって感じ

・留学と奨学金の書類、推薦状の用意は今日でほぼ終わって、あとは向こうの大学に出願するだけ

・レポート類も今のところ全部片付いてる

・お仕事は年度末の調整に入りかけてる感じで、今月で締め


って感じですかね?


ほんとにいろんなことで時間が取られてて、どうも機械学習系のことに没頭できてなくて...


(そもそも最近あんまりパソコン触ってない...)


明日も午前中は働いて、夕方には東京に移動しなきゃいけなくて。
(土曜日に文科省に行かなきゃいけない)


今回のブログ記事は「集中して実装にじっくり取り組める時間があんまりなくてしょんぼりしてる」という萎えぽよインフィニティみたいな記事です(笑)


ちょっと気合い入れなきゃなぁって毎日思ってるんですけど思うように時間が取れないから、うまいことがんばらなきゃって思ってます。

そんなこんなでちょっと機械学習系の話題を

あんまりうだうだ言ってても仕方ないので、最近簡単に実装したDQNの簡単な解説を書きます(唐突)


DQNってのはDeep Q Networkのことで、2013年くらい?に出てきた手法でAIがゲームをやってるってことで一世風靡したやつですね。


なんで今回これの話題かっていうと、最近実装したネタがこれくらいしかないからです(笑)


(Value Iteration Networkの実装やるって言ってたのに実装してないのは許してください)


まずは簡単にNetwork部分とReplayMemoryの部分。

class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.obs2value = nn.Sequential(
            nn.Conv2d(state_num, 16, kernel_size=5, stride=2),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=5, stride=2),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=5, stride=2),
            nn.BatchNorm2d(32),
            nn.Linear(448, action_num)
        )
 
        self.experience_replay = deque()
        self.epsilon = 1
        self.action_num = output_size
 
        self.batch_size = 16
        self.memory_size = 10000
        self.gamma = 0.9
        self.mse = criterion = nn.MSELoss()
 
    def forward(self, x):
        return self.obs2value(x)
 
    def sample_action(self, epoch, state):
        if epoch == 0:
            return np.argmax(self.forward(state).data.numpy())
 
        self.epsilon /= epoch
 
        greedy = np.random.rand()
        if greedy < self.epsilon:
            action = np.random.randint(self.action_num)
        else:
            action = np.argmax(self.forward(state).data.numpy())
        return action
     
    def compute(self, state, action, reward, new_state, done, optimizer):
        self.experience_replay.append((state, action, reward, new_state, done))
        if len(self.experience_replay) > self.memory_size:
            self.experience_replay.popleft()
        if len(self.experience_replay) > self.batch_size:
            self.train(optimizer)
 
    def train(self, optimizer):
        minibatch = random.sample(self.experience_replay, self.batch_size)
         
        state = [data[0] for data in minibatch]
        action = [data[1] for data in minibatch]
        reward = [data[2] for data in minibatch]
        new_state = [data[3] for data in minibatch]
        done = [data[4] for data in minibatch]
 
        y_label = []
        q_prime = self.forward(Variable(torch.from_numpy(np.array(new_state)).float())).data.numpy()
        #get the y_label e.t. the r+Q(s',a',w-)
        for i in xrange(self.batch_size):
            if done[i]:
                y_label.append(reward[i])
            else:
                y_label.append(reward[i] + np.max(q_prime[i]))
 
        state_input = torch.from_numpy(np.array(state)).float()
        action_input = torch.from_numpy(np.array(action))
        out = self.forward(Variable(state_input))
        y_out = out.gather(1, Variable(action_input.unsqueeze(1)))
 
        optimizer.zero_grad()
        loss = self.mse(y_out, Variable(torch.from_numpy(np.array(y_label)).float()))
        loss.backward()
        optimizer.step()


DQNってただ画面をNNに投げればいいってもんじゃなくて、ReplayMemoryという{(S_t, A_t, R_t, S_{t+1})}の遷移過程をストックしておいたものをQ-valueの更新の際にランダムサンプリングしてターゲットネットワークとすることで過学習の防止策になる、というものなんですよね。


まああと報酬のクリッピングとか色々手法というかテクニックがあるんですけど、とりあえずDQNにおいて象徴的なのがReplayMemoryだと個人的に思っているので、とりあえずここはプッシュして説明だけしときます。


まああと実装の際に気をつけるのが、PyTorchをぼくが使ってるってのもあるんですけど、入力の際にチャンネルがいくつかとかバッチがいくつかとかの、入力情報の並びだと思ってます。


OpenAIGymでCartPoleは状態空間は{\mathbb{R}^1}なのに対してBreakoutは状態空間が\mathbb{R}^3で、PyTorchでのConv層(torch.nn.Conv2d)への入力は
Conv2d(N, C_{in}, W_{in}, H_{in})
で、バッチ数、チャネルの数、幅、高さとなっています。


これらの入力の仕方を気をつければ、多分実装のときにつまづくことはないかなぁ、と。


そんなこんなでtipsを書きつつ、あとは実行部分を示します。

for i in range(1000):
    state = env.reset()
    while True:
        tensor_state = torch.from_numpy(np.expand_dims(state, axis = 0)).float()
        action = agent.sample_action(i + 1, Variable(tensor_state))
        state_new, reward, done, info = env.step(action)
        agent.compute(state, action, reward, state_new, done, optimizer)
        state = state_new
        if done:
            break


まあここら辺は基本的な流れで、状態を初期化して行動を起こし、逐次得た報酬などに基づいてQ-valueを更新するというものです。


ここら辺は割とテキトーに書いとけばだいたいなんとかなります(暴論)

まとめ

なんとなくふわっとだけDQNの解説を書いときました。


(果たしてこれが解説と言えるのだろうか....)


まあぶっちゃけたことを言えば、「進捗ないです」


そろそろテスト忙しくなるし今のうちにがんばらないとなぁ〜〜〜


明日東京行くのめんどくさい.....がんばります.......