From 0b1c833ed2f70394b8dab48939fa54a1282db956 Mon Sep 17 00:00:00 2001 From: Achille Date: Fri, 4 Dec 2020 11:33:51 -0800 Subject: [PATCH] 0.4 (#438) * add protocol package * add some documentation * fix * make ByteSequence more generic + add more benchmarks * WIP: add support for record batches * finish support for record batches * add support for recort set compression * backward-compatible compression codec imports * fix compress tests * make it possible for the transport to connect to multiple clusters + enhance kafka.Client to expose methods for creating and deleting topics * support responding to metadata requests with cached response * manage proper shutdown of client transport in tests * WIP: test Produce API * WIP: massive cleanup + track down CRC32 validation issue * functional Produce and Fetch implementations * add metadata request/response * add listoffsets API * expose cluster id and controller in metadata response * remove bufio.Writer from the protocol API * remove bufio.Reader from the protocol API * add back deprecated Client methods * fixes for kafka 0.10 * cleanup comment in protocol/record.go * add more comments * reduce size of bufio.Reader buffer on kafka connections * refactor transport internals to support splitting requests and dispatching them across multiple brokers * avoid contention on connection pool mutex in most cases * cleanup * add kafka.(*Client).MultiFetch API * close records in produce request * refactor record batch APIs to fully support streaming * remove io.Closer from protocol.RecordBatch * never return nil record batches * record batch fixes * remove unused variable * fix reading of multiple topic partitions in produce and fetch messages * alias compress.Compression in the kafka package * expose compression constants in the kafka package * exposes kafka.Request and kafka.Response interfaces * simplify the protocol.Bytes interface * simplify error management in protocol package * wait for topic creation to propagate + fix request dispatching in multi-broker clusters * simplify kafka.(*Client).CreateTopics API * improve error handling + wait for metadata propagation after topic creation * revisit connection pool implementation to remove multiplexing * fix panic when referencing truncated page buffer * fix unexpected EOF errors reading kafka messages * revisit record reader API * fix panic type asserting nil response into *metadata.Response * optimize allocation of broker ids in cluster metadata * unify sync.Pool usage * reduce memory footprint of protocol.(*RecordSet).readFromVersion2 * fix panic accessing optimized record reader with a nil headers slice * add APIs for marshaling and unmarshaling kafka values * [skip ci] fix README example * investigate-multi-fetch-issues * remove MultiFetch API * simplify protocol tests * add benchmarks for kafka.Marshal and kafka.Unmarshal * fix crash on cluster layout changes * add more error codes * remove partial support for flexible message format * downgrade metadata test from v9 to v8 * test against kafka 2.5.0 * Update offsetfetch.go Co-authored-by: Jeremy Jackins * Update offsetfetch.go Co-authored-by: Jeremy Jackins * Update offsetfetch.go Co-authored-by: Jeremy Jackins * fix typos * fix more typos * set pprof labels on transport goroutines (#458) * change tests to run against 2.4.1 instead of 2.5.0 * support up to 2.3.1 (TestConn/nettest/PingPong fails with 2.4 and above) * Update README.md Co-authored-by: Steve van Loben Sels * Update client.go Co-authored-by: Steve van Loben Sels * comment on why we devide the timeout by 2 * protocol.Reducer => protocol.Merger * cleanup docker-compose.yml * protocol.Mapper => protocol.Splitter * propagate the caller's context to the dial function (#460) * fix backward compatiblity with kafka-go v0.3.x * fix record offsets when fetching messages with version 1 * default record timestamps to current timestamp * revert changes to docker-compose.yml * fix tests * fix tests (2) * 0.4: kafka.Writer (#461) * 0.4: kafka.Writer * update README * disable some parallel tests * disable global parallelism in tests * fix typo * disable parallelism in sub-packages tests * properly seed random sources + delete test topics * cleanup build * run all tests * fix tests * enable more SASL mechanisms on CI * try to fix the CI config * try testing the sasl package with 2.3.1 only * inline configuration for kafka 2.3.1 in CI * fix zookeeper hostname in CI * cleanup CI config * keep the kafka 0.10 configuration separate + test against more kafka versions * fix kafka 0.11 image tag * try caching dependencies * support multiple broker addresses * uncomment max attempt test * fix typos * guard against empty kafka.MultiAddr in kafka.Transport * don't export new APIs for network addresses + adapt to any multi-addr implementation * add comment about the transport caching the metadata responses * 0.4 fix tls address panic (#478) * 0.4: fix panic when TLS is enabled * 0.4: fix panic when establishing TLS connections * cleanup * Update transport_test.go Co-authored-by: Steve van Loben Sels * validate that an error is returned Co-authored-by: Steve van Loben Sels * 0.4: fix short writes (#479) * 0.4: modify protocol.Bytes to expose the number of remaining bytes instead of the full size of the sequence (#485) * modify protocol.Bytes to expose the number of remaining bytes instead of the full size of the sequence * add test for pageRef.ReadByte + fix pageRef.scan * reuse contiguousPages.scan * fix(writer): set correct balancer (#489) Sets the correct balancer as passed through in the config on the writer Co-authored-by: Steve van Loben Sels Co-authored-by: Artur * Fix for panic when RequiredAcks is set to RequireNone (#504) * Fix panic in async wait() method when RequiredAcks is None When RequiredAcks is None, the producer does not wait for a response from the broker, therefore the response is nil. The async wait() method was not handling this case, leading to a panic. * Add regression test for RequiredAcks == RequireNone This new test is required because all the other Writer tests use NewWriter() to create Writers, which sets RequiredAcks to RequireAll when 0 (None) was specified. * fix: writer test for RequiredAcks=None * fix: writer tests for RequiredAcks=None (2) * 0.4 broker resolver (#526) * 0.4: kafka.BrokerResolver * add kafka.Transport.Context * inline network and address fields in conn type * Fix sasl authentication on writer (#541) The authenticateSASL was called before getting api version. This resulted incorrect apiversion (0 instead of 1) when calling saslHandshakeRoundTrip request * Remove deprecated function (NewWriter) usages (#528) * fix zstd decoder leak (#543) * fix zstd decoder leak * fix tests * fix panic * fix tests (2) * fix tests (3) * fix tests (4) * move ConnWaitGroup to testing package * fix zstd codec * Update compress/zstd/zstd.go Co-authored-by: Nicholas Sun * PR feedback Co-authored-by: Nicholas Sun * improve custom resolver support by allowing port to be overridden (#545) * 0.4: reduce memory footprint (#547) * Bring over flexible message changes * Add docker-compose config for kafka 2.4.1 * Misc. cleanups * Add protocol tests and fix issues * Misc. fixes; run circleci on v2.4.1 * Skip conntest for v2.4.1 * Disable nettests for kafka 2.4.1 * Revert formatting changes * Misc. fixes * Update comments * Make create topics test more interesting * feat(writer): add support for writing messages to multiple topics (#561) * Add comments on failing nettests * Fix spacing * Update var int sizing * Simplify writeVarInt implementation * Revert encoding change * Simplify varint encoding functions and expand tests * Also test sizeOf functions in protocol test * chore: merge master and resolve conflicts (#570) Co-authored-by: Jeremy Jackins Co-authored-by: Steve van Loben Sels Co-authored-by: Artur Co-authored-by: Neil Cook Co-authored-by: Ahmy Yulrizka Co-authored-by: Turfa Auliarachman Co-authored-by: Nicholas Sun Co-authored-by: Dominic Barnes Co-authored-by: Benjamin Yolken Co-authored-by: Benjamin Yolken <54862872+yolken-segment@users.noreply.github.com> --- .circleci/config.yml | 298 ++-- README.md | 97 +- address.go | 71 + address_test.go | 55 + balancer.go | 19 +- batch.go | 2 +- client.go | 207 ++- client_test.go | 231 ++- compress/compress.go | 74 + .../compress_test.go | 210 ++- compress/gzip/gzip.go | 122 ++ compress/lz4/lz4.go | 68 + compress/snappy/snappy.go | 89 ++ {snappy => compress/snappy}/xerial.go | 0 {snappy => compress/snappy}/xerial_test.go | 0 compress/zstd/zstd.go | 180 +++ compression.go | 59 +- conn.go | 21 +- conn_test.go | 21 +- consumergroup_test.go | 17 +- crc32_test.go | 2 - createtopics.go | 141 +- createtopics_test.go | 78 + deletetopics.go | 61 + deletetopics_test.go | 21 + dialer.go | 66 +- dialer_test.go | 137 +- docker-compose-241.yml | 32 + docker-compose.010.yml | 29 + docker-compose.yml | 17 +- error.go | 97 +- error_test.go | 2 - example_writer_test.go | 8 +- examples/producer-api/go.mod | 4 +- examples/producer-api/go.sum | 25 +- examples/producer-api/main.go | 6 +- examples/producer-random/go.mod | 4 +- examples/producer-random/go.sum | 25 +- examples/producer-random/main.go | 6 +- export_test.go | 9 - fetch.go | 153 ++ fetch_test.go | 302 ++++ go.mod | 2 +- gzip/gzip.go | 124 +- kafka.go | 92 ++ kafka_test.go | 173 +++ listoffset.go | 183 ++- listoffset_test.go | 76 + lz4/lz4.go | 68 +- message.go | 14 +- metadata.go | 107 ++ metadata_test.go | 75 + offsetfetch.go | 111 ++ produce.go | 172 ++- produce_test.go | 102 ++ protocol.go | 22 +- protocol/apiversions/apiversions.go | 27 + protocol/buffer.go | 645 ++++++++ protocol/buffer_test.go | 108 ++ protocol/cluster.go | 143 ++ protocol/conn.go | 96 ++ protocol/createtopics/createtopics.go | 74 + protocol/decode.go | 545 +++++++ protocol/deletetopics/deletetopics.go | 34 + protocol/encode.go | 639 ++++++++ protocol/error.go | 91 ++ protocol/fetch/fetch.go | 126 ++ protocol/fetch/fetch_test.go | 147 ++ protocol/findcoordinator/findcoordinator.go | 25 + protocol/listoffsets/listoffsets.go | 230 +++ protocol/listoffsets/listoffsets_test.go | 104 ++ protocol/metadata/metadata.go | 52 + protocol/metadata/metadata_test.go | 199 +++ protocol/offsetfetch/offsetfetch.go | 46 + protocol/produce/produce.go | 147 ++ protocol/produce/produce_test.go | 273 ++++ protocol/protocol.go | 480 ++++++ protocol/protocol_test.go | 281 ++++ protocol/prototest/bytes.go | 15 + protocol/prototest/prototest.go | 188 +++ protocol/prototest/reflect.go | 142 ++ protocol/prototest/request.go | 99 ++ protocol/prototest/response.go | 95 ++ protocol/record.go | 314 ++++ protocol/record_batch.go | 369 +++++ protocol/record_batch_test.go | 200 +++ protocol/record_v1.go | 243 +++ protocol/record_v2.go | 315 ++++ protocol/reflect.go | 101 ++ protocol/reflect_unsafe.go | 138 ++ protocol/request.go | 128 ++ protocol/response.go | 101 ++ protocol/roundtrip.go | 28 + protocol/saslauthenticate/saslauthenticate.go | 22 + protocol/saslhandshake/saslhandshake.go | 20 + protocol/size.go | 93 ++ protocol_test.go | 2 - reader_test.go | 131 +- record.go | 42 + resolver.go | 62 + sasl/sasl_test.go | 3 - snappy/snappy.go | 88 +- testing/conn.go | 32 + testing/version.go | 2 +- testing/version_test.go | 2 +- time.go | 13 +- transport.go | 1298 ++++++++++++++++ transport_test.go | 34 + write_test.go | 14 +- writer.go | 1312 ++++++++++------- writer_test.go | 249 +++- zstd/zstd.go | 132 +- 112 files changed, 13404 insertions(+), 1522 deletions(-) create mode 100644 address.go create mode 100644 address_test.go create mode 100644 compress/compress.go rename compression_test.go => compress/compress_test.go (61%) create mode 100644 compress/gzip/gzip.go create mode 100644 compress/lz4/lz4.go create mode 100644 compress/snappy/snappy.go rename {snappy => compress/snappy}/xerial.go (100%) rename {snappy => compress/snappy}/xerial_test.go (100%) create mode 100644 compress/zstd/zstd.go create mode 100644 docker-compose-241.yml create mode 100644 docker-compose.010.yml delete mode 100644 export_test.go create mode 100644 fetch_test.go create mode 100644 kafka.go create mode 100644 kafka_test.go create mode 100644 listoffset_test.go create mode 100644 metadata_test.go create mode 100644 produce_test.go create mode 100644 protocol/apiversions/apiversions.go create mode 100644 protocol/buffer.go create mode 100644 protocol/buffer_test.go create mode 100644 protocol/cluster.go create mode 100644 protocol/conn.go create mode 100644 protocol/createtopics/createtopics.go create mode 100644 protocol/decode.go create mode 100644 protocol/deletetopics/deletetopics.go create mode 100644 protocol/encode.go create mode 100644 protocol/error.go create mode 100644 protocol/fetch/fetch.go create mode 100644 protocol/fetch/fetch_test.go create mode 100644 protocol/findcoordinator/findcoordinator.go create mode 100644 protocol/listoffsets/listoffsets.go create mode 100644 protocol/listoffsets/listoffsets_test.go create mode 100644 protocol/metadata/metadata.go create mode 100644 protocol/metadata/metadata_test.go create mode 100644 protocol/offsetfetch/offsetfetch.go create mode 100644 protocol/produce/produce.go create mode 100644 protocol/produce/produce_test.go create mode 100644 protocol/protocol.go create mode 100644 protocol/protocol_test.go create mode 100644 protocol/prototest/bytes.go create mode 100644 protocol/prototest/prototest.go create mode 100644 protocol/prototest/reflect.go create mode 100644 protocol/prototest/request.go create mode 100644 protocol/prototest/response.go create mode 100644 protocol/record.go create mode 100644 protocol/record_batch.go create mode 100644 protocol/record_batch_test.go create mode 100644 protocol/record_v1.go create mode 100644 protocol/record_v2.go create mode 100644 protocol/reflect.go create mode 100644 protocol/reflect_unsafe.go create mode 100644 protocol/request.go create mode 100644 protocol/response.go create mode 100644 protocol/roundtrip.go create mode 100644 protocol/saslauthenticate/saslauthenticate.go create mode 100644 protocol/saslhandshake/saslhandshake.go create mode 100644 protocol/size.go create mode 100644 record.go create mode 100644 resolver.go create mode 100644 testing/conn.go create mode 100644 transport.go create mode 100644 transport_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ec8e2bbd..54029c89b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,133 +1,213 @@ version: 2 jobs: + # The kafka 0.10 tests are maintained as a separate configuration because + # kafka only supported plain text SASL in this version. kafka-010: - working_directory: /go/src/github.com/segmentio/kafka-go + working_directory: &working_directory /go/src/github.com/segmentio/kafka-go environment: KAFKA_VERSION: "0.10.1" docker: - - image: circleci/golang - - image: wurstmeister/zookeeper - ports: ['2181:2181'] - - image: wurstmeister/kafka:0.10.1.1 - ports: ['9092:9092'] - environment: - KAFKA_BROKER_ID: '1' - KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' - KAFKA_DELETE_TOPIC_ENABLE: 'true' - KAFKA_ADVERTISED_HOST_NAME: 'localhost' - KAFKA_ADVERTISED_PORT: '9092' - KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' - KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' - KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' - KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' - KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN' - KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" - CUSTOM_INIT_SCRIPT: |- - echo -e 'KafkaServer {\norg.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; - steps: - - checkout - - setup_remote_docker: { reusable: true, docker_layer_caching: true } - - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy - - run: go test -v -race -cover -timeout 150s . ./gzip ./lz4 ./sasl ./snappy + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:0.10.1.1 + ports: + - 9092:9092 + - 9093:9093 + environment: + KAFKA_BROKER_ID: '1' + KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' + KAFKA_DELETE_TOPIC_ENABLE: 'true' + KAFKA_ADVERTISED_HOST_NAME: 'localhost' + KAFKA_ADVERTISED_PORT: '9092' + KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' + KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' + KAFKA_MESSAGE_MAX_BYTES: '200000000' + KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' + KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' + KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN' + KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" + CUSTOM_INIT_SCRIPT: |- + echo -e 'KafkaServer {\norg.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; + steps: &steps + - checkout + - restore_cache: + key: kafka-go-mod-{{ checksum "go.sum" }}-1 + - run: go mod download + - save_cache: + key: kafka-go-mod-{{ checksum "go.sum" }}-1 + paths: + - /go/pkg/mod + - run: go test -race -cover ./... + # Starting at version 0.11, the kafka features and configuration remained + # mostly stable, so we can use this CI job configuration as template for other + # versions as well. kafka-011: - working_directory: /go/src/github.com/segmentio/kafka-go + working_directory: *working_directory environment: KAFKA_VERSION: "0.11.0" docker: - - image: circleci/golang - - image: wurstmeister/zookeeper - ports: ['2181:2181'] - - image: wurstmeister/kafka:2.11-0.11.0.3 - ports: ['9092:9092','9093:9093'] - environment: - KAFKA_BROKER_ID: '1' - KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' - KAFKA_DELETE_TOPIC_ENABLE: 'true' - KAFKA_ADVERTISED_HOST_NAME: 'localhost' - KAFKA_ADVERTISED_PORT: '9092' - KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' - KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' - KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' - KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' - KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN,SCRAM-SHA-256,SCRAM-SHA-512' - KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" - CUSTOM_INIT_SCRIPT: |- - echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; - /opt/kafka/bin/kafka-configs.sh --zookeeper localhost:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram - steps: - - checkout - - setup_remote_docker: { reusable: true, docker_layer_caching: true } - - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy - - run: go test -v -race -cover -timeout 150s . ./gzip ./lz4 ./sasl ./snappy + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.11-0.11.0.3 + ports: + - 9092:9092 + - 9093:9093 + environment: &environment + KAFKA_BROKER_ID: '1' + KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' + KAFKA_DELETE_TOPIC_ENABLE: 'true' + KAFKA_ADVERTISED_HOST_NAME: 'localhost' + KAFKA_ADVERTISED_PORT: '9092' + KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' + KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' + KAFKA_MESSAGE_MAX_BYTES: '200000000' + KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' + KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' + KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN,SCRAM-SHA-256,SCRAM-SHA-512' + KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" + CUSTOM_INIT_SCRIPT: |- + echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; + /opt/kafka/bin/kafka-configs.sh --zookeeper localhost:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram + steps: *steps + + kafka-101: + working_directory: *working_directory + environment: + KAFKA_VERSION: "1.0.1" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.11-1.0.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps kafka-111: - working_directory: /go/src/github.com/segmentio/kafka-go + working_directory: *working_directory environment: KAFKA_VERSION: "1.1.1" docker: - - image: circleci/golang - - image: wurstmeister/zookeeper - ports: ['2181:2181'] - - image: wurstmeister/kafka:2.11-1.1.1 - ports: ['9092:9092','9093:9093'] - environment: - KAFKA_BROKER_ID: '1' - KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' - KAFKA_DELETE_TOPIC_ENABLE: 'true' - KAFKA_ADVERTISED_HOST_NAME: 'localhost' - KAFKA_ADVERTISED_PORT: '9092' - KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' - KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' - KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' - KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' - KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN,SCRAM-SHA-256,SCRAM-SHA-512' - KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" - CUSTOM_INIT_SCRIPT: |- - echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; - /opt/kafka/bin/kafka-configs.sh --zookeeper localhost:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram - steps: - - checkout - - setup_remote_docker: { reusable: true, docker_layer_caching: true } - - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy - - run: go test -v -race -cover -timeout 150s . ./gzip ./lz4 ./sasl ./snappy + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.11-1.1.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps + + kafka-201: + working_directory: *working_directory + environment: + KAFKA_VERSION: "2.0.1" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.12-2.0.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps + + kafka-211: + working_directory: *working_directory + environment: + KAFKA_VERSION: "2.1.1" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.12-2.1.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps + + kafka-222: + working_directory: *working_directory + environment: + KAFKA_VERSION: "2.2.2" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.12-2.2.2 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps - kafka-210: - working_directory: /go/src/github.com/segmentio/kafka-go + kafka-231: + working_directory: *working_directory environment: - KAFKA_VERSION: "2.1.0" + KAFKA_VERSION: "2.3.1" + docker: + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.12-2.3.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps + + kafka-241: + working_directory: *working_directory + environment: + KAFKA_VERSION: "2.4.1" + + # Need to skip nettest to avoid these kinds of errors: + # --- FAIL: TestConn/nettest (17.56s) + # --- FAIL: TestConn/nettest/PingPong (7.40s) + # conntest.go:112: unexpected Read error: [7] Request Timed Out: the request exceeded the user-specified time limit in the request + # conntest.go:118: mismatching value: got 77, want 78 + # conntest.go:118: mismatching value: got 78, want 79 + # ... + # + # TODO: Figure out why these are happening and fix them (they don't appear to be new). + KAFKA_SKIP_NETTEST: "1" docker: - - image: circleci/golang - - image: wurstmeister/zookeeper - ports: ['2181:2181'] - - image: wurstmeister/kafka:2.12-2.1.0 - ports: ['9092:9092','9093:9093'] - environment: - KAFKA_BROKER_ID: '1' - KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' - KAFKA_DELETE_TOPIC_ENABLE: 'true' - KAFKA_ADVERTISED_HOST_NAME: 'localhost' - KAFKA_ADVERTISED_PORT: '9092' - KAFKA_ZOOKEEPER_CONNECT: 'localhost:2181' - KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' - KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' - KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' - KAFKA_SASL_ENABLED_MECHANISMS: SCRAM-SHA-256,SCRAM-SHA-512,PLAIN - KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" - CUSTOM_INIT_SCRIPT: |- - echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; - /opt/kafka/bin/kafka-configs.sh --zookeeper localhost:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram - steps: - - checkout - - setup_remote_docker: { reusable: true, docker_layer_caching: true } - - run: go get -v -t . ./gzip ./lz4 ./sasl ./snappy - - run: go test -v -race -cover -timeout 150s $(go list ./... | grep -v examples) + - image: circleci/golang + - image: wurstmeister/zookeeper + ports: + - 2181:2181 + - image: wurstmeister/kafka:2.12-2.4.1 + ports: + - 9092:9092 + - 9093:9093 + environment: *environment + steps: *steps workflows: version: 2 run: jobs: - - kafka-010 - - kafka-011 - - kafka-111 - - kafka-210 + - kafka-010 + - kafka-011 + - kafka-101 + - kafka-111 + - kafka-201 + - kafka-211 + - kafka-222 + - kafka-231 + - kafka-241 diff --git a/README.md b/README.md index 62e1623a8..c57480641 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,35 @@ APIs for interacting with Kafka, mirroring concepts and implementing interfaces the Go standard library to make it easy to use and integrate with existing software. +## Migrating to 0.4 + +Version 0.4 introduces a few breaking changes to the repository structure which +should have minimal impact on programs and should only manifest at compile time +(the runtime behavior should remain unchanged). + +* Programs do not need to import compression packages anymore in order to read +compressed messages from kafka. All compression codecs are supported by default. + +* Programs that used the compression codecs directly must be adapted. +Compression codecs are now exposed in the `compress` sub-package. + +* The experimental `kafka.Client` API has been updated and slightly modified: +the `kafka.NewClient` function and `kafka.ClientConfig` type were removed. +Programs now configure the client values directly through exported fields. + +* The `kafka.(*Client).ConsumerOffsets` method is now deprecated (along with the +`kafka.TopicAndGroup` type, and will be removed when we release version 1.0. +Programs should use the `kafka.(*Client).OffsetFetch` API instead. + +With 0.4, we know that we are starting to introduce a bit more complexity in the +code, but the plan is to eventually converge towards a simpler and more effective +API, allowing us to keep up with Kafka's ever growing feature set, and bringing +a more efficient implementation to programs depending on kafka-go. + +We truly appreciate everyone's input and contributions, which have made this +project way more than what it was when we started it, and we're looking forward +to receive more feedback on where we should take it. + ## Kafka versions `kafka-go` is currently compatible with Kafka versions from 0.10.1.0 to 2.1.0. While latest versions will be working, @@ -318,11 +347,11 @@ to use in most cases as it provides additional features: ```go // make a writer that produces to topic-A, using the least-bytes distribution -w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, +w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), Topic: "topic-A", Balancer: &kafka.LeastBytes{}, -}) +} err := w.WriteMessages(context.Background(), kafka.Message{ @@ -350,6 +379,28 @@ if err := w.Close(); err != nil { **Note:** Even though kafka.Message contain ```Topic``` and ```Partition``` fields, they **MUST NOT** be set when writing messages. They are intended for read use only. +### Writing to multiple topics + +Each writer is bound to a single topic, to write to multiple topics, a program +must create multiple writers. + +We considered making the `kafka.Writer` type interpret the `Topic` field of +`kafka.Message` values, batching messages per partition or topic/partition pairs +does not introduce much more complexity. However, supporting this means we would +also have to report stats broken down by topic. It would also raise the question +of how we would manage configuration specific to each topic (e.g. different +compression algorithms). Overall, the amount of coupling between the various +properties of `kafka.Writer` suggest that we should not support publishing to +multiple topics from a single writer, and instead encourage the program to +create a writer for each topic it needs to publish to. + +The split of connection management into the `kafka.Transport` in kafka-go 0.4 has +made writers really cheap to create as they barely manage any state, programs can +construct new writers configured to publish to kafka topics when needed. + +Adding new APIs to facilitate the management of writer sets is an option, and we +would welcome contributions in this area. + ### Compatibility with other clients #### Sarama @@ -359,11 +410,11 @@ partitioning, you can use the ```kafka.Hash``` balancer. ```kafka.Hash``` route messages to the same partitions that Sarama's default partitioner would route to. ```go -w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, +w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), Topic: "topic-A", Balancer: &kafka.Hash{}, -}) +} ``` #### librdkafka and confluent-kafka-go @@ -372,11 +423,11 @@ Use the ```kafka.CRC32Balancer``` balancer to get the same behaviour as librdkaf default ```consistent_random``` partition strategy. ```go -w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, +w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), Topic: "topic-A", Balancer: kafka.CRC32Balancer{}, -}) +} ``` #### Java @@ -386,34 +437,32 @@ Java client's default partitioner. Note: the Java class allows you to directly the partition which is not permitted. ```go -w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, +w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), Topic: "topic-A", Balancer: kafka.Murmur2Balancer{}, -}) +} ``` ### Compression -Compression can be enabled on the `Writer` by configuring the `CompressionCodec`: +Compression can be enabled on the `Writer` by setting the `Compression` field: ```go -w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, - Topic: "topic-A", - CompressionCodec: snappy.NewCompressionCodec(), -}) +w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), + Topic: "topic-A", + Compression: kafka.Snappy, +} ``` The `Reader` will by determine if the consumed messages are compressed by examining the message attributes. However, the package(s) for all expected -codecs must be imported so that they get loaded correctly. For example, if you -are going to be receiving messages compressed with Snappy, add the following -import: +codecs must be imported so that they get loaded correctly. -```go -import _ "github.com/segmentio/kafka-go/snappy" -``` +_Note: in versions prior to 0.4 programs had to import compression packages to +install codecs and support reading compressed messages from kafka. This is no +longer the case and import of the compression packages are now no-ops._ ## TLS Support diff --git a/address.go b/address.go new file mode 100644 index 000000000..c7d8efd9b --- /dev/null +++ b/address.go @@ -0,0 +1,71 @@ +package kafka + +import ( + "net" + "strings" +) + +// TCP constructs an address with the network set to "tcp". +func TCP(address ...string) net.Addr { return makeNetAddr("tcp", address) } + +func makeNetAddr(network string, addresses []string) net.Addr { + switch len(addresses) { + case 0: + return nil // maybe panic instead? + case 1: + return makeAddr(network, addresses[0]) + default: + return makeMultiAddr(network, addresses) + } +} + +func makeAddr(network, address string) net.Addr { + host, port, _ := net.SplitHostPort(address) + if port == "" { + port = "9092" + } + if host == "" { + host = address + } + return &networkAddress{ + network: network, + address: net.JoinHostPort(host, port), + } +} + +func makeMultiAddr(network string, addresses []string) net.Addr { + multi := make(multiAddr, len(addresses)) + for i, address := range addresses { + multi[i] = makeAddr(network, address) + } + return multi +} + +type networkAddress struct { + network string + address string +} + +func (a *networkAddress) Network() string { return a.network } + +func (a *networkAddress) String() string { return a.address } + +type multiAddr []net.Addr + +func (m multiAddr) Network() string { return m.join(net.Addr.Network) } + +func (m multiAddr) String() string { return m.join(net.Addr.String) } + +func (m multiAddr) join(f func(net.Addr) string) string { + switch len(m) { + case 0: + return "" + case 1: + return f(m[0]) + } + s := make([]string, len(m)) + for i, a := range m { + s[i] = f(a) + } + return strings.Join(s, ",") +} diff --git a/address_test.go b/address_test.go new file mode 100644 index 000000000..18f0bde33 --- /dev/null +++ b/address_test.go @@ -0,0 +1,55 @@ +package kafka + +import ( + "net" + "testing" +) + +func TestNetworkAddress(t *testing.T) { + tests := []struct { + addr net.Addr + network string + address string + }{ + { + addr: TCP("127.0.0.1"), + network: "tcp", + address: "127.0.0.1:9092", + }, + + { + addr: TCP("::1"), + network: "tcp", + address: "[::1]:9092", + }, + + { + addr: TCP("localhost"), + network: "tcp", + address: "localhost:9092", + }, + + { + addr: TCP("localhost:9092"), + network: "tcp", + address: "localhost:9092", + }, + + { + addr: TCP("localhost", "localhost:9093", "localhost:9094"), + network: "tcp,tcp,tcp", + address: "localhost:9092,localhost:9093,localhost:9094", + }, + } + + for _, test := range tests { + t.Run(test.network+"+"+test.address, func(t *testing.T) { + if s := test.addr.Network(); s != test.network { + t.Errorf("network mismatch: want %q but got %q", test.network, s) + } + if s := test.addr.String(); s != test.address { + t.Errorf("network mismatch: want %q but got %q", test.address, s) + } + }) + } +} diff --git a/balancer.go b/balancer.go index 5ae236aa3..7a50cc1ce 100644 --- a/balancer.go +++ b/balancer.go @@ -7,15 +7,14 @@ import ( "math/rand" "sort" "sync" + "sync/atomic" ) // The Balancer interface provides an abstraction of the message distribution // logic used by Writer instances to route messages to the partitions available // on a kafka cluster. // -// Instances of Balancer do not have to be safe to use concurrently by multiple -// goroutines, the Writer implementation ensures that calls to Balance are -// synchronized. +// Balancers must be safe to use concurrently from multiple goroutines. type Balancer interface { // Balance receives a message and a set of available partitions and // returns the partition number that the message should be routed to. @@ -39,14 +38,19 @@ func (f BalancerFunc) Balance(msg Message, partitions ...int) int { // RoundRobin is an Balancer implementation that equally distributes messages // across all available partitions. type RoundRobin struct { - offset uint64 + // Use a 32 bits integer so RoundRobin values don't need to be aligned to + // apply atomic increments. + offset uint32 } // Balance satisfies the Balancer interface. func (rr *RoundRobin) Balance(msg Message, partitions ...int) int { - length := uint64(len(partitions)) - offset := rr.offset - rr.offset++ + return rr.balance(partitions) +} + +func (rr *RoundRobin) balance(partitions []int) int { + length := uint32(len(partitions)) + offset := atomic.AddUint32(&rr.offset, 1) - 1 return partitions[offset%length] } @@ -57,6 +61,7 @@ func (rr *RoundRobin) Balance(msg Message, partitions ...int) int { // balancing relies on the fact that each producer using a LeastBytes balancer // should produce well balanced messages. type LeastBytes struct { + mutex sync.Mutex counters []leastBytesCounter } diff --git a/batch.go b/batch.go index ff9fae102..eefb61a60 100644 --- a/batch.go +++ b/batch.go @@ -205,7 +205,7 @@ func (batch *Batch) ReadMessage() (Message, error) { msg.Topic = batch.topic msg.Partition = batch.partition msg.Offset = offset - msg.Time = timestampToTime(timestamp) + msg.Time = makeTime(timestamp) msg.Headers = headers return msg, err diff --git a/client.go b/client.go index e4a8f6474..cdb536e48 100644 --- a/client.go +++ b/client.go @@ -2,100 +2,84 @@ package kafka import ( "context" - "fmt" + "errors" + "net" + "time" + + "github.com/segmentio/kafka-go/protocol" +) + +const ( + defaultCreateTopicsTimeout = 2 * time.Second + defaultDeleteTopicsTimeout = 2 * time.Second + defaultProduceTimeout = 500 * time.Millisecond + defaultMaxWait = 500 * time.Millisecond ) -// Client is a new and experimental API for kafka-go. It is expected that this API will grow over time, -// and offer a new set of "mid-level" capabilities. Specifically, it is expected Client will be a higher level API than Conn, -// yet provide more control and lower level operations than the Reader and Writer APIs. +// Client is a high-level API to interract with kafka brokers. // -// N.B Client is currently experimental! Therefore, it is subject to change, including breaking changes -// between MINOR and PATCH releases. +// All methods of the Client type accept a context as first argument, which may +// be used to asynchronously cancel the requests. +// +// Clients are safe to use concurrently from multiple goroutines, as long as +// their configuration is not changed after first use. type Client struct { - brokers []string - dialer *Dialer + // Address of the kafka cluster (or specific broker) that the client will be + // sending requests to. + // + // This field is optional, the address may be provided in each request + // instead. The request address takes precedence if both were specified. + Addr net.Addr + + // Time limit for requests sent by this client. + // + // If zero, no timeout is applied. + Timeout time.Duration + + // A transport used to communicate with the kafka brokers. + // + // If nil, DefaultTransport is used. + Transport RoundTripper } -// Configuration for Client +// A ConsumerGroup and Topic as these are both strings we define a type for +// clarity when passing to the Client as a function argument // -// N.B ClientConfig is currently experimental! Therefore, it is subject to change, including breaking changes -// between MINOR and PATCH releases. -type ClientConfig struct { - // List of broker strings in the format : - // to use for bootstrap connecting to cluster - Brokers []string - // Dialer used for connecting to the Cluster - Dialer *Dialer -} - -// A ConsumerGroup and Topic as these are both strings -// we define a type for clarity when passing to the Client -// as a function argument +// N.B TopicAndGroup is currently experimental! Therefore, it is subject to +// change, including breaking changes between MINOR and PATCH releases. // -// N.B TopicAndGroup is currently experimental! Therefore, it is subject to change, including breaking changes -// between MINOR and PATCH releases. +// DEPRECATED: this type will be removed in version 1.0, programs should +// migrate to use kafka.(*Client).OffsetFetch instead. type TopicAndGroup struct { Topic string GroupId string } -// NewClient creates and returns a *Client taking ...string of bootstrap -// brokers for connecting to the cluster. -func NewClient(brokers ...string) *Client { - return NewClientWith(ClientConfig{Brokers: brokers, Dialer: DefaultDialer}) -} - -// NewClientWith creates and returns a *Client. For safety, it copies the []string of bootstrap -// brokers for connecting to the cluster and uses the user supplied Dialer. -// In the event the Dialer is nil, we use the DefaultDialer. -func NewClientWith(config ClientConfig) *Client { - if len(config.Brokers) == 0 { - panic("must provide at least one broker") - } - - b := make([]string, len(config.Brokers)) - copy(b, config.Brokers) - d := config.Dialer - if d == nil { - d = DefaultDialer - } - - return &Client{ - brokers: b, - dialer: d, - } -} - -// ConsumerOffsets returns a map[int]int64 of partition to committed offset for a consumer group id and topic +// ConsumerOffsets returns a map[int]int64 of partition to committed offset for +// a consumer group id and topic. +// +// DEPRECATED: this method will be removed in version 1.0, programs should +// migrate to use kafka.(*Client).OffsetFetch instead. func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int]int64, error) { - address, err := c.lookupCoordinator(tg.GroupId) - if err != nil { - return nil, err - } + metadata, err := c.Metadata(ctx, &MetadataRequest{ + Topics: []string{tg.Topic}, + }) - conn, err := c.coordinator(ctx, address) if err != nil { return nil, err } - defer conn.Close() - partitions, err := conn.ReadPartitions(tg.Topic) - if err != nil { - return nil, err - } + topic := metadata.Topics[0] + partitions := make([]int, len(topic.Partitions)) - var parts []int32 - for _, p := range partitions { - parts = append(parts, int32(p.ID)) + for i := range topic.Partitions { + partitions[i] = topic.Partitions[i].ID } - offsets, err := conn.offsetFetch(offsetFetchRequestV1{ + offsets, err := c.OffsetFetch(ctx, &OffsetFetchRequest{ GroupID: tg.GroupId, - Topics: []offsetFetchRequestV1Topic{ - { - Topic: tg.Topic, - Partitions: parts, - }, + Topics: map[string][]int{ + tg.Topic: partitions, }, }) @@ -103,63 +87,58 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int return nil, err } - if len(offsets.Responses) != 1 { - return nil, fmt.Errorf("error fetching offsets, no responses received") - } + topicOffsets := offsets.Topics[topic.Name] + partitionOffsets := make(map[int]int64, len(topicOffsets)) - offsetsByPartition := map[int]int64{} - for _, pr := range offsets.Responses[0].PartitionResponses { - offset := pr.Offset - if offset < 0 { - // No offset stored - // -1 indicates that there is no offset saved for the partition. - // If we returned a -1 here the user might interpret that as LastOffset - // so we set to Firstoffset for safety. - // See http://kafka.apache.org/protocol.html#The_Messages_OffsetFetch - offset = FirstOffset - } - offsetsByPartition[int(pr.Partition)] = offset + for _, off := range topicOffsets { + partitionOffsets[off.Partition] = off.CommittedOffset } - return offsetsByPartition, nil + return partitionOffsets, nil } -// connect returns a connection to ANY broker -func (c *Client) connect() (conn *Conn, err error) { - for _, broker := range c.brokers { - if conn, err = c.dialer.Dial("tcp", broker); err == nil { - return +func (c *Client) roundTrip(ctx context.Context, addr net.Addr, msg protocol.Message) (protocol.Message, error) { + if c.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.Timeout) + defer cancel() + } + + if addr == nil { + if addr = c.Addr; addr == nil { + return nil, errors.New("no address was given for the kafka cluster in the request or on the client") } } - return // err will be non-nil + + return c.transport().RoundTrip(ctx, addr, msg) } -// coordinator returns a connection to a coordinator -func (c *Client) coordinator(ctx context.Context, address string) (*Conn, error) { - conn, err := c.dialer.DialContext(ctx, "tcp", address) - if err != nil { - return nil, fmt.Errorf("unable to connect to coordinator, %v", address) +func (c *Client) transport() RoundTripper { + if c.Transport != nil { + return c.Transport } - - return conn, nil + return DefaultTransport } -// lookupCoordinator scans the brokers and looks up the address of the -// coordinator for the groupId. -func (c *Client) lookupCoordinator(groupId string) (string, error) { - conn, err := c.connect() - if err != nil { - return "", fmt.Errorf("unable to find coordinator to any connect for group, %v: %v\n", groupId, err) +func (c *Client) timeout(ctx context.Context, defaultTimeout time.Duration) time.Duration { + timeout := c.Timeout + + if deadline, ok := ctx.Deadline(); ok { + if remain := time.Until(deadline); remain < timeout { + timeout = remain + } } - defer conn.Close() - out, err := conn.findCoordinator(findCoordinatorRequestV0{ - CoordinatorKey: groupId, - }) - if err != nil { - return "", fmt.Errorf("unable to find coordinator for group, %v: %v", groupId, err) + if timeout > 0 { + // Half the timeout because it is communicated to kafka in multiple + // requests (e.g. Fetch, Produce, etc...), this adds buffer to account + // for network latency when waiting for the response from kafka. + return timeout / 2 } - address := fmt.Sprintf("%v:%v", out.Coordinator.Host, out.Coordinator.Port) - return address, nil + return defaultTimeout +} + +func (c *Client) timeoutMs(ctx context.Context, defaultTimeout time.Duration) int32 { + return milliseconds(c.timeout(ctx, defaultTimeout)) } diff --git a/client_test.go b/client_test.go index 392703e4f..e971d4efb 100644 --- a/client_test.go +++ b/client_test.go @@ -1,13 +1,96 @@ package kafka import ( + "bytes" "context" + "io" + "math/rand" + "net" "testing" "time" + + "github.com/segmentio/kafka-go/compress" + ktesting "github.com/segmentio/kafka-go/testing" ) +func newLocalClientAndTopic() (*Client, string, func()) { + topic := makeTopic() + client, shutdown := newLocalClientWithTopic(topic, 1) + return client, topic, shutdown +} + +func newLocalClientWithTopic(topic string, partitions int) (*Client, func()) { + client, shutdown := newLocalClient() + if err := clientCreateTopic(client, topic, partitions); err != nil { + shutdown() + panic(err) + } + return client, func() { + client.DeleteTopics(context.Background(), &DeleteTopicsRequest{ + Topics: []string{topic}, + }) + shutdown() + } +} + +func clientCreateTopic(client *Client, topic string, partitions int) error { + _, err := client.CreateTopics(context.Background(), &CreateTopicsRequest{ + Topics: []TopicConfig{{ + Topic: topic, + NumPartitions: partitions, + ReplicationFactor: 1, + }}, + }) + if err != nil { + return err + } + + // Topic creation seems to be asynchronous. Metadata for the topic partition + // layout in the cluster is available in the controller before being synced + // with the other brokers, which causes "Error:[3] Unknown Topic Or Partition" + // when sending requests to the partition leaders. + // + // This loop will wait up to 2 seconds polling the cluster until no errors + // are returned. + for i := 0; i < 20; i++ { + r, err := client.Fetch(context.Background(), &FetchRequest{ + Topic: topic, + Partition: 0, + Offset: 0, + }) + if err == nil && r.Error == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + return nil +} + +func newLocalClient() (*Client, func()) { + return newClient(TCP("localhost")) +} + +func newClient(addr net.Addr) (*Client, func()) { + conns := &ktesting.ConnWaitGroup{ + DialFunc: (&net.Dialer{}).DialContext, + } + + transport := &Transport{ + Dial: conns.Dial, + Resolver: NewBrokerResolver(nil), + } + + client := &Client{ + Addr: addr, + Timeout: 5 * time.Second, + Transport: transport, + } + + return client, func() { transport.CloseIdleConnections(); conns.Wait() } +} + func TestClient(t *testing.T) { - t.Parallel() tests := []struct { scenario string function func(*testing.T, context.Context, *Client) @@ -21,33 +104,37 @@ func TestClient(t *testing.T) { for _, test := range tests { testFunc := test.function t.Run(test.scenario, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - c := NewClient("localhost:9092") - testFunc(t, ctx, c) + client, shutdown := newLocalClient() + defer shutdown() + + testFunc(t, ctx, client) }) } } -func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client) { +func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, client *Client) { const totalMessages = 144 const partitions = 12 const msgPerPartition = totalMessages / partitions + topic := makeTopic() + if err := clientCreateTopic(client, topic, partitions); err != nil { + t.Fatal(err) + } + groupId := makeGroupID() - createTopic(t, topic, partitions) brokers := []string{"localhost:9092"} - writer := NewWriter(WriterConfig{ - Brokers: brokers, + writer := &Writer{ + Addr: TCP(brokers...), Topic: topic, - Dialer: DefaultDialer, Balancer: &RoundRobin{}, BatchSize: 1, - }) + Transport: client.Transport, + } if err := writer.WriteMessages(ctx, makeTestSequence(totalMessages)...); err != nil { t.Fatalf("bad write messages: %v", err) } @@ -68,12 +155,14 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client) for i := 0; i < totalMessages; i++ { m, err := r.FetchMessage(ctx) if err != nil { - t.Errorf("error fetching message: %s", err) + t.Fatalf("error fetching message: %s", err) + } + if err := r.CommitMessages(context.Background(), m); err != nil { + t.Fatal(err) } - r.CommitMessages(context.Background(), m) } - offsets, err := c.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic}) + offsets, err := client.ConsumerOffsets(ctx, TopicAndGroup{GroupId: groupId, Topic: topic}) if err != nil { t.Fatal(err) } @@ -85,7 +174,119 @@ func testConsumerGroupFetchOffsets(t *testing.T, ctx context.Context, c *Client) for i := 0; i < partitions; i++ { committedOffset := offsets[i] if committedOffset != msgPerPartition { - t.Fatalf("expected committed offset of %d but received %d", msgPerPartition, committedOffset) + t.Errorf("expected partition %d with committed offset of %d but received %d", i, msgPerPartition, committedOffset) + } + } +} + +func TestClientProduceAndConsume(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + // Tests a typical kafka use case, data is produced to a partition, + // then consumed back sequentially. We use snappy compression because + // kafka stream are often compressed, and verify that each record + // produced is exposed to the consumer, and order is preserved. + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + epoch := time.Now() + seed := int64(0) // deterministic + prng := rand.New(rand.NewSource(seed)) + offset := int64(0) + + const numBatches = 100 + const recordsPerBatch = 320 + t.Logf("producing %d batches of %d records...", numBatches, recordsPerBatch) + + for i := 0; i < numBatches; i++ { // produce 100 batches + records := make([]Record, recordsPerBatch) + + for i := range records { + v := make([]byte, prng.Intn(999)+1) + io.ReadFull(prng, v) + records[i].Time = epoch + records[i].Value = NewBytes(v) + } + + res, err := client.Produce(ctx, &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: NewRecordReader(records...), + Compression: compress.Snappy, + }) + if err != nil { + t.Fatal(err) + } + if res.Error != nil { + t.Fatal(res.Error) + } + if res.BaseOffset != offset { + t.Fatalf("records were produced at an unexpected offset, want %d but got %d", offset, res.BaseOffset) } + offset += int64(len(records)) } + + prng.Seed(seed) + offset = 0 // reset + numFetches := 0 + numRecords := 0 + + for numRecords < (numBatches * recordsPerBatch) { + res, err := client.Fetch(ctx, &FetchRequest{ + Topic: topic, + Partition: 0, + Offset: offset, + MinBytes: 1, + MaxBytes: 256 * 1024, + MaxWait: 100 * time.Millisecond, // should only hit on the last fetch + }) + if err != nil { + t.Fatal(err) + } + if res.Error != nil { + t.Fatal(err) + } + + for { + r, err := res.Records.ReadRecord() + if err != nil { + if err != io.EOF { + t.Fatal(err) + } + break + } + + if r.Key != nil { + r.Key.Close() + t.Error("unexpected non-null key on record at offset", r.Offset) + } + + n := prng.Intn(999) + 1 + a := make([]byte, n) + b := make([]byte, n) + io.ReadFull(prng, a) + + _, err = io.ReadFull(r.Value, b) + r.Value.Close() + if err != nil { + t.Fatalf("reading record at offset %d: %v", r.Offset, err) + } + + if !bytes.Equal(a, b) { + t.Fatalf("value of record at offset %d mismatches", r.Offset) + } + + if r.Offset != offset { + t.Fatalf("record at offset %d was expected to have offset %d", r.Offset, offset) + } + + offset = r.Offset + 1 + numRecords++ + } + + numFetches++ + } + + t.Logf("%d records were read in %d fetches", numRecords, numFetches) } diff --git a/compress/compress.go b/compress/compress.go new file mode 100644 index 000000000..52006077e --- /dev/null +++ b/compress/compress.go @@ -0,0 +1,74 @@ +package compress + +import ( + "io" + + "github.com/segmentio/kafka-go/compress/gzip" + "github.com/segmentio/kafka-go/compress/lz4" + "github.com/segmentio/kafka-go/compress/snappy" + "github.com/segmentio/kafka-go/compress/zstd" +) + +// Compression represents the the compression applied to a record set. +type Compression int8 + +const ( + Gzip Compression = 1 + Snappy Compression = 2 + Lz4 Compression = 3 + Zstd Compression = 4 +) + +func (c Compression) Codec() Codec { + if i := int(c); i >= 0 && i < len(Codecs) { + return Codecs[i] + } + return nil +} + +func (c Compression) String() string { + if codec := c.Codec(); codec != nil { + return codec.Name() + } + return "uncompressed" +} + +// Codec represents a compression codec to encode and decode the messages. +// See : https://cwiki.apache.org/confluence/display/KAFKA/Compression +// +// A Codec must be safe for concurrent access by multiple go routines. +type Codec interface { + // Code returns the compression codec code + Code() int8 + + // Human-readable name for the codec. + Name() string + + // Constructs a new reader which decompresses data from r. + NewReader(r io.Reader) io.ReadCloser + + // Constructs a new writer which writes compressed data to w. + NewWriter(w io.Writer) io.WriteCloser +} + +var ( + // The global gzip codec installed on the Codecs table. + GzipCodec gzip.Codec + + // The global snappy codec installed on the Codecs table. + SnappyCodec snappy.Codec + + // The global lz4 codec installed on the Codecs table. + Lz4Codec lz4.Codec + + // The global zstd codec installed on the Codecs table. + ZstdCodec zstd.Codec + + // The global table of compression codecs supported by the kafka protocol. + Codecs = [...]Codec{ + Gzip: &GzipCodec, + Snappy: &SnappyCodec, + Lz4: &Lz4Codec, + Zstd: &ZstdCodec, + } +) diff --git a/compression_test.go b/compress/compress_test.go similarity index 61% rename from compression_test.go rename to compress/compress_test.go index a006e70e4..07880e280 100644 --- a/compression_test.go +++ b/compress/compress_test.go @@ -1,12 +1,14 @@ -package kafka_test +package compress_test import ( "bytes" - compressGzip "compress/gzip" + stdgzip "compress/gzip" "context" "fmt" "io" "io/ioutil" + "math/rand" + "net" "os" "path/filepath" "strconv" @@ -15,27 +17,44 @@ import ( "time" kafka "github.com/segmentio/kafka-go" - "github.com/segmentio/kafka-go/gzip" - "github.com/segmentio/kafka-go/lz4" - "github.com/segmentio/kafka-go/snappy" + pkg "github.com/segmentio/kafka-go/compress" + "github.com/segmentio/kafka-go/compress/gzip" + "github.com/segmentio/kafka-go/compress/lz4" + "github.com/segmentio/kafka-go/compress/snappy" + "github.com/segmentio/kafka-go/compress/zstd" ktesting "github.com/segmentio/kafka-go/testing" - "github.com/segmentio/kafka-go/zstd" ) +func init() { + // Seeding the random source is important to prevent multiple test runs from + // reusing the same topic names. + rand.Seed(time.Now().UnixNano()) +} + +func TestCodecs(t *testing.T) { + for i, c := range pkg.Codecs { + if c != nil { + if code := c.Code(); int8(code) != int8(i) { + t.Fatal("default compression codec table is misconfigured for", c.Name()) + } + } + } +} + func TestCompression(t *testing.T) { msg := kafka.Message{ Value: []byte("message"), } - testEncodeDecode(t, msg, gzip.NewCompressionCodec()) - testEncodeDecode(t, msg, snappy.NewCompressionCodec()) - testEncodeDecode(t, msg, lz4.NewCompressionCodec()) + testEncodeDecode(t, msg, new(gzip.Codec)) + testEncodeDecode(t, msg, new(snappy.Codec)) + testEncodeDecode(t, msg, new(lz4.Codec)) if ktesting.KafkaIsAtLeast("2.1.0") { - testEncodeDecode(t, msg, zstd.NewCompressionCodec()) + testEncodeDecode(t, msg, new(zstd.Codec)) } } -func compress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { +func compress(codec pkg.Codec, src []byte) ([]byte, error) { b := new(bytes.Buffer) r := bytes.NewReader(src) w := codec.NewWriter(b) @@ -49,7 +68,7 @@ func compress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { return b.Bytes(), nil } -func decompress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { +func decompress(codec pkg.Codec, src []byte) ([]byte, error) { b := new(bytes.Buffer) r := codec.NewReader(bytes.NewReader(src)) if _, err := io.Copy(b, r); err != nil { @@ -62,51 +81,57 @@ func decompress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { return b.Bytes(), nil } -func testEncodeDecode(t *testing.T, m kafka.Message, codec kafka.CompressionCodec) { +func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) { var r1, r2 []byte var err error t.Run("encode with "+codec.Name(), func(t *testing.T) { r1, err = compress(codec, m.Value) if err != nil { - t.Error(err) + t.Fatal(err) } }) t.Run("decode with "+codec.Name(), func(t *testing.T) { + if r1 == nil { + if r1, err = compress(codec, m.Value); err != nil { + t.Fatal(err) + } + } r2, err = decompress(codec, r1) if err != nil { - t.Error(err) + t.Fatal(err) } if string(r2) != "message" { t.Error("bad message") - t.Log("got: ", string(r2)) - t.Log("expected: ", string(m.Value)) + t.Logf("expected: %q", string(m.Value)) + t.Logf("got: %q", string(r2)) } }) } func TestCompressedMessages(t *testing.T) { - testCompressedMessages(t, gzip.NewCompressionCodec()) - testCompressedMessages(t, snappy.NewCompressionCodec()) - testCompressedMessages(t, lz4.NewCompressionCodec()) + testCompressedMessages(t, new(gzip.Codec)) + testCompressedMessages(t, new(snappy.Codec)) + testCompressedMessages(t, new(lz4.Codec)) if ktesting.KafkaIsAtLeast("2.1.0") { - testCompressedMessages(t, zstd.NewCompressionCodec()) + testCompressedMessages(t, new(zstd.Codec)) } } -func testCompressedMessages(t *testing.T, codec kafka.CompressionCodec) { - t.Run("produce/consume with"+codec.Name(), func(t *testing.T) { - t.Parallel() - - topic := kafka.CreateTopic(t, 1) - w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"127.0.0.1:9092"}, - Topic: topic, - CompressionCodec: codec, - BatchTimeout: 10 * time.Millisecond, - }) +func testCompressedMessages(t *testing.T, codec pkg.Codec) { + t.Run(codec.Name(), func(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + w := &kafka.Writer{ + Addr: kafka.TCP("127.0.0.1:9092"), + Topic: topic, + Compression: kafka.Compression(codec.Code()), + BatchTimeout: 10 * time.Millisecond, + Transport: client.Transport, + } defer w.Close() offset := 0 @@ -149,16 +174,16 @@ func testCompressedMessages(t *testing.T, codec kafka.CompressionCodec) { for i := base; i < len(values); i++ { msg, err := r.ReadMessage(ctx) if err != nil { - t.Errorf("error receiving message at loop %d, offset %d, reason: %+v", base, i, err) + t.Fatalf("error receiving message at loop %d, offset %d, reason: %+v", base, i, err) } if msg.Offset != int64(i) { - t.Errorf("wrong offset at loop %d...expected %d but got %d", base, i, msg.Offset) + t.Fatalf("wrong offset at loop %d...expected %d but got %d", base, i, msg.Offset) } if strconv.Itoa(i) != string(msg.Key) { - t.Errorf("wrong message key at loop %d...expected %d but got %s", base, i, string(msg.Key)) + t.Fatalf("wrong message key at loop %d...expected %d but got %s", base, i, string(msg.Key)) } if values[i] != string(msg.Value) { - t.Errorf("wrong message value at loop %d...expected %s but got %s", base, values[i], string(msg.Value)) + t.Fatalf("wrong message value at loop %d...expected %s but got %s", base, values[i], string(msg.Value)) } } } @@ -166,20 +191,23 @@ func testCompressedMessages(t *testing.T, codec kafka.CompressionCodec) { } func TestMixedCompressedMessages(t *testing.T) { - t.Parallel() - - topic := kafka.CreateTopic(t, 1) + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() offset := 0 var values []string - produce := func(n int, codec kafka.CompressionCodec) { - w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"127.0.0.1:9092"}, - Topic: topic, - CompressionCodec: codec, - }) + produce := func(n int, codec pkg.Codec) { + w := &kafka.Writer{ + Addr: kafka.TCP("127.0.0.1:9092"), + Topic: topic, + Transport: client.Transport, + } defer w.Close() + if codec != nil { + w.Compression = kafka.Compression(codec.Code()) + } + msgs := make([]kafka.Message, n) for i := range msgs { value := fmt.Sprintf("Hello World %d!", offset) @@ -199,10 +227,10 @@ func TestMixedCompressedMessages(t *testing.T) { // different compression codecs. reader should be able to properly handle // all of them. produce(10, nil) - produce(20, gzip.NewCompressionCodec()) + produce(20, new(gzip.Codec)) produce(5, nil) - produce(10, snappy.NewCompressionCodec()) - produce(10, lz4.NewCompressionCodec()) + produce(10, new(snappy.Codec)) + produce(10, new(lz4.Codec)) produce(5, nil) r := kafka.NewReader(kafka.ReaderConfig{ @@ -261,27 +289,27 @@ func (nopWriteCloser) Close() error { return nil } func BenchmarkCompression(b *testing.B) { benchmarks := []struct { - codec kafka.CompressionCodec - function func(*testing.B, kafka.CompressionCodec, *bytes.Buffer, []byte) float64 + codec pkg.Codec + function func(*testing.B, pkg.Codec, *bytes.Buffer, []byte) float64 }{ { codec: &noopCodec{}, function: benchmarkCompression, }, { - codec: gzip.NewCompressionCodec(), + codec: new(gzip.Codec), function: benchmarkCompression, }, { - codec: snappy.NewCompressionCodec(), + codec: new(snappy.Codec), function: benchmarkCompression, }, { - codec: lz4.NewCompressionCodec(), + codec: new(lz4.Codec), function: benchmarkCompression, }, { - codec: zstd.NewCompressionCodec(), + codec: new(zstd.Codec), function: benchmarkCompression, }, } @@ -292,7 +320,7 @@ func BenchmarkCompression(b *testing.B) { } defer f.Close() - z, err := compressGzip.NewReader(f) + z, err := stdgzip.NewReader(f) if err != nil { b.Fatal(err) } @@ -327,7 +355,7 @@ func BenchmarkCompression(b *testing.B) { } } -func benchmarkCompression(b *testing.B, codec kafka.CompressionCodec, buf *bytes.Buffer, payload []byte) float64 { +func benchmarkCompression(b *testing.B, codec pkg.Codec, buf *bytes.Buffer, payload []byte) float64 { // In case only the decompression benchmark are run, we use this flags to // detect whether we have to compress the payload before the decompression // benchmarks. @@ -388,3 +416,73 @@ func benchmarkCompression(b *testing.B, codec kafka.CompressionCodec, buf *bytes return 1 - (float64(buf.Len()) / float64(len(payload))) } + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func makeTopic() string { + return fmt.Sprintf("kafka-go-%016x", rand.Int63()) +} + +func newLocalClientAndTopic() (*kafka.Client, string, func()) { + topic := makeTopic() + client, shutdown := newLocalClient() + + _, err := client.CreateTopics(context.Background(), &kafka.CreateTopicsRequest{ + Topics: []kafka.TopicConfig{{ + Topic: topic, + NumPartitions: 1, + ReplicationFactor: 1, + }}, + }) + if err != nil { + shutdown() + panic(err) + } + + // Topic creation seems to be asynchronous. Metadata for the topic partition + // layout in the cluster is available in the controller before being synced + // with the other brokers, which causes "Error:[3] Unknown Topic Or Partition" + // when sending requests to the partition leaders. + for i := 0; i < 20; i++ { + r, err := client.Fetch(context.Background(), &kafka.FetchRequest{ + Topic: topic, + Partition: 0, + Offset: 0, + }) + if err == nil && r.Error == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + + return client, topic, func() { + client.DeleteTopics(context.Background(), &kafka.DeleteTopicsRequest{ + Topics: []string{topic}, + }) + shutdown() + } +} + +func newLocalClient() (*kafka.Client, func()) { + return newClient(kafka.TCP("127.0.0.1:9092")) +} + +func newClient(addr net.Addr) (*kafka.Client, func()) { + conns := &ktesting.ConnWaitGroup{ + DialFunc: (&net.Dialer{}).DialContext, + } + + transport := &kafka.Transport{ + Dial: conns.Dial, + } + + client := &kafka.Client{ + Addr: addr, + Timeout: 5 * time.Second, + Transport: transport, + } + + return client, func() { transport.CloseIdleConnections(); conns.Wait() } +} diff --git a/compress/gzip/gzip.go b/compress/gzip/gzip.go new file mode 100644 index 000000000..64da3129d --- /dev/null +++ b/compress/gzip/gzip.go @@ -0,0 +1,122 @@ +package gzip + +import ( + "compress/gzip" + "io" + "sync" +) + +var ( + readerPool sync.Pool +) + +// Codec is the implementation of a compress.Codec which supports creating +// readers and writers for kafka messages compressed with gzip. +type Codec struct { + // The compression level to configure on writers created by this codec. + // Acceptable values are defined in the standard gzip package. + // + // Default to gzip.DefaultCompressionLevel. + Level int + + writerPool sync.Pool +} + +// Code implements the compress.Codec interface. +func (c *Codec) Code() int8 { return 1 } + +// Name implements the compress.Codec interface. +func (c *Codec) Name() string { return "gzip" } + +// NewReader implements the compress.Codec interface. +func (c *Codec) NewReader(r io.Reader) io.ReadCloser { + var err error + z, _ := readerPool.Get().(*gzip.Reader) + if z != nil { + err = z.Reset(r) + } else { + z, err = gzip.NewReader(r) + } + if err != nil { + if z != nil { + readerPool.Put(z) + } + return &errorReader{err: err} + } + return &reader{Reader: z} +} + +// NewWriter implements the compress.Codec interface. +func (c *Codec) NewWriter(w io.Writer) io.WriteCloser { + x := c.writerPool.Get() + z, _ := x.(*gzip.Writer) + if z == nil { + x, err := gzip.NewWriterLevel(w, c.level()) + if err != nil { + return &errorWriter{err: err} + } + z = x + } else { + z.Reset(w) + } + return &writer{codec: c, Writer: z} +} + +func (c *Codec) level() int { + if c.Level != 0 { + return c.Level + } + return gzip.DefaultCompression +} + +type reader struct{ *gzip.Reader } + +func (r *reader) Close() (err error) { + if z := r.Reader; z != nil { + r.Reader = nil + err = z.Close() + // Pass it an empty reader, which is a zero-size value implementing the + // flate.Reader interface to avoid the construction of a bufio.Reader in + // the call to Reset. + // + // Note: we could also not reset the reader at all, but that would cause + // the underlying reader to be retained until the gzip.Reader is freed, + // which may not be desirable. + z.Reset(emptyReader{}) + readerPool.Put(z) + } + return +} + +type writer struct { + codec *Codec + *gzip.Writer +} + +func (w *writer) Close() (err error) { + if z := w.Writer; z != nil { + w.Writer = nil + err = z.Close() + z.Reset(nil) + w.codec.writerPool.Put(z) + } + return +} + +type emptyReader struct{} + +func (emptyReader) ReadByte() (byte, error) { return 0, io.EOF } + +func (emptyReader) Read([]byte) (int, error) { return 0, io.EOF } + +type errorReader struct{ err error } + +func (r *errorReader) Close() error { return r.err } + +func (r *errorReader) Read([]byte) (int, error) { return 0, r.err } + +type errorWriter struct{ err error } + +func (w *errorWriter) Close() error { return w.err } + +func (w *errorWriter) Write([]byte) (int, error) { return 0, w.err } diff --git a/compress/lz4/lz4.go b/compress/lz4/lz4.go new file mode 100644 index 000000000..2c892ca07 --- /dev/null +++ b/compress/lz4/lz4.go @@ -0,0 +1,68 @@ +package lz4 + +import ( + "io" + "sync" + + "github.com/pierrec/lz4" +) + +var ( + readerPool sync.Pool + writerPool sync.Pool +) + +// Codec is the implementation of a compress.Codec which supports creating +// readers and writers for kafka messages compressed with lz4. +type Codec struct{} + +// Code implements the compress.Codec interface. +func (c *Codec) Code() int8 { return 3 } + +// Name implements the compress.Codec interface. +func (c *Codec) Name() string { return "lz4" } + +// NewReader implements the compress.Codec interface. +func (c *Codec) NewReader(r io.Reader) io.ReadCloser { + z, _ := readerPool.Get().(*lz4.Reader) + if z != nil { + z.Reset(r) + } else { + z = lz4.NewReader(r) + } + return &reader{Reader: z} +} + +// NewWriter implements the compress.Codec interface. +func (c *Codec) NewWriter(w io.Writer) io.WriteCloser { + z, _ := writerPool.Get().(*lz4.Writer) + if z != nil { + z.Reset(w) + } else { + z = lz4.NewWriter(w) + } + return &writer{Writer: z} +} + +type reader struct{ *lz4.Reader } + +func (r *reader) Close() (err error) { + if z := r.Reader; z != nil { + r.Reader = nil + z.Reset(nil) + readerPool.Put(z) + } + return +} + +type writer struct{ *lz4.Writer } + +func (w *writer) Close() (err error) { + if z := w.Writer; z != nil { + w.Writer = nil + err = z.Close() + z.Reset(nil) + writerPool.Put(z) + } + return +} diff --git a/compress/snappy/snappy.go b/compress/snappy/snappy.go new file mode 100644 index 000000000..fcf72409d --- /dev/null +++ b/compress/snappy/snappy.go @@ -0,0 +1,89 @@ +package snappy + +import ( + "io" + "sync" + + "github.com/golang/snappy" +) + +// Framing is an enumeration type used to enable or disable xerial framing of +// snappy messages. +type Framing int + +const ( + Framed Framing = iota + Unframed +) + +var ( + readerPool sync.Pool + writerPool sync.Pool +) + +// Codec is the implementation of a compress.Codec which supports creating +// readers and writers for kafka messages compressed with snappy. +type Codec struct { + // An optional framing to apply to snappy compression. + // + // Default to Framed. + Framing Framing +} + +// Code implements the compress.Codec interface. +func (c *Codec) Code() int8 { return 2 } + +// Name implements the compress.Codec interface. +func (c *Codec) Name() string { return "snappy" } + +// NewReader implements the compress.Codec interface. +func (c *Codec) NewReader(r io.Reader) io.ReadCloser { + x, _ := readerPool.Get().(*xerialReader) + if x != nil { + x.Reset(r) + } else { + x = &xerialReader{ + reader: r, + decode: snappy.Decode, + } + } + return &reader{xerialReader: x} +} + +// NewWriter implements the compress.Codec interface. +func (c *Codec) NewWriter(w io.Writer) io.WriteCloser { + x, _ := writerPool.Get().(*xerialWriter) + if x != nil { + x.Reset(w) + } else { + x = &xerialWriter{ + writer: w, + encode: snappy.Encode, + } + } + x.framed = c.Framing == Framed + return &writer{xerialWriter: x} +} + +type reader struct{ *xerialReader } + +func (r *reader) Close() (err error) { + if x := r.xerialReader; x != nil { + r.xerialReader = nil + x.Reset(nil) + readerPool.Put(x) + } + return +} + +type writer struct{ *xerialWriter } + +func (w *writer) Close() (err error) { + if x := w.xerialWriter; x != nil { + w.xerialWriter = nil + err = x.Flush() + x.Reset(nil) + writerPool.Put(x) + } + return +} diff --git a/snappy/xerial.go b/compress/snappy/xerial.go similarity index 100% rename from snappy/xerial.go rename to compress/snappy/xerial.go diff --git a/snappy/xerial_test.go b/compress/snappy/xerial_test.go similarity index 100% rename from snappy/xerial_test.go rename to compress/snappy/xerial_test.go diff --git a/compress/zstd/zstd.go b/compress/zstd/zstd.go new file mode 100644 index 000000000..608b002e9 --- /dev/null +++ b/compress/zstd/zstd.go @@ -0,0 +1,180 @@ +// Package zstd implements Zstandard compression. +package zstd + +import ( + "io" + "runtime" + "sync" + + "github.com/klauspost/compress/zstd" +) + +// Codec is the implementation of a compress.Codec which supports creating +// readers and writers for kafka messages compressed with zstd. +type Codec struct { + // The compression level configured on writers created by the codec. + // + // Default to 3. + Level int + + encoderPool sync.Pool // *encoder +} + +// Code implements the compress.Codec interface. +func (c *Codec) Code() int8 { return 4 } + +// Name implements the compress.Codec interface. +func (c *Codec) Name() string { return "zstd" } + +// NewReader implements the compress.Codec interface. +func (c *Codec) NewReader(r io.Reader) io.ReadCloser { + p := new(reader) + if dec, _ := decoderPool.Get().(*decoder); dec == nil { + z, err := zstd.NewReader(r) + if err != nil { + p.err = err + } else { + p.dec = &decoder{z} + // We need a finalizer because the reader spawns goroutines + // that will only be stopped if the Close method is called. + runtime.SetFinalizer(p.dec, (*decoder).finalize) + } + } else { + p.dec = dec + p.err = dec.Reset(r) + } + return p +} + +func (c *Codec) level() int { + if c.Level != 0 { + return c.Level + } + return 3 +} + +func (c *Codec) zstdLevel() zstd.EncoderLevel { + return zstd.EncoderLevelFromZstd(c.level()) +} + +var decoderPool sync.Pool // *decoder + +type decoder struct { + *zstd.Decoder +} + +func (d *decoder) finalize() { + d.Close() +} + +type reader struct { + dec *decoder + err error +} + +// Close implements the io.Closer interface. +func (r *reader) Close() error { + if r.dec != nil { + r.dec.Reset(devNull{}) // don't retain the underlying reader + decoderPool.Put(r.dec) + r.dec = nil + r.err = io.ErrClosedPipe + } + return nil +} + +// Read implements the io.Reader interface. +func (r *reader) Read(p []byte) (int, error) { + if r.err != nil { + return 0, r.err + } + return r.dec.Read(p) +} + +// WriteTo implements the io.WriterTo interface. +func (r *reader) WriteTo(w io.Writer) (int64, error) { + if r.err != nil { + return 0, r.err + } + return r.dec.WriteTo(w) +} + +// NewWriter implements the compress.Codec interface. +func (c *Codec) NewWriter(w io.Writer) io.WriteCloser { + p := new(writer) + if enc, _ := c.encoderPool.Get().(*encoder); enc == nil { + z, err := zstd.NewWriter(w, zstd.WithEncoderLevel(c.zstdLevel())) + if err != nil { + p.err = err + } else { + p.enc = &encoder{z} + // We need a finalizer because the writer spawns goroutines + // that will only be stopped if the Close method is called. + runtime.SetFinalizer(p.enc, (*encoder).finalize) + } + } else { + p.enc = enc + p.enc.Reset(w) + } + p.c = c + return p +} + +type encoder struct { + *zstd.Encoder +} + +func (e *encoder) finalize() { + e.Close() +} + +type writer struct { + c *Codec + enc *encoder + err error +} + +// Close implements the io.Closer interface. +func (w *writer) Close() error { + if w.enc != nil { + // Close needs to be called to write the end of stream marker and flush + // the buffers. The zstd package documents that the encoder is re-usable + // after being closed. + err := w.enc.Close() + if err != nil { + w.err = err + } + w.enc.Reset(devNull{}) // don't retain the underyling writer + w.c.encoderPool.Put(w.enc) + w.enc = nil + return err + } + return nil +} + +// WriteTo implements the io.WriterTo interface. +func (w *writer) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + if w.enc == nil { + return 0, io.ErrClosedPipe + } + return w.enc.Write(p) +} + +// ReadFrom implements the io.ReaderFrom interface. +func (w *writer) ReadFrom(r io.Reader) (int64, error) { + if w.err != nil { + return 0, w.err + } + if w.enc == nil { + return 0, io.ErrClosedPipe + } + return w.enc.ReadFrom(r) +} + +type devNull struct{} + +func (devNull) Read([]byte) (int, error) { return 0, io.EOF } +func (devNull) Write([]byte) (int, error) { return 0, nil } diff --git a/compression.go b/compression.go index c415ff436..61e8d20c1 100644 --- a/compression.go +++ b/compression.go @@ -2,57 +2,36 @@ package kafka import ( "errors" - "io" - "sync" + + "github.com/segmentio/kafka-go/compress" +) + +type Compression = compress.Compression + +const ( + Gzip Compression = compress.Gzip + Snappy Compression = compress.Snappy + Lz4 Compression = compress.Lz4 + Zstd Compression = compress.Zstd ) +type CompressionCodec = compress.Codec + +// TODO: this file should probably go away once the internals of the package +// have moved to use the protocol package. const ( compressionCodecMask = 0x07 ) var ( errUnknownCodec = errors.New("the compression code is invalid or its codec has not been imported") - - codecs = make(map[int8]CompressionCodec) - codecsMutex sync.RWMutex ) -// RegisterCompressionCodec registers a compression codec so it can be used by a Writer. -func RegisterCompressionCodec(codec CompressionCodec) { - code := codec.Code() - codecsMutex.Lock() - codecs[code] = codec - codecsMutex.Unlock() -} - // resolveCodec looks up a codec by Code() -func resolveCodec(code int8) (codec CompressionCodec, err error) { - codecsMutex.RLock() - codec = codecs[code] - codecsMutex.RUnlock() - +func resolveCodec(code int8) (CompressionCodec, error) { + codec := compress.Compression(code).Codec() if codec == nil { - err = errUnknownCodec + return nil, errUnknownCodec } - return -} - -// CompressionCodec represents a compression codec to encode and decode -// the messages. -// See : https://cwiki.apache.org/confluence/display/KAFKA/Compression -// -// A CompressionCodec must be safe for concurrent access by multiple go -// routines. -type CompressionCodec interface { - // Code returns the compression codec code - Code() int8 - - // Human-readable name for the codec. - Name() string - - // Constructs a new reader which decompresses data from r. - NewReader(r io.Reader) io.ReadCloser - - // Constructs a new writer which writes compressed data to w. - NewWriter(w io.Writer) io.WriteCloser + return codec, nil } diff --git a/conn.go b/conn.go index ff229f76e..8b6c793a9 100644 --- a/conn.go +++ b/conn.go @@ -20,23 +20,6 @@ var ( errInvalidWritePartition = errors.New("writes must NOT set Partition on kafka.Message") ) -// Broker carries the metadata associated with a kafka broker. -type Broker struct { - Host string - Port int - ID int - Rack string -} - -// Partition carries the metadata associated with a kafka partition. -type Partition struct { - Topic string - Leader Broker - Replicas []Broker - Isr []Broker - ID int -} - // Conn represents a connection to a kafka broker. // // Instances of Conn are safe to use concurrently from multiple goroutines. @@ -890,7 +873,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { conn: c, msgs: msgs, deadline: adjustedDeadline, - throttle: duration(throttle), + throttle: makeDuration(throttle), lock: lock, topic: c.topic, // topic is copied to Batch to prevent race with Batch.close partition: int(c.partition), // partition is copied to Batch to prevent race with Batch.close @@ -1404,7 +1387,7 @@ func (c *Conn) ApiVersions() ([]ApiVersion, error) { if deadline.deadline().IsZero() { // ApiVersions is called automatically when API version negotiation - // needs to happen, so we are not garanteed that a read deadline has + // needs to happen, so we are not guaranteed that a read deadline has // been set yet. Fallback to use the write deadline in case it was // set, for example when version negotiation is initiated during a // produce request. diff --git a/conn_test.go b/conn_test.go index 535e4b791..579b13f73 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "net" + "os" "strconv" "testing" "time" @@ -105,8 +106,6 @@ func makeGroupID() string { } func TestConn(t *testing.T) { - t.Parallel() - tests := []struct { scenario string function func(*testing.T, *Conn) @@ -153,7 +152,7 @@ func TestConn(t *testing.T) { }, { - scenario: "unchecked seeks allow the connection to be positionned outside the boundaries of the partition", + scenario: "unchecked seeks allow the connection to be positioned outside the boundaries of the partition", function: testConnSeekDontCheck, }, @@ -295,6 +294,20 @@ func TestConn(t *testing.T) { } t.Run("nettest", func(t *testing.T) { + // Need ability to skip nettest on newer Kafka versions to avoid these kinds of errors: + // --- FAIL: TestConn/nettest (17.56s) + // --- FAIL: TestConn/nettest/PingPong (7.40s) + // conntest.go:112: unexpected Read error: [7] Request Timed Out: the request exceeded the user-specified time limit in the request + // conntest.go:118: mismatching value: got 77, want 78 + // conntest.go:118: mismatching value: got 78, want 79 + // ... + // + // TODO: Figure out why these are happening and fix them (they don't appear to be new). + if _, ok := os.LookupEnv("KAFKA_SKIP_NETTEST"); ok { + t.Log("skipping nettest because KAFKA_SKIP_NETTEST is set") + t.Skip() + } + t.Parallel() nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { @@ -1078,7 +1091,7 @@ func testBrokers(t *testing.T, conn *Conn) { func TestReadPartitionsNoTopic(t *testing.T) { conn, err := Dial("tcp", "127.0.0.1:9092") if err != nil { - t.Error(err) + t.Fatal(err) } defer conn.Close() diff --git a/consumergroup_test.go b/consumergroup_test.go index 390de5b2d..2dbd37d99 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -101,11 +101,11 @@ func TestValidateConsumerGroupConfig(t *testing.T) { {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", SessionTimeout: -1}, errorOccured: true}, {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: -1}, errorOccured: true}, {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: -2}, errorOccured: true}, - {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: -2}, errorOccured: true}, - {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, StartOffset: 123}, errorOccured: true}, - {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, PartitionWatchInterval: -1}, errorOccured: true}, - {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, PartitionWatchInterval: 1, JoinGroupBackoff: -1}, errorOccured: true}, - {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, PartitionWatchInterval: 1, JoinGroupBackoff: 1}, errorOccured: false}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, StartOffset: 123}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: 1}, errorOccured: false}, } for _, test := range tests { err := test.config.Validate() @@ -237,8 +237,6 @@ func TestReaderAssignTopicPartitions(t *testing.T) { } func TestConsumerGroup(t *testing.T) { - t.Parallel() - tests := []struct { scenario string function func(*testing.T, context.Context, *ConsumerGroup) @@ -314,11 +312,10 @@ func TestConsumerGroup(t *testing.T) { topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) for _, test := range tests { t.Run(test.scenario, func(t *testing.T) { - t.Parallel() - group, err := NewConsumerGroup(ConsumerGroupConfig{ ID: makeGroupID(), Topics: []string{topic}, @@ -342,8 +339,6 @@ func TestConsumerGroup(t *testing.T) { } func TestConsumerGroupErrors(t *testing.T) { - t.Parallel() - var left []string var lock sync.Mutex mc := mockCoordinator{ diff --git a/crc32_test.go b/crc32_test.go index e0251b6b7..bbc6f4ba6 100644 --- a/crc32_test.go +++ b/crc32_test.go @@ -7,8 +7,6 @@ import ( ) func TestMessageCRC32(t *testing.T) { - t.Parallel() - m := message{ MagicByte: 1, Timestamp: 42, diff --git a/createtopics.go b/createtopics.go index 95619da2d..b0b50a460 100644 --- a/createtopics.go +++ b/createtopics.go @@ -2,9 +2,86 @@ package kafka import ( "bufio" + "context" + "fmt" + "net" "time" + + "github.com/segmentio/kafka-go/protocol/createtopics" ) +// CreateTopicRequests represents a request sent to a kafka broker to create +// new topics. +type CreateTopicsRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // List of topics to create and their configuration. + Topics []TopicConfig + + // When set to true, topics are not created but the configuration is + // validated as if they were. + // + // This field will be ignored if the kafka broker did no support the + // CreateTopics API in version 1 or above. + ValidateOnly bool +} + +// CreateTopicResponse represents a response from a kafka broker to a topic +// creation request. +type CreateTopicsResponse struct { + // The amount of time that the broker throttled the request. + // + // This field will be zero if the kafka broker did no support the + // CreateTopics API in version 2 or above. + Throttle time.Duration + + // Mapping of topic names to errors that occurred while attempting to create + // the topics. + // + // The errors contain the kafka error code. Programs may use the standard + // errors.Is function to test the error against kafka error codes. + Errors map[string]error +} + +// CreateTopics sends a topic creation request to a kafka broker and returns the +// response. +func (c *Client) CreateTopics(ctx context.Context, req *CreateTopicsRequest) (*CreateTopicsResponse, error) { + topics := make([]createtopics.RequestTopic, len(req.Topics)) + + for i, t := range req.Topics { + topics[i] = createtopics.RequestTopic{ + Name: t.Topic, + NumPartitions: int32(t.NumPartitions), + ReplicationFactor: int16(t.ReplicationFactor), + Assignments: t.assignments(), + Configs: t.configs(), + } + } + + m, err := c.roundTrip(ctx, req.Addr, &createtopics.Request{ + Topics: topics, + TimeoutMs: c.timeoutMs(ctx, defaultCreateTopicsTimeout), + ValidateOnly: req.ValidateOnly, + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).CreateTopics: %w", err) + } + + res := m.(*createtopics.Response) + ret := &CreateTopicsResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Errors: make(map[string]error, len(res.Topics)), + } + + for _, t := range res.Topics { + ret.Errors[t.Name] = makeError(t.ErrorCode, t.ErrorMessage) + } + + return ret, nil +} + type ConfigEntry struct { ConfigName string ConfigValue string @@ -34,29 +111,55 @@ func (t createTopicsRequestV0ConfigEntry) writeTo(wb *writeBuffer) { type ReplicaAssignment struct { Partition int - Replicas int + // The list of brokers where the partition should be allocated. There must + // be as many entries in thie list as there are replicas of the partition. + // The first entry represents the broker that will be the preferred leader + // for the partition. + // + // This field changed in 0.4 from `int` to `[]int`. It was invalid to pass + // a single integer as this is supposed to be a list. While this introduces + // a breaking change, it probably never worked before. + Replicas []int +} + +func (a *ReplicaAssignment) partitionIndex() int32 { + return int32(a.Partition) +} + +func (a *ReplicaAssignment) brokerIDs() []int32 { + if len(a.Replicas) == 0 { + return nil + } + replicas := make([]int32, len(a.Replicas)) + for i, r := range a.Replicas { + replicas[i] = int32(r) + } + return replicas } func (a ReplicaAssignment) toCreateTopicsRequestV0ReplicaAssignment() createTopicsRequestV0ReplicaAssignment { return createTopicsRequestV0ReplicaAssignment{ Partition: int32(a.Partition), - Replicas: int32(a.Replicas), + Replicas: a.brokerIDs(), } } type createTopicsRequestV0ReplicaAssignment struct { Partition int32 - Replicas int32 + Replicas []int32 } func (t createTopicsRequestV0ReplicaAssignment) size() int32 { return sizeofInt32(t.Partition) + - sizeofInt32(t.Replicas) + (int32(len(t.Replicas)+1) * sizeofInt32(0)) // N+1 because the array length is a int32 } func (t createTopicsRequestV0ReplicaAssignment) writeTo(wb *writeBuffer) { wb.writeInt32(t.Partition) - wb.writeInt32(t.Replicas) + wb.writeInt32(int32(len(t.Replicas))) + for _, r := range t.Replicas { + wb.writeInt32(int32(r)) + } } type TopicConfig struct { @@ -77,6 +180,34 @@ type TopicConfig struct { ConfigEntries []ConfigEntry } +func (t *TopicConfig) assignments() []createtopics.RequestAssignment { + if len(t.ReplicaAssignments) == 0 { + return nil + } + assignments := make([]createtopics.RequestAssignment, len(t.ReplicaAssignments)) + for i, a := range t.ReplicaAssignments { + assignments[i] = createtopics.RequestAssignment{ + PartitionIndex: a.partitionIndex(), + BrokerIDs: a.brokerIDs(), + } + } + return assignments +} + +func (t *TopicConfig) configs() []createtopics.RequestConfig { + if len(t.ConfigEntries) == 0 { + return nil + } + configs := make([]createtopics.RequestConfig, len(t.ConfigEntries)) + for i, c := range t.ConfigEntries { + configs[i] = createtopics.RequestConfig{ + Name: c.ConfigName, + Value: c.ConfigValue, + } + } + return configs +} + func (t TopicConfig) toCreateTopicsRequestV0Topic() createTopicsRequestV0Topic { var requestV0ReplicaAssignments []createTopicsRequestV0ReplicaAssignment for _, a := range t.ReplicaAssignments { diff --git a/createtopics_test.go b/createtopics_test.go index fae3b2865..71cf456e0 100644 --- a/createtopics_test.go +++ b/createtopics_test.go @@ -3,10 +3,88 @@ package kafka import ( "bufio" "bytes" + "context" "reflect" "testing" ) +func TestClientCreateTopics(t *testing.T) { + const ( + topic1 = "client-topic-1" + topic2 = "client-topic-2" + topic3 = "client-topic-3" + ) + + client, shutdown := newLocalClient() + defer shutdown() + + config := []ConfigEntry{{ + ConfigName: "retention.ms", + ConfigValue: "3600000", + }} + + res, err := client.CreateTopics(context.Background(), &CreateTopicsRequest{ + Topics: []TopicConfig{ + { + Topic: topic1, + NumPartitions: -1, + ReplicationFactor: -1, + ReplicaAssignments: []ReplicaAssignment{ + { + Partition: 0, + Replicas: []int{1}, + }, + { + Partition: 1, + Replicas: []int{1}, + }, + { + Partition: 2, + Replicas: []int{1}, + }, + }, + ConfigEntries: config, + }, + { + Topic: topic2, + NumPartitions: 2, + ReplicationFactor: 1, + ConfigEntries: config, + }, + { + Topic: topic3, + NumPartitions: 1, + ReplicationFactor: 1, + ConfigEntries: config, + }, + }, + }) + + if err != nil { + t.Fatal(err) + } + + defer deleteTopic(t, topic1, topic2, topic3) + + expectTopics := map[string]struct{}{ + topic1: {}, + topic2: {}, + topic3: {}, + } + + for topic, error := range res.Errors { + delete(expectTopics, topic) + + if error != nil { + t.Errorf("%s => %s", topic, error) + } + } + + for topic := range expectTopics { + t.Errorf("topic missing in response: %s", topic) + } +} + func TestCreateTopicsResponseV0(t *testing.T) { item := createTopicsResponseV0{ TopicErrors: []createTopicsResponseV0TopicError{ diff --git a/deletetopics.go b/deletetopics.go index 687c380c3..470f9ef83 100644 --- a/deletetopics.go +++ b/deletetopics.go @@ -2,9 +2,70 @@ package kafka import ( "bufio" + "context" + "fmt" + "net" "time" + + "github.com/segmentio/kafka-go/protocol/deletetopics" ) +// DeleteTopicsRequest represents a request sent to a kafka broker to delete +// topics. +type DeleteTopicsRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // Names of topics to delete. + Topics []string +} + +// DeleteTopicsResponse represents a response from a kafka broker to a topic +// deletion request. +type DeleteTopicsResponse struct { + // The amount of time that the broker throttled the request. + // + // This field will be zero if the kafka broker did no support the + // DeleteTopics API in version 1 or above. + Throttle time.Duration + + // Mapping of topic names to errors that occurred while attempting to delete + // the topics. + // + // The errors contain the kafka error code. Programs may use the standard + // errors.Is function to test the error against kafka error codes. + Errors map[string]error +} + +// DeleteTopics sends a topic deletion request to a kafka broker and returns the +// response. +func (c *Client) DeleteTopics(ctx context.Context, req *DeleteTopicsRequest) (*DeleteTopicsResponse, error) { + m, err := c.roundTrip(ctx, req.Addr, &deletetopics.Request{ + TopicNames: req.Topics, + TimeoutMs: c.timeoutMs(ctx, defaultDeleteTopicsTimeout), + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).DeleteTopics: %w", err) + } + + res := m.(*deletetopics.Response) + ret := &DeleteTopicsResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Errors: make(map[string]error, len(res.Responses)), + } + + for _, t := range res.Responses { + if t.ErrorCode == 0 { + ret.Errors[t.Name] = nil + } else { + ret.Errors[t.Name] = Error(t.ErrorCode) + } + } + + return ret, nil +} + // See http://kafka.apache.org/protocol.html#The_Messages_DeleteTopics type deleteTopicsRequestV0 struct { // Topics holds the topic names diff --git a/deletetopics_test.go b/deletetopics_test.go index dcedc4642..3caffe840 100644 --- a/deletetopics_test.go +++ b/deletetopics_test.go @@ -3,10 +3,31 @@ package kafka import ( "bufio" "bytes" + "context" "reflect" "testing" ) +func TestClientDeleteTopics(t *testing.T) { + client, shutdown := newLocalClient() + defer shutdown() + + topic := makeTopic() + createTopic(t, topic, 1) + + res, err := client.DeleteTopics(context.Background(), &DeleteTopicsRequest{ + Topics: []string{topic}, + }) + + if err != nil { + t.Fatal(err) + } + + if err := res.Errors[topic]; err != nil { + t.Error(err) + } +} + func TestDeleteTopicsResponseV1(t *testing.T) { item := deleteTopicsResponseV0{ TopicErrorCodes: []deleteTopicsResponseV0TopicErrorCode{ diff --git a/dialer.go b/dialer.go index 35eb080cc..43e8af194 100644 --- a/dialer.go +++ b/dialer.go @@ -65,7 +65,13 @@ type Dialer struct { // support keep-alives ignore this field. KeepAlive time.Duration - // Resolver optionally specifies an alternate resolver to use. + // Resolver optionally gives a hook to convert the broker address into an + // alternate host or IP address which is useful for custom service discovery. + // If a custom resolver returns any possible hosts, the first one will be + // used and the original discarded. If a port number is included with the + // resolved host, it will only be used if a port number was not previously + // specified. If no port is specified or resolved, the default of 9092 will be + // used. Resolver Resolver // TLS enables Dialer to open secure connections. If nil, standard net.Conn @@ -320,20 +326,10 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { return nil } -func (d *Dialer) dialContext(ctx context.Context, network string, address string) (net.Conn, error) { - if r := d.Resolver; r != nil { - host, port := splitHostPort(address) - addrs, err := r.LookupHost(ctx, host) - if err != nil { - return nil, err - } - if len(addrs) != 0 { - address = addrs[0] - } - if len(port) != 0 { - address, _ = splitHostPort(address) - address = net.JoinHostPort(address, port) - } +func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + address, err := lookupHost(ctx, addr, d.Resolver) + if err != nil { + return nil, err } dial := d.DialFunc @@ -407,14 +403,6 @@ func LookupPartitions(ctx context.Context, network string, address string, topic return DefaultDialer.LookupPartitions(ctx, network, address, topic) } -// The Resolver interface is used as an abstraction to provide service discovery -// of the hosts of a kafka cluster. -type Resolver interface { - // LookupHost looks up the given host using the local resolver. - // It returns a slice of that host's addresses. - LookupHost(ctx context.Context, host string) (addrs []string, err error) -} - func sleep(ctx context.Context, duration time.Duration) bool { if duration == 0 { select { @@ -449,3 +437,35 @@ func splitHostPort(s string) (host string, port string) { } return } + +func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) { + host, port := splitHostPort(address) + + if resolver != nil { + resolved, err := resolver.LookupHost(ctx, host) + if err != nil { + return "", err + } + + // if the resolver doesn't return anything, we'll fall back on the provided + // address instead + if len(resolved) > 0 { + resolvedHost, resolvedPort := splitHostPort(resolved[0]) + + // we'll always prefer the resolved host + host = resolvedHost + + // in the case of port though, the provided address takes priority, and we + // only use the resolved address to set the port when not specified + if port == "" { + port = resolvedPort + } + } + } + + if port == "" { + port = "9092" + } + + return net.JoinHostPort(host, port), nil +} diff --git a/dialer_test.go b/dialer_test.go index 66f337255..5aedb777b 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "io" "net" "reflect" @@ -26,8 +27,6 @@ func TestDialer(t *testing.T) { for _, test := range tests { testFunc := test.function t.Run(test.scenario, func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -37,15 +36,15 @@ func TestDialer(t *testing.T) { } func testDialerLookupPartitions(t *testing.T, ctx context.Context, d *Dialer) { - const topic = "test-dialer-LookupPartitions" - - createTopic(t, topic, 1) + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() // Write a message to ensure the partition gets created. - w := NewWriter(WriterConfig{ - Brokers: []string{"localhost:9092"}, - Topic: topic, - }) + w := &Writer{ + Addr: TCP("localhost:9092"), + Topic: topic, + Transport: client.Transport, + } w.WriteMessages(ctx, Message{}) w.Close() @@ -61,7 +60,7 @@ func testDialerLookupPartitions(t *testing.T, ctx context.Context, d *Dialer) { want := []Partition{ { - Topic: "test-dialer-LookupPartitions", + Topic: topic, Leader: Broker{Host: "localhost", Port: 9092, ID: 1}, Replicas: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, Isr: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, @@ -170,17 +169,15 @@ wE3YmpC3Q0g9r44nEbz4Bw== } func TestDialerTLS(t *testing.T) { - t.Parallel() - - const topic = "test-dialer-LookupPartitions" - - createTopic(t, topic, 1) + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() // Write a message to ensure the partition gets created. - w := NewWriter(WriterConfig{ - Brokers: []string{"localhost:9092"}, - Topic: topic, - }) + w := &Writer{ + Addr: TCP("localhost:9092"), + Topic: topic, + Transport: client.Transport, + } w.WriteMessages(context.Background(), Message{}) w.Close() @@ -302,3 +299,105 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) { t.FailNow() } } + +func TestDialerResolver(t *testing.T) { + ctx := context.TODO() + + tests := []struct { + scenario string + address string + resolver map[string][]string + }{ + { + scenario: "resolve domain to ip", + address: "example.com", + resolver: map[string][]string{ + "example.com": {"127.0.0.1"}, + }, + }, + { + scenario: "resolve domain to ip and port", + address: "example.com", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:9092"}, + }, + }, + { + scenario: "resolve domain with port to ip", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:9092"}, + }, + }, + { + scenario: "resolve domain with port to ip with different port", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1:80"}, + }, + }, + { + scenario: "resolve domain with port to ip", + address: "example.com:9092", + resolver: map[string][]string{ + "example.com": {"127.0.0.1"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + d := Dialer{ + Resolver: &mockResolver{addrs: test.resolver}, + } + + // Write a message to ensure the partition gets created. + w := NewWriter(WriterConfig{ + Brokers: []string{"localhost:9092"}, + Topic: topic, + Dialer: &d, + }) + w.WriteMessages(context.Background(), Message{}) + w.Close() + + partitions, err := d.LookupPartitions(ctx, "tcp", test.address, topic) + if err != nil { + t.Error(err) + return + } + + sort.Slice(partitions, func(i int, j int) bool { + return partitions[i].ID < partitions[j].ID + }) + + want := []Partition{ + { + Topic: topic, + Leader: Broker{Host: "localhost", Port: 9092, ID: 1}, + Replicas: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, + Isr: []Broker{{Host: "localhost", Port: 9092, ID: 1}}, + ID: 0, + }, + } + if !reflect.DeepEqual(partitions, want) { + t.Errorf("bad partitions:\ngot: %+v\nwant: %+v", partitions, want) + } + }) + } +} + +type mockResolver struct { + addrs map[string][]string +} + +func (mr *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + if addrs, ok := mr.addrs[host]; !ok { + return nil, fmt.Errorf("unrecognized host %s", host) + } else { + return addrs, nil + } +} diff --git a/docker-compose-241.yml b/docker-compose-241.yml new file mode 100644 index 000000000..6feb1844b --- /dev/null +++ b/docker-compose-241.yml @@ -0,0 +1,32 @@ +version: "3" +services: + kafka: + image: wurstmeister/kafka:2.12-2.4.1 + restart: on-failure:3 + links: + - zookeeper + ports: + - 9092:9092 + - 9093:9093 + environment: + KAFKA_VERSION: '2.4.1' + KAFKA_BROKER_ID: '1' + KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' + KAFKA_DELETE_TOPIC_ENABLE: 'true' + KAFKA_ADVERTISED_HOST_NAME: 'localhost' + KAFKA_ADVERTISED_PORT: '9092' + KAFKA_ZOOKEEPER_CONNECT: 'zookeeper:2181' + KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' + KAFKA_MESSAGE_MAX_BYTES: '200000000' + KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' + KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' + KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN,SCRAM-SHA-256,SCRAM-SHA-512' + KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" + CUSTOM_INIT_SCRIPT: |- + echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; + /opt/kafka/bin/kafka-configs.sh --zookeeper zookeeper:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram + + zookeeper: + image: wurstmeister/zookeeper + ports: + - 2181:2181 diff --git a/docker-compose.010.yml b/docker-compose.010.yml new file mode 100644 index 000000000..56123f85c --- /dev/null +++ b/docker-compose.010.yml @@ -0,0 +1,29 @@ +version: "3" +services: + kafka: + image: wurstmeister/kafka:0.10.1.1 + links: + - zookeeper + ports: + - 9092:9092 + - 9093:9093 + environment: + KAFKA_BROKER_ID: '1' + KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' + KAFKA_DELETE_TOPIC_ENABLE: 'true' + KAFKA_ADVERTISED_HOST_NAME: 'localhost' + KAFKA_ADVERTISED_PORT: '9092' + KAFKA_ZOOKEEPER_CONNECT: 'zookeeper:2181' + KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' + KAFKA_MESSAGE_MAX_BYTES: '200000000' + KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' + KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' + KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN' + KAFKA_OPTS: "-Djava.security.auth.login.config=/opt/kafka/config/kafka_server_jaas.conf" + CUSTOM_INIT_SCRIPT: |- + echo -e 'KafkaServer {\norg.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; + + zookeeper: + image: wurstmeister/zookeeper + ports: + - 2181:2181 diff --git a/docker-compose.yml b/docker-compose.yml index cc393f433..72e7b94b1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,23 +1,23 @@ version: "3" services: kafka: - image: wurstmeister/kafka:2.11-0.11.0.3 + image: wurstmeister/kafka:2.12-2.3.1 restart: on-failure:3 links: - - zookeeper + - zookeeper ports: - - "9092:9092" - - "9093:9093" + - 9092:9092 + - 9093:9093 environment: - KAFKA_VERSION: '0.11.0.1' - KAFKA_BROKER_ID: 1 + KAFKA_VERSION: '2.3.1' + KAFKA_BROKER_ID: '1' KAFKA_CREATE_TOPICS: 'test-writer-0:3:1,test-writer-1:3:1' KAFKA_DELETE_TOPIC_ENABLE: 'true' KAFKA_ADVERTISED_HOST_NAME: 'localhost' KAFKA_ADVERTISED_PORT: '9092' KAFKA_ZOOKEEPER_CONNECT: 'zookeeper:2181' KAFKA_AUTO_CREATE_TOPICS_ENABLE: 'true' - KAFKA_MESSAGE_MAX_BYTES: 200000000 + KAFKA_MESSAGE_MAX_BYTES: '200000000' KAFKA_LISTENERS: 'PLAINTEXT://:9092,SASL_PLAINTEXT://:9093' KAFKA_ADVERTISED_LISTENERS: 'PLAINTEXT://localhost:9092,SASL_PLAINTEXT://localhost:9093' KAFKA_SASL_ENABLED_MECHANISMS: 'PLAIN,SCRAM-SHA-256,SCRAM-SHA-512' @@ -25,7 +25,8 @@ services: CUSTOM_INIT_SCRIPT: |- echo -e 'KafkaServer {\norg.apache.kafka.common.security.scram.ScramLoginModule required\n username="adminscram"\n password="admin-secret";\n org.apache.kafka.common.security.plain.PlainLoginModule required\n username="adminplain"\n password="admin-secret"\n user_adminplain="admin-secret";\n };' > /opt/kafka/config/kafka_server_jaas.conf; /opt/kafka/bin/kafka-configs.sh --zookeeper zookeeper:2181 --alter --add-config 'SCRAM-SHA-256=[password=admin-secret-256],SCRAM-SHA-512=[password=admin-secret-512]' --entity-type users --entity-name adminscram + zookeeper: image: wurstmeister/zookeeper ports: - - "2181:2181" + - 2181:2181 diff --git a/error.go b/error.go index 744535dcd..47267000c 100644 --- a/error.go +++ b/error.go @@ -93,6 +93,12 @@ const ( PreferredLeaderNotAvailable Error = 80 GroupMaxSizeReached Error = 81 FencedInstanceID Error = 82 + EligibleLeadersNotAvailable Error = 83 + ElectionNotNeeded Error = 84 + NoReassignmentInProgress Error = 85 + GroupSubscribedToTopic Error = 86 + InvalidRecord Error = 87 + UnstableOffsetCommit Error = 88 ) // Error satisfies the error interface. @@ -130,9 +136,13 @@ func (e Error) Temporary() bool { FencedLeaderEpoch, UnknownLeaderEpoch, OffsetNotAvailable, - PreferredLeaderNotAvailable: + PreferredLeaderNotAvailable, + EligibleLeadersNotAvailable, + ElectionNotNeeded, + NoReassignmentInProgress, + GroupSubscribedToTopic, + UnstableOffsetCommit: return true - default: return false } @@ -293,6 +303,18 @@ func (e Error) Title() string { return "Unknown Leader Epoch" case UnsupportedCompressionType: return "Unsupported Compression Type" + case EligibleLeadersNotAvailable: + return "Eligible Leader Not Available" + case ElectionNotNeeded: + return "Election Not Needed" + case NoReassignmentInProgress: + return "No Reassignment In Progress" + case GroupSubscribedToTopic: + return "Group Subscribed To Topic" + case InvalidRecord: + return "Invalid Record" + case UnstableOffsetCommit: + return "Unstable Offset Commit" } return "" } @@ -452,6 +474,18 @@ func (e Error) Description() string { return "the leader epoch in the request is newer than the epoch on the broker" case UnsupportedCompressionType: return "the requesting client does not support the compression type of given partition" + case EligibleLeadersNotAvailable: + return "eligible topic partition leaders are not available" + case ElectionNotNeeded: + return "leader election not needed for topic partition" + case NoReassignmentInProgress: + return "no partition reassignment is in progress" + case GroupSubscribedToTopic: + return "deleting offsets of a topic is forbidden while the consumer group is actively subscribed to it" + case InvalidRecord: + return "this record has failed the validation on broker and hence be rejected" + case UnstableOffsetCommit: + return "there are unstable offsets that need to be cleared" } return "" } @@ -498,6 +532,65 @@ type MessageTooLargeError struct { Remaining []Message } +func messageTooLarge(msgs []Message, i int) MessageTooLargeError { + remain := make([]Message, 0, len(msgs)-1) + remain = append(remain, msgs[:i]...) + remain = append(remain, msgs[i+1:]...) + return MessageTooLargeError{ + Message: msgs[i], + Remaining: remain, + } +} + func (e MessageTooLargeError) Error() string { return MessageSizeTooLarge.Error() } + +func makeError(code int16, message string) error { + if code == 0 { + return nil + } + if message == "" { + return Error(code) + } + return fmt.Errorf("%w: %s", Error(code), message) +} + +// WriteError is returned by kafka.(*Writer).WriteMessages when the writer is +// not configured to write messages asynchronously. WriteError values contain +// a list of errors where each entry matches the position of a message in the +// WriteMessages call. The program can determine the status of each message by +// looping over the error: +// +// switch err := w.WriteMessages(ctx, msgs...).(type) { +// case nil: +// case kafka.WriteErrors: +// for i := range msgs { +// if err[i] != nil { +// // handle the error writing msgs[i] +// ... +// } +// } +// default: +// // handle other errors +// ... +// } +// +type WriteErrors []error + +// Count counts the number of non-nil errors in err. +func (err WriteErrors) Count() int { + n := 0 + + for _, e := range err { + if e != nil { + n++ + } + } + + return n +} + +func (err WriteErrors) Error() string { + return fmt.Sprintf("kafka write errors (%d/%d)", err.Count(), len(err)) +} diff --git a/error_test.go b/error_test.go index 96c88776f..3d461154c 100644 --- a/error_test.go +++ b/error_test.go @@ -6,8 +6,6 @@ import ( ) func TestError(t *testing.T) { - t.Parallel() - errorCodes := []Error{ Unknown, OffsetOutOfRange, diff --git a/example_writer_test.go b/example_writer_test.go index dfd1c574c..398b553b4 100644 --- a/example_writer_test.go +++ b/example_writer_test.go @@ -7,10 +7,10 @@ import ( ) func ExampleWriter() { - w := kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{"localhost:9092"}, - Topic: "Topic-1", - }) + w := &kafka.Writer{ + Addr: kafka.TCP("localhost:9092"), + Topic: "Topic-1", + } w.WriteMessages(context.Background(), kafka.Message{ diff --git a/examples/producer-api/go.mod b/examples/producer-api/go.mod index 3bf332c76..525c89032 100644 --- a/examples/producer-api/go.mod +++ b/examples/producer-api/go.mod @@ -1,3 +1,5 @@ module github.com/tarekbadrshalaan/GoKafka/kafka-go -require github.com/segmentio/kafka-go v0.2.2 +go 1.15 + +require github.com/segmentio/kafka-go v0.4.5 diff --git a/examples/producer-api/go.sum b/examples/producer-api/go.sum index e798e5613..156312619 100644 --- a/examples/producer-api/go.sum +++ b/examples/producer-api/go.sum @@ -1,2 +1,23 @@ -github.com/segmentio/kafka-go v0.2.2 h1:KIUln5unPisRL2yyAkZsDR/coiymN9Djunv6JKGQ6JI= -github.com/segmentio/kafka-go v0.2.2/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA= +github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/segmentio/kafka-go v0.4.5 h1:vphUaNc3rt77MlGjGfV6AjGq/piP+04wzDLuIiAE9iE= +github.com/segmentio/kafka-go v0.4.5/go.mod h1:Inh7PqOsxmfgasV8InZYKVXWsdjcCq2d9tFV75GLbuM= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284 h1:rlLehGeYg6jfoyz/eDqDU1iRXLKfR42nnNh57ytKEWo= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/examples/producer-api/main.go b/examples/producer-api/main.go index dca658b19..ba365f6ea 100644 --- a/examples/producer-api/main.go +++ b/examples/producer-api/main.go @@ -30,11 +30,11 @@ func producerHandler(kafkaWriter *kafka.Writer) func(http.ResponseWriter, *http. } func getKafkaWriter(kafkaURL, topic string) *kafka.Writer { - return kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{kafkaURL}, + return &kafka.Writer{ + Addr: kafka.TCP(kafkaURL), Topic: topic, Balancer: &kafka.LeastBytes{}, - }) + } } func main() { diff --git a/examples/producer-random/go.mod b/examples/producer-random/go.mod index cee2f1d18..9e78ce30e 100644 --- a/examples/producer-random/go.mod +++ b/examples/producer-random/go.mod @@ -1,6 +1,8 @@ module github.com/tarekbadrshalaan/GoKafka/kafka-go +go 1.15 + require ( github.com/google/uuid v1.1.0 - github.com/segmentio/kafka-go v0.2.2 + github.com/segmentio/kafka-go v0.4.5 ) diff --git a/examples/producer-random/go.sum b/examples/producer-random/go.sum index b8d21fa3b..9270c4e0e 100644 --- a/examples/producer-random/go.sum +++ b/examples/producer-random/go.sum @@ -1,4 +1,25 @@ +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/uuid v1.1.0 h1:Jf4mxPC/ziBnoPIdpQdPJ9OeiomAUHLvxmPRSPH9m4s= github.com/google/uuid v1.1.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/segmentio/kafka-go v0.2.2 h1:KIUln5unPisRL2yyAkZsDR/coiymN9Djunv6JKGQ6JI= -github.com/segmentio/kafka-go v0.2.2/go.mod h1:X6itGqS9L4jDletMsxZ7Dz+JFWxM6JHfPOCvTvk+EJo= +github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA= +github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/segmentio/kafka-go v0.4.5 h1:vphUaNc3rt77MlGjGfV6AjGq/piP+04wzDLuIiAE9iE= +github.com/segmentio/kafka-go v0.4.5/go.mod h1:Inh7PqOsxmfgasV8InZYKVXWsdjcCq2d9tFV75GLbuM= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= +github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284 h1:rlLehGeYg6jfoyz/eDqDU1iRXLKfR42nnNh57ytKEWo= +golang.org/x/crypto v0.0.0-20190506204251-e1dfcc566284/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/examples/producer-random/main.go b/examples/producer-random/main.go index 8d68381e2..7b14e3aee 100644 --- a/examples/producer-random/main.go +++ b/examples/producer-random/main.go @@ -11,11 +11,11 @@ import ( ) func newKafkaWriter(kafkaURL, topic string) *kafka.Writer { - return kafka.NewWriter(kafka.WriterConfig{ - Brokers: []string{kafkaURL}, + return &kafka.Writer{ + Addr: kafka.TCP(kafkaURL), Topic: topic, Balancer: &kafka.LeastBytes{}, - }) + } } func main() { diff --git a/export_test.go b/export_test.go deleted file mode 100644 index b2ae1b537..000000000 --- a/export_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package kafka - -import "testing" - -func CreateTopic(t *testing.T, partitions int) string { - topic := makeTopic() - createTopic(t, topic, partitions) - return topic -} diff --git a/fetch.go b/fetch.go index b742fef20..e4775b658 100644 --- a/fetch.go +++ b/fetch.go @@ -1,5 +1,158 @@ package kafka +import ( + "context" + "fmt" + "math" + "net" + "time" + + "github.com/segmentio/kafka-go/protocol" + fetchAPI "github.com/segmentio/kafka-go/protocol/fetch" +) + +// FetchRequest represents a request sent to a kafka broker to retrieve records +// from a topic partition. +type FetchRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // Topic, partition, and offset to retrieve records from. + Topic string + Partition int + Offset int64 + + // Size and time limits of the response returned by the broker. + MinBytes int64 + MaxBytes int64 + MaxWait time.Duration + + // The isolation level for the request. + // + // Defaults to ReadUncommitted. + // + // This field requires the kafka broker to support the Fetch API in version + // 4 or above (otherwise the value is ignored). + IsolationLevel IsolationLevel +} + +// FetchResponse represents a response from a kafka broker to a fetch request. +type FetchResponse struct { + // The amount of time that the broker throttled the request. + Throttle time.Duration + + // The topic and partition that the response came for (will match the values + // in the request). + Topic string + Partition int + + // Informations about the topic partition layout returned from the broker. + // + // LastStableOffset requires the kafka broker to support the Fetch API in + // version 4 or above (otherwise the value is zero). + // + /// LogStartOffset requires the kafka broker to support the Fetch API in + // version 5 or above (otherwise the value is zero). + HighWatermark int64 + LastStableOffset int64 + LogStartOffset int64 + + // An error that may have occurred while attempting to fetch the records. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error + + // The set of records returned in the response. + // + // The program is expected to call the RecordSet's Close method when it + // finished reading the records. + // + // Note that kafka may return record batches that start at an offset before + // the one that was requested. It is the program's responsibility to skip + // the offsets that it is not interested in. + Records RecordReader +} + +// Fetch sends a fetch request to a kafka broker and returns the response. +// +// If the broker returned an invalid response with no topics, an error wrapping +// protocol.ErrNoTopic is returned. +// +// If the broker returned an invalid response with no partitions, an error +// wrapping ErrNoPartitions is returned. +func (c *Client) Fetch(ctx context.Context, req *FetchRequest) (*FetchResponse, error) { + timeout := c.timeout(ctx, math.MaxInt64) + maxWait := req.maxWait() + + if maxWait < timeout { + timeout = maxWait + } + + m, err := c.roundTrip(ctx, req.Addr, &fetchAPI.Request{ + ReplicaID: -1, + MaxWaitTime: milliseconds(timeout), + MinBytes: int32(req.MinBytes), + MaxBytes: int32(req.MaxBytes), + IsolationLevel: int8(req.IsolationLevel), + SessionID: -1, + SessionEpoch: -1, + Topics: []fetchAPI.RequestTopic{{ + Topic: req.Topic, + Partitions: []fetchAPI.RequestPartition{{ + Partition: int32(req.Partition), + CurrentLeaderEpoch: -1, + FetchOffset: req.Offset, + LogStartOffset: -1, + PartitionMaxBytes: int32(req.MaxBytes), + }}, + }}, + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).Fetch: %w", err) + } + + res := m.(*fetchAPI.Response) + if len(res.Topics) == 0 { + return nil, fmt.Errorf("kafka.(*Client).Fetch: %w", protocol.ErrNoTopic) + } + topic := &res.Topics[0] + if len(topic.Partitions) == 0 { + return nil, fmt.Errorf("kafka.(*Client).Fetch: %w", protocol.ErrNoPartition) + } + partition := &topic.Partitions[0] + + ret := &FetchResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Topic: topic.Topic, + Partition: int(partition.Partition), + Error: makeError(res.ErrorCode, ""), + HighWatermark: partition.HighWatermark, + LastStableOffset: partition.LastStableOffset, + LogStartOffset: partition.LogStartOffset, + Records: partition.RecordSet.Records, + } + + if partition.ErrorCode != 0 { + ret.Error = makeError(partition.ErrorCode, "") + } + + if ret.Records == nil { + ret.Records = NewRecordReader() + } + + return ret, nil +} + +func (req *FetchRequest) maxWait() time.Duration { + if req.MaxWait > 0 { + return req.MaxWait + } + return defaultMaxWait +} + type fetchRequestV2 struct { ReplicaID int32 MaxWaitTime int32 diff --git a/fetch_test.go b/fetch_test.go new file mode 100644 index 000000000..950e82278 --- /dev/null +++ b/fetch_test.go @@ -0,0 +1,302 @@ +package kafka + +import ( + "context" + "errors" + "io" + "io/ioutil" + "net" + "reflect" + "testing" + "time" + + "github.com/segmentio/kafka-go/compress" +) + +func copyRecords(records []Record) []Record { + newRecords := make([]Record, len(records)) + + for i := range records { + k, _ := ReadAll(records[i].Key) + v, _ := ReadAll(records[i].Value) + + records[i].Key = NewBytes(k) + records[i].Value = NewBytes(v) + + newRecords[i].Key = NewBytes(k) + newRecords[i].Value = NewBytes(v) + } + + return newRecords +} + +func produceRecords(t *testing.T, n int, addr net.Addr, topic string, compression compress.Codec) []Record { + conn, err := (&Dialer{ + Resolver: &net.Resolver{}, + }).DialLeader(context.Background(), addr.Network(), addr.String(), topic, 0) + + if err != nil { + t.Fatal("failed to open a new kafka connection:", err) + } + defer conn.Close() + + msgs := makeTestSequence(n) + if compression == nil { + _, err = conn.WriteMessages(msgs...) + } else { + _, err = conn.WriteCompressedMessages(compression, msgs...) + } + if err != nil { + t.Fatal(err) + } + + records := make([]Record, len(msgs)) + for offset, msg := range msgs { + records[offset] = Record{ + Offset: int64(offset), + Key: NewBytes(msg.Key), + Value: NewBytes(msg.Value), + Headers: msg.Headers, + } + } + + return records +} + +func TestClientFetch(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + records := produceRecords(t, 10, client.Addr, topic, nil) + + res, err := client.Fetch(context.Background(), &FetchRequest{ + Topic: topic, + Partition: 0, + Offset: 0, + MinBytes: 1, + MaxBytes: 64 * 1024, + MaxWait: 100 * time.Millisecond, + }) + + if err != nil { + t.Fatal(err) + } + + assertFetchResponse(t, res, &FetchResponse{ + Topic: topic, + Partition: 0, + HighWatermark: 10, + Records: NewRecordReader(records...), + }) +} + +func TestClientFetchCompressed(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + records := produceRecords(t, 10, client.Addr, topic, &compress.GzipCodec) + + res, err := client.Fetch(context.Background(), &FetchRequest{ + Topic: topic, + Partition: 0, + Offset: 0, + MinBytes: 1, + MaxBytes: 64 * 1024, + MaxWait: 100 * time.Millisecond, + }) + + if err != nil { + t.Fatal(err) + } + + assertFetchResponse(t, res, &FetchResponse{ + Topic: topic, + Partition: 0, + HighWatermark: 10, + Records: NewRecordReader(records...), + }) +} + +func assertFetchResponse(t *testing.T, found, expected *FetchResponse) { + t.Helper() + + if found.Topic != expected.Topic { + t.Error("invalid topic found in response:", found.Topic) + } + + if found.Partition != expected.Partition { + t.Error("invalid partition found in response:", found.Partition) + } + + if found.HighWatermark != expected.HighWatermark { + t.Error("invalid high watermark found in response:", found.HighWatermark) + } + + if found.Error != nil { + t.Error("unexpected error found in response:", found.Error) + } + + records1, err := readRecords(found.Records) + if err != nil { + t.Error("error reading records:", err) + } + + records2, err := readRecords(expected.Records) + if err != nil { + t.Error("error reading records:", err) + } + + assertRecords(t, records1, records2) +} + +type memoryRecord struct { + offset int64 + key []byte + value []byte + headers []Header +} + +func assertRecords(t *testing.T, found, expected []memoryRecord) { + t.Helper() + i := 0 + + for i < len(found) && i < len(expected) { + r1 := found[i] + r2 := expected[i] + + if !reflect.DeepEqual(r1, r2) { + t.Errorf("records at index %d don't match", i) + t.Logf("expected:\n%#v", r2) + t.Logf("found:\n%#v", r1) + } + + i++ + } + + for i < len(found) { + t.Errorf("unexpected record at index %d:\n%+v", i, found[i]) + i++ + } + + for i < len(expected) { + t.Errorf("missing record at index %d:\n%+v", i, expected[i]) + i++ + } +} + +func readRecords(records RecordReader) ([]memoryRecord, error) { + list := []memoryRecord{} + + for { + rec, err := records.ReadRecord() + + if err != nil { + if errors.Is(err, io.EOF) { + return list, nil + } + return nil, err + } + + var ( + offset = rec.Offset + key = rec.Key + value = rec.Value + headers = rec.Headers + bytesKey []byte + bytesValues []byte + ) + + if key != nil { + bytesKey, _ = ioutil.ReadAll(key) + } + + if value != nil { + bytesValues, _ = ioutil.ReadAll(value) + } + + list = append(list, memoryRecord{ + offset: offset, + key: bytesKey, + value: bytesValues, + headers: headers, + }) + } +} + +func TestClientPipeline(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + const numBatches = 100 + const recordsPerBatch = 30 + + unixEpoch := time.Unix(0, 0) + records := make([]Record, recordsPerBatch) + content := []byte("1234567890") + + for i := 0; i < numBatches; i++ { + for j := range records { + records[j] = Record{Value: NewBytes(content)} + } + + _, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + RequiredAcks: -1, + Records: NewRecordReader(records...), + Compression: Snappy, + }) + if err != nil { + t.Fatal(err) + } + } + + offset := int64(0) + + for i := 0; i < (numBatches * recordsPerBatch); { + req := &FetchRequest{ + Topic: topic, + Offset: offset, + MinBytes: 1, + MaxBytes: 8192, + MaxWait: 500 * time.Millisecond, + } + + res, err := client.Fetch(context.Background(), req) + if err != nil { + t.Fatal(err) + } + + if res.Error != nil { + t.Fatal(res.Error) + } + + for { + r, err := res.Records.ReadRecord() + if err != nil { + if err == io.EOF { + break + } + t.Fatal(err) + } + + if r.Key != nil { + r.Key.Close() + } + + if r.Value != nil { + r.Value.Close() + } + + if r.Offset != offset { + t.Errorf("record at index %d has mismatching offset, want %d but got %d", i, offset, r.Offset) + } + + if r.Time.IsZero() || r.Time.Equal(unixEpoch) { + t.Errorf("record at index %d with offset %d has not timestamp", i, r.Offset) + } + + offset = r.Offset + 1 + i++ + } + } +} diff --git a/go.mod b/go.mod index 738546806..0d12f0c31 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/segmentio/kafka-go -go 1.11 +go 1.13 require ( github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 diff --git a/gzip/gzip.go b/gzip/gzip.go index 5743d963e..2ad84b500 100644 --- a/gzip/gzip.go +++ b/gzip/gzip.go @@ -1,131 +1,25 @@ +// Package gzip does nothing, it's kept for backward compatibility to avoid +// breaking the majority of programs that imported it to install the compression +// codec, which is now always included. package gzip import ( - "bytes" - "compress/gzip" - "io" - "io/ioutil" - "sync" + gz "compress/gzip" - kafka "github.com/segmentio/kafka-go" + "github.com/segmentio/kafka-go/compress/gzip" ) -var ( - // emptyGzipBytes is the binary value for an empty file that has been - // gzipped. It is used to initialize gzip.Reader before adding it to the - // readerPool. - emptyGzipBytes = [...]byte{ - 0x1f, 0x8b, 0x08, 0x08, 0x0d, 0x0c, 0x67, 0x5c, 0x00, 0x03, 0x66, 0x6f, - 0x6f, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - } - - readerPool = sync.Pool{ - New: func() interface{} { - // if the reader doesn't get valid gzip at initialization time, - // it will not be valid and will fail on Reset. - reader := &gzipReader{} - reader.Reset(nil) - return reader - }, - } -) - -type gzipReader struct { - gzip.Reader - emptyGzipFile bytes.Reader -} - -func (z *gzipReader) Reset(r io.Reader) { - if r == nil { - z.emptyGzipFile.Reset(emptyGzipBytes[:]) - r = &z.emptyGzipFile - } - z.Reader.Reset(r) -} - -func init() { - kafka.RegisterCompressionCodec(NewCompressionCodec()) -} - const ( - Code = 1 - - DefaultCompressionLevel = gzip.DefaultCompression + Code = 1 + DefaultCompressionLevel = gz.DefaultCompression ) -type CompressionCodec struct{ writerPool sync.Pool } +type CompressionCodec = gzip.Codec func NewCompressionCodec() *CompressionCodec { return NewCompressionCodecLevel(DefaultCompressionLevel) } func NewCompressionCodecLevel(level int) *CompressionCodec { - return &CompressionCodec{ - writerPool: sync.Pool{ - New: func() interface{} { - w, err := gzip.NewWriterLevel(ioutil.Discard, level) - if err != nil { - return err - } - return w - }, - }, - } -} - -// Code implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Code() int8 { return Code } - -// Name implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Name() string { return "gzip" } - -// NewReader implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { - z := readerPool.Get().(*gzipReader) - z.Reset(r) - return &reader{z} -} - -// NewWriter implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { - x := c.writerPool.Get() - z, _ := x.(*gzip.Writer) - if z == nil { - return errorWriter{err: x.(error)} - } - z.Reset(w) - return &writer{c, z} + return &CompressionCodec{Level: level} } - -type reader struct{ *gzipReader } - -func (r *reader) Close() (err error) { - if z := r.gzipReader; z != nil { - r.gzipReader = nil - err = z.Close() - z.Reset(nil) - readerPool.Put(z) - } - return -} - -type writer struct { - c *CompressionCodec - *gzip.Writer -} - -func (w *writer) Close() (err error) { - if z := w.Writer; z != nil { - w.Writer = nil - err = z.Close() - z.Reset(nil) - w.c.writerPool.Put(z) - } - return -} - -type errorWriter struct{ err error } - -func (w errorWriter) Close() error { return w.err } - -func (w errorWriter) Write(b []byte) (int, error) { return 0, w.err } diff --git a/kafka.go b/kafka.go new file mode 100644 index 000000000..06eddb878 --- /dev/null +++ b/kafka.go @@ -0,0 +1,92 @@ +package kafka + +import "github.com/segmentio/kafka-go/protocol" + +// Broker represents a kafka broker in a kafka cluster. +type Broker struct { + Host string + Port int + ID int + Rack string +} + +// Topic represents a topic in a kafka cluster. +type Topic struct { + // Name of the topic. + Name string + + // True if the topic is internal. + Internal bool + + // The list of partition currently available on this topic. + Partitions []Partition + + // An error that may have occurred while attempting to read the topic + // metadata. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error +} + +// Partition carries the metadata associated with a kafka partition. +type Partition struct { + // Name of the topic that the partition belongs to, and its index in the + // topic. + Topic string + ID int + + // Leader, replicas, and ISR for the partition. + Leader Broker + Replicas []Broker + Isr []Broker + + // An error that may have occurred while attempting to read the partition + // metadata. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error +} + +// Marshal encodes v into a binary representation of the value in the kafka data +// format. +// +// If v is a, or contains struct types, the kafka struct fields are interpreted +// and may contain one of these values: +// +// nullable valid on bytes and strings, encodes as a nullable value +// compact valid on strings, encodes as a compact string +// +// The kafka struct tags should not contain min and max versions. If you need to +// encode types based on specific versions of kafka APIs, use the Version type +// instead. +func Marshal(v interface{}) ([]byte, error) { + return protocol.Marshal(-1, v) +} + +// Unmarshal decodes a binary representation from b into v. +// +// See Marshal for details. +func Unmarshal(b []byte, v interface{}) error { + return protocol.Unmarshal(b, -1, v) +} + +// Version represents a version number for kafka APIs. +type Version int16 + +// Marshal is like the top-level Marshal function, but will only encode struct +// fields for which n falls within the min and max versions specified on the +// struct tag. +func (n Version) Marshal(v interface{}) ([]byte, error) { + return protocol.Marshal(int16(n), v) +} + +// Unmarshal is like the top-level Unmarshal function, but will only decode +// struct fields for which n falls within the min and max versions specified on +// the struct tag. +func (n Version) Unmarshal(b []byte, v interface{}) error { + return protocol.Unmarshal(b, int16(n), v) +} diff --git a/kafka_test.go b/kafka_test.go new file mode 100644 index 000000000..f66855cd7 --- /dev/null +++ b/kafka_test.go @@ -0,0 +1,173 @@ +package kafka + +import ( + "fmt" + "math" + "reflect" + "strconv" + "testing" +) + +func TestMarshalUnmarshal(t *testing.T) { + values := []interface{}{ + true, + false, + + int8(0), + int8(1), + int8(math.MinInt8), + int8(math.MaxInt8), + + int16(0), + int16(1), + int16(math.MinInt16), + int16(math.MaxInt16), + + int32(0), + int32(1), + int32(math.MinInt32), + int32(math.MaxInt32), + + int64(0), + int64(1), + int64(math.MinInt64), + int64(math.MaxInt64), + + "", + "hello world!", + + ([]byte)(nil), + []byte(""), + []byte("hello world!"), + + ([]int32)(nil), + []int32{}, + []int32{0, 1, 2, 3, 4}, + + struct{}{}, + struct { + A int32 + B string + C []byte + }{A: 1, B: "42", C: []byte{}}, + } + + for _, v := range values { + t.Run(fmt.Sprintf("%+v", v), func(t *testing.T) { + b, err := Marshal(v) + if err != nil { + t.Fatal("marshal error:", err) + } + + x := reflect.New(reflect.TypeOf(v)) + + if err := Unmarshal(b, x.Interface()); err != nil { + t.Fatal("unmarshal error:", err) + } + + if !reflect.DeepEqual(v, x.Elem().Interface()) { + t.Fatalf("values mismatch:\nexpected: %#v\nfound: %#v\n", v, x.Elem().Interface()) + } + }) + } +} + +func TestVersionMarshalUnmarshal(t *testing.T) { + type T struct { + A int32 `kafka:"min=v0,max=v1"` + B string `kafka:"min=v1,max=v2"` + C []byte `kafka:"min=v2,max=v2,nullable"` + } + + tests := []struct { + out T + ver Version + }{ + { + out: T{A: 42}, + ver: Version(0), + }, + } + + in := T{ + A: 42, + B: "Hello World!", + C: []byte("question?"), + } + + for _, test := range tests { + t.Run(strconv.Itoa(int(test.ver)), func(t *testing.T) { + b, err := test.ver.Marshal(in) + if err != nil { + t.Fatal("marshal error:", err) + } + + x1 := test.out + x2 := T{} + + if err := test.ver.Unmarshal(b, &x2); err != nil { + t.Fatal("unmarshal error:", err) + } + + if !reflect.DeepEqual(x1, x2) { + t.Fatalf("values mismatch:\nexpected: %#v\nfound: %#v\n", x1, x2) + } + }) + } + +} + +type Struct struct { + A int32 + B int32 + C int32 +} + +var benchmarkValues = []interface{}{ + true, + int8(1), + int16(1), + int32(1), + int64(1), + "Hello World!", + []byte("Hello World!"), + []int32{1, 2, 3}, + Struct{A: 1, B: 2, C: 3}, +} + +func BenchmarkMarshal(b *testing.B) { + for _, v := range benchmarkValues { + b.Run(fmt.Sprintf("%T", v), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := Marshal(v) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkUnmarshal(b *testing.B) { + for _, v := range benchmarkValues { + b.Run(fmt.Sprintf("%T", v), func(b *testing.B) { + data, err := Marshal(v) + + if err != nil { + b.Fatal(err) + } + + value := reflect.New(reflect.TypeOf(v)) + ptr := value.Interface() + elem := value.Elem() + zero := reflect.Zero(reflect.TypeOf(v)) + + for i := 0; i < b.N; i++ { + if err := Unmarshal(data, ptr); err != nil { + b.Fatal(err) + } + elem.Set(zero) + } + }) + } +} diff --git a/listoffset.go b/listoffset.go index 7988b4c87..11c5d04b4 100644 --- a/listoffset.go +++ b/listoffset.go @@ -1,6 +1,187 @@ package kafka -import "bufio" +import ( + "bufio" + "context" + "fmt" + "net" + "time" + + "github.com/segmentio/kafka-go/protocol/listoffsets" +) + +// OffsetRequest represents a request to retrieve a single partition offset. +type OffsetRequest struct { + Partition int + Timestamp int64 +} + +// FirstOffsetOf constructs an OffsetRequest which asks for the first offset of +// the parition given as argument. +func FirstOffsetOf(partition int) OffsetRequest { + return OffsetRequest{Partition: partition, Timestamp: FirstOffset} +} + +// LastOffsetOf constructs an OffsetRequest which asks for the last offset of +// the partition given as argument. +func LastOffsetOf(partition int) OffsetRequest { + return OffsetRequest{Partition: partition, Timestamp: LastOffset} +} + +// TimeOffsetOf constructs an OffsetRequest which asks for a partition offset +// at a given time. +func TimeOffsetOf(partition int, at time.Time) OffsetRequest { + return OffsetRequest{Partition: partition, Timestamp: timestamp(at)} +} + +// PartitionOffsets carries information about offsets available in a topic +// partition. +type PartitionOffsets struct { + Partition int + FirstOffset int64 + LastOffset int64 + Offsets map[int64]time.Time + Error error +} + +// ListOffsetsRequest represents a request sent to a kafka broker to list of the +// offsets of topic partitions. +type ListOffsetsRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // A mapping of topic names to list of partitions that the program wishes to + // get the offsets for. + Topics map[string][]OffsetRequest + + // The isolation level for the request. + // + // Defaults to ReadUncommitted. + // + // This field requires the kafka broker to support the ListOffsets API in + // version 2 or above (otherwise the value is ignored). + IsolationLevel IsolationLevel +} + +// ListOffsetsResponse represents a response from a kafka broker to a offset +// listing request. +type ListOffsetsResponse struct { + // The amount of time that the broker throttled the request. + Throttle time.Duration + + // Mappings of topics names to partition offsets, there will be one entry + // for each topic in the request. + Topics map[string][]PartitionOffsets +} + +// ListOffsets sends an offset request to a kafka broker and returns the +// response. +func (c *Client) ListOffsets(ctx context.Context, req *ListOffsetsRequest) (*ListOffsetsResponse, error) { + type topicPartition struct { + topic string + partition int + } + + partitionOffsets := make(map[topicPartition]PartitionOffsets) + + for topicName, requests := range req.Topics { + for _, r := range requests { + key := topicPartition{ + topic: topicName, + partition: r.Partition, + } + + partition, ok := partitionOffsets[key] + if !ok { + partition = PartitionOffsets{ + Partition: r.Partition, + FirstOffset: -1, + LastOffset: -1, + Offsets: make(map[int64]time.Time), + } + } + + switch r.Timestamp { + case FirstOffset: + partition.FirstOffset = 0 + case LastOffset: + partition.LastOffset = 0 + } + + partitionOffsets[topicPartition{ + topic: topicName, + partition: r.Partition, + }] = partition + } + } + + topics := make([]listoffsets.RequestTopic, 0, len(req.Topics)) + + for topicName, requests := range req.Topics { + partitions := make([]listoffsets.RequestPartition, len(requests)) + + for i, r := range requests { + partitions[i] = listoffsets.RequestPartition{ + Partition: int32(r.Partition), + CurrentLeaderEpoch: -1, + Timestamp: r.Timestamp, + } + } + + topics = append(topics, listoffsets.RequestTopic{ + Topic: topicName, + Partitions: partitions, + }) + } + + m, err := c.roundTrip(ctx, req.Addr, &listoffsets.Request{ + ReplicaID: -1, + IsolationLevel: int8(req.IsolationLevel), + Topics: topics, + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).ListOffsets: %w", err) + } + + res := m.(*listoffsets.Response) + ret := &ListOffsetsResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Topics: make(map[string][]PartitionOffsets, len(res.Topics)), + } + + for _, t := range res.Topics { + for _, p := range t.Partitions { + key := topicPartition{ + topic: t.Topic, + partition: int(p.Partition), + } + + partition := partitionOffsets[key] + + switch p.Timestamp { + case FirstOffset: + partition.FirstOffset = p.Offset + case LastOffset: + partition.LastOffset = p.Offset + default: + partition.Offsets[p.Offset] = makeTime(p.Timestamp) + } + + if p.ErrorCode != 0 { + partition.Error = Error(p.ErrorCode) + } + + partitionOffsets[key] = partition + } + } + + for key, partition := range partitionOffsets { + ret.Topics[key.topic] = append(ret.Topics[key.topic], partition) + } + + return ret, nil +} type listOffsetRequestV1 struct { ReplicaID int32 diff --git a/listoffset_test.go b/listoffset_test.go new file mode 100644 index 000000000..7605b7944 --- /dev/null +++ b/listoffset_test.go @@ -0,0 +1,76 @@ +package kafka + +import ( + "context" + "testing" + "time" +) + +func TestClientListOffsets(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + now := time.Now() + + _, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: NewRecordReader( + Record{Time: now, Value: NewBytes([]byte(`hello-1`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-2`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-3`))}, + ), + }) + + if err != nil { + t.Fatal(err) + } + + res, err := client.ListOffsets(context.Background(), &ListOffsetsRequest{ + Topics: map[string][]OffsetRequest{ + topic: {FirstOffsetOf(0), LastOffsetOf(0)}, + }, + }) + + if err != nil { + t.Fatal(err) + } + + if len(res.Topics) != 1 { + t.Fatal("invalid number of topics found in list offsets response:", len(res.Topics)) + } + + partitions, ok := res.Topics[topic] + if !ok { + t.Fatal("missing topic in the list offsets response:", topic) + } + if len(partitions) != 1 { + t.Fatal("invalid number of partitions found in list offsets response:", len(partitions)) + } + partition := partitions[0] + + if partition.Partition != 0 { + t.Error("invalid partition id found in list offsets response:", partition.Partition) + } + + if partition.FirstOffset != 0 { + t.Error("invalid first offset found in list offsets response:", partition.FirstOffset) + } + + if partition.LastOffset != 3 { + t.Error("invalid last offset found in list offsets response:", partition.LastOffset) + } + + if firstOffsetTime := partition.Offsets[partition.FirstOffset]; !firstOffsetTime.IsZero() { + t.Error("unexpected first offset time in list offsets response:", partition.Offsets) + } + + if lastOffsetTime := partition.Offsets[partition.LastOffset]; !lastOffsetTime.IsZero() { + t.Error("unexpected last offset time in list offsets response:", partition.Offsets) + } + + if partition.Error != nil { + t.Error("unexpected error in list offsets response:", partition.Error) + } +} diff --git a/lz4/lz4.go b/lz4/lz4.go index 140c7f9cb..53a47677f 100644 --- a/lz4/lz4.go +++ b/lz4/lz4.go @@ -1,74 +1,16 @@ +// Package lz4 does nothing, it's kept for backward compatibility to avoid +// breaking the majority of programs that imported it to install the compression +// codec, which is now always included. package lz4 -import ( - "io" - "sync" - - "github.com/pierrec/lz4" - kafka "github.com/segmentio/kafka-go" -) - -func init() { - kafka.RegisterCompressionCodec(NewCompressionCodec()) -} +import "github.com/segmentio/kafka-go/compress/lz4" const ( Code = 3 ) -type CompressionCodec struct{} +type CompressionCodec = lz4.Codec func NewCompressionCodec() *CompressionCodec { return &CompressionCodec{} } - -// Code implements the kafka.CompressionCodec interface. -func (CompressionCodec) Code() int8 { return Code } - -// Name implements the kafka.CompressionCodec interface. -func (CompressionCodec) Name() string { return "lz4" } - -// NewReader implements the kafka.CompressionCodec interface. -func (CompressionCodec) NewReader(r io.Reader) io.ReadCloser { - z := readerPool.Get().(*lz4.Reader) - z.Reset(r) - return &reader{z} -} - -// NewWriter implements the kafka.CompressionCodec interface. -func (CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { - z := writerPool.Get().(*lz4.Writer) - z.Reset(w) - return &writer{z} -} - -type reader struct{ *lz4.Reader } - -func (r *reader) Close() (err error) { - if z := r.Reader; z != nil { - r.Reader = nil - z.Reset(nil) - readerPool.Put(z) - } - return -} - -type writer struct{ *lz4.Writer } - -func (w *writer) Close() (err error) { - if z := w.Writer; z != nil { - w.Writer = nil - err = z.Close() - z.Reset(nil) - writerPool.Put(z) - } - return -} - -var readerPool = sync.Pool{ - New: func() interface{} { return lz4.NewReader(nil) }, -} - -var writerPool = sync.Pool{ - New: func() interface{} { return lz4.NewWriter(nil) }, -} diff --git a/message.go b/message.go index 32612ecce..dee59b783 100644 --- a/message.go +++ b/message.go @@ -10,10 +10,13 @@ import ( // Message is a data structure representing kafka messages. type Message struct { - // Topic is reads only and MUST NOT be set when writing messages + // Topic indicates which topic this message was consumed from via Reader. + // + // When being used with Writer, this can be used to configured the topic if + // not already specified on the writer itself. Topic string - // Partition is reads only and MUST NOT be set when writing messages + // Partition is read-only and MUST NOT be set when writing messages Partition int Offset int64 Key []byte @@ -40,7 +43,7 @@ func (msg Message) message(cw *crc32Writer) message { const timestampSize = 8 -func (msg Message) size() int32 { +func (msg *Message) size() int32 { return 4 + 1 + 1 + sizeofBytes(msg.Key) + sizeofBytes(msg.Value) + timestampSize } @@ -353,11 +356,6 @@ func extractOffset(base int64, msgSet []byte) (offset int64, err error) { return } -type Header struct { - Key string - Value []byte -} - type messageSetHeaderV2 struct { firstOffset int64 length int32 diff --git a/metadata.go b/metadata.go index d524b9fd8..4b1309f85 100644 --- a/metadata.go +++ b/metadata.go @@ -1,5 +1,112 @@ package kafka +import ( + "context" + "fmt" + "net" + "time" + + metadataAPI "github.com/segmentio/kafka-go/protocol/metadata" +) + +// MetadataRequest represents a request sent to a kafka broker to retrieve its +// cluster metadata. +type MetadataRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // The list of topics to retrieve metadata for. + Topics []string +} + +// MetadatResponse represents a response from a kafka broker to a metadata +// request. +type MetadataResponse struct { + // The amount of time that the broker throttled the request. + Throttle time.Duration + + // Name of the kafka cluster that client retrieved metadata from. + ClusterID string + + // The broker which is currently the controller for the cluster. + Controller Broker + + // The list of brokers registered to the cluster. + Brokers []Broker + + // The list of topics available on the cluster. + Topics []Topic +} + +// Metadata sends a metadata request to a kafka broker and returns the response. +func (c *Client) Metadata(ctx context.Context, req *MetadataRequest) (*MetadataResponse, error) { + m, err := c.roundTrip(ctx, req.Addr, &metadataAPI.Request{ + TopicNames: req.Topics, + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).Metadata: %w", err) + } + + res := m.(*metadataAPI.Response) + ret := &MetadataResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Brokers: make([]Broker, len(res.Brokers)), + Topics: make([]Topic, len(res.Topics)), + ClusterID: res.ClusterID, + } + + brokers := make(map[int32]Broker, len(res.Brokers)) + + for i, b := range res.Brokers { + broker := Broker{ + Host: b.Host, + Port: int(b.Port), + ID: int(b.NodeID), + Rack: b.Rack, + } + + ret.Brokers[i] = broker + brokers[b.NodeID] = broker + + if b.NodeID == res.ControllerID { + ret.Controller = broker + } + } + + for i, t := range res.Topics { + ret.Topics[i] = Topic{ + Name: t.Name, + Internal: t.IsInternal, + Partitions: make([]Partition, len(t.Partitions)), + Error: makeError(t.ErrorCode, ""), + } + + for j, p := range t.Partitions { + partition := Partition{ + Topic: t.Name, + ID: int(p.PartitionIndex), + Leader: brokers[p.LeaderID], + Replicas: make([]Broker, len(p.ReplicaNodes)), + Isr: make([]Broker, len(p.IsrNodes)), + Error: makeError(p.ErrorCode, ""), + } + + for i, id := range p.ReplicaNodes { + partition.Replicas[i] = brokers[id] + } + + for i, id := range p.IsrNodes { + partition.Isr[i] = brokers[id] + } + + ret.Topics[i].Partitions[j] = partition + } + } + + return ret, nil +} + type topicMetadataRequestV1 []string func (r topicMetadataRequestV1) size() int32 { diff --git a/metadata_test.go b/metadata_test.go new file mode 100644 index 000000000..35d8f2ff8 --- /dev/null +++ b/metadata_test.go @@ -0,0 +1,75 @@ +package kafka + +import ( + "context" + "testing" +) + +func TestClientMetadata(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + metadata, err := client.Metadata(context.Background(), &MetadataRequest{ + Topics: []string{topic}, + }) + + if err != nil { + t.Fatal(err) + } + + if len(metadata.Brokers) == 0 { + t.Error("no brokers were returned in the metadata response") + } + + for _, b := range metadata.Brokers { + if b == (Broker{}) { + t.Error("unexpected broker with zero-value in metadata response") + } + } + + if len(metadata.Topics) == 0 { + t.Error("no topics were returned in the metadata response") + } else { + topicMetadata := metadata.Topics[0] + + if topicMetadata.Name != topic { + t.Error("invalid topic name:", topicMetadata.Name) + } + + if len(topicMetadata.Partitions) == 0 { + t.Error("no partitions were returned in the topic metadata response") + } else { + partitionMetadata := topicMetadata.Partitions[0] + + if partitionMetadata.Topic != topic { + t.Error("invalid partition topic name:", partitionMetadata.Topic) + } + + if partitionMetadata.ID != 0 { + t.Error("invalid partition index:", partitionMetadata.ID) + } + + if partitionMetadata.Leader == (Broker{}) { + t.Error("no partition leader was returned in the partition metadata response") + } + + if partitionMetadata.Error != nil { + t.Error("unexpected error found in the partition metadata response:", partitionMetadata.Error) + } + + // assume newLocalClientAndTopic creates the topic with one + // partition + if len(topicMetadata.Partitions) > 1 { + t.Error("too many partitions were returned in the topic metadata response") + } + } + + if topicMetadata.Error != nil { + t.Error("unexpected error found in the topic metadata response:", topicMetadata.Error) + } + + if len(metadata.Topics) > 1 { + t.Error("too many topics were returned in the metadata response") + } + } +} diff --git a/offsetfetch.go b/offsetfetch.go index 91b037e3a..5a87f7cee 100644 --- a/offsetfetch.go +++ b/offsetfetch.go @@ -2,8 +2,119 @@ package kafka import ( "bufio" + "context" + "fmt" + "net" + "time" + + "github.com/segmentio/kafka-go/protocol/offsetfetch" ) +// OffsetFetchRequest represents a request sent to a kafka broker to read the +// currently committed offsets of topic partitions. +type OffsetFetchRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // ID of the consumer group to retrieve the offsets for. + GroupID string + + // Set of topic partitions to retrieve the offsets for. + Topics map[string][]int +} + +// OffsetFetchResponse represents a response from a kafka broker to an offset +// fetch request. +type OffsetFetchResponse struct { + // The amount of time that the broker throttled the request. + Throttle time.Duration + + // Set of topic partitions that the kafka broker has returned offsets for. + Topics map[string][]OffsetFetchPartition + + // An error that may have occurred while attempting to retrieve consumer + // group offsets. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error +} + +// OffsetFetchPartition represents the state of a single partition in a consumer +// group. +type OffsetFetchPartition struct { + // ID of the partition. + Partition int + + // Last committed offsets on the partition when the request was served by + // the kafka broker. + CommittedOffset int64 + + // Consumer group metadata for this partition. + Metadata string + + // An error that may have occurred while attempting to retrieve consumer + // group offsets for this partition. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error +} + +// OffsetFetch sends an offset fetch request to a kafka broker and returns the +// response. +func (c *Client) OffsetFetch(ctx context.Context, req *OffsetFetchRequest) (*OffsetFetchResponse, error) { + topics := make([]offsetfetch.RequestTopic, 0, len(req.Topics)) + + for topicName, partitions := range req.Topics { + indexes := make([]int32, len(partitions)) + + for i, p := range partitions { + indexes[i] = int32(p) + } + + topics = append(topics, offsetfetch.RequestTopic{ + Name: topicName, + PartitionIndexes: indexes, + }) + } + + m, err := c.roundTrip(ctx, req.Addr, &offsetfetch.Request{ + GroupID: req.GroupID, + Topics: topics, + }) + + if err != nil { + return nil, fmt.Errorf("kafka.(*Client).OffsetFetch: %w", err) + } + + res := m.(*offsetfetch.Response) + ret := &OffsetFetchResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Topics: make(map[string][]OffsetFetchPartition, len(res.Topics)), + Error: makeError(res.ErrorCode, ""), + } + + for _, t := range res.Topics { + partitions := make([]OffsetFetchPartition, len(t.Partitions)) + + for i, p := range t.Partitions { + partitions[i] = OffsetFetchPartition{ + Partition: int(p.PartitionIndex), + CommittedOffset: p.CommittedOffset, + Metadata: p.Metadata, + Error: makeError(p.ErrorCode, ""), + } + } + + ret.Topics[t.Name] = partitions + } + + return ret, nil +} + type offsetFetchRequestV1Topic struct { // Topic name Topic string diff --git a/produce.go b/produce.go index 449993e99..bbf34b7fa 100644 --- a/produce.go +++ b/produce.go @@ -1,6 +1,176 @@ package kafka -import "bufio" +import ( + "bufio" + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/segmentio/kafka-go/protocol" + produceAPI "github.com/segmentio/kafka-go/protocol/produce" +) + +type RequiredAcks int + +const ( + RequireNone RequiredAcks = 0 + RequireOne RequiredAcks = 1 + RequireAll RequiredAcks = -1 +) + +func (acks RequiredAcks) String() string { + switch acks { + case RequireNone: + return "none" + case RequireOne: + return "one" + case RequireAll: + return "all" + default: + return "unknown" + } +} + +// ProduceRequest represents a request sent to a kafka broker to produce records +// to a topic partition. +type ProduceRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // The topic to produce the records to. + Topic string + + // The partition to produce the records to. + Partition int + + // The level of required acknowledgements to ask the kafka broker for. + RequiredAcks RequiredAcks + + // The message format version used when encoding the records. + // + // By default, the client automatically determine which version should be + // used based on the version of the Produce API supported by the server. + MessageVersion int + + // An optional transaction id when producing to the kafka broker is part of + // a transaction. + TransactionalID string + + // The sequence of records to produce to the topic partition. + Records RecordReader + + // An optional compression algorithm to apply to the batch of records sent + // to the kafka broker. + Compression Compression +} + +// ProduceResponse represents a response from a kafka broker to a produce +// request. +type ProduceResponse struct { + // The amount of time that the broker throttled the request. + Throttle time.Duration + + // An error that may have occurred while attempting to produce the records. + // + // The error contains both the kafka error code, and an error message + // returned by the kafka broker. Programs may use the standard errors.Is + // function to test the error against kafka error codes. + Error error + + // Offset of the first record that was written to the topic partition. + // + // This field will be zero if the kafka broker did no support the Produce + // API in version 3 or above. + BaseOffset int64 + + // Time at which the broker wrote the records to the topic partition. + // + // This field will be zero if the kafka broker did no support the Produce + // API in version 2 or above. + LogAppendTime time.Time + + // First offset in the topic partition that the records were written to. + // + // This field will be zero if the kafka broker did no support the Produce + // API in version 5 or above (or if the first offset is zero). + LogStartOffset int64 + + // If errors occurred writing specific records, they will be reported in + // this map. + // + // This field will always be empty if the kafka broker did no support the + // Produce API in version 8 or above. + RecordErrors map[int]error +} + +// Produce sends a produce request to a kafka broker and returns the response. +// +// If the request contained no records, an error wrapping protocol.ErrNoRecord +// is returned. +// +// When the request is configured with RequiredAcks=none, both the response and +// the error will be nil on success. +func (c *Client) Produce(ctx context.Context, req *ProduceRequest) (*ProduceResponse, error) { + attributes := protocol.Attributes(req.Compression) & 0x7 + + m, err := c.roundTrip(ctx, req.Addr, &produceAPI.Request{ + TransactionalID: req.TransactionalID, + Acks: int16(req.RequiredAcks), + Timeout: c.timeoutMs(ctx, defaultProduceTimeout), + Topics: []produceAPI.RequestTopic{{ + Topic: req.Topic, + Partitions: []produceAPI.RequestPartition{{ + Partition: int32(req.Partition), + RecordSet: protocol.RecordSet{ + Attributes: attributes, + Records: req.Records, + }, + }}, + }}, + }) + + switch { + case err == nil: + case errors.Is(err, protocol.ErrNoRecord): + return new(ProduceResponse), nil + default: + return nil, fmt.Errorf("kafka.(*Client).Produce: %w", err) + } + + if req.RequiredAcks == RequireNone { + return nil, nil + } + + res := m.(*produceAPI.Response) + if len(res.Topics) == 0 { + return nil, fmt.Errorf("kafka.(*Client).Produce: %w", protocol.ErrNoTopic) + } + topic := &res.Topics[0] + if len(topic.Partitions) == 0 { + return nil, fmt.Errorf("kafka.(*Client).Produce: %w", protocol.ErrNoPartition) + } + partition := &topic.Partitions[0] + + ret := &ProduceResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Error: makeError(partition.ErrorCode, partition.ErrorMessage), + BaseOffset: partition.BaseOffset, + LogAppendTime: makeTime(partition.LogAppendTime), + LogStartOffset: partition.LogStartOffset, + } + + if len(partition.RecordErrors) != 0 { + ret.RecordErrors = make(map[int]error, len(partition.RecordErrors)) + + for _, recErr := range partition.RecordErrors { + ret.RecordErrors[int(recErr.BatchIndex)] = errors.New(recErr.BatchIndexErrorMessage) + } + } + + return ret, nil +} type produceRequestV2 struct { RequiredAcks int16 diff --git a/produce_test.go b/produce_test.go new file mode 100644 index 000000000..68347a1d7 --- /dev/null +++ b/produce_test.go @@ -0,0 +1,102 @@ +package kafka + +import ( + "context" + "testing" + "time" + + "github.com/segmentio/kafka-go/compress" +) + +func TestClientProduce(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + now := time.Now() + + res, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: NewRecordReader( + Record{Time: now, Value: NewBytes([]byte(`hello-1`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-2`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-3`))}, + ), + }) + + if err != nil { + t.Fatal(err) + } + + if res.Error != nil { + t.Error(res.Error) + } + + for index, err := range res.RecordErrors { + t.Errorf("record at index %d produced an error: %v", index, err) + } +} + +func TestClientProduceCompressed(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + now := time.Now() + + res, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Compression: compress.Gzip, + Records: NewRecordReader( + Record{Time: now, Value: NewBytes([]byte(`hello-1`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-2`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-3`))}, + ), + }) + + if err != nil { + t.Fatal(err) + } + + if res.Error != nil { + t.Error(res.Error) + } + + for index, err := range res.RecordErrors { + t.Errorf("record at index %d produced an error: %v", index, err) + } +} + +func TestClientProduceNilRecords(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + _, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: nil, + }) + + if err != nil { + t.Fatal(err) + } +} + +func TestClientProduceEmptyRecords(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + _, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: NewRecordReader(), + }) + + if err != nil { + t.Fatal(err) + } +} diff --git a/protocol.go b/protocol.go index 727cb2cf8..706946ee3 100644 --- a/protocol.go +++ b/protocol.go @@ -102,17 +102,17 @@ func (k apiKey) String() string { type apiVersion int16 const ( - v0 apiVersion = 0 - v1 apiVersion = 1 - v2 apiVersion = 2 - v3 apiVersion = 3 - v4 apiVersion = 4 - v5 apiVersion = 5 - v6 apiVersion = 6 - v7 apiVersion = 7 - v8 apiVersion = 8 - v9 apiVersion = 9 - v10 apiVersion = 10 + v0 = 0 + v1 = 1 + v2 = 2 + v3 = 3 + v4 = 4 + v5 = 5 + v6 = 6 + v7 = 7 + v8 = 8 + v9 = 9 + v10 = 10 ) var apiKeyStrings = [...]string{ diff --git a/protocol/apiversions/apiversions.go b/protocol/apiversions/apiversions.go new file mode 100644 index 000000000..1c5745582 --- /dev/null +++ b/protocol/apiversions/apiversions.go @@ -0,0 +1,27 @@ +package apiversions + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + _ struct{} `kafka:"min=v0,max=v2"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.ApiVersions } + +type Response struct { + ErrorCode int16 `kafka:"min=v0,max=v2"` + ApiKeys []ApiKeyResponse `kafka:"min=v0,max=v2"` + ThrottleTimeMs int32 `kafka:"min=v1,max=v2"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.ApiVersions } + +type ApiKeyResponse struct { + ApiKey int16 `kafka:"min=v0,max=v2"` + MinVersion int16 `kafka:"min=v0,max=v2"` + MaxVersion int16 `kafka:"min=v0,max=v2"` +} diff --git a/protocol/buffer.go b/protocol/buffer.go new file mode 100644 index 000000000..901fdf3e2 --- /dev/null +++ b/protocol/buffer.go @@ -0,0 +1,645 @@ +package protocol + +import ( + "bytes" + "fmt" + "io" + "math" + "sync" + "sync/atomic" +) + +// Bytes is an interface implemented by types that represent immutable +// sequences of bytes. +// +// Bytes values are used to abstract the location where record keys and +// values are read from (e.g. in-memory buffers, network sockets, files). +// +// The Close method should be called to release resources held by the object +// when the program is done with it. +// +// Bytes values are generally not safe to use concurrently from multiple +// goroutines. +type Bytes interface { + io.ReadCloser + // Returns the number of bytes remaining to be read from the payload. + Len() int +} + +// NewBytes constructs a Bytes value from b. +// +// The returned value references b, it does not make a copy of the backing +// array. +// +// If b is nil, nil is returned to represent a null BYTES value in the kafka +// protocol. +func NewBytes(b []byte) Bytes { + if b == nil { + return nil + } + r := new(bytesReader) + r.Reset(b) + return r +} + +// ReadAll is similar to ioutil.ReadAll, but it takes advantage of knowing the +// length of b to minimize the memory footprint. +// +// The function returns a nil slice if b is nil. +func ReadAll(b Bytes) ([]byte, error) { + if b == nil { + return nil, nil + } + s := make([]byte, b.Len()) + _, err := io.ReadFull(b, s) + return s, err +} + +type bytesReader struct{ bytes.Reader } + +func (*bytesReader) Close() error { return nil } + +type refCount uintptr + +func (rc *refCount) ref() { atomic.AddUintptr((*uintptr)(rc), 1) } + +func (rc *refCount) unref(onZero func()) { + if atomic.AddUintptr((*uintptr)(rc), ^uintptr(0)) == 0 { + onZero() + } +} + +const ( + // Size of the memory buffer for a single page. We use a farily + // large size here (64 KiB) because batches exchanged with kafka + // tend to be multiple kilobytes in size, sometimes hundreds. + // Using large pages amortizes the overhead of the page metadata + // and algorithms to manage the pages. + pageSize = 65536 +) + +type page struct { + refc refCount + offset int64 + length int + buffer *[pageSize]byte +} + +func newPage(offset int64) *page { + p, _ := pagePool.Get().(*page) + if p != nil { + p.offset = offset + p.length = 0 + p.ref() + } else { + p = &page{ + refc: 1, + offset: offset, + buffer: &[pageSize]byte{}, + } + } + return p +} + +func (p *page) ref() { p.refc.ref() } + +func (p *page) unref() { p.refc.unref(func() { pagePool.Put(p) }) } + +func (p *page) slice(begin, end int64) []byte { + i, j := begin-p.offset, end-p.offset + + if i < 0 { + i = 0 + } else if i > pageSize { + i = pageSize + } + + if j < 0 { + j = 0 + } else if j > pageSize { + j = pageSize + } + + if i < j { + return p.buffer[i:j] + } + + return nil +} + +func (p *page) Cap() int { return pageSize } + +func (p *page) Len() int { return p.length } + +func (p *page) Size() int64 { return int64(p.length) } + +func (p *page) Truncate(n int) { + if n < p.length { + p.length = n + } +} + +func (p *page) ReadAt(b []byte, off int64) (int, error) { + if off -= p.offset; off < 0 || off > pageSize { + panic("offset out of range") + } + if off > int64(p.length) { + return 0, nil + } + return copy(b, p.buffer[off:p.length]), nil +} + +func (p *page) ReadFrom(r io.Reader) (int64, error) { + n, err := io.ReadFull(r, p.buffer[p.length:]) + if err == io.EOF || err == io.ErrUnexpectedEOF { + err = nil + } + p.length += n + return int64(n), err +} + +func (p *page) WriteAt(b []byte, off int64) (int, error) { + if off -= p.offset; off < 0 || off > pageSize { + panic("offset out of range") + } + n := copy(p.buffer[off:], b) + if end := int(off) + n; end > p.length { + p.length = end + } + return n, nil +} + +func (p *page) Write(b []byte) (int, error) { + return p.WriteAt(b, p.offset+int64(p.length)) +} + +var ( + _ io.ReaderAt = (*page)(nil) + _ io.ReaderFrom = (*page)(nil) + _ io.Writer = (*page)(nil) + _ io.WriterAt = (*page)(nil) +) + +type pageBuffer struct { + refc refCount + pages contiguousPages + length int + cursor int +} + +func newPageBuffer() *pageBuffer { + b, _ := pageBufferPool.Get().(*pageBuffer) + if b != nil { + b.cursor = 0 + b.refc.ref() + } else { + b = &pageBuffer{ + refc: 1, + pages: make(contiguousPages, 0, 16), + } + } + return b +} + +func (pb *pageBuffer) refTo(ref *pageRef, begin, end int64) { + length := end - begin + + if length > math.MaxUint32 { + panic("reference to contiguous buffer pages exceeds the maximum size of 4 GB") + } + + ref.pages = append(ref.buffer[:0], pb.pages.slice(begin, end)...) + ref.pages.ref() + ref.offset = begin + ref.length = uint32(length) +} + +func (pb *pageBuffer) ref(begin, end int64) *pageRef { + ref := new(pageRef) + pb.refTo(ref, begin, end) + return ref +} + +func (pb *pageBuffer) unref() { + pb.refc.unref(func() { + pb.pages.unref() + pb.pages.clear() + pb.pages = pb.pages[:0] + pb.length = 0 + pageBufferPool.Put(pb) + }) +} + +func (pb *pageBuffer) newPage() *page { + return newPage(int64(pb.length)) +} + +func (pb *pageBuffer) Close() error { + return nil +} + +func (pb *pageBuffer) Len() int { + return pb.length - pb.cursor +} + +func (pb *pageBuffer) Size() int64 { + return int64(pb.length) +} + +func (pb *pageBuffer) Discard(n int) (int, error) { + remain := pb.length - pb.cursor + if remain < n { + n = remain + } + pb.cursor += n + return n, nil +} + +func (pb *pageBuffer) Truncate(n int) { + if n < pb.length { + pb.length = n + + if n < pb.cursor { + pb.cursor = n + } + + for i := range pb.pages { + if p := pb.pages[i]; p.length <= n { + n -= p.length + } else { + if n > 0 { + pb.pages[i].Truncate(n) + i++ + } + pb.pages[i:].unref() + pb.pages[i:].clear() + pb.pages = pb.pages[:i] + break + } + } + } +} + +func (pb *pageBuffer) Seek(offset int64, whence int) (int64, error) { + c, err := seek(int64(pb.cursor), int64(pb.length), offset, whence) + if err != nil { + return -1, err + } + pb.cursor = int(c) + return c, nil +} + +func (pb *pageBuffer) ReadByte() (byte, error) { + b := [1]byte{} + _, err := pb.Read(b[:]) + return b[0], err +} + +func (pb *pageBuffer) Read(b []byte) (int, error) { + if pb.cursor >= pb.length { + return 0, io.EOF + } + n, err := pb.ReadAt(b, int64(pb.cursor)) + pb.cursor += n + return n, err +} + +func (pb *pageBuffer) ReadAt(b []byte, off int64) (int, error) { + return pb.pages.ReadAt(b, off) +} + +func (pb *pageBuffer) ReadFrom(r io.Reader) (int64, error) { + if len(pb.pages) == 0 { + pb.pages = append(pb.pages, pb.newPage()) + } + + rn := int64(0) + + for { + tail := pb.pages[len(pb.pages)-1] + free := tail.Cap() - tail.Len() + + if free == 0 { + tail = pb.newPage() + free = pageSize + pb.pages = append(pb.pages, tail) + } + + n, err := tail.ReadFrom(r) + pb.length += int(n) + rn += n + if n < int64(free) { + return rn, err + } + } +} + +func (pb *pageBuffer) WriteString(s string) (int, error) { + return pb.Write([]byte(s)) +} + +func (pb *pageBuffer) Write(b []byte) (int, error) { + wn := len(b) + if wn == 0 { + return 0, nil + } + + if len(pb.pages) == 0 { + pb.pages = append(pb.pages, pb.newPage()) + } + + for len(b) != 0 { + tail := pb.pages[len(pb.pages)-1] + free := tail.Cap() - tail.Len() + + if len(b) <= free { + tail.Write(b) + pb.length += len(b) + break + } + + tail.Write(b[:free]) + b = b[free:] + + pb.length += free + pb.pages = append(pb.pages, pb.newPage()) + } + + return wn, nil +} + +func (pb *pageBuffer) WriteAt(b []byte, off int64) (int, error) { + n, err := pb.pages.WriteAt(b, off) + if err != nil { + return n, err + } + if n < len(b) { + pb.Write(b[n:]) + } + return len(b), nil +} + +func (pb *pageBuffer) WriteTo(w io.Writer) (int64, error) { + var wn int + var err error + pb.pages.scan(int64(pb.cursor), int64(pb.length), func(b []byte) bool { + var n int + n, err = w.Write(b) + wn += n + return err == nil + }) + pb.cursor += wn + return int64(wn), err +} + +var ( + _ io.ReaderAt = (*pageBuffer)(nil) + _ io.ReaderFrom = (*pageBuffer)(nil) + _ io.StringWriter = (*pageBuffer)(nil) + _ io.Writer = (*pageBuffer)(nil) + _ io.WriterAt = (*pageBuffer)(nil) + _ io.WriterTo = (*pageBuffer)(nil) + + pagePool sync.Pool + pageBufferPool sync.Pool +) + +type contiguousPages []*page + +func (pages contiguousPages) ref() { + for _, p := range pages { + p.ref() + } +} + +func (pages contiguousPages) unref() { + for _, p := range pages { + p.unref() + } +} + +func (pages contiguousPages) clear() { + for i := range pages { + pages[i] = nil + } +} + +func (pages contiguousPages) ReadAt(b []byte, off int64) (int, error) { + rn := 0 + + for _, p := range pages.slice(off, off+int64(len(b))) { + n, _ := p.ReadAt(b, off) + b = b[n:] + rn += n + off += int64(n) + } + + return rn, nil +} + +func (pages contiguousPages) WriteAt(b []byte, off int64) (int, error) { + wn := 0 + + for _, p := range pages.slice(off, off+int64(len(b))) { + n, _ := p.WriteAt(b, off) + b = b[n:] + wn += n + off += int64(n) + } + + return wn, nil +} + +func (pages contiguousPages) slice(begin, end int64) contiguousPages { + i := pages.indexOf(begin) + j := pages.indexOf(end) + if j < len(pages) { + j++ + } + return pages[i:j] +} + +func (pages contiguousPages) indexOf(offset int64) int { + if len(pages) == 0 { + return 0 + } + return int((offset - pages[0].offset) / pageSize) +} + +func (pages contiguousPages) scan(begin, end int64, f func([]byte) bool) { + for _, p := range pages.slice(begin, end) { + if !f(p.slice(begin, end)) { + break + } + } +} + +var ( + _ io.ReaderAt = contiguousPages{} + _ io.WriterAt = contiguousPages{} +) + +type pageRef struct { + buffer [2]*page + pages contiguousPages + offset int64 + cursor int64 + length uint32 + once uint32 +} + +func (ref *pageRef) unref() { + if atomic.CompareAndSwapUint32(&ref.once, 0, 1) { + ref.pages.unref() + ref.pages.clear() + ref.pages = nil + ref.offset = 0 + ref.cursor = 0 + ref.length = 0 + } +} + +func (ref *pageRef) Len() int { return int(ref.Size() - ref.cursor) } + +func (ref *pageRef) Size() int64 { return int64(ref.length) } + +func (ref *pageRef) Close() error { ref.unref(); return nil } + +func (ref *pageRef) String() string { + return fmt.Sprintf("[offset=%d cursor=%d length=%d]", ref.offset, ref.cursor, ref.length) +} + +func (ref *pageRef) Seek(offset int64, whence int) (int64, error) { + c, err := seek(ref.cursor, int64(ref.length), offset, whence) + if err != nil { + return -1, err + } + ref.cursor = c + return c, nil +} + +func (ref *pageRef) ReadByte() (byte, error) { + var c byte + var ok bool + ref.scan(ref.cursor, func(b []byte) bool { + c, ok = b[0], true + return false + }) + if ok { + ref.cursor++ + } else { + return 0, io.EOF + } + return c, nil +} + +func (ref *pageRef) Read(b []byte) (int, error) { + if ref.cursor >= int64(ref.length) { + return 0, io.EOF + } + n, err := ref.ReadAt(b, ref.cursor) + ref.cursor += int64(n) + return n, err +} + +func (ref *pageRef) ReadAt(b []byte, off int64) (int, error) { + limit := ref.offset + int64(ref.length) + off += ref.offset + + if off >= limit { + return 0, io.EOF + } + + if off+int64(len(b)) > limit { + b = b[:limit-off] + } + + if len(b) == 0 { + return 0, nil + } + + n, err := ref.pages.ReadAt(b, off) + if n == 0 && err == nil { + err = io.EOF + } + return n, err +} + +func (ref *pageRef) WriteTo(w io.Writer) (wn int64, err error) { + ref.scan(ref.cursor, func(b []byte) bool { + var n int + n, err = w.Write(b) + wn += int64(n) + return err == nil + }) + ref.cursor += wn + return +} + +func (ref *pageRef) scan(off int64, f func([]byte) bool) { + begin := ref.offset + off + end := ref.offset + int64(ref.length) + ref.pages.scan(begin, end, f) +} + +var ( + _ io.Closer = (*pageRef)(nil) + _ io.Seeker = (*pageRef)(nil) + _ io.Reader = (*pageRef)(nil) + _ io.ReaderAt = (*pageRef)(nil) + _ io.WriterTo = (*pageRef)(nil) +) + +type pageRefAllocator struct { + refs []pageRef + head int + size int +} + +func (a *pageRefAllocator) newPageRef() *pageRef { + if a.head == len(a.refs) { + a.refs = make([]pageRef, a.size) + a.head = 0 + } + ref := &a.refs[a.head] + a.head++ + return ref +} + +func unref(x interface{}) { + if r, _ := x.(interface{ unref() }); r != nil { + r.unref() + } +} + +func seek(cursor, limit, offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + // absolute offset + case io.SeekCurrent: + offset = cursor + offset + case io.SeekEnd: + offset = limit - offset + default: + return -1, fmt.Errorf("seek: invalid whence value: %d", whence) + } + if offset < 0 { + offset = 0 + } + if offset > limit { + offset = limit + } + return offset, nil +} + +func closeBytes(b Bytes) { + if b != nil { + b.Close() + } +} + +func resetBytes(b Bytes) { + if r, _ := b.(interface{ Reset() }); r != nil { + r.Reset() + } +} diff --git a/protocol/buffer_test.go b/protocol/buffer_test.go new file mode 100644 index 000000000..5a9817dc7 --- /dev/null +++ b/protocol/buffer_test.go @@ -0,0 +1,108 @@ +package protocol + +import ( + "bytes" + "io" + "io/ioutil" + "testing" +) + +func TestPageBufferWriteReadSeek(t *testing.T) { + buffer := newPageBuffer() + defer buffer.unref() + + io.WriteString(buffer, "Hello World!") + + if n := buffer.Size(); n != 12 { + t.Fatal("invalid size:", n) + } + + for i := 0; i < 3; i++ { + if n := buffer.Len(); n != 12 { + t.Fatal("invalid length before read:", n) + } + + b, err := ioutil.ReadAll(buffer) + if err != nil { + t.Fatal(err) + } + + if n := buffer.Len(); n != 0 { + t.Fatal("invalid length after read:", n) + } + + if string(b) != "Hello World!" { + t.Fatalf("invalid content after read #%d: %q", i, b) + } + + offset, err := buffer.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if offset != 0 { + t.Fatalf("invalid offset after seek #%d: %d", i, offset) + } + } +} + +func TestPageRefWriteReadSeek(t *testing.T) { + buffer := newPageBuffer() + defer buffer.unref() + + io.WriteString(buffer, "Hello World!") + + ref := buffer.ref(1, 11) + defer ref.unref() + + if n := ref.Size(); n != 10 { + t.Fatal("invalid size:", n) + } + + for i := 0; i < 3; i++ { + if n := ref.Len(); n != 10 { + t.Fatal("invalid length before read:", n) + } + + b, err := ioutil.ReadAll(ref) + if err != nil { + t.Fatal(err) + } + + if n := ref.Len(); n != 0 { + t.Fatal("invalid length after read:", n) + } + + if string(b) != "ello World" { + t.Fatalf("invalid content after read #%d: %q", i, b) + } + + offset, err := ref.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if offset != 0 { + t.Fatalf("invalid offset after seek #%d: %d", i, offset) + } + } +} + +func TestPageRefReadByte(t *testing.T) { + buffer := newPageBuffer() + defer buffer.unref() + + content := bytes.Repeat([]byte("1234567890"), 10e3) + buffer.Write(content) + + ref := buffer.ref(0, buffer.Size()) + defer ref.unref() + + for i, c := range content { + b, err := ref.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != c { + t.Fatalf("byte at offset %d mismatch, expected '%c' but got '%c'", i, c, b) + } + } +} diff --git a/protocol/cluster.go b/protocol/cluster.go new file mode 100644 index 000000000..5dd3455ad --- /dev/null +++ b/protocol/cluster.go @@ -0,0 +1,143 @@ +package protocol + +import ( + "fmt" + "sort" + "strings" + "text/tabwriter" +) + +type Cluster struct { + ClusterID string + Controller int32 + Brokers map[int32]Broker + Topics map[string]Topic +} + +func (c Cluster) BrokerIDs() []int32 { + brokerIDs := make([]int32, 0, len(c.Brokers)) + for id := range c.Brokers { + brokerIDs = append(brokerIDs, id) + } + sort.Slice(brokerIDs, func(i, j int) bool { + return brokerIDs[i] < brokerIDs[j] + }) + return brokerIDs +} + +func (c Cluster) TopicNames() []string { + topicNames := make([]string, 0, len(c.Topics)) + for name := range c.Topics { + topicNames = append(topicNames, name) + } + sort.Strings(topicNames) + return topicNames +} + +func (c Cluster) IsZero() bool { + return c.ClusterID == "" && c.Controller == 0 && len(c.Brokers) == 0 && len(c.Topics) == 0 +} + +func (c Cluster) Format(w fmt.State, _ rune) { + tw := new(tabwriter.Writer) + fmt.Fprintf(w, "CLUSTER: %q\n\n", c.ClusterID) + + tw.Init(w, 0, 8, 2, ' ', 0) + fmt.Fprint(tw, " BROKER\tHOST\tPORT\tRACK\tCONTROLLER\n") + + for _, id := range c.BrokerIDs() { + broker := c.Brokers[id] + fmt.Fprintf(tw, " %d\t%s\t%d\t%s\t%t\n", broker.ID, broker.Host, broker.Port, broker.Rack, broker.ID == c.Controller) + } + + tw.Flush() + fmt.Fprintln(w) + + tw.Init(w, 0, 8, 2, ' ', 0) + fmt.Fprint(tw, " TOPIC\tPARTITIONS\tBROKERS\n") + topicNames := c.TopicNames() + brokers := make(map[int32]struct{}, len(c.Brokers)) + brokerIDs := make([]int32, 0, len(c.Brokers)) + + for _, name := range topicNames { + topic := c.Topics[name] + + for _, p := range topic.Partitions { + for _, id := range p.Replicas { + brokers[id] = struct{}{} + } + } + + for id := range brokers { + brokerIDs = append(brokerIDs, id) + } + + fmt.Fprintf(tw, " %s\t%d\t%s\n", topic.Name, len(topic.Partitions), formatBrokerIDs(brokerIDs, -1)) + + for id := range brokers { + delete(brokers, id) + } + + brokerIDs = brokerIDs[:0] + } + + tw.Flush() + fmt.Fprintln(w) + + if w.Flag('+') { + for _, name := range topicNames { + fmt.Fprintf(w, " TOPIC: %q\n\n", name) + + tw.Init(w, 0, 8, 2, ' ', 0) + fmt.Fprint(tw, " PARTITION\tREPLICAS\tISR\tOFFLINE\n") + + for _, p := range c.Topics[name].Partitions { + fmt.Fprintf(tw, " %d\t%s\t%s\t%s\n", p.ID, + formatBrokerIDs(p.Replicas, -1), + formatBrokerIDs(p.ISR, p.Leader), + formatBrokerIDs(p.Offline, -1), + ) + } + + tw.Flush() + fmt.Fprintln(w) + } + } +} + +func formatBrokerIDs(brokerIDs []int32, leader int32) string { + if len(brokerIDs) == 0 { + return "" + } + + if len(brokerIDs) == 1 { + return itoa(brokerIDs[0]) + } + + sort.Slice(brokerIDs, func(i, j int) bool { + id1 := brokerIDs[i] + id2 := brokerIDs[j] + + if id1 == leader { + return true + } + + if id2 == leader { + return false + } + + return id1 < id2 + }) + + brokerNames := make([]string, len(brokerIDs)) + + for i, id := range brokerIDs { + brokerNames[i] = itoa(id) + } + + return strings.Join(brokerNames, ",") +} + +var ( + _ fmt.Formatter = Cluster{} +) diff --git a/protocol/conn.go b/protocol/conn.go new file mode 100644 index 000000000..a39752fa0 --- /dev/null +++ b/protocol/conn.go @@ -0,0 +1,96 @@ +package protocol + +import ( + "bufio" + "fmt" + "net" + "sync/atomic" + "time" +) + +type Conn struct { + buffer *bufio.Reader + conn net.Conn + clientID string + idgen int32 + versions atomic.Value // map[ApiKey]int16 +} + +func NewConn(conn net.Conn, clientID string) *Conn { + return &Conn{ + buffer: bufio.NewReader(conn), + conn: conn, + clientID: clientID, + } +} + +func (c *Conn) String() string { + return fmt.Sprintf("kafka://%s@%s->%s", c.clientID, c.LocalAddr(), c.RemoteAddr()) +} + +func (c *Conn) Close() error { + return c.conn.Close() +} + +func (c *Conn) Discard(n int) (int, error) { + return c.buffer.Discard(n) +} + +func (c *Conn) Peek(n int) ([]byte, error) { + return c.buffer.Peek(n) +} + +func (c *Conn) Read(b []byte) (int, error) { + return c.buffer.Read(b) +} + +func (c *Conn) Write(b []byte) (int, error) { + return c.conn.Write(b) +} + +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *Conn) SetVersions(versions map[ApiKey]int16) { + connVersions := make(map[ApiKey]int16, len(versions)) + + for k, v := range versions { + connVersions[k] = v + } + + c.versions.Store(connVersions) +} + +func (c *Conn) RoundTrip(msg Message) (Message, error) { + correlationID := atomic.AddInt32(&c.idgen, +1) + versions, _ := c.versions.Load().(map[ApiKey]int16) + apiVersion := versions[msg.ApiKey()] + + if p, _ := msg.(PreparedMessage); p != nil { + p.Prepare(apiVersion) + } + + return RoundTrip(c, apiVersion, correlationID, c.clientID, msg) +} + +var ( + _ net.Conn = (*Conn)(nil) + _ bufferedReader = (*Conn)(nil) +) diff --git a/protocol/createtopics/createtopics.go b/protocol/createtopics/createtopics.go new file mode 100644 index 000000000..62c597fb1 --- /dev/null +++ b/protocol/createtopics/createtopics.go @@ -0,0 +1,74 @@ +package createtopics + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + // We need at least one tagged field to indicate that v5+ uses "flexible" + // messages. + _ struct{} `kafka:"min=v5,max=v5,tag"` + + Topics []RequestTopic `kafka:"min=v0,max=v5"` + TimeoutMs int32 `kafka:"min=v0,max=v5"` + ValidateOnly bool `kafka:"min=v1,max=v5"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.CreateTopics } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + return cluster.Brokers[cluster.Controller], nil +} + +type RequestTopic struct { + Name string `kafka:"min=v0,max=v5"` + NumPartitions int32 `kafka:"min=v0,max=v5"` + ReplicationFactor int16 `kafka:"min=v0,max=v5"` + Assignments []RequestAssignment `kafka:"min=v0,max=v5"` + Configs []RequestConfig `kafka:"min=v0,max=v5"` +} + +type RequestAssignment struct { + PartitionIndex int32 `kafka:"min=v0,max=v5"` + BrokerIDs []int32 `kafka:"min=v0,max=v5"` +} + +type RequestConfig struct { + Name string `kafka:"min=v0,max=v5"` + Value string `kafka:"min=v0,max=v5,nullable"` +} + +type Response struct { + // We need at least one tagged field to indicate that v5+ uses "flexible" + // messages. + _ struct{} `kafka:"min=v5,max=v5,tag"` + + ThrottleTimeMs int32 `kafka:"min=v2,max=v5"` + Topics []ResponseTopic `kafka:"min=v0,max=v5"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.CreateTopics } + +type ResponseTopic struct { + Name string `kafka:"min=v0,max=v5"` + ErrorCode int16 `kafka:"min=v0,max=v5"` + ErrorMessage string `kafka:"min=v1,max=v5,nullable"` + NumPartitions int32 `kafka:"min=v5,max=v5"` + ReplicationFactor int16 `kafka:"min=v5,max=v5"` + + Configs []ResponseTopicConfig `kafka:"min=v5,max=v5"` +} + +type ResponseTopicConfig struct { + Name string `kafka:"min=v5,max=v5"` + Value string `kafka:"min=v5,max=v5,nullable"` + ReadOnly bool `kafka:"min=v5,max=v5"` + ConfigSource int8 `kafka:"min=v5,max=v5"` + IsSensitive bool `kafka:"min=v5,max=v5"` +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) +) diff --git a/protocol/decode.go b/protocol/decode.go new file mode 100644 index 000000000..10888c4ef --- /dev/null +++ b/protocol/decode.go @@ -0,0 +1,545 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "hash/crc32" + "io" + "io/ioutil" + "reflect" + "sync" + "sync/atomic" +) + +type discarder interface { + Discard(int) (int, error) +} + +type decoder struct { + reader io.Reader + remain int + buffer [8]byte + err error + table *crc32.Table + crc32 uint32 +} + +func (d *decoder) Reset(r io.Reader, n int) { + d.reader = r + d.remain = n + d.buffer = [8]byte{} + d.err = nil + d.table = nil + d.crc32 = 0 +} + +func (d *decoder) Read(b []byte) (int, error) { + if d.err != nil { + return 0, d.err + } + if d.remain == 0 { + return 0, io.EOF + } + if len(b) > d.remain { + b = b[:d.remain] + } + n, err := d.reader.Read(b) + if n > 0 && d.table != nil { + d.crc32 = crc32.Update(d.crc32, d.table, b[:n]) + } + d.remain -= n + return n, err +} + +func (d *decoder) ReadByte() (byte, error) { + c := d.readByte() + return c, d.err +} + +func (d *decoder) done() bool { + return d.remain == 0 || d.err != nil +} + +func (d *decoder) setCRC(table *crc32.Table) { + d.table, d.crc32 = table, 0 +} + +func (d *decoder) decodeBool(v value) { + v.setBool(d.readBool()) +} + +func (d *decoder) decodeInt8(v value) { + v.setInt8(d.readInt8()) +} + +func (d *decoder) decodeInt16(v value) { + v.setInt16(d.readInt16()) +} + +func (d *decoder) decodeInt32(v value) { + v.setInt32(d.readInt32()) +} + +func (d *decoder) decodeInt64(v value) { + v.setInt64(d.readInt64()) +} + +func (d *decoder) decodeString(v value) { + v.setString(d.readString()) +} + +func (d *decoder) decodeCompactString(v value) { + v.setString(d.readCompactString()) +} + +func (d *decoder) decodeBytes(v value) { + v.setBytes(d.readBytes()) +} + +func (d *decoder) decodeCompactBytes(v value) { + v.setBytes(d.readCompactBytes()) +} + +func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) { + if n := d.readInt32(); n < 0 { + v.setArray(array{}) + } else { + a := makeArray(elemType, int(n)) + for i := 0; i < int(n) && d.remain > 0; i++ { + decodeElem(d, a.index(i)) + } + v.setArray(a) + } +} + +func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) { + if n := d.readUnsignedVarInt(); n < 1 { + v.setArray(array{}) + } else { + a := makeArray(elemType, int(n-1)) + for i := 0; i < int(n-1) && d.remain > 0; i++ { + decodeElem(d, a.index(i)) + } + v.setArray(a) + } +} + +func (d *decoder) discardAll() { + d.discard(d.remain) +} + +func (d *decoder) discard(n int) { + if n > d.remain { + n = d.remain + } + var err error + if r, _ := d.reader.(discarder); r != nil { + n, err = r.Discard(n) + d.remain -= n + } else { + _, err = io.Copy(ioutil.Discard, d) + } + d.setError(err) +} + +func (d *decoder) read(n int) []byte { + b := make([]byte, n) + n, err := io.ReadFull(d, b) + b = b[:n] + d.setError(err) + return b +} + +func (d *decoder) writeTo(w io.Writer, n int) { + limit := d.remain + if n < limit { + d.remain = n + } + c, err := io.Copy(w, d) + if int(c) < n && err == nil { + err = io.ErrUnexpectedEOF + } + d.remain = limit - int(c) + d.setError(err) +} + +func (d *decoder) setError(err error) { + if d.err == nil && err != nil { + d.err = err + d.discardAll() + } +} + +func (d *decoder) readFull(b []byte) bool { + n, err := io.ReadFull(d, b) + d.setError(err) + return n == len(b) +} + +func (d *decoder) readByte() byte { + if d.readFull(d.buffer[:1]) { + return d.buffer[0] + } + return 0 +} + +func (d *decoder) readBool() bool { + return d.readByte() != 0 +} + +func (d *decoder) readInt8() int8 { + if d.readFull(d.buffer[:1]) { + return readInt8(d.buffer[:1]) + } + return 0 +} + +func (d *decoder) readInt16() int16 { + if d.readFull(d.buffer[:2]) { + return readInt16(d.buffer[:2]) + } + return 0 +} + +func (d *decoder) readInt32() int32 { + if d.readFull(d.buffer[:4]) { + return readInt32(d.buffer[:4]) + } + return 0 +} + +func (d *decoder) readInt64() int64 { + if d.readFull(d.buffer[:8]) { + return readInt64(d.buffer[:8]) + } + return 0 +} + +func (d *decoder) readString() string { + if n := d.readInt16(); n < 0 { + return "" + } else { + return bytesToString(d.read(int(n))) + } +} + +func (d *decoder) readVarString() string { + if n := d.readVarInt(); n < 0 { + return "" + } else { + return bytesToString(d.read(int(n))) + } +} + +func (d *decoder) readCompactString() string { + if n := d.readUnsignedVarInt(); n < 1 { + return "" + } else { + return bytesToString(d.read(int(n - 1))) + } +} + +func (d *decoder) readBytes() []byte { + if n := d.readInt32(); n < 0 { + return nil + } else { + return d.read(int(n)) + } +} + +func (d *decoder) readBytesTo(w io.Writer) bool { + if n := d.readInt32(); n < 0 { + return false + } else { + d.writeTo(w, int(n)) + return d.err == nil + } +} + +func (d *decoder) readVarBytes() []byte { + if n := d.readVarInt(); n < 0 { + return nil + } else { + return d.read(int(n)) + } +} + +func (d *decoder) readVarBytesTo(w io.Writer) bool { + if n := d.readVarInt(); n < 0 { + return false + } else { + d.writeTo(w, int(n)) + return d.err == nil + } +} + +func (d *decoder) readCompactBytes() []byte { + if n := d.readUnsignedVarInt(); n < 1 { + return nil + } else { + return d.read(int(n - 1)) + } +} + +func (d *decoder) readCompactBytesTo(w io.Writer) bool { + if n := d.readUnsignedVarInt(); n < 1 { + return false + } else { + d.writeTo(w, int(n-1)) + return d.err == nil + } +} + +func (d *decoder) readVarInt() int64 { + n := 11 // varints are at most 11 bytes + + if n > d.remain { + n = d.remain + } + + x := uint64(0) + s := uint(0) + + for n > 0 { + b := d.readByte() + + if (b & 0x80) == 0 { + x |= uint64(b) << s + return int64(x>>1) ^ -(int64(x) & 1) + } + + x |= uint64(b&0x7f) << s + s += 7 + n-- + } + + d.setError(fmt.Errorf("cannot decode varint from input stream")) + return 0 +} + +func (d *decoder) readUnsignedVarInt() uint64 { + n := 11 // varints are at most 11 bytes + + if n > d.remain { + n = d.remain + } + + x := uint64(0) + s := uint(0) + + for n > 0 { + b := d.readByte() + + if (b & 0x80) == 0 { + x |= uint64(b) << s + return x + } + + x |= uint64(b&0x7f) << s + s += 7 + n-- + } + + d.setError(fmt.Errorf("cannot decode unsigned varint from input stream")) + return 0 +} + +type decodeFunc func(*decoder, value) + +var ( + _ io.Reader = (*decoder)(nil) + _ io.ByteReader = (*decoder)(nil) + + readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem() +) + +func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { + if reflect.PtrTo(typ).Implements(readerFrom) { + return readerDecodeFuncOf(typ) + } + switch typ.Kind() { + case reflect.Bool: + return (*decoder).decodeBool + case reflect.Int8: + return (*decoder).decodeInt8 + case reflect.Int16: + return (*decoder).decodeInt16 + case reflect.Int32: + return (*decoder).decodeInt32 + case reflect.Int64: + return (*decoder).decodeInt64 + case reflect.String: + return stringDecodeFuncOf(flexible, tag) + case reflect.Struct: + return structDecodeFuncOf(typ, version, flexible) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { // []byte + return bytesDecodeFuncOf(flexible, tag) + } + return arrayDecodeFuncOf(typ, version, flexible, tag) + default: + panic("unsupported type: " + typ.String()) + } +} + +func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc { + if flexible { + // In flexible messages, all strings are compact + return (*decoder).decodeCompactString + } + return (*decoder).decodeString +} + +func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc { + if flexible { + // In flexible messages, all arrays are compact + return (*decoder).decodeCompactBytes + } + return (*decoder).decodeBytes +} + +func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc { + type field struct { + decode decodeFunc + index index + tagID int + } + + var fields []field + taggedFields := map[int]*field{} + + forEachStructField(typ, func(typ reflect.Type, index index, tag string) { + forEachStructTag(tag, func(tag structTag) bool { + if tag.MinVersion <= version && version <= tag.MaxVersion { + f := field{ + decode: decodeFuncOf(typ, version, flexible, tag), + index: index, + tagID: tag.TagID, + } + + if tag.TagID < -1 { + // Normal required field + fields = append(fields, f) + } else { + // Optional tagged field (flexible messages only) + taggedFields[tag.TagID] = &f + } + return false + } + return true + }) + }) + + return func(d *decoder, v value) { + for i := range fields { + f := &fields[i] + f.decode(d, v.fieldByIndex(f.index)) + } + + if flexible { + // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields + // for details of tag buffers in "flexible" messages. + n := int(d.readUnsignedVarInt()) + + for i := 0; i < n; i++ { + tagID := int(d.readUnsignedVarInt()) + size := int(d.readUnsignedVarInt()) + + f, ok := taggedFields[tagID] + if ok { + f.decode(d, v.fieldByIndex(f.index)) + } else { + d.read(size) + } + } + } + } +} + +func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { + elemType := typ.Elem() + elemFunc := decodeFuncOf(elemType, version, flexible, tag) + if flexible { + // In flexible messages, all arrays are compact + return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) } + } + + return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) } +} + +func readerDecodeFuncOf(typ reflect.Type) decodeFunc { + typ = reflect.PtrTo(typ) + return func(d *decoder, v value) { + if d.err == nil { + _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d) + if err != nil { + d.setError(err) + } + } + } +} + +func readInt8(b []byte) int8 { + return int8(b[0]) +} + +func readInt16(b []byte) int16 { + return int16(binary.BigEndian.Uint16(b)) +} + +func readInt32(b []byte) int32 { + return int32(binary.BigEndian.Uint32(b)) +} + +func readInt64(b []byte) int64 { + return int64(binary.BigEndian.Uint64(b)) +} + +func Unmarshal(data []byte, version int16, value interface{}) error { + typ := elemTypeOf(value) + cache, _ := unmarshalers.Load().(map[_type]decodeFunc) + decode := cache[typ] + + if decode == nil { + decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{ + MinVersion: -1, + MaxVersion: -1, + TagID: -2, + Compact: true, + Nullable: true, + }) + + newCache := make(map[_type]decodeFunc, len(cache)+1) + newCache[typ] = decode + + for typ, fun := range cache { + newCache[typ] = fun + } + + unmarshalers.Store(newCache) + } + + d, _ := decoders.Get().(*decoder) + if d == nil { + d = &decoder{reader: bytes.NewReader(nil)} + } + + d.remain = len(data) + r, _ := d.reader.(*bytes.Reader) + r.Reset(data) + + defer func() { + r.Reset(nil) + d.Reset(r, 0) + decoders.Put(d) + }() + + decode(d, valueOf(value)) + return dontExpectEOF(d.err) +} + +var ( + decoders sync.Pool // *decoder + unmarshalers atomic.Value // map[_type]decodeFunc +) diff --git a/protocol/deletetopics/deletetopics.go b/protocol/deletetopics/deletetopics.go new file mode 100644 index 000000000..3af5a0014 --- /dev/null +++ b/protocol/deletetopics/deletetopics.go @@ -0,0 +1,34 @@ +package deletetopics + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + TopicNames []string `kafka:"min=v0,max=v3"` + TimeoutMs int32 `kafka:"min=v0,max=v3"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.DeleteTopics } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + return cluster.Brokers[cluster.Controller], nil +} + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v1,max=v3"` + Responses []ResponseTopic `kafka:"min=v0,max=v3"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.DeleteTopics } + +type ResponseTopic struct { + Name string `kafka:"min=v0,max=v3"` + ErrorCode int16 `kafka:"min=v0,max=v3"` +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) +) diff --git a/protocol/encode.go b/protocol/encode.go new file mode 100644 index 000000000..2d9d6ff1b --- /dev/null +++ b/protocol/encode.go @@ -0,0 +1,639 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "hash/crc32" + "io" + "reflect" + "sync" + "sync/atomic" +) + +type encoder struct { + writer io.Writer + err error + table *crc32.Table + crc32 uint32 + buffer [32]byte +} + +type encoderChecksum struct { + reader io.Reader + encoder *encoder +} + +func (e *encoderChecksum) Read(b []byte) (int, error) { + n, err := e.reader.Read(b) + if n > 0 { + e.encoder.update(b[:n]) + } + return n, err +} + +func (e *encoder) Reset(w io.Writer) { + e.writer = w + e.err = nil + e.table = nil + e.crc32 = 0 + e.buffer = [32]byte{} +} + +func (e *encoder) ReadFrom(r io.Reader) (int64, error) { + if e.table != nil { + r = &encoderChecksum{ + reader: r, + encoder: e, + } + } + return io.Copy(e.writer, r) +} + +func (e *encoder) Write(b []byte) (int, error) { + if e.err != nil { + return 0, e.err + } + n, err := e.writer.Write(b) + if n > 0 { + e.update(b[:n]) + } + if err != nil { + e.err = err + } + return n, err +} + +func (e *encoder) WriteByte(b byte) error { + e.buffer[0] = b + _, err := e.Write(e.buffer[:1]) + return err +} + +func (e *encoder) WriteString(s string) (int, error) { + // This implementation is an optimization to avoid the heap allocation that + // would occur when converting the string to a []byte to call crc32.Update. + // + // Strings are rarely long in the kafka protocol, so the use of a 32 byte + // buffer is a good comprise between keeping the encoder value small and + // limiting the number of calls to Write. + // + // We introduced this optimization because memory profiles on the benchmarks + // showed that most heap allocations were caused by this code path. + n := 0 + + for len(s) != 0 { + c := copy(e.buffer[:], s) + w, err := e.Write(e.buffer[:c]) + n += w + if err != nil { + return n, err + } + s = s[c:] + } + + return n, nil +} + +func (e *encoder) setCRC(table *crc32.Table) { + e.table, e.crc32 = table, 0 +} + +func (e *encoder) update(b []byte) { + if e.table != nil { + e.crc32 = crc32.Update(e.crc32, e.table, b) + } +} + +func (e *encoder) encodeBool(v value) { + b := int8(0) + if v.bool() { + b = 1 + } + e.writeInt8(b) +} + +func (e *encoder) encodeInt8(v value) { + e.writeInt8(v.int8()) +} + +func (e *encoder) encodeInt16(v value) { + e.writeInt16(v.int16()) +} + +func (e *encoder) encodeInt32(v value) { + e.writeInt32(v.int32()) +} + +func (e *encoder) encodeInt64(v value) { + e.writeInt64(v.int64()) +} + +func (e *encoder) encodeString(v value) { + e.writeString(v.string()) +} + +func (e *encoder) encodeVarString(v value) { + e.writeVarString(v.string()) +} + +func (e *encoder) encodeCompactString(v value) { + e.writeCompactString(v.string()) +} + +func (e *encoder) encodeNullString(v value) { + e.writeNullString(v.string()) +} + +func (e *encoder) encodeVarNullString(v value) { + e.writeVarNullString(v.string()) +} + +func (e *encoder) encodeCompactNullString(v value) { + e.writeCompactNullString(v.string()) +} + +func (e *encoder) encodeBytes(v value) { + e.writeBytes(v.bytes()) +} + +func (e *encoder) encodeVarBytes(v value) { + e.writeVarBytes(v.bytes()) +} + +func (e *encoder) encodeCompactBytes(v value) { + e.writeCompactBytes(v.bytes()) +} + +func (e *encoder) encodeNullBytes(v value) { + e.writeNullBytes(v.bytes()) +} + +func (e *encoder) encodeVarNullBytes(v value) { + e.writeVarNullBytes(v.bytes()) +} + +func (e *encoder) encodeCompactNullBytes(v value) { + e.writeCompactNullBytes(v.bytes()) +} + +func (e *encoder) encodeArray(v value, elemType reflect.Type, encodeElem encodeFunc) { + a := v.array(elemType) + n := a.length() + e.writeInt32(int32(n)) + + for i := 0; i < n; i++ { + encodeElem(e, a.index(i)) + } +} + +func (e *encoder) encodeCompactArray(v value, elemType reflect.Type, encodeElem encodeFunc) { + a := v.array(elemType) + n := a.length() + e.writeUnsignedVarInt(uint64(n + 1)) + + for i := 0; i < n; i++ { + encodeElem(e, a.index(i)) + } +} + +func (e *encoder) encodeNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) { + a := v.array(elemType) + if a.isNil() { + e.writeInt32(-1) + return + } + + n := a.length() + e.writeInt32(int32(n)) + + for i := 0; i < n; i++ { + encodeElem(e, a.index(i)) + } +} + +func (e *encoder) encodeCompactNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) { + a := v.array(elemType) + if a.isNil() { + e.writeUnsignedVarInt(0) + return + } + + n := a.length() + e.writeUnsignedVarInt(uint64(n + 1)) + for i := 0; i < n; i++ { + encodeElem(e, a.index(i)) + } +} + +func (e *encoder) writeInt8(i int8) { + writeInt8(e.buffer[:1], i) + e.Write(e.buffer[:1]) +} + +func (e *encoder) writeInt16(i int16) { + writeInt16(e.buffer[:2], i) + e.Write(e.buffer[:2]) +} + +func (e *encoder) writeInt32(i int32) { + writeInt32(e.buffer[:4], i) + e.Write(e.buffer[:4]) +} + +func (e *encoder) writeInt64(i int64) { + writeInt64(e.buffer[:8], i) + e.Write(e.buffer[:8]) +} + +func (e *encoder) writeString(s string) { + e.writeInt16(int16(len(s))) + e.WriteString(s) +} + +func (e *encoder) writeVarString(s string) { + e.writeVarInt(int64(len(s))) + e.WriteString(s) +} + +func (e *encoder) writeCompactString(s string) { + e.writeUnsignedVarInt(uint64(len(s)) + 1) + e.WriteString(s) +} + +func (e *encoder) writeNullString(s string) { + if s == "" { + e.writeInt16(-1) + } else { + e.writeInt16(int16(len(s))) + e.WriteString(s) + } +} + +func (e *encoder) writeVarNullString(s string) { + if s == "" { + e.writeVarInt(-1) + } else { + e.writeVarInt(int64(len(s))) + e.WriteString(s) + } +} + +func (e *encoder) writeCompactNullString(s string) { + if s == "" { + e.writeUnsignedVarInt(0) + } else { + e.writeUnsignedVarInt(uint64(len(s)) + 1) + e.WriteString(s) + } +} + +func (e *encoder) writeBytes(b []byte) { + e.writeInt32(int32(len(b))) + e.Write(b) +} + +func (e *encoder) writeVarBytes(b []byte) { + e.writeVarInt(int64(len(b))) + e.Write(b) +} + +func (e *encoder) writeCompactBytes(b []byte) { + e.writeUnsignedVarInt(uint64(len(b)) + 1) + e.Write(b) +} + +func (e *encoder) writeNullBytes(b []byte) { + if b == nil { + e.writeInt32(-1) + } else { + e.writeInt32(int32(len(b))) + e.Write(b) + } +} + +func (e *encoder) writeVarNullBytes(b []byte) { + if b == nil { + e.writeVarInt(-1) + } else { + e.writeVarInt(int64(len(b))) + e.Write(b) + } +} + +func (e *encoder) writeCompactNullBytes(b []byte) { + if b == nil { + e.writeUnsignedVarInt(0) + } else { + e.writeUnsignedVarInt(uint64(len(b)) + 1) + e.Write(b) + } +} + +func (e *encoder) writeBytesFrom(b Bytes) error { + size := int64(b.Len()) + e.writeInt32(int32(size)) + n, err := io.Copy(e, b) + if err == nil && n != size { + err = fmt.Errorf("size of bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF) + } + return err +} + +func (e *encoder) writeNullBytesFrom(b Bytes) error { + if b == nil { + e.writeInt32(-1) + return nil + } else { + size := int64(b.Len()) + e.writeInt32(int32(size)) + n, err := io.Copy(e, b) + if err == nil && n != size { + err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF) + } + return err + } +} + +func (e *encoder) writeVarNullBytesFrom(b Bytes) error { + if b == nil { + e.writeVarInt(-1) + return nil + } else { + size := int64(b.Len()) + e.writeVarInt(size) + n, err := io.Copy(e, b) + if err == nil && n != size { + err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF) + } + return err + } +} + +func (e *encoder) writeCompactNullBytesFrom(b Bytes) error { + if b == nil { + e.writeUnsignedVarInt(0) + return nil + } else { + size := int64(b.Len()) + e.writeUnsignedVarInt(uint64(size + 1)) + n, err := io.Copy(e, b) + if err == nil && n != size { + err = fmt.Errorf("size of compact nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF) + } + return err + } +} + +func (e *encoder) writeVarInt(i int64) { + e.writeUnsignedVarInt(uint64((i << 1) ^ (i >> 63))) +} + +func (e *encoder) writeUnsignedVarInt(i uint64) { + b := e.buffer[:] + n := 0 + + for i >= 0x80 && n < len(b) { + b[n] = byte(i) | 0x80 + i >>= 7 + n++ + } + + if n < len(b) { + b[n] = byte(i) + n++ + } + + e.Write(b[:n]) +} + +type encodeFunc func(*encoder, value) + +var ( + _ io.ReaderFrom = (*encoder)(nil) + _ io.Writer = (*encoder)(nil) + _ io.ByteWriter = (*encoder)(nil) + _ io.StringWriter = (*encoder)(nil) + + writerTo = reflect.TypeOf((*io.WriterTo)(nil)).Elem() +) + +func encodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc { + if reflect.PtrTo(typ).Implements(writerTo) { + return writerEncodeFuncOf(typ) + } + switch typ.Kind() { + case reflect.Bool: + return (*encoder).encodeBool + case reflect.Int8: + return (*encoder).encodeInt8 + case reflect.Int16: + return (*encoder).encodeInt16 + case reflect.Int32: + return (*encoder).encodeInt32 + case reflect.Int64: + return (*encoder).encodeInt64 + case reflect.String: + return stringEncodeFuncOf(flexible, tag) + case reflect.Struct: + return structEncodeFuncOf(typ, version, flexible) + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { // []byte + return bytesEncodeFuncOf(flexible, tag) + } + return arrayEncodeFuncOf(typ, version, flexible, tag) + default: + panic("unsupported type: " + typ.String()) + } +} + +func stringEncodeFuncOf(flexible bool, tag structTag) encodeFunc { + switch { + case flexible && tag.Nullable: + // In flexible messages, all strings are compact + return (*encoder).encodeCompactNullString + case flexible: + // In flexible messages, all strings are compact + return (*encoder).encodeCompactString + case tag.Nullable: + return (*encoder).encodeNullString + default: + return (*encoder).encodeString + } +} + +func bytesEncodeFuncOf(flexible bool, tag structTag) encodeFunc { + switch { + case flexible && tag.Nullable: + // In flexible messages, all arrays are compact + return (*encoder).encodeCompactNullBytes + case flexible: + // In flexible messages, all arrays are compact + return (*encoder).encodeCompactBytes + case tag.Nullable: + return (*encoder).encodeNullBytes + default: + return (*encoder).encodeBytes + } +} + +func structEncodeFuncOf(typ reflect.Type, version int16, flexible bool) encodeFunc { + type field struct { + encode encodeFunc + index index + tagID int + } + + var fields []field + var taggedFields []field + + forEachStructField(typ, func(typ reflect.Type, index index, tag string) { + if typ.Size() != 0 { // skip struct{} + forEachStructTag(tag, func(tag structTag) bool { + if tag.MinVersion <= version && version <= tag.MaxVersion { + f := field{ + encode: encodeFuncOf(typ, version, flexible, tag), + index: index, + tagID: tag.TagID, + } + + if tag.TagID < -1 { + // Normal required field + fields = append(fields, f) + } else { + // Optional tagged field (flexible messages only) + taggedFields = append(taggedFields, f) + } + return false + } + return true + }) + } + }) + + return func(e *encoder, v value) { + for i := range fields { + f := &fields[i] + f.encode(e, v.fieldByIndex(f.index)) + } + + if flexible { + // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields + // for details of tag buffers in "flexible" messages. + e.writeUnsignedVarInt(uint64(len(taggedFields))) + + for i := range taggedFields { + f := &taggedFields[i] + e.writeUnsignedVarInt(uint64(f.tagID)) + + buf := &bytes.Buffer{} + se := &encoder{writer: buf} + f.encode(se, v.fieldByIndex(f.index)) + e.writeUnsignedVarInt(uint64(buf.Len())) + e.Write(buf.Bytes()) + } + } + } +} + +func arrayEncodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc { + elemType := typ.Elem() + elemFunc := encodeFuncOf(elemType, version, flexible, tag) + switch { + case flexible && tag.Nullable: + // In flexible messages, all arrays are compact + return func(e *encoder, v value) { e.encodeCompactNullArray(v, elemType, elemFunc) } + case flexible: + // In flexible messages, all arrays are compact + return func(e *encoder, v value) { e.encodeCompactArray(v, elemType, elemFunc) } + case tag.Nullable: + return func(e *encoder, v value) { e.encodeNullArray(v, elemType, elemFunc) } + default: + return func(e *encoder, v value) { e.encodeArray(v, elemType, elemFunc) } + } +} + +func writerEncodeFuncOf(typ reflect.Type) encodeFunc { + typ = reflect.PtrTo(typ) + return func(e *encoder, v value) { + // Optimization to write directly into the buffer when the encoder + // does no need to compute a crc32 checksum. + w := io.Writer(e) + if e.table == nil { + w = e.writer + } + _, err := v.iface(typ).(io.WriterTo).WriteTo(w) + if err != nil { + e.err = err + } + } +} + +func writeInt8(b []byte, i int8) { + b[0] = byte(i) +} + +func writeInt16(b []byte, i int16) { + binary.BigEndian.PutUint16(b, uint16(i)) +} + +func writeInt32(b []byte, i int32) { + binary.BigEndian.PutUint32(b, uint32(i)) +} + +func writeInt64(b []byte, i int64) { + binary.BigEndian.PutUint64(b, uint64(i)) +} + +func Marshal(version int16, value interface{}) ([]byte, error) { + typ := typeOf(value) + cache, _ := marshalers.Load().(map[_type]encodeFunc) + encode := cache[typ] + + if encode == nil { + encode = encodeFuncOf(reflect.TypeOf(value), version, false, structTag{ + MinVersion: -1, + MaxVersion: -1, + TagID: -2, + Compact: true, + Nullable: true, + }) + + newCache := make(map[_type]encodeFunc, len(cache)+1) + newCache[typ] = encode + + for typ, fun := range cache { + newCache[typ] = fun + } + + marshalers.Store(newCache) + } + + e, _ := encoders.Get().(*encoder) + if e == nil { + e = &encoder{writer: new(bytes.Buffer)} + } + + b, _ := e.writer.(*bytes.Buffer) + defer func() { + b.Reset() + e.Reset(b) + encoders.Put(e) + }() + + encode(e, nonAddressableValueOf(value)) + + if e.err != nil { + return nil, e.err + } + + buf := b.Bytes() + out := make([]byte, len(buf)) + copy(out, buf) + return out, nil +} + +var ( + encoders sync.Pool // *encoder + marshalers atomic.Value // map[_type]encodeFunc +) diff --git a/protocol/error.go b/protocol/error.go new file mode 100644 index 000000000..c8d39c2ef --- /dev/null +++ b/protocol/error.go @@ -0,0 +1,91 @@ +package protocol + +import ( + "fmt" +) + +// Error represents client-side protocol errors. +type Error string + +func (e Error) Error() string { return string(e) } + +func Errorf(msg string, args ...interface{}) Error { + return Error(fmt.Sprintf(msg, args...)) +} + +const ( + // ErrNoTopic is returned when a request needs to be sent to a specific + ErrNoTopic Error = "topic not found" + + // ErrNoPartition is returned when a request needs to be sent to a specific + // partition, but the client did not find it in the cluster metadata. + ErrNoPartition Error = "topic partition not found" + + // ErrNoLeader is returned when a request needs to be sent to a partition + // leader, but the client could not determine what the leader was at this + // time. + ErrNoLeader Error = "topic partition has no leader" + + // ErrNoRecord is returned when attempting to write a message containing an + // empty record set (which kafka forbids). + // + // We handle this case client-side because kafka will close the connection + // that it received an empty produce request on, causing all concurrent + // requests to be aborted. + ErrNoRecord Error = "record set contains no records" + + // ErrNoReset is returned by ResetRecordReader when the record reader does + // not support being reset. + ErrNoReset Error = "record sequence does not support reset" +) + +type TopicError struct { + Topic string + Err error +} + +func NewTopicError(topic string, err error) *TopicError { + return &TopicError{Topic: topic, Err: err} +} + +func NewErrNoTopic(topic string) *TopicError { + return NewTopicError(topic, ErrNoTopic) +} + +func (e *TopicError) Error() string { + return fmt.Sprintf("%v (topic=%q)", e.Err, e.Topic) +} + +func (e *TopicError) Unwrap() error { + return e.Err +} + +type TopicPartitionError struct { + Topic string + Partition int32 + Err error +} + +func NewTopicPartitionError(topic string, partition int32, err error) *TopicPartitionError { + return &TopicPartitionError{ + Topic: topic, + Partition: partition, + Err: err, + } +} + +func NewErrNoPartition(topic string, partition int32) *TopicPartitionError { + return NewTopicPartitionError(topic, partition, ErrNoPartition) +} + +func NewErrNoLeader(topic string, partition int32) *TopicPartitionError { + return NewTopicPartitionError(topic, partition, ErrNoLeader) +} + +func (e *TopicPartitionError) Error() string { + return fmt.Sprintf("%v (topic=%q partition=%d)", e.Err, e.Topic, e.Partition) +} + +func (e *TopicPartitionError) Unwrap() error { + return e.Err +} diff --git a/protocol/fetch/fetch.go b/protocol/fetch/fetch.go new file mode 100644 index 000000000..6ce7bae1b --- /dev/null +++ b/protocol/fetch/fetch.go @@ -0,0 +1,126 @@ +package fetch + +import ( + "fmt" + + "github.com/segmentio/kafka-go/protocol" +) + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + ReplicaID int32 `kafka:"min=v0,max=v11"` + MaxWaitTime int32 `kafka:"min=v0,max=v11"` + MinBytes int32 `kafka:"min=v0,max=v11"` + MaxBytes int32 `kafka:"min=v3,max=v11"` + IsolationLevel int8 `kafka:"min=v4,max=v11"` + SessionID int32 `kafka:"min=v7,max=v11"` + SessionEpoch int32 `kafka:"min=v7,max=v11"` + Topics []RequestTopic `kafka:"min=v0,max=v11"` + ForgottenTopics []RequestForgottenTopic `kafka:"min=v7,max=v11"` + RackID string `kafka:"min=v11,max=v11"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.Fetch } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + broker := protocol.Broker{ID: -1} + + for i := range r.Topics { + t := &r.Topics[i] + + topic, ok := cluster.Topics[t.Topic] + if !ok { + return broker, NewError(protocol.NewErrNoTopic(t.Topic)) + } + + for j := range t.Partitions { + p := &t.Partitions[j] + + partition, ok := topic.Partitions[p.Partition] + if !ok { + return broker, NewError(protocol.NewErrNoPartition(t.Topic, p.Partition)) + } + + if b, ok := cluster.Brokers[partition.Leader]; !ok { + return broker, NewError(protocol.NewErrNoLeader(t.Topic, p.Partition)) + } else if broker.ID < 0 { + broker = b + } else if b.ID != broker.ID { + return broker, NewError(fmt.Errorf("mismatching leaders (%d!=%d)", b.ID, broker.ID)) + } + } + } + + return broker, nil +} + +type RequestTopic struct { + Topic string `kafka:"min=v0,max=v11"` + Partitions []RequestPartition `kafka:"min=v0,max=v11"` +} + +type RequestPartition struct { + Partition int32 `kafka:"min=v0,max=v11"` + CurrentLeaderEpoch int32 `kafka:"min=v9,max=v11"` + FetchOffset int64 `kafka:"min=v0,max=v11"` + LogStartOffset int64 `kafka:"min=v5,max=v11"` + PartitionMaxBytes int32 `kafka:"min=v0,max=v11"` +} + +type RequestForgottenTopic struct { + Topic string `kafka:"min=v7,max=v11"` + Partitions []int32 `kafka:"min=v7,max=v11"` +} + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v1,max=v11"` + ErrorCode int16 `kafka:"min=v7,max=v11"` + SessionID int32 `kafka:"min=v7,max=v11"` + Topics []ResponseTopic `kafka:"min=v0,max=v11"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.Fetch } + +type ResponseTopic struct { + Topic string `kafka:"min=v0,max=v11"` + Partitions []ResponsePartition `kafka:"min=v0,max=v11"` +} + +type ResponsePartition struct { + Partition int32 `kafka:"min=v0,max=v11"` + ErrorCode int16 `kafka:"min=v0,max=v11"` + HighWatermark int64 `kafka:"min=v0,max=v11"` + LastStableOffset int64 `kafka:"min=v4,max=v11"` + LogStartOffset int64 `kafka:"min=v5,max=v11"` + AbortedTransactions []ResponseTransaction `kafka:"min=v4,max=v11"` + PreferredReadReplica int32 `kafka:"min=v11,max=v11"` + RecordSet protocol.RecordSet `kafka:"min=v0,max=v11"` +} + +type ResponseTransaction struct { + ProducerID int64 `kafka:"min=v4,max=v11"` + FirstOffset int64 `kafka:"min=v4,max=v11"` +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) +) + +type Error struct { + Err error +} + +func NewError(err error) *Error { + return &Error{Err: err} +} + +func (e *Error) Error() string { + return fmt.Sprintf("fetch request error: %v", e.Err) +} + +func (e *Error) Unwrap() error { + return e.Err +} diff --git a/protocol/fetch/fetch_test.go b/protocol/fetch/fetch_test.go new file mode 100644 index 000000000..3f7f820d9 --- /dev/null +++ b/protocol/fetch/fetch_test.go @@ -0,0 +1,147 @@ +package fetch_test + +import ( + "testing" + "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/fetch" + "github.com/segmentio/kafka-go/protocol/prototest" +) + +const ( + v0 = 0 + v11 = 11 +) + +func TestFetchRequest(t *testing.T) { + prototest.TestRequest(t, v0, &fetch.Request{ + ReplicaID: -1, + MaxWaitTime: 500, + MinBytes: 1024, + Topics: []fetch.RequestTopic{ + { + Topic: "topic-1", + Partitions: []fetch.RequestPartition{ + { + Partition: 1, + FetchOffset: 2, + PartitionMaxBytes: 1024, + }, + }, + }, + }, + }) +} + +func TestFetchResponse(t *testing.T) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.TestResponse(t, v0, &fetch.Response{ + Topics: []fetch.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []fetch.ResponsePartition{ + { + Partition: 1, + HighWatermark: 1000, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.TestResponse(t, v11, &fetch.Response{ + Topics: []fetch.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []fetch.ResponsePartition{ + { + Partition: 1, + HighWatermark: 1000, + RecordSet: protocol.RecordSet{ + Version: 2, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) +} + +func BenchmarkFetchResponse(b *testing.B) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.BenchmarkResponse(b, v0, &fetch.Response{ + Topics: []fetch.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []fetch.ResponsePartition{ + { + Partition: 1, + HighWatermark: 1000, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.BenchmarkResponse(b, v11, &fetch.Response{ + Topics: []fetch.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []fetch.ResponsePartition{ + { + Partition: 1, + HighWatermark: 1000, + RecordSet: protocol.RecordSet{ + Version: 2, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) +} diff --git a/protocol/findcoordinator/findcoordinator.go b/protocol/findcoordinator/findcoordinator.go new file mode 100644 index 000000000..0306e206d --- /dev/null +++ b/protocol/findcoordinator/findcoordinator.go @@ -0,0 +1,25 @@ +package findcoordinator + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + Key string `kafka:"min=v0,max=v2"` + KeyType int8 `kafka:"min=v1,max=v2"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.FindCoordinator } + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v1,max=v2"` + ErrorCode int16 `kafka:"min=v0,max=v2"` + ErrorMessage string `kafka:"min=v1,max=v2,nullable"` + NodeID int32 `kafka:"min=v0,max=v2"` + Host string `kafka:"min=v0,max=v2"` + Port int32 `kafka:"min=v0,max=v2"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.FindCoordinator } diff --git a/protocol/listoffsets/listoffsets.go b/protocol/listoffsets/listoffsets.go new file mode 100644 index 000000000..059662d9a --- /dev/null +++ b/protocol/listoffsets/listoffsets.go @@ -0,0 +1,230 @@ +package listoffsets + +import ( + "sort" + + "github.com/segmentio/kafka-go/protocol" +) + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + ReplicaID int32 `kafka:"min=v1,max=v5"` + IsolationLevel int8 `kafka:"min=v2,max=v5"` + Topics []RequestTopic `kafka:"min=v1,max=v5"` +} + +type RequestTopic struct { + Topic string `kafka:"min=v1,max=v5"` + Partitions []RequestPartition `kafka:"min=v1,max=v5"` +} + +type RequestPartition struct { + Partition int32 `kafka:"min=v1,max=v5"` + CurrentLeaderEpoch int32 `kafka:"min=v4,max=v5"` + Timestamp int64 `kafka:"min=v1,max=v5"` + // v0 of the API predates kafka 0.10, and doesn't make much sense to + // use so we chose not to support it. It had this extra field to limit + // the number of offsets returned, which has been removed in v1. + // + // MaxNumOffsets int32 `kafka:"min=v0,max=v0"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.ListOffsets } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + // Expects r to be a request that was returned by Map, will likely panic + // or produce the wrong result if that's not the case. + partition := r.Topics[0].Partitions[0].Partition + topic := r.Topics[0].Topic + + for _, p := range cluster.Topics[topic].Partitions { + if p.ID == partition { + return cluster.Brokers[p.Leader], nil + } + } + + return protocol.Broker{ID: -1}, nil +} + +func (r *Request) Split(cluster protocol.Cluster) ([]protocol.Message, protocol.Merger, error) { + // Because kafka refuses to answer ListOffsets requests containing multiple + // entries of unique topic/partition pairs, we submit multiple requests on + // the wire and merge their results back. + // + // ListOffsets requests also need to be sent to partition leaders, to keep + // the logic simple we simply split each offset request into a single + // message. This may cause a bit more requests to be sent on the wire but + // it keeps the code sane, we can still optimize the aggregation mechanism + // later if it becomes a problem. + // + // Really the idea here is to shield applications from having to deal with + // the limitation of the kafka server, so they can request any combinations + // of topic/partition/offsets. + requests := make([]Request, 0, 2*len(r.Topics)) + + for _, t := range r.Topics { + for _, p := range t.Partitions { + requests = append(requests, Request{ + ReplicaID: r.ReplicaID, + IsolationLevel: r.IsolationLevel, + Topics: []RequestTopic{{ + Topic: t.Topic, + Partitions: []RequestPartition{{ + Partition: p.Partition, + CurrentLeaderEpoch: p.CurrentLeaderEpoch, + Timestamp: p.Timestamp, + }}, + }}, + }) + } + } + + messages := make([]protocol.Message, len(requests)) + + for i := range requests { + messages[i] = &requests[i] + } + + return messages, new(Response), nil +} + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v2,max=v5"` + Topics []ResponseTopic `kafka:"min=v1,max=v5"` +} + +type ResponseTopic struct { + Topic string `kafka:"min=v1,max=v5"` + Partitions []ResponsePartition `kafka:"min=v1,max=v5"` +} + +type ResponsePartition struct { + Partition int32 `kafka:"min=v1,max=v5"` + ErrorCode int16 `kafka:"min=v1,max=v5"` + Timestamp int64 `kafka:"min=v1,max=v5"` + Offset int64 `kafka:"min=v1,max=v5"` + LeaderEpoch int32 `kafka:"min=v4,max=v5"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.ListOffsets } + +func (r *Response) Merge(requests []protocol.Message, results []interface{}) (protocol.Message, error) { + type topicPartition struct { + topic string + partition int32 + } + + // Kafka doesn't always return the timestamp in the response, for example + // when the request sends -2 (for the first offset) it always returns -1, + // probably to indicate that the timestamp is unknown. This means that we + // can't correlate the requests and responses based on their timestamps, + // the primary key is the topic/partition pair. + // + // To make the API a bit friendly, we reconstructing an index of topic + // partitions to the timestamps that were requested, and override the + // timestamp value in the response. + timestamps := make([]map[topicPartition]int64, len(requests)) + + for i, m := range requests { + req := m.(*Request) + ts := make(map[topicPartition]int64, len(req.Topics)) + + for _, t := range req.Topics { + for _, p := range t.Partitions { + ts[topicPartition{ + topic: t.Topic, + partition: p.Partition, + }] = p.Timestamp + } + } + + timestamps[i] = ts + } + + topics := make(map[string][]ResponsePartition) + errors := 0 + + for i, res := range results { + m, err := protocol.Result(res) + if err != nil { + for _, t := range requests[i].(*Request).Topics { + partitions := topics[t.Topic] + + for _, p := range t.Partitions { + partitions = append(partitions, ResponsePartition{ + Partition: p.Partition, + ErrorCode: -1, // UNKNOWN, can we do better? + Timestamp: -1, + Offset: -1, + LeaderEpoch: -1, + }) + } + + topics[t.Topic] = partitions + } + errors++ + continue + } + + response := m.(*Response) + + if r.ThrottleTimeMs < response.ThrottleTimeMs { + r.ThrottleTimeMs = response.ThrottleTimeMs + } + + for _, t := range response.Topics { + for _, p := range t.Partitions { + if timestamp, ok := timestamps[i][topicPartition{ + topic: t.Topic, + partition: p.Partition, + }]; ok { + p.Timestamp = timestamp + } + topics[t.Topic] = append(topics[t.Topic], p) + } + } + + } + + if errors > 0 && errors == len(results) { + _, err := protocol.Result(results[0]) + return nil, err + } + + r.Topics = make([]ResponseTopic, 0, len(topics)) + + for topicName, partitions := range topics { + r.Topics = append(r.Topics, ResponseTopic{ + Topic: topicName, + Partitions: partitions, + }) + } + + sort.Slice(r.Topics, func(i, j int) bool { + return r.Topics[i].Topic < r.Topics[j].Topic + }) + + for _, t := range r.Topics { + sort.Slice(t.Partitions, func(i, j int) bool { + p1 := &t.Partitions[i] + p2 := &t.Partitions[j] + + if p1.Partition != p2.Partition { + return p1.Partition < p2.Partition + } + + return p1.Offset < p2.Offset + }) + } + + return r, nil +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) + _ protocol.Splitter = (*Request)(nil) + _ protocol.Merger = (*Response)(nil) +) diff --git a/protocol/listoffsets/listoffsets_test.go b/protocol/listoffsets/listoffsets_test.go new file mode 100644 index 000000000..ffeb2ca99 --- /dev/null +++ b/protocol/listoffsets/listoffsets_test.go @@ -0,0 +1,104 @@ +package listoffsets_test + +import ( + "testing" + + "github.com/segmentio/kafka-go/protocol/listoffsets" + "github.com/segmentio/kafka-go/protocol/prototest" +) + +const ( + v1 = 1 + v4 = 4 +) + +func TestListOffsetsRequest(t *testing.T) { + prototest.TestRequest(t, v1, &listoffsets.Request{ + ReplicaID: 1, + Topics: []listoffsets.RequestTopic{ + { + Topic: "topic-1", + Partitions: []listoffsets.RequestPartition{ + {Partition: 0, Timestamp: 1e9}, + {Partition: 1, Timestamp: 1e9}, + {Partition: 2, Timestamp: 1e9}, + }, + }, + }, + }) + + prototest.TestRequest(t, v4, &listoffsets.Request{ + ReplicaID: 1, + IsolationLevel: 2, + Topics: []listoffsets.RequestTopic{ + { + Topic: "topic-1", + Partitions: []listoffsets.RequestPartition{ + {Partition: 0, Timestamp: 1e9}, + {Partition: 1, Timestamp: 1e9}, + {Partition: 2, Timestamp: 1e9}, + }, + }, + { + Topic: "topic-2", + Partitions: []listoffsets.RequestPartition{ + {Partition: 0, CurrentLeaderEpoch: 10, Timestamp: 1e9}, + {Partition: 1, CurrentLeaderEpoch: 11, Timestamp: 1e9}, + {Partition: 2, CurrentLeaderEpoch: 12, Timestamp: 1e9}, + }, + }, + }, + }) +} + +func TestListOffsetsResponse(t *testing.T) { + prototest.TestResponse(t, v1, &listoffsets.Response{ + Topics: []listoffsets.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []listoffsets.ResponsePartition{ + { + Partition: 0, + ErrorCode: 0, + Timestamp: 1e9, + Offset: 1234567890, + }, + }, + }, + }, + }) + + prototest.TestResponse(t, v4, &listoffsets.Response{ + ThrottleTimeMs: 1234, + Topics: []listoffsets.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []listoffsets.ResponsePartition{ + { + Partition: 0, + ErrorCode: 0, + Timestamp: 1e9, + Offset: 1234567890, + LeaderEpoch: 10, + }, + }, + }, + { + Topic: "topic-2", + Partitions: []listoffsets.ResponsePartition{ + { + Partition: 0, + ErrorCode: 0, + Timestamp: 1e9, + Offset: 1234567890, + LeaderEpoch: 10, + }, + { + Partition: 1, + ErrorCode: 2, + }, + }, + }, + }, + }) +} diff --git a/protocol/metadata/metadata.go b/protocol/metadata/metadata.go new file mode 100644 index 000000000..ac2031bda --- /dev/null +++ b/protocol/metadata/metadata.go @@ -0,0 +1,52 @@ +package metadata + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + TopicNames []string `kafka:"min=v0,max=v8,nullable"` + AllowAutoTopicCreation bool `kafka:"min=v4,max=v8"` + IncludeClusterAuthorizedOperations bool `kafka:"min=v8,max=v8"` + IncludeTopicAuthorizedOperations bool `kafka:"min=v8,max=v8"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.Metadata } + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v3,max=v8"` + Brokers []ResponseBroker `kafka:"min=v0,max=v8"` + ClusterID string `kafka:"min=v2,max=v8,nullable"` + ControllerID int32 `kafka:"min=v1,max=v8"` + Topics []ResponseTopic `kafka:"min=v0,max=v8"` + ClusterAuthorizedOperations int32 `kafka:"min=v8,max=v8"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.Metadata } + +type ResponseBroker struct { + NodeID int32 `kafka:"min=v0,max=v8"` + Host string `kafka:"min=v0,max=v8"` + Port int32 `kafka:"min=v0,max=v8"` + Rack string `kafka:"min=v1,max=v8,nullable"` +} + +type ResponseTopic struct { + ErrorCode int16 `kafka:"min=v0,max=v8"` + Name string `kafka:"min=v0,max=v8"` + IsInternal bool `kafka:"min=v1,max=v8"` + Partitions []ResponsePartition `kafka:"min=v0,max=v8"` + TopicAuthorizedOperations int32 `kafka:"min=v8,max=v8"` +} + +type ResponsePartition struct { + ErrorCode int16 `kafka:"min=v0,max=v8"` + PartitionIndex int32 `kafka:"min=v0,max=v8"` + LeaderID int32 `kafka:"min=v0,max=v8"` + LeaderEpoch int32 `kafka:"min=v7,max=v8"` + ReplicaNodes []int32 `kafka:"min=v0,max=v8"` + IsrNodes []int32 `kafka:"min=v0,max=v8"` + OfflineReplicas []int32 `kafka:"min=v5,max=v8"` +} diff --git a/protocol/metadata/metadata_test.go b/protocol/metadata/metadata_test.go new file mode 100644 index 000000000..168fd2785 --- /dev/null +++ b/protocol/metadata/metadata_test.go @@ -0,0 +1,199 @@ +package metadata_test + +import ( + "testing" + + "github.com/segmentio/kafka-go/protocol/metadata" + "github.com/segmentio/kafka-go/protocol/prototest" +) + +const ( + v0 = 0 + v1 = 1 + v4 = 4 + v8 = 8 +) + +func TestMetadataRequest(t *testing.T) { + prototest.TestRequest(t, v0, &metadata.Request{ + TopicNames: nil, + }) + + prototest.TestRequest(t, v4, &metadata.Request{ + TopicNames: []string{"hello", "world"}, + AllowAutoTopicCreation: true, + }) + + prototest.TestRequest(t, v8, &metadata.Request{ + TopicNames: []string{"hello", "world"}, + AllowAutoTopicCreation: true, + IncludeClusterAuthorizedOperations: true, + IncludeTopicAuthorizedOperations: true, + }) +} + +func TestMetadataResponse(t *testing.T) { + prototest.TestResponse(t, v0, &metadata.Response{ + Brokers: []metadata.ResponseBroker{ + { + NodeID: 0, + Host: "127.0.0.1", + Port: 9092, + }, + { + NodeID: 1, + Host: "127.0.0.1", + Port: 9093, + }, + }, + Topics: []metadata.ResponseTopic{ + { + Name: "topic-1", + Partitions: []metadata.ResponsePartition{ + { + PartitionIndex: 0, + LeaderID: 1, + ReplicaNodes: []int32{0}, + IsrNodes: []int32{1, 0}, + }, + }, + }, + }, + }) + + prototest.TestResponse(t, v1, &metadata.Response{ + ControllerID: 1, + Brokers: []metadata.ResponseBroker{ + { + NodeID: 0, + Host: "127.0.0.1", + Port: 9092, + Rack: "rack-1", + }, + { + NodeID: 1, + Host: "127.0.0.1", + Port: 9093, + Rack: "rack-2", + }, + }, + Topics: []metadata.ResponseTopic{ + { + Name: "topic-1", + IsInternal: true, + Partitions: []metadata.ResponsePartition{ + { + PartitionIndex: 0, + LeaderID: 1, + ReplicaNodes: []int32{0}, + IsrNodes: []int32{1, 0}, + }, + { + PartitionIndex: 1, + LeaderID: 0, + ReplicaNodes: []int32{1}, + IsrNodes: []int32{0, 1}, + }, + }, + }, + }, + }) + + prototest.TestResponse(t, v8, &metadata.Response{ + ThrottleTimeMs: 123, + ClusterID: "test", + ControllerID: 1, + ClusterAuthorizedOperations: 0x01, + Brokers: []metadata.ResponseBroker{ + { + NodeID: 0, + Host: "127.0.0.1", + Port: 9092, + Rack: "rack-1", + }, + { + NodeID: 1, + Host: "127.0.0.1", + Port: 9093, + Rack: "rack-2", + }, + }, + Topics: []metadata.ResponseTopic{ + { + Name: "topic-1", + Partitions: []metadata.ResponsePartition{ + { + PartitionIndex: 0, + LeaderID: 1, + LeaderEpoch: 1234567890, + ReplicaNodes: []int32{0}, + IsrNodes: []int32{0}, + OfflineReplicas: []int32{1}, + }, + { + ErrorCode: 1, + ReplicaNodes: []int32{}, + IsrNodes: []int32{}, + OfflineReplicas: []int32{}, + }, + }, + TopicAuthorizedOperations: 0x01, + }, + }, + }) +} + +func BenchmarkMetadataRequest(b *testing.B) { + prototest.BenchmarkRequest(b, v8, &metadata.Request{ + TopicNames: []string{"hello", "world"}, + AllowAutoTopicCreation: true, + IncludeClusterAuthorizedOperations: true, + IncludeTopicAuthorizedOperations: true, + }) +} + +func BenchmarkMetadataResponse(b *testing.B) { + prototest.BenchmarkResponse(b, v8, &metadata.Response{ + ThrottleTimeMs: 123, + ClusterID: "test", + ControllerID: 1, + ClusterAuthorizedOperations: 0x01, + Brokers: []metadata.ResponseBroker{ + { + NodeID: 0, + Host: "127.0.0.1", + Port: 9092, + Rack: "rack-1", + }, + { + NodeID: 1, + Host: "127.0.0.1", + Port: 9093, + Rack: "rack-2", + }, + }, + Topics: []metadata.ResponseTopic{ + { + Name: "topic-1", + Partitions: []metadata.ResponsePartition{ + { + PartitionIndex: 0, + LeaderID: 1, + LeaderEpoch: 1234567890, + ReplicaNodes: []int32{0}, + IsrNodes: []int32{0}, + OfflineReplicas: []int32{1}, + }, + { + ErrorCode: 1, + ReplicaNodes: []int32{}, + IsrNodes: []int32{}, + OfflineReplicas: []int32{}, + }, + }, + TopicAuthorizedOperations: 0x01, + }, + }, + }) + +} diff --git a/protocol/offsetfetch/offsetfetch.go b/protocol/offsetfetch/offsetfetch.go new file mode 100644 index 000000000..011003340 --- /dev/null +++ b/protocol/offsetfetch/offsetfetch.go @@ -0,0 +1,46 @@ +package offsetfetch + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + GroupID string `kafka:"min=v0,max=v5"` + Topics []RequestTopic `kafka:"min=v0,max=v5"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.OffsetFetch } + +func (r *Request) Group() string { return r.GroupID } + +type RequestTopic struct { + Name string `kafka:"min=v0,max=v5"` + PartitionIndexes []int32 `kafka:"min=v0,max=v5"` +} + +var ( + _ protocol.GroupMessage = (*Request)(nil) +) + +type Response struct { + ThrottleTimeMs int32 `kafka:"min=v3,max=v5"` + Topics []ResponseTopic `kafka:"min=v0,max=v5"` + ErrorCode int16 `kafka:"min=v2,max=v5"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.OffsetFetch } + +type ResponseTopic struct { + Name string `kafka:"min=v0,max=v5"` + Partitions []ResponsePartition `kafka:"min=v0,max=v5"` +} + +type ResponsePartition struct { + PartitionIndex int32 `kafka:"min=v0,max=v5"` + CommittedOffset int64 `kafka:"min=v0,max=v5"` + ComittedLeaderEpoch int32 `kafka:"min=v5,max=v5"` + Metadata string `kafka:"min=v0,max=v5,nullable"` + ErrorCode int16 `kafka:"min=v0,max=v5"` +} diff --git a/protocol/produce/produce.go b/protocol/produce/produce.go new file mode 100644 index 000000000..6d337c3cf --- /dev/null +++ b/protocol/produce/produce.go @@ -0,0 +1,147 @@ +package produce + +import ( + "fmt" + + "github.com/segmentio/kafka-go/protocol" +) + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + TransactionalID string `kafka:"min=v3,max=v8,nullable"` + Acks int16 `kafka:"min=v0,max=v8"` + Timeout int32 `kafka:"min=v0,max=v8"` + Topics []RequestTopic `kafka:"min=v0,max=v8"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.Produce } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + broker := protocol.Broker{ID: -1} + + for i := range r.Topics { + t := &r.Topics[i] + + topic, ok := cluster.Topics[t.Topic] + if !ok { + return broker, NewError(protocol.NewErrNoTopic(t.Topic)) + } + + for j := range t.Partitions { + p := &t.Partitions[j] + + partition, ok := topic.Partitions[p.Partition] + if !ok { + return broker, NewError(protocol.NewErrNoPartition(t.Topic, p.Partition)) + } + + if b, ok := cluster.Brokers[partition.Leader]; !ok { + return broker, NewError(protocol.NewErrNoLeader(t.Topic, p.Partition)) + } else if broker.ID < 0 { + broker = b + } else if b.ID != broker.ID { + return broker, NewError(fmt.Errorf("mismatching leaders (%d!=%d)", b.ID, broker.ID)) + } + } + } + + return broker, nil +} + +func (r *Request) Prepare(apiVersion int16) { + // Determine which version of the message should be used, based on which + // version of the Produce API is supported by the server. + // + // In version 0.11, kafka gives this error: + // + // org.apache.kafka.common.record.InvalidRecordException + // Produce requests with version 3 are only allowed to contain record batches with magic version. + // + // In version 2.x, kafka refuses the message claiming that the CRC32 + // checksum is invalid. + var recordVersion int8 + + if apiVersion < 3 { + recordVersion = 1 + } else { + recordVersion = 2 + } + + for i := range r.Topics { + t := &r.Topics[i] + + for j := range t.Partitions { + p := &t.Partitions[j] + + // Allow the program to overload the version if really needed. + if p.RecordSet.Version == 0 { + p.RecordSet.Version = recordVersion + } + } + } +} + +func (r *Request) HasResponse() bool { + return r.Acks != 0 +} + +type RequestTopic struct { + Topic string `kafka:"min=v0,max=v8"` + Partitions []RequestPartition `kafka:"min=v0,max=v8"` +} + +type RequestPartition struct { + Partition int32 `kafka:"min=v0,max=v8"` + RecordSet protocol.RecordSet `kafka:"min=v0,max=v8"` +} + +type Response struct { + Topics []ResponseTopic `kafka:"min=v0,max=v8"` + ThrottleTimeMs int32 `kafka:"min=v1,max=v8"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.Produce } + +type ResponseTopic struct { + Topic string `kafka:"min=v0,max=v8"` + Partitions []ResponsePartition `kafka:"min=v0,max=v8"` +} + +type ResponsePartition struct { + Partition int32 `kafka:"min=v0,max=v8"` + ErrorCode int16 `kafka:"min=v0,max=v8"` + BaseOffset int64 `kafka:"min=v0,max=v8"` + LogAppendTime int64 `kafka:"min=v2,max=v8"` + LogStartOffset int64 `kafka:"min=v5,max=v8"` + RecordErrors []ResponseError `kafka:"min=v8,max=v8"` + ErrorMessage string `kafka:"min=v8,max=v8,nullable"` +} + +type ResponseError struct { + BatchIndex int32 `kafka:"min=v8,max=v8"` + BatchIndexErrorMessage string `kafka:"min=v8,max=v8,nullable"` +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) + _ protocol.PreparedMessage = (*Request)(nil) +) + +type Error struct { + Err error +} + +func NewError(err error) *Error { + return &Error{Err: err} +} + +func (e *Error) Error() string { + return fmt.Sprintf("fetch request error: %v", e.Err) +} + +func (e *Error) Unwrap() error { + return e.Err +} diff --git a/protocol/produce/produce_test.go b/protocol/produce/produce_test.go new file mode 100644 index 000000000..276b7d463 --- /dev/null +++ b/protocol/produce/produce_test.go @@ -0,0 +1,273 @@ +package produce_test + +import ( + "testing" + "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/produce" + "github.com/segmentio/kafka-go/protocol/prototest" +) + +const ( + v0 = 0 + v3 = 3 + v5 = 5 + v8 = 8 +) + +func TestProduceRequest(t *testing.T) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.TestRequest(t, v0, &produce.Request{ + Acks: 1, + Timeout: 500, + Topics: []produce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []produce.RequestPartition{ + { + Partition: 0, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), + }, + }, + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + + { + Topic: "topic-2", + Partitions: []produce.RequestPartition{ + { + Partition: 0, + RecordSet: protocol.RecordSet{ + Version: 1, + Attributes: protocol.Gzip, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) + + prototest.TestRequest(t, v3, &produce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []produce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []produce.RequestPartition{ + { + Partition: 0, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), + }, + }, + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.TestRequest(t, v5, &produce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []produce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []produce.RequestPartition{ + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 2, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + + { + Topic: "topic-2", + Partitions: []produce.RequestPartition{ + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 2, + Attributes: protocol.Snappy, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) +} + +func TestProduceResponse(t *testing.T) { + prototest.TestResponse(t, v0, &produce.Response{ + Topics: []produce.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []produce.ResponsePartition{ + { + Partition: 0, + ErrorCode: 0, + BaseOffset: 0, + }, + { + Partition: 1, + ErrorCode: 0, + BaseOffset: 42, + }, + }, + }, + }, + }) + + prototest.TestResponse(t, v8, &produce.Response{ + Topics: []produce.ResponseTopic{ + { + Topic: "topic-1", + Partitions: []produce.ResponsePartition{ + { + Partition: 0, + ErrorCode: 0, + BaseOffset: 42, + LogAppendTime: 1e9, + LogStartOffset: 10, + RecordErrors: []produce.ResponseError{}, + }, + { + Partition: 1, + ErrorCode: 1, + RecordErrors: []produce.ResponseError{ + {BatchIndex: 1, BatchIndexErrorMessage: "message-1"}, + {BatchIndex: 2, BatchIndexErrorMessage: "message-2"}, + {BatchIndex: 3, BatchIndexErrorMessage: "message-3"}, + }, + ErrorMessage: "something went wrong", + }, + }, + }, + }, + }) +} + +func BenchmarkProduceRequest(b *testing.B) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.BenchmarkRequest(b, v3, &produce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []produce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []produce.RequestPartition{ + { + Partition: 0, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), + }, + }, + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 1, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.BenchmarkRequest(b, v5, &produce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []produce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []produce.RequestPartition{ + { + Partition: 1, + RecordSet: protocol.RecordSet{ + Version: 2, + Records: protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), + }, + }, + }, + }, + }, + }) +} diff --git a/protocol/protocol.go b/protocol/protocol.go new file mode 100644 index 000000000..c32366d09 --- /dev/null +++ b/protocol/protocol.go @@ -0,0 +1,480 @@ +package protocol + +import ( + "fmt" + "io" + "net" + "reflect" + "strconv" + "strings" +) + +// Message is an interface implemented by all request and response types of the +// kafka protocol. +// +// This interface is used mostly as a safe-guard to provide a compile-time check +// for values passed to functions dealing kafka message types. +type Message interface { + ApiKey() ApiKey +} + +type ApiKey int16 + +func (k ApiKey) String() string { + if i := int(k); i >= 0 && i < len(apiNames) { + return apiNames[i] + } + return strconv.Itoa(int(k)) +} + +func (k ApiKey) MinVersion() int16 { return k.apiType().minVersion() } + +func (k ApiKey) MaxVersion() int16 { return k.apiType().maxVersion() } + +func (k ApiKey) SelectVersion(minVersion, maxVersion int16) int16 { + min := k.MinVersion() + max := k.MaxVersion() + switch { + case min > maxVersion: + return min + case max < maxVersion: + return max + default: + return maxVersion + } +} + +func (k ApiKey) apiType() apiType { + if i := int(k); i >= 0 && i < len(apiTypes) { + return apiTypes[i] + } + return apiType{} +} + +const ( + Produce ApiKey = 0 + Fetch ApiKey = 1 + ListOffsets ApiKey = 2 + Metadata ApiKey = 3 + LeaderAndIsr ApiKey = 4 + StopReplica ApiKey = 5 + UpdateMetadata ApiKey = 6 + ControlledShutdown ApiKey = 7 + OffsetCommit ApiKey = 8 + OffsetFetch ApiKey = 9 + FindCoordinator ApiKey = 10 + JoinGroup ApiKey = 11 + Heartbeat ApiKey = 12 + LeaveGroup ApiKey = 13 + SyncGroup ApiKey = 14 + DescribeGroups ApiKey = 15 + ListGroups ApiKey = 16 + SaslHandshake ApiKey = 17 + ApiVersions ApiKey = 18 + CreateTopics ApiKey = 19 + DeleteTopics ApiKey = 20 + DeleteRecords ApiKey = 21 + InitProducerId ApiKey = 22 + OffsetForLeaderEpoch ApiKey = 23 + AddPartitionsToTxn ApiKey = 24 + AddOffsetsToTxn ApiKey = 25 + EndTxn ApiKey = 26 + WriteTxnMarkers ApiKey = 27 + TxnOffsetCommit ApiKey = 28 + DescribeAcls ApiKey = 29 + CreateAcls ApiKey = 30 + DeleteAcls ApiKey = 31 + DescribeConfigs ApiKey = 32 + AlterConfigs ApiKey = 33 + AlterReplicaLogDirs ApiKey = 34 + DescribeLogDirs ApiKey = 35 + SaslAuthenticate ApiKey = 36 + CreatePartitions ApiKey = 37 + CreateDelegationToken ApiKey = 38 + RenewDelegationToken ApiKey = 39 + ExpireDelegationToken ApiKey = 40 + DescribeDelegationToken ApiKey = 41 + DeleteGroups ApiKey = 42 + ElectLeaders ApiKey = 43 + IncrementalAlterConfigs ApiKey = 44 + AlterPartitionReassignments ApiKey = 45 + ListPartitionReassignments ApiKey = 46 + OffsetDelete ApiKey = 47 + DescribeClientQuotas ApiKey = 48 + AlterClientQuotas ApiKey = 49 + + numApis = 50 +) + +var apiNames = [numApis]string{ + Produce: "Produce", + Fetch: "Fetch", + ListOffsets: "ListOffsets", + Metadata: "Metadata", + LeaderAndIsr: "LeaderAndIsr", + StopReplica: "StopReplica", + UpdateMetadata: "UpdateMetadata", + ControlledShutdown: "ControlledShutdown", + OffsetCommit: "OffsetCommit", + OffsetFetch: "OffsetFetch", + FindCoordinator: "FindCoordinator", + JoinGroup: "JoinGroup", + Heartbeat: "Heartbeat", + LeaveGroup: "LeaveGroup", + SyncGroup: "SyncGroup", + DescribeGroups: "DescribeGroups", + ListGroups: "ListGroups", + SaslHandshake: "SaslHandshake", + ApiVersions: "ApiVersions", + CreateTopics: "CreateTopics", + DeleteTopics: "DeleteTopics", + DeleteRecords: "DeleteRecords", + InitProducerId: "InitProducerId", + OffsetForLeaderEpoch: "OffsetForLeaderEpoch", + AddPartitionsToTxn: "AddPartitionsToTxn", + AddOffsetsToTxn: "AddOffsetsToTxn", + EndTxn: "EndTxn", + WriteTxnMarkers: "WriteTxnMarkers", + TxnOffsetCommit: "TxnOffsetCommit", + DescribeAcls: "DescribeAcls", + CreateAcls: "CreateAcls", + DeleteAcls: "DeleteAcls", + DescribeConfigs: "DescribeConfigs", + AlterConfigs: "AlterConfigs", + AlterReplicaLogDirs: "AlterReplicaLogDirs", + DescribeLogDirs: "DescribeLogDirs", + SaslAuthenticate: "SaslAuthenticate", + CreatePartitions: "CreatePartitions", + CreateDelegationToken: "CreateDelegationToken", + RenewDelegationToken: "RenewDelegationToken", + ExpireDelegationToken: "ExpireDelegationToken", + DescribeDelegationToken: "DescribeDelegationToken", + DeleteGroups: "DeleteGroups", + ElectLeaders: "ElectLeaders", + IncrementalAlterConfigs: "IncrementalAlterConfigs", + AlterPartitionReassignments: "AlterPartitionReassignments", + ListPartitionReassignments: "ListPartitionReassignments", + OffsetDelete: "OffsetDelete", + DescribeClientQuotas: "DescribeClientQuotas", + AlterClientQuotas: "AlterClientQuotas", +} + +type messageType struct { + version int16 + flexible bool + gotype reflect.Type + decode decodeFunc + encode encodeFunc +} + +func (t *messageType) new() Message { + return reflect.New(t.gotype).Interface().(Message) +} + +type apiType struct { + requests []messageType + responses []messageType +} + +func (t apiType) minVersion() int16 { + if len(t.requests) == 0 { + return 0 + } + return t.requests[0].version +} + +func (t apiType) maxVersion() int16 { + if len(t.requests) == 0 { + return 0 + } + return t.requests[len(t.requests)-1].version +} + +var apiTypes [numApis]apiType + +// Register is automatically called by sub-packages are imported to install a +// new pair of request/response message types. +func Register(req, res Message) { + k1 := req.ApiKey() + k2 := res.ApiKey() + + if k1 != k2 { + panic(fmt.Sprintf("[%T/%T]: request and response API keys mismatch: %d != %d", req, res, k1, k2)) + } + + apiTypes[k1] = apiType{ + requests: typesOf(req), + responses: typesOf(res), + } +} + +func typesOf(v interface{}) []messageType { + return makeTypes(reflect.TypeOf(v).Elem()) +} + +func makeTypes(t reflect.Type) []messageType { + minVersion := int16(-1) + maxVersion := int16(-1) + + // All future versions will be flexible (according to spec), so don't need to + // worry about maxes here. + minFlexibleVersion := int16(-1) + + forEachStructField(t, func(_ reflect.Type, _ index, tag string) { + forEachStructTag(tag, func(tag structTag) bool { + if minVersion < 0 || tag.MinVersion < minVersion { + minVersion = tag.MinVersion + } + if maxVersion < 0 || tag.MaxVersion > maxVersion { + maxVersion = tag.MaxVersion + } + if tag.TagID > -2 && (minFlexibleVersion < 0 || tag.MinVersion < minFlexibleVersion) { + minFlexibleVersion = tag.MinVersion + } + return true + }) + }) + + types := make([]messageType, 0, (maxVersion-minVersion)+1) + + for v := minVersion; v <= maxVersion; v++ { + flexible := minFlexibleVersion >= 0 && v >= minFlexibleVersion + + types = append(types, messageType{ + version: v, + gotype: t, + flexible: flexible, + decode: decodeFuncOf(t, v, flexible, structTag{}), + encode: encodeFuncOf(t, v, flexible, structTag{}), + }) + } + + return types +} + +type structTag struct { + MinVersion int16 + MaxVersion int16 + Compact bool + Nullable bool + TagID int +} + +func forEachStructTag(tag string, do func(structTag) bool) { + if tag == "-" { + return // special case to ignore the field + } + + forEach(tag, '|', func(s string) bool { + tag := structTag{ + MinVersion: -1, + MaxVersion: -1, + + // Legitimate tag IDs can start at 0. We use -1 as a placeholder to indicate + // that the message type is flexible, so that leaves -2 as the default for + // indicating that there is no tag ID and the message is not flexible. + TagID: -2, + } + + var err error + forEach(s, ',', func(s string) bool { + switch { + case strings.HasPrefix(s, "min="): + tag.MinVersion, err = parseVersion(s[4:]) + case strings.HasPrefix(s, "max="): + tag.MaxVersion, err = parseVersion(s[4:]) + case s == "tag": + tag.TagID = -1 + case strings.HasPrefix(s, "tag="): + tag.TagID, err = strconv.Atoi(s[4:]) + case s == "compact": + tag.Compact = true + case s == "nullable": + tag.Nullable = true + default: + err = fmt.Errorf("unrecognized option: %q", s) + } + return err == nil + }) + + if err != nil { + panic(fmt.Errorf("malformed struct tag: %w", err)) + } + + if tag.MinVersion < 0 && tag.MaxVersion >= 0 { + panic(fmt.Errorf("missing minimum version in struct tag: %q", s)) + } + + if tag.MaxVersion < 0 && tag.MinVersion >= 0 { + panic(fmt.Errorf("missing maximum version in struct tag: %q", s)) + } + + if tag.MinVersion > tag.MaxVersion { + panic(fmt.Errorf("invalid version range in struct tag: %q", s)) + } + + return do(tag) + }) +} + +func forEach(s string, sep byte, do func(string) bool) bool { + for len(s) != 0 { + p := "" + i := strings.IndexByte(s, sep) + if i < 0 { + p, s = s, "" + } else { + p, s = s[:i], s[i+1:] + } + if !do(p) { + return false + } + } + return true +} + +func forEachStructField(t reflect.Type, do func(reflect.Type, index, string)) { + for i, n := 0, t.NumField(); i < n; i++ { + f := t.Field(i) + + if f.PkgPath != "" && f.Name != "_" { + continue + } + + kafkaTag, ok := f.Tag.Lookup("kafka") + if !ok { + kafkaTag = "|" + } + + do(f.Type, indexOf(f), kafkaTag) + } +} + +func parseVersion(s string) (int16, error) { + if !strings.HasPrefix(s, "v") { + return 0, fmt.Errorf("invalid version number: %q", s) + } + i, err := strconv.ParseInt(s[1:], 10, 16) + if err != nil { + return 0, fmt.Errorf("invalid version number: %q: %w", s, err) + } + if i < 0 { + return 0, fmt.Errorf("invalid negative version number: %q", s) + } + return int16(i), nil +} + +func dontExpectEOF(err error) error { + switch err { + case nil: + return nil + case io.EOF: + return io.ErrUnexpectedEOF + default: + return err + } +} + +type Broker struct { + Rack string + Host string + Port int32 + ID int32 +} + +func (b Broker) String() string { + return net.JoinHostPort(b.Host, itoa(b.Port)) +} + +func (b Broker) Format(w fmt.State, v rune) { + switch v { + case 'd': + io.WriteString(w, itoa(b.ID)) + case 's': + io.WriteString(w, b.String()) + case 'v': + io.WriteString(w, itoa(b.ID)) + io.WriteString(w, " ") + io.WriteString(w, b.String()) + if b.Rack != "" { + io.WriteString(w, " ") + io.WriteString(w, b.Rack) + } + } +} + +func itoa(i int32) string { + return strconv.Itoa(int(i)) +} + +type Topic struct { + Name string + Error int16 + Partitions map[int32]Partition +} + +type Partition struct { + ID int32 + Error int16 + Leader int32 + Replicas []int32 + ISR []int32 + Offline []int32 +} + +// BrokerMessage is an extension of the Message interface implemented by some +// request types to customize the broker assignment logic. +type BrokerMessage interface { + // Given a representation of the kafka cluster state as argument, returns + // the broker that the message should be routed to. + Broker(Cluster) (Broker, error) +} + +// GroupMessage is an extension of the Message interface implemented by some +// request types to inform the program that they should be routed to a group +// coordinator. +type GroupMessage interface { + // Returns the group configured on the message. + Group() string +} + +// PreparedMessage is an extension of the Message interface implemented by some +// request types which may need to run some pre-processing on their state before +// being sent. +type PreparedMessage interface { + // Prepares the message before being sent to a kafka broker using the API + // version passed as argument. + Prepare(apiVersion int16) +} + +// Splitter is an interface implemented by messages that can be split into +// multiple requests and have their results merged back by a Merger. +type Splitter interface { + // For a given cluster layout, returns the list of messages constructed + // from the receiver for each requests that should be sent to the cluster. + // The second return value is a Merger which can be used to merge back the + // results of each request into a single message (or an error). + Split(Cluster) ([]Message, Merger, error) +} + +// Merger is an interface implemented by messages which can merge multiple +// results into one response. +type Merger interface { + // Given a list of message and associated results, merge them back into a + // response (or an error). The results must be either Message or error + // values, other types should trigger a panic. + Merge(messages []Message, results []interface{}) (Message, error) +} + +// Result converts r to a Message or and error, or panics if r could be be +// converted to these types. +func Result(r interface{}) (Message, error) { + switch v := r.(type) { + case Message: + return v, nil + case error: + return nil, v + default: + panic(fmt.Errorf("BUG: result must be a message or an error but not %T", v)) + } +} diff --git a/protocol/protocol_test.go b/protocol/protocol_test.go new file mode 100644 index 000000000..b0c5f1ae7 --- /dev/null +++ b/protocol/protocol_test.go @@ -0,0 +1,281 @@ +package protocol + +import ( + "bytes" + "reflect" + "testing" +) + +type testType struct { + Field1 string `kafka:"min=v0,max=v4,nullable"` + Field2 int16 `kafka:"min=v2,max=v4"` + Field3 []byte `kafka:"min=v2,max=v4,nullable"` + SubTypes []testSubType `kafka:"min=v1,max=v4"` + + TaggedField1 int8 `kafka:"min=v3,max=v4,tag=0"` + TaggedField2 string `kafka:"min=v4,max=v4,tag=1"` +} + +type testSubType struct { + SubField1 int8 `kafka:"min=v1,max=v4"` +} + +func TestMakeFlexibleTypes(t *testing.T) { + types := makeTypes(reflect.TypeOf(&testType{}).Elem()) + if len(types) != 5 { + t.Error( + "Wrong number of types", + "expected", 5, + "got", len(types), + ) + } + + fv := []int16{} + + for _, to := range types { + if to.flexible { + fv = append(fv, to.version) + } + } + + if !reflect.DeepEqual([]int16{3, 4}, fv) { + t.Error( + "Unexpected flexible versions", + "expected", []int16{3, 4}, + "got", fv, + ) + } +} + +func TestEncodeDecodeFlexibleType(t *testing.T) { + f := &testType{ + Field1: "value1", + Field2: 15, + Field3: []byte("hello"), + SubTypes: []testSubType{ + { + SubField1: 2, + }, + { + SubField1: 3, + }, + }, + + TaggedField1: 34, + TaggedField2: "taggedValue2", + } + + b := &bytes.Buffer{} + e := &encoder{writer: b} + + types := makeTypes(reflect.TypeOf(&testType{}).Elem()) + ft := types[4] + ft.encode(e, valueOf(f)) + if e.err != nil { + t.Error( + "Error during encoding", + "expected", nil, + "got", e.err, + ) + } + + exp := []byte{ + // size of "value1" + 1 + 7, + // "value1" + 118, 97, 108, 117, 101, 49, + // 15 as 16-bit int + 0, 15, + // size of []byte("hello") + 1 + 6, + // []byte("hello") + 104, 101, 108, 108, 111, + // size of []SubTypes + 1 + 3, + // 2 as 8-bit int + 2, + // tag buffer for first SubType struct + 0, + // 3 as 8-bit int + 3, + // tag buffer for second SubType struct + 0, + // number of tagged fields + 2, + // id of first tagged field + 0, + // size of first tagged field + 1, + // 34 as 8-bit int + 34, + // id of second tagged field + 1, + // size of second tagged field + 13, + // size of "taggedValue2" + 1 + 13, + // "taggedValue2" + 116, 97, 103, 103, 101, 100, 86, 97, 108, 117, 101, 50, + } + + if !reflect.DeepEqual(exp, b.Bytes()) { + t.Error( + "Wrong encoded output", + "expected", exp, + "got", b.Bytes(), + ) + } + + b = &bytes.Buffer{} + b.Write(exp) + d := &decoder{reader: b, remain: len(exp)} + + f2 := &testType{} + ft.decode(d, valueOf(f2)) + if d.err != nil { + t.Error( + "Error during decoding", + "expected", nil, + "got", e.err, + ) + } + + if !reflect.DeepEqual(f, f2) { + t.Error( + "Decoded value does not equal encoded one", + "expected", *f, + "got", *f2, + ) + } +} + +func TestVarInts(t *testing.T) { + type tc struct { + input int64 + expVarInt []byte + expUVarInt []byte + } + + tcs := []tc{ + { + input: 12, + expVarInt: []byte{24}, + expUVarInt: []byte{12}, + }, + { + input: 63, + expVarInt: []byte{126}, + expUVarInt: []byte{63}, + }, + { + input: -64, + expVarInt: []byte{127}, + expUVarInt: []byte{192, 255, 255, 255, 255, 255, 255, 255, 255, 1}, + }, + { + input: 64, + expVarInt: []byte{128, 1}, + expUVarInt: []byte{64}, + }, + { + input: 127, + expVarInt: []byte{254, 1}, + expUVarInt: []byte{127}, + }, + { + input: 128, + expVarInt: []byte{128, 2}, + expUVarInt: []byte{128, 1}, + }, + { + input: 129, + expVarInt: []byte{130, 2}, + expUVarInt: []byte{129, 1}, + }, + { + input: 12345, + expVarInt: []byte{242, 192, 1}, + expUVarInt: []byte{185, 96}, + }, + { + input: 123456789101112, + expVarInt: []byte{240, 232, 249, 224, 144, 146, 56}, + expUVarInt: []byte{184, 244, 188, 176, 136, 137, 28}, + }, + } + + for _, tc := range tcs { + b := &bytes.Buffer{} + e := &encoder{writer: b} + e.writeVarInt(tc.input) + if e.err != nil { + t.Errorf( + "Unexpected error encoding %d as varInt: %+v", + tc.input, + e.err, + ) + } + if !reflect.DeepEqual(b.Bytes(), tc.expVarInt) { + t.Error( + "Wrong output encoding value", tc.input, "as varInt", + "expected", tc.expVarInt, + "got", b.Bytes(), + ) + } + expLen := sizeOfVarInt(tc.input) + if expLen != len(b.Bytes()) { + t.Error( + "Wrong sizeOf for", tc.input, "as varInt", + "expected", expLen, + "got", len(b.Bytes()), + ) + } + + d := &decoder{reader: b, remain: len(b.Bytes())} + v := d.readVarInt() + if v != tc.input { + t.Error( + "Decoded varInt value does not equal encoded one", + "expected", tc.input, + "got", v, + ) + } + + b = &bytes.Buffer{} + e = &encoder{writer: b} + e.writeUnsignedVarInt(uint64(tc.input)) + if e.err != nil { + t.Errorf( + "Unexpected error encoding %d as unsignedVarInt: %+v", + tc.input, + e.err, + ) + } + if !reflect.DeepEqual(b.Bytes(), tc.expUVarInt) { + t.Error( + "Wrong output encoding value", tc.input, "as unsignedVarInt", + "expected", tc.expUVarInt, + "got", b.Bytes(), + ) + } + expLen = sizeOfUnsignedVarInt(uint64(tc.input)) + if expLen != len(b.Bytes()) { + t.Error( + "Wrong sizeOf for", tc.input, "as unsignedVarInt", + "expected", expLen, + "got", len(b.Bytes()), + ) + } + + d = &decoder{reader: b, remain: len(b.Bytes())} + v = int64(d.readUnsignedVarInt()) + if v != tc.input { + t.Error( + "Decoded unsignedVarInt value does not equal encoded one", + "expected", tc.input, + "got", v, + ) + } + + } +} diff --git a/protocol/prototest/bytes.go b/protocol/prototest/bytes.go new file mode 100644 index 000000000..3938e14d6 --- /dev/null +++ b/protocol/prototest/bytes.go @@ -0,0 +1,15 @@ +package prototest + +import ( + "github.com/segmentio/kafka-go/protocol" +) + +// Bytes constructs a Bytes which exposes the content of b. +func Bytes(b []byte) protocol.Bytes { + return protocol.NewBytes(b) +} + +// String constructs a Bytes which exposes the content of s. +func String(s string) protocol.Bytes { + return protocol.NewBytes([]byte(s)) +} diff --git a/protocol/prototest/prototest.go b/protocol/prototest/prototest.go new file mode 100644 index 000000000..9186a3f46 --- /dev/null +++ b/protocol/prototest/prototest.go @@ -0,0 +1,188 @@ +package prototest + +import ( + "bytes" + "io" + "reflect" + "time" + + "github.com/segmentio/kafka-go/protocol" +) + +func deepEqual(x1, x2 interface{}) bool { + if x1 == nil { + return x2 == nil + } + if r1, ok := x1.(protocol.RecordReader); ok { + if r2, ok := x2.(protocol.RecordReader); ok { + return deepEqualRecords(r1, r2) + } + return false + } + if b1, ok := x1.(protocol.Bytes); ok { + if b2, ok := x2.(protocol.Bytes); ok { + return deepEqualBytes(b1, b2) + } + return false + } + if t1, ok := x1.(time.Time); ok { + if t2, ok := x2.(time.Time); ok { + return t1.Equal(t2) + } + return false + } + return deepEqualValue(reflect.ValueOf(x1), reflect.ValueOf(x2)) +} + +func deepEqualValue(v1, v2 reflect.Value) bool { + t1 := v1.Type() + t2 := v2.Type() + + if t1 != t2 { + return false + } + + switch v1.Kind() { + case reflect.Bool: + return v1.Bool() == v2.Bool() + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v1.Int() == v2.Int() + case reflect.String: + return v1.String() == v2.String() + case reflect.Struct: + return deepEqualStruct(v1, v2) + case reflect.Ptr: + return deepEqualPtr(v1, v2) + case reflect.Slice: + return deepEqualSlice(v1, v2) + default: + panic("comparing values of unsupported type: " + v1.Type().String()) + } +} + +func deepEqualPtr(v1, v2 reflect.Value) bool { + if v1.IsNil() { + return v2.IsNil() + } + return deepEqual(v1.Elem().Interface(), v2.Elem().Interface()) +} + +func deepEqualStruct(v1, v2 reflect.Value) bool { + t := v1.Type() + n := t.NumField() + + for i := 0; i < n; i++ { + f := t.Field(i) + + if f.PkgPath != "" { // ignore unexported fields + continue + } + + f1 := v1.Field(i) + f2 := v2.Field(i) + + if !deepEqual(f1.Interface(), f2.Interface()) { + return false + } + } + + return true +} + +func deepEqualSlice(v1, v2 reflect.Value) bool { + t := v1.Type() + e := t.Elem() + + if e.Kind() == reflect.Uint8 { // []byte + return bytes.Equal(v1.Bytes(), v2.Bytes()) + } + + n1 := v1.Len() + n2 := v2.Len() + + if n1 != n2 { + return false + } + + for i := 0; i < n1; i++ { + f1 := v1.Index(i) + f2 := v2.Index(i) + + if !deepEqual(f1.Interface(), f2.Interface()) { + return false + } + } + + return true +} + +func deepEqualBytes(s1, s2 protocol.Bytes) bool { + if s1 == nil { + return s2 == nil + } + + if s2 == nil { + return false + } + + n1 := s1.Len() + n2 := s2.Len() + + if n1 != n2 { + return false + } + + b1 := make([]byte, n1) + b2 := make([]byte, n2) + + if _, err := s1.(io.ReaderAt).ReadAt(b1, 0); err != nil { + panic(err) + } + + if _, err := s2.(io.ReaderAt).ReadAt(b2, 0); err != nil { + panic(err) + } + + return bytes.Equal(b1, b2) +} + +func deepEqualRecords(r1, r2 protocol.RecordReader) bool { + for { + rec1, err1 := r1.ReadRecord() + rec2, err2 := r2.ReadRecord() + + if err1 != nil || err2 != nil { + return err1 == err2 + } + + if !deepEqualRecord(rec1, rec2) { + return false + } + } +} + +func deepEqualRecord(r1, r2 *protocol.Record) bool { + if r1.Offset != r2.Offset { + return false + } + + if !r1.Time.Equal(r2.Time) { + return false + } + + if !deepEqualBytes(r1.Key, r2.Key) { + return false + } + + if !deepEqualBytes(r1.Value, r2.Value) { + return false + } + + return deepEqual(r1.Headers, r2.Headers) +} + +func reset(v interface{}) { + if r, _ := v.(interface{ Reset() }); r != nil { + r.Reset() + } +} diff --git a/protocol/prototest/reflect.go b/protocol/prototest/reflect.go new file mode 100644 index 000000000..5c3d0a1d7 --- /dev/null +++ b/protocol/prototest/reflect.go @@ -0,0 +1,142 @@ +package prototest + +import ( + "errors" + "io" + "reflect" + "time" + + "github.com/segmentio/kafka-go/protocol" +) + +var ( + recordReader = reflect.TypeOf((*protocol.RecordReader)(nil)).Elem() +) + +func closeMessage(m protocol.Message) { + forEachField(reflect.ValueOf(m), func(v reflect.Value) { + if v.Type().Implements(recordReader) { + rr := v.Interface().(protocol.RecordReader) + for { + r, err := rr.ReadRecord() + if err != nil { + break + } + if r.Key != nil { + r.Key.Close() + } + if r.Value != nil { + r.Value.Close() + } + } + } + }) +} + +func load(v interface{}) (reset func()) { + return loadValue(reflect.ValueOf(v)) +} + +func loadValue(v reflect.Value) (reset func()) { + resets := []func(){} + + forEachField(v, func(f reflect.Value) { + switch x := f.Interface().(type) { + case protocol.RecordReader: + records := loadRecords(x) + resetFunc := func() { + f.Set(reflect.ValueOf(protocol.NewRecordReader(makeRecords(records)...))) + } + resetFunc() + resets = append(resets, resetFunc) + } + }) + + return func() { + for _, f := range resets { + f() + } + } +} + +func forEachField(v reflect.Value, do func(reflect.Value)) { + for v.Kind() == reflect.Ptr { + if v.IsNil() { + return + } + v = v.Elem() + } + + switch v.Kind() { + case reflect.Slice: + for i, n := 0, v.Len(); i < n; i++ { + forEachField(v.Index(i), do) + } + + case reflect.Struct: + for i, n := 0, v.NumField(); i < n; i++ { + forEachField(v.Field(i), do) + } + + default: + do(v) + } +} + +type memoryRecord struct { + offset int64 + time time.Time + key []byte + value []byte + headers []protocol.Header +} + +func (m *memoryRecord) Record() protocol.Record { + return protocol.Record{ + Offset: m.offset, + Time: m.time, + Key: protocol.NewBytes(m.key), + Value: protocol.NewBytes(m.value), + Headers: m.headers, + } +} + +func makeRecords(memoryRecords []memoryRecord) []protocol.Record { + records := make([]protocol.Record, len(memoryRecords)) + for i, m := range memoryRecords { + records[i] = m.Record() + } + return records +} + +func loadRecords(r protocol.RecordReader) []memoryRecord { + records := []memoryRecord{} + + for { + rec, err := r.ReadRecord() + if err != nil { + if errors.Is(err, io.EOF) { + return records + } + panic(err) + } + records = append(records, memoryRecord{ + offset: rec.Offset, + time: rec.Time, + key: readAll(rec.Key), + value: readAll(rec.Value), + headers: rec.Headers, + }) + } +} + +func readAll(bytes protocol.Bytes) []byte { + if bytes != nil { + defer bytes.Close() + } + b, err := protocol.ReadAll(bytes) + if err != nil { + panic(err) + } + return b +} diff --git a/protocol/prototest/request.go b/protocol/prototest/request.go new file mode 100644 index 000000000..15c6e79c8 --- /dev/null +++ b/protocol/prototest/request.go @@ -0,0 +1,99 @@ +package prototest + +import ( + "bufio" + "bytes" + "encoding/hex" + "fmt" + "io" + "testing" + + "github.com/segmentio/kafka-go/protocol" +) + +func TestRequest(t *testing.T, version int16, msg protocol.Message) { + reset := load(msg) + + t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { + b := &bytes.Buffer{} + + if err := protocol.WriteRequest(b, version, 1234, "me", msg); err != nil { + t.Fatal(err) + } + + reset() + + t.Logf("\n%s\n", hex.Dump(b.Bytes())) + + apiVersion, correlationID, clientID, req, err := protocol.ReadRequest(b) + if err != nil { + t.Fatal(err) + } + if apiVersion != version { + t.Errorf("api version mismatch: %d != %d", apiVersion, version) + } + if correlationID != 1234 { + t.Errorf("correlation id mismatch: %d != %d", correlationID, 1234) + } + if clientID != "me" { + t.Errorf("client id mismatch: %q != %q", clientID, "me") + } + if !deepEqual(msg, req) { + t.Errorf("request message mismatch:") + t.Logf("expected: %+v", msg) + t.Logf("found: %+v", req) + } + }) +} + +func BenchmarkRequest(b *testing.B, version int16, msg protocol.Message) { + reset := load(msg) + + b.Run(fmt.Sprintf("v%d", version), func(b *testing.B) { + buffer := &bytes.Buffer{} + buffer.Grow(1024) + + b.Run("read", func(b *testing.B) { + w := io.Writer(buffer) + + if err := protocol.WriteRequest(w, version, 1234, "client", msg); err != nil { + b.Fatal(err) + } + + reset() + + p := buffer.Bytes() + x := bytes.NewReader(p) + r := bufio.NewReader(x) + + for i := 0; i < b.N; i++ { + _, _, _, req, err := protocol.ReadRequest(r) + if err != nil { + b.Fatal(err) + } + closeMessage(req) + x.Reset(p) + r.Reset(x) + } + + b.SetBytes(int64(len(p))) + buffer.Reset() + }) + + b.Run("write", func(b *testing.B) { + w := io.Writer(buffer) + n := int64(0) + + for i := 0; i < b.N; i++ { + if err := protocol.WriteRequest(w, version, 1234, "client", msg); err != nil { + b.Fatal(err) + } + reset() + n = int64(buffer.Len()) + buffer.Reset() + } + + b.SetBytes(n) + }) + }) +} diff --git a/protocol/prototest/response.go b/protocol/prototest/response.go new file mode 100644 index 000000000..9a66161f1 --- /dev/null +++ b/protocol/prototest/response.go @@ -0,0 +1,95 @@ +package prototest + +import ( + "bufio" + "bytes" + "encoding/hex" + "fmt" + "io" + "testing" + + "github.com/segmentio/kafka-go/protocol" +) + +func TestResponse(t *testing.T, version int16, msg protocol.Message) { + reset := load(msg) + + t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { + b := &bytes.Buffer{} + + if err := protocol.WriteResponse(b, version, 1234, msg); err != nil { + t.Fatal(err) + } + + reset() + + t.Logf("\n%s", hex.Dump(b.Bytes())) + + correlationID, res, err := protocol.ReadResponse(b, msg.ApiKey(), version) + if err != nil { + t.Fatal(err) + } + if correlationID != 1234 { + t.Errorf("correlation id mismatch: %d != %d", correlationID, 1234) + } + if !deepEqual(msg, res) { + t.Errorf("response message mismatch:") + t.Logf("expected: %+v", msg) + t.Logf("found: %+v", res) + } + closeMessage(res) + }) +} + +func BenchmarkResponse(b *testing.B, version int16, msg protocol.Message) { + reset := load(msg) + + b.Run(fmt.Sprintf("v%d", version), func(b *testing.B) { + apiKey := msg.ApiKey() + buffer := &bytes.Buffer{} + buffer.Grow(1024) + + b.Run("read", func(b *testing.B) { + w := io.Writer(buffer) + + if err := protocol.WriteResponse(w, version, 1234, msg); err != nil { + b.Fatal(err) + } + + reset() + + p := buffer.Bytes() + x := bytes.NewReader(p) + r := bufio.NewReader(x) + + for i := 0; i < b.N; i++ { + _, res, err := protocol.ReadResponse(r, apiKey, version) + if err != nil { + b.Fatal(err) + } + closeMessage(res) + x.Reset(p) + r.Reset(x) + } + + b.SetBytes(int64(len(p))) + buffer.Reset() + }) + + b.Run("write", func(b *testing.B) { + w := io.Writer(buffer) + n := int64(0) + + for i := 0; i < b.N; i++ { + if err := protocol.WriteResponse(w, version, 1234, msg); err != nil { + b.Fatal(err) + } + reset() + n = int64(buffer.Len()) + buffer.Reset() + } + + b.SetBytes(n) + }) + }) +} diff --git a/protocol/record.go b/protocol/record.go new file mode 100644 index 000000000..84594868b --- /dev/null +++ b/protocol/record.go @@ -0,0 +1,314 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "time" + + "github.com/segmentio/kafka-go/compress" +) + +// Attributes is a bitset representing special attributes set on records. +type Attributes int16 + +const ( + Gzip Attributes = Attributes(compress.Gzip) // 1 + Snappy Attributes = Attributes(compress.Snappy) // 2 + Lz4 Attributes = Attributes(compress.Lz4) // 3 + Zstd Attributes = Attributes(compress.Zstd) // 4 + Transactional Attributes = 1 << 4 + Control Attributes = 1 << 5 +) + +func (a Attributes) Compression() compress.Compression { + return compress.Compression(a & 7) +} + +func (a Attributes) Transactional() bool { + return (a & Transactional) != 0 +} + +func (a Attributes) Control() bool { + return (a & Control) != 0 +} + +func (a Attributes) String() string { + s := a.Compression().String() + if a.Transactional() { + s += "+transactional" + } + if a.Control() { + s += "+control" + } + return s +} + +// Header represents a single entry in a list of record headers. +type Header struct { + Key string + Value []byte +} + +// Record is an interface representing a single kafka record. +// +// Record values are not safe to use concurrently from multiple goroutines. +type Record struct { + // The offset at which the record exists in a topic partition. This value + // is ignored in produce requests. + Offset int64 + + // Returns the time of the record. This value may be omitted in produce + // requests to let kafka set the time when it saves the record. + Time time.Time + + // Returns a byte sequence containing the key of this record. The returned + // sequence may be nil to indicate that the record has no key. If the record + // is part of a RecordSet, the content of the key must remain valid at least + // until the record set is closed (or until the key is closed). + Key Bytes + + // Returns a byte sequence containing the value of this record. The returned + // sequence may be nil to indicate that the record has no value. If the + // record is part of a RecordSet, the content of the value must remain valid + // at least until the record set is closed (or until the value is closed). + Value Bytes + + // Returns the list of headers associated with this record. The returned + // slice may be reused across calls, the program should use it as an + // immutable value. + Headers []Header +} + +// RecordSet represents a sequence of records in Produce requests and Fetch +// responses. All v0, v1, and v2 formats are supported. +type RecordSet struct { + // The message version that this record set will be represented as, valid + // values are 1, or 2. + // + // When reading, this is the value of the highest version used in the + // batches that compose the record set. + // + // When writing, this value dictates the format that the records will be + // encoded in. + Version int8 + + // Attributes set on the record set. + // + // When reading, the attributes are the combination of all attributes in + // the batches that compose the record set. + // + // When writing, the attributes apply to the whole sequence of records in + // the set. + Attributes Attributes + + // A reader exposing the sequence of records. + // + // When reading a RecordSet from an io.Reader, the Records field will be a + // *RecordStream. If the program needs to access the details of each batch + // that compose the stream, it may use type assertions to access the + // underlying types of each batch. + Records RecordReader +} + +// bufferedReader is an interface implemented by types like bufio.Reader, which +// we use to optimize prefix reads by accessing the internal buffer directly +// through calls to Peek. +type bufferedReader interface { + Discard(int) (int, error) + Peek(int) ([]byte, error) +} + +// bytesBuffer is an interface implemented by types like bytes.Buffer, which we +// use to optimize prefix reads by accessing the internal buffer directly +// through calls to Bytes. +type bytesBuffer interface { + Bytes() []byte +} + +// magicByteOffset is the position of the magic byte in all versions of record +// sets in the kafka protocol. +const magicByteOffset = 16 + +// ReadFrom reads the representation of a record set from r into rs, returning +// the number of bytes consumed from r, and an non-nil error if the record set +// could not be read. +func (rs *RecordSet) ReadFrom(r io.Reader) (int64, error) { + d, _ := r.(*decoder) + if d == nil { + d = &decoder{ + reader: r, + remain: 4, + } + } + + *rs = RecordSet{} + limit := d.remain + size := d.readInt32() + + if d.err != nil { + return int64(limit - d.remain), d.err + } + + if size <= 0 { + return 4, nil + } + + stream := &RecordStream{ + Records: make([]RecordReader, 0, 4), + } + + var err error + d.remain = int(size) + + for d.remain > 0 && err == nil { + var version byte + + if d.remain < (magicByteOffset + 1) { + if len(stream.Records) != 0 { + break + } + return 4, fmt.Errorf("impossible record set shorter than %d bytes", magicByteOffset+1) + } + + switch r := d.reader.(type) { + case bufferedReader: + b, err := r.Peek(magicByteOffset + 1) + if err != nil { + n, _ := r.Discard(len(b)) + return 4 + int64(n), dontExpectEOF(err) + } + version = b[magicByteOffset] + case bytesBuffer: + version = r.Bytes()[magicByteOffset] + default: + b := make([]byte, magicByteOffset+1) + if n, err := io.ReadFull(d.reader, b); err != nil { + return 4 + int64(n), dontExpectEOF(err) + } + version = b[magicByteOffset] + // Reconstruct the prefix that we had to read to determine the version + // of the record set from the magic byte. + // + // Technically this may recurisvely stack readers when consuming all + // items of the batch, which could hurt performance. In practice this + // path should not be taken tho, since the decoder would read from a + // *bufio.Reader which implements the bufferedReader interface. + d.reader = io.MultiReader(bytes.NewReader(b), d.reader) + } + + var tmp RecordSet + switch version { + case 0, 1: + err = tmp.readFromVersion1(d) + case 2: + err = tmp.readFromVersion2(d) + default: + err = fmt.Errorf("unsupported message version %d for message of size %d", version, size) + } + + if tmp.Version > rs.Version { + rs.Version = tmp.Version + } + + rs.Attributes |= tmp.Attributes + + if tmp.Records != nil { + stream.Records = append(stream.Records, tmp.Records) + } + } + + if len(stream.Records) != 0 { + rs.Records = stream + // Ignore errors if we've successfully read records, so the + // program can keep making progress. + err = nil + } + + d.discardAll() + rn := 4 + (int(size) - d.remain) + d.remain = limit - rn + return int64(rn), err +} + +// WriteTo writes the representation of rs into w. The value of rs.Version +// dictates which format that the record set will be represented as. +// +// The error will be ErrNoRecord if rs contained no records. +// +// Note: since this package is only compatible with kafka 0.10 and above, the +// method never produces messages in version 0. If rs.Version is zero, the +// method defaults to producing messages in version 1. +func (rs *RecordSet) WriteTo(w io.Writer) (int64, error) { + if rs.Records == nil { + return 0, ErrNoRecord + } + + // This optimization avoids rendering the record set in an intermediary + // buffer when the writer is already a pageBuffer, which is a common case + // due to the way WriteRequest and WriteResponse are implemented. + buffer, _ := w.(*pageBuffer) + bufferOffset := int64(0) + + if buffer != nil { + bufferOffset = buffer.Size() + } else { + buffer = newPageBuffer() + defer buffer.unref() + } + + size := packUint32(0) + buffer.Write(size[:]) // size placeholder + + var err error + switch rs.Version { + case 0, 1: + err = rs.writeToVersion1(buffer, bufferOffset+4) + case 2: + err = rs.writeToVersion2(buffer, bufferOffset+4) + default: + err = fmt.Errorf("unsupported record set version %d", rs.Version) + } + if err != nil { + return 0, err + } + + n := buffer.Size() - bufferOffset + if n == 0 { + size = packUint32(^uint32(0)) + } else { + size = packUint32(uint32(n) - 4) + } + buffer.WriteAt(size[:], bufferOffset) + + // This condition indicates that the output writer received by `WriteTo` was + // not a *pageBuffer, in which case we need to flush the buffered records + // data into it. + if buffer != w { + return buffer.WriteTo(w) + } + + return n, nil +} + +func makeTime(t int64) time.Time { + return time.Unix(t/1000, (t%1000)*int64(time.Millisecond)) +} + +func timestamp(t time.Time) int64 { + if t.IsZero() { + return 0 + } + return t.UnixNano() / int64(time.Millisecond) +} + +func packUint32(u uint32) (b [4]byte) { + binary.BigEndian.PutUint32(b[:], u) + return +} + +func packUint64(u uint64) (b [8]byte) { + binary.BigEndian.PutUint64(b[:], u) + return +} diff --git a/protocol/record_batch.go b/protocol/record_batch.go new file mode 100644 index 000000000..af9ddc5ca --- /dev/null +++ b/protocol/record_batch.go @@ -0,0 +1,369 @@ +package protocol + +import ( + "errors" + "io" + "time" +) + +// RecordReader is an interface representing a sequence of records. Record sets +// are used in both produce and fetch requests to represent the sequence of +// records that are sent to or receive from kafka brokers. +// +// RecordSet values are not safe to use concurrently from multiple goroutines. +type RecordReader interface { + // Returns the next record in the set, or io.EOF if the end of the sequence + // has been reached. + // + // The returned Record is guaranteed to be valid until the next call to + // ReadRecord. If the program needs to retain the Record value it must make + // a copy. + ReadRecord() (*Record, error) +} + +// NewRecordReader constructs a reader exposing the records passed as arguments. +func NewRecordReader(records ...Record) RecordReader { + switch len(records) { + case 0: + return emptyRecordReader{} + default: + r := &recordReader{records: make([]Record, len(records))} + copy(r.records, records) + return r + } +} + +// MultiRecordReader merges multiple record batches into one. +func MultiRecordReader(batches ...RecordReader) RecordReader { + switch len(batches) { + case 0: + return emptyRecordReader{} + case 1: + return batches[0] + default: + m := &multiRecordReader{batches: make([]RecordReader, len(batches))} + copy(m.batches, batches) + return m + } +} + +func forEachRecord(r RecordReader, f func(int, *Record) error) error { + for i := 0; ; i++ { + rec, err := r.ReadRecord() + + if err != nil { + if errors.Is(err, io.EOF) { + err = nil + } + return err + } + + if err := handleRecord(i, rec, f); err != nil { + return err + } + } +} + +func handleRecord(i int, r *Record, f func(int, *Record) error) error { + if r.Key != nil { + defer r.Key.Close() + } + if r.Value != nil { + defer r.Value.Close() + } + return f(i, r) +} + +type recordReader struct { + records []Record + index int +} + +func (r *recordReader) ReadRecord() (*Record, error) { + if i := r.index; i >= 0 && i < len(r.records) { + r.index++ + return &r.records[i], nil + } + return nil, io.EOF +} + +type multiRecordReader struct { + batches []RecordReader + index int +} + +func (m *multiRecordReader) ReadRecord() (*Record, error) { + for { + if m.index == len(m.batches) { + return nil, io.EOF + } + r, err := m.batches[m.index].ReadRecord() + if err == nil { + return r, nil + } + if !errors.Is(err, io.EOF) { + return nil, err + } + m.index++ + } +} + +func concatRecordReader(head RecordReader, tail RecordReader) RecordReader { + if head == nil { + return tail + } + if m, _ := head.(*multiRecordReader); m != nil { + m.batches = append(m.batches, tail) + return m + } + return MultiRecordReader(head, tail) +} + +// optimizedRecordReader is an implementation of a RecordReader which exposes a +// sequence +type optimizedRecordReader struct { + records []optimizedRecord + index int + buffer Record + headers [][]Header +} + +func (r *optimizedRecordReader) ReadRecord() (*Record, error) { + if i := r.index; i >= 0 && i < len(r.records) { + rec := &r.records[i] + r.index++ + r.buffer = Record{ + Offset: rec.offset, + Time: rec.time(), + Key: rec.key(), + Value: rec.value(), + } + if i < len(r.headers) { + r.buffer.Headers = r.headers[i] + } + return &r.buffer, nil + } + return nil, io.EOF +} + +type optimizedRecord struct { + offset int64 + timestamp int64 + keyRef *pageRef + valueRef *pageRef +} + +func (r *optimizedRecord) time() time.Time { + return makeTime(r.timestamp) +} + +func (r *optimizedRecord) key() Bytes { + return makeBytes(r.keyRef) +} + +func (r *optimizedRecord) value() Bytes { + return makeBytes(r.valueRef) +} + +func makeBytes(ref *pageRef) Bytes { + if ref == nil { + return nil + } + return ref +} + +type emptyRecordReader struct{} + +func (emptyRecordReader) ReadRecord() (*Record, error) { return nil, io.EOF } + +// ControlRecord represents a record read from a control batch. +type ControlRecord struct { + Offset int64 + Time time.Time + Version int16 + Type int16 + Data []byte + Headers []Header +} + +func ReadControlRecord(r *Record) (*ControlRecord, error) { + if r.Key != nil { + defer r.Key.Close() + } + if r.Value != nil { + defer r.Value.Close() + } + + k, err := ReadAll(r.Key) + if err != nil { + return nil, err + } + if k == nil { + return nil, Error("invalid control record with nil key") + } + if len(k) != 4 { + return nil, Errorf("invalid control record with key of size %d", len(k)) + } + + v, err := ReadAll(r.Value) + if err != nil { + return nil, err + } + + c := &ControlRecord{ + Offset: r.Offset, + Time: r.Time, + Version: readInt16(k[:2]), + Type: readInt16(k[2:]), + Data: v, + Headers: r.Headers, + } + + return c, nil +} + +func (cr *ControlRecord) Key() Bytes { + k := make([]byte, 4) + writeInt16(k[:2], cr.Version) + writeInt16(k[2:], cr.Type) + return NewBytes(k) +} + +func (cr *ControlRecord) Value() Bytes { + return NewBytes(cr.Data) +} + +func (cr *ControlRecord) Record() Record { + return Record{ + Offset: cr.Offset, + Time: cr.Time, + Key: cr.Key(), + Value: cr.Value(), + Headers: cr.Headers, + } +} + +// ControlBatch is an implementation of the RecordReader interface representing +// control batches returned by kafka brokers. +type ControlBatch struct { + Attributes Attributes + PartitionLeaderEpoch int32 + BaseOffset int64 + ProducerID int64 + ProducerEpoch int16 + BaseSequence int32 + Records RecordReader +} + +// NewControlBatch constructs a control batch from the list of records passed as +// arguments. +func NewControlBatch(records ...ControlRecord) *ControlBatch { + rawRecords := make([]Record, len(records)) + for i, cr := range records { + rawRecords[i] = cr.Record() + } + return &ControlBatch{ + Records: NewRecordReader(rawRecords...), + } +} + +func (c *ControlBatch) ReadRecord() (*Record, error) { + return c.Records.ReadRecord() +} + +func (c *ControlBatch) ReadControlRecord() (*ControlRecord, error) { + r, err := c.ReadRecord() + if err != nil { + return nil, err + } + if r.Key != nil { + defer r.Key.Close() + } + if r.Value != nil { + defer r.Value.Close() + } + return ReadControlRecord(r) +} + +func (c *ControlBatch) Offset() int64 { + return c.BaseOffset +} + +func (c *ControlBatch) Version() int { + return 2 +} + +// RecordBatch is an implementation of the RecordReader interface representing +// regular record batches (v2). +type RecordBatch struct { + Attributes Attributes + PartitionLeaderEpoch int32 + BaseOffset int64 + ProducerID int64 + ProducerEpoch int16 + BaseSequence int32 + Records RecordReader +} + +func (r *RecordBatch) ReadRecord() (*Record, error) { + return r.Records.ReadRecord() +} + +func (r *RecordBatch) Offset() int64 { + return r.BaseOffset +} + +func (r *RecordBatch) Version() int { + return 2 +} + +// MessageSet is an implementation of the RecordReader interface representing +// regular message sets (v1). +type MessageSet struct { + Attributes Attributes + BaseOffset int64 + Records RecordReader +} + +func (m *MessageSet) ReadRecord() (*Record, error) { + return m.Records.ReadRecord() +} + +func (m *MessageSet) Offset() int64 { + return m.BaseOffset +} + +func (m *MessageSet) Version() int { + return 1 +} + +// RecordStream is an implementation of the RecordReader interface which +// combines multiple underlying RecordReader and only expose records that +// are not from control batches. +type RecordStream struct { + Records []RecordReader + index int +} + +func (s *RecordStream) ReadRecord() (*Record, error) { + for { + if s.index < 0 || s.index >= len(s.Records) { + return nil, io.EOF + } + + if _, isControl := s.Records[s.index].(*ControlBatch); isControl { + s.index++ + continue + } + + r, err := s.Records[s.index].ReadRecord() + if err != nil { + if errors.Is(err, io.EOF) { + s.index++ + continue + } + } + + return r, err + } +} diff --git a/protocol/record_batch_test.go b/protocol/record_batch_test.go new file mode 100644 index 000000000..e48e568ea --- /dev/null +++ b/protocol/record_batch_test.go @@ -0,0 +1,200 @@ +package protocol + +import ( + "errors" + "io" + "reflect" + "testing" + "time" +) + +type memoryRecord struct { + offset int64 + time time.Time + key []byte + value []byte + headers []Header +} + +func (m *memoryRecord) Record() Record { + return Record{ + Offset: m.offset, + Time: m.time, + Key: NewBytes(m.key), + Value: NewBytes(m.value), + Headers: m.headers, + } +} + +func makeRecords(memoryRecords []memoryRecord) []Record { + records := make([]Record, len(memoryRecords)) + for i, m := range memoryRecords { + records[i] = m.Record() + } + return records +} + +func TestRecordReader(t *testing.T) { + now := time.Now() + + records := []memoryRecord{ + { + offset: 1, + time: now, + key: []byte("key-1"), + }, + { + offset: 2, + time: now.Add(time.Millisecond), + value: []byte("value-1"), + }, + { + offset: 3, + time: now.Add(time.Second), + key: []byte("key-3"), + value: []byte("value-3"), + headers: []Header{ + {Key: "answer", Value: []byte("42")}, + }, + }, + } + + r1 := NewRecordReader(makeRecords(records)...) + r2 := NewRecordReader(makeRecords(records)...) + assertRecords(t, r1, r2) +} + +func TestMultiRecordReader(t *testing.T) { + now := time.Now() + + records := []memoryRecord{ + { + offset: 1, + time: now, + key: []byte("key-1"), + }, + { + offset: 2, + time: now.Add(time.Millisecond), + value: []byte("value-1"), + }, + { + offset: 3, + time: now.Add(time.Second), + key: []byte("key-3"), + value: []byte("value-3"), + headers: []Header{ + {Key: "answer", Value: []byte("42")}, + }, + }, + } + + r1 := NewRecordReader(makeRecords(records)...) + r2 := MultiRecordReader( + NewRecordReader(makeRecords(records[:1])...), + NewRecordReader(makeRecords(records[1:])...), + ) + assertRecords(t, r1, r2) +} + +func TestControlRecord(t *testing.T) { + now := time.Now() + + records := []ControlRecord{ + { + Offset: 1, + Time: now, + Version: 2, + Type: 3, + }, + { + Offset: 2, + Time: now.Add(time.Second), + Version: 4, + Type: 5, + Data: []byte("Hello World!"), + Headers: []Header{ + {Key: "answer", Value: []byte("42")}, + }, + }, + } + + batch := NewControlBatch(records...) + found := make([]ControlRecord, 0, len(records)) + + for { + r, err := batch.ReadControlRecord() + if err != nil { + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + break + } + found = append(found, *r) + } + + if !reflect.DeepEqual(records, found) { + t.Error("control records mismatch") + } +} + +func assertRecords(t *testing.T, r1, r2 RecordReader) { + t.Helper() + + for { + rec1, err1 := r1.ReadRecord() + rec2, err2 := r2.ReadRecord() + + if err1 != nil || err2 != nil { + if err1 != err2 { + t.Error("errors mismatch:") + t.Log("expected:", err2) + t.Log("found: ", err1) + } + return + } + + if !equalRecords(rec1, rec2) { + t.Error("records mismatch:") + t.Logf("expected: %+v", rec2) + t.Logf("found: %+v", rec1) + } + } +} + +func equalRecords(r1, r2 *Record) bool { + if r1.Offset != r2.Offset { + return false + } + + if !r1.Time.Equal(r2.Time) { + return false + } + + k1 := readAll(r1.Key) + k2 := readAll(r2.Key) + + if !reflect.DeepEqual(k1, k2) { + return false + } + + v1 := readAll(r1.Value) + v2 := readAll(r2.Value) + + if !reflect.DeepEqual(v1, v2) { + return false + } + + return reflect.DeepEqual(r1.Headers, r2.Headers) +} + +func readAll(bytes Bytes) []byte { + if bytes != nil { + defer bytes.Close() + } + b, err := ReadAll(bytes) + if err != nil { + panic(err) + } + return b +} diff --git a/protocol/record_v1.go b/protocol/record_v1.go new file mode 100644 index 000000000..5757e1146 --- /dev/null +++ b/protocol/record_v1.go @@ -0,0 +1,243 @@ +package protocol + +import ( + "errors" + "hash/crc32" + "io" + "math" + "time" +) + +func readMessage(b *pageBuffer, d *decoder) (attributes int8, baseOffset, timestamp int64, key, value Bytes, err error) { + md := decoder{ + reader: d, + remain: 12, + } + + baseOffset = md.readInt64() + md.remain = int(md.readInt32()) + + crc := uint32(md.readInt32()) + md.setCRC(crc32.IEEETable) + magicByte := md.readInt8() + attributes = md.readInt8() + timestamp = int64(0) + + if magicByte != 0 { + timestamp = md.readInt64() + } + + keyOffset := b.Size() + keyLength := int(md.readInt32()) + hasKey := keyLength >= 0 + if hasKey { + md.writeTo(b, keyLength) + key = b.ref(keyOffset, b.Size()) + } + + valueOffset := b.Size() + valueLength := int(md.readInt32()) + hasValue := valueLength >= 0 + if hasValue { + md.writeTo(b, valueLength) + value = b.ref(valueOffset, b.Size()) + } + + if md.crc32 != crc { + err = Errorf("crc32 checksum mismatch (computed=%d found=%d)", md.crc32, crc) + } else { + err = dontExpectEOF(md.err) + } + + return +} + +func (rs *RecordSet) readFromVersion1(d *decoder) error { + var records RecordReader + + b := newPageBuffer() + defer b.unref() + + attributes, baseOffset, timestamp, key, value, err := readMessage(b, d) + if err != nil { + return err + } + + if compression := Attributes(attributes).Compression(); compression == 0 { + records = &message{ + Record: Record{ + Offset: baseOffset, + Time: makeTime(timestamp), + Key: key, + Value: value, + }, + } + } else { + // Can we have a non-nil key when reading a compressed message? + if key != nil { + key.Close() + } + if value == nil { + records = emptyRecordReader{} + } else { + defer value.Close() + + codec := compression.Codec() + if codec == nil { + return Errorf("unsupported compression codec: %d", compression) + } + decompressor := codec.NewReader(value) + defer decompressor.Close() + + b := newPageBuffer() + defer b.unref() + + d := &decoder{ + reader: decompressor, + remain: math.MaxInt32, + } + + r := &recordReader{ + records: make([]Record, 0, 32), + } + + for !d.done() { + _, offset, timestamp, key, value, err := readMessage(b, d) + if err != nil { + if errors.Is(err, io.ErrUnexpectedEOF) { + break + } + for _, rec := range r.records { + closeBytes(rec.Key) + closeBytes(rec.Value) + } + return err + } + r.records = append(r.records, Record{ + Offset: offset, + Time: makeTime(timestamp), + Key: key, + Value: value, + }) + } + + if baseOffset != 0 { + // https://kafka.apache.org/documentation/#messageset + // + // In version 1, to avoid server side re-compression, only the + // wrapper message will be assigned an offset. The inner messages + // will have relative offsets. The absolute offset can be computed + // using the offset from the outer message, which corresponds to the + // offset assigned to the last inner message. + lastRelativeOffset := int64(len(r.records)) - 1 + + for i := range r.records { + r.records[i].Offset = baseOffset - (lastRelativeOffset - r.records[i].Offset) + } + } + + records = r + } + } + + *rs = RecordSet{ + Version: 1, + Attributes: Attributes(attributes), + Records: records, + } + + return nil +} + +func (rs *RecordSet) writeToVersion1(buffer *pageBuffer, bufferOffset int64) error { + attributes := rs.Attributes + records := rs.Records + + if compression := attributes.Compression(); compression != 0 { + if codec := compression.Codec(); codec != nil { + // In the message format version 1, compression is achieved by + // compressing the value of a message which recursively contains + // the representation of the compressed message set. + subset := *rs + subset.Attributes &= ^7 // erase compression + + if err := subset.writeToVersion1(buffer, bufferOffset); err != nil { + return err + } + + compressed := newPageBuffer() + defer compressed.unref() + + compressor := codec.NewWriter(compressed) + defer compressor.Close() + + var err error + buffer.pages.scan(bufferOffset, buffer.Size(), func(b []byte) bool { + _, err = compressor.Write(b) + return err == nil + }) + if err != nil { + return err + } + if err := compressor.Close(); err != nil { + return err + } + + buffer.Truncate(int(bufferOffset)) + + records = &message{ + Record: Record{ + Value: compressed, + }, + } + } + } + + e := encoder{writer: buffer} + currentTimestamp := timestamp(time.Now()) + + return forEachRecord(records, func(i int, r *Record) error { + t := timestamp(r.Time) + if t == 0 { + t = currentTimestamp + } + + messageOffset := buffer.Size() + e.writeInt64(int64(i)) + e.writeInt32(0) // message size placeholder + e.writeInt32(0) // crc32 placeholder + e.setCRC(crc32.IEEETable) + e.writeInt8(1) // magic byte: version 1 + e.writeInt8(int8(attributes)) + e.writeInt64(t) + + if err := e.writeNullBytesFrom(r.Key); err != nil { + return err + } + + if err := e.writeNullBytesFrom(r.Value); err != nil { + return err + } + + b0 := packUint32(uint32(buffer.Size() - (messageOffset + 12))) + b1 := packUint32(e.crc32) + + buffer.WriteAt(b0[:], messageOffset+8) + buffer.WriteAt(b1[:], messageOffset+12) + e.setCRC(nil) + return nil + }) +} + +type message struct { + Record Record + read bool +} + +func (m *message) ReadRecord() (*Record, error) { + if m.read { + return nil, io.EOF + } + m.read = true + return &m.Record, nil +} diff --git a/protocol/record_v2.go b/protocol/record_v2.go new file mode 100644 index 000000000..366ec4bff --- /dev/null +++ b/protocol/record_v2.go @@ -0,0 +1,315 @@ +package protocol + +import ( + "fmt" + "hash/crc32" + "io" + "time" +) + +func (rs *RecordSet) readFromVersion2(d *decoder) error { + baseOffset := d.readInt64() + batchLength := d.readInt32() + + if int(batchLength) > d.remain || d.err != nil { + d.discardAll() + return nil + } + + dec := &decoder{ + reader: d, + remain: int(batchLength), + } + + partitionLeaderEpoch := dec.readInt32() + magicByte := dec.readInt8() + crc := dec.readInt32() + + dec.setCRC(crc32.MakeTable(crc32.Castagnoli)) + + attributes := dec.readInt16() + lastOffsetDelta := dec.readInt32() + firstTimestamp := dec.readInt64() + maxTimestamp := dec.readInt64() + producerID := dec.readInt64() + producerEpoch := dec.readInt16() + baseSequence := dec.readInt32() + numRecords := dec.readInt32() + reader := io.Reader(dec) + + // unused + _ = lastOffsetDelta + _ = maxTimestamp + + if compression := Attributes(attributes).Compression(); compression != 0 { + codec := compression.Codec() + if codec == nil { + return fmt.Errorf("unsupported compression codec (%d)", compression) + } + decompressor := codec.NewReader(reader) + defer decompressor.Close() + reader = decompressor + } + + buffer := newPageBuffer() + defer buffer.unref() + + _, err := buffer.ReadFrom(reader) + if err != nil { + return err + } + if dec.crc32 != uint32(crc) { + return fmt.Errorf("crc32 checksum mismatch (computed=%d found=%d)", dec.crc32, uint32(crc)) + } + + recordsLength := buffer.Len() + dec.reader = buffer + dec.remain = recordsLength + + records := make([]optimizedRecord, numRecords) + // These are two lazy allocators that will be used to optimize allocation of + // page references for keys and values. + // + // By default, no memory is allocated and on first use, numRecords page refs + // are allocated in a contiguous memory space, and the allocators return + // pointers into those arrays for each page ref that get requested. + // + // The reasoning is that kafka partitions typically have records of a single + // form, which either have no keys, no values, or both keys and values. + // Using lazy allocators adapts nicely to these patterns to only allocate + // the memory that is needed by the program, while still reducing the number + // of malloc calls made by the program. + // + // Using a single allocator for both keys and values keeps related values + // close by in memory, making access to the records more friendly to CPU + // caches. + alloc := pageRefAllocator{size: int(numRecords)} + // Following the same reasoning that kafka partitions will typically have + // records with repeating formats, we expect to either find records with + // no headers, or records which always contain headers. + // + // To reduce the memory footprint when records have no headers, the Header + // slices are lazily allocated in a separate array. + headers := ([][]Header)(nil) + + for i := range records { + r := &records[i] + _ = dec.readVarInt() // record length (unused) + _ = dec.readInt8() // record attributes (unused) + timestampDelta := dec.readVarInt() + offsetDelta := dec.readVarInt() + + r.offset = baseOffset + offsetDelta + r.timestamp = firstTimestamp + timestampDelta + + keyLength := dec.readVarInt() + keyOffset := int64(recordsLength - dec.remain) + if keyLength > 0 { + dec.discard(int(keyLength)) + } + + valueLength := dec.readVarInt() + valueOffset := int64(recordsLength - dec.remain) + if valueLength > 0 { + dec.discard(int(valueLength)) + } + + if numHeaders := dec.readVarInt(); numHeaders > 0 { + if headers == nil { + headers = make([][]Header, numRecords) + } + + h := make([]Header, numHeaders) + + for i := range h { + h[i] = Header{ + Key: dec.readVarString(), + Value: dec.readVarBytes(), + } + } + + headers[i] = h + } + + if dec.err != nil { + records = records[:i] + break + } + + if keyLength >= 0 { + r.keyRef = alloc.newPageRef() + buffer.refTo(r.keyRef, keyOffset, keyOffset+keyLength) + } + + if valueLength >= 0 { + r.valueRef = alloc.newPageRef() + buffer.refTo(r.valueRef, valueOffset, valueOffset+valueLength) + } + } + + // Note: it's unclear whether kafka 0.11+ still truncates the responses, + // all attempts I made at constructing a test to trigger a truncation have + // failed. I kept this code here as a safeguard but it may never execute. + if dec.err != nil && len(records) == 0 { + return dec.err + } + + *rs = RecordSet{ + Version: magicByte, + Attributes: Attributes(attributes), + Records: &optimizedRecordReader{ + records: records, + headers: headers, + }, + } + + if rs.Attributes.Control() { + rs.Records = &ControlBatch{ + Attributes: rs.Attributes, + PartitionLeaderEpoch: partitionLeaderEpoch, + BaseOffset: baseOffset, + ProducerID: producerID, + ProducerEpoch: producerEpoch, + BaseSequence: baseSequence, + Records: rs.Records, + } + } else { + rs.Records = &RecordBatch{ + Attributes: rs.Attributes, + PartitionLeaderEpoch: partitionLeaderEpoch, + BaseOffset: baseOffset, + ProducerID: producerID, + ProducerEpoch: producerEpoch, + BaseSequence: baseSequence, + Records: rs.Records, + } + } + + return nil +} + +func (rs *RecordSet) writeToVersion2(buffer *pageBuffer, bufferOffset int64) error { + records := rs.Records + numRecords := int32(0) + + e := &encoder{writer: buffer} + e.writeInt64(0) // base offset | 0 +8 + e.writeInt32(0) // placeholder for record batch length | 8 +4 + e.writeInt32(-1) // partition leader epoch | 12 +3 + e.writeInt8(2) // magic byte | 16 +1 + e.writeInt32(0) // placeholder for crc32 checksum | 17 +4 + e.writeInt16(int16(rs.Attributes)) // attributes | 21 +2 + e.writeInt32(0) // placeholder for lastOffsetDelta | 23 +4 + e.writeInt64(0) // placeholder for firstTimestamp | 27 +8 + e.writeInt64(0) // placeholder for maxTimestamp | 35 +8 + e.writeInt64(-1) // producer id | 43 +8 + e.writeInt16(-1) // producer epoch | 51 +2 + e.writeInt32(-1) // base sequence | 53 +4 + e.writeInt32(0) // placeholder for numRecords | 57 +4 + + var compressor io.WriteCloser + if compression := rs.Attributes.Compression(); compression != 0 { + if codec := compression.Codec(); codec != nil { + compressor = codec.NewWriter(buffer) + e.writer = compressor + } + } + + currentTimestamp := timestamp(time.Now()) + lastOffsetDelta := int32(0) + firstTimestamp := int64(0) + maxTimestamp := int64(0) + + err := forEachRecord(records, func(i int, r *Record) error { + t := timestamp(r.Time) + if t == 0 { + t = currentTimestamp + } + if i == 0 { + firstTimestamp = t + } + if t > maxTimestamp { + maxTimestamp = t + } + + timestampDelta := t - firstTimestamp + offsetDelta := int64(i) + lastOffsetDelta = int32(offsetDelta) + + length := 1 + // attributes + sizeOfVarInt(timestampDelta) + + sizeOfVarInt(offsetDelta) + + sizeOfVarNullBytesIface(r.Key) + + sizeOfVarNullBytesIface(r.Value) + + sizeOfVarInt(int64(len(r.Headers))) + + for _, h := range r.Headers { + length += sizeOfVarString(h.Key) + sizeOfVarNullBytes(h.Value) + } + + e.writeVarInt(int64(length)) + e.writeInt8(0) // record attributes (unused) + e.writeVarInt(timestampDelta) + e.writeVarInt(offsetDelta) + + if err := e.writeVarNullBytesFrom(r.Key); err != nil { + return err + } + + if err := e.writeVarNullBytesFrom(r.Value); err != nil { + return err + } + + e.writeVarInt(int64(len(r.Headers))) + + for _, h := range r.Headers { + e.writeVarString(h.Key) + e.writeVarNullBytes(h.Value) + } + + numRecords++ + return nil + }) + + if err != nil { + return err + } + + if compressor != nil { + if err := compressor.Close(); err != nil { + return err + } + } + + if numRecords == 0 { + return ErrNoRecord + } + + b2 := packUint32(uint32(lastOffsetDelta)) + b3 := packUint64(uint64(firstTimestamp)) + b4 := packUint64(uint64(maxTimestamp)) + b5 := packUint32(uint32(numRecords)) + + buffer.WriteAt(b2[:], bufferOffset+23) + buffer.WriteAt(b3[:], bufferOffset+27) + buffer.WriteAt(b4[:], bufferOffset+35) + buffer.WriteAt(b5[:], bufferOffset+57) + + totalLength := buffer.Size() - bufferOffset + batchLength := totalLength - 12 + + checksum := uint32(0) + crcTable := crc32.MakeTable(crc32.Castagnoli) + + buffer.pages.scan(bufferOffset+21, bufferOffset+totalLength, func(chunk []byte) bool { + checksum = crc32.Update(checksum, crcTable, chunk) + return true + }) + + b0 := packUint32(uint32(batchLength)) + b1 := packUint32(checksum) + + buffer.WriteAt(b0[:], bufferOffset+8) + buffer.WriteAt(b1[:], bufferOffset+17) + return nil +} diff --git a/protocol/reflect.go b/protocol/reflect.go new file mode 100644 index 000000000..181ac0d72 --- /dev/null +++ b/protocol/reflect.go @@ -0,0 +1,101 @@ +// +build !unsafe + +package protocol + +import ( + "reflect" +) + +type index []int + +type _type struct{ typ reflect.Type } + +func typeOf(x interface{}) _type { + return makeType(reflect.TypeOf(x)) +} + +func elemTypeOf(x interface{}) _type { + return makeType(reflect.TypeOf(x).Elem()) +} + +func makeType(t reflect.Type) _type { + return _type{typ: t} +} + +type value struct { + val reflect.Value +} + +func nonAddressableValueOf(x interface{}) value { + return value{val: reflect.ValueOf(x)} +} + +func valueOf(x interface{}) value { + return value{val: reflect.ValueOf(x).Elem()} +} + +func makeValue(t reflect.Type) value { + return value{val: reflect.New(t).Elem()} +} + +func (v value) bool() bool { return v.val.Bool() } + +func (v value) int8() int8 { return int8(v.int64()) } + +func (v value) int16() int16 { return int16(v.int64()) } + +func (v value) int32() int32 { return int32(v.int64()) } + +func (v value) int64() int64 { return v.val.Int() } + +func (v value) string() string { return v.val.String() } + +func (v value) bytes() []byte { return v.val.Bytes() } + +func (v value) iface(t reflect.Type) interface{} { return v.val.Addr().Interface() } + +func (v value) array(t reflect.Type) array { return array{val: v.val} } + +func (v value) setBool(b bool) { v.val.SetBool(b) } + +func (v value) setInt8(i int8) { v.setInt64(int64(i)) } + +func (v value) setInt16(i int16) { v.setInt64(int64(i)) } + +func (v value) setInt32(i int32) { v.setInt64(int64(i)) } + +func (v value) setInt64(i int64) { v.val.SetInt(i) } + +func (v value) setString(s string) { v.val.SetString(s) } + +func (v value) setBytes(b []byte) { v.val.SetBytes(b) } + +func (v value) setArray(a array) { + if a.val.IsValid() { + v.val.Set(a.val) + } else { + v.val.Set(reflect.Zero(v.val.Type())) + } +} + +func (v value) fieldByIndex(i index) value { + return value{val: v.val.FieldByIndex(i)} +} + +type array struct { + val reflect.Value +} + +func makeArray(t reflect.Type, n int) array { + return array{val: reflect.MakeSlice(reflect.SliceOf(t), n, n)} +} + +func (a array) index(i int) value { return value{val: a.val.Index(i)} } + +func (a array) length() int { return a.val.Len() } + +func (a array) isNil() bool { return a.val.IsNil() } + +func indexOf(s reflect.StructField) index { return index(s.Index) } + +func bytesToString(b []byte) string { return string(b) } diff --git a/protocol/reflect_unsafe.go b/protocol/reflect_unsafe.go new file mode 100644 index 000000000..e6d6511dc --- /dev/null +++ b/protocol/reflect_unsafe.go @@ -0,0 +1,138 @@ +// +build unsafe + +package protocol + +import ( + "reflect" + "unsafe" +) + +type iface struct { + typ unsafe.Pointer + ptr unsafe.Pointer +} + +type slice struct { + ptr unsafe.Pointer + len int + cap int +} + +type index uintptr + +type _type struct { + ptr unsafe.Pointer +} + +func typeOf(x interface{}) _type { + return _type{ptr: ((*iface)(unsafe.Pointer(&x))).typ} +} + +func elemTypeOf(x interface{}) _type { + return makeType(reflect.TypeOf(x).Elem()) +} + +func makeType(t reflect.Type) _type { + return _type{ptr: ((*iface)(unsafe.Pointer(&t))).ptr} +} + +type value struct { + ptr unsafe.Pointer +} + +func nonAddressableValueOf(x interface{}) value { + return valueOf(x) +} + +func valueOf(x interface{}) value { + return value{ptr: ((*iface)(unsafe.Pointer(&x))).ptr} +} + +func makeValue(t reflect.Type) value { + return value{ptr: unsafe.Pointer(reflect.New(t).Pointer())} +} + +func (v value) bool() bool { return *(*bool)(v.ptr) } + +func (v value) int8() int8 { return *(*int8)(v.ptr) } + +func (v value) int16() int16 { return *(*int16)(v.ptr) } + +func (v value) int32() int32 { return *(*int32)(v.ptr) } + +func (v value) int64() int64 { return *(*int64)(v.ptr) } + +func (v value) string() string { return *(*string)(v.ptr) } + +func (v value) bytes() []byte { return *(*[]byte)(v.ptr) } + +func (v value) iface(t reflect.Type) interface{} { + return *(*interface{})(unsafe.Pointer(&iface{ + typ: ((*iface)(unsafe.Pointer(&t))).ptr, + ptr: v.ptr, + })) +} + +func (v value) array(t reflect.Type) array { + return array{ + size: uintptr(t.Size()), + elem: ((*slice)(v.ptr)).ptr, + len: ((*slice)(v.ptr)).len, + } +} + +func (v value) setBool(b bool) { *(*bool)(v.ptr) = b } + +func (v value) setInt8(i int8) { *(*int8)(v.ptr) = i } + +func (v value) setInt16(i int16) { *(*int16)(v.ptr) = i } + +func (v value) setInt32(i int32) { *(*int32)(v.ptr) = i } + +func (v value) setInt64(i int64) { *(*int64)(v.ptr) = i } + +func (v value) setString(s string) { *(*string)(v.ptr) = s } + +func (v value) setBytes(b []byte) { *(*[]byte)(v.ptr) = b } + +func (v value) setArray(a array) { *(*slice)(v.ptr) = slice{ptr: a.elem, len: a.len, cap: a.len} } + +func (v value) fieldByIndex(i index) value { + return value{ptr: unsafe.Pointer(uintptr(v.ptr) + uintptr(i))} +} + +type array struct { + elem unsafe.Pointer + size uintptr + len int +} + +var ( + emptyArray struct{} +) + +func makeArray(t reflect.Type, n int) array { + var elem unsafe.Pointer + var size = uintptr(t.Size()) + if n == 0 { + elem = unsafe.Pointer(&emptyArray) + } else { + elem = unsafe_NewArray(((*iface)(unsafe.Pointer(&t))).ptr, n) + } + return array{elem: elem, size: size, len: n} +} + +func (a array) index(i int) value { + return value{ptr: unsafe.Pointer(uintptr(a.elem) + (uintptr(i) * a.size))} +} + +func (a array) length() int { return a.len } + +func (a array) isNil() bool { return a.elem == nil } + +func indexOf(s reflect.StructField) index { return index(s.Offset) } + +func bytesToString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } + +//go:linkname unsafe_NewArray reflect.unsafe_NewArray +func unsafe_NewArray(rtype unsafe.Pointer, length int) unsafe.Pointer diff --git a/protocol/request.go b/protocol/request.go new file mode 100644 index 000000000..8b99e0537 --- /dev/null +++ b/protocol/request.go @@ -0,0 +1,128 @@ +package protocol + +import ( + "fmt" + "io" +) + +func ReadRequest(r io.Reader) (apiVersion int16, correlationID int32, clientID string, msg Message, err error) { + d := &decoder{reader: r, remain: 4} + size := d.readInt32() + + if err = d.err; err != nil { + err = dontExpectEOF(err) + return + } + + d.remain = int(size) + apiKey := ApiKey(d.readInt16()) + apiVersion = d.readInt16() + correlationID = d.readInt32() + clientID = d.readString() + + if i := int(apiKey); i < 0 || i >= len(apiTypes) { + err = fmt.Errorf("unsupported api key: %d", i) + return + } + + if err = d.err; err != nil { + err = dontExpectEOF(err) + return + } + + t := &apiTypes[apiKey] + if t == nil { + err = fmt.Errorf("unsupported api: %s", apiNames[apiKey]) + return + } + + minVersion := t.minVersion() + maxVersion := t.maxVersion() + + if apiVersion < minVersion || apiVersion > maxVersion { + err = fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion) + return + } + + req := &t.requests[apiVersion-minVersion] + + if req.flexible { + // In the flexible case, there's a tag buffer at the end of the request header + taggedCount := int(d.readUnsignedVarInt()) + for i := 0; i < taggedCount; i++ { + d.readUnsignedVarInt() // tagID + size := d.readUnsignedVarInt() + + // Just throw away the values for now + d.read(int(size)) + } + } + + msg = req.new() + req.decode(d, valueOf(msg)) + d.discardAll() + + if err = d.err; err != nil { + err = dontExpectEOF(err) + } + + return +} + +func WriteRequest(w io.Writer, apiVersion int16, correlationID int32, clientID string, msg Message) error { + apiKey := msg.ApiKey() + + if i := int(apiKey); i < 0 || i >= len(apiTypes) { + return fmt.Errorf("unsupported api key: %d", i) + } + + t := &apiTypes[apiKey] + if t == nil { + return fmt.Errorf("unsupported api: %s", apiNames[apiKey]) + } + + minVersion := t.minVersion() + maxVersion := t.maxVersion() + + if apiVersion < minVersion || apiVersion > maxVersion { + return fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion) + } + + r := &t.requests[apiVersion-minVersion] + v := valueOf(msg) + b := newPageBuffer() + defer b.unref() + + e := &encoder{writer: b} + e.writeInt32(0) // placeholder for the request size + e.writeInt16(int16(apiKey)) + e.writeInt16(apiVersion) + e.writeInt32(correlationID) + + if r.flexible { + // Flexible messages use a nullable string for the client ID, then extra space for a + // tag buffer, which begins with a size value. Since we're not writing any fields into the + // latter, we can just write zero for now. + // + // See + // https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields + // for details. + e.writeNullString(clientID) + e.writeUnsignedVarInt(0) + } else { + // Technically, recent versions of kafka interpret this field as a nullable + // string, however kafka 0.10 expected a non-nullable string and fails with + // a NullPointerException when it receives a null client id. + e.writeString(clientID) + } + r.encode(e, v) + err := e.err + + if err == nil { + size := packUint32(uint32(b.Size()) - 4) + b.WriteAt(size[:], 0) + _, err = b.WriteTo(w) + } + + return err +} diff --git a/protocol/response.go b/protocol/response.go new file mode 100644 index 000000000..ca4218d32 --- /dev/null +++ b/protocol/response.go @@ -0,0 +1,101 @@ +package protocol + +import ( + "fmt" + "io" +) + +func ReadResponse(r io.Reader, apiKey ApiKey, apiVersion int16) (correlationID int32, msg Message, err error) { + if i := int(apiKey); i < 0 || i >= len(apiTypes) { + err = fmt.Errorf("unsupported api key: %d", i) + return + } + + t := &apiTypes[apiKey] + if t == nil { + err = fmt.Errorf("unsupported api: %s", apiNames[apiKey]) + return + } + + minVersion := t.minVersion() + maxVersion := t.maxVersion() + + if apiVersion < minVersion || apiVersion > maxVersion { + err = fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion) + return + } + + d := &decoder{reader: r, remain: 4} + size := d.readInt32() + + if err = d.err; err != nil { + err = dontExpectEOF(err) + return + } + + d.remain = int(size) + correlationID = d.readInt32() + + res := &t.responses[apiVersion-minVersion] + + if res.flexible { + // In the flexible case, there's a tag buffer at the end of the response header + taggedCount := int(d.readUnsignedVarInt()) + for i := 0; i < taggedCount; i++ { + d.readUnsignedVarInt() // tagID + size := d.readUnsignedVarInt() + + // Just throw away the values for now + d.read(int(size)) + } + } + + msg = res.new() + res.decode(d, valueOf(msg)) + d.discardAll() + + if err = d.err; err != nil { + err = dontExpectEOF(err) + } + + return +} + +func WriteResponse(w io.Writer, apiVersion int16, correlationID int32, msg Message) error { + apiKey := msg.ApiKey() + + if i := int(apiKey); i < 0 || i >= len(apiTypes) { + return fmt.Errorf("unsupported api key: %d", i) + } + + t := &apiTypes[apiKey] + if t == nil { + return fmt.Errorf("unsupported api: %s", apiNames[apiKey]) + } + + minVersion := t.minVersion() + maxVersion := t.maxVersion() + + if apiVersion < minVersion || apiVersion > maxVersion { + return fmt.Errorf("unsupported %s version: v%d not in range v%d-v%d", apiKey, apiVersion, minVersion, maxVersion) + } + + r := &t.responses[apiVersion-minVersion] + v := valueOf(msg) + b := newPageBuffer() + defer b.unref() + + e := &encoder{writer: b} + e.writeInt32(0) // placeholder for the response size + e.writeInt32(correlationID) + r.encode(e, v) + err := e.err + + if err == nil { + size := packUint32(uint32(b.Size()) - 4) + b.WriteAt(size[:], 0) + _, err = b.WriteTo(w) + } + + return err +} diff --git a/protocol/roundtrip.go b/protocol/roundtrip.go new file mode 100644 index 000000000..c23532ca7 --- /dev/null +++ b/protocol/roundtrip.go @@ -0,0 +1,28 @@ +package protocol + +import ( + "io" +) + +// RoundTrip sends a request to a kafka broker and returns the response. +func RoundTrip(rw io.ReadWriter, apiVersion int16, correlationID int32, clientID string, req Message) (Message, error) { + if err := WriteRequest(rw, apiVersion, correlationID, clientID, req); err != nil { + return nil, err + } + if !hasResponse(req) { + return nil, nil + } + id, res, err := ReadResponse(rw, req.ApiKey(), apiVersion) + if err != nil { + return nil, err + } + if id != correlationID { + return nil, Errorf("correlation id mismatch (expected=%d, found=%d)", correlationID, id) + } + return res, nil +} + +func hasResponse(msg Message) bool { + x, _ := msg.(interface{ HasResponse() bool }) + return x == nil || x.HasResponse() +} diff --git a/protocol/saslauthenticate/saslauthenticate.go b/protocol/saslauthenticate/saslauthenticate.go new file mode 100644 index 000000000..994876336 --- /dev/null +++ b/protocol/saslauthenticate/saslauthenticate.go @@ -0,0 +1,22 @@ +package saslauthenticate + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + AuthBytes []byte `kafka:"min=v0,max=v1"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate } + +type Response struct { + ErrorCode int16 `kafka:"min=v0,max=v1"` + ErrorMessage string `kafka:"min=v0,max=v1,nullable"` + AuthBytes []byte `kafka:"min=v0,max=v1"` + SessionLifetimeMs int64 `kafka:"min=v1,max=v1"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate } diff --git a/protocol/saslhandshake/saslhandshake.go b/protocol/saslhandshake/saslhandshake.go new file mode 100644 index 000000000..aa72e8309 --- /dev/null +++ b/protocol/saslhandshake/saslhandshake.go @@ -0,0 +1,20 @@ +package saslhandshake + +import "github.com/segmentio/kafka-go/protocol" + +func init() { + protocol.Register(&Request{}, &Response{}) +} + +type Request struct { + Mechanism string `kafka:"min=v0,max=v1"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.SaslHandshake } + +type Response struct { + ErrorCode int16 `kafka:"min=v0,max=v1"` + Mechanisms []string `kafka:"min=v0,max=v1"` +} + +func (r *Response) ApiKey() protocol.ApiKey { return protocol.SaslHandshake } diff --git a/protocol/size.go b/protocol/size.go new file mode 100644 index 000000000..66477cd99 --- /dev/null +++ b/protocol/size.go @@ -0,0 +1,93 @@ +package protocol + +import ( + "math/bits" + "reflect" +) + +func sizeOf(typ reflect.Type) int { + switch typ.Kind() { + case reflect.Bool, reflect.Int8: + return 1 + case reflect.Int16: + return 2 + case reflect.Int32: + return 4 + case reflect.Int64: + return 8 + default: + return 0 + } +} + +func sizeOfString(s string) int { + return 2 + len(s) +} + +func sizeOfBytes(b []byte) int { + return 4 + len(b) +} + +func sizeOfVarString(s string) int { + return sizeOfVarInt(int64(len(s))) + len(s) +} + +func sizeOfCompactString(s string) int { + return sizeOfUnsignedVarInt(uint64(len(s)+1)) + len(s) +} + +func sizeOfVarBytes(b []byte) int { + return sizeOfVarInt(int64(len(b))) + len(b) +} + +func sizeOfCompactBytes(b []byte) int { + return sizeOfUnsignedVarInt(uint64(len(b)+1)) + len(b) +} + +func sizeOfVarNullString(s string) int { + n := len(s) + if n == 0 { + return sizeOfVarInt(-1) + } + return sizeOfVarInt(int64(n)) + n +} + +func sizeOfCompactNullString(s string) int { + n := len(s) + if n == 0 { + return sizeOfUnsignedVarInt(0) + } + return sizeOfUnsignedVarInt(uint64(n)+1) + n +} + +func sizeOfVarNullBytes(b []byte) int { + if b == nil { + return sizeOfVarInt(-1) + } + n := len(b) + return sizeOfVarInt(int64(n)) + n +} + +func sizeOfVarNullBytesIface(b Bytes) int { + if b == nil { + return sizeOfVarInt(-1) + } + n := b.Len() + return sizeOfVarInt(int64(n)) + n +} + +func sizeOfCompactNullBytes(b []byte) int { + if b == nil { + return sizeOfUnsignedVarInt(0) + } + n := len(b) + return sizeOfUnsignedVarInt(uint64(n)+1) + n +} + +func sizeOfVarInt(i int64) int { + return sizeOfUnsignedVarInt(uint64((i << 1) ^ (i >> 63))) // zig-zag encoding +} + +func sizeOfUnsignedVarInt(i uint64) int { + return (bits.Len64(i|1) + 6) / 7 +} diff --git a/protocol_test.go b/protocol_test.go index 0870c001c..d1f0540fe 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -32,8 +32,6 @@ func TestApiVersionsFormat(t *testing.T) { } func TestProtocol(t *testing.T) { - t.Parallel() - tests := []interface{}{ int8(42), int16(42), diff --git a/reader_test.go b/reader_test.go index 97d67fd6d..1324483d4 100644 --- a/reader_test.go +++ b/reader_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "math/rand" + "net" "reflect" "strconv" "sync" @@ -12,8 +13,6 @@ import ( ) func TestReader(t *testing.T) { - t.Parallel() - tests := []struct { scenario string function func(*testing.T, context.Context, *Reader) @@ -268,11 +267,22 @@ func testReaderOutOfRangeGetsCanceled(t *testing.T, ctx context.Context, r *Read func createTopic(t *testing.T, topic string, partitions int) { conn, err := Dial("tcp", "localhost:9092") if err != nil { - t.Error("bad conn") - return + t.Fatal(err) } defer conn.Close() + controller, err := conn.Controller() + if err != nil { + t.Fatal(err) + } + + conn, err = Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port))) + if err != nil { + t.Fatal(err) + } + + conn.SetDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.createTopics(createTopicsRequestV0{ Topics: []createTopicsRequestV0Topic{ { @@ -281,7 +291,7 @@ func createTopic(t *testing.T, topic string, partitions int) { ReplicationFactor: 1, }, }, - Timeout: int32(30 * time.Second / time.Millisecond), + Timeout: milliseconds(time.Second), }) switch err { case nil: @@ -289,14 +299,36 @@ func createTopic(t *testing.T, topic string, partitions int) { case TopicAlreadyExists: // ok default: - t.Error("bad createTopics", err) + t.Error(err) t.FailNow() } } -func TestReaderOnNonZeroPartition(t *testing.T) { - t.Parallel() +func deleteTopic(t *testing.T, topic ...string) { + conn, err := Dial("tcp", "localhost:9092") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + controller, err := conn.Controller() + if err != nil { + t.Fatal(err) + } + + conn, err = Dial("tcp", net.JoinHostPort(controller.Host, strconv.Itoa(controller.Port))) + if err != nil { + t.Fatal(err) + } + + conn.SetDeadline(time.Now().Add(2 * time.Second)) + if err := conn.DeleteTopics(topic...); err != nil { + t.Fatal(err) + } +} + +func TestReaderOnNonZeroPartition(t *testing.T) { tests := []struct { scenario string function func(*testing.T, context.Context, *Reader) @@ -314,6 +346,7 @@ func TestReaderOnNonZeroPartition(t *testing.T) { topic := makeTopic() createTopic(t, topic, 2) + defer deleteTopic(t, topic) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -370,7 +403,6 @@ func TestReadTruncatedMessages(t *testing.T) { // include it in CI unit tests. t.Skip() - t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() r := NewReader(ReaderConfig{ @@ -482,8 +514,11 @@ func BenchmarkReader(b *testing.B) { func TestCloseLeavesGroup(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) + groupID := makeGroupID() r := NewReader(ReaderConfig{ Brokers: []string{"localhost:9092"}, @@ -675,8 +710,6 @@ func TestExtractTopics(t *testing.T) { } func TestReaderConsumerGroup(t *testing.T) { - t.Parallel() - tests := []struct { scenario string partitions int @@ -746,10 +779,14 @@ func TestReaderConsumerGroup(t *testing.T) { for _, test := range tests { t.Run(test.scenario, func(t *testing.T) { + // It appears that some of the tests depend on all these tests being + // run concurrently to pass... this is brittle and should be fixed + // at some point. t.Parallel() topic := makeTopic() createTopic(t, topic, test.partitions) + defer deleteTopic(t, topic) groupID := makeGroupID() r := NewReader(ReaderConfig{ @@ -882,13 +919,16 @@ func testReaderConsumerGroupVerifyCommitsOnClose(t *testing.T, ctx context.Conte func testReaderConsumerGroupReadContentAcrossPartitions(t *testing.T, ctx context.Context, r *Reader) { const N = 12 - writer := NewWriter(WriterConfig{ - Brokers: r.config.Brokers, + client, shutdown := newLocalClient() + defer shutdown() + + writer := &Writer{ + Addr: TCP(r.config.Brokers...), Topic: r.config.Topic, - Dialer: r.config.Dialer, Balancer: &RoundRobin{}, BatchSize: 1, - }) + Transport: client.Transport, + } if err := writer.WriteMessages(ctx, makeTestSequence(N)...); err != nil { t.Fatalf("bad write messages: %v", err) } @@ -919,14 +959,17 @@ func testReaderConsumerGroupRebalance(t *testing.T, ctx context.Context, r *Read partitions = 2 ) + client, shutdown := newLocalClient() + defer shutdown() + // rebalance should result in 12 message in each of the partitions - writer := NewWriter(WriterConfig{ - Brokers: r.config.Brokers, + writer := &Writer{ + Addr: TCP(r.config.Brokers...), Topic: r.config.Topic, - Dialer: r.config.Dialer, Balancer: &RoundRobin{}, BatchSize: 1, - }) + Transport: client.Transport, + } if err := writer.WriteMessages(ctx, makeTestSequence(N*partitions)...); err != nil { t.Fatalf("bad write messages: %v", err) } @@ -947,8 +990,8 @@ func testReaderConsumerGroupRebalance(t *testing.T, ctx context.Context, r *Read func testReaderConsumerGroupRebalanceAcrossTopics(t *testing.T, ctx context.Context, r *Reader) { // create a second reader that shares the groupID, but reads from a different topic - topic2 := makeTopic() - createTopic(t, topic2, 1) + client, topic2, shutdown := newLocalClientAndTopic() + defer shutdown() r2 := NewReader(ReaderConfig{ Brokers: r.config.Brokers, @@ -969,13 +1012,13 @@ func testReaderConsumerGroupRebalanceAcrossTopics(t *testing.T, ctx context.Cont ) // write messages across both partitions - writer := NewWriter(WriterConfig{ - Brokers: r.config.Brokers, + writer := &Writer{ + Addr: TCP(r.config.Brokers...), Topic: r.config.Topic, - Dialer: r.config.Dialer, Balancer: &RoundRobin{}, BatchSize: 1, - }) + Transport: client.Transport, + } if err := writer.WriteMessages(ctx, makeTestSequence(N)...); err != nil { t.Fatalf("bad write messages: %v", err) } @@ -1021,14 +1064,17 @@ func testReaderConsumerGroupRebalanceAcrossManyPartitionsAndConsumers(t *testing } }() + client, shutdown := newLocalClient() + defer shutdown() + // write messages across both partitions - writer := NewWriter(WriterConfig{ - Brokers: r.config.Brokers, + writer := &Writer{ + Addr: TCP(r.config.Brokers...), Topic: r.config.Topic, - Dialer: r.config.Dialer, Balancer: &RoundRobin{}, BatchSize: 1, - }) + Transport: client.Transport, + } if err := writer.WriteMessages(ctx, makeTestSequence(N*3)...); err != nil { t.Fatalf("bad write messages: %v", err) } @@ -1217,8 +1263,6 @@ func TestCommitOffsetsWithRetry(t *testing.T) { // than partitions in a group. // https://github.com/segmentio/kafka-go/issues/200 func TestRebalanceTooManyConsumers(t *testing.T) { - t.Parallel() - ctx := context.Background() conf := ReaderConfig{ Brokers: []string{"localhost:9092"}, @@ -1249,7 +1293,6 @@ func TestRebalanceTooManyConsumers(t *testing.T) { } func TestConsumerGroupWithMissingTopic(t *testing.T) { - t.Parallel() t.Skip("this test doesn't work when the cluster is configured to auto-create topics") ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -1274,14 +1317,16 @@ func TestConsumerGroupWithMissingTopic(t *testing.T) { }() time.Sleep(time.Second) - createTopic(t, conf.Topic, 1) + client, shutdown := newLocalClientWithTopic(conf.Topic, 1) + defer shutdown() - w := NewWriter(WriterConfig{ - Brokers: r.config.Brokers, + w := &Writer{ + Addr: TCP(r.config.Brokers...), Topic: r.config.Topic, BatchTimeout: 10 * time.Millisecond, BatchSize: 1, - }) + Transport: client.Transport, + } defer w.Close() if err := w.WriteMessages(ctx, Message{}); err != nil { t.Fatalf("write error: %+v", err) @@ -1297,7 +1342,7 @@ func TestConsumerGroupWithMissingTopic(t *testing.T) { } } -func getOffsets(t *testing.T, config ReaderConfig) offsetFetchResponseV1 { +func getOffsets(t *testing.T, config ReaderConfig) map[int]int64 { // minimal config required to lookup coordinator cg := ConsumerGroup{ config: ConsumerGroupConfig{ @@ -1324,7 +1369,17 @@ func getOffsets(t *testing.T, config ReaderConfig) offsetFetchResponseV1 { t.Errorf("bad fetchOffsets: %v", err) } - return offsets + m := map[int]int64{} + + for _, r := range offsets.Responses { + if r.Topic == config.Topic { + for _, p := range r.PartitionResponses { + m[int(p.Partition)] = p.Offset + } + } + } + + return m } const ( diff --git a/record.go b/record.go new file mode 100644 index 000000000..c674e1891 --- /dev/null +++ b/record.go @@ -0,0 +1,42 @@ +package kafka + +import ( + "github.com/segmentio/kafka-go/protocol" +) + +// Header is a key/value pair type representing headers set on records. +type Header = protocol.Header + +// Bytes is an interface representing a sequence of bytes. This abstraction +// makes it possible for programs to inject data into produce requests without +// having to load in into an intermediary buffer, or read record keys and values +// from a fetch response directly from internal buffers. +// +// Bytes are not safe to use concurrently from multiple goroutines. +type Bytes = protocol.Bytes + +// NewBytes constructs a Bytes value from a byte slice. +// +// If b is nil, nil is returned. +func NewBytes(b []byte) Bytes { return protocol.NewBytes(b) } + +// ReadAll reads b into a byte slice. +func ReadAll(b Bytes) ([]byte, error) { return protocol.ReadAll(b) } + +// Record is an interface representing a single kafka record. +// +// Record values are not safe to use concurrently from multiple goroutines. +type Record = protocol.Record + +// RecordReader is an interface representing a sequence of records. Record sets +// are used in both produce and fetch requests to represent the sequence of +// records that are sent to or receive from kafka brokers. +// +// RecordReader values are not safe to use concurrently from multiple goroutines. +type RecordReader = protocol.RecordReader + +// NewRecordReade rconstructs a RecordSet which exposes the sequence of records +// passed as arguments. +func NewRecordReader(records ...Record) RecordReader { + return protocol.NewRecordReader(records...) +} diff --git a/resolver.go b/resolver.go new file mode 100644 index 000000000..5a61eee5e --- /dev/null +++ b/resolver.go @@ -0,0 +1,62 @@ +package kafka + +import ( + "context" + "net" +) + +// The Resolver interface is used as an abstraction to provide service discovery +// of the hosts of a kafka cluster. +type Resolver interface { + // LookupHost looks up the given host using the local resolver. + // It returns a slice of that host's addresses. + LookupHost(ctx context.Context, host string) (addrs []string, err error) +} + +// BrokerResolver is an interface implemented by types that translate host +// names into a network address. +// +// This resolver is not intended to be a general purpose interface. Instead, +// it is tailored to the particular needs of the kafka protocol, with the goal +// being to provide a flexible mechanism for extending broker name resolution +// while retaining context that is specific to interacting with a kafka cluster. +// +// Resolvers must be safe to use from multiple goroutines. +type BrokerResolver interface { + // Returns the IP addresses of the broker passed as argument. + LookupBrokerIPAddr(ctx context.Context, broker Broker) ([]net.IPAddr, error) +} + +// NewBrokerResolver constructs a Resolver from r. +// +// If r is nil, net.DefaultResolver is used instead. +func NewBrokerResolver(r *net.Resolver) BrokerResolver { + return brokerResolver{r} +} + +type brokerResolver struct { + *net.Resolver +} + +func (r brokerResolver) LookupBrokerIPAddr(ctx context.Context, broker Broker) ([]net.IPAddr, error) { + rslv := r.Resolver + if rslv == nil { + rslv = net.DefaultResolver + } + + ipAddrs, err := r.LookupIPAddr(ctx, broker.Host) + if err != nil { + return nil, err + } + + if len(ipAddrs) == 0 { + return nil, &net.DNSError{ + Err: "no addresses were returned by the resolver", + Name: broker.Host, + IsTemporary: true, + IsNotFound: true, + } + } + + return ipAddrs, nil +} diff --git a/sasl/sasl_test.go b/sasl/sasl_test.go index 7a6307c45..a4101391a 100644 --- a/sasl/sasl_test.go +++ b/sasl/sasl_test.go @@ -18,9 +18,6 @@ const ( ) func TestSASL(t *testing.T) { - - t.Parallel() - tests := []struct { valid func() sasl.Mechanism invalid func() sasl.Mechanism diff --git a/snappy/snappy.go b/snappy/snappy.go index ff77cb228..0ffddafbc 100644 --- a/snappy/snappy.go +++ b/snappy/snappy.go @@ -1,92 +1,24 @@ +// Package snappy does nothing, it's kept for backward compatibility to avoid +// breaking the majority of programs that imported it to install the compression +// codec, which is now always included. package snappy -import ( - "io" - "sync" +import "github.com/segmentio/kafka-go/compress/snappy" - "github.com/golang/snappy" - kafka "github.com/segmentio/kafka-go" -) - -func init() { - kafka.RegisterCompressionCodec(NewCompressionCodec()) -} +type CompressionCodec = snappy.Codec -// Framing is an enumeration type used to enable or disable xerial framing of -// snappy messages. -type Framing int - -const ( - Framed Framing = iota - Unframed -) +type Framing = snappy.Framing const ( - Code = 2 + Code = 2 + Framed = snappy.Framed + Unframed = snappy.Unframed ) -type CompressionCodec struct{ framing Framing } - func NewCompressionCodec() *CompressionCodec { return NewCompressionCodecFraming(Framed) } func NewCompressionCodecFraming(framing Framing) *CompressionCodec { - return &CompressionCodec{framing} -} - -// Code implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Code() int8 { return Code } - -// Name implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Name() string { return "snappy" } - -// NewReader implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { - x := readerPool.Get().(*xerialReader) - x.Reset(r) - return &reader{x} -} - -// NewWriter implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { - x := writerPool.Get().(*xerialWriter) - x.Reset(w) - x.framed = c.framing == Framed - return &writer{x} -} - -type reader struct{ *xerialReader } - -func (r *reader) Close() (err error) { - if x := r.xerialReader; x != nil { - r.xerialReader = nil - x.Reset(nil) - readerPool.Put(x) - } - return -} - -type writer struct{ *xerialWriter } - -func (w *writer) Close() (err error) { - if x := w.xerialWriter; x != nil { - w.xerialWriter = nil - err = x.Flush() - x.Reset(nil) - writerPool.Put(x) - } - return -} - -var readerPool = sync.Pool{ - New: func() interface{} { - return &xerialReader{decode: snappy.Decode} - }, -} - -var writerPool = sync.Pool{ - New: func() interface{} { - return &xerialWriter{encode: snappy.Encode} - }, + return &CompressionCodec{Framing: framing} } diff --git a/testing/conn.go b/testing/conn.go new file mode 100644 index 000000000..e45840b7e --- /dev/null +++ b/testing/conn.go @@ -0,0 +1,32 @@ +package testing + +import ( + "context" + "net" + "sync" +) + +type ConnWaitGroup struct { + DialFunc func(context.Context, string, string) (net.Conn, error) + sync.WaitGroup +} + +func (g *ConnWaitGroup) Dial(ctx context.Context, network, address string) (net.Conn, error) { + c, err := g.DialFunc(ctx, network, address) + if err != nil { + return nil, err + } + g.Add(1) + return &groupConn{Conn: c, group: g}, nil +} + +type groupConn struct { + net.Conn + group *ConnWaitGroup + once sync.Once +} + +func (c *groupConn) Close() error { + defer c.once.Do(c.group.Done) + return c.Conn.Close() +} diff --git a/testing/version.go b/testing/version.go index 354ac4a7e..c52cd106a 100644 --- a/testing/version.go +++ b/testing/version.go @@ -1,4 +1,4 @@ -package kafka +package testing import ( "os" diff --git a/testing/version_test.go b/testing/version_test.go index 9cf3730ee..52af989db 100644 --- a/testing/version_test.go +++ b/testing/version_test.go @@ -1,4 +1,4 @@ -package kafka +package testing import ( "testing" diff --git a/time.go b/time.go index 26f33afd0..544d84207 100644 --- a/time.go +++ b/time.go @@ -11,6 +11,13 @@ const ( defaultRTT = 1 * time.Second ) +func makeTime(t int64) time.Time { + if t <= 0 { + return time.Time{} + } + return time.Unix(t/1000, (t%1000)*int64(time.Millisecond)).UTC() +} + func timestamp(t time.Time) int64 { if t.IsZero() { return 0 @@ -18,11 +25,7 @@ func timestamp(t time.Time) int64 { return t.UnixNano() / int64(time.Millisecond) } -func timestampToTime(t int64) time.Time { - return time.Unix(t/1000, (t%1000)*int64(time.Millisecond)) -} - -func duration(ms int32) time.Duration { +func makeDuration(ms int32) time.Duration { return time.Duration(ms) * time.Millisecond } diff --git a/transport.go b/transport.go new file mode 100644 index 000000000..b8311ad5a --- /dev/null +++ b/transport.go @@ -0,0 +1,1298 @@ +package kafka + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "math/rand" + "net" + "runtime/pprof" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/apiversions" + "github.com/segmentio/kafka-go/protocol/createtopics" + "github.com/segmentio/kafka-go/protocol/findcoordinator" + meta "github.com/segmentio/kafka-go/protocol/metadata" + "github.com/segmentio/kafka-go/protocol/saslauthenticate" + "github.com/segmentio/kafka-go/protocol/saslhandshake" + "github.com/segmentio/kafka-go/sasl" +) + +// Request is an interface implemented by types that represent messages sent +// from kafka clients to brokers. +type Request = protocol.Message + +// Response is an interface implemented by types that represent messages sent +// from kafka brokers in response to client requests. +type Response = protocol.Message + +// RoundTripper is an interface implemented by types which support interacting +// with kafka brokers. +type RoundTripper interface { + // RoundTrip sends a request to a kafka broker and returns the response that + // was received, or a non-nil error. + // + // The context passed as first argument can be used to asynchronnously abort + // the call if needed. + RoundTrip(context.Context, net.Addr, Request) (Response, error) +} + +// Transport is an implementation of the RoundTripper interface. +// +// Transport values manage a pool of connections and automatically discovers the +// clusters layout to route requests to the appropriate brokers. +// +// Transport values are safe to use concurrently from multiple goroutines. +// +// Note: The intent is for the Transport to become the underlying layer of the +// kafka.Reader and kafka.Writer types. +type Transport struct { + // A function used to establish connections to the kafka cluster. + Dial func(context.Context, string, string) (net.Conn, error) + + // Time limit set for establishing connections to the kafka cluster. This + // limit includes all round trips done to establish the connections (TLS + // hadbhaske, SASL negotiation, etc...). + // + // Defaults to 5s. + DialTimeout time.Duration + + // Maximum amount of time that connections will remain open and unused. + // The transport will manage to automatically close connections that have + // been idle for too long, and re-open them on demand when the transport is + // used again. + // + // Defaults to 30s. + IdleTimeout time.Duration + + // TTL for the metadata cached by this transport. Note that the value + // configured here is an upper bound, the transport randomizes the TTLs to + // avoid getting into states where multiple clients end up synchronized and + // cause bursts of requests to the kafka broker. + // + // Default to 6s. + MetadataTTL time.Duration + + // Unique identifier that the transport communicates to the brokers when it + // sends requests. + ClientID string + + // An optional configuration for TLS connections established by this + // transport. + // + // If the Server + TLS *tls.Config + + // SASL configures the Transfer to use SASL authentication. + SASL sasl.Mechanism + + // An optional resolver used to translate broker host names into network + // addresses. + // + // The resolver will be called for every request (not every connection), + // making it possible to implement ACL policies by validating that the + // program is allowed to connect to the kafka broker. This also means that + // the resolver should probably provide a caching layer to avoid storming + // the service discovery backend with requests. + // + // When set, the Dial function is not responsible for performing name + // resolution, and is always called with a pre-resolved address. + Resolver BrokerResolver + + // The background context used to control goroutines started internally by + // the transport. + // + // If nil, context.Background() is used instead. + Context context.Context + + mutex sync.RWMutex + pools map[networkAddress]*connPool +} + +// DefaultTransport is the default transport used by kafka clients in this +// package. +var DefaultTransport RoundTripper = &Transport{ + Dial: (&net.Dialer{ + Timeout: 3 * time.Second, + DualStack: true, + }).DialContext, +} + +// CloseIdleConnections closes all idle connections immediately, and marks all +// connections that are in use to be closed when they become idle again. +func (t *Transport) CloseIdleConnections() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for _, pool := range t.pools { + pool.unref() + } + + for k := range t.pools { + delete(t.pools, k) + } +} + +// RoundTrip sends a request to a kafka cluster and returns the response, or an +// error if no responses were received. +// +// Message types are available in sub-packages of the protocol package. Each +// kafka API is implemented in a different sub-package. For example, the request +// and response types for the Fetch API are available in the protocol/fetch +// package. +// +// The type of the response message will match the type of the request. For +// exmple, if RoundTrip was called with a *fetch.Request as argument, the value +// returned will be of type *fetch.Response. It is safe for the program to do a +// type assertion after checking that no error was returned. +// +// This example illustrates the way this method is expected to be used: +// +// r, err := transport.RoundTrip(ctx, addr, &fetch.Request{ ... }) +// if err != nil { +// ... +// } else { +// res := r.(*fetch.Response) +// ... +// } +// +// The transport automatically selects the highest version of the API that is +// supported by both the kafka-go package and the kafka broker. The negotiation +// happens transparently once when connections are established. +// +// This API was introduced in version 0.4 as a way to leverage the lower-level +// features of the kafka protocol, but also provide a more efficient way of +// managing connections to kafka brokers. +func (t *Transport) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) { + p := t.grabPool(addr) + defer p.unref() + return p.roundTrip(ctx, req) +} + +func (t *Transport) dial() func(context.Context, string, string) (net.Conn, error) { + if t.Dial != nil { + return t.Dial + } + return defaultDialer.DialContext +} + +func (t *Transport) dialTimeout() time.Duration { + if t.DialTimeout > 0 { + return t.DialTimeout + } + return 5 * time.Second +} + +func (t *Transport) idleTimeout() time.Duration { + if t.IdleTimeout > 0 { + return t.IdleTimeout + } + return 30 * time.Second +} + +func (t *Transport) metadataTTL() time.Duration { + if t.MetadataTTL > 0 { + return t.MetadataTTL + } + return 6 * time.Second +} + +func (t *Transport) grabPool(addr net.Addr) *connPool { + k := networkAddress{ + network: addr.Network(), + address: addr.String(), + } + + t.mutex.RLock() + p := t.pools[k] + if p != nil { + p.ref() + } + t.mutex.RUnlock() + + if p != nil { + return p + } + + t.mutex.Lock() + defer t.mutex.Unlock() + + if p := t.pools[k]; p != nil { + p.ref() + return p + } + + ctx, cancel := context.WithCancel(t.context()) + + p = &connPool{ + refc: 2, + + dial: t.dial(), + dialTimeout: t.dialTimeout(), + idleTimeout: t.idleTimeout(), + metadataTTL: t.metadataTTL(), + clientID: t.ClientID, + tls: t.TLS, + sasl: t.SASL, + resolver: t.Resolver, + + ready: make(event), + wake: make(chan event), + conns: make(map[int32]*connGroup), + cancel: cancel, + } + + p.ctrl = p.newConnGroup(addr) + go p.discover(ctx, p.wake) + + if t.pools == nil { + t.pools = make(map[networkAddress]*connPool) + } + t.pools[k] = p + return p +} + +func (t *Transport) context() context.Context { + if t.Context != nil { + return t.Context + } + return context.Background() +} + +type event chan struct{} + +func (e event) trigger() { close(e) } + +type connPool struct { + refc uintptr + // Immutable fields of the connection pool. Connections access these field + // on their parent pool in a ready-only fashion, so no synchronization is + // required. + dial func(context.Context, string, string) (net.Conn, error) + dialTimeout time.Duration + idleTimeout time.Duration + metadataTTL time.Duration + clientID string + tls *tls.Config + sasl sasl.Mechanism + resolver BrokerResolver + // Signaling mechanisms to orchestrate communications between the pool and + // the rest of the program. + once sync.Once // ensure that `ready` is triggered only once + ready event // triggered after the first metadata update + wake chan event // used to force metadata updates + cancel context.CancelFunc + // Mutable fields of the connection pool, access must be synchronized. + mutex sync.RWMutex + conns map[int32]*connGroup // data connections used for produce/fetch/etc... + ctrl *connGroup // control connections used for metadata requests + state atomic.Value // cached cluster state +} + +type connPoolState struct { + metadata *meta.Response // last metadata response seen by the pool + err error // last error from metadata requests + layout protocol.Cluster // cluster layout built from metadata response +} + +func (p *connPool) grabState() connPoolState { + state, _ := p.state.Load().(connPoolState) + return state +} + +func (p *connPool) setState(state connPoolState) { + p.state.Store(state) +} + +func (p *connPool) ref() { + atomic.AddUintptr(&p.refc, +1) +} + +func (p *connPool) unref() { + if atomic.AddUintptr(&p.refc, ^uintptr(0)) == 0 { + p.mutex.Lock() + defer p.mutex.Unlock() + + for _, conns := range p.conns { + conns.closeIdleConns() + } + + p.ctrl.closeIdleConns() + p.cancel() + } +} + +func (p *connPool) roundTrip(ctx context.Context, req Request) (Response, error) { + // This first select should never block after the first metadata response + // that would mark the pool as `ready`. + select { + case <-p.ready: + case <-ctx.Done(): + return nil, ctx.Err() + } + + var expectTopics []string + defer func() { + if len(expectTopics) != 0 { + p.refreshMetadata(ctx, expectTopics) + } + }() + + var state = p.grabState() + var response promise + + switch m := req.(type) { + case *meta.Request: + // We serve metadata requests directly from the transport cache. + // + // This reduces the number of round trips to kafka brokers while keeping + // the logic simple when applying partitioning strategies. + if state.err != nil { + return nil, state.err + } + return filterMetadataResponse(m, state.metadata), nil + + case *createtopics.Request: + // Force an update of the metadata when adding topics, + // otherwise the cached state would get out of sync. + expectTopics = make([]string, len(m.Topics)) + for i := range m.Topics { + expectTopics[i] = m.Topics[i].Name + } + + case protocol.Splitter: + // Messages that implement the Splitter interface trigger the creation of + // multiple requests that are all merged back into a single results by + // a merger. + messages, merger, err := m.Split(state.layout) + if err != nil { + return nil, err + } + promises := make([]promise, len(messages)) + for i, m := range messages { + promises[i] = p.sendRequest(ctx, m, state) + } + response = join(promises, messages, merger) + } + + if response == nil { + response = p.sendRequest(ctx, req, state) + } + + return response.await(ctx) +} + +// refreshMetadata forces an update of the cached cluster metadata, and waits +// for the given list of topics to appear. This waiting mechanism is necessary +// to account for the fact that topic creation is asynchronous in kafka, and +// causes subsequent requests to fail while the cluster state is propagated to +// all the brokers. +func (p *connPool) refreshMetadata(ctx context.Context, expectTopics []string) { + minBackoff := 100 * time.Millisecond + maxBackoff := 2 * time.Second + cancel := ctx.Done() + + for ctx.Err() == nil { + notify := make(event) + select { + case <-cancel: + return + case p.wake <- notify: + select { + case <-notify: + case <-cancel: + return + } + } + + state := p.grabState() + found := 0 + + for _, topic := range expectTopics { + if _, ok := state.layout.Topics[topic]; ok { + found++ + } + } + + if found == len(expectTopics) { + return + } + + if delay := time.Duration(rand.Int63n(int64(minBackoff))); delay > 0 { + timer := time.NewTimer(minBackoff) + select { + case <-cancel: + case <-timer.C: + } + timer.Stop() + + if minBackoff *= 2; minBackoff > maxBackoff { + minBackoff = maxBackoff + } + } + } +} + +func (p *connPool) setReady() { + p.once.Do(p.ready.trigger) +} + +// update is called periodically by the goroutine running the discover method +// to refresh the cluster layout information used by the transport to route +// requests to brokers. +func (p *connPool) update(ctx context.Context, metadata *meta.Response, err error) { + var layout protocol.Cluster + + if metadata != nil { + metadata.ThrottleTimeMs = 0 + + // Normalize the lists so we can apply binary search on them. + sortMetadataBrokers(metadata.Brokers) + sortMetadataTopics(metadata.Topics) + + for i := range metadata.Topics { + t := &metadata.Topics[i] + sortMetadataPartitions(t.Partitions) + } + + layout = makeLayout(metadata) + } + + state := p.grabState() + addBrokers := make(map[int32]struct{}) + delBrokers := make(map[int32]struct{}) + + if err != nil { + // Only update the error on the transport if the cluster layout was + // unknown. This ensures that we prioritize a previously known state + // of the cluster to reduce the impact of transient failures. + if state.metadata != nil { + return + } + state.err = err + } else { + for id, b2 := range layout.Brokers { + if b1, ok := state.layout.Brokers[id]; !ok { + addBrokers[id] = struct{}{} + } else if b1 != b2 { + addBrokers[id] = struct{}{} + delBrokers[id] = struct{}{} + } + } + + for id := range state.layout.Brokers { + if _, ok := layout.Brokers[id]; !ok { + delBrokers[id] = struct{}{} + } + } + + state.metadata, state.layout = metadata, layout + } + + defer p.setReady() + defer p.setState(state) + + if len(addBrokers) != 0 || len(delBrokers) != 0 { + // Only acquire the lock when there is a change of layout. This is an + // infrequent event so we don't risk introducing regular contention on + // the mutex if we were to lock it on every update. + p.mutex.Lock() + defer p.mutex.Unlock() + + if ctx.Err() != nil { + return // the pool has been closed, no need to update + } + + for id := range delBrokers { + if broker := p.conns[id]; broker != nil { + broker.closeIdleConns() + delete(p.conns, id) + } + } + + for id := range addBrokers { + broker := layout.Brokers[id] + p.conns[id] = p.newBrokerConnGroup(Broker{ + Rack: broker.Rack, + Host: broker.Host, + Port: int(broker.Port), + ID: int(broker.ID), + }) + } + } +} + +// discover is the entry point of an internal goroutine for the transport which +// periodically requests updates of the cluster metadata and refreshes the +// transport cached cluster layout. +func (p *connPool) discover(ctx context.Context, wake <-chan event) { + prng := rand.New(rand.NewSource(time.Now().UnixNano())) + metadataTTL := func() time.Duration { + return time.Duration(prng.Int63n(int64(p.metadataTTL))) + } + + timer := time.NewTimer(metadataTTL()) + defer timer.Stop() + + var notify event + var done = ctx.Done() + + for { + c, err := p.grabClusterConn(ctx) + if err != nil { + p.update(ctx, nil, err) + } else { + res := make(async, 1) + req := &meta.Request{ + IncludeClusterAuthorizedOperations: true, + IncludeTopicAuthorizedOperations: true, + } + deadline, cancel := context.WithTimeout(ctx, p.metadataTTL) + c.reqs <- connRequest{ + ctx: deadline, + req: req, + res: res, + } + r, err := res.await(deadline) + cancel() + if err != nil && err == ctx.Err() { + return + } + ret, _ := r.(*meta.Response) + p.update(ctx, ret, err) + } + + if notify != nil { + notify.trigger() + notify = nil + } + + select { + case <-timer.C: + timer.Reset(metadataTTL()) + case <-done: + return + case notify = <-wake: + } + } +} + +// grabBrokerConn returns a connection to a specific broker represented by the +// broker id passed as argument. If the broker id was not known, an error is +// returned. +func (p *connPool) grabBrokerConn(ctx context.Context, brokerID int32) (*conn, error) { + p.mutex.RLock() + g := p.conns[brokerID] + p.mutex.RUnlock() + if g == nil { + return nil, BrokerNotAvailable + } + return g.grabConnOrConnect(ctx) +} + +// grabClusterConn returns the connection to the kafka cluster that the pool is +// configured to connect to. +// +// The transport uses a shared `control` connection to the cluster for any +// requests that aren't supposed to be sent to specific brokers (e.g. Fetch or +// Produce requests). Requests intended to be routed to specific brokers are +// dispatched on a separate pool of connections that the transport maintains. +// This split help avoid head-of-line blocking situations where control requests +// like Metadata would be queued behind large responses from Fetch requests for +// example. +// +// In either cases, the requests are multiplexed so we can keep a minimal number +// of connections open (N+1, where N is the number of brokers in the cluster). +func (p *connPool) grabClusterConn(ctx context.Context) (*conn, error) { + return p.ctrl.grabConnOrConnect(ctx) +} + +func (p *connPool) sendRequest(ctx context.Context, req Request, state connPoolState) promise { + brokerID := int32(-1) + + switch m := req.(type) { + case protocol.BrokerMessage: + // Some requests are supposed to be sent to specific brokers (e.g. the + // partition leaders). They implement the BrokerMessage interface to + // delegate the routing decision to each message type. + broker, err := m.Broker(state.layout) + if err != nil { + return reject(err) + } + brokerID = broker.ID + + case protocol.GroupMessage: + // Some requests are supposed to be sent to a group coordinator, + // look up which broker is currently the coordinator for the group + // so we can get a connection to that broker. + // + // TODO: should we cache the coordinator info? + p := p.sendRequest(ctx, &findcoordinator.Request{Key: m.Group()}, state) + r, err := p.await(ctx) + if err != nil { + return reject(err) + } + brokerID = r.(*findcoordinator.Response).NodeID + } + + var c *conn + var err error + if brokerID >= 0 { + c, err = p.grabBrokerConn(ctx, brokerID) + } else { + c, err = p.grabClusterConn(ctx) + } + if err != nil { + return reject(err) + } + + res := make(async, 1) + + c.reqs <- connRequest{ + ctx: ctx, + req: req, + res: res, + } + + return res +} + +func filterMetadataResponse(req *meta.Request, res *meta.Response) *meta.Response { + ret := *res + + if req.TopicNames != nil { + ret.Topics = make([]meta.ResponseTopic, len(req.TopicNames)) + + for i, topicName := range req.TopicNames { + j, ok := findMetadataTopic(res.Topics, topicName) + if ok { + ret.Topics[i] = res.Topics[j] + } else { + ret.Topics[i] = meta.ResponseTopic{ + ErrorCode: int16(UnknownTopicOrPartition), + Name: topicName, + } + } + } + } + + return &ret +} + +func findMetadataTopic(topics []meta.ResponseTopic, topicName string) (int, bool) { + i := sort.Search(len(topics), func(i int) bool { + return topics[i].Name >= topicName + }) + return i, i >= 0 && i < len(topics) && topics[i].Name == topicName +} + +func sortMetadataBrokers(brokers []meta.ResponseBroker) { + sort.Slice(brokers, func(i, j int) bool { + return brokers[i].NodeID < brokers[j].NodeID + }) +} + +func sortMetadataTopics(topics []meta.ResponseTopic) { + sort.Slice(topics, func(i, j int) bool { + return topics[i].Name < topics[j].Name + }) +} + +func sortMetadataPartitions(partitions []meta.ResponsePartition) { + sort.Slice(partitions, func(i, j int) bool { + return partitions[i].PartitionIndex < partitions[j].PartitionIndex + }) +} + +func makeLayout(metadataResponse *meta.Response) protocol.Cluster { + layout := protocol.Cluster{ + Controller: metadataResponse.ControllerID, + Brokers: make(map[int32]protocol.Broker), + Topics: make(map[string]protocol.Topic), + } + + for _, broker := range metadataResponse.Brokers { + layout.Brokers[broker.NodeID] = protocol.Broker{ + Rack: broker.Rack, + Host: broker.Host, + Port: broker.Port, + ID: broker.NodeID, + } + } + + for _, topic := range metadataResponse.Topics { + if topic.IsInternal { + continue // TODO: do we need to expose those? + } + layout.Topics[topic.Name] = protocol.Topic{ + Name: topic.Name, + Error: topic.ErrorCode, + Partitions: makePartitions(topic.Partitions), + } + } + + return layout +} + +func makePartitions(metadataPartitions []meta.ResponsePartition) map[int32]protocol.Partition { + protocolPartitions := make(map[int32]protocol.Partition, len(metadataPartitions)) + numBrokerIDs := 0 + + for _, p := range metadataPartitions { + numBrokerIDs += len(p.ReplicaNodes) + len(p.IsrNodes) + len(p.OfflineReplicas) + } + + // Reduce the memory footprint a bit by allocating a single buffer to write + // all broker ids. + brokerIDs := make([]int32, 0, numBrokerIDs) + + for _, p := range metadataPartitions { + var rep, isr, off []int32 + brokerIDs, rep = appendBrokerIDs(brokerIDs, p.ReplicaNodes) + brokerIDs, isr = appendBrokerIDs(brokerIDs, p.IsrNodes) + brokerIDs, off = appendBrokerIDs(brokerIDs, p.OfflineReplicas) + + protocolPartitions[p.PartitionIndex] = protocol.Partition{ + ID: p.PartitionIndex, + Error: p.ErrorCode, + Leader: p.LeaderID, + Replicas: rep, + ISR: isr, + Offline: off, + } + } + + return protocolPartitions +} + +func appendBrokerIDs(ids, brokers []int32) ([]int32, []int32) { + i := len(ids) + ids = append(ids, brokers...) + return ids, ids[i:len(ids):len(ids)] +} + +func (p *connPool) newConnGroup(a net.Addr) *connGroup { + return &connGroup{ + addr: a, + pool: p, + broker: Broker{ + ID: -1, + }, + } +} + +func (p *connPool) newBrokerConnGroup(broker Broker) *connGroup { + return &connGroup{ + addr: &networkAddress{ + network: "tcp", + address: net.JoinHostPort(broker.Host, strconv.Itoa(broker.Port)), + }, + pool: p, + broker: broker, + } +} + +type connRequest struct { + ctx context.Context + req Request + res async +} + +// The promise interface is used as a message passing abstraction to coordinate +// between goroutines that handle requests and responses. +type promise interface { + // Waits until the promise is resolved, rejected, or the context canceled. + await(context.Context) (Response, error) +} + +// async is an implementation of the promise interface which supports resolving +// or rejecting the await call asynchronously. +type async chan interface{} + +func (p async) await(ctx context.Context) (Response, error) { + select { + case x := <-p: + switch v := x.(type) { + case nil: + return nil, nil // A nil response is ok (e.g. when RequiredAcks is None) + case Response: + return v, nil + case error: + return nil, v + default: + panic(fmt.Errorf("BUG: promise resolved with impossible value of type %T", v)) + } + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p async) resolve(res Response) { p <- res } + +func (p async) reject(err error) { p <- err } + +// rejected is an implementation of the promise interface which is always +// returns an error. Values of this type are constructed using the reject +// function. +type rejected struct{ err error } + +func reject(err error) promise { return &rejected{err: err} } + +func (p *rejected) await(ctx context.Context) (Response, error) { + return nil, p.err +} + +// joined is an implementation of the promise interface which merges results +// from multiple promises into one await call using a merger. +type joined struct { + promises []promise + requests []Request + merger protocol.Merger +} + +func join(promises []promise, requests []Request, merger protocol.Merger) promise { + return &joined{ + promises: promises, + requests: requests, + merger: merger, + } +} + +func (p *joined) await(ctx context.Context) (Response, error) { + results := make([]interface{}, len(p.promises)) + + for i, sub := range p.promises { + m, err := sub.await(ctx) + if err != nil { + results[i] = err + } else { + results[i] = m + } + } + + return p.merger.Merge(p.requests, results) +} + +// Default dialer used by the transport connections when no Dial function +// was configured by the program. +var defaultDialer = net.Dialer{ + Timeout: 3 * time.Second, + DualStack: true, +} + +// connGroup represents a logical connection group to a kafka broker. The +// actual network connections are lazily open before sending requests, and +// closed if they are unused for longer than the idle timeout. +type connGroup struct { + addr net.Addr + broker Broker + // Immutable state of the connection. + pool *connPool + // Shared state of the connection, this is synchronized on the mutex through + // calls to the synchronized method. Both goroutines of the connection share + // the state maintained in these fields. + mutex sync.Mutex + closed bool + idleConns []*conn // stack of idle connections +} + +func (g *connGroup) closeIdleConns() { + g.mutex.Lock() + conns := g.idleConns + g.idleConns = nil + g.closed = true + g.mutex.Unlock() + + for _, c := range conns { + c.close() + } +} + +func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) { + var rslv = g.pool.resolver + var addr = g.addr + var c *conn + + if rslv == nil { + c = g.grabConn() + } else { + var err error + var broker = g.broker + + if broker.ID < 0 { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil, fmt.Errorf("%s: %w", addr, err) + } + portNumber, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("%s: %w", addr, err) + } + broker.Host = host + broker.Port = portNumber + } + + ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker) + if err != nil { + return nil, err + } + + for _, ipAddr := range ipAddrs { + network := addr.Network() + address := net.JoinHostPort(ipAddr.String(), strconv.Itoa(broker.Port)) + + if c = g.grabConnTo(network, address); c != nil { + break + } + } + } + + if c == nil { + connChan := make(chan *conn) + errChan := make(chan error) + + go func() { + c, err := g.connect(ctx, addr) + if err != nil { + select { + case errChan <- err: + case <-ctx.Done(): + } + } else { + select { + case connChan <- c: + case <-ctx.Done(): + if !g.releaseConn(c) { + c.close() + } + } + } + }() + + select { + case c = <-connChan: + case err := <-errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + return c, nil +} + +func (g *connGroup) grabConnTo(network, address string) *conn { + g.mutex.Lock() + defer g.mutex.Unlock() + + for i := len(g.idleConns) - 1; i >= 0; i-- { + c := g.idleConns[i] + + if c.network == network && c.address == address { + copy(g.idleConns[i:], g.idleConns[i+1:]) + n := len(g.idleConns) - 1 + g.idleConns[n] = nil + g.idleConns = g.idleConns[:n] + + if c.timer != nil { + c.timer.Stop() + } + + return c + } + } + + return nil +} + +func (g *connGroup) grabConn() *conn { + g.mutex.Lock() + defer g.mutex.Unlock() + + if len(g.idleConns) == 0 { + return nil + } + + n := len(g.idleConns) - 1 + c := g.idleConns[n] + g.idleConns[n] = nil + g.idleConns = g.idleConns[:n] + + if c.timer != nil { + c.timer.Stop() + } + + return c +} + +func (g *connGroup) removeConn(c *conn) bool { + g.mutex.Lock() + defer g.mutex.Unlock() + + if c.timer != nil { + c.timer.Stop() + } + + for i, x := range g.idleConns { + if x == c { + copy(g.idleConns[i:], g.idleConns[i+1:]) + n := len(g.idleConns) - 1 + g.idleConns[n] = nil + g.idleConns = g.idleConns[:n] + return true + } + } + + return false +} + +func (g *connGroup) releaseConn(c *conn) bool { + idleTimeout := g.pool.idleTimeout + + g.mutex.Lock() + defer g.mutex.Unlock() + + if g.closed { + return false + } + + if c.timer != nil { + c.timer.Reset(idleTimeout) + } else { + c.timer = time.AfterFunc(idleTimeout, func() { + if g.removeConn(c) { + c.close() + } + }) + } + + g.idleConns = append(g.idleConns, c) + return true +} + +func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) { + deadline := time.Now().Add(g.pool.dialTimeout) + + ctx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + + var network = strings.Split(addr.Network(), ",") + var address = strings.Split(addr.String(), ",") + var netConn net.Conn + var netAddr net.Addr + var err error + + if len(address) > 1 { + // Shuffle the list of addresses to randomize the order in which + // connections are attempted. This prevents routing all connections + // to the first broker (which will usually succeed). + rand.Shuffle(len(address), func(i, j int) { + network[i], network[j] = network[j], network[i] + address[i], address[j] = address[j], address[i] + }) + } + + for i := range address { + netConn, err = g.pool.dial(ctx, network[i], address[i]) + if err == nil { + netAddr = &networkAddress{ + network: network[i], + address: address[i], + } + break + } + } + + if err != nil { + return nil, err + } + + defer func() { + if netConn != nil { + netConn.Close() + } + }() + + if tlsConfig := g.pool.tls; tlsConfig != nil { + if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { + host, _, _ := net.SplitHostPort(netAddr.String()) + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = host + } + netConn = tls.Client(netConn, tlsConfig) + } + + pc := protocol.NewConn(netConn, g.pool.clientID) + pc.SetDeadline(deadline) + + r, err := pc.RoundTrip(new(apiversions.Request)) + if err != nil { + return nil, err + } + res := r.(*apiversions.Response) + ver := make(map[protocol.ApiKey]int16, len(res.ApiKeys)) + + if res.ErrorCode != 0 { + return nil, fmt.Errorf("negotating API versions with kafka broker at %s: %w", g.addr, Error(res.ErrorCode)) + } + + for _, r := range res.ApiKeys { + apiKey := protocol.ApiKey(r.ApiKey) + ver[apiKey] = apiKey.SelectVersion(r.MinVersion, r.MaxVersion) + } + + pc.SetVersions(ver) + pc.SetDeadline(time.Time{}) + + if g.pool.sasl != nil { + if err := authenticateSASL(ctx, pc, g.pool.sasl); err != nil { + return nil, err + } + } + + reqs := make(chan connRequest) + c := &conn{ + network: netAddr.Network(), + address: netAddr.String(), + reqs: reqs, + group: g, + } + go c.run(pc, reqs) + + netConn = nil + return c, nil +} + +type conn struct { + reqs chan<- connRequest + network string + address string + once sync.Once + group *connGroup + timer *time.Timer +} + +func (c *conn) close() { + c.once.Do(func() { close(c.reqs) }) +} + +func (c *conn) run(pc *protocol.Conn, reqs <-chan connRequest) { + defer pc.Close() + + for cr := range reqs { + r, err := c.roundTrip(cr.ctx, pc, cr.req) + if err != nil { + cr.res.reject(err) + if !errors.Is(err, protocol.ErrNoRecord) { + break + } + } else { + cr.res.resolve(r) + } + if !c.group.releaseConn(c) { + break + } + } +} + +func (c *conn) roundTrip(ctx context.Context, pc *protocol.Conn, req Request) (Response, error) { + pprof.SetGoroutineLabels(ctx) + defer pprof.SetGoroutineLabels(context.Background()) + + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + pc.SetDeadline(deadline) + defer pc.SetDeadline(time.Time{}) + } + + return pc.RoundTrip(req) +} + +// authenticateSASL performs all of the required requests to authenticate this +// connection. If any step fails, this function returns with an error. A nil +// error indicates successful authentication. +func authenticateSASL(ctx context.Context, pc *protocol.Conn, mechanism sasl.Mechanism) error { + if err := saslHandshakeRoundTrip(pc, mechanism.Name()); err != nil { + return err + } + + sess, state, err := mechanism.Start(ctx) + if err != nil { + return err + } + + for completed := false; !completed; { + challenge, err := saslAuthenticateRoundTrip(pc, state) + switch err { + case nil: + case io.EOF: + // the broker may communicate a failed exchange by closing the + // connection (esp. in the case where we're passing opaque sasl + // data over the wire since there's no protocol info). + return SASLAuthenticationFailed + default: + return err + } + + completed, state, err = sess.Next(ctx, challenge) + if err != nil { + return err + } + } + + return nil +} + +// saslHandshake sends the SASL handshake message. This will determine whether +// the Mechanism is supported by the cluster. If it's not, this function will +// error out with UnsupportedSASLMechanism. +// +// If the mechanism is unsupported, the handshake request will reply with the +// list of the cluster's configured mechanisms, which could potentially be used +// to facilitate negotiation. At the moment, we are not negotiating the +// mechanism as we believe that brokers are usually known to the client, and +// therefore the client should already know which mechanisms are supported. +// +// See http://kafka.apache.org/protocol.html#The_Messages_SaslHandshake +func saslHandshakeRoundTrip(pc *protocol.Conn, mechanism string) error { + msg, err := pc.RoundTrip(&saslhandshake.Request{ + Mechanism: mechanism, + }) + if err != nil { + return err + } + res := msg.(*saslhandshake.Response) + if res.ErrorCode != 0 { + err = Error(res.ErrorCode) + } + return err +} + +// saslAuthenticate sends the SASL authenticate message. This function must +// be immediately preceded by a successful saslHandshake. +// +// See http://kafka.apache.org/protocol.html#The_Messages_SaslAuthenticate +func saslAuthenticateRoundTrip(pc *protocol.Conn, data []byte) ([]byte, error) { + msg, err := pc.RoundTrip(&saslauthenticate.Request{ + AuthBytes: data, + }) + if err != nil { + return nil, err + } + res := msg.(*saslauthenticate.Response) + if res.ErrorCode != 0 { + err = makeError(res.ErrorCode, res.ErrorMessage) + } + return res.AuthBytes, err +} + +var ( + _ RoundTripper = (*Transport)(nil) +) diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 000000000..71806611a --- /dev/null +++ b/transport_test.go @@ -0,0 +1,34 @@ +package kafka + +import ( + "context" + "crypto/tls" + "net" + "testing" +) + +func TestIssue477(t *testing.T) { + // This test verifies that a connection attempt with a minimal TLS + // configuration does not panic. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + cg := connGroup{ + addr: l.Addr(), + pool: &connPool{ + dial: defaultDialer.DialContext, + tls: &tls.Config{}, + }, + } + + if _, err := cg.connect(context.Background(), cg.addr); err != nil { + // An error is expected here because we are not actually establishing + // a TLS connection to a kafka broker. + t.Log(err) + } else { + t.Error("no error was reported when attempting to establish a TLS connection to a non-TLS endpoint") + } +} diff --git a/write_test.go b/write_test.go index 20f7b2563..80ffd9cbf 100644 --- a/write_test.go +++ b/write_test.go @@ -47,7 +47,6 @@ func TestWriteVarInt(t *testing.T) { } func TestWriteOptimizations(t *testing.T) { - t.Parallel() t.Run("writeFetchRequestV2", testWriteFetchRequestV2) t.Run("writeListOffsetRequestV1", testWriteListOffsetRequestV1) t.Run("writeProduceRequestV2", testWriteProduceRequestV2) @@ -193,24 +192,27 @@ func testWriteOptimization(t *testing.T, h requestHeader, r request, f func(*wri } func TestWriteV2RecordBatch(t *testing.T) { - if !ktesting.KafkaIsAtLeast("0.11.0") { t.Skip("RecordBatch was added in kafka 0.11.0") return } - topic := CreateTopic(t, 1) + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + msgs := make([]Message, 15) for i := range msgs { value := fmt.Sprintf("Sample message content: %d!", i) msgs[i] = Message{Key: []byte("Key"), Value: []byte(value), Headers: []Header{Header{Key: "hk", Value: []byte("hv")}}} } - w := NewWriter(WriterConfig{ - Brokers: []string{"localhost:9092"}, + + w := &Writer{ + Addr: TCP("localhost:9092"), Topic: topic, BatchTimeout: 100 * time.Millisecond, BatchSize: 5, - }) + Transport: client.Transport, + } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() diff --git a/writer.go b/writer.go index 6d618c0a1..0f9412b42 100644 --- a/writer.go +++ b/writer.go @@ -1,37 +1,216 @@ package kafka import ( + "bytes" "context" "errors" - "fmt" "io" - "math/rand" - "sort" + "net" "sync" + "sync/atomic" "time" + + metadataAPI "github.com/segmentio/kafka-go/protocol/metadata" ) // The Writer type provides the implementation of a producer of kafka messages // that automatically distributes messages across partitions of a single topic // using a configurable balancing policy. // -// Instances of Writer are safe to use concurrently from multiple goroutines. +// Writes manage the dispatch of messages across partitions of the topic they +// are configured to write to using a Balancer, and aggregate batches to +// optimize the writes to kafka. +// +// Writers may be configured to be used synchronously or asynchronously. When +// use synchronously, calls to WriteMessages block until the messages have been +// written to kafka. In this mode, the program should inspect the error returned +// by the function and test if it an instance of kafka.WriteErrors in order to +// identify which messages have succeeded or failed, for example: +// +// // Construct a synchronous writer (the default mode). +// w := &kafka.Writer{ +// Addr: kafka.TCP("localhost:9092"), +// Topic: "topic-A", +// RequiredAcks: kafka.RequireAll, +// } +// +// ... +// +// // Passing a context can prevent the operation from blocking indefinitely. +// switch err := w.WriteMessages(ctx, msgs...).(type) { +// case nil: +// case kafka.WriteErrors: +// for i := range msgs { +// if err[i] != nil { +// // handle the error writing msgs[i] +// ... +// } +// } +// default: +// // handle other errors +// ... +// } +// +// In asynchronous mode, the program may configure a completion handler on the +// writer to receive notifications of messages being written to kafka: +// +// w := &kafka.Writer{ +// Addr: kafka.TCP("localhost:9092"), +// Topic: "topic-A", +// RequiredAcks: kafka.RequireAll, +// Async: true, // make the writer asynchronous +// Completion: func(messages []kafka.Message, err error) { +// ... +// }, +// } +// +// ... +// +// // Because the writer is asynchronous, there is no need for the context to +// // be cancelled, the call will never block. +// if err := w.WriteMessages(context.Background(), msgs...); err != nil { +// // Only validation errors would be reported in this case. +// ... +// } +// +// Methods of Writer are safe to use concurrently from multiple goroutines, +// however the writer configuration should not be modified after first use. type Writer struct { - config WriterConfig + // Address of the kafka cluster that this writer is configured to send + // messages to. + // + // This feild is required, attempting to write messages to a writer with a + // nil address will error. + Addr net.Addr + + // Topic is the name of the topic that the writer will produce messages to. + // + // Setting this field or not is a mutually exclusive option. If you set Topic + // here, you must not set Topic for any produced Message. Otherwise, if you do + // not set Topic, every Message must have Topic specified. + Topic string + + // The balancer used to distribute messages across partitions. + // + // The default is to use a round-robin distribution. + Balancer Balancer + + // Limit on how many attempts will be made to deliver a message. + // + // The default is to try at most 10 times. + MaxAttempts int + + // Limit on how many messages will be buffered before being sent to a + // partition. + // + // The default is to use a target batch size of 100 messages. + BatchSize int + + // Limit the maximum size of a request in bytes before being sent to + // a partition. + // + // The default is to use a kafka default value of 1048576. + BatchBytes int64 - mutex sync.RWMutex - closed bool + // Time limit on how often incomplete message batches will be flushed to + // kafka. + // + // The default is to flush at least every second. + BatchTimeout time.Duration + + // Timeout for read operations performed by the Writer. + // + // Defaults to 10 seconds. + ReadTimeout time.Duration - join sync.WaitGroup - msgs chan writerMessage - done chan struct{} + // Timeout for write operation performed by the Writer. + // + // Defaults to 10 seconds. + WriteTimeout time.Duration + + // Number of acknowledges from partition replicas required before receiving + // a response to a produce request, the following values are supported: + // + // RequireNone (0) fire-and-forget, do not wait for acknowledgements from the + // RequireOne (1) wait for the leader to acknowledge the writes + // RequireAll (-1) wait for the full ISR to acknowledge the writes + // + RequiredAcks RequiredAcks + + // Setting this flag to true causes the WriteMessages method to never block. + // It also means that errors are ignored since the caller will not receive + // the returned value. Use this only if you don't care about guarantees of + // whether the messages were written to kafka. + Async bool + + // An optional function called when the writer succeeds or fails the + // delivery of messages to a kafka partition. When writing the messages + // fails, the `err` parameter will be non-nil. + // + // The messages that the Completion function is called with have their + // topic, partition, offset, and time set based on the Produce responses + // received from kafka. All messages passed to a call to the function have + // been written to the same partition. The keys and values of messages are + // referencing the original byte slices carried by messages in the calls to + // WriteMessages. + // + // The function is called from goroutines started by the writer. Calls to + // Close will block on the Completion function calls. When the Writer is + // not writing asynchronously, the WriteMessages call will also block on + // Completion function, which is a useful guarantee if the byte slices + // for the message keys and values are intended to be reused after the + // WriteMessages call returned. + // + // If a completion function panics, the program terminates because the + // panic is not recovered by the writer and bubbles up to the top of the + // goroutine's call stack. + Completion func(messages []Message, err error) + + // Compression set the compression codec to be used to compress messages. + Compression Compression + + // If not nil, specifies a logger used to report internal changes within the + // writer. + Logger Logger + + // ErrorLogger is the logger used to report errors. If nil, the writer falls + // back to using Logger instead. + ErrorLogger Logger + + // A transport used to send messages to kafka clusters. + // + // If nil, DefaultTransport is used. + Transport RoundTripper + + // Atomic flag indicating whether the writer has been closed. + closed uint32 + group sync.WaitGroup + + // Manages the current batch being aggregated on the writer. + mutex sync.Mutex + batches map[topicPartition]*writeBatch // writer stats are all made of atomic values, no need for synchronization. - // Use a pointer to ensure 64-bit alignment of the values. - stats *writerStats + // Use a pointer to ensure 64-bit alignment of the values. The once value is + // used to lazily create the value when first used, allowing programs to use + // the zero-value value of Writer. + once sync.Once + *writerStats + + // If no balancer is configured, the writer uses this one. RoundRobin values + // are safe to use concurrently from multiple goroutines, there is no need + // for extra synchronization to access this field. + roundRobin RoundRobin + + // non-nil when a transport was created by NewWriter, remove in 1.0. + transport *Transport } // WriterConfig is a configuration type used to create new instances of Writer. +// +// DEPRECATED: writer values should be configured directly by assigning their +// exported fields. This type is kept for backward compatibility, and will be +// removed in version 1.0. type WriterConfig struct { // The list of brokers used to discover the partitions available on the // kafka cluster. @@ -42,8 +221,9 @@ type WriterConfig struct { // The topic that the writer will produce messages to. // - // This field is required, attempting to create a writer with an empty topic - // will panic. + // If provided, this will be used to set the topic for all produced messages. + // If not provided, each Message must specify a topic for itself. This must be + // mutually exclusive, otherwise the Writer will return an error. Topic string // The dialer used by the writer to establish connections to the kafka @@ -62,9 +242,10 @@ type WriterConfig struct { // The default is to try at most 10 times. MaxAttempts int - // A hint on the capacity of the writer's internal message queue. - // - // The default is to use a queue capacity of 100 messages. + // DEPRECATED: in versions prior to 0.4, the writer used channels internally + // to dispatch messages to partitions. This has been replaced by an in-memory + // aggregation of batches which uses shared state instead of message passing, + // making this option unnecessary. QueueCapacity int // Limit on how many messages will be buffered before being sent to a @@ -95,16 +276,15 @@ type WriterConfig struct { // Defaults to 10 seconds. WriteTimeout time.Duration - // This interval defines how often the list of partitions is refreshed from - // kafka. It allows the writer to automatically handle when new partitions - // are added to a topic. - // - // The default is to refresh partitions every 15 seconds. + // DEPRECATED: in versions prior to 0.4, the writer used to maintain a cache + // the topic layout. With the change to use a transport to manage connections, + // the responsibility of syncing the cluster layout has been delegated to the + // transport. RebalanceInterval time.Duration - // Connections that were idle for this duration will not be reused. - // - // Defaults to 9 minutes. + // DEPRECACTED: in versions prior to 0.4, the writer used to manage connections + // to the kafka cluster directly. With the change to use a transport to manage + // connections, the writer has no connections to manage directly anymore. IdleConnTimeout time.Duration // Number of acknowledges from partition replicas required before receiving @@ -124,7 +304,6 @@ type WriterConfig struct { Async bool // CompressionCodec set the codec to be used to compress Kafka messages. - // Note that messages are allowed to overwrite the compression codec individually. CompressionCodec // If not nil, specifies a logger used to report internal changes within the @@ -134,40 +313,62 @@ type WriterConfig struct { // ErrorLogger is the logger used to report errors. If nil, the writer falls // back to using Logger instead. ErrorLogger Logger +} - newPartitionWriter func(partition int, config WriterConfig, stats *writerStats) partitionWriter +type topicPartition struct { + topic string + partition int32 +} + +// Validate method validates WriterConfig properties. +func (config *WriterConfig) Validate() error { + if len(config.Brokers) == 0 { + return errors.New("cannot create a kafka writer with an empty list of brokers") + } + return nil } // WriterStats is a data structure returned by a call to Writer.Stats that // exposes details about the behavior of the writer. type WriterStats struct { - Dials int64 `metric:"kafka.writer.dial.count" type:"counter"` - Writes int64 `metric:"kafka.writer.write.count" type:"counter"` - Messages int64 `metric:"kafka.writer.message.count" type:"counter"` - Bytes int64 `metric:"kafka.writer.message.bytes" type:"counter"` - Rebalances int64 `metric:"kafka.writer.rebalance.count" type:"counter"` - Errors int64 `metric:"kafka.writer.error.count" type:"counter"` - - DialTime DurationStats `metric:"kafka.writer.dial.seconds"` + Writes int64 `metric:"kafka.writer.write.count" type:"counter"` + Messages int64 `metric:"kafka.writer.message.count" type:"counter"` + Bytes int64 `metric:"kafka.writer.message.bytes" type:"counter"` + Errors int64 `metric:"kafka.writer.error.count" type:"counter"` + + BatchTime DurationStats `metric:"kafka.writer.batch.seconds"` WriteTime DurationStats `metric:"kafka.writer.write.seconds"` WaitTime DurationStats `metric:"kafka.writer.wait.seconds"` Retries SummaryStats `metric:"kafka.writer.retries.count"` BatchSize SummaryStats `metric:"kafka.writer.batch.size"` BatchBytes SummaryStats `metric:"kafka.writer.batch.bytes"` - MaxAttempts int64 `metric:"kafka.writer.attempts.max" type:"gauge"` - MaxBatchSize int64 `metric:"kafka.writer.batch.max" type:"gauge"` - BatchTimeout time.Duration `metric:"kafka.writer.batch.timeout" type:"gauge"` - ReadTimeout time.Duration `metric:"kafka.writer.read.timeout" type:"gauge"` - WriteTimeout time.Duration `metric:"kafka.writer.write.timeout" type:"gauge"` - RebalanceInterval time.Duration `metric:"kafka.writer.rebalance.interval" type:"gauge"` - RequiredAcks int64 `metric:"kafka.writer.acks.required" type:"gauge"` - Async bool `metric:"kafka.writer.async" type:"gauge"` - QueueLength int64 `metric:"kafka.writer.queue.length" type:"gauge"` - QueueCapacity int64 `metric:"kafka.writer.queue.capacity" type:"gauge"` + MaxAttempts int64 `metric:"kafka.writer.attempts.max" type:"gauge"` + MaxBatchSize int64 `metric:"kafka.writer.batch.max" type:"gauge"` + BatchTimeout time.Duration `metric:"kafka.writer.batch.timeout" type:"gauge"` + ReadTimeout time.Duration `metric:"kafka.writer.read.timeout" type:"gauge"` + WriteTimeout time.Duration `metric:"kafka.writer.write.timeout" type:"gauge"` + RequiredAcks int64 `metric:"kafka.writer.acks.required" type:"gauge"` + Async bool `metric:"kafka.writer.async" type:"gauge"` + + Topic string `tag:"topic"` - ClientID string `tag:"client_id"` - Topic string `tag:"topic"` + // DEPRECATED: these fields will only be reported for backward compatibility + // if the Writer was constructed with NewWriter. + Dials int64 `metric:"kafka.writer.dial.count" type:"counter"` + DialTime DurationStats `metric:"kafka.writer.dial.seconds"` + + // DEPRECATED: these fields were meaningful prior to kafka-go 0.4, changes + // to the internal implementation and the introduction of the transport type + // made them unnecessary. + // + // The values will be zero but are left for backward compatibility to avoid + // breaking programs that used these fields. + Rebalances int64 + RebalanceInterval time.Duration + QueueLength int64 + QueueCapacity int64 + ClientID string } // writerStats is a struct that contains statistics on a writer. @@ -180,9 +381,9 @@ type writerStats struct { writes counter messages counter bytes counter - rebalances counter errors counter dialTime summary + batchTime summary writeTime summary waitTime summary retries summary @@ -190,23 +391,12 @@ type writerStats struct { batchSizeBytes summary } -// Validate method validates WriterConfig properties. -func (config *WriterConfig) Validate() error { - - if len(config.Brokers) == 0 { - return errors.New("cannot create a kafka writer with an empty list of brokers") - } - - if len(config.Topic) == 0 { - return errors.New("cannot create a kafka writer with an empty topic") - } - - return nil -} - // NewWriter creates and returns a new Writer configured with config. +// +// DEPRECATED: Writer value can be instantiated and configured directly, +// this function is retained for backward compatibility and will be removed +// in version 1.0. func NewWriter(config WriterConfig) *Writer { - if err := config.Validate(); err != nil { panic(err) } @@ -219,63 +409,124 @@ func NewWriter(config WriterConfig) *Writer { config.Balancer = &RoundRobin{} } - if config.newPartitionWriter == nil { - config.newPartitionWriter = func(partition int, config WriterConfig, stats *writerStats) partitionWriter { - return newWriter(partition, config, stats) - } + // Converts the pre-0.4 Dialer API into a Transport. + kafkaDialer := DefaultDialer + if config.Dialer != nil { + kafkaDialer = config.Dialer } - if config.MaxAttempts == 0 { - config.MaxAttempts = 10 + dialer := (&net.Dialer{ + Timeout: kafkaDialer.Timeout, + Deadline: kafkaDialer.Deadline, + LocalAddr: kafkaDialer.LocalAddr, + DualStack: kafkaDialer.DualStack, + FallbackDelay: kafkaDialer.FallbackDelay, + KeepAlive: kafkaDialer.KeepAlive, + }) + + var resolver Resolver + if r, ok := kafkaDialer.Resolver.(*net.Resolver); ok { + dialer.Resolver = r + } else { + resolver = kafkaDialer.Resolver } - if config.QueueCapacity == 0 { - config.QueueCapacity = 100 + stats := new(writerStats) + // For backward compatibility with the pre-0.4 APIs, support custom + // resolvers by wrapping the dial function. + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { + start := time.Now() + defer func() { + stats.dials.observe(1) + stats.dialTime.observe(int64(time.Since(start))) + }() + address, err := lookupHost(ctx, addr, resolver) + if err != nil { + return nil, err + } + return dialer.DialContext(ctx, network, address) } - if config.BatchSize == 0 { - config.BatchSize = 100 + idleTimeout := config.IdleConnTimeout + if idleTimeout == 0 { + // Historical default value of WriterConfig.IdleTimeout, 9 minutes seems + // like it is way too long when there is no ping mechanism in the kafka + // protocol. + idleTimeout = 9 * time.Minute } - if config.BatchBytes == 0 { - // 1048576 == 1MB which is the Kafka default. - config.BatchBytes = 1048576 + metadataTTL := config.RebalanceInterval + if metadataTTL == 0 { + // Historical default value of WriterConfig.RebalanceInterval. + metadataTTL = 15 * time.Second } - if config.BatchTimeout == 0 { - config.BatchTimeout = 1 * time.Second + transport := &Transport{ + Dial: dial, + SASL: kafkaDialer.SASLMechanism, + TLS: kafkaDialer.TLS, + ClientID: kafkaDialer.ClientID, + IdleTimeout: idleTimeout, + MetadataTTL: metadataTTL, } - if config.ReadTimeout == 0 { - config.ReadTimeout = 10 * time.Second + w := &Writer{ + Addr: TCP(config.Brokers...), + Topic: config.Topic, + MaxAttempts: config.MaxAttempts, + BatchSize: config.BatchSize, + Balancer: config.Balancer, + BatchBytes: int64(config.BatchBytes), + BatchTimeout: config.BatchTimeout, + ReadTimeout: config.ReadTimeout, + WriteTimeout: config.WriteTimeout, + RequiredAcks: RequiredAcks(config.RequiredAcks), + Async: config.Async, + Logger: config.Logger, + ErrorLogger: config.ErrorLogger, + Transport: transport, + transport: transport, + writerStats: stats, } - if config.WriteTimeout == 0 { - config.WriteTimeout = 10 * time.Second + if config.RequiredAcks == 0 { + // Historically the writers created by NewWriter have used "all" as the + // default value when 0 was specified. + w.RequiredAcks = RequireAll } - if config.RebalanceInterval == 0 { - config.RebalanceInterval = 15 * time.Second + if config.CompressionCodec != nil { + w.Compression = Compression(config.CompressionCodec.Code()) } - if config.IdleConnTimeout == 0 { - config.IdleConnTimeout = 9 * time.Minute + + return w +} + +// Close flushes pending writes, and waits for all writes to complete before +// returning. Calling Close also prevents new writes from being submitted to +// the writer, further calls to WriteMessages and the like will fail with +// io.ErrClosedPipe. +func (w *Writer) Close() error { + w.markClosed() + // If batches are pending, trigger them so messages get sent. + w.mutex.Lock() + + for _, batch := range w.batches { + batch.trigger() } - w := &Writer{ - config: config, - msgs: make(chan writerMessage, config.QueueCapacity), - done: make(chan struct{}), - stats: &writerStats{ - dialTime: makeSummary(), - writeTime: makeSummary(), - waitTime: makeSummary(), - retries: makeSummary(), - }, + for partition := range w.batches { + delete(w.batches, partition) } - w.join.Add(1) - go w.run() - return w + w.mutex.Unlock() + w.group.Wait() + + if w.transport != nil { + w.transport.CloseIdleConnections() + } + + return nil } // WriteMessages writes a batch of messages to the kafka topic configured on this @@ -294,8 +545,8 @@ func NewWriter(config WriterConfig) *Writer { // best way to achieve good batching behavior is to share one Writer amongst // multiple go routines. // -// When the method returns an error, there's no way to know yet which messages -// have succeeded of failed. +// When the method returns an error, it may be of type kafka.WriteError to allow +// the caller to determine the status of each message. // // The context passed as first argument may also be used to asynchronously // cancel the operation. Note that in this case there are no guarantees made on @@ -303,536 +554,545 @@ func NewWriter(config WriterConfig) *Writer { // whole batch failed and re-write the messages later (which could then cause // duplicates). func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { + if w.Addr == nil { + return errors.New("kafka.(*Writer).WriteMessages: cannot create a kafka writer with a nil address") + } + + w.group.Add(1) + defer w.group.Done() + + if w.isClosed() { + return io.ErrClosedPipe + } + if len(msgs) == 0 { return nil } - var err error - var res chan error - if !w.config.Async { - res = make(chan error, len(msgs)) + balancer := w.balancer() + batchBytes := w.batchBytes() + + for i := range msgs { + n := int64(msgs[i].size()) + if n > batchBytes { + // This error is left for backward compatibility with historical + // behavior, but it can yield O(N^2) behaviors. The expectations + // are that the program will check if WriteMessages returned a + // MessageTooLargeError, discard the message that was exceeding + // the maximum size, and try again. + return messageTooLarge(msgs, i) + } } - t0 := time.Now() - for attempt := 0; attempt < w.config.MaxAttempts; attempt++ { - w.mutex.RLock() + // We use int32 here to half the memory footprint (compared to using int + // on 64 bits architectures). We map lists of the message indexes instead + // of the message values for the same reason, int32 is 4 bytes, vs a full + // Message value which is 100+ bytes and contains pointers and contributes + // to increasing GC work. + assignments := make(map[topicPartition][]int32) - if w.closed { - w.mutex.RUnlock() - return io.ErrClosedPipe + for i, msg := range msgs { + topic, err := w.chooseTopic(msg) + if err != nil { + return err } - for i, msg := range msgs { - if int(msg.size()) > w.config.BatchBytes { - err := MessageTooLargeError{ - Message: msg, - Remaining: msgs[i+1:], - } - w.mutex.RUnlock() - return err - } - select { - case w.msgs <- writerMessage{ - msg: msg, - res: res, - }: - case <-ctx.Done(): - w.mutex.RUnlock() - return ctx.Err() - } + numPartitions, err := w.partitions(ctx, topic) + if err != nil { + return err } - w.mutex.RUnlock() + partition := balancer.Balance(msg, loadCachedPartitions(numPartitions)...) - if w.config.Async { - break + key := topicPartition{ + topic: topic, + partition: int32(partition), } - var retry []Message - - for i := 0; i != len(msgs); i++ { - select { - case e := <-res: - if e != nil { - if we, ok := e.(*writerError); ok { - w.stats.retries.observe(1) - retry, err = append(retry, we.msg), we.err - } else { - err = e - } - } - case <-ctx.Done(): - return ctx.Err() - } - } + assignments[key] = append(assignments[key], int32(i)) + } - if msgs = retry; len(msgs) == 0 { - break - } + batches := w.batchMessages(msgs, assignments) + if w.Async { + return nil + } - timer := time.NewTimer(backoff(attempt+1, 100*time.Millisecond, 1*time.Second)) + done := ctx.Done() + hasErrors := false + for batch := range batches { select { - case <-timer.C: - // Only clear the error (so we retry the loop) if we have more retries, otherwise - // we risk silencing the error. - if attempt < w.config.MaxAttempts-1 { - err = nil + case <-done: + return ctx.Err() + case <-batch.done: + if batch.err != nil { + hasErrors = true } - case <-ctx.Done(): - err = ctx.Err() - case <-w.done: - err = io.ErrClosedPipe } - timer.Stop() + } - if err != nil { - break + if !hasErrors { + return nil + } + + werr := make(WriteErrors, len(msgs)) + + for batch, indexes := range batches { + for _, i := range indexes { + werr[i] = batch.err } } - w.stats.writeTime.observeDuration(time.Since(t0)) - return err + return werr } -// Stats returns a snapshot of the writer stats since the last time the method -// was called, or since the writer was created if it is called for the first -// time. -// -// A typical use of this method is to spawn a goroutine that will periodically -// call Stats on a kafka writer and report the metrics to a stats collection -// system. -func (w *Writer) Stats() WriterStats { - return WriterStats{ - Dials: w.stats.dials.snapshot(), - Writes: w.stats.writes.snapshot(), - Messages: w.stats.messages.snapshot(), - Bytes: w.stats.bytes.snapshot(), - Rebalances: w.stats.rebalances.snapshot(), - Errors: w.stats.errors.snapshot(), - DialTime: w.stats.dialTime.snapshotDuration(), - WriteTime: w.stats.writeTime.snapshotDuration(), - WaitTime: w.stats.waitTime.snapshotDuration(), - Retries: w.stats.retries.snapshot(), - BatchSize: w.stats.batchSize.snapshot(), - BatchBytes: w.stats.batchSizeBytes.snapshot(), - MaxAttempts: int64(w.config.MaxAttempts), - MaxBatchSize: int64(w.config.BatchSize), - BatchTimeout: w.config.BatchTimeout, - ReadTimeout: w.config.ReadTimeout, - WriteTimeout: w.config.WriteTimeout, - RebalanceInterval: w.config.RebalanceInterval, - RequiredAcks: int64(w.config.RequiredAcks), - Async: w.config.Async, - QueueLength: int64(len(w.msgs)), - QueueCapacity: int64(cap(w.msgs)), - ClientID: w.config.Dialer.ClientID, - Topic: w.config.Topic, - } -} - -// Close flushes all buffered messages and closes the writer. The call to Close -// aborts any concurrent calls to WriteMessages, which then return with the -// io.ErrClosedPipe error. -func (w *Writer) Close() (err error) { +func (w *Writer) batchMessages(messages []Message, assignments map[topicPartition][]int32) map[*writeBatch][]int32 { + var batches map[*writeBatch][]int32 + if !w.Async { + batches = make(map[*writeBatch][]int32, len(assignments)) + } + + batchSize := w.batchSize() + batchBytes := w.batchBytes() + w.mutex.Lock() + defer w.mutex.Unlock() - if !w.closed { - w.closed = true - close(w.msgs) - close(w.done) + if w.batches == nil { + w.batches = map[topicPartition]*writeBatch{} } - w.mutex.Unlock() - w.join.Wait() - return + for key, indexes := range assignments { + for _, i := range indexes { + assignMessage: + batch := w.batches[key] + if batch == nil { + batch = w.newWriteBatch(key) + w.batches[key] = batch + } + if !batch.add(messages[i], batchSize, batchBytes) { + batch.trigger() + delete(w.batches, key) + goto assignMessage + } + if batch.full(batchSize, batchBytes) { + batch.trigger() + delete(w.batches, key) + } + if !w.Async { + batches[batch] = append(batches[batch], i) + } + } + } + + return batches +} + +func (w *Writer) newWriteBatch(key topicPartition) *writeBatch { + batch := newWriteBatch(time.Now(), w.batchTimeout()) + w.group.Add(1) + go func() { + defer w.group.Done() + w.writeBatch(key, batch) + }() + return batch } -func (w *Writer) run() { - defer w.join.Done() +func (w *Writer) writeBatch(key topicPartition, batch *writeBatch) { + // This goroutine has taken ownership of the batch, it is responsible + // for waiting for the batch to be ready (because it became full), or + // to timeout. + select { + case <-batch.timer.C: + // The batch timed out, we want to detach it from the writer to + // prevent more messages from being added. + w.mutex.Lock() + if batch == w.batches[key] { + delete(w.batches, key) + } + w.mutex.Unlock() - ticker := time.NewTicker(w.config.RebalanceInterval) - defer ticker.Stop() + case <-batch.ready: + // The batch became full, it was removed from the writer and its + // ready channel was closed. We need to close the timer to avoid + // having it leak until it expires. + batch.timer.Stop() + } - var rebalance = true - var writers = make(map[int]partitionWriter) - var partitions []int + stats := w.stats() + stats.batchTime.observe(int64(time.Since(batch.time))) + stats.batchSize.observe(int64(len(batch.msgs))) + stats.batchSizeBytes.observe(batch.bytes) + + var res *ProduceResponse var err error - for { - if rebalance { - w.stats.rebalances.observe(1) - rebalance = false + for attempt, maxAttempts := 0, w.maxAttempts(); attempt < maxAttempts; attempt++ { + if attempt != 0 { + stats.retries.observe(1) + // TODO: should there be a way to asynchronously cancel this + // operation? + // + // * If all goroutines that added message to this batch have stopped + // waiting for it, should we abort? + // + // * If the writer has been closed? It reduces the durability + // guarantees to abort, but may be better to avoid long wait times + // on close. + // + delay := backoff(attempt, 100*time.Millisecond, 1*time.Second) + w.withLogger(func(log Logger) { + log.Printf("backing off %s writing %d messages to %s (partition: %d)", delay, len(batch.msgs), key.topic, key.partition) + }) + time.Sleep(delay) + } + + w.withLogger(func(log Logger) { + log.Printf("writing %d messages to %s (partition: %d)", len(batch.msgs), key.topic, key.partition) + }) - var newPartitions []int - var oldPartitions = partitions + start := time.Now() + res, err = w.produce(key, batch) + + stats.writes.observe(1) + stats.messages.observe(int64(len(batch.msgs))) + stats.bytes.observe(batch.bytes) + // stats.writeTime used to report the duration of WriteMessages, but the + // implementation was broken and reporting values in the nanoseconds + // range. In kafka-go 0.4, we recylced this value to instead report the + // duration of produce requests, and changed the stats.waitTime value to + // report the time that kafka has throttled the requests for. + stats.writeTime.observe(int64(time.Since(start))) + + if res != nil { + err = res.Error + stats.waitTime.observe(int64(res.Throttle)) + } - if newPartitions, err = w.partitions(); err == nil { - for _, partition := range diffp(oldPartitions, newPartitions) { - w.close(writers[partition]) - delete(writers, partition) - } + if err == nil { + break + } - for _, partition := range diffp(newPartitions, oldPartitions) { - writers[partition] = w.open(partition) - } + stats.errors.observe(1) - partitions = newPartitions - } + w.withErrorLogger(func(log Logger) { + log.Printf("error writing messages to %s (partition %d): %s", key.topic, key.partition, err) + }) + + if !isTemporary(err) { + break } + } - select { - case wm, ok := <-w.msgs: - if !ok { - for _, writer := range writers { - w.close(writer) - } - return - } + if res != nil { + for i := range batch.msgs { + m := &batch.msgs[i] + m.Topic = key.topic + m.Partition = int(key.partition) + m.Offset = res.BaseOffset + int64(i) - if len(partitions) != 0 { - selectedPartition := w.config.Balancer.Balance(wm.msg, partitions...) - if pw, ok := writers[selectedPartition]; !ok { - err = fmt.Errorf("write balancer chose nonexistant partition %d for topic %s", selectedPartition, w.config.Topic) - if wm.res != nil { - wm.res <- &writerError{msg: wm.msg, err: err} - } - } else { - pw.messages() <- wm - } - } else { - // No partitions were found because the topic doesn't exist. - if err == nil { - err = fmt.Errorf("failed to find any partitions for topic %s", w.config.Topic) - } - if wm.res != nil { - wm.res <- &writerError{msg: wm.msg, err: err} - } + if m.Time.IsZero() { + m.Time = res.LogAppendTime } - - case <-ticker.C: - rebalance = true } } + + if w.Completion != nil { + w.Completion(batch.msgs, err) + } + + batch.complete(err) } -func (w *Writer) partitions() (partitions []int, err error) { - for _, broker := range shuffledStrings(w.config.Brokers) { - var conn *Conn - var plist []Partition +func (w *Writer) produce(key topicPartition, batch *writeBatch) (*ProduceResponse, error) { + timeout := w.writeTimeout() - if conn, err = w.config.Dialer.Dial("tcp", broker); err != nil { - continue - } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() - conn.SetReadDeadline(time.Now().Add(w.config.ReadTimeout)) - plist, err = conn.ReadPartitions(w.config.Topic) - conn.Close() + return w.client(timeout).Produce(ctx, &ProduceRequest{ + Partition: int(key.partition), + Topic: key.topic, + RequiredAcks: w.RequiredAcks, + Compression: w.Compression, + Records: &writerRecords{ + msgs: batch.msgs, + }, + }) +} - if err == nil { - partitions = make([]int, len(plist)) - for i, p := range plist { - partitions[i] = p.ID +func (w *Writer) partitions(ctx context.Context, topic string) (int, error) { + client := w.client(w.readTimeout()) + // Here we use the transport directly as an optimization to avoid the + // construction of temporary request and response objects made by the + // (*Client).Metadata API. + // + // It is expected that the transport will optimize this request by + // caching recent results (the kafka.Transport types does). + r, err := client.transport().RoundTrip(ctx, client.Addr, &metadataAPI.Request{ + TopicNames: []string{topic}, + }) + if err != nil { + return 0, err + } + for _, t := range r.(*metadataAPI.Response).Topics { + if t.Name == topic { + // This should always hit, unless kafka has a bug. + if t.ErrorCode != 0 { + return 0, Error(t.ErrorCode) } - break + return len(t.Partitions), nil } } + return 0, UnknownTopicOrPartition +} - sort.Ints(partitions) - return +func (w *Writer) markClosed() { + atomic.StoreUint32(&w.closed, 1) } -func (w *Writer) open(partition int) partitionWriter { - return w.config.newPartitionWriter(partition, w.config, w.stats) +func (w *Writer) isClosed() bool { + return atomic.LoadUint32(&w.closed) != 0 } -func (w *Writer) close(writer partitionWriter) { - w.join.Add(1) - go func() { - writer.close() - w.join.Done() - }() +func (w *Writer) client(timeout time.Duration) *Client { + return &Client{ + Addr: w.Addr, + Transport: w.Transport, + Timeout: timeout, + } } -func diffp(new []int, old []int) (diff []int) { - for _, p := range new { - if i := sort.SearchInts(old, p); i == len(old) || old[i] != p { - diff = append(diff, p) - } +func (w *Writer) balancer() Balancer { + if w.Balancer != nil { + return w.Balancer } - return -} - -type partitionWriter interface { - messages() chan<- writerMessage - close() -} - -type writer struct { - brokers []string - topic string - partition int - requiredAcks int - batchSize int - maxMessageBytes int - batchTimeout time.Duration - writeTimeout time.Duration - idleConnTimeout time.Duration - dialer *Dialer - msgs chan writerMessage - join sync.WaitGroup - stats *writerStats - codec CompressionCodec - logger Logger - errorLogger Logger -} - -func newWriter(partition int, config WriterConfig, stats *writerStats) *writer { - w := &writer{ - brokers: config.Brokers, - topic: config.Topic, - partition: partition, - requiredAcks: config.RequiredAcks, - batchSize: config.BatchSize, - maxMessageBytes: config.BatchBytes, - batchTimeout: config.BatchTimeout, - writeTimeout: config.WriteTimeout, - idleConnTimeout: config.IdleConnTimeout, - dialer: config.Dialer, - msgs: make(chan writerMessage, config.QueueCapacity), - stats: stats, - codec: config.CompressionCodec, - logger: config.Logger, - errorLogger: config.ErrorLogger, - } - w.join.Add(1) - go w.run() - return w + return &w.roundRobin } -func (w *writer) close() { - close(w.msgs) - w.join.Wait() +func (w *Writer) maxAttempts() int { + if w.MaxAttempts > 0 { + return w.MaxAttempts + } + // TODO: this is a very high default, if something has failed 9 times it + // seems unlikely it will succeed on the 10th attempt. However, it does + // carry the risk to greatly increase the volume of requests sent to the + // kafka cluster. We should consider reducing this default (3?). + return 10 } -func (w *writer) messages() chan<- writerMessage { - return w.msgs +func (w *Writer) batchSize() int { + if w.BatchSize > 0 { + return w.BatchSize + } + return 100 } -func (w *writer) withLogger(do func(Logger)) { - if w.logger != nil { - do(w.logger) +func (w *Writer) batchBytes() int64 { + if w.BatchBytes > 0 { + return w.BatchBytes } + return 1048576 } -func (w *writer) withErrorLogger(do func(Logger)) { - if w.errorLogger != nil { - do(w.errorLogger) - } else { - w.withLogger(do) +func (w *Writer) batchTimeout() time.Duration { + if w.BatchTimeout > 0 { + return w.BatchTimeout } + return 1 * time.Second } -func (w *writer) run() { - defer w.join.Done() +func (w *Writer) readTimeout() time.Duration { + if w.ReadTimeout > 0 { + return w.ReadTimeout + } + return 10 * time.Second +} - batchTimer := time.NewTimer(0) - <-batchTimer.C - batchTimerRunning := false - defer batchTimer.Stop() +func (w *Writer) writeTimeout() time.Duration { + if w.WriteTimeout > 0 { + return w.WriteTimeout + } + return 10 * time.Second +} - var conn *Conn - var done bool - var batch = make([]Message, 0, w.batchSize) - var resch = make([](chan<- error), 0, w.batchSize) - var lastMsg writerMessage - var batchSizeBytes int - var idleConnDeadline time.Time +func (w *Writer) withLogger(do func(Logger)) { + if w.Logger != nil { + do(w.Logger) + } +} - defer func() { - if conn != nil { - conn.Close() - } - }() +func (w *Writer) withErrorLogger(do func(Logger)) { + if w.ErrorLogger != nil { + do(w.ErrorLogger) + } else { + w.withLogger(do) + } +} - for !done { - var mustFlush bool - // lstMsg gets set when the next message would put the maxMessageBytes over the limit. - // If a lstMsg exists we need to add it to the batch so we don't lose it. - if len(lastMsg.msg.Value) != 0 { - batch = append(batch, lastMsg.msg) - if lastMsg.res != nil { - resch = append(resch, lastMsg.res) - } - batchSizeBytes += int(lastMsg.msg.size()) - lastMsg = writerMessage{} - if !batchTimerRunning { - batchTimer.Reset(w.batchTimeout) - batchTimerRunning = true - } +func (w *Writer) stats() *writerStats { + w.once.Do(func() { + // This field is not nil when the writer was constructed with NewWriter + // to share the value with the dial function and count dials. + if w.writerStats == nil { + w.writerStats = new(writerStats) } - select { - case wm, ok := <-w.msgs: - if !ok { - done, mustFlush = true, true - } else { - if int(wm.msg.size())+batchSizeBytes > w.maxMessageBytes { - // If the size of the current message puts us over the maxMessageBytes limit, - // store the message but don't send it in this batch. - mustFlush = true - lastMsg = wm - break - } - batch = append(batch, wm.msg) - if wm.res != nil { - resch = append(resch, wm.res) - } - batchSizeBytes += int(wm.msg.size()) - mustFlush = len(batch) >= w.batchSize || batchSizeBytes >= w.maxMessageBytes - } - if !batchTimerRunning { - batchTimer.Reset(w.batchTimeout) - batchTimerRunning = true - } + }) + return w.writerStats +} - case <-batchTimer.C: - mustFlush = true - batchTimerRunning = false - } +// Stats returns a snapshot of the writer stats since the last time the method +// was called, or since the writer was created if it is called for the first +// time. +// +// A typical use of this method is to spawn a goroutine that will periodically +// call Stats on a kafka writer and report the metrics to a stats collection +// system. +func (w *Writer) Stats() WriterStats { + stats := w.stats() + return WriterStats{ + Dials: stats.dials.snapshot(), + Writes: stats.writes.snapshot(), + Messages: stats.messages.snapshot(), + Bytes: stats.bytes.snapshot(), + Errors: stats.errors.snapshot(), + DialTime: stats.dialTime.snapshotDuration(), + BatchTime: stats.batchTime.snapshotDuration(), + WriteTime: stats.writeTime.snapshotDuration(), + WaitTime: stats.waitTime.snapshotDuration(), + Retries: stats.retries.snapshot(), + BatchSize: stats.batchSize.snapshot(), + BatchBytes: stats.batchSizeBytes.snapshot(), + MaxAttempts: int64(w.MaxAttempts), + MaxBatchSize: int64(w.BatchSize), + BatchTimeout: w.BatchTimeout, + ReadTimeout: w.ReadTimeout, + WriteTimeout: w.WriteTimeout, + RequiredAcks: int64(w.RequiredAcks), + Async: w.Async, + Topic: w.Topic, + } +} - if mustFlush { - w.stats.batchSizeBytes.observe(int64(batchSizeBytes)) - if batchTimerRunning { - if stopped := batchTimer.Stop(); !stopped { - <-batchTimer.C - } - batchTimerRunning = false - } - if conn != nil && time.Now().After(idleConnDeadline) { - conn.Close() - conn = nil - } - if len(batch) == 0 { - continue - } - var err error - if conn, err = w.write(conn, batch, resch); err != nil { - if conn != nil { - conn.Close() - conn = nil - } - } - idleConnDeadline = time.Now().Add(w.idleConnTimeout) - for i := range batch { - batch[i] = Message{} - } +func (w *Writer) chooseTopic(msg Message) (string, error) { + // w.Topic and msg.Topic are mutually exclusive, meaning only 1 must be set + // otherwise we will return an error. + if (w.Topic != "" && msg.Topic != "") || (w.Topic == "" && msg.Topic == "") { + return "", InvalidMessage + } - for i := range resch { - resch[i] = nil - } - batch = batch[:0] - resch = resch[:0] - batchSizeBytes = 0 - } + // now we choose the topic, depending on which one is not empty + if msg.Topic != "" { + return msg.Topic, nil } + + return w.Topic, nil } -func (w *writer) dial() (conn *Conn, err error) { - for _, broker := range shuffledStrings(w.brokers) { - t0 := time.Now() - if conn, err = w.dialer.DialLeader(context.Background(), "tcp", broker, w.topic, w.partition); err == nil { - t1 := time.Now() - w.stats.dials.observe(1) - w.stats.dialTime.observeDuration(t1.Sub(t0)) - conn.SetRequiredAcks(w.requiredAcks) - break - } +type writeBatch struct { + time time.Time + msgs []Message + size int + bytes int64 + ready chan struct{} + done chan struct{} + timer *time.Timer + err error // result of the batch completion +} + +func newWriteBatch(now time.Time, timeout time.Duration) *writeBatch { + return &writeBatch{ + time: now, + ready: make(chan struct{}), + done: make(chan struct{}), + timer: time.NewTimer(timeout), } - return } -func (w *writer) write(conn *Conn, batch []Message, resch [](chan<- error)) (ret *Conn, err error) { - w.stats.writes.observe(1) - if conn == nil { - if conn, err = w.dial(); err != nil { - w.stats.errors.observe(1) - w.withErrorLogger(func(logger Logger) { - logger.Printf("error dialing kafka brokers for topic %s (partition %d): %s", w.topic, w.partition, err) - }) - for i, res := range resch { - res <- &writerError{msg: batch[i], err: err} - } - return - } +func (b *writeBatch) add(msg Message, maxSize int, maxBytes int64) bool { + bytes := int64(msg.size()) + + if b.size > 0 && (b.bytes+bytes) > maxBytes { + return false } - t0 := time.Now() - conn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) - if _, err = conn.WriteCompressedMessages(w.codec, batch...); err != nil { - w.stats.errors.observe(1) - w.withErrorLogger(func(logger Logger) { - logger.Printf("error writing messages to %s (partition %d): %s", w.topic, w.partition, err) - }) - for i, res := range resch { - res <- &writerError{msg: batch[i], err: err} - } - } else { - for _, m := range batch { - w.stats.messages.observe(1) - w.stats.bytes.observe(int64(len(m.Key) + len(m.Value))) - } - for _, res := range resch { - res <- nil - } + if cap(b.msgs) == 0 { + b.msgs = make([]Message, 0, maxSize) } - t1 := time.Now() - w.stats.waitTime.observeDuration(t1.Sub(t0)) - w.stats.batchSize.observe(int64(len(batch))) - ret = conn - return + b.msgs = append(b.msgs, msg) + b.size++ + b.bytes += bytes + return true } -type writerMessage struct { - msg Message - res chan<- error +func (b *writeBatch) full(maxSize int, maxBytes int64) bool { + return b.size >= maxSize || b.bytes >= maxBytes } -type writerError struct { - msg Message - err error +func (b *writeBatch) trigger() { + close(b.ready) } -func (e *writerError) Cause() error { - return e.err +func (b *writeBatch) complete(err error) { + b.err = err + close(b.done) } -func (e *writerError) Error() string { - return e.err.Error() +type writerRecords struct { + msgs []Message + index int + record Record + key bytesReadCloser + value bytesReadCloser } -func (e *writerError) Temporary() bool { - return isTemporary(e.err) +func (r *writerRecords) ReadRecord() (*Record, error) { + if r.index >= 0 && r.index < len(r.msgs) { + m := &r.msgs[r.index] + r.index++ + r.record = Record{ + Time: m.Time, + Headers: m.Headers, + } + if m.Key != nil { + r.key.Reset(m.Key) + r.record.Key = &r.key + } + if m.Value != nil { + r.value.Reset(m.Value) + r.record.Value = &r.value + } + return &r.record, nil + } + return nil, io.EOF } -func (e *writerError) Timeout() bool { - return isTimeout(e.err) -} +type bytesReadCloser struct{ bytes.Reader } + +func (*bytesReadCloser) Close() error { return nil } -func shuffledStrings(list []string) []string { - shuffledList := make([]string, len(list)) - copy(shuffledList, list) +// A cache of []int values passed to balancers of writers, used to amortize the +// heap allocation of the partition index lists. +// +// With hindsight, the use of `...int` to pass the partition list to Balancers +// was not the best design choice: kafka partition numbers are monotonically +// increasing, we could have simply passed the number of partitions instead. +// If we ever revisit this API, we can hopefully remove this cache. +var partitionsCache atomic.Value + +func loadCachedPartitions(numPartitions int) []int { + partitions, ok := partitionsCache.Load().([]int) + if ok && len(partitions) >= numPartitions { + return partitions[:numPartitions] + } - shufflerMutex.Lock() + const alignment = 128 + n := ((numPartitions / alignment) + 1) * alignment - for i := range shuffledList { - j := shuffler.Intn(i + 1) - shuffledList[i], shuffledList[j] = shuffledList[j], shuffledList[i] + partitions = make([]int, n) + for i := range partitions { + partitions[i] = i } - shufflerMutex.Unlock() - return shuffledList + partitionsCache.Store(partitions) + return partitions[:numPartitions] } - -var ( - shufflerMutex = sync.Mutex{} - shuffler = rand.New(rand.NewSource(time.Now().Unix())) -) diff --git a/writer_test.go b/writer_test.go index c5228fa2f..2d2e05167 100644 --- a/writer_test.go +++ b/writer_test.go @@ -2,17 +2,13 @@ package kafka import ( "context" - "errors" "io" "math" - "strings" "testing" "time" ) func TestWriter(t *testing.T) { - t.Parallel() - tests := []struct { scenario string function func(*testing.T) @@ -31,22 +27,46 @@ func TestWriter(t *testing.T) { scenario: "running out of max attempts should return an error", function: testWriterMaxAttemptsErr, }, + { scenario: "writing a message larger then the max bytes should return an error", function: testWriterMaxBytes, }, + { scenario: "writing a batch of message based on batch byte size", function: testWriterBatchBytes, }, + { scenario: "writing a batch of messages", function: testWriterBatchSize, }, + { scenario: "writing messsages with a small batch byte size", function: testWriterSmallBatchBytes, }, + { + scenario: "setting a non default balancer on the writer", + function: testWriterSetsRightBalancer, + }, + { + scenario: "setting RequiredAcks to None in Writer does not cause a panic", + function: testWriterRequiredAcksNone, + }, + { + scenario: "writing messages to multiple topics", + function: testWriterMultipleTopics, + }, + { + scenario: "writing messages without specifying a topic", + function: testWriterMissingTopic, + }, + { + scenario: "specifying topic for message when already set for writer", + function: testWriterUnexpectedMessageTopic, + }, { scenario: "writing a message to an invalid partition", function: testWriterInvalidPartition, @@ -71,8 +91,9 @@ func newTestWriter(config WriterConfig) *Writer { func testWriterClose(t *testing.T) { const topic = "test-writer-0" - createTopic(t, topic, 1) + defer deleteTopic(t, topic) + w := newTestWriter(WriterConfig{ Topic: topic, }) @@ -82,10 +103,52 @@ func testWriterClose(t *testing.T) { } } -func testWriterRoundRobin1(t *testing.T) { +func testWriterRequiredAcksNone(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + transport := &Transport{} + defer transport.CloseIdleConnections() + + writer := &Writer{ + Addr: TCP("localhost:9092"), + Topic: topic, + Balancer: &RoundRobin{}, + RequiredAcks: RequireNone, + Transport: transport, + } + defer writer.Close() + + msg := Message{ + Key: []byte("ThisIsAKey"), + Value: []byte("Test message for required acks test")} + + err := writer.WriteMessages(context.Background(), msg) + if err != nil { + t.Fatal(err) + } +} + +func testWriterSetsRightBalancer(t *testing.T) { const topic = "test-writer-1" + balancer := &CRC32Balancer{} + w := newTestWriter(WriterConfig{ + Topic: topic, + Balancer: balancer, + }) + defer w.Close() + + if w.Balancer != balancer { + t.Errorf("Balancer not set correctly") + } +} +func testWriterRoundRobin1(t *testing.T) { + const topic = "test-writer-1" createTopic(t, topic, 1) + defer deleteTopic(t, topic) + offset, err := readOffset(topic, 0) if err != nil { t.Fatal(err) @@ -130,7 +193,7 @@ func TestValidateWriter(t *testing.T) { errorOccured bool }{ {config: WriterConfig{}, errorOccured: true}, - {config: WriterConfig{Brokers: []string{"broker1", "broker2"}}, errorOccured: true}, + {config: WriterConfig{Brokers: []string{"broker1", "broker2"}}, errorOccured: false}, {config: WriterConfig{Brokers: []string{"broker1"}, Topic: "topic1"}, errorOccured: false}, } for _, test := range tests { @@ -144,42 +207,16 @@ func TestValidateWriter(t *testing.T) { } } -type fakeWriter struct { - attempts int -} - -func (f *fakeWriter) messages() chan<- writerMessage { - ch := make(chan writerMessage, 1) - - go func() { - for { - msg := <-ch - f.attempts++ - msg.res <- &writerError{ - err: errors.New("bad attempt"), - } - } - }() - - return ch -} - -func (f *fakeWriter) close() {} - func testWriterMaxAttemptsErr(t *testing.T) { - const topic = "test-writer-2" - const maxAttempts = 3 - - fw := &fakeWriter{} - + topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) + w := newTestWriter(WriterConfig{ + Brokers: []string{"localhost:9999"}, // nothing is listening here Topic: topic, - MaxAttempts: maxAttempts, + MaxAttempts: 3, Balancer: &RoundRobin{}, - newPartitionWriter: func(p int, config WriterConfig, stats *writerStats) partitionWriter { - return fw - }, }) defer w.Close() @@ -188,22 +225,14 @@ func testWriterMaxAttemptsErr(t *testing.T) { }); err == nil { t.Error("expected error") return - } else if err != nil { - if !strings.Contains(err.Error(), "bad attempt") { - t.Errorf("unexpected error: %s", err) - return - } - } - - if fw.attempts != maxAttempts { - t.Errorf("got %d attempts, want %d", fw.attempts, maxAttempts) } } func testWriterMaxBytes(t *testing.T) { topic := makeTopic() - createTopic(t, topic, 1) + defer deleteTopic(t, topic) + w := newTestWriter(WriterConfig{ Topic: topic, BatchBytes: 25, @@ -294,9 +323,11 @@ func readPartition(topic string, partition int, offset int64) (msgs []Message, e func testWriterBatchBytes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - const topic = "test-writer-1-bytes" + const topic = "test-writer-1-bytes" createTopic(t, topic, 1) + defer deleteTopic(t, topic) + offset, err := readOffset(topic, 0) if err != nil { t.Fatal(err) @@ -348,6 +379,8 @@ func testWriterBatchSize(t *testing.T) { topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) + offset, err := readOffset(topic, 0) if err != nil { t.Fatal(err) @@ -399,6 +432,8 @@ func testWriterSmallBatchBytes(t *testing.T) { topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) + offset, err := readOffset(topic, 0) if err != nil { t.Fatal(err) @@ -444,12 +479,89 @@ func testWriterSmallBatchBytes(t *testing.T) { } } -type staticBalancer struct { - partition int +func testWriterMultipleTopics(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + topic1 := makeTopic() + createTopic(t, topic1, 1) + defer deleteTopic(t, topic1) + + offset1, err := readOffset(topic1, 0) + if err != nil { + t.Fatal(err) + } + + topic2 := makeTopic() + createTopic(t, topic2, 1) + defer deleteTopic(t, topic2) + + offset2, err := readOffset(topic2, 0) + if err != nil { + t.Fatal(err) + } + + w := newTestWriter(WriterConfig{ + Balancer: &RoundRobin{}, + }) + defer w.Close() + + msg1 := Message{Topic: topic1, Value: []byte("Hello")} + msg2 := Message{Topic: topic2, Value: []byte("World")} + + if err := w.WriteMessages(ctx, msg1, msg2); err != nil { + t.Error(err) + return + } + ws := w.Stats() + if ws.Writes != 2 { + t.Error("didn't batch messages; Writes: ", ws.Writes) + return + } + + msgs1, err := readPartition(topic1, 0, offset1) + if err != nil { + t.Error("error reading partition", err) + return + } + if len(msgs1) != 1 { + t.Error("bad messages in partition", msgs1) + return + } + if string(msgs1[0].Value) != "Hello" { + t.Error("bad message in partition", msgs1) + } + + msgs2, err := readPartition(topic2, 0, offset2) + if err != nil { + t.Error("error reading partition", err) + return + } + if len(msgs2) != 1 { + t.Error("bad messages in partition", msgs2) + return + } + if string(msgs2[0].Value) != "World" { + t.Error("bad message in partition", msgs2) + } } -func (b *staticBalancer) Balance(_ Message, partitions ...int) int { - return b.partition +func testWriterMissingTopic(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + w := newTestWriter(WriterConfig{ + // no topic + Balancer: &RoundRobin{}, + }) + defer w.Close() + + msg := Message{Value: []byte("Hello World")} // no topic + + if err := w.WriteMessages(ctx, msg); err == nil { + t.Error("expected error") + return + } } func testWriterInvalidPartition(t *testing.T) { @@ -458,6 +570,7 @@ func testWriterInvalidPartition(t *testing.T) { topic := makeTopic() createTopic(t, topic, 1) + defer deleteTopic(t, topic) w := newTestWriter(WriterConfig{ Topic: topic, @@ -475,3 +588,33 @@ func testWriterInvalidPartition(t *testing.T) { t.Fatal("expected error attempting to write message") } } + +func testWriterUnexpectedMessageTopic(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + w := newTestWriter(WriterConfig{ + Topic: topic, + Balancer: &RoundRobin{}, + }) + defer w.Close() + + msg := Message{Topic: "should-fail", Value: []byte("Hello World")} + + if err := w.WriteMessages(ctx, msg); err == nil { + t.Error("expected error") + return + } +} + +type staticBalancer struct { + partition int +} + +func (b *staticBalancer) Balance(_ Message, partitions ...int) int { + return b.partition +} diff --git a/zstd/zstd.go b/zstd/zstd.go index d4a6e85f4..c8bded1c3 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -1,135 +1,21 @@ -// Package zstd implements Zstandard compression. +// Package zstd does nothing, it's kept for backward compatibility to avoid +// breaking the majority of programs that imported it to install the compression +// codec, which is now always included. package zstd -import ( - "io" - "runtime" - "sync" +import "github.com/segmentio/kafka-go/compress/zstd" - zstdlib "github.com/klauspost/compress/zstd" - kafka "github.com/segmentio/kafka-go" +const ( + Code = 4 + DefaultCompressionLevel = 3 ) -func init() { - kafka.RegisterCompressionCodec(NewCompressionCodec()) -} - -const Code = 4 - -const DefaultCompressionLevel = 3 - -type CompressionCodec struct{ level zstdlib.EncoderLevel } +type CompressionCodec = zstd.Codec func NewCompressionCodec() *CompressionCodec { return NewCompressionCodecWith(DefaultCompressionLevel) } func NewCompressionCodecWith(level int) *CompressionCodec { - return &CompressionCodec{zstdlib.EncoderLevelFromZstd(level)} -} - -// Code implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Code() int8 { return Code } - -// Name implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) Name() string { return "zstd" } - -// NewReader implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { - p := new(reader) - if cached := decPool.Get(); cached == nil { - p.dec, p.err = zstdlib.NewReader(r) - runtime.SetFinalizer(p, finalizeReader) - } else { - p = cached.(*reader) - p.err = p.dec.Reset(r) - } - return p -} - -var decPool sync.Pool - -type reader struct { - dec *zstdlib.Decoder - err error -} - -// Close implements the io.Closer interface. -func (r *reader) Close() error { - if r.dec != nil { - r.err = io.ErrClosedPipe - decPool.Put(r) - } - return nil -} - -// Read implements the io.Reader interface. -func (r *reader) Read(p []byte) (n int, err error) { - if r.err != nil { - return 0, r.err - } - return r.dec.Read(p) -} - -// WriteTo implements the io.WriterTo interface. -func (r *reader) WriteTo(w io.Writer) (n int64, err error) { - if r.err != nil { - return 0, r.err - } - return r.dec.WriteTo(w) -} - -// NewWriter implements the kafka.CompressionCodec interface. -func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { - p := new(writer) - if cached := encPool.Get(); cached == nil { - p.enc, p.err = zstdlib.NewWriter(w, - zstdlib.WithEncoderLevel(c.level)) - } else { - p.enc = cached.(*zstdlib.Encoder) - p.enc.Reset(w) - } - return p -} - -var encPool sync.Pool - -type writer struct { - enc *zstdlib.Encoder - err error -} - -// Close implements the io.Closer interface. -func (w *writer) Close() error { - if w.enc == nil { - return nil // already closed - } - err := w.enc.Close() - encPool.Put(w.enc) - w.enc = nil - w.err = io.ErrClosedPipe - return err -} - -// WriteTo implements the io.WriterTo interface. -func (w *writer) Write(p []byte) (n int, err error) { - if w.err != nil { - return 0, w.err - } - return w.enc.Write(p) -} - -// ReadFrom implements the io.ReaderFrom interface. -func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { - if w.err != nil { - return 0, w.err - } - return w.enc.ReadFrom(r) -} - -// finalizeReader closes underlying resources managed by a reader. -func finalizeReader(r *reader) { - if r.dec != nil { - r.dec.Close() - } + return &CompressionCodec{Level: level} }