Skip to content

Commit f879c3e

Browse files
avisquidGiovanniCanali
authored andcommitted
Added EGNO block
1 parent dc808c1 commit f879c3e

File tree

9 files changed

+5850
-0
lines changed

9 files changed

+5850
-0
lines changed

pina/egno_data/1.amc

Lines changed: 4443 additions & 0 deletions
Large diffs are not rendered by default.

pina/egno_data/amc_parser.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# From official implementation of the Equivariant Graph Neural Operator.
2+
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from transforms3d.euler import euler2mat
6+
from mpl_toolkits.mplot3d import Axes3D
7+
8+
9+
class Joint:
10+
def __init__(self, name, direction, length, axis, dof, limits):
11+
"""
12+
Definition of basic joint. The joint also contains the information of the
13+
bone between it's parent joint and itself. Refer
14+
[here](https://research.cs.wisc.edu/graphics/Courses/cs-838-1999/Jeff/ASF-AMC.html)
15+
for detailed description for asf files.
16+
Parameter
17+
---------
18+
name: Name of the joint defined in the asf file. There should always be one
19+
root joint. String.
20+
direction: Default direction of the joint(bone). The motions are all defined
21+
based on this default pose.
22+
length: Length of the bone.
23+
axis: Axis of rotation for the bone.
24+
dof: Degree of freedom. Specifies the number of motion channels and in what
25+
order they appear in the AMC file.
26+
limits: Limits on each of the channels in the dof specification
27+
"""
28+
self.name = name
29+
self.direction = np.reshape(direction, [3, 1])
30+
self.length = length
31+
axis = np.deg2rad(axis)
32+
self.C = euler2mat(*axis)
33+
self.Cinv = np.linalg.inv(self.C)
34+
self.limits = np.zeros([3, 2])
35+
for lm, nm in zip(limits, dof):
36+
if nm == 'rx':
37+
self.limits[0] = lm
38+
elif nm == 'ry':
39+
self.limits[1] = lm
40+
else:
41+
self.limits[2] = lm
42+
self.parent = None
43+
self.children = []
44+
self.coordinate = None
45+
self.matrix = None
46+
47+
def set_motion(self, motion):
48+
if self.name == 'root':
49+
self.coordinate = np.reshape(np.array(motion['root'][:3]), [3, 1])
50+
rotation = np.deg2rad(motion['root'][3:])
51+
self.matrix = self.C.dot(euler2mat(*rotation)).dot(self.Cinv)
52+
else:
53+
idx = 0
54+
rotation = np.zeros(3)
55+
for axis, lm in enumerate(self.limits):
56+
if not np.array_equal(lm, np.zeros(2)):
57+
rotation[axis] = motion[self.name][idx]
58+
idx += 1
59+
rotation = np.deg2rad(rotation)
60+
self.matrix = self.parent.matrix.dot(self.C).dot(euler2mat(*rotation)).dot(self.Cinv)
61+
self.coordinate = self.parent.coordinate + self.length * self.matrix.dot(self.direction)
62+
for child in self.children:
63+
child.set_motion(motion)
64+
65+
def get_name_to_idx(self):
66+
joints = self.to_dict()
67+
name_to_idx = {}
68+
for idx, joint in enumerate(joints.values()):
69+
assert joint.name not in name_to_idx
70+
name_to_idx[joint.name] = idx
71+
self.name_to_idx = name_to_idx
72+
73+
def output_edges(self):
74+
joints = self.to_dict()
75+
name_to_idx = self.name_to_idx
76+
edges = []
77+
for idx, joint in enumerate(joints.values()):
78+
child = joint
79+
if child.parent is not None:
80+
parent = child.parent
81+
edges.append([name_to_idx[child.name], name_to_idx[parent.name]])
82+
return edges
83+
84+
def output_coord(self):
85+
N = len(self.name_to_idx)
86+
X = np.zeros((N, 3))
87+
joints = self.to_dict()
88+
name_to_idx = self.name_to_idx
89+
for idx, joint in enumerate(joints.values()):
90+
X[name_to_idx[joint.name]] = joint.coordinate.reshape(-1)
91+
return X
92+
93+
def draw(self):
94+
joints = self.to_dict()
95+
fig = plt.figure()
96+
ax = Axes3D(fig)
97+
98+
ax.set_xlim3d(-50, 10)
99+
ax.set_ylim3d(-20, 40)
100+
ax.set_zlim3d(-20, 40)
101+
102+
xs, ys, zs = [], [], []
103+
for joint in joints.values():
104+
xs.append(joint.coordinate[0, 0])
105+
ys.append(joint.coordinate[1, 0])
106+
zs.append(joint.coordinate[2, 0])
107+
plt.plot(zs, xs, ys, 'b.')
108+
109+
for joint in joints.values():
110+
child = joint
111+
if child.parent is not None:
112+
parent = child.parent
113+
xs = [child.coordinate[0, 0], parent.coordinate[0, 0]]
114+
ys = [child.coordinate[1, 0], parent.coordinate[1, 0]]
115+
zs = [child.coordinate[2, 0], parent.coordinate[2, 0]]
116+
plt.plot(zs, xs, ys, 'r')
117+
plt.show()
118+
119+
def to_dict(self):
120+
ret = {self.name: self}
121+
for child in self.children:
122+
ret.update(child.to_dict())
123+
return ret
124+
125+
def pretty_print(self):
126+
print('===================================')
127+
print('joint: %s' % self.name)
128+
print('direction:')
129+
print(self.direction)
130+
print('limits:', self.limits)
131+
print('parent:', self.parent)
132+
print('children:', self.children)
133+
134+
135+
def read_line(stream, idx):
136+
if idx >= len(stream):
137+
return None, idx
138+
line = stream[idx].strip().split()
139+
idx += 1
140+
return line, idx
141+
142+
143+
def parse_asf(file_path):
144+
'''read joint data only'''
145+
with open(file_path) as f:
146+
content = f.read().splitlines()
147+
148+
for idx, line in enumerate(content):
149+
# meta infomation is ignored
150+
if line == ':bonedata':
151+
content = content[idx + 1:]
152+
break
153+
154+
# read joints
155+
joints = {'root': Joint('root', np.zeros(3), 0, np.zeros(3), [], [])}
156+
idx = 0
157+
while True:
158+
# the order of each section is hard-coded
159+
160+
line, idx = read_line(content, idx)
161+
162+
if line[0] == ':hierarchy':
163+
break
164+
165+
assert line[0] == 'begin'
166+
167+
line, idx = read_line(content, idx)
168+
assert line[0] == 'id'
169+
170+
line, idx = read_line(content, idx)
171+
assert line[0] == 'name'
172+
name = line[1]
173+
174+
line, idx = read_line(content, idx)
175+
assert line[0] == 'direction'
176+
direction = np.array([float(axis) for axis in line[1:]])
177+
178+
# skip length
179+
line, idx = read_line(content, idx)
180+
assert line[0] == 'length'
181+
length = float(line[1])
182+
183+
line, idx = read_line(content, idx)
184+
assert line[0] == 'axis'
185+
assert line[4] == 'XYZ'
186+
187+
axis = np.array([float(axis) for axis in line[1:-1]])
188+
189+
dof = []
190+
limits = []
191+
192+
line, idx = read_line(content, idx)
193+
if line[0] == 'dof':
194+
dof = line[1:]
195+
for i in range(len(dof)):
196+
line, idx = read_line(content, idx)
197+
if i == 0:
198+
assert line[0] == 'limits'
199+
line = line[1:]
200+
assert len(line) == 2
201+
mini = float(line[0][1:])
202+
maxi = float(line[1][:-1])
203+
limits.append((mini, maxi))
204+
205+
line, idx = read_line(content, idx)
206+
207+
assert line[0] == 'end'
208+
joints[name] = Joint(
209+
name,
210+
direction,
211+
length,
212+
axis,
213+
dof,
214+
limits
215+
)
216+
217+
# read hierarchy
218+
assert line[0] == ':hierarchy'
219+
220+
line, idx = read_line(content, idx)
221+
222+
assert line[0] == 'begin'
223+
224+
while True:
225+
line, idx = read_line(content, idx)
226+
if line[0] == 'end':
227+
break
228+
assert len(line) >= 2
229+
for joint_name in line[1:]:
230+
joints[line[0]].children.append(joints[joint_name])
231+
for nm in line[1:]:
232+
joints[nm].parent = joints[line[0]]
233+
234+
return joints
235+
236+
237+
def parse_amc(file_path):
238+
with open(file_path) as f:
239+
content = f.read().splitlines()
240+
241+
for idx, line in enumerate(content):
242+
if line == ':DEGREES':
243+
content = content[idx + 1:]
244+
break
245+
246+
frames = []
247+
idx = 0
248+
line, idx = read_line(content, idx)
249+
assert line[0].isnumeric(), line
250+
EOF = False
251+
while not EOF:
252+
joint_degree = {}
253+
while True:
254+
line, idx = read_line(content, idx)
255+
if line is None:
256+
EOF = True
257+
break
258+
if line[0].isnumeric():
259+
break
260+
joint_degree[line[0]] = [float(deg) for deg in line[1:]]
261+
frames.append(joint_degree)
262+
return frames

