AsyncDNS-cr  resolver.cr at [e71ccf3e48]

File src/resolver.cr artifact b25ecb3e05 part of check-in e71ccf3e48


     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
   100
   101
   102
   103
   104
   105
   106
   107
   108
   109
   110
   111
   112
   113
   114
   115
   116
   117
   118
   119
   120
   121
   122
   123
   124
   125
   126
   127
   128
   129
   130
   131
   132
   133
   134
   135
   136
   137
   138
   139
   140
   141
   142
   143
   144
   145
   146
   147
   148
   149
   150
   151
   152
   153
   154
   155
   156
   157
   158
   159
   160
   161
   162
   163
   164
   165
   166
   167
   168
   169
   170
   171
   172
   173
   174
   175
   176
   177
   178
   179
   180
   181
   182
   183
   184
   185
   186
   187
   188
   189
   190
   191
   192
   193
   194
   195
   196
   197
   198
   199
   200
   201
   202
   203
   204
   205
   206
   207
   208
   209
   210
   211
   212
   213
   214
   215
   216
   217
   218
   219
   220
   221
   222
   223
require "socket"

require "./query"
require "./response"
require "./rr"
require "./settings"

# TODO: Cancel timer
# TODO: TCP

module AsyncDNS
  class Resolver
    @v6_sock : UDPSocket?
    @v4_sock : UDPSocket?

    getter settings : Settings

    enum Error
      NO_NAME_SERVER
      INVALID_REPLY
      TRUNCATED
      SERVER_INVALID_FORMAT
      SERVER_FAILURE
      SERVER_NAME_ERROR
      SERVER_NOT_IMPLEMENTED
      SERVER_REFUSED
      UNKNOWN
    end

    private class Context
      getter query : Query
      getter id : UInt16
      getter settings : Settings
      getter block : Response | Error ->
      property ns_index : Int32
      property attempt : Int32
      property used_ns : Socket::IPAddress | Nil
      getter raw_data : Bytes

      def initialize(@query : Query, @id : UInt16, @settings : Settings,
                     @block : Response | Error ->)
        @ns_index = 0
        @attempt = 0
        @used_ns = nil

        # Header

        io = IO::Memory.new(512)
        io.write_bytes(@id, IO::ByteFormat::BigEndian)
        # RD
        io.write_bytes(0x100_u16, IO::ByteFormat::BigEndian)
        # QDCOUNT
        io.write_bytes(1_u16, IO::ByteFormat::BigEndian)
        # ANCOUNT, NSCOUNT and ARCOUNT
        3.times { io.write_bytes(0_u16) }

        # Question

        # QNAME
        @query.domain.split('.').each do |component|
          if component.bytesize > 63 || io.size + component.bytesize > 512
            raise ArgumentError.new("Domain component too long")
          end

          io.write_bytes(component.bytesize.to_u8)
          io << component
        end

        # QTYPE
        io.write_bytes(@query.rr_type.to_u16, IO::ByteFormat::BigEndian)
        # QCLASS
        io.write_bytes(@query.dns_class.to_u16, IO::ByteFormat::BigEndian)

        @raw_data = io.to_slice
      end
    end

    def initialize
      @settings = Settings.new
      @queries = Hash(UInt16, Context).new
      @tcp_queries = Hash(Socket, Context).new
    end

    def resolve(query : Query, &block : Response | Error ->) : Nil
      while true
        id = Random::Secure.rand(UInt16::MIN..UInt16::MAX).to_u16
        break unless @queries.has_key?(id)
      end

      if query.domain.bytesize > 253
        raise ArgumentError.new("Queried domain is too long")
      end

      if @settings.nameservers.empty?
        block.call(Error::NO_NAME_SERVER)
        return
      end

      send(Context.new(query.dup, id, settings.dup, block))
    end

    private def send(context : Context) : Nil
      @queries[context.id] = context

      ns = context.settings.nameservers[context.ns_index]
      context.used_ns = used_ns = Socket::IPAddress.new(ns, 53)

      case used_ns.family
      when Socket::Family::INET6
        if @v6_sock.nil?
          @v6_sock = s = UDPSocket.new
          s.bind "::", 0
          start_recv_loop(s)
        end
        sock = @v6_sock.not_nil!
      when Socket::Family::INET
        if @v4_sock.nil?
          @v4_sock = s = UDPSocket.new
          s.bind "0.0.0.0", 0
          start_recv_loop(s)
        end
        sock = @v4_sock.not_nil!
      else
        raise ArgumentError.new("Nameserver must be INET or INET6")
      end
      sock.send(context.raw_data, used_ns)
    end

    def start_recv_loop(sock : UDPSocket) : Nil
      spawn do
        buf = Bytes.new(0x10000)

        until sock.closed?
          begin
            len, addr = sock.receive(buf)
          rescue ex : IO::Error
            break if sock.closed? && ex.os_error.nil?
            raise ex
          end
          handle_packet(buf[0, len], addr)
        end
      end
    end

    def handle_packet(packet : Bytes, sender : Socket::IPAddress) : Nil
      begin
        id = packet[0].to_u16 << 8 | packet[1]

        begin
          context = @queries[id]
        rescue KeyError
          return
        end

        return if sender != context.used_ns

        @queries.delete(id)

        if packet[2] & 0x80 == 0 ||                       # QR
           packet[2] & 0x78 != context.raw_data[2] & 0x78 # Opcode
          context.block.call(Error::INVALID_REPLY)
          return
        end

        if packet[2] & 0x02 != 0 # TC
          # TODO: Switch to TCP
          context.block.call(Error::TRUNCATED)
          return
        end

        case packet[3] & 0x0F # RCODE
        when 0
          error = nil
          try_next_ns = false
        when 1
          error = Error::SERVER_INVALID_FORMAT
          try_next_ns = false
        when 2
          error = Error::SERVER_FAILURE
          try_next_ns = true
        when 3
          error = Error::SERVER_NAME_ERROR
          try_next_ns = false
        when 4
          error = Error::SERVER_NOT_IMPLEMENTED
          try_next_ns = true
        when 5
          error = Error::SERVER_REFUSED
          try_next_ns = true
        else
          error = Error::UNKNOWN
          try_next_ns = true
        end

        if try_next_ns
          if context.ns_index + 1 < context.settings.nameservers.size
            context.ns_index += 1
            send(context)
            return
          end
        end

        if error
          context.block.call(error)
          return
        end

        io = IO::Memory.new(packet[4, packet.size - 4], writeable: false)
        qdcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        adcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        nscount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        arcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
      rescue IndexError
      rescue IO::EOFError
      end
    end

    def stop : Nil
      @v6_sock.try { |s| s.close }
      @v4_sock.try { |s| s.close }
    end
  end
end