Skip to content

Commit c9d3f87

Browse files
committed
Fix memory leak from reference cycle
This commit is taken from an umerged PR in taco (tensor-compiler#520).
1 parent de40a77 commit c9d3f87

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

src/tensor.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,10 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr);
476476
struct AccessTensorNode : public AccessNode {
477477
AccessTensorNode(TensorBase tensor, const std::vector<IndexVar>& indices)
478478
: AccessNode(tensor.getTensorVar(), indices, {}, false),
479-
tensor(tensor) {}
479+
tensorPtr(tensor.content) {}
480480

481481
AccessTensorNode(TensorBase tensor, const std::vector<std::shared_ptr<IndexVarInterface>>& indices)
482-
: AccessNode(tensor.getTensorVar()), tensor(tensor) {
482+
: AccessNode(tensor.getTensorVar()), tensorPtr(tensor.content) {
483483
// Create the vector of IndexVar to assign to this->indexVars.
484484
std::vector<IndexVar> ivars(indices.size());
485485
for (size_t i = 0; i < indices.size(); i++) {
@@ -517,10 +517,21 @@ struct AccessTensorNode : public AccessNode {
517517
this->indexVars = std::move(ivars);
518518
}
519519

520-
TensorBase tensor;
520+
// We hold a weak_ptr to the accessed TensorBase to avoid creating a reference
521+
// cycle between the accessed TensorBase and this AccessTensorNode, since the
522+
// TensorBase will store the AccessTensorNode (as part of an IndexExpr) as a
523+
// field on itself. Not using a weak pointer results in leaking TensorBases.
524+
std::weak_ptr<TensorBase::Content> tensorPtr;
525+
TensorBase getTensor() const {
526+
TensorBase tensor;
527+
tensor.content = tensorPtr.lock();
528+
return tensor;
529+
}
530+
521531
virtual void setAssignment(const Assignment& assignment) {
522-
tensor.syncDependentTensors();
532+
auto tensor = this->getTensor();
523533

534+
tensor.syncDependentTensors();
524535
Assignment assign = makeReductionNotation(assignment);
525536

526537
tensor.setNeedsPack(false);
@@ -751,7 +762,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
751762
taco_iassert(isa<AccessTensorNode>(node)) << "Unknown subexpression";
752763

753764
if (!util::contains(arguments, node->tensorVar)) {
754-
arguments.insert({node->tensorVar, to<AccessTensorNode>(node)->tensor});
765+
arguments.insert({node->tensorVar, to<AccessTensorNode>(node)->getTensor()});
755766
}
756767

757768
// Also add any tensors backing index sets of tensor accesses.
@@ -763,7 +774,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
763774
}
764775

765776
// TODO (rohany): This seems like dead code.
766-
TensorBase tensor = to<AccessTensorNode>(node)->tensor;
777+
TensorBase tensor = to<AccessTensorNode>(node)->getTensor();
767778
if (!util::contains(inserted, tensor)) {
768779
inserted.insert(tensor);
769780
operands.push_back(tensor);

0 commit comments

Comments
 (0)