1- import  re 
2- from  langchain_openai  import  ChatOpenAI 
3- from  langchain_core .messages  import  HumanMessage , SystemMessage 
4- import  pandas  as  pd 
1+ """ 
2+ SQL 쿼리 결과를 Plotly로 시각화하는 모듈 
3+ 
4+ 이 모듈은 Lang2SQL 실행 결과를 다양한 형태의 차트로 시각화하는 기능을 제공합니다. 
5+ LLM을 활용하여 적절한 Plotly 코드를 생성하고 실행합니다. 
6+ """ 
7+ 
58import  os 
9+ import  re 
10+ from  typing  import  Optional 
611
12+ import  pandas  as  pd 
713import  plotly 
814import  plotly .express  as  px 
915import  plotly .graph_objects  as  go 
16+ from  langchain_core .messages  import  HumanMessage , SystemMessage 
17+ from  langchain_openai  import  ChatOpenAI 
1018
1119
1220class  DisplayChart :
@@ -17,12 +25,29 @@ class DisplayChart:
1725    plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다. 
1826    """ 
1927
20-     def  __init__ (self , question , sql , df_metadata ):
28+     def  __init__ (self , question : str , sql : str , df_metadata : str ):
29+         """ 
30+         DisplayChart 인스턴스를 초기화합니다. 
31+ 
32+         Args: 
33+             question (str): 사용자 질문 
34+             sql (str): 실행된 SQL 쿼리 
35+             df_metadata (str): 데이터프레임 메타데이터 
36+         """ 
2137        self .question  =  question 
2238        self .sql  =  sql 
2339        self .df_metadata  =  df_metadata 
2440
25-     def  llm_model_for_chart (self , message_log ):
41+     def  llm_model_for_chart (self , message_log ) ->  Optional [str ]:
42+         """ 
43+         LLM 모델을 사용하여 차트 생성 코드를 생성합니다. 
44+ 
45+         Args: 
46+             message_log: LLM에 전달할 메시지 목록 
47+ 
48+         Returns: 
49+             Optional[str]: 생성된 차트 코드 또는 None 
50+         """ 
2651        provider  =  os .getenv ("LLM_PROVIDER" )
2752        if  provider  ==  "openai" :
2853            llm  =  ChatOpenAI (
@@ -31,18 +56,29 @@ def llm_model_for_chart(self, message_log):
3156            )
3257            result  =  llm .invoke (message_log )
3358            return  result 
59+         return  None 
3460
3561    def  _extract_python_code (self , markdown_string : str ) ->  str :
62+         """ 
63+         마크다운 문자열에서 Python 코드 블록을 추출합니다. 
64+ 
65+         Args: 
66+             markdown_string: 마크다운 형식의 문자열 
67+ 
68+         Returns: 
69+             str: 추출된 Python 코드 
70+         """ 
3671        # Strip whitespace to avoid indentation errors in LLM-generated code 
37-         markdown_string  =  markdown_string .content .split ("```" )[1 ][6 :].strip ()
72+         if  hasattr (markdown_string , "content" ):
73+             markdown_string  =  markdown_string .content .split ("```" )[1 ][6 :].strip ()
74+         else :
75+             markdown_string  =  str (markdown_string )
3876
3977        # Regex pattern to match Python code blocks 
40-         pattern  =  r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"    # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용 
78+         pattern  =  r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" 
4179
4280        # Find all matches in the markdown string 
43-         matches  =  re .findall (
44-             pattern , markdown_string , re .IGNORECASE 
45-         )  # 대소문자 구분 안함 
81+         matches  =  re .findall (pattern , markdown_string , re .IGNORECASE )
4682
4783        # Extract the Python code from the matches 
4884        python_code  =  []
@@ -55,13 +91,27 @@ def _extract_python_code(self, markdown_string: str) -> str:
5591
5692        return  python_code [0 ]
5793
58-     def  _sanitize_plotly_code (self , raw_plotly_code ):
94+     def  _sanitize_plotly_code (self , raw_plotly_code : str ) ->  str :
95+         """ 
96+         Plotly 코드에서 불필요한 부분을 제거합니다. 
97+ 
98+         Args: 
99+             raw_plotly_code: 원본 Plotly 코드 
100+ 
101+         Returns: 
102+             str: 정리된 Plotly 코드 
103+         """ 
59104        # Remove the fig.show() statement from the plotly code 
60105        plotly_code  =  raw_plotly_code .replace ("fig.show()" , "" )
61- 
62106        return  plotly_code 
63107
64108    def  generate_plotly_code (self ) ->  str :
109+         """ 
110+         LLM을 사용하여 Plotly 코드를 생성합니다. 
111+ 
112+         Returns: 
113+             str: 생성된 Plotly 코드 
114+         """ 
65115        if  self .question  is  not None :
66116            system_msg  =  f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{ self .question }  
67117        else :
@@ -82,20 +132,33 @@ def generate_plotly_code(self) -> str:
82132        ]
83133
84134        plotly_code  =  self .llm_model_for_chart (message_log )
135+         if  plotly_code  is  None :
136+             return  "" 
85137
86138        return  self ._sanitize_plotly_code (self ._extract_python_code (plotly_code ))
87139
88140    def  get_plotly_figure (
89141        self , plotly_code : str , df : pd .DataFrame , dark_mode : bool  =  True 
90-     ) ->  plotly .graph_objs .Figure :
91- 
142+     ) ->  Optional [plotly .graph_objs .Figure ]:
143+         """ 
144+         Plotly 코드를 실행하여 Figure 객체를 생성합니다. 
145+ 
146+         Args: 
147+             plotly_code: 실행할 Plotly 코드 
148+             df: 데이터프레임 
149+             dark_mode: 다크 모드 사용 여부 
150+ 
151+         Returns: 
152+             Optional[plotly.graph_objs.Figure]: 생성된 Figure 객체 또는 None 
153+         """ 
92154        ldict  =  {"df" : df , "px" : px , "go" : go }
155+         fig  =  None 
156+ 
93157        try :
94-             exec (plotly_code , globals (), ldict )
158+             exec (plotly_code , globals (), ldict )   # noqa: S102 
95159            fig  =  ldict .get ("fig" , None )
96160
97-         except  Exception  as  e :
98- 
161+         except  Exception :
99162            # Inspect data types 
100163            numeric_cols  =  df .select_dtypes (include = ["number" ]).columns .tolist ()
101164            categorical_cols  =  df .select_dtypes (
0 commit comments