Study/Python

Actor-Critic 샘플 코드

SigmoidFunction 2025. 12. 4. 21:11
728x90
반응형
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

# 1. 하이퍼파라미터 설정
learning_rate = 0.001
gamma = 0.99
n_episodes = 1000

# 2. Actor-Critic 네트워크 정의
# 핵심: Actor와 Critic이 앞단 레이어를 공유하거나, 별도 헤드를 가짐
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(4, 128)  # CartPole State: 4
        
        # Actor Head: 액션 확률 반환 (Softmax)
        self.actor_fc = nn.Linear(128, 2) 
        # Critic Head: 상태 가치(V) 반환 (Linear)
        self.critic_fc = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        
        # Actor: 확률 분포 계산
        prob = F.softmax(self.actor_fc(x), dim=-1)
        
        # Critic: 가치 계산
        value = self.critic_fc(x)
        
        return prob, value

# 3. 학습 함수 (매 스텝 업데이트)
def train():
    env = gym.make('CartPole-v1')
    model = ActorCritic()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for episode in range(n_episodes):
        state, _ = env.reset()
        done = False
        score = 0

        while not done:
            # -------------------------------------------------------
            # [Step 1] Action 선택
            # -------------------------------------------------------
            state_tensor = torch.from_numpy(state).float()
            prob, value = model(state_tensor)
            
            m = Categorical(prob)
            action = m.sample()
            
            # -------------------------------------------------------
            # [Step 2] 환경 상호작용
            # -------------------------------------------------------
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            score += reward
            
            # -------------------------------------------------------
            # [Step 3] TD Error 계산 및 업데이트 (핵심!)
            # -------------------------------------------------------
            # Next State의 가치 계산 (Target 구하기 위함)
            next_state_tensor = torch.from_numpy(next_state).float()
            _, next_value = model(next_state_tensor)
            
            # TD Target = Reward + Gamma * V(S') (종료 상태면 Reward만)
            mask = 0 if done else 1
            target_value = reward + gamma * next_value.item() * mask
            
            # Advantage(TD Error) = Target - V(S)
            # detach() 주의: Target은 그라디언트가 흐르면 안 되는 상수 취급
            advantage = target_value - value
            
            # Critic Loss (MSE): V(S)가 Target에 가까워지도록
            critic_loss = F.mse_loss(value, torch.tensor([target_value]))
            
            # Actor Loss: -log_prob * advantage
            # Advantage를 detach()해서 Actor 업데이트 시 Critic 쪽으로 그라디언트 전파 방지
            actor_loss = -m.log_prob(action) * advantage.detach()
            
            # Total Loss
            loss = actor_loss + critic_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            state = next_state

        if episode % 50 == 0:
            print(f"Episode: {episode}, Score: {score}")

    env.close()

if __name__ == "__main__":
    train()
728x90
반응형