diff --git a/cnn_alexnet.py b/cnn_alexnet.py new file mode 100644 index 0000000..f4c4864 --- /dev/null +++ b/cnn_alexnet.py @@ -0,0 +1,449 @@ +# -*- coding: utf-8 -*- +"""CNN_Alexnet.ipynb + +Automatically generated by Colab. + +Original file is located at + https://colab.research.google.com/drive/1UkBoflNjUldTRoyBaU7T6wlQZNMf0Fan + +# **Step 1: Import Libraries** +""" + +import os + +from PIL import Image +import matplotlib.pyplot as plt + +import torch +from torch.utils.data import Dataset, DataLoader, random_split +import torchvision.transforms as t + +import numpy as np +import cv2 +from tqdm import tqdm + +try: + from google.colab import drive + IN_COLAB = True +except ImportError: + IN_COLAB = False + +if IN_COLAB: + drive.mount('/content/drive') + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print("Using device:", device) +torch.manual_seed(42) +if device.type == "cuda": + torch.cuda.manual_seed_all(42) + +"""# **Step 2: Create the Custom Dataset Class**""" + +class CustomDataset(Dataset): + + def __init__(self, data_dir, data_type="Training", size=(224, 224), transform=None): + super().__init__() + self.data_dir = data_dir + self.data_type = data_type + self.size = size + self.transform = transform + + self.samples, self.class_names = self.__process_data() + print(f"[CustomDataset] {data_type}: {len(self.samples)} images") + print(" classes:", self.class_names) + + def __len__(self): + return len(self.samples) + + def __preprocess_data(self, image_rgb, size=(224, 224)): + + image = image_rgb[..., ::-1] + clean_image = image.copy() + + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + + kernel = np.ones((3, 3), np.uint8) + thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel) + + contours, _ = cv2.findContours( + thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + + if len(contours) == 0: + resized = cv2.resize(clean_image, size, interpolation=cv2.INTER_CUBIC) + return resized[..., ::-1] + + cnt = max(contours, key=cv2.contourArea) + x, y, w, h = cv2.boundingRect(cnt) + crop = clean_image[y:y + h, x:x + w] + resized = cv2.resize(crop, size, interpolation=cv2.INTER_CUBIC) + + return resized[..., ::-1] + + def __getitem__(self, idx): + img_path, label = self.samples[idx] + + image_bgr = cv2.imread(img_path) + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + preprocessed = self.__preprocess_data(image_rgb, size=self.size) + + if self.transform is not None: + image = self.transform(preprocessed) + else: + image = t.ToTensor()(preprocessed) + + return image, torch.tensor(label).long() + + def __process_data(self): + base_dir = os.path.join(self.data_dir, self.data_type) + + class_names = [ + "Chorionic_villi", + "Decidual_tissue", + "Hemorrhage", + "Trophoblastic_tissue", + ] + + samples = [] + exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff") + + for label_idx, cls in enumerate(class_names): + cls_dir = os.path.join(base_dir, cls) + if not os.path.isdir(cls_dir): + continue + + for fname in os.listdir(cls_dir): + if fname.lower().endswith(exts): + img_path = os.path.join(cls_dir, fname) + samples.append((img_path, label_idx)) + + return samples, class_names + +def plot_grid_images(x, y, class_names=None, max_n=9): + """x: [B, C, H, W] tensor, y: [B] tensor""" + n = min(max_n, x.size(0)) + rows = cols = int(np.ceil(n ** 0.5)) + + fig, axes = plt.subplots(rows, cols, figsize=(8, 8)) + axes = axes.flatten() + + for i in range(rows * cols): + axes[i].axis("off") + + for i in range(n): + img = x[i].permute(1, 2, 0).cpu().numpy() + img = (img - img.min()) / (img.max() - img.min() + 1e-8) + axes[i].imshow(img) + if class_names is not None: + axes[i].set_title(class_names[y[i].item()], fontsize=8) + + plt.tight_layout() + plt.show() + +"""# **Step 3: Instantiate the Dataset and DataLoader**""" + +if IN_COLAB: + data_dir = "/content/drive/MyDrive/POC_Dataset" +else: + data_dir = "/path/to/POC_Dataset" + +train_transform = t.Compose( + [ + t.ToPILImage(), + t.Resize((224, 224)), + t.RandomHorizontalFlip(p=0.5), + t.RandomRotation(degrees=10), + t.ToTensor(), + t.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] +) + +test_transform = t.Compose( + [ + t.ToPILImage(), + t.Resize((224, 224)), + t.ToTensor(), + t.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] +) + +full_train_dataset = CustomDataset( + data_dir=data_dir, + data_type="Training", + size=(224, 224), + transform=train_transform, +) + +val_ratio = 0.2 +num_total = len(full_train_dataset) +num_val = int(num_total * val_ratio) +num_train = num_total - num_val + +train_dataset, val_dataset = random_split( + full_train_dataset, + [num_train, num_val], + generator=torch.Generator().manual_seed(42), +) + +print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}") + +test_dataset = CustomDataset( + data_dir=data_dir, + data_type="Testing", + size=(224, 224), + transform=test_transform, +) + +batch_size = 32 +num_workers = 2 if device.type == "cuda" else 0 +pin = (device.type == "cuda") + +train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin, +) +val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin, +) +test_dataloader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=pin, +) + +images, labels = next(iter(train_dataloader)) +print("Batch image shape:", images.shape) +print("Batch label shape:", labels.shape) + +class_names = [ + "Chorionic_villi", + "Decidual_tissue", + "Hemorrhage", + "Trophoblastic_tissue", +] +plot_grid_images(images, labels, class_names=class_names, max_n=9) + +"""# **Model Definition (Alexnet)**""" + +import torch.nn as nn +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR + +num_classes = 4 + + +class AlexNet(nn.Module): + def __init__(self, num_classes=4): + super().__init__() + + self.features = nn.Sequential( + nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + + nn.Conv2d(96, 256, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + + nn.Conv2d(256, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + + nn.Conv2d(384, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + + self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) + + self.classifier = nn.Sequential( + nn.Dropout(p=0.5), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + + nn.Dropout(p=0.5), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +model = AlexNet(num_classes=num_classes).to(device) +print(model) + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam( + model.parameters(), + lr=1e-4, + weight_decay=1e-4, +) + +scheduler = StepLR(optimizer, step_size=10, gamma=0.1) + +"""# **train / evaluate**""" + +def train_one_epoch(model, dataloader, criterion, optimizer, device): + model.train() + running_loss = 0.0 + running_corrects = 0 + total = 0 + + for images, labels in tqdm(dataloader, desc="Train", leave=False): + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + _, preds = torch.max(outputs, 1) + running_loss += loss.item() * images.size(0) + running_corrects += torch.sum(preds == labels).item() + total += images.size(0) + + epoch_loss = running_loss / total + epoch_acc = running_corrects / total + return epoch_loss, epoch_acc + + +def evaluate(model, dataloader, criterion, device): + model.eval() + running_loss = 0.0 + running_corrects = 0 + total = 0 + + with torch.no_grad(): + for images, labels in tqdm(dataloader, desc="Eval", leave=False): + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + loss = criterion(outputs, labels) + + _, preds = torch.max(outputs, 1) + running_loss += loss.item() * images.size(0) + running_corrects += torch.sum(preds == labels).item() + total += images.size(0) + + epoch_loss = running_loss / total + epoch_acc = running_corrects / total + return epoch_loss, epoch_acc + +"""# **Training Loop**""" + +num_epochs = 40 +best_val_acc = 0.0 +best_model_state = None +best_epoch = -1 + +patience = 10 +epochs_no_improve = 0 + +for epoch in range(num_epochs): + current_lr = optimizer.param_groups[0]["lr"] + print(f"\nEpoch [{epoch + 1}/{num_epochs}] (lr={current_lr:.6f})") + + train_loss, train_acc = train_one_epoch( + model, train_dataloader, criterion, optimizer, device + ) + + val_loss, val_acc = evaluate( + model, val_dataloader, criterion, device + ) + + print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc * 100:.2f}%") + print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc * 100:.2f}%") + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_model_state = model.state_dict() + best_epoch = epoch + 1 + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + + if "scheduler" in globals(): + scheduler.step() + + if epochs_no_improve >= patience: + print(f"\nEarly stopping at epoch {epoch + 1} (no improvement for {patience} epochs)") + break + +print(f"\nBest Val Acc: {best_val_acc * 100:.2f}% (at epoch {best_epoch})") + +"""# **Test / Save Model**""" + +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, +) + +if best_model_state is not None: + model.load_state_dict(best_model_state) + +test_loss, test_acc = evaluate(model, test_dataloader, criterion, device) +print(f"\n[Evaluate() 기준]") +print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}") + +model.eval() +all_labels = [] +all_preds = [] + +with torch.no_grad(): + for images, labels in test_dataloader: + images = images.to(device) + labels = labels.to(device) + + outputs = model(images) + _, preds = torch.max(outputs, 1) + + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + +acc = accuracy_score(all_labels, all_preds) +prec = precision_score(all_labels, all_preds, average="macro", zero_division=0) +rec = recall_score(all_labels, all_preds, average="macro", zero_division=0) +f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0) + +print(f"Accuracy : {acc:.4f}") +print(f"Precision: {prec:.4f}") +print(f"Recall : {rec:.4f}") +print(f"F1-score : {f1:.4f}") + +save_path = "alexnet_poc_best.pth" +torch.save(model.state_dict(), save_path) +print("\nModel saved to:", save_path) \ No newline at end of file