Skip to content

Commit d4e3483

Browse files
[Example] add Predict_Composite_Strength (#1181)
* 重新提交整理好的代码和数据 * 恢复jointContribution目录的完整内容 * 恢复jointContribution目录的完整内容 * 添加模型文件 * 修改readme.md * 完善文档说明部分,并调整文件结构 * 提交主要代码 * 更新内容 * Delete resnet.yaml * 整理参数文件 * 整理优化文件内容 * Delete Saved_Output.zip * Update mkdocs.yml * 修改文档 并重新提交conf文件 * 创建Saved_Output * 在md文件中添加saveout链接 * Update CNN_UTS.md * Update resnet.yaml * Update main.py * Update CNN_UTS.md * 修改内容 * 将注释改为英文 并修改对应文档 * 更新版本 更新版本 * 更新了一个可运行的版本 * Update mkdocs.yml * Update mkdocs.yml * Update __init__.py --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent ba0bfce commit d4e3483

File tree

9 files changed

+1822
-0
lines changed

9 files changed

+1822
-0
lines changed

docs/zh/examples/CNN_UTS.md

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Predicting the Strength of Composites
2+
3+
=== "模型训练命令"
4+
5+
``` sh
6+
python main.py mode=train
7+
```
8+
9+
=== "模型评估命令"
10+
11+
``` sh
12+
python main.py mode=eval
13+
```
14+
15+
## 下载预训练模型
16+
17+
| [resnet18-v5-fold1](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/resnet18-v5-fold1.pdparams) |
18+
[resnet18-v5-fold2](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/resnet18-v5-fold1.pdparams) |
19+
[resnet18-v5-fold3](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/resnet18-v5-fold1.pdparams) |
20+
[resnet18-v5-fold4](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/resnet18-v5-fold1.pdparams) |
21+
[resnet18-v5-fold5](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/resnet18-v5-fold1.pdparams) ||
22+
23+
## 下载模型必要参数
24+
25+
| [Saved_Output](https://paddle-org.bj.bcebos.com/paddlescience/models/CNN_UTS/Saved_Output.tar.gz) |
26+
27+
## 背景简介
28+
29+
材料的极限抗拉强度(UTS)是衡量复合材料抗拉伸破坏的核心指标,直接决定其应用安全性与可靠性。它是结构设计的关键依据,确保构件在拉伸载荷下不失效;也是材料选型的重要标准,匹配不同场景的强度需求,最终保障复合材料制品的性能上限。但由于复杂的形态-性能关系,预测其机械性能仍然较为困难,使用传统机器学习方法很难对其做出有效的预测。
30+
31+
针对材料科学领域中材料结构强度预测这一问题,通过X射线CT图像预测聚合物-陶瓷复合材料的极限抗拉强度(UTS)。相较于传统材料强度预测方法对于数据和模型的需求严苛,且需要耗费较长的时间成本,本项目通过深度学习技术,在小样本数据集的条件下,实现了较高精度的UTS值预测,提供了更快速且准确的工具。帮助研究人员快速了解材料的特性,并优化材料设计
32+
33+
本研究中使用卷积神经网络(CNN) 来分析冷烧结聚合物-陶瓷复合材料的 X 射线计算机断层扫描 (CT) 图像来应对这一问题。以形态特征作为输入的传统机器学习模型产生的准确性有限,而使用预训练的卷积神经网络,并使用集成学习进一步优化了模型。使用小型数据集来揭示复合材料中形态-结构-性能关系的替代机器学习方法,为衡量复合材料的性能提供了更精确且高效的解决方案。
34+
35+
## 目录结构
36+
37+
```
38+
CNN_UTS/
39+
40+
├─ conf/
41+
│ └─ resnet.yaml
42+
├─ data_utils.py
43+
├─ model_utils.py
44+
├─ main.py
45+
├─ requirements.txt
46+
├─ readme.md
47+
├─ resnet18-v5-finetune/
48+
├─ outputs/
49+
├─ Saved_Output/
50+
└─ Dataset/
51+
├─ Train_val/
52+
└─ Test/
53+
```
54+
55+
## 2. 模型原理
56+
57+
本章节对基于卷积神经网络的材料拉伸强度预测模型的原理进行介绍。
58+
59+
该方法的主要思想是通过卷积神经网络建立材料微观结构图像与拉伸强度(UTS)之间的非线性映射关系。模型采用ResNet架构,能够有效提取图像中的深层特征信息。
60+
61+
本案例采用ResNet-18作为基础模型架构,主要包括以下几个部分:
62+
63+
1. 输入层:接收 224×224×3 的RGB图像数据
64+
2. 卷积层:多个卷积块,包含残差连接
65+
3. 池化层:最大池化操作,降低特征图尺寸
66+
4. 全连接层:将特征映射到最终的预测值
67+
5. 输出层:输出预测的UTS值(MPa)
68+
69+
通过这种方式,我们可以自动学习材料微观结构图像中的关键特征,建立图像与性能之间的映射关系,实现准确的拉伸强度预测。
70+
71+
## 3. 模型实现
72+
73+
本章节我们讲解如何基于 PaddleScience 代码实现材料拉伸强度预测模型。本案例使用5折交叉验证进行模型训练和评估,并使用 PaddleScience 内置的各种功能模块。
74+
75+
### 3.1 数据格式说明
76+
77+
数据集下载链接:<https://paddle-org.bj.bcebos.com/paddlescience/datasets/CNN_UTS/Dataset.zip>
78+
79+
| Image Name | ...特征列... | UTS (MPa) | ... |
80+
|--------------------|--------------|-----------|-----|
81+
| IPP_10__40060.jpg | ... | 0.56 | ... |
82+
| ... | ... | ... | ... |
83+
84+
本案例使用的数据集包含材料微观结构图像和对应的拉伸强度标签。数据集分为以下几个部分:
85+
86+
1. 训练集:`Dataset/Train_val/`
87+
2. 测试集:`Dataset/Test/`
88+
89+
数据集结构如下:
90+
91+
- 每个样本包含RGB图像和对应的UTS标签
92+
- 图像经过预处理,统一调整为224×224尺寸
93+
- 使用ImageNet预训练权重的标准化参数进行归一化
94+
95+
为了方便数据处理,我们使用了 `make_dataset` 函数来创建数据集:
96+
97+
``` py linenums="73" title="examples/CNN_UTS/main.py"
98+
--8<--
99+
examples/CNN_UTS/main.py:73:74
100+
--8<--
101+
```
102+
103+
### 3.2 模型构建
104+
105+
本案例使用 PaddlePaddle 内置的 `paddle.vision.models.resnet18` 构建ResNet-18模型。模型的主要参数包括:
106+
107+
1. 网络结构:ResNet-18 (2,2,2,2)
108+
2. 输入通道:3(RGB图像)
109+
3. 输出维度:1(UTS预测值)
110+
4. 预训练权重:ImageNet
111+
112+
模型定义代码如下:
113+
114+
``` py linenums="112" title="examples/CNN_UTS/main.py"
115+
--8<--
116+
examples/CNN_UTS/main.py:112:115
117+
--8<--
118+
```
119+
120+
### 3.3 数据增强
121+
122+
为了提高模型的泛化能力,我们实现了多种数据增强策略:
123+
124+
1. 随机水平翻转
125+
2. 随机垂直翻转
126+
3. 中心裁剪到224×224
127+
4. 标准化处理
128+
129+
数据增强配置如下:
130+
131+
``` py linenums="53" title="examples/CNN_UTS/main.py"
132+
--8<--
133+
examples/CNN_UTS/main.py:53:70
134+
--8<--
135+
```
136+
137+
### 3.4 训练策略
138+
139+
本案例采用5折交叉验证策略进行模型训练:
140+
141+
1. 将训练数据分为5个fold
142+
2. 每个fold训练一个独立的模型
143+
3. 最终使用所有fold的预测结果进行集成
144+
145+
训练过程包括:
146+
147+
``` py linenums="85" title="examples/CNN_UTS/main.py"
148+
--8<--
149+
examples/CNN_UTS/main.py:85:98
150+
--8<--
151+
```
152+
153+
### 3.5 损失函数和优化器
154+
155+
使用均方误差损失函数进行回归任务:
156+
157+
``` py linenums="116" title="examples/CNN_UTS/main.py"
158+
--8<--
159+
examples/CNN_UTS/main.py:116:116
160+
--8<--
161+
```
162+
163+
使用Adam优化器进行参数更新:
164+
165+
``` py linenums="117" title="examples/CNN_UTS/main.py"
166+
--8<--
167+
examples/CNN_UTS/main.py:117:119
168+
--8<--
169+
```
170+
171+
### 3.6 模型评估
172+
173+
评估过程包括:
174+
175+
1. 计算MSE和R²指标
176+
2. 生成parity plot和violin plot
177+
3. 进行集成预测
178+
179+
评估器构建代码如下:
180+
181+
``` py linenums="156" title="examples/CNN_UTS/main.py"
182+
--8<--
183+
examples/CNN_UTS/main.py:156:188
184+
--8<--
185+
```
186+
187+
## 4. 完整代码
188+
189+
``` py linenums="1" title="examples/CNN_UTS/main.py"
190+
--8<--
191+
examples/CNN_UTS/main.py
192+
--8<--
193+
```
194+
195+
## 参考文献
196+
197+
- [Predicting the Strength of Composites with Computer Vision Using Small Experimental Datasets](<https://pubs.acs.org/doi/10.1021/acsmaterialslett.4c02424>)

examples/CNN_UTS/conf/resnet.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
mode: "train"
2+
seed: 42
3+
device: "gpu:0"
4+
data:
5+
train_path: "./Dataset/Train_val"
6+
test_path: "./Dataset/Test"
7+
N: 1
8+
train:
9+
epochs: 32
10+
n_splits: 5
11+
batch_size: 32
12+
lr: 0.0009761248347350309
13+
output_dir: "./Saved_Output"

examples/CNN_UTS/data_utils.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# data_utils.py
2+
import os
3+
import random
4+
5+
import paddle
6+
import pandas as pd
7+
from PIL import Image
8+
9+
10+
def device2str(type=None, index=None, *, device=None):
11+
type = device if device else type
12+
if isinstance(type, int):
13+
type = f"gpu:{type}"
14+
elif isinstance(type, str):
15+
if "cuda" in type:
16+
type = type.replace("cuda", "gpu")
17+
if "cpu" in type:
18+
type = "cpu"
19+
elif index is not None:
20+
type = f"{type}:{index}"
21+
elif isinstance(type, paddle.CPUPlace) or (type is None):
22+
type = "cpu"
23+
elif isinstance(type, paddle.CUDAPlace):
24+
type = f"gpu:{type.get_device_id()}"
25+
return type
26+
27+
28+
class CustomDataset(paddle.io.Dataset):
29+
def __init__(self, data, device="cpu"):
30+
self.data = data
31+
self.device = device
32+
self.preload_to_device()
33+
34+
def preload_to_device(self):
35+
self.data = [
36+
(
37+
image.to(self.device),
38+
group,
39+
paddle.to_tensor(data=features).astype(dtype="float32").to(self.device),
40+
)
41+
for image, group, features in self.data
42+
]
43+
44+
def __len__(self):
45+
return len(self.data)
46+
47+
def __getitem__(self, index):
48+
image, group, features = self.data[index]
49+
return image, group, features
50+
51+
52+
image_transforms = paddle.vision.transforms.Compose(
53+
transforms=[
54+
paddle.vision.transforms.CenterCrop(size=224),
55+
paddle.vision.transforms.ToTensor(),
56+
paddle.vision.transforms.Normalize(
57+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
58+
),
59+
]
60+
)
61+
62+
63+
def make_dataset(data_folder, N=1, verbose=False, device="cpu"):
64+
random.seed(16)
65+
this_data = []
66+
all_subfolders = [
67+
f
68+
for f in os.listdir(data_folder)
69+
if os.path.isdir(os.path.join(data_folder, f)) and len(f.split("_")) >= 3
70+
]
71+
72+
def safe_folder_sort_key(x):
73+
parts = x.split("_")
74+
try:
75+
return float(parts[-3])
76+
except Exception:
77+
return float("inf")
78+
79+
subfolders = sorted(all_subfolders, key=safe_folder_sort_key)
80+
grouped_subfolders = [[] for _ in range(5)]
81+
for i, subfolder in enumerate(subfolders):
82+
index = i // (len(subfolders) // 5)
83+
if index >= 5:
84+
index = 4
85+
grouped_subfolders[index].append(subfolder)
86+
if verbose:
87+
print("分组结果:", grouped_subfolders)
88+
chunk_keys = {}
89+
for i, gs in enumerate(grouped_subfolders):
90+
for sf in gs:
91+
chunk_keys[sf] = i
92+
sample_keys = {k: i for i, k in enumerate(subfolders)}
93+
for _ in range(len(subfolders) // 5 + 1):
94+
for k, group in enumerate(grouped_subfolders):
95+
if not group:
96+
continue
97+
selected_subfolder = random.choice(group)
98+
group.remove(selected_subfolder)
99+
folder_path = os.path.join(data_folder, selected_subfolder)
100+
if not os.path.isdir(folder_path):
101+
print(f"Warning: {folder_path} is not a valid directory")
102+
continue
103+
csv_data = None
104+
try:
105+
for file_name in os.listdir(folder_path):
106+
if file_name.endswith(".csv"):
107+
csv_path = os.path.join(folder_path, file_name)
108+
try:
109+
csv_data = pd.read_csv(csv_path)
110+
break
111+
except Exception as e:
112+
print(f"Error reading CSV file {csv_path}: {str(e)}")
113+
continue
114+
except Exception as e:
115+
print(f"Error accessing directory {folder_path}: {str(e)}")
116+
continue
117+
num = 0
118+
try:
119+
image_names = [
120+
image_name
121+
for image_name in os.listdir(folder_path)
122+
if image_name.endswith(".jpg")
123+
]
124+
image_names.sort()
125+
except Exception as e:
126+
print(f"Error reading images from {folder_path}: {str(e)}")
127+
continue
128+
for i, image_name in enumerate(image_names):
129+
if i % N != 0:
130+
continue
131+
num += 1
132+
image_path = os.path.join(folder_path, image_name)
133+
image_data = Image.open(image_path).convert("RGB")
134+
image_data = image_transforms(image_data)
135+
if csv_data is not None:
136+
image_features = (
137+
csv_data.loc[csv_data["Image Name"] == image_name, "UTS (MPa)"]
138+
.values[0]
139+
.astype(float)
140+
)
141+
else:
142+
image_features = None
143+
this_data.append(
144+
(
145+
image_data,
146+
(
147+
chunk_keys[selected_subfolder],
148+
sample_keys[selected_subfolder],
149+
),
150+
image_features,
151+
)
152+
)
153+
if verbose:
154+
print(f"文件夹 {selected_subfolder} 采样图片数: {num}")
155+
return CustomDataset(this_data, device=device)

0 commit comments

Comments
 (0)