@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546 case "NewSet" :
547547 pset , errs := oc .processNewSet (info , pkgPath , call , nil , varName )
548548 return pset , notePositionAll (exprPos , errs )
549+ case "Subtract" :
550+ pset , errs := oc .processSubtract (info , pkgPath , call , nil , varName )
551+ return pset , notePositionAll (exprPos , errs )
549552 case "Bind" :
550553 b , err := processBind (oc .fset , info , call )
551554 if err != nil {
@@ -880,6 +883,116 @@ func isPrevented(tag string) bool {
880883 return reflect .StructTag (tag ).Get ("wire" ) == "-"
881884}
882885
886+ func (oc * objectCache ) processSubtract (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (interface {}, []error ) {
887+ // Assumes that call.Fun is wire.Subtract.
888+ if len (call .Args ) < 2 {
889+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
890+ errors .New ("call to Subtract must specify types to be subtracted" ))}
891+ }
892+ firstArg , errs := oc .processExpr (info , pkgPath , call .Args [0 ], "" )
893+ if len (errs ) > 0 {
894+ return nil , errs
895+ }
896+ set , ok := firstArg .(* ProviderSet )
897+ if ! ok {
898+ return nil , []error {
899+ notePosition (oc .fset .Position (call .Pos ()),
900+ fmt .Errorf ("first argument to Subtract must be a Set" )),
901+ }
902+ }
903+ pset := & ProviderSet {
904+ Pos : call .Pos (),
905+ InjectorArgs : args ,
906+ PkgPath : pkgPath ,
907+ VarName : varName ,
908+ // Copy the other fields.
909+ Providers : set .Providers ,
910+ Bindings : set .Bindings ,
911+ Values : set .Values ,
912+ Fields : set .Fields ,
913+ Imports : set .Imports ,
914+ }
915+ ec := new (errorCollector )
916+ for _ , arg := range call .Args [1 :] {
917+ ptr , ok := info .TypeOf (arg ).(* types.Pointer )
918+ if ! ok {
919+ ec .add (notePosition (oc .fset .Position (arg .Pos ()),
920+ fmt .Errorf ("argument to Subtract must be a pointer" ),
921+ ))
922+ continue
923+ }
924+ ec .add (oc .filterType (pset , ptr .Elem ())... )
925+ }
926+ if len (ec .errors ) > 0 {
927+ return nil , ec .errors
928+ }
929+ return pset , nil
930+ }
931+
932+ func (oc * objectCache ) filterType (set * ProviderSet , t types.Type ) []error {
933+ hasType := func (outs []types.Type ) bool {
934+ for _ , o := range outs {
935+ if types .Identical (o , t ) {
936+ return true
937+ }
938+ pt , ok := o .(* types.Pointer )
939+ if ok && types .Identical (pt .Elem (), t ) {
940+ return true
941+ }
942+ }
943+ return false
944+ }
945+
946+ providers := make ([]* Provider , 0 , len (set .Providers ))
947+ for _ , p := range set .Providers {
948+ if ! hasType (p .Out ) {
949+ providers = append (providers , p )
950+ }
951+ }
952+ set .Providers = providers
953+
954+ bindings := make ([]* IfaceBinding , 0 , len (set .Bindings ))
955+ for _ , i := range set .Bindings {
956+ if ! types .Identical (i .Iface , t ) {
957+ bindings = append (bindings , i )
958+ }
959+ }
960+ set .Bindings = bindings
961+
962+ values := make ([]* Value , 0 , len (set .Values ))
963+ for _ , v := range set .Values {
964+ if ! types .Identical (v .Out , t ) {
965+ values = append (values , v )
966+ }
967+ }
968+ set .Values = values
969+
970+ fields := make ([]* Field , 0 , len (set .Fields ))
971+ for _ , f := range set .Fields {
972+ if ! hasType (f .Out ) {
973+ fields = append (fields , f )
974+ }
975+ }
976+ set .Fields = fields
977+
978+ imports := make ([]* ProviderSet , 0 , len (set .Imports ))
979+ for _ , p := range set .Imports {
980+ clone := * p
981+ if errs := oc .filterType (& clone , t ); len (errs ) > 0 {
982+ return errs
983+ }
984+ imports = append (imports , & clone )
985+ }
986+ set .Imports = imports
987+
988+ var errs []error
989+ set .providerMap , set .srcMap , errs = buildProviderMap (oc .fset , oc .hasher , set )
990+ if len (errs ) > 0 {
991+ return errs
992+ }
993+ return nil
994+ }
995+
883996// processBind creates an interface binding from a wire.Bind call.
884997func processBind (fset * token.FileSet , info * types.Info , call * ast.CallExpr ) (* IfaceBinding , error ) {
885998 // Assumes that call.Fun is wire.Bind.
@@ -1122,7 +1235,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
11221235 default :
11231236 invalid = true
11241237 }
1125-
11261238 }
11271239 if wireBuildCall == nil {
11281240 return nil , nil
0 commit comments