Skip to content

Commit be6887c

Browse files
committed
Simplify tabular API
Add test for accepted types and keys Finish API and update examples Add keys_accepted to constructor and update docs Update docs Renamed tabular_input to tabular
1 parent a8b5134 commit be6887c

15 files changed

+326
-596
lines changed

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pip install dowel
1616
## Usage
1717
```python
1818
import dowel
19-
from dowel import logger, tabular
19+
from dowel import logger
2020

2121
logger.add_output(dowel.StdOutput())
2222
logger.add_output(dowel.TensorBoardOutput('tensorboard_logdir'))
@@ -26,9 +26,8 @@ for i in range(1000):
2626
logger.push_prefix('itr {}'.format(i))
2727
logger.log('Running training step')
2828

29-
tabular.record('itr', i)
30-
tabular.record('loss', 100.0 / (2 + i))
31-
logger.log(tabular)
29+
logger.logkv('itr', i)
30+
logger.logkv('loss', 100.0 / (2 + i))
3231

3332
logger.pop_prefix()
3433
logger.dump_all()

examples/log_progress.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import time
99

1010
import dowel
11-
from dowel import logger, tabular
11+
from dowel import logger
1212

1313
logger.add_output(dowel.StdOutput())
1414
logger.add_output(dowel.CsvOutput('progress.csv'))
@@ -22,9 +22,8 @@
2222

2323
time.sleep(0.01) # Tensorboard doesn't like output to be too fast.
2424

25-
tabular.record('itr', i)
26-
tabular.record('loss', 100.0 / (2 + i))
27-
logger.log(tabular)
25+
logger.logkv('itr', i)
26+
logger.logkv('loss', 100.0 / (2 + i))
2827

2928
logger.pop_prefix()
3029
logger.dump_all()

src/dowel/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
from dowel.histogram import Histogram
66
from dowel.logger import Logger, LoggerWarning, LogOutput
77
from dowel.simple_outputs import StdOutput, TextOutput
8-
from dowel.tabular_input import TabularInput
8+
from dowel.tabular import Tabular
99
from dowel.csv_output import CsvOutput # noqa: I100
1010
from dowel.tensor_board_output import TensorBoardOutput
1111

1212
logger = Logger()
13-
tabular = TabularInput()
1413

1514
__all__ = [
1615
'Histogram',
@@ -20,8 +19,7 @@
2019
'TextOutput',
2120
'LogOutput',
2221
'LoggerWarning',
23-
'TabularInput',
22+
'Tabular',
2423
'TensorBoardOutput',
2524
'logger',
26-
'tabular',
2725
]

src/dowel/csv_output.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,66 @@
22
import csv
33
import warnings
44

5-
from dowel import TabularInput
5+
import numpy as np
6+
67
from dowel.simple_outputs import FileOutput
8+
from dowel.tabular import Tabular
79
from dowel.utils import colorize
810

911

1012
class CsvOutput(FileOutput):
1113
"""CSV file output for logger.
1214
1315
:param file_name: The file this output should log to.
16+
:param keys_accepted: Regex for which keys this output should accept.
1417
"""
1518

16-
def __init__(self, file_name):
17-
super().__init__(file_name)
19+
def __init__(self, file_name, keys_accepted=r'^\S+$'):
20+
super().__init__(file_name, keys_accepted=keys_accepted)
1821
self._writer = None
1922
self._fieldnames = None
2023
self._warned_once = set()
2124
self._disable_warnings = False
25+
self.tabular = Tabular()
2226

2327
@property
2428
def types_accepted(self):
25-
"""Accept TabularInput objects only."""
26-
return (TabularInput, )
27-
28-
def record(self, data, prefix=''):
29-
"""Log tabular data to CSV."""
30-
if isinstance(data, TabularInput):
31-
to_csv = data.as_primitive_dict
32-
33-
if not to_csv.keys() and not self._writer:
34-
return
35-
36-
if not self._writer:
37-
self._fieldnames = set(to_csv.keys())
38-
self._writer = csv.DictWriter(
39-
self._log_file,
40-
fieldnames=self._fieldnames,
41-
extrasaction='ignore')
42-
self._writer.writeheader()
43-
44-
if to_csv.keys() != self._fieldnames:
45-
self._warn('Inconsistent TabularInput keys detected. '
46-
'CsvOutput keys: {}. '
47-
'TabularInput keys: {}. '
48-
'Did you change key sets after your first '
49-
'logger.log(TabularInput)?'.format(
50-
set(self._fieldnames), set(to_csv.keys())))
51-
52-
self._writer.writerow(to_csv)
53-
54-
for k in to_csv.keys():
55-
data.mark(k)
56-
else:
57-
raise ValueError('Unacceptable type.')
29+
"""Accept str and scalar objects."""
30+
return (str, ) + np.ScalarType
31+
32+
def record(self, key, value, prefix=''):
33+
"""Log data to a csv file."""
34+
self.tabular.record(key, value)
35+
36+
def dump(self, step=None):
37+
"""Flush data to log file."""
38+
if self.tabular.empty:
39+
return
40+
41+
to_csv = self.tabular.as_primitive_dict
42+
43+
if not to_csv.keys() and not self._writer:
44+
return
45+
46+
if not self._writer:
47+
self._fieldnames = set(to_csv.keys())
48+
self._writer = csv.DictWriter(self._log_file,
49+
fieldnames=self._fieldnames,
50+
extrasaction='ignore')
51+
self._writer.writeheader()
52+
53+
if to_csv.keys() != self._fieldnames:
54+
self._warn('Inconsistent Tabular keys detected. '
55+
'CsvOutput keys: {}. '
56+
'Tabular keys: {}. '
57+
'Did you change key sets after your first '
58+
'logger.log(Tabular)?'.format(set(self._fieldnames),
59+
set(to_csv.keys())))
60+
61+
self._writer.writerow(to_csv)
62+
63+
self._log_file.flush()
64+
self.tabular.clear()
5865

5966
def _warn(self, msg):
6067
"""Warns the user using warnings.warn.
@@ -63,8 +70,9 @@ def _warn(self, msg):
6370
is the one printed.
6471
"""
6572
if not self._disable_warnings and msg not in self._warned_once:
66-
warnings.warn(
67-
colorize(msg, 'yellow'), CsvOutputWarning, stacklevel=3)
73+
warnings.warn(colorize(msg, 'yellow'),
74+
CsvOutputWarning,
75+
stacklevel=3)
6876
self._warned_once.add(msg)
6977
return msg
7078

src/dowel/logger.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
66
The logger has 4 major steps:
77
8-
1. Inputs, such as a simple string or something more complicated like
9-
TabularInput, are passed to the log() method of an instantiated Logger.
8+
1. Inputs, such as a simple string or something more complicated like
9+
a distribution, are passed to the log() or logkv() method of an
10+
instantiated Logger.
1011
11-
2. The Logger class checks for any outputs that have been added to it, and
12-
calls the record() method of any outputs that accept the type of input.
12+
2. The Logger class checks for any outputs that have been added to it, and
13+
calls the record() method of any outputs that accept the type of input.
1314
14-
3. The output (a subclass of LogOutput) receives the input via its record()
15-
method and handles it in whatever way is expected.
15+
3. The output (a subclass of LogOutput) receives the input via its record()
16+
method and handles it in whatever way is expected.
1617
17-
4. (only in some cases) The dump method is used to dump the output to file.
18-
It is necessary for some LogOutput subclasses, like TensorBoardOutput.
18+
4. (only in some cases) The dump method is used to dump the output to file
19+
and to log any key-value pairs that have been stored.
1920
2021
2122
# Here's a demonstration of dowel:
@@ -61,8 +62,8 @@
6162
6263
# And another output.
6364
64-
from dowel import CsvOutput
65-
logger.add_output(CsvOutput('log_folder/table.csv'))
65+
from dowel import TensorBoardOutput
66+
logger.add_output(TensorBoardOutput('log_folder/tensorboard'))
6667
6768
+---------+
6869
+------>StdOutput|
@@ -72,13 +73,16 @@
7273
|logger+------>TextOutput|
7374
+------+ +----------+
7475
|
75-
| +---------+
76-
+------>CsvOutput|
77-
+---------+
76+
| +-----------------+
77+
+------>TensorBoardOutput|
78+
+-----------------+
7879
7980
# The logger will record anything passed to logger.log to all outputs that
8081
# accept its type.
8182
83+
84+
# Now let's try logging a string again.
85+
8286
logger.log('test')
8387
8488
+---------+
@@ -89,38 +93,36 @@
8993
|logger+---'test'--->TextOutput|
9094
+------+ +----------+
9195
|
92-
| +---------+
93-
+-----!!----->CsvOutput|
94-
+---------+
96+
| +-----------------+
97+
+-----!!----->TensorBoardOutput|
98+
+-----------------+
9599
96-
# !! Note that the logger knows not to send CsvOutput the string 'test'
97-
# Similarly, more complex objects like tf.tensor won't be sent to (for
100+
# !! Note that the logger knows not to send 'test' to TensorBoardOutput.
101+
# Similarly, more complex objects like tf.Graph won't be sent to (for
98102
# example) TextOutput.
99103
# This behavior is defined in each output's types_accepted property
100104
101105
# Here's a more complex example.
102-
# TabularInput, instantiated for you as the tabular, can log key/value pairs.
106+
# We can log key-value pairs using logger.logkv
103107
104-
from dowel import tabular
105-
tabular.record('key', 72)
106-
tabular.record('foo', 'bar')
107-
logger.log(tabular)
108+
logger.logkv('key', 72)
109+
logger.logkv('foo', 'bar')
110+
logger.dump_all()
108111
109-
+---------+
110-
+---tabular--->StdOutput|
111-
| +---------+
112+
+---------+
113+
+------>StdOutput|
114+
| +---------+
112115
|
113-
+------+ +----------+
114-
|logger+---tabular--->TextOutput|
115-
+------+ +----------+
116+
+------+ +----------+
117+
|logger+------>TextOutput|
118+
+------+ +----------+
116119
|
117-
| +---------+
118-
+---tabular--->CsvOutput|
119-
+---------+
120+
| +---------+
121+
+------>CsvOutput|
122+
+---------+
120123
121-
# Note that LogOutputs which consume TabularInputs must call
122-
# TabularInput.mark() on each key they log. This helps the logger detect when
123-
# tabular data is not logged.
124+
# Note that the key-value pairs are saved in each output until we call
125+
# dump_all().
124126
125127
# Console Output:
126128
--- ---
@@ -133,29 +135,37 @@
133135
"""
134136
import abc
135137
import contextlib
138+
import re
136139
import warnings
137140

138141
from dowel.utils import colorize
139142

140143

141144
class LogOutput(abc.ABC):
142-
"""Abstract class for Logger Outputs."""
145+
"""Abstract class for Logger Outputs.
143146
144-
@property
145-
def types_accepted(self):
146-
"""Pass these types to this logger output.
147+
:param keys_accepted: Regex for which keys this output should accept.
148+
"""
147149

148-
The types in this tuple will be accepted by this output.
150+
def __init__(self, keys_accepted=r'^$'):
151+
self._keys_accepted = keys_accepted
149152

150-
:return: A tuple containing all valid input types.
151-
"""
153+
@property
154+
def types_accepted(self):
155+
"""Returns a tuple containing all valid input value types."""
152156
return ()
153157

158+
@property
159+
def keys_accepted(self):
160+
"""Returns a regex string matching keys to be sent to this output."""
161+
return self._keys_accepted
162+
154163
@abc.abstractmethod
155-
def record(self, data, prefix=''):
164+
def record(self, key, value, prefix=''):
156165
"""Pass logger data to this output.
157166
158-
:param data: The data to be logged by the output.
167+
:param key: The key to be logged by the output.
168+
:param value: The value to be logged by the output.
159169
:param prefix: A prefix placed before a log entry in text outputs.
160170
"""
161171
pass
@@ -186,7 +196,7 @@ def __init__(self):
186196
self._warned_once = set()
187197
self._disable_warnings = False
188198

189-
def log(self, data):
199+
def logkv(self, key, value):
190200
"""Magic method that takes in all different types of input.
191201
192202
This method is the main API for the logger. Any data to be logged goes
@@ -195,24 +205,30 @@ def log(self, data):
195205
Any data sent to this method is sent to all outputs that accept its
196206
type (defined in the types_accepted property).
197207
198-
:param data: Data to be logged. This can be any type specified in the
208+
:param key: Key to be logged. This must be a string.
209+
:param value: Value to be logged. This can be any type specified in the
199210
types_accepted property of any of the logger outputs.
200211
"""
201212
if not self._outputs:
202213
self._warn('No outputs have been added to the logger.')
203214

204215
at_least_one_logged = False
205216
for output in self._outputs:
206-
if isinstance(data, output.types_accepted):
207-
output.record(data, prefix=self._prefix_str)
217+
if isinstance(value, output.types_accepted) and re.match(
218+
output.keys_accepted, key):
219+
output.record(key, value, prefix=self._prefix_str)
208220
at_least_one_logged = True
209221

210222
if not at_least_one_logged:
211223
warning = (
212224
'Log data of type {} was not accepted by any output'.format(
213-
type(data).__name__))
225+
type(value).__name__))
214226
self._warn(warning)
215227

228+
def log(self, value):
229+
"""Log just a value without a key."""
230+
self.logkv('', value)
231+
216232
def add_output(self, output):
217233
"""Add a new output to the logger.
218234

0 commit comments

Comments
 (0)