AsyncDNS-cr  Artifact [a2ff58132d]

Artifact a2ff58132d04e4590684650e34b4c4d02beaeff79b160d688f01e459114e40e5:


require "socket"

require "./query"
require "./response"
require "./rr"
require "./settings"

# TODO: Cancel timer
# TODO: TCP

module AsyncDNS
  class Resolver
    @v6_sock : UDPSocket?
    @v4_sock : UDPSocket?

    getter settings : Settings

    enum Error
      NO_NAME_SERVER
      INVALID_REPLY
      TRUNCATED
      SERVER_INVALID_FORMAT
      SERVER_FAILURE
      SERVER_NAME_ERROR
      SERVER_NOT_IMPLEMENTED
      SERVER_REFUSED
      UNKNOWN
    end

    private class Context
      getter query : Query
      getter id : UInt16
      getter settings : Settings
      getter block : Response | Error ->
      property ns_index : Int32
      property attempt : Int32
      property used_ns : Socket::IPAddress | Nil
      getter raw_data : Bytes

      def initialize(@query : Query, @id : UInt16, @settings : Settings,
                     @block : Response | Error ->)
        @ns_index = 0
        @attempt = 0
        @used_ns = nil

        # Header

        io = IO::Memory.new(512)
        io.write_bytes(@id, IO::ByteFormat::BigEndian)
        # RD
        io.write_bytes(0x100_u16, IO::ByteFormat::BigEndian)
        # QDCOUNT
        io.write_bytes(1_u16, IO::ByteFormat::BigEndian)
        # ANCOUNT, NSCOUNT and ARCOUNT
        3.times { io.write_bytes(0_u16) }

        # Question

        # QNAME
        @query.domain.split('.').each do |component|
          if component.bytesize > 63 || io.size + component.bytesize > 512
            raise ArgumentError.new("Domain component too long")
          end

          io.write_bytes(component.bytesize.to_u8)
          io << component
        end

        # QTYPE
        io.write_bytes(@query.rr_type.to_u16, IO::ByteFormat::BigEndian)
        # QCLASS
        io.write_bytes(@query.dns_class.to_u16, IO::ByteFormat::BigEndian)

        @raw_data = io.to_slice
      end
    end

    def initialize
      @settings = Settings.new
      @queries = Hash(UInt16, Context).new
      @tcp_queries = Hash(Socket, Context).new
    end

    def resolve(query : Query, &block : Response | Error ->) : Nil
      while true
        id = Random::Secure.rand(UInt16::MIN..UInt16::MAX).to_u16
        break unless @queries.has_key?(id)
      end

      if query.domain.bytesize > 253
        raise ArgumentError.new("Queried domain is too long")
      end

      if @settings.nameservers.empty?
        block.call(Error::NO_NAME_SERVER)
        return
      end

      send(Context.new(query.dup, id, settings.dup, block))
    end

    private def send(context : Context) : Nil
      @queries[context.id] = context

      ns = context.settings.nameservers[context.ns_index]
      context.used_ns = used_ns = Socket::IPAddress.new(ns, 53)

      case used_ns.family
      when Socket::Family::INET6
        if @v6_sock.nil?
          @v6_sock = s = UDPSocket.new
          s.bind "::", 0
          start_recv_loop(s)
        end
        sock = @v6_sock.not_nil!
      when Socket::Family::INET
        if @v4_sock.nil?
          @v4_sock = s = UDPSocket.new
          s.bind "0.0.0.0", 0
          start_recv_loop(s)
        end
        sock = @v4_sock.not_nil!
      else
        raise ArgumentError.new("Nameserver must be INET or INET6")
      end
      sock.send(context.raw_data, used_ns)
    end

    def start_recv_loop(sock : UDPSocket) : Nil
      spawn do
        buf = Bytes.new(0x10000)

        until sock.closed?
          begin
            len, addr = sock.receive(buf)
          rescue ex : IO::Error
            break if sock.closed? && ex.os_error.nil?
            raise ex
          end
          handle_packet(buf[0, len], addr)
        end
      end
    end

    def handle_packet(packet : Bytes, sender : Socket::IPAddress) : Nil
      io = IO::Memory.new(packet, writeable: false)
      begin
        id = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)

        begin
          context = @queries[id]
        rescue KeyError
          return
        end

        return if sender != context.used_ns

        @queries.delete(id)

        byte = io.read_bytes(UInt8)
        if byte & 0x80 == 0 ||                       # QR
           byte & 0x78 != context.raw_data[2] & 0x78 # Opcode
          context.block.call(Error::INVALID_REPLY)
          return
        end

        if byte & 0x02 != 0 # TC
          # TODO: Switch to TCP
          context.block.call(Error::TRUNCATED)
          return
        end

        case io.read_bytes(UInt8) & 0x0F # RCODE
        when 0
          error = nil
          try_next_ns = false
        when 1
          error = Error::SERVER_INVALID_FORMAT
          try_next_ns = false
        when 2
          error = Error::SERVER_FAILURE
          try_next_ns = true
        when 3
          error = Error::SERVER_NAME_ERROR
          try_next_ns = false
        when 4
          error = Error::SERVER_NOT_IMPLEMENTED
          try_next_ns = true
        when 5
          error = Error::SERVER_REFUSED
          try_next_ns = true
        else
          error = Error::UNKNOWN
          try_next_ns = true
        end

        if try_next_ns
          if context.ns_index + 1 < context.settings.nameservers.size
            context.ns_index += 1
            send(context)
            return
          end
        end

        if error
          context.block.call(error)
          return
        end

        qdcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        adcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        nscount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
        arcount = io.read_bytes(UInt16, IO::ByteFormat::BigEndian)
      rescue IO::EOFError
      end
    end

    def stop : Nil
      @v6_sock.try { |s| s.close }
      @v4_sock.try { |s| s.close }
    end
  end
end