Skip to content

Commit 79d0f31

Browse files
committed
Simplify tabular API
Add test for accepted types and keys Finish API and update examples Add keys_accepted to constructor and update docs Update docs
1 parent a8b5134 commit 79d0f31

File tree

13 files changed

+245
-513
lines changed

13 files changed

+245
-513
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pip install dowel
1616
## Usage
1717
```python
1818
import dowel
19-
from dowel import logger, tabular
19+
from dowel import logger
2020

2121
logger.add_output(dowel.StdOutput())
2222
logger.add_output(dowel.TensorBoardOutput('tensorboard_logdir'))
@@ -26,9 +26,8 @@ for i in range(1000):
2626
logger.push_prefix('itr {}'.format(i))
2727
logger.log('Running training step')
2828

29-
tabular.record('itr', i)
30-
tabular.record('loss', 100.0 / (2 + i))
31-
logger.log(tabular)
29+
logger.logkv('itr', i)
30+
logger.logkv('loss', 100.0 / (2 + i))
3231

3332
logger.pop_prefix()
3433
logger.dump_all()

examples/log_progress.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99

1010
import dowel
11-
from dowel import logger, tabular
11+
from dowel import logger
1212

1313
logger.add_output(dowel.StdOutput())
1414
logger.add_output(dowel.CsvOutput('progress.csv'))
@@ -22,9 +22,8 @@
2222

2323
time.sleep(0.01) # Tensorboard doesn't like output to be too fast.
2424

25-
tabular.record('itr', i)
26-
tabular.record('loss', 100.0 / (2 + i))
27-
logger.log(tabular)
25+
logger.logkv('itr', i)
26+
logger.logkv('loss', 100.0 / (2 + i))
2827

2928
logger.pop_prefix()
3029
logger.dump_all()

src/dowel/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from dowel.tensor_board_output import TensorBoardOutput
1111

1212
logger = Logger()
13-
tabular = TabularInput()
1413

1514
__all__ = [
1615
'Histogram',
@@ -23,5 +22,4 @@
2322
'TabularInput',
2423
'TensorBoardOutput',
2524
'logger',
26-
'tabular',
2725
]

src/dowel/csv_output.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,66 @@
22
import csv
33
import warnings
44

5-
from dowel import TabularInput
5+
import numpy as np
6+
67
from dowel.simple_outputs import FileOutput
8+
from dowel.tabular_input import TabularInput
79
from dowel.utils import colorize
810

911

1012
class CsvOutput(FileOutput):
1113
"""CSV file output for logger.
1214
1315
:param file_name: The file this output should log to.
16+
:param keys_accepted: Regex for which keys this output should accept.
1417
"""
1518

16-
def __init__(self, file_name):
17-
super().__init__(file_name)
19+
def __init__(self, file_name, keys_accepted=r'^'):
20+
super().__init__(file_name, keys_accepted=keys_accepted)
1821
self._writer = None
1922
self._fieldnames = None
2023
self._warned_once = set()
2124
self._disable_warnings = False
25+
self.tabular = TabularInput()
2226

2327
@property
2428
def types_accepted(self):
25-
"""Accept TabularInput objects only."""
26-
return (TabularInput, )
27-
28-
def record(self, data, prefix=''):
29-
"""Log tabular data to CSV."""
30-
if isinstance(data, TabularInput):
31-
to_csv = data.as_primitive_dict
32-
33-
if not to_csv.keys() and not self._writer:
34-
return
35-
36-
if not self._writer:
37-
self._fieldnames = set(to_csv.keys())
38-
self._writer = csv.DictWriter(
39-
self._log_file,
40-
fieldnames=self._fieldnames,
41-
extrasaction='ignore')
42-
self._writer.writeheader()
43-
44-
if to_csv.keys() != self._fieldnames:
45-
self._warn('Inconsistent TabularInput keys detected. '
46-
'CsvOutput keys: {}. '
47-
'TabularInput keys: {}. '
48-
'Did you change key sets after your first '
49-
'logger.log(TabularInput)?'.format(
50-
set(self._fieldnames), set(to_csv.keys())))
51-
52-
self._writer.writerow(to_csv)
53-
54-
for k in to_csv.keys():
55-
data.mark(k)
56-
else:
57-
raise ValueError('Unacceptable type.')
29+
"""Accept str and scalar objects."""
30+
return (str, ) + np.ScalarType
31+
32+
def record(self, key, value, prefix=''):
33+
"""Log data to a csv file."""
34+
self.tabular.record(key, value)
35+
36+
def dump(self, step=None):
37+
"""Flush data to log file."""
38+
if self.tabular.empty:
39+
return
40+
41+
to_csv = self.tabular.as_primitive_dict
42+
43+
if not to_csv.keys() and not self._writer:
44+
return
45+
46+
if not self._writer:
47+
self._fieldnames = set(to_csv.keys())
48+
self._writer = csv.DictWriter(self._log_file,
49+
fieldnames=self._fieldnames,
50+
extrasaction='ignore')
51+
self._writer.writeheader()
52+
53+
if to_csv.keys() != self._fieldnames:
54+
self._warn('Inconsistent TabularInput keys detected. '
55+
'CsvOutput keys: {}. '
56+
'TabularInput keys: {}. '
57+
'Did you change key sets after your first '
58+
'logger.log(TabularInput)?'.format(
59+
set(self._fieldnames), set(to_csv.keys())))
60+
61+
self._writer.writerow(to_csv)
62+
63+
self._log_file.flush()
64+
self.tabular.clear()
5865

5966
def _warn(self, msg):
6067
"""Warns the user using warnings.warn.
@@ -63,8 +70,9 @@ def _warn(self, msg):
6370
is the one printed.
6471
"""
6572
if not self._disable_warnings and msg not in self._warned_once:
66-
warnings.warn(
67-
colorize(msg, 'yellow'), CsvOutputWarning, stacklevel=3)
73+
warnings.warn(colorize(msg, 'yellow'),
74+
CsvOutputWarning,
75+
stacklevel=3)
6876
self._warned_once.add(msg)
6977
return msg
7078

src/dowel/logger.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,37 @@
133133
"""
134134
import abc
135135
import contextlib
136+
import re
136137
import warnings
137138

138139
from dowel.utils import colorize
139140

140141

141142
class LogOutput(abc.ABC):
142-
"""Abstract class for Logger Outputs."""
143+
"""Abstract class for Logger Outputs.
143144
144-
@property
145-
def types_accepted(self):
146-
"""Pass these types to this logger output.
145+
:param keys_accepted: Regex for which keys this output should accept.
146+
"""
147147

148-
The types in this tuple will be accepted by this output.
148+
def __init__(self, keys_accepted=r'^$'):
149+
self._keys_accepted = keys_accepted
149150

150-
:return: A tuple containing all valid input types.
151-
"""
151+
@property
152+
def types_accepted(self):
153+
"""Returns a tuple containing all valid input value types."""
152154
return ()
153155

156+
@property
157+
def keys_accepted(self):
158+
"""Returns a regex string matching keys to be sent to this output."""
159+
return self._keys_accepted
160+
154161
@abc.abstractmethod
155-
def record(self, data, prefix=''):
162+
def record(self, key, value, prefix=''):
156163
"""Pass logger data to this output.
157164
158-
:param data: The data to be logged by the output.
165+
:param key: The key to be logged by the output.
166+
:param value: The value to be logged by the output.
159167
:param prefix: A prefix placed before a log entry in text outputs.
160168
"""
161169
pass
@@ -186,7 +194,7 @@ def __init__(self):
186194
self._warned_once = set()
187195
self._disable_warnings = False
188196

189-
def log(self, data):
197+
def logkv(self, key, value):
190198
"""Magic method that takes in all different types of input.
191199
192200
This method is the main API for the logger. Any data to be logged goes
@@ -195,24 +203,30 @@ def log(self, data):
195203
Any data sent to this method is sent to all outputs that accept its
196204
type (defined in the types_accepted property).
197205
198-
:param data: Data to be logged. This can be any type specified in the
206+
:param key: Key to be logged. This must be a string.
207+
:param value: Value to be logged. This can be any type specified in the
199208
types_accepted property of any of the logger outputs.
200209
"""
201210
if not self._outputs:
202211
self._warn('No outputs have been added to the logger.')
203212

204213
at_least_one_logged = False
205214
for output in self._outputs:
206-
if isinstance(data, output.types_accepted):
207-
output.record(data, prefix=self._prefix_str)
215+
if isinstance(value, output.types_accepted) and re.match(
216+
output.keys_accepted, key):
217+
output.record(key, value, prefix=self._prefix_str)
208218
at_least_one_logged = True
209219

210220
if not at_least_one_logged:
211221
warning = (
212222
'Log data of type {} was not accepted by any output'.format(
213-
type(data).__name__))
223+
type(value).__name__))
214224
self._warn(warning)
215225

226+
def log(self, value):
227+
"""Log just a value without a key."""
228+
self.logkv('', value)
229+
216230
def add_output(self, output):
217231
"""Add a new output to the logger.
218232

src/dowel/simple_outputs.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99

1010
import dateutil.tz
11+
import numpy as np
1112

1213
from dowel import LogOutput
1314
from dowel.tabular_input import TabularInput
@@ -17,46 +18,53 @@
1718
class StdOutput(LogOutput):
1819
"""Standard console output for the logger.
1920
21+
:param keys_accepted: Regex for which keys this output should accept.
2022
:param with_timestamp: Whether to log a timestamp before non-tabular data.
2123
"""
2224

23-
def __init__(self, with_timestamp=True):
25+
def __init__(self, keys_accepted=r'^', with_timestamp=True):
26+
super().__init__(keys_accepted=keys_accepted)
2427
self._with_timestamp = with_timestamp
28+
self.tabular = TabularInput()
2529

2630
@property
2731
def types_accepted(self):
28-
"""Accept str and TabularInput objects."""
29-
return (str, TabularInput)
32+
"""Accept str and scalar objects."""
33+
return (str, ) + np.ScalarType
3034

31-
def record(self, data, prefix=''):
35+
def record(self, key, value, prefix=''):
3236
"""Log data to console."""
33-
if isinstance(data, str):
34-
out = prefix + data
35-
if self._with_timestamp:
36-
now = datetime.datetime.now(dateutil.tz.tzlocal())
37-
timestamp = now.strftime('%Y-%m-%d %H:%M:%S')
38-
out = '%s | %s' % (timestamp, out)
39-
elif isinstance(data, TabularInput):
40-
out = str(data)
41-
data.mark_str()
37+
if not key:
38+
if isinstance(value, str):
39+
out = prefix + value
40+
if self._with_timestamp:
41+
now = datetime.datetime.now(dateutil.tz.tzlocal())
42+
timestamp = now.strftime('%Y-%m-%d %H:%M:%S')
43+
out = '%s | %s' % (timestamp, out)
44+
print(out)
45+
else:
46+
raise ValueError('Unacceptable type')
4247
else:
43-
raise ValueError('Unacceptable type')
44-
45-
print(out)
48+
self.tabular.record(key, value)
4649

4750
def dump(self, step=None):
4851
"""Flush data to standard output stream."""
52+
if not self.tabular.empty:
53+
print(str(self.tabular))
54+
self.tabular.clear()
4955
sys.stdout.flush()
5056

5157

5258
class FileOutput(LogOutput, metaclass=abc.ABCMeta):
5359
"""File output abstract class for logger.
5460
5561
:param file_name: The file this output should log to.
62+
:param keys_accepted: Regex for which keys this output should accept.
5663
:param mode: File open mode ('a', 'w', etc).
5764
"""
5865

59-
def __init__(self, file_name, mode='w'):
66+
def __init__(self, file_name, keys_accepted=r'^', mode='w'):
67+
super().__init__(keys_accepted=keys_accepted)
6068
mkdir_p(os.path.dirname(file_name))
6169
# Open the log file in child class
6270
self._log_file = open(file_name, mode)
@@ -75,31 +83,38 @@ class TextOutput(FileOutput):
7583
"""Text file output for logger.
7684
7785
:param file_name: The file this output should log to.
86+
:param keys_accepted: Regex for which keys this output should accept.
7887
:param with_timestamp: Whether to log a timestamp before the data.
7988
"""
8089

81-
def __init__(self, file_name, with_timestamp=True):
82-
super().__init__(file_name, 'a')
90+
def __init__(self, file_name, keys_accepted=r'^', with_timestamp=True):
91+
super().__init__(file_name, keys_accepted=keys_accepted, mode='a')
8392
self._with_timestamp = with_timestamp
84-
self._delimiter = ' | '
93+
self.tabular = TabularInput()
8594

8695
@property
8796
def types_accepted(self):
88-
"""Accept str objects only."""
89-
return (str, TabularInput)
97+
"""Accept str and scalar objects."""
98+
return (str, ) + np.ScalarType
9099

91-
def record(self, data, prefix=''):
100+
def record(self, key, value, prefix=''):
92101
"""Log data to text file."""
93-
if isinstance(data, str):
94-
out = prefix + data
95-
if self._with_timestamp:
96-
now = datetime.datetime.now(dateutil.tz.tzlocal())
97-
timestamp = now.strftime('%Y-%m-%d %H:%M:%S')
98-
out = '%s | %s' % (timestamp, out)
99-
elif isinstance(data, TabularInput):
100-
out = str(data)
101-
data.mark_str()
102+
if not key:
103+
if isinstance(value, str):
104+
out = prefix + value
105+
if self._with_timestamp:
106+
now = datetime.datetime.now(dateutil.tz.tzlocal())
107+
timestamp = now.strftime('%Y-%m-%d %H:%M:%S')
108+
out = '%s | %s' % (timestamp, out)
109+
self._log_file.write(out + '\n')
110+
else:
111+
raise ValueError('Unacceptable type')
102112
else:
103-
raise ValueError('Unacceptable type.')
113+
self.tabular.record(key, value)
104114

105-
self._log_file.write(out + '\n')
115+
def dump(self, step=None):
116+
"""Flush data to log file."""
117+
if not self.tabular.empty:
118+
self._log_file.write(str(self.tabular) + '\n')
119+
self.tabular.clear()
120+
self._log_file.flush()

0 commit comments

Comments
 (0)