-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
125 lines (99 loc) · 4.26 KB
/
Copy pathtrain.py
File metadata and controls
125 lines (99 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from drivingModel import DribingResNet
from drivingDataset import DrivingDataset
def main():
# -----------------------------
# 1. 설정
# -----------------------------
# [경로 수정] train.py 파일의 위치를 기준으로 데이터 폴더 찾기
# 구조:
# Project/
# ├── DrivingProcess/ (현재 코드 위치)
# └── DrivingDataset/ (데이터 위치)
# 1. 현재 파일(train.py)의 절대 경로
current_file_path = os.path.abspath(__file__)
# 2. 현재 폴더 (DrivingProcess)
current_dir = os.path.dirname(current_file_path)
# 3. 상위 폴더 (Project)
parent_dir = os.path.dirname(current_dir)
# 4. 데이터 폴더 (DrivingDataset)
DATA_ROOT = os.path.join(parent_dir, "DrivingDataset")
# 디버깅: 경로가 맞는지 터미널에 출력
print(f"Current Dir: {current_dir}")
print(f"Data Root Path: {DATA_ROOT}")
# labels.csv 확인
if not os.path.exists(os.path.join(DATA_ROOT, "labels.csv")):
print(f"\n❌ [오류] '{DATA_ROOT}' 폴더 안에 labels.csv 가 없습니다!")
print(" -> 폴더명이나 구조를 다시 확인해주세요.")
return
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{DEVICE}")
# -----------------------------
# 2. 데이터 준비
# -----------------------------
data_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
full_dataset = DrivingDataset(root_dir=DATA_ROOT, transform=data_transform)
print(f"{len(full_dataset)}")
# 데이터 분할 (8:2)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print(f"Total Data: {len(full_dataset)} | Train: {len(train_dataset)} | Val: {len(val_dataset)}")
# -----------------------------
# 3. 모델, 손실함수, 최적화 초기화
# -----------------------------
model = DribingResNet(num_classes=3).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# -----------------------------
# 4. 학습 루프
# -----------------------------
best_val_loss = float("inf")
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(NUM_EPOCHS):
# 학습
model.train()
train_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
epoch_train_loss = train_loss / len(train_dataset)
# 검증
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
epoch_val_loss = val_loss / len(val_dataset)
print(f"Epoch [{epoch + 1} / {NUM_EPOCHS}] Train Loss: {epoch_train_loss:.5f} | Val Loss: {epoch_val_loss:.5f}")
# 베스트 모델 저장
if epoch_val_loss < best_val_loss:
best_val_loss = epoch_val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), "best_dirving_model.pth")
print(f" >>> Best Model Saved (Val Loss: {best_val_loss:.5f})")
print(f"\nTraining Complete. Best Validation Loss: {best_val_loss:.5f}")
if __name__ == "__main__":
main()