1+ import ast
2+ import keyword
3+ from pathlib import Path
4+ from typing import List , Dict
5+
6+ import astor
7+ from lionwebpython .language import Language , Concept , Interface , Containment , Property
8+ from lionwebpython .language .classifier import Classifier
9+ from lionwebpython .language .enumeration import Enumeration
10+ from lionwebpython .language .primitive_type import PrimitiveType
11+ from lionwebpython .language .reference import Reference
12+ from lionwebpython .lionweb_version import LionWebVersion
13+
14+ from pylasu .lionweb .starlasu import StarLasuBaseLanguage
15+ from pylasu .lionweb .utils import to_snake_case
16+
17+
18+ def make_cond (enumeration_name : str , member_name : str ):
19+ return ast .Compare (
20+ left = ast .Name (id = "serialized" , ctx = ast .Load ()),
21+ ops = [ast .Eq ()],
22+ comparators = [
23+ ast .Attribute (
24+ value = ast .Attribute (
25+ value = ast .Name (id = enumeration_name , ctx = ast .Load ()),
26+ attr = member_name ,
27+ ctx = ast .Load ()
28+ ),
29+ attr = "value" ,
30+ ctx = ast .Load ()
31+ )
32+ ]
33+ )
34+
35+ # The return: return AssignmentType.Add
36+ def make_return (enumeration_name : str , member_name : str ):
37+ return ast .Return (
38+ value = ast .Attribute (
39+ value = ast .Name (id = enumeration_name , ctx = ast .Load ()),
40+ attr = member_name ,
41+ ctx = ast .Load ()
42+ )
43+ )
44+
45+
46+ def deserializer_generation (click , language : Language , output ):
47+ import_abc = ast .ImportFrom (
48+ module = 'abc' ,
49+ names = [ast .alias (name = 'ABC' , asname = None )],
50+ level = 0
51+ )
52+ import_dataclass = ast .ImportFrom (
53+ module = 'dataclasses' ,
54+ names = [ast .alias (name = 'dataclass' , asname = None )],
55+ level = 0
56+ )
57+ import_enum = ast .ImportFrom (
58+ module = "enum" ,
59+ names = [ast .alias (name = "Enum" , asname = None )],
60+ level = 0
61+ )
62+ import_typing = ast .ImportFrom (
63+ module = 'typing' ,
64+ names = [ast .alias (name = 'Optional' , asname = None )],
65+ level = 0
66+ )
67+ import_starlasu = ast .ImportFrom (
68+ module = 'pylasu.model.metamodel' ,
69+ names = [ast .alias (name = 'Expression' , asname = 'StarLasuExpression' ),
70+ ast .alias (name = 'PlaceholderElement' , asname = 'StarLasuPlaceholderElement' ),
71+ ast .alias (name = 'Named' , asname = 'StarLasuNamed' ),
72+ ast .alias (name = 'TypeAnnotation' , asname = 'StarLasuTypeAnnotation' ),
73+ ast .alias (name = 'Parameter' , asname = 'StarLasuParameter' ),
74+ ast .alias (name = 'Statement' , asname = 'StarLasuStatement' ),
75+ ast .alias (name = 'EntityDeclaration' , asname = 'StarLasuEntityDeclaration' ),
76+ ast .alias (name = 'BehaviorDeclaration' , asname = 'StarLasuBehaviorDeclaration' ),
77+ ast .alias (name = 'Documentation' , asname = 'StarLasuDocumentation' )],
78+ level = 0
79+ )
80+ import_node = ast .ImportFrom (
81+ module = 'pylasu.model' ,
82+ names = [ast .alias (name = 'Node' , asname = None )],
83+ level = 0
84+ )
85+ import_ast = ast .ImportFrom (
86+ module = 'ast' ,
87+ names = [ast .alias (name = e .get_name (), asname = None ) for e in language .get_elements () if not isinstance (e , PrimitiveType )],
88+ level = 0
89+ )
90+ import_primitives = ast .ImportFrom (
91+ module = 'primitive_types' ,
92+ names = [ast .alias (name = e .get_name (), asname = None ) for e in language .get_elements () if isinstance (e , PrimitiveType )],
93+ level = 0
94+ )
95+ module = ast .Module (body = [import_abc , import_dataclass , import_typing , import_enum , import_starlasu , import_node ,
96+ import_ast , import_primitives ],
97+ type_ignores = [])
98+
99+
100+
101+ for e in language .get_elements ():
102+ if isinstance (e , Enumeration ):
103+ arg_serialized = ast .arg (arg = "serialized" , annotation = ast .Name (id = "str" , ctx = ast .Load ()))
104+ # The raise: raise ValueError(f"...")
105+ raise_stmt = ast .Raise (
106+ exc = ast .Call (
107+ func = ast .Name (id = "ValueError" , ctx = ast .Load ()),
108+ args = [
109+ ast .JoinedStr (values = [
110+ ast .Constant (value = f"Invalid value for { e .get_name ()} : " ),
111+ ast .FormattedValue (
112+ value = ast .Name (id = "serialized" , ctx = ast .Load ()),
113+ conversion = - 1
114+ )
115+ ])
116+ ],
117+ keywords = []
118+ ),
119+ cause = None
120+ )
121+ # The function body
122+ literals = e .get_literals ()
123+ current_if = ast .If (
124+ test = make_cond (e .get_name (), literals [0 ].get_name ()),
125+ body = [make_return (e .get_name (), literals [0 ].get_name ())],
126+ orelse = []
127+ )
128+ root_if = current_if
129+
130+ for literal in literals [1 :]:
131+ next_if = ast .If (
132+ test = make_cond (e .get_name (), literal .get_name ()),
133+ body = [make_return (e .get_name (), literal .get_name ())],
134+ orelse = []
135+ )
136+ current_if .orelse = [next_if ]
137+ current_if = next_if
138+
139+ # Final else
140+ current_if .orelse = [raise_stmt ]
141+
142+ # Function definition
143+ func_def = ast .FunctionDef (
144+ name = f"_deserialize_{ to_snake_case (e .get_name ())} " ,
145+ args = ast .arguments (
146+ posonlyargs = [],
147+ args = [arg_serialized ],
148+ kwonlyargs = [],
149+ kw_defaults = [],
150+ defaults = []
151+ ),
152+ body = [root_if ],
153+ decorator_list = [],
154+ returns = ast .Constant (value = e .get_name ())
155+ )
156+ module .body .append (func_def )
157+
158+ generated_code = astor .to_source (module )
159+ output_path = Path (output )
160+ output_path .mkdir (parents = True , exist_ok = True )
161+ click .echo (f"📂 Saving deserializer to: { output } " )
162+ with Path (f"{ output } /deserializer.py" ).open ("w" , encoding = "utf-8" ) as f :
163+ f .write (generated_code )
0 commit comments