Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Break Versioning](https://www.taoensso.com/break-ve

## [Unreleased]

### Fixed

- Support for intersection types (created with `&` operator) in schema definitions (fixes #494) (@baweaver)

## [1.14.1] - 2025-03-03

Expand Down
4 changes: 4 additions & 0 deletions lib/dry/schema/macros/dsl.rb
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def extract_type_spec(args, nullable: false, set_type: true)
type_rule = [type_spec.left, type_spec.right].map { |ts|
new(klass: Core, chain: false).value(ts)
}.reduce(:|)
elsif type_spec.is_a?(Dry::Types::Intersection) && set_type
type_rule = [type_spec.left, type_spec.right].map { |ts|
new(klass: Core, chain: false).value(ts)
}.reduce(:&)
else
type_predicates = predicate_inferrer[resolved_type]

Expand Down
11 changes: 10 additions & 1 deletion lib/dry/schema/predicate_inferrer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@ module Dry
module Schema
# @api private
class PredicateInferrer < ::Dry::Types::PredicateInferrer
Compiler = ::Class.new(superclass::Compiler)
Compiler = ::Class.new(superclass::Compiler) do
# @api private
def visit_intersection(node)
left_node, right_node, = node
left = visit(left_node)
right = visit(right_node)

[left, right].flatten.compact
end
end

def initialize(registry = PredicateRegistry.new)
super
Expand Down
8 changes: 7 additions & 1 deletion lib/dry/schema/primitive_inferrer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ module Dry
module Schema
# @api private
class PrimitiveInferrer < ::Dry::Types::PrimitiveInferrer
Compiler = ::Class.new(superclass::Compiler)
Compiler = ::Class.new(superclass::Compiler) do
# @api private
def visit_intersection(node)
left, right = node
[visit(left), visit(right)].flatten(1)
end
end

def initialize
super
Expand Down
72 changes: 72 additions & 0 deletions spec/integration/schema/intersection_types_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# frozen_string_literal: true

RSpec.describe "Intersection types" do
context "with hash schemas" do
let(:schema) do
Dry::Schema.Params do
required(:body).value(
Types::Hash.schema(a: Types::String) &
(Types::Hash.schema(b: Types::String) | Types::Hash.schema(c: Types::String))
)
end
end

it "validates intersection of hash schemas successfully" do
result = schema.call(body: {a: "test", b: "value"})

expect(result).to be_success
expect(result.to_h).to eq(body: {a: "test", b: "value"})
end

it "validates intersection with alternative branch" do
result = schema.call(body: {a: "test", c: "value"})

expect(result).to be_success
expect(result.to_h).to eq(body: {a: "test", c: "value"})
end

it "fails when intersection requirements not met" do
result = schema.call(body: {b: "value"})

expect(result).to be_failure
expect(result.errors.to_h).to eq(body: {a: ["is missing"]})
end
end

context "with simple type intersection" do
let(:schema) do
Dry::Schema.Params do
required(:value).value(Types::String & Types::Params::String)
end
end

it "validates simple intersection types" do
result = schema.call(value: "test")

expect(result).to be_success
expect(result.to_h).to eq(value: "test")
end
end

context "with DSL predicates and intersection" do
let(:schema) do
Dry::Schema.Params do
required(:name).value(Types::String & Types::Params::String) { filled? & min_size?(2) }
end
end

it "combines type intersection with predicate rules" do
result = schema.call(name: "John")

expect(result).to be_success
expect(result.to_h).to eq(name: "John")
end

it "fails when predicate rules not met" do
result = schema.call(name: "J")

expect(result).to be_failure
expect(result.errors.to_h).to eq(name: ["size cannot be less than 2"])
end
end
end