AsyncDNS-cr  resolver.cr at [b6cfd2fe24]

File src/resolver.cr artifact d8a14eee35 part of check-in b6cfd2fe24


require "socket"

require "./query"
require "./response"
require "./rr"
require "./settings"

module AsyncDNS
  class Resolver
    @v6_sock : UDPSocket?
    @v4_sock : UDPSocket?

    getter settings : Settings

    enum Error
      NO_NAME_SERVER
    end

    private class Context
      getter query : Query
      getter id : UInt16
      getter settings : Settings
      getter block : Response | Error ->
      property ns_index : Int32
      property attempt : Int32
      property used_ns : Socket::IPAddress | Nil
      getter raw_data : Bytes

      def initialize(@query : Query, @id : UInt16, @settings : Settings,
                     @block : Response | Error ->)
        @ns_index = 0
        @attempt = 0
        @used_ns = nil
        @raw_data = Bytes.new(512)

        # Header

        i = 0
        @raw_data[i] = (@id >> 8).to_u8; i += 1
        @raw_data[i] = (@id & 0xFF).to_u8; i += 1
        # RD
        @raw_data[i] = 1; i += 1
        i += 1
        # QDCOUNT
        i += 1
        @raw_data[i] = 1; i += 1
        # ANCOUNT, NSCOUNT and ARCOUNT
        i += 6

        # Question

        # QNAME
        @query.domain.split('.').each do |component|
          if component.bytesize > 63 || i + component.bytesize > 512
            raise ArgumentError.new("Domain component too long")
          end

          raw_component = component.to_slice
          @raw_data[i] = raw_component.bytesize.to_u8; i += 1
          @raw_data[i, raw_component.bytesize].copy_from(raw_component)
          i += raw_component.bytesize
        end

        # QTYPE
        qtype = @query.rr_type.to_i
        @raw_data[i] = (qtype >> 8).to_u8; i += 1
        @raw_data[i] = (qtype & 0xFF).to_u8; i+= 1

        # QCLASS
        qclass = @query.dns_class.to_i
        @raw_data[i] = (qclass >> 8).to_u8; i += 1
        @raw_data[i] = (qclass & 0xFF).to_u8; i += 1
      end
    end

    def initialize
      @settings = Settings.new
      @queries = Hash(UInt16, Context).new
      @tcp_queries = Hash(Socket, Context).new
    end

    def resolve(query : Query, &block : Response | Error ->) : Nil
      id : UInt16
      while true
        id = Random::Secure.rand(UInt16::MIN..UInt16::MAX)
        break unless @queries.has_key?(id)
      end

      if query.domain.bytesize > 253
        raise ArgumentError.new("Queried domain is too long")
      end

      if @settings.nameservers.empty?
        block.call(Error::NO_NAME_SERVER)
        return
      end

      send(Context.new(query.dup, id, settings.dup, block))
    end

    private def send(context : Context) : Nil
      @queries[context.id] = context

      ns = context.settings.nameservers[context.ns_index]
      context.used_ns = used_ns = Socket::IPAddress.new(ns, 53)

      sock : UDPSocket
      case used_ns.family
      when Socket::Family::INET6
        if @v6_sock.nil?
          @v6_sock = s = UDPSocket.new
          s.bind "::", 0
        end
        sock = @v6_sock.not_nil!
      when Socket::Family::INET
        if @v4_sock.nil?
          @v4_sock = s = UDPSocket.new
          s.bind "0.0.0.0", 0
        end
        sock = @v4_sock.not_nil!
      else
        raise ArgumentError.new("Nameserver must be INET or INET6")
      end
      sock.send(context.raw_data, used_ns)
    end

    def stop : Nil
      # TODO
    end
  end
end