pina/egno_data/egno_training.ipynb

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "4c55a527",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from pina.problem.zoo import SupervisedProblem\n",
11+
"import matplotlib.pyplot as plt\n",
12+
"from pina.graph import Graph\n",
13+
"import amc_parser as amc\n",
14+
"import torch"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 2,
20+
"id": "b71cb46f",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"# Edges -- comuni a tutti i grafi\n",
25+
"joints = amc.parse_asf('run.asf')\n",
26+
"joints['root'].get_name_to_idx()\n",
27+
"edge_idx = torch.tensor(joints['root'].output_edges()).T\n",
28+
"\n",
29+
"graphs = []\n",
30+
"\n",
31+
"# Motions -- diversi per ogni grafo\n",
32+
"motions = amc.parse_amc('1.amc')\n",
33+
"for time, motion in enumerate(motions):\n",
34+
" joints['root'].set_motion(motion)\n",
35+
" pos = torch.tensor(joints['root'].output_coord(), dtype=torch.float32)\n",
36+
" graph = Graph(pos=pos, edge_index=edge_idx, time=torch.tensor(time, dtype=torch.float32), velocity=torch.zeros_like(pos))\n",
37+
" graphs.append(graph)\n",
38+
"\n",
39+
"for i in range(len(graphs)-1):\n",
40+
" graphs[i].velocity = graphs[i+1].pos - graphs[i].pos\n",
41+
" graphs[-1].velocity = graphs[-2].velocity"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 4,
47+
"id": "090078da",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"# Plot each structure with different colors/markers\n",
52+
"for i, graph in enumerate(graphs):\n",
53+
" coords = graphs[i].pos.numpy()\n",
54+
" fig = plt.figure()\n",
55+
" ax = fig.add_subplot(111, projection='3d') \n",
56+
" ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2])\n",
57+
" ax.set_xlabel(\"X\")\n",
58+
" ax.set_ylabel(\"Y\")\n",
59+
" ax.set_zlabel(\"Z\")\n",
60+
"\n",
61+
" ax.set_xlim(-5, 5)\n",
62+
" ax.set_ylim(0, 35)\n",
63+
" ax.set_zlim(0, 45)\n",
64+
"\n",
65+
" # Draw edges\n",
66+
" for k, j in graph.edge_index.T.numpy():\n",
67+
" x = [coords[k, 0], coords[j, 0]]\n",
68+
" y = [coords[k, 1], coords[j, 1]]\n",
69+
" z = [coords[k, 2], coords[j, 2]]\n",
70+
" ax.plot(x, y, z, c=\"k\")\n",
71+
" ax.view_init(elev=0, azim=0)\n",
72+
" ax.set_axis_off()\n",
73+
"\n",
74+
" plt.savefig(f\"walk_{i}\")\n",
75+
" plt.close()"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 5,
81+
"id": "ecd595d4",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"input = graphs[:-1]\n",
86+
"target = graphs[1:]\n",
87+
"problem = SupervisedProblem(input, target)"
88+
]
89+
}
90+
],
91+
"metadata": {
92+
"kernelspec": {
93+
"display_name": "deep",
94+
"language": "python",
95+
"name": "python3"
96+
},
97+
"language_info": {
98+
"codemirror_mode": {
99+
"name": "ipython",
100+
"version": 3
101+
},
102+
"file_extension": ".py",
103+
"mimetype": "text/x-python",
104+
"name": "python",
105+
"nbconvert_exporter": "python",
106+
"pygments_lexer": "ipython3",
107+
"version": "3.12.11"
108+
}
109+
},
110+
"nbformat": 4,
111+
"nbformat_minor": 5
112+
}

0 commit comments

Comments
 (0)