From 68213657bb2a0c9a6257192beff9ff0ccfc83182 Mon Sep 17 00:00:00 2001 From: Takeshi KOMIYA Date: Wed, 19 Feb 2025 04:04:19 +0900 Subject: [PATCH] Implement user-defined type guard method This implements the minimal version of user-defined type guard method. A user-defined type guard methos is a method defined by user that is able to guarantee the type of the objects (receiver or arguments). At present, Steep supports type guard methods provided by ruby-core (ex. `#is_a?`, `#nil?`, and so on). But we also have many kinds of user-defined methods that are able to check the type of the objects. Therefore user-defined type guard will help checking the type of these applications by narrowing types. This implementation uses an annotation to declare user-defined type guard method. ``` class Example < Integer %a{guard:self is Integer} def integer?: () -> bool end ``` For example, the above method `Example#integer?` is a user-defined type guard method that narrows the Example object itself to an Integer if the conditional branch passed. ``` example = Example.new if example.integer? example #=> Integer end ``` In this PR, the predicate of type guards only supports "self is TYPE" statement. I have a plan to extend it: * `%a{guard:self is arg}` * `%a{guard:self is_a arg}` * `%a{guard:self is TYPE_PARAM}` * `%a{guard:arg is TYPE}` Note: The compatibility of RBS syntax is the large reason of using annotations. I'm afraid that adding a new syntax to define it will bring breaking change to the RBS, and difficult to use it on common repository or generators (ex. gem_rbs_collection and rbs_rails). --- lib/steep/ast/types/logic.rb | 24 ++- lib/steep/interface/builder.rb | 36 ++++ .../type_inference/logic_type_interpreter.rb | 54 ++++- sig/steep/ast/types.rbs | 2 +- sig/steep/ast/types/logic.rbs | 13 ++ sig/steep/interface/builder.rbs | 4 + .../type_inference/logic_type_interpreter.rbs | 4 +- sig/test/type_check_test.rbs | 8 + test/type_check_test.rb | 188 ++++++++++++++++++ 9 files changed, 325 insertions(+), 8 deletions(-) diff --git a/lib/steep/ast/types/logic.rb b/lib/steep/ast/types/logic.rb index 6a9408379..c9a4dfedc 100644 --- a/lib/steep/ast/types/logic.rb +++ b/lib/steep/ast/types/logic.rb @@ -4,7 +4,7 @@ module Types module Logic class Base extend SharedInstance - + def subst(s) self end @@ -53,6 +53,28 @@ class ArgEqualsReceiver < Base class ArgIsAncestor < Base end + class Guard < Base + PATTERN = /\Aguard:\s*(self)\s+(is)\s+(\S+)\s*\Z/ + + attr_reader :subject + attr_reader :operator + attr_reader :type + + def initialize(subject:, operator:, type:) + @subject = subject + @operator = operator + @type = type + end + + def ==(other) + super && subject == other.subject && operator == other.operator && type == other.type + end + + def hash + self.class.hash ^ subject.hash ^ operator.hash ^ type.hash + end + end + class Env < Base attr_reader :truthy, :falsy, :type diff --git a/lib/steep/interface/builder.rb b/lib/steep/interface/builder.rb index f24a0e294..02852afbe 100644 --- a/lib/steep/interface/builder.rb +++ b/lib/steep/interface/builder.rb @@ -281,6 +281,7 @@ def singleton_shape(type_name) overloads = method.defs.map do |type_def| method_name = method_name_for(type_def, name) method_type = factory.method_type(type_def.type) + method_type = replace_guard_method(definition, type_def, method_type) method_type = replace_primitive_method(method_name, type_def, method_type) method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Builtin::Class.instance_type } method_type = add_implicitly_returns_nil(type_def.annotations, method_type) @@ -313,6 +314,7 @@ def object_shape(type_name) overloads = method.defs.map do |type_def| method_name = method_name_for(type_def, name) method_type = factory.method_type(type_def.type) + method_type = replace_guard_method(definition, type_def, method_type) method_type = replace_primitive_method(method_name, type_def, method_type) if type_name.class? method_type = replace_kernel_class(method_name, type_def, method_type) { AST::Types::Name::Singleton.new(name: type_name) } @@ -725,6 +727,40 @@ def proc_shape(proc, proc_shape) shape end + def replace_guard_method(definition, method_def, method_type) + match = method_def.annotations.filter_map { AST::Types::Logic::Guard::PATTERN.match(_1.string) }.first + if match + subject = match[1] or raise + operator = match[2] or raise + type_name = match[3] or raise + + type = RBS::Parser.parse_type(type_name) + raise "Unknown type: #{type_name}" unless type + + context = context_from(definition.type_name) + type = type.map_type_name { factory.absolute_type_name(_1, context: context) or raise "Unknown type: #{_1}" } + guard = AST::Types::Logic::Guard.new(subject: subject, operator: operator, type: type) + definition.type_name + method_type.with( + type: method_type.type.with(return_type: guard) + ) + else + method_type + end + rescue => exn + Steep.logger.error { exn.message } + method_type + end + + def context_from(type_name) + if type_name.namespace == RBS::Namespace.root + [nil, type_name] + else + parent = context_from(type_name.namespace.to_type_name) + [parent, type_name] + end + end + def replace_primitive_method(method_name, method_def, method_type) defined_in = method_def.defined_in member = method_def.member diff --git a/lib/steep/type_inference/logic_type_interpreter.rb b/lib/steep/type_inference/logic_type_interpreter.rb index 052b4eec0..4f39a6a91 100644 --- a/lib/steep/type_inference/logic_type_interpreter.rb +++ b/lib/steep/type_inference/logic_type_interpreter.rb @@ -405,6 +405,42 @@ def evaluate_method_call(env:, type:, receiver:, arguments:) truthy_result.update_type { FALSE } ] end + + when AST::Types::Logic::Guard + if receiver + receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver)) || raise + + # TODO: Expand the type params to the actual types + # TODO: Support argument types (ex. `self is arg1`) + # TODO: Support class' type param types (ex. `self is T`) + # TODO: Support method's type param types (ex. `self is T`) + # TODO: Support is_a operator + # TODO: Ensure the type exists + + sub_type = factory.type(type.type) + if no_subtyping?(sub_type: sub_type, super_type: receiver_type) + typing.add_error( + Diagnostic::Ruby::UnexpectedError.new(node: receiver, error: Exception.new("#{receiver_type} is not a subtype of #{type.type}")) + ) + return nil + end + + truthy_type, falsy_type = type_guard_type_case_select(receiver_type, sub_type) + truthy_env, falsy_env = refine_node_type( + env: env, + node: receiver, + truthy_type: truthy_type || receiver_type, + falsy_type: falsy_type || UNTYPED + ) + + truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false) + truthy_result.unreachable! unless truthy_type + + falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false) + falsy_result.unreachable! unless falsy_type + + [truthy_result, falsy_result] + end end end @@ -494,8 +530,8 @@ def literal_var_type_case_select(value_node, arg_type) end end - def type_case_select(type, klass) - truth_types, false_types = type_case_select0(type, klass) + def type_guard_type_case_select(type, guard_type) + truth_types, false_types = type_case_select0(type, guard_type) [ truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types), @@ -503,16 +539,24 @@ def type_case_select(type, klass) ] end - def type_case_select0(type, klass) + def type_case_select(type, klass) instance_type = factory.instance_type(klass) + truth_types, false_types = type_case_select0(type, instance_type) + + [ + truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types), + false_types.empty? ? nil : AST::Types::Union.build(types: false_types) + ] + end + def type_case_select0(type, instance_type) case type when AST::Types::Union truthy_types = [] # :Array[AST::Types::t] falsy_types = [] #: Array[AST::Types::t] type.types.each do |ty| - truths, falses = type_case_select0(ty, klass) + truths, falses = type_case_select0(ty, instance_type) if truths.empty? falsy_types.push(ty) @@ -529,7 +573,7 @@ def type_case_select0(type, klass) if ty == type [[type], [type]] else - type_case_select0(ty, klass) + type_case_select0(ty, instance_type) end when AST::Types::Any, AST::Types::Top, AST::Types::Var diff --git a/sig/steep/ast/types.rbs b/sig/steep/ast/types.rbs index d361f619d..183a1d764 100644 --- a/sig/steep/ast/types.rbs +++ b/sig/steep/ast/types.rbs @@ -6,7 +6,7 @@ module Steep | Intersection | Record | Tuple | Union | Name::Alias | Name::Instance | Name::Interface | Name::Singleton | Proc | Var - | Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Env + | Logic::Not | Logic::ReceiverIsNil | Logic::ReceiverIsNotNil | Logic::ReceiverIsArg | Logic::ArgIsReceiver | Logic::ArgEqualsReceiver | Logic::ArgIsAncestor | Logic::Guard | Logic::Env # Variables and special types that is subject for substitution # diff --git a/sig/steep/ast/types/logic.rbs b/sig/steep/ast/types/logic.rbs index 6570a4d70..86e38921d 100644 --- a/sig/steep/ast/types/logic.rbs +++ b/sig/steep/ast/types/logic.rbs @@ -50,6 +50,19 @@ module Steep class ArgIsAncestor < Base end + # A type for type guard. + class Guard < Base + PATTERN: Regexp + + attr_reader subject: String + attr_reader operator: String + attr_reader type: RBS::Types::t + + def self.new: (subject: String, operator: String, type: RBS::Types::t) -> Guard + + def initialize: (subject: String, operator: String, type: RBS::Types::t) -> void + end + # A type with truthy/falsy type environment. class Env < Base attr_reader truthy: TypeInference::TypeEnv diff --git a/sig/steep/interface/builder.rbs b/sig/steep/interface/builder.rbs index 0ad58ca17..7c326eb3e 100644 --- a/sig/steep/interface/builder.rbs +++ b/sig/steep/interface/builder.rbs @@ -110,6 +110,10 @@ module Steep def method_name_for: (RBS::Definition::Method::TypeDef, Symbol name) -> method_name + def replace_guard_method: (RBS::Definition, RBS::Definition::Method::TypeDef, MethodType) -> MethodType + + def context_from: (RBS::TypeName) -> RBS::Resolver::context + def replace_primitive_method: (method_name, RBS::Definition::Method::TypeDef, MethodType) -> MethodType def replace_kernel_class: (method_name, RBS::Definition::Method::TypeDef, MethodType) { () -> AST::Types::t } -> MethodType diff --git a/sig/steep/type_inference/logic_type_interpreter.rbs b/sig/steep/type_inference/logic_type_interpreter.rbs index 6e1ddbbe0..d2f0bbc7c 100644 --- a/sig/steep/type_inference/logic_type_interpreter.rbs +++ b/sig/steep/type_inference/logic_type_interpreter.rbs @@ -94,9 +94,11 @@ module Steep # def literal_var_type_case_select: (Parser::AST::Node value_node, AST::Types::t arg_type) -> [Array[AST::Types::t], Array[AST::Types::t]]? + def type_guard_type_case_select: (AST::Types::t `type`, AST::Types::t guard_type) -> [AST::Types::t?, AST::Types::t?] + def type_case_select: (AST::Types::t `type`, RBS::TypeName klass) -> [AST::Types::t?, AST::Types::t?] - def type_case_select0: (AST::Types::t `type`, RBS::TypeName klass) -> [Array[AST::Types::t], Array[AST::Types::t]] + def type_case_select0: (AST::Types::t `type`, AST::Types::t instance_type) -> [Array[AST::Types::t], Array[AST::Types::t]] def try_convert: (AST::Types::t, Symbol) -> AST::Types::t? diff --git a/sig/test/type_check_test.rbs b/sig/test/type_check_test.rbs index 65e74c2a3..004040de4 100644 --- a/sig/test/type_check_test.rbs +++ b/sig/test/type_check_test.rbs @@ -98,6 +98,14 @@ class TypeCheckTest < Minitest::Test def test_type_narrowing__local_variable_safe_navigation_operator: () -> untyped + def test_type_guard__self_is_TYPE: () -> untyped + + def test_type_guard__self_is_TYPE_no_subtyping: () -> untyped + + def test_type_guard__self_is_TYPE_unknown: () -> untyped + + def test_type_guard__self_is_TYPE_singleton: () -> untyped + def test_argument_error__unexpected_unexpected_positional_argument: () -> untyped def test_type_assertion__type_error: () -> untyped diff --git a/test/type_check_test.rb b/test/type_check_test.rb index 4fb6f6731..6f33c4dd2 100644 --- a/test/type_check_test.rb +++ b/test/type_check_test.rb @@ -1500,6 +1500,194 @@ def foo(context) ) end + def test_type_guard__self_is_TYPE + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Object + %a{guard:self is Integer} + def integer?: () -> bool + end + RBS + }, + code: { + "a.rb" => <<~RUBY + # @type var a: Object + a = (_ = nil) + + if a.integer? + a + 1 + else + a.succ + end + + a + 1 + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: + - range: + start: + line: 7 + character: 4 + end: + line: 7 + character: 8 + severity: ERROR + message: Type `::Object` does not have method `succ` + code: Ruby::NoMethod + - range: + start: + line: 10 + character: 2 + end: + line: 10 + character: 3 + severity: ERROR + message: Type `(::Integer | ::Object)` does not have method `+` + code: Ruby::NoMethod + YAML + ) + end + + def test_type_guard__self_is_TYPE_no_subtyping + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Integer + %a{guard:self is String} + def string?: () -> bool + end + RBS + }, + code: { + "a.rb" => <<~RUBY + a = 1 + + if a.string? + a.reverse + end + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: + - range: + start: + line: 3 + character: 3 + end: + line: 3 + character: 4 + severity: ERROR + message: 'UnexpectedError: ::Integer is not a subtype of ::String(Exception)' + code: Ruby::UnexpectedError + - range: + start: + line: 4 + character: 4 + end: + line: 4 + character: 11 + severity: ERROR + message: Type `::Integer` does not have method `reverse` + code: Ruby::NoMethod + YAML + ) + end + + def test_type_guard__self_is_TYPE_unknown + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Object + %a{guard:self is Unknown} + def integer?: () -> bool + end + RBS + }, + code: { + "a.rb" => <<~RUBY + # @type var a: Object + a = (_ = nil) + + if a.integer? + a + 1 + end + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: + - range: + start: + line: 5 + character: 4 + end: + line: 5 + character: 5 + severity: ERROR + message: Type `::Object` does not have method `+` + code: Ruby::NoMethod + YAML + ) + end + + def test_type_guard__self_is_TYPE_singleton + run_type_check_test( + signatures: { + "a.rbs" => <<~RBS + class Object + %a{guard:self is singleton(String)} + def self.string?: () -> bool + end + RBS + }, + code: { + "a.rb" => <<~RUBY + cls = Object + + if cls.string? + cls.new + "" + else + cls.new.succ + end + + cls.new + "" + RUBY + }, + expectations: <<~YAML + --- + - file: a.rb + diagnostics: + - range: + start: + line: 6 + character: 10 + end: + line: 6 + character: 14 + severity: ERROR + message: Type `::Object` does not have method `succ` + code: Ruby::NoMethod + - range: + start: + line: 9 + character: 8 + end: + line: 9 + character: 9 + severity: ERROR + message: Type `(::String | ::Object)` does not have method `+` + code: Ruby::NoMethod + YAML + ) + end + def test_argument_error__unexpected_unexpected_positional_argument run_type_check_test( signatures: {