728x90
반응형
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
# 1. 하이퍼파라미터 설정
learning_rate = 0.001
gamma = 0.99
n_episodes = 1000
# 2. Actor-Critic 네트워크 정의
# 핵심: Actor와 Critic이 앞단 레이어를 공유하거나, 별도 헤드를 가짐
class ActorCritic(nn.Module):
def __init__(self):
super(ActorCritic, self).__init__()
self.fc1 = nn.Linear(4, 128) # CartPole State: 4
# Actor Head: 액션 확률 반환 (Softmax)
self.actor_fc = nn.Linear(128, 2)
# Critic Head: 상태 가치(V) 반환 (Linear)
self.critic_fc = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
# Actor: 확률 분포 계산
prob = F.softmax(self.actor_fc(x), dim=-1)
# Critic: 가치 계산
value = self.critic_fc(x)
return prob, value
# 3. 학습 함수 (매 스텝 업데이트)
def train():
env = gym.make('CartPole-v1')
model = ActorCritic()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for episode in range(n_episodes):
state, _ = env.reset()
done = False
score = 0
while not done:
# -------------------------------------------------------
# [Step 1] Action 선택
# -------------------------------------------------------
state_tensor = torch.from_numpy(state).float()
prob, value = model(state_tensor)
m = Categorical(prob)
action = m.sample()
# -------------------------------------------------------
# [Step 2] 환경 상호작용
# -------------------------------------------------------
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
score += reward
# -------------------------------------------------------
# [Step 3] TD Error 계산 및 업데이트 (핵심!)
# -------------------------------------------------------
# Next State의 가치 계산 (Target 구하기 위함)
next_state_tensor = torch.from_numpy(next_state).float()
_, next_value = model(next_state_tensor)
# TD Target = Reward + Gamma * V(S') (종료 상태면 Reward만)
mask = 0 if done else 1
target_value = reward + gamma * next_value.item() * mask
# Advantage(TD Error) = Target - V(S)
# detach() 주의: Target은 그라디언트가 흐르면 안 되는 상수 취급
advantage = target_value - value
# Critic Loss (MSE): V(S)가 Target에 가까워지도록
critic_loss = F.mse_loss(value, torch.tensor([target_value]))
# Actor Loss: -log_prob * advantage
# Advantage를 detach()해서 Actor 업데이트 시 Critic 쪽으로 그라디언트 전파 방지
actor_loss = -m.log_prob(action) * advantage.detach()
# Total Loss
loss = actor_loss + critic_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
if episode % 50 == 0:
print(f"Episode: {episode}, Score: {score}")
env.close()
if __name__ == "__main__":
train()728x90
반응형
'Study > Python' 카테고리의 다른 글
| 반복 작업은 이제 그만! 파이썬으로 데이터 분석 자동화하는 완벽 가이드 (1) | 2025.05.31 |
|---|---|
| PyTorch 기초부터 완벽 정리! 딥러닝 입문자를 위한 친절한 가이드 (2) | 2025.05.27 |
| 파이썬으로 시작하는 데이터분석 여행 🐍 초보자를 위한 완벽 가이드 (7) | 2025.05.27 |
| [Python] 패키지 목록 requirement.txt 만들기 (0) | 2024.01.18 |
| [Python]샘플용 데이터프레임 쉽게 생성하기 (0) | 2023.11.17 |