Skip to content

Commit cb97d29

Browse files
authored
Merge pull request #41 from lincc-frameworks/method_types
Fix the types for cite_function
2 parents 5e23c4d + ab3ea20 commit cb97d29

File tree

4 files changed

+128
-35
lines changed

4 files changed

+128
-35
lines changed

benchmarks/benchmarks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import citation_compass as cc
77

88

9-
@cc.cite_function("fake")
9+
@cc.cite_function(label="fake")
1010
def fake_function():
1111
"""A fake function to demonstrate the use of the citation_compass package."""
1212
return 1
1313

1414

15-
@cc.cite_function("fake", track_used=False)
15+
@cc.cite_function(label="fake", track_used=False)
1616
def fake_function2():
1717
"""A fake function to demonstrate the use of the citation_compass package."""
1818
return 1
@@ -28,7 +28,7 @@ def __init__(self):
2828
def time_create_function():
2929
"""Time the use of a wrapper with a label."""
3030

31-
@cc.cite_function("example")
31+
@cc.cite_function(label="example")
3232
def test_function():
3333
return 1
3434

docs/notebooks/timing.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
" return 1\n",
2626
"\n",
2727
"\n",
28-
"@cc.cite_function(\"test_func2\")\n",
28+
"@cc.cite_function(label=\"test_func2\")\n",
2929
"def test_func2():\n",
3030
" \"\"\"A test function with a citation\"\"\"\n",
3131
" return 1\n",
3232
"\n",
3333
"\n",
34-
"@cc.cite_function(\"test_func3\", track_used=False)\n",
34+
"@cc.cite_function(label=\"test_func3\", track_used=False)\n",
3535
"def test_func3():\n",
3636
" \"\"\"A test function with a citation that will not be tracked\"\"\"\n",
3737
" return 1"
@@ -94,7 +94,7 @@
9494
" return (-b + math.sqrt(inner)) / (2 * a), (-b - math.sqrt(inner)) / (2 * a)\n",
9595
"\n",
9696
"\n",
97-
"@cc.cite_function(\"test_func4\")\n",
97+
"@cc.cite_function\n",
9898
"def test_func4(a, b, c):\n",
9999
" \"\"\"A test function with a citation\"\"\"\n",
100100
" inner = b**2 - 4 * a * c\n",
@@ -127,7 +127,7 @@
127127
],
128128
"metadata": {
129129
"kernelspec": {
130-
"display_name": "citation",
130+
"display_name": "citation (3.13.8)",
131131
"language": "python",
132132
"name": "python3"
133133
},
@@ -141,7 +141,7 @@
141141
"name": "python",
142142
"nbconvert_exporter": "python",
143143
"pygments_lexer": "ipython3",
144-
"version": "3.10.4"
144+
"version": "3.13.8"
145145
}
146146
},
147147
"nbformat": 4,

src/citation_compass/citation.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44
from os import urandom
55
import sys
6+
import types
67

78
from citation_compass.citation_registry import (
89
CitationEntry,
@@ -74,14 +75,18 @@ def init_wrapper(*args, **kwargs):
7475
cls.__init__ = init_wrapper
7576

7677

77-
def cite_function(label=None, track_used=True):
78+
def cite_function(callable=None, *, label=None, track_used=True):
7879
"""A function wrapper for adding a citation to a function or
7980
class method.
8081
8182
Parameters
8283
----------
84+
callable : function or method, optional
85+
The function or method to add a citation to. This is automatically passed as
86+
the first argument when using the decorator without parentheses.
8387
label : str, optional
84-
The (optional) user-defined label for the citation.
88+
The (optional) user-defined label for the citation. If not provided,
89+
the label will be auto-extracted from the function's docstring.
8590
track_used : bool
8691
If True, the function will be marked as used when it is called.
8792
This adds a small amount of overhead to each function call.
@@ -92,35 +97,47 @@ class method.
9297
function
9398
The wrapped function or method.
9499
"""
95-
# If the label is callable, there were no parentheses on the
96-
# dectorator and it passed in the function instead. So use None
97-
# as the label.
98-
use_label = label if not callable(label) else None
99-
100-
def decorator(func):
101-
entry = CitationEntry.from_object(func, label=use_label)
102-
CITATION_COMPASS_REGISTRY.add(entry)
100+
# This decorator is designed as a two-layer decorator. The first layer handles the (optional)
101+
# arguments. The second handles the actual function wrapping.
103102

104-
# Wrap the function so it is marked as USED when it is called.
105-
if track_used:
103+
def _inner_decorator(callable):
104+
# The inner decorator is used to set the "all" citations entry and handle
105+
# the correct return types when wrapping the function.
106106

107-
@wraps(func)
108-
def fun_wrapper(*args, **kwargs):
109-
# Save the citation as USED when it is first called.
110-
CITATION_COMPASS_REGISTRY.mark_used(entry.key)
111-
return func(*args, **kwargs)
112-
else:
113-
# We do not wrap the function, but just return the original function.
114-
fun_wrapper = func
107+
# Add the function to the registry (for the "all" citations).
108+
entry = CitationEntry.from_object(callable, label=label)
109+
CITATION_COMPASS_REGISTRY.add(entry)
115110

116-
# We mark as used be default so the citation does not get dropped.
111+
# If we are not tracking when the function is used, we don't need to wrap it.
112+
# We can just return the original callable.
113+
if not track_used:
114+
# We mark as used by default so the citation does not get dropped.
117115
CITATION_COMPASS_REGISTRY.mark_used(entry.key)
116+
return callable
118117

119-
return fun_wrapper
118+
# If the callable is a classmethod or method, we need to get the function that has
119+
# the self or cls argument.
120+
func = callable.__func__ if isinstance(callable, (classmethod, types.MethodType)) else callable
120121

121-
if callable(label):
122-
return decorator(label)
123-
return decorator
122+
# Define the actual wrapper for the callable we passed in. This wrapper function will
123+
# be called each time the internal function is called.
124+
@wraps(func)
125+
def citation_wrapper(*args, **kwargs):
126+
# Save the citation as USED when it is first called.
127+
CITATION_COMPASS_REGISTRY.mark_used(entry.key)
128+
return func(*args, **kwargs)
129+
130+
# We cast the wrapped function as the correct type.
131+
if isinstance(callable, classmethod):
132+
return classmethod(citation_wrapper)
133+
elif isinstance(callable, staticmethod):
134+
return staticmethod(citation_wrapper)
135+
elif isinstance(callable, types.MethodType):
136+
return types.MethodType(citation_wrapper, callable.__self__)
137+
return citation_wrapper
138+
139+
# Handle the optional parentheses in the decorator.
140+
return _inner_decorator if callable is None else _inner_decorator(callable)
124141

125142

126143
def cite_object(obj, label=None):

tests/citation_compass/test_citation.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import fake_module
22
import pytest
3+
import types
34

45
from citation_compass import (
56
cite_function,
@@ -31,13 +32,44 @@ def example_function_x(x):
3132
return x
3233

3334

35+
class _FakeTestingClass:
36+
"""A fake class for testing."""
37+
38+
def __init__(self, data=1):
39+
self.data = data
40+
41+
@classmethod
42+
@cite_function
43+
def fake_class_classmethod(cls):
44+
"""A fake classmethod for testing."""
45+
return cls(data=1)
46+
47+
@staticmethod
48+
@cite_function
49+
def fake_class_staticmethod():
50+
"""A fake staticmethod for testing."""
51+
return 0
52+
53+
@cite_function
54+
def fake_class_normal_method(self):
55+
"""A fake normal class method for testing."""
56+
return self.data
57+
58+
def uncited_method(self):
59+
"""A method that is not cited."""
60+
return self.data
61+
62+
3463
def test_citations_all():
3564
"""Check that all the citations are registered."""
3665
known_citations = [
3766
# The functions defined in this file.
3867
"test_citation.example_function_1: function_citation_1",
3968
"test_citation.example_function_2: function_citation_2",
4069
"test_citation.example_function_x: function_citation_x",
70+
"test_citation._FakeTestingClass.fake_class_classmethod: A fake classmethod for testing.",
71+
"test_citation._FakeTestingClass.fake_class_staticmethod: A fake staticmethod for testing.",
72+
"test_citation._FakeTestingClass.fake_class_normal_method: A fake normal class method for testing.",
4173
# The items defined in fake_module.
4274
"fake_module: CitationCompass, 2025.",
4375
"fake_module.FakeClass.fake_method: A fake class method for testing.",
@@ -60,8 +92,8 @@ def test_citations_all():
6092
obj = fake_module.FakeCitedClass()
6193
assert isinstance(obj, fake_module.FakeCitedClass)
6294

63-
# A citation with no docstring, but a label.
64-
@cite_function("function_citation_3")
95+
# A citation with no docstring, but a manual label.
96+
@cite_function(label="function_citation_3")
6597
def example_function_3():
6698
return 3
6799

@@ -238,6 +270,50 @@ def test_find_in_citations():
238270
assert len(find_in_citations("FakeCitedClass", True)) == 1
239271

240272

273+
def test_functions_in_class():
274+
"""Test that we correctly handle methods in a class including static and class methods."""
275+
obj = _FakeTestingClass(data=5)
276+
277+
# Nothing is used.
278+
assert len(find_in_citations("fake_class_normal_method", True)) == 0
279+
assert len(find_in_citations("fake_class_staticmethod", True)) == 0
280+
assert len(find_in_citations("fake_class_classmethod", True)) == 0
281+
282+
# All the functions are usable.
283+
assert obj.fake_class_normal_method() == 5
284+
assert obj.fake_class_staticmethod() == 0
285+
286+
obj2 = obj.fake_class_classmethod()
287+
assert isinstance(obj2, _FakeTestingClass)
288+
assert obj2.data == 1
289+
290+
# Everything is now cited.
291+
assert len(find_in_citations("fake_class_normal_method", True)) == 1
292+
assert len(find_in_citations("fake_class_staticmethod", True)) == 1
293+
assert len(find_in_citations("fake_class_classmethod", True)) == 1
294+
295+
# We preserve the types of each method when called from the class. The class method
296+
# static method should be those types.
297+
assert isinstance(_FakeTestingClass.__dict__["fake_class_normal_method"], types.FunctionType)
298+
assert isinstance(_FakeTestingClass.__dict__["fake_class_staticmethod"], staticmethod)
299+
assert isinstance(_FakeTestingClass.__dict__["fake_class_classmethod"], classmethod)
300+
301+
# Check the types when accessing an instance.
302+
assert isinstance(obj.fake_class_normal_method, types.MethodType)
303+
assert isinstance(obj.fake_class_staticmethod, types.FunctionType)
304+
assert isinstance(obj.fake_class_classmethod, types.MethodType)
305+
306+
# We preserve the names of the methods.
307+
assert obj.fake_class_classmethod.__name__ == "fake_class_classmethod"
308+
assert obj.fake_class_staticmethod.__name__ == "fake_class_staticmethod"
309+
assert obj.fake_class_normal_method.__name__ == "fake_class_normal_method"
310+
311+
# We preserve the docstring of the methods.
312+
assert obj.fake_class_classmethod.__doc__ == "A fake classmethod for testing."
313+
assert obj.fake_class_staticmethod.__doc__ == "A fake staticmethod for testing."
314+
assert obj.fake_class_normal_method.__doc__ == "A fake normal class method for testing."
315+
316+
241317
def test_citation_context():
242318
"""Test the CitationContext class."""
243319
reset_used_citations()

0 commit comments

Comments
 (0)