@@ -476,10 +476,10 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr);
476
476
struct AccessTensorNode : public AccessNode {
477
477
AccessTensorNode (TensorBase tensor, const std::vector<IndexVar>& indices)
478
478
: AccessNode(tensor.getTensorVar(), indices, {}, false ),
479
- tensor (tensor) {}
479
+ tensorPtr (tensor.content ) {}
480
480
481
481
AccessTensorNode (TensorBase tensor, const std::vector<std::shared_ptr<IndexVarInterface>>& indices)
482
- : AccessNode(tensor.getTensorVar()), tensor (tensor) {
482
+ : AccessNode(tensor.getTensorVar()), tensorPtr (tensor.content ) {
483
483
// Create the vector of IndexVar to assign to this->indexVars.
484
484
std::vector<IndexVar> ivars (indices.size ());
485
485
for (size_t i = 0 ; i < indices.size (); i++) {
@@ -517,10 +517,21 @@ struct AccessTensorNode : public AccessNode {
517
517
this ->indexVars = std::move (ivars);
518
518
}
519
519
520
- TensorBase tensor;
520
+ // We hold a weak_ptr to the accessed TensorBase to avoid creating a reference
521
+ // cycle between the accessed TensorBase and this AccessTensorNode, since the
522
+ // TensorBase will store the AccessTensorNode (as part of an IndexExpr) as a
523
+ // field on itself. Not using a weak pointer results in leaking TensorBases.
524
+ std::weak_ptr<TensorBase::Content> tensorPtr;
525
+ TensorBase getTensor () const {
526
+ TensorBase tensor;
527
+ tensor.content = tensorPtr.lock ();
528
+ return tensor;
529
+ }
530
+
521
531
virtual void setAssignment (const Assignment& assignment) {
522
- tensor. syncDependentTensors ();
532
+ auto tensor = this -> getTensor ();
523
533
534
+ tensor.syncDependentTensors ();
524
535
Assignment assign = makeReductionNotation (assignment);
525
536
526
537
tensor.setNeedsPack (false );
@@ -751,7 +762,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
751
762
taco_iassert (isa<AccessTensorNode>(node)) << " Unknown subexpression" ;
752
763
753
764
if (!util::contains (arguments, node->tensorVar )) {
754
- arguments.insert ({node->tensorVar , to<AccessTensorNode>(node)->tensor });
765
+ arguments.insert ({node->tensorVar , to<AccessTensorNode>(node)->getTensor () });
755
766
}
756
767
757
768
// Also add any tensors backing index sets of tensor accesses.
@@ -763,7 +774,7 @@ static inline map<TensorVar, TensorBase> getTensors(const IndexExpr& expr) {
763
774
}
764
775
765
776
// TODO (rohany): This seems like dead code.
766
- TensorBase tensor = to<AccessTensorNode>(node)->tensor ;
777
+ TensorBase tensor = to<AccessTensorNode>(node)->getTensor () ;
767
778
if (!util::contains (inserted, tensor)) {
768
779
inserted.insert (tensor);
769
780
operands.push_back (tensor);
0 commit comments