123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- #!/usr/bin/env python2.7
- # Copyright 2015 gRPC authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Starts a local DNS server for use in tests"""
- import argparse
- import os
- import platform
- import signal
- import sys
- import threading
- import time
- import twisted
- import twisted.internet
- import twisted.internet.defer
- import twisted.internet.protocol
- import twisted.internet.reactor
- import twisted.internet.threads
- import twisted.names
- from twisted.names import authority
- from twisted.names import client
- from twisted.names import common
- from twisted.names import dns
- from twisted.names import server
- import twisted.names.client
- import twisted.names.dns
- import twisted.names.server
- import yaml
- _SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax
- _SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
- class NoFileAuthority(authority.FileAuthority):
- def __init__(self, soa, records):
- # skip FileAuthority
- common.ResolverBase.__init__(self)
- self.soa = soa
- self.records = records
- def start_local_dns_server(args):
- all_records = {}
- def _push_record(name, r):
- name = name.encode('ascii')
- print('pushing record: |%s|' % name)
- if all_records.get(name) is not None:
- all_records[name].append(r)
- return
- all_records[name] = [r]
- def _maybe_split_up_txt_data(name, txt_data, r_ttl):
- txt_data = txt_data.encode('ascii')
- start = 0
- txt_data_list = []
- while len(txt_data[start:]) > 0:
- next_read = len(txt_data[start:])
- if next_read > 255:
- next_read = 255
- txt_data_list.append(txt_data[start:start + next_read])
- start += next_read
- _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
- with open(args.records_config_path) as config:
- test_records_config = yaml.load(config)
- common_zone_name = test_records_config['resolver_tests_common_zone_name']
- for group in test_records_config['resolver_component_tests']:
- for name in group['records'].keys():
- for record in group['records'][name]:
- r_type = record['type']
- r_data = record['data']
- r_ttl = int(record['TTL'])
- record_full_name = '%s.%s' % (name, common_zone_name)
- assert record_full_name[-1] == '.'
- record_full_name = record_full_name[:-1]
- if r_type == 'A':
- _push_record(record_full_name,
- dns.Record_A(r_data, ttl=r_ttl))
- if r_type == 'AAAA':
- _push_record(record_full_name,
- dns.Record_AAAA(r_data, ttl=r_ttl))
- if r_type == 'SRV':
- p, w, port, target = r_data.split(' ')
- p = int(p)
- w = int(w)
- port = int(port)
- target_full_name = (
- '%s.%s' % (target, common_zone_name)).encode('ascii')
- _push_record(
- record_full_name,
- dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
- if r_type == 'TXT':
- _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
- # Add an optional IPv4 record is specified
- if args.add_a_record:
- extra_host, extra_host_ipv4 = args.add_a_record.split(':')
- _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0))
- # Server health check record
- _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME,
- dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
- soa_record = dns.Record_SOA(mname=common_zone_name.encode('ascii'))
- test_domain_com = NoFileAuthority(
- soa=(common_zone_name.encode('ascii'), soa_record),
- records=all_records,
- )
- server = twisted.names.server.DNSServerFactory(
- authorities=[test_domain_com], verbose=2)
- server.noisy = 2
- twisted.internet.reactor.listenTCP(args.port, server)
- dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
- dns_proto.noisy = 2
- twisted.internet.reactor.listenUDP(args.port, dns_proto)
- print('starting local dns server on 127.0.0.1:%s' % args.port)
- print('starting twisted.internet.reactor')
- twisted.internet.reactor.suggestThreadPoolSize(1)
- twisted.internet.reactor.run()
- def _quit_on_signal(signum, _frame):
- print('Received SIGNAL %d. Quitting with exit code 0' % signum)
- twisted.internet.reactor.stop()
- sys.stdout.flush()
- sys.exit(0)
- def flush_stdout_loop():
- num_timeouts_so_far = 0
- sleep_time = 1
- # Prevent zombies. Tests that use this server are short-lived.
- max_timeouts = 60 * 10
- while num_timeouts_so_far < max_timeouts:
- sys.stdout.flush()
- time.sleep(sleep_time)
- num_timeouts_so_far += 1
- print('Process timeout reached, or cancelled. Exitting 0.')
- os.kill(os.getpid(), signal.SIGTERM)
- def main():
- argp = argparse.ArgumentParser(
- description='Local DNS Server for resolver tests')
- argp.add_argument('-p',
- '--port',
- default=None,
- type=int,
- help='Port for DNS server to listen on for TCP and UDP.')
- argp.add_argument(
- '-r',
- '--records_config_path',
- default=None,
- type=str,
- help=('Directory of resolver_test_record_groups.yaml file. '
- 'Defaults to path needed when the test is invoked as part '
- 'of run_tests.py.'))
- argp.add_argument(
- '--add_a_record',
- default=None,
- type=str,
- help=('Add an A record via the command line. Useful for when we '
- 'need to serve a one-off A record that is under a '
- 'different domain then the rest the records configured in '
- '--records_config_path (which all need to be under the '
- 'same domain). Format: <name>:<ipv4 address>'))
- args = argp.parse_args()
- signal.signal(signal.SIGTERM, _quit_on_signal)
- signal.signal(signal.SIGINT, _quit_on_signal)
- output_flush_thread = threading.Thread(target=flush_stdout_loop)
- output_flush_thread.setDaemon(True)
- output_flush_thread.start()
- start_local_dns_server(args)
- if __name__ == '__main__':
- main()
|