From d66d2c05d3f041156b3a922f0cad53b6e2309a05 Mon Sep 17 00:00:00 2001 From: Tomeu Vizoso Date: Tue, 17 Feb 2026 09:18:57 +0100 Subject: [PATCH] ethosu: Switch to the weight encoder from Regor We vendor the encoder used in the Regor compiler in Vela, and replace the previous one that was used by the Python compiler and doesn't support U85. Part-of: --- .clang-format-ignore | 2 + src/gallium/drivers/ethosu/ethosu_coefs.c | 149 ++- src/gallium/drivers/ethosu/ethosu_encode.cpp | 149 +++ src/gallium/drivers/ethosu/ethosu_encode.h | 27 + .../drivers/ethosu/ethosu_encode_support.h | 672 ++++++++++ src/gallium/drivers/ethosu/ethosu_ml.h | 2 + src/gallium/drivers/ethosu/meson.build | 7 +- .../ethosu/mlw_codec/include/mlw_decode.h | 50 + .../ethosu/mlw_codec/include/mlw_encode.h | 115 ++ .../drivers/ethosu/mlw_codec/meson.build | 17 + .../drivers/ethosu/mlw_codec/mlw_common.h | 29 - .../drivers/ethosu/mlw_codec/mlw_encode.c | 1186 ----------------- .../drivers/ethosu/mlw_codec/mlw_encode.h | 65 - .../ethosu/mlw_codec/source/ml_bit_buffer.hpp | 255 ++++ .../mlw_codec/source/ml_encoder_internal.hpp | 130 ++ .../mlw_codec/source/ml_ethosu_encode.cpp | 120 ++ .../ethosu/mlw_codec/source/ml_raw_buffer.hpp | 162 +++ .../ethosu/mlw_codec/source/mlw_decode.cpp | 362 +++++ .../ethosu/mlw_codec/source/mlw_encode.cpp | 979 ++++++++++++++ .../mlw_codec/source/mlw_encode_fwd.cpp | 197 +++ 20 files changed, 3328 insertions(+), 1347 deletions(-) create mode 100644 src/gallium/drivers/ethosu/ethosu_encode.cpp create mode 100644 src/gallium/drivers/ethosu/ethosu_encode.h create mode 100644 src/gallium/drivers/ethosu/ethosu_encode_support.h create mode 100644 src/gallium/drivers/ethosu/mlw_codec/include/mlw_decode.h create mode 100644 src/gallium/drivers/ethosu/mlw_codec/include/mlw_encode.h create mode 100644 src/gallium/drivers/ethosu/mlw_codec/meson.build delete mode 100644 src/gallium/drivers/ethosu/mlw_codec/mlw_common.h delete mode 100644 src/gallium/drivers/ethosu/mlw_codec/mlw_encode.c delete mode 100644 src/gallium/drivers/ethosu/mlw_codec/mlw_encode.h create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/ml_bit_buffer.hpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/ml_encoder_internal.hpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/ml_ethosu_encode.cpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/ml_raw_buffer.hpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/mlw_decode.cpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode.cpp create mode 100644 src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode_fwd.cpp diff --git a/.clang-format-ignore b/.clang-format-ignore index b23f076504c..3bdaec92bb1 100644 --- a/.clang-format-ignore +++ b/.clang-format-ignore @@ -1,2 +1,4 @@ # Vendored code src/amd/vulkan/radix_sort/* +src/gallium/drivers/ethosu/mlw_codec/**/* +src/gallium/drivers/ethosu/ethosu_encode_support.h diff --git a/src/gallium/drivers/ethosu/ethosu_coefs.c b/src/gallium/drivers/ethosu/ethosu_coefs.c index 1c309a376e1..9a29647d117 100644 --- a/src/gallium/drivers/ethosu/ethosu_coefs.c +++ b/src/gallium/drivers/ethosu/ethosu_coefs.c @@ -5,9 +5,53 @@ #include "util/u_inlines.h" -#include "mlw_codec/mlw_encode.h" -#include "ethosu_ml.h" +#include #include "ethosu_coefs.h" +#include "ethosu_encode.h" +#include "ethosu_ml.h" +#include "mlw_encode.h" + +static void +encode_bias_scale_u65(int64_t bias, int32_t scale, uint32_t shift, uint8_t data[10]) +{ + assert(-(1LL << (40 - 1)) <= bias && bias < (1LL << (40 - 1))); // signed 40-bit range + assert(0 <= scale); // unsigned 32-bit range + assert(0 <= shift && shift < (1 << 6)); // unsigned 6-bit range + + data[0] = (bias >> (0 * 8)) & 0xFF; + data[1] = (bias >> (1 * 8)) & 0xFF; + data[2] = (bias >> (2 * 8)) & 0xFF; + data[3] = (bias >> (3 * 8)) & 0xFF; + data[4] = (bias >> (4 * 8)) & 0xFF; + + data[5] = (scale >> (0 * 8)) & 0xFF; + data[6] = (scale >> (1 * 8)) & 0xFF; + data[7] = (scale >> (2 * 8)) & 0xFF; + data[8] = (scale >> (3 * 8)) & 0xFF; + + data[9] = shift & 0x3F; +} + +static void +encode_bias_scale_u85(int64_t bias, int32_t scale, uint32_t shift, uint8_t data[10]) +{ + assert(INT32_MIN <= bias && bias <= INT32_MAX); // signed 32-bit range + assert(0 <= scale); // unsigned 31-bit range + assert(0 <= shift && shift < (1 << 6)); // unsigned 6-bit range + + data[0] = (bias >> (0 * 8)) & 0xFF; + data[1] = (bias >> (1 * 8)) & 0xFF; + data[2] = (bias >> (2 * 8)) & 0xFF; + data[3] = (bias >> (3 * 8)) & 0xFF; + + data[4] = (scale >> (0 * 8)) & 0xFF; + data[5] = (scale >> (1 * 8)) & 0xFF; + data[6] = (scale >> (2 * 8)) & 0xFF; + data[7] = (scale >> (3 * 8)) & 0x7F; + + data[8] = shift & 0x3F; + data[9] = 0; +} static void fill_scale_and_biases(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **scales, long *scales_size, struct pipe_resource *bias_rsrc) @@ -15,32 +59,51 @@ fill_scale_and_biases(struct ethosu_subgraph *subgraph, struct ethosu_operation struct pipe_transfer *transfer_in; int32_t *biases = pipe_buffer_map(subgraph->base.context, bias_rsrc, PIPE_MAP_READ, &transfer_in); + float ifm_scale = operation->ifm.scale; + float ofm_scale = operation->ofm.scale; unsigned idx = 0; - *scales_size = align(operation->ofm.shape.depth * 10, 16); + /* U65 packs 10-byte bias/scale entries contiguously then aligns to 16. + * U85 scales are read in groups of 16 channels, so pad depth to a + * 16-channel boundary first, then multiply by 10 bytes per entry. */ + if (ethosu_is_u65(ethosu_screen(subgraph->base.context->screen))) + *scales_size = align(operation->ofm.shape.depth * 10, 16); + else + *scales_size = align(operation->ofm.shape.depth, 16) * 10; + *scales = malloc(*scales_size); memset(*scales, 0, *scales_size); for (unsigned i = 0; i < operation->ofm.shape.depth; i++) { - uint64_t bias = biases[i]; double kernel_scale = (operation->kernel.scales != NULL) ? operation->kernel.scales[i] : operation->kernel.scale; - double conv_scale = ((double)operation->ifm.scale * kernel_scale) / (double)operation->ofm.scale; + double conv_scale; + + if (!operation->ifm.is_signed) { + /* UInt8 path: multiply as float first, then cast to double */ + conv_scale = (double)(ifm_scale * kernel_scale) / (double)ofm_scale; + } else { + /* Int8 path: cast to double before multiply for higher precision */ + conv_scale = ((double)ifm_scale * (double)kernel_scale) / (double)ofm_scale; + } + uint32_t shift; int scale = ethosu_quantize_scale(conv_scale, &shift); - (*scales)[idx++] = (bias >> (0 * 8)) & 0xFF; - (*scales)[idx++] = (bias >> (1 * 8)) & 0xFF; - (*scales)[idx++] = (bias >> (2 * 8)) & 0xFF; - (*scales)[idx++] = (bias >> (3 * 8)) & 0xFF; - (*scales)[idx++] = (bias >> (4 * 8)) & 0xFF; + if (ethosu_is_u65(ethosu_screen(subgraph->base.context->screen))) + encode_bias_scale_u65( + biases[i], scale, shift, &(*scales)[idx]); + else + encode_bias_scale_u85( + biases[i], scale, shift, &(*scales)[idx]); - (*scales)[idx++] = (scale >> (0 * 8)) & 0xFF; - (*scales)[idx++] = (scale >> (1 * 8)) & 0xFF; - (*scales)[idx++] = (scale >> (2 * 8)) & 0xFF; - (*scales)[idx++] = (scale >> (3 * 8)) & 0xFF; + /* Saved for NPU_SET_OFM_SCALE emission in the command stream. */ + if (i == 0) { + operation->conv.scale = scale; + operation->conv.shift = shift; + } - (*scales)[idx++] = shift & 0x3F; + idx += 10; } pipe_buffer_unmap(subgraph->base.context, transfer_in); @@ -65,60 +128,11 @@ calculate_weights_strides(struct ethosu_operation *operation, int out_strides[4] static void fill_weights(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **weights, long *weights_size, struct pipe_resource *weight_rsrc) { - struct ethosu_screen *screen = ethosu_screen(subgraph->base.context->screen); - int brick_strides[4] = {0}; - unsigned input_channels = operation->ifm.shape.depth; - - if (operation->kernel.depthwise) - input_channels = 1; - - calculate_weights_strides(operation, brick_strides); - struct pipe_transfer *transfer_in; uint8_t *input_weights_8 = pipe_buffer_map(subgraph->base.context, weight_rsrc, PIPE_MAP_READ, &transfer_in); - int16_t *input_weights = malloc(pipe_buffer_size(weight_rsrc) * sizeof(*input_weights)); - unsigned num_weights = pipe_buffer_size(weight_rsrc); - unsigned output_channels = operation->ofm.shape.depth; - unsigned oc_stride = output_channels > 0 ? num_weights / output_channels : num_weights; - - for (unsigned i = 0; i < num_weights; i++) { - int zp; - if (operation->kernel.zero_points) { - unsigned ch = operation->kernel.depthwise ? i % output_channels : i / oc_stride; - zp = operation->kernel.zero_points[ch]; - } else { - zp = operation->kernel.zero_point; - } - - if (operation->kernel.is_signed) - input_weights[i] = (int8_t)input_weights_8[i] - zp; - else - input_weights[i] = input_weights_8[i] - zp; - } + ml_reorder_encode_weights(subgraph, operation, input_weights_8, pipe_buffer_size(weight_rsrc), weights, weights_size); pipe_buffer_unmap(subgraph->base.context, transfer_in); - - int64_t padded_size = 0; - *weights_size = mlw_reorder_encode( - screen->ifm_ublock.depth, - screen->ofm_ublock.depth, - operation->ofm.shape.depth, - operation->kernel.height, - operation->kernel.width, - input_channels, - brick_strides, - input_weights, - operation->block_config.ofm_block.depth, - operation->kernel.depthwise, - operation->block_config.is_partkernel, - 8 /* ifm_bitdepth */, - 8 /* decomp_h */, - 8 /* decomp_w */, - weights, - &padded_size, - DBG_ENABLED(ETHOSU_DBG_MSGS)); - - free(input_weights); } void @@ -140,6 +154,11 @@ fill_coefs(struct ethosu_subgraph *subgraph, uint8_t *weights = NULL; fill_weights(subgraph, operation, &weights, &operation->conv.weights.size, weight_rsrc); + if (!weights) { + mesa_loge("fill_weights failed"); + return; + } + operation->conv.weights.region = COEFS_REGION; operation->conv.weights.address = subgraph->coefs_used; subgraph->coefs_used += ALIGN_POT(operation->conv.weights.size, 16); diff --git a/src/gallium/drivers/ethosu/ethosu_encode.cpp b/src/gallium/drivers/ethosu/ethosu_encode.cpp new file mode 100644 index 00000000000..f6d1d963fdd --- /dev/null +++ b/src/gallium/drivers/ethosu/ethosu_encode.cpp @@ -0,0 +1,149 @@ +// +// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright (c) 2025 Tomeu Vizoso +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include +#include "ethosu_encode_support.h" +#include "ethosu_ml.h" +#include "mlw_encode.h" + +extern "C" { + +static int32_t +weight_func(int32_t query, ml_source_state_t *state, int16_t *buffer, int32_t size, void *user_arg) +{ + assert(query == MLW_SOURCE_QUERY_WEIGHTS); + IWeightSource *src = reinterpret_cast(user_arg); + int source_size = src->Get(buffer, size); + state->eos = source_size < size; + return source_size; +} + +static int +apply_zero_point_ihwo(const WeightTransformParam *p, int value) +{ + value = (value - int(p->zeroPoints[p->o % p->zeroCount])); + assert(value >= -255 && value <= 255); + return value; +} + +static int +apply_zero_point_ohwi(const WeightTransformParam *p, int value) +{ + value = (value - int(p->zeroPoints[p->i % p->zeroCount])); + assert(value >= -255 && value <= 255); + return value; +} + +void +ml_reorder_encode_weights(struct ethosu_subgraph *subgraph, + struct ethosu_operation *operation, + const uint8_t *input_weights, + long input_weights_size, + uint8_t **weights, + long *weights_size) +{ + struct ethosu_screen *screen = ethosu_screen(subgraph->base.context->screen); + int bit_depth = 8; + bool is_sparse = false; + EthosUTraversal traversal; + struct WeightTransformParam param; + WeightTransformFunc transform_func = apply_zero_point_ohwi; + ml_encode_result_t res; + Point2i stride(operation->kernel.stride_x, operation->kernel.stride_y); + Point2i dilation(operation->kernel.dilation_x, operation->kernel.dilation_y); + int ret; + + ml_ethosu_encode_params_t params; + params.encoder_flags = MLW_ENCODE_FLAG_NONE; + params.source_buffering_hint = 128 * 1024; + params.realloc_func = NULL; + + if (operation->kernel.depthwise) { + traversal = EthosUTraversal::Depthwise; + transform_func = apply_zero_point_ihwo; + } else if (operation->block_config.is_partkernel) + traversal = EthosUTraversal::PartKernel; + else + traversal = EthosUTraversal::DepthFirst; + + int zero_point = (int)operation->kernel.zero_point; + param.zeroPoints = &zero_point; + param.zeroCount = 1; + + WeightSourceCommon *source; + + if (ethosu_is_u65(screen)) { + if (operation->kernel.is_signed) { + source = new EthosUWeightOrdering(1, dilation, + operation->block_config.ofm_block.depth, bit_depth, screen->ofm_ublock.depth, + screen->ifm_ublock.depth, transform_func, ¶m, traversal); + } else { + source = new EthosUWeightOrdering(1, dilation, + operation->block_config.ofm_block.depth, bit_depth, screen->ofm_ublock.depth, + screen->ifm_ublock.depth, transform_func, ¶m, traversal); + } + } else { + if (operation->kernel.is_signed) { + source = new EthosU85WeightOrdering(1, 256, stride, dilation, + operation->block_config.ofm_block.depth, operation->block_config.ifm_block.depth, bit_depth, operation->block_config.ofm_ublock.depth, + transform_func, ¶m, traversal, is_sparse); + } else { + source = new EthosU85WeightOrdering(1, 256, stride, dilation, + operation->block_config.ofm_block.depth, operation->block_config.ifm_block.depth, bit_depth, operation->block_config.ofm_ublock.depth, + transform_func, ¶m, traversal, is_sparse); + } + } + + Shape ohwi = {static_cast(operation->ofm.shape.depth), + static_cast(operation->kernel.height), + static_cast(operation->kernel.width), + static_cast(operation->ifm.shape.depth)}; + Shape ohwiStrides; + + int v = 1; + for (int i = 4 - 1; i >= 0; --i) { + ohwiStrides[i] = v; + v *= ohwi[i]; + } + + if (operation->kernel.depthwise) + SWAP(ohwiStrides[0], ohwiStrides[3]); /* IHWO */ + + source->SetSource(input_weights, 0, ohwi, ohwiStrides, 0); + + mle_context_t *ctx = nullptr; + ret = ml_encode_ethosu_stream(&res, ¶ms, weight_func, source, &ctx); + mle_destroy_context(ctx); + if (ret < 0) { + mesa_loge("mlw encode failed"); + *weights = NULL; + *weights_size = 0; + mle_free(&res); + delete source; + return; + } + + *weights = res.encoded_data; + res.encoded_data = NULL; + *weights_size = res.encoded_length; + mle_free(&res); + delete source; +} + +} // extern "C" diff --git a/src/gallium/drivers/ethosu/ethosu_encode.h b/src/gallium/drivers/ethosu/ethosu_encode.h new file mode 100644 index 00000000000..7197a109dc7 --- /dev/null +++ b/src/gallium/drivers/ethosu/ethosu_encode.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 Tomeu Vizoso + * SPDX-License-Identifier: MIT + */ + +#ifndef ETHOSU_ENCODE_H +#define ETHOSU_ENCODE_H + +#include "ethosu_ml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void +ml_reorder_encode_weights(struct ethosu_subgraph *subgraph, + struct ethosu_operation *operation, + const uint8_t *input_weights, + long input_weights_size, + uint8_t **weights, + long *weights_size); + +#ifdef __cplusplus +} +#endif + +#endif /* ETHOSU_ENCODE_H */ diff --git a/src/gallium/drivers/ethosu/ethosu_encode_support.h b/src/gallium/drivers/ethosu/ethosu_encode_support.h new file mode 100644 index 00000000000..2b90907409e --- /dev/null +++ b/src/gallium/drivers/ethosu/ethosu_encode_support.h @@ -0,0 +1,672 @@ +// +// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright (c) 2025 Tomeu Vizoso +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef ETHOSU_ENCODE_SUPPORT_H +#define ETHOSU_ENCODE_SUPPORT_H + +#include +#include +#include +#include +#include + +struct WeightTransformParam +{ + int o, h, w, i, zeroCount; + int *zeroPoints; + bool is_signed; +}; + +typedef int (*WeightTransformFunc)(const WeightTransformParam *param, int weight); + +enum class EthosUTraversal { + DepthFirst, + PartKernel, + Depthwise +}; + +struct WeightsNotSparse : public std::runtime_error +{ + WeightsNotSparse() : std::runtime_error("weights not sparse") {} +}; + +struct SparsityTracker +{ + int _sparse_zeroes = 4; + int _sparse_index = 0; + uint32_t _sparse_pos = 0xFFFFFFFF; + void Reset() { _sparse_pos = 0xFFFFFFFF; } + + void Check(uint32_t pos, int depth, int weight) + { + if ( _sparse_pos != pos ) + { + _sparse_pos = pos; + _sparse_zeroes = 0; + _sparse_index = 0; + if ( depth & 3 ) throw WeightsNotSparse(); + } + + if ( weight == 0 ) _sparse_zeroes++; + else if ( weight > 127 || weight < -127 ) throw WeightsNotSparse(); + + if ( (_sparse_index & 3) == 3 ) + { + if ( _sparse_zeroes < 2 ) throw WeightsNotSparse(); + _sparse_zeroes = 0; + } + + _sparse_index++; + } +}; + +struct Point2i +{ + int x; + int y; + Point2i(int a, int b) : x(a), y(b) {} +}; + +static inline int RoundAway(int value, int align) +{ + assert(align > 0); + int rem = value % align; + if ( rem == 0 ) + { + return value; + } + else if ( rem < 0 ) + { + return value - (align + rem); + } + return value + (align - rem); +} + +typedef int Shape[4]; + +class IWeightSource +{ +public: + virtual ~IWeightSource() = default; + virtual int Elements() = 0; + virtual int Get(int16_t *buffer, int count) = 0; +}; + +struct IVolumeWeightSource : public IWeightSource +{ + virtual ~IVolumeWeightSource() = default; + virtual void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) = 0; +}; + +class WeightSourceCommon : public IVolumeWeightSource +{ + +protected: + const void *_source; + int16_t _streams = 1; + int16_t _streamIndex = 0; + int _ofmDepth = 0; + int _ifmDepth = 0; + int _kernelH = 0; + int _kernelW = 0; + int _ohwiStrides[4]; + +protected: + void SetSourceCommon(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex, bool separated) + { + assert(streamIndex < _streams); + _streamIndex = streamIndex; + + int streamOffset = depthOffset * ohwiStrides[0]; + _source = reinterpret_cast(buffer) + streamOffset; + _ifmDepth = ohwiShape[3]; + _ofmDepth = separated ? (ohwiShape[0] + _streams - 1 - streamIndex) / _streams : ohwiShape[0]; + _kernelH = ohwiShape[1]; + _kernelW = ohwiShape[2]; + + // Bring in values for better cache locality + _ohwiStrides[0] = ohwiStrides[0] * (separated ? _streams : 1); + _ohwiStrides[1] = ohwiStrides[1]; + _ohwiStrides[2] = ohwiStrides[2]; + _ohwiStrides[3] = ohwiStrides[3]; + } + + int Elements() override { return _ofmDepth * _ifmDepth * _kernelH * _kernelW; } + + inline int WeightIndex(int ofm_z, int wy, int wx, int ifm_z) const + { + return ofm_z * _ohwiStrides[0] + wy * _ohwiStrides[1] + wx * _ohwiStrides[2] + ifm_z * _ohwiStrides[3]; + } +}; + +template +class EthosU85WeightOrdering : public WeightSourceCommon +{ +protected: + static constexpr int InterleaveDepth = 4; + // Transform + WeightTransformParam *_param; + WeightTransformFunc _transform; + // Loop Limits + Point2i _stride = Point2i(0, 0); + int _ofmBlockDepth; + int _ifmBlockDepth; + short _ofmUBlockDepth; + short _ifmUBlockDepth; + short _decompX; + short _decompY; + short _subKernelRound; + short _dwPaddingCount; + // Saved state + int _ofmBlockZ = 0; + int _ifmBlockZ = 0; + int _subKernelX = 0; + int _subKernelY = 0; + int _ifmUBlockOuter = 0; + int _ifmUBlockInner = 0; + int _ofmUBlockZ = 0; + int _ifmUBlockZ = 0; + int _subKernelElements = 0; + int _strideX = 0; + int _strideY = 0; + int _kernelX = 0; + int _kernelY = 0; + int _ofmUBlockInner = 0; + int _ofmUBlockOuter = 0; + int _ifmLoopInc = 0; + int _padding = 0; + EthosUTraversal _traversal; + bool _sparse; + SparsityTracker _sparsity; + +public: + EthosU85WeightOrdering(int cores, int macs, Point2i stride, const Point2i &dilation, int ofmBlockDepth, int ifmBlockDepth, int ifmBitDepth, + int ofmUBlockDepth, WeightTransformFunc func, WeightTransformParam *param, EthosUTraversal traversal, bool sparse) + { + const bool ifm16bit = (ifmBitDepth == 16); + _streams = cores; + _transform = func; + _param = param; + _traversal = traversal; + _sparse = sparse; + _stride = stride; + + _ofmBlockDepth = ofmBlockDepth; + _ifmBlockDepth = ifmBlockDepth; + _ofmUBlockDepth = short(ofmUBlockDepth); + + if ( traversal == EthosUTraversal::PartKernel ) + { + _subKernelRound = (ifm16bit || sparse) ? 10 : 5; + _ifmUBlockDepth = ifm16bit && !sparse ? 8 : _ifmBlockDepth; + } + else + { + if ( traversal == EthosUTraversal::DepthFirst ) + { + _stride = Point2i(1, 1); + _subKernelRound = 1; + _ifmUBlockDepth = _ifmBlockDepth; + } + else if ( traversal == EthosUTraversal::Depthwise ) + { + _subKernelRound = 10; + _ifmUBlockDepth = 1; + } + } + + _decompX = short(8 / dilation.x); + _decompY = short(8 / dilation.y); + _dwPaddingCount = (!ifm16bit && macs <= 256) ? 0 : (macs <= 512) ? 2 : 6; + + _ifmLoopInc = -_ifmBlockDepth; + } + + void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) override + { + SetSourceCommon(buffer, depthOffset, ohwiShape, ohwiStrides, streamIndex, false); + assert(_streamIndex == streamIndex); + _ofmUBlockZ = _streamIndex * InterleaveDepth; + _sparsity.Reset(); + } + +public: + int Get(int16_t *output, int count) override + { + if ( _traversal == EthosUTraversal::Depthwise ) + { + assert(!_sparse); + return GetNext(output, count); + } + else if ( _sparse ) + { + return GetNext(output, count); + } + + return GetNext(output, count); + } + + template + int GetNext(int16_t *output, int count) + { + if ( _ofmBlockZ >= _ofmDepth ) + { + return 0; + } + + int ofmBlockZ, ifmBlockZ; + int ifmUBlockOuter, ifmUBlockInner; + int ifmUBlockZ, ofmUBlockZ, ofmUBlockInner, ofmUBlockOuter; + int subKernelX, subKernelY; + int strideX, strideY; + int kernelX, kernelY; + int padding; + int16_t *write = output; + + const TYPE *buffer = reinterpret_cast(_source); + + for ( ofmBlockZ = _ofmBlockZ; ofmBlockZ < _ofmDepth; ofmBlockZ += _ofmBlockDepth ) + { + _ifmLoopInc = -_ifmLoopInc; + int clippedOfmBlockDepth = std::min(_ofmBlockDepth, _ofmDepth - ofmBlockZ); + // IFM blocks required for the brick + for ( ifmBlockZ = _ifmBlockZ; ifmBlockZ < (IS_DEPTHWISE ? 1 : _ifmDepth) && ifmBlockZ >= 0; ifmBlockZ += _ifmLoopInc ) + { + _ifmBlockZ = ifmBlockZ; + int clippedIfmBlockDepth = std::min(_ifmBlockDepth, _ifmDepth - ifmBlockZ); + + // Weight decomposition + // Subkernel splitting (W) + for ( subKernelX = _subKernelX; subKernelX < _kernelW; subKernelX += _decompX ) + { + int subWidth = std::min(_kernelW - subKernelX, _decompX); + // Subkernel Splitting (H) + for ( subKernelY = _subKernelY; subKernelY < _kernelH; subKernelY += _decompY ) + { + int subHeight = std::min(_kernelH - subKernelY, _decompY); + int ifmBlockDepthOuter = IS_DEPTHWISE ? 1 : clippedIfmBlockDepth; + for ( ifmUBlockOuter = _ifmUBlockOuter; ifmUBlockOuter < ifmBlockDepthOuter; ifmUBlockOuter += _ifmUBlockDepth ) + { + // OFM uBlocks in OFM-block over depth + for ( ofmUBlockOuter = _ofmUBlockOuter; ofmUBlockOuter < clippedOfmBlockDepth; ofmUBlockOuter += _ofmUBlockDepth ) + { + // Part kernel first works across the kernel H/W and needs padding + if ( !_subKernelElements ) + { + int subKernelElements = subWidth * subHeight; + _subKernelElements = RoundAway(subKernelElements, _subKernelRound); + } + for ( strideY = _strideY; strideY < _stride.y; ++strideY ) + { + int stridedKernelH = (subHeight + _stride.y - 1 - strideY) / _stride.y; + for ( strideX = _strideX; strideX < _stride.x; ++strideX ) + { + int stridedKernelW = (subWidth + _stride.x - 1 - strideX) / _stride.x; + for ( kernelY = _kernelY; kernelY < stridedKernelH; ++kernelY ) + { + int y = kernelY; + for ( kernelX = _kernelX; kernelX < stridedKernelW; ++kernelX ) + { + int x = kernelY % 2 == 0 ? kernelX : stridedKernelW - 1 - kernelX; + _subKernelElements--; + int ifmUBlockInnerStep = IS_DEPTHWISE ? 1 : (IS_SPARSE ? 16 : 8); + for ( ifmUBlockInner = _ifmUBlockInner; ifmUBlockInner < _ifmUBlockDepth; ifmUBlockInner += ifmUBlockInnerStep ) + { + // Feed OFM uBlock elements + for ( ofmUBlockZ = _ofmUBlockZ; ofmUBlockZ < _ofmUBlockDepth; ofmUBlockZ += InterleaveDepth * _streams ) + { + for ( ofmUBlockInner = _ofmUBlockInner; ofmUBlockInner < InterleaveDepth; ofmUBlockInner++ ) + { + // Source IFM uBlock elements (only 1 element deep if + // depthwise) + for ( ifmUBlockZ = _ifmUBlockZ; ifmUBlockZ < ifmUBlockInnerStep; ifmUBlockZ++ ) + { + // Source position within the current subkernel + int wx = subKernelX + strideX + x * _stride.x; + int wy = subKernelY + strideY + y * _stride.y; + // Source IFM/OFM slices + int ifm_z = ifmBlockZ + ifmUBlockOuter + ifmUBlockInner + ifmUBlockZ; + int ofm_z = ofmBlockZ + ofmUBlockOuter + ofmUBlockInner + ofmUBlockZ; + int weight = 0; + if ( ifm_z < _ifmDepth && ofm_z < _ofmDepth ) + { + _param->o = ofm_z; + _param->h = wy; + _param->w = wx; + _param->i = ifm_z; + weight = int(buffer[WeightIndex(ofm_z, wy, wx, ifm_z)]); + weight = _transform(_param, weight); + } + + if constexpr ( IS_SPARSE ) + _sparsity.Check((unsigned(wy) << 16) | wx, ifm_z, weight); + + *write++ = int16_t(weight); + + if ( --count == 0 ) + { + // Save state + _subKernelElements++; + _ifmUBlockZ = ifmUBlockZ + 1; + _ofmUBlockInner = ofmUBlockInner; + _ofmUBlockZ = ofmUBlockZ; + _ifmUBlockInner = ifmUBlockInner; + _kernelX = kernelX; + _kernelY = kernelY; + _strideX = strideX; + _strideY = strideY; + _ofmUBlockOuter = ofmUBlockOuter; + _ifmUBlockOuter = ifmUBlockOuter; + _subKernelY = subKernelY; + _subKernelX = subKernelX; + _ofmBlockZ = ofmBlockZ; + _ifmLoopInc = -_ifmLoopInc; + return int(intptr_t(write - output)); + } + } + _ifmUBlockZ = 0; + } + _ofmUBlockInner = 0; + } + _ofmUBlockZ = _streamIndex * InterleaveDepth; + } + // Depthwise padding + if ( IS_DEPTHWISE && _subKernelElements % _subKernelRound == 0 ) + { + int padCount = _dwPaddingCount * _ofmUBlockDepth / _streams; + for ( padding = _padding; padding < padCount; padding++ ) + { + *write++ = 0; + if ( --count == 0 ) + { + // Save state + _subKernelElements++; + _padding = padding + 1; + _ifmUBlockInner = ifmUBlockInner; // Will skip loop above + _kernelX = kernelX; + _kernelY = kernelY; + _strideX = strideX; + _strideY = strideY; + _ofmUBlockOuter = ofmUBlockOuter; + _ifmUBlockOuter = ifmUBlockOuter; + _subKernelY = subKernelY; + _subKernelX = subKernelX; + _ofmBlockZ = ofmBlockZ; + _ifmLoopInc = -_ifmLoopInc; + return int(intptr_t(write - output)); + } + } + _padding = 0; + } + _ifmUBlockInner = 0; + } + _kernelX = 0; + } + _kernelY = 0; + } + _strideX = 0; + } + // Padding + if ( _subKernelElements > 0 ) + { + int padCount = _subKernelElements + (IS_DEPTHWISE ? _dwPaddingCount : 0); + padCount = padCount * _ifmUBlockDepth * _ofmUBlockDepth / _streams; + for ( padding = _padding; padding < padCount; padding++ ) + { + *write++ = 0; + if ( --count == 0 ) + { + // Save state + _padding = padding + 1; + _strideY = strideY; // Will skip loop above + _ofmUBlockOuter = ofmUBlockOuter; + _ifmUBlockOuter = ifmUBlockOuter; + _subKernelY = subKernelY; + _subKernelX = subKernelX; + _ofmBlockZ = ofmBlockZ; + _ifmLoopInc = -_ifmLoopInc; + return int(intptr_t(write - output)); + } + } + _padding = 0; + } + _subKernelElements = 0; + _strideY = 0; + } + _ofmUBlockOuter = 0; + } + _ifmUBlockOuter = 0; + } + _subKernelY = 0; + } + _subKernelX = 0; + } + } + _ifmLoopInc = -_ifmBlockDepth; + _ifmBlockZ = 0; + _ofmBlockZ = 0; + // Return weights generated (less than requested count == EOS) + return int(intptr_t(write - output)); + } +}; + +template +class EthosUWeightOrdering : public WeightSourceCommon +{ +protected: + // Transform + WeightTransformParam *_param; + WeightTransformFunc _transform; + // Loop Limits + int _ofmBlockDepth; + int _ifmBlockDepth; + short _ofmUBlockDepth; + short _ifmUBlockDepth; + short _decompX; + short _decompY; + short _subKernelRound; + // Saved state + int _ofmBlockZ = 0; + int _ifmBlockZ = 0; + int _subKernelX = 0; + int _subKernelY = 0; + int _ifmUBlockOuter = 0; + int _ifmUBlockInner = 0; + int _ofmUBlockZ = 0; + int _ifmUBlockZ = 0; + int _kernelElement = 0; + int _ofmUBlock = 0; + EthosUTraversal _traversal; + +public: + EthosUWeightOrdering(int cores, const Point2i &dilation, int ofmBlockDepth, int ifmBitDepth, int ofmUBlockDepth, + int ifmUBlockDepth, WeightTransformFunc func, WeightTransformParam *param, EthosUTraversal traversal) + { + _streams = cores; + _ofmBlockDepth = ofmBlockDepth; + _ifmBlockDepth = ((traversal == EthosUTraversal::PartKernel) || (ifmBitDepth == 16)) ? 16 : 32; + _ofmUBlockDepth = short(ofmUBlockDepth); + _ifmUBlockDepth = short(ifmUBlockDepth); + _decompX = short(8 / dilation.x); + _decompY = short(8 / dilation.y); + if ( traversal == EthosUTraversal::Depthwise ) + { + _subKernelRound = 4; + } + else if ( traversal == EthosUTraversal::PartKernel ) + { + _subKernelRound = (ifmBitDepth == 16) ? 2 : 4; + } + else + { + _subKernelRound = 1; + } + _transform = func; + _param = param; + _traversal = traversal; + } + + void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) override + { + SetSourceCommon(buffer, depthOffset + streamIndex, ohwiShape, ohwiStrides, streamIndex, true); + } + +public: + int Get(int16_t *output, int count) override + { + if ( _traversal == EthosUTraversal::Depthwise ) return GetNext(output, count); + else if ( _traversal == EthosUTraversal::PartKernel ) return GetNext(output, count); + return GetNext(output, count); + } + + template + int GetNext(int16_t *output, int count) + { + if ( _ofmBlockZ >= _ofmDepth ) + { + return 0; + } + + int ofmBlockZ, ifmBlockZ; + int ifmUBlockOuter, ifmUBlockInner; + int ifmUBlockZ, ofmUBlockZ, ofmUBlock; + int subKernelX, subKernelY; + int kernelElement; + int16_t *write = output; + + const TYPE *buffer = reinterpret_cast(_source); + int streamBlockDepth = (_ofmBlockDepth + _streams - 1 - _streamIndex) / _streams; + + for ( ofmBlockZ = _ofmBlockZ; ofmBlockZ < _ofmDepth; ofmBlockZ += streamBlockDepth ) + { + int clippedOfmBlockDepth = std::min(streamBlockDepth, _ofmDepth - ofmBlockZ); + // IFM blocks required for the brick + for ( ifmBlockZ = _ifmBlockZ; ifmBlockZ < (IS_DEPTHWISE ? 1 : _ifmDepth); ifmBlockZ += _ifmBlockDepth ) + { + int clippedIfmBlockDepth; + if ( IS_DEPTHWISE ) + { + clippedIfmBlockDepth = _ifmUBlockDepth; + } + else + { + clippedIfmBlockDepth = IS_PARTKERNEL ? std::min(_ifmBlockDepth, _ifmDepth - ifmBlockZ) : _ifmBlockDepth; + } + + // Weight decomposition + // Subkernel Splitting (H) + for ( subKernelY = _subKernelY; subKernelY < _kernelH; subKernelY += _decompY ) + { + int subHeight = std::min(_kernelH - subKernelY, _decompY); + // Subkernel splitting (W) + for ( subKernelX = _subKernelX; subKernelX < _kernelW; subKernelX += _decompX ) + { + int subWidth = std::min(_kernelW - subKernelX, _decompX); + int subKernelElements = subWidth * subHeight; + + // Part-kernel first works across the kernel H/W and needs padding + subKernelElements = RoundAway(subKernelElements, _subKernelRound); + + int ifmBlockDepthOuter = IS_PARTKERNEL ? clippedIfmBlockDepth : 1; + int ifmBlockDepthInner = IS_PARTKERNEL ? 1 : clippedIfmBlockDepth; + + for ( ifmUBlockOuter = _ifmUBlockOuter; ifmUBlockOuter < ifmBlockDepthOuter; ifmUBlockOuter += _ifmUBlockDepth ) + { + // OFM uBlocks in OFM-block over depth + for ( ofmUBlock = _ofmUBlock; ofmUBlock < clippedOfmBlockDepth; ofmUBlock += _ofmUBlockDepth ) + { + // HW Kernel element traversal - cannot be a H/W loop due to element + // padding requirement on depthwise/part-kernel configurations + for ( kernelElement = _kernelElement; kernelElement < subKernelElements; kernelElement++ ) + { + int kx = kernelElement % subWidth; + int ky = kernelElement / subWidth; + // IFM uBlocks in IFM-block over depth (only 1 uBlock if depthwise) + // In case of part-kernel-first IFM uBlock traversal have already been handled + // and this loop is ignored. + for ( ifmUBlockInner = _ifmUBlockInner; ifmUBlockInner < ifmBlockDepthInner; ifmUBlockInner += _ifmUBlockDepth ) + { + int ifmUBlock = ifmUBlockInner + ifmUBlockOuter; + // Feed OFM uBlock elements + for ( ofmUBlockZ = _ofmUBlockZ; ofmUBlockZ < _ofmUBlockDepth; ofmUBlockZ++ ) + { + // Source IFM uBlock elements (only 1 element deep if depthwise) + for ( ifmUBlockZ = _ifmUBlockZ; ifmUBlockZ < (IS_DEPTHWISE ? 1 : _ifmUBlockDepth); ifmUBlockZ++ ) + { + // Source position within the current subkernel + int wx = subKernelX + kx; + int wy = subKernelY + ky; + // Source IFM/OFM slices + int ifm_z = ifmBlockZ + ifmUBlock + ifmUBlockZ; + int ofm_z = ofmBlockZ + ofmUBlock + ofmUBlockZ; + if ( (ifm_z < _ifmDepth) && (ofm_z < _ofmDepth) && (ky < subHeight) ) + { + _param->o = ofm_z; + _param->h = wy; + _param->w = wx; + _param->i = ifm_z; + int weight = int(buffer[WeightIndex(ofm_z, wy, wx, ifm_z)]); + *write = int16_t(_transform(_param, weight)); + } + else + { + *write = 0; + } + write++; + if ( --count == 0 ) + { + // Save state + _ifmUBlockZ = ifmUBlockZ + 1; + _ofmUBlockZ = ofmUBlockZ; + _ifmUBlockInner = ifmUBlockInner; + _kernelElement = kernelElement; + _ofmUBlock = ofmUBlock; + _ifmUBlockOuter = ifmUBlockOuter; + _subKernelX = subKernelX; + _subKernelY = subKernelY; + _ifmBlockZ = ifmBlockZ; + _ofmBlockZ = ofmBlockZ; + // Return weights generated (less than requested count == EOS) + return int(intptr_t(write - output)); + } + } + _ifmUBlockZ = 0; + } + _ofmUBlockZ = 0; + } + _ifmUBlockInner = 0; + } + _kernelElement = 0; + } + _ofmUBlock = 0; + } + _ifmUBlockOuter = 0; + } + _subKernelX = 0; + } + _subKernelY = 0; + } + _ifmBlockZ = 0; + } + _ofmBlockZ = 0; + return int(intptr_t(write - output)); + } +}; + +#endif /* ETHOSU_ENCODE_SUPPORT_H */ diff --git a/src/gallium/drivers/ethosu/ethosu_ml.h b/src/gallium/drivers/ethosu/ethosu_ml.h index 8db4adf4ec1..d93567972d4 100644 --- a/src/gallium/drivers/ethosu/ethosu_ml.h +++ b/src/gallium/drivers/ethosu/ethosu_ml.h @@ -137,6 +137,8 @@ struct ethosu_operation { struct ethosu_address_range weights; struct ethosu_address_range scales; bool depthwise; + unsigned scale; + unsigned shift; } conv; struct { diff --git a/src/gallium/drivers/ethosu/meson.build b/src/gallium/drivers/ethosu/meson.build index 28f696a1bf5..8844cec0311 100644 --- a/src/gallium/drivers/ethosu/meson.build +++ b/src/gallium/drivers/ethosu/meson.build @@ -1,6 +1,8 @@ # Copyright 2019 Google, Inc # SPDX-License-Identifier: MIT +subdir('mlw_codec') + ethosu_registers = custom_target( 'ethosu_registers.h', input : ['gen_parser.py', 'gen_header.py', 'registers.xml'], @@ -13,18 +15,19 @@ files_ethosu = files( 'ethosu_cmd.c', 'ethosu_coefs.c', 'ethosu_device.c', + 'ethosu_encode.cpp', 'ethosu_lower.c', 'ethosu_ml.c', 'ethosu_sched.c', - 'mlw_codec/mlw_encode.c', ) libethosu = static_library( 'ethosu', [files_ethosu, ethosu_registers], - include_directories : [inc_gallium_aux, inc_gallium, inc_include, inc_src], + include_directories : [inc_gallium_aux, inc_gallium, inc_include, inc_src, inc_libmlw_codec], gnu_symbol_visibility : 'hidden', dependencies : [idep_mesautil, dep_libdrm], + link_with : [libmlw_codec], ) driver_ethosu = declare_dependency( diff --git a/src/gallium/drivers/ethosu/mlw_codec/include/mlw_decode.h b/src/gallium/drivers/ethosu/mlw_codec/include/mlw_decode.h new file mode 100644 index 00000000000..8aca4ea2cc4 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/include/mlw_decode.h @@ -0,0 +1,50 @@ +// +// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#ifndef __MLW_DECODE_H__ +#define __MLW_DECODE_H__ + +#pragma once + +#include + +// Result of the decode process +typedef struct ml_decode_result_t +{ + int16_t *decoded_data; // decoded weight elements + int32_t decoded_length; // decoded weight length (in elements) + int32_t section_count; // number of sections in stream + int32_t *section_sizes; // section sizes in stream +} ml_decode_result_t; + + +#if defined __cplusplus +extern "C" +{ +#endif + + void ml_decode_ethosu_stream(ml_decode_result_t *result, const uint8_t *buffer, int size_bytes); + + void mld_free(ml_decode_result_t *result); + +#if defined __cplusplus +} // extern "C" +#endif + + +#endif diff --git a/src/gallium/drivers/ethosu/mlw_codec/include/mlw_encode.h b/src/gallium/drivers/ethosu/mlw_codec/include/mlw_encode.h new file mode 100644 index 00000000000..289eff27be2 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/include/mlw_encode.h @@ -0,0 +1,115 @@ +// +// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#if !defined MLW_ENCODE_H +#define MLW_ENCODE_H + +#include +#include + +#if defined _MSC_VER + #define MLW_CODEC_PACKED + #define MLW_CODEC_USE_PACK_PRAGMA (1) +#else // __GNUC__ and clang + #define MLW_CODEC_PACKED __attribute__((packed)) +#endif + +// Encoder input parameters EthosU +typedef struct ml_ethosu_encode_params_t +{ + int32_t source_buffering_hint; // Recommend a buffering size + uint16_t encoder_flags; // Control flags to pass to the encoder + void* (*realloc_func)(void*, size_t, int purpose); // Custom output allocator function +} ml_ethosu_encode_params_t; + +// Resulting encoded section information +typedef struct ml_encode_section_t +{ + int32_t offset; // Byte offset of encoded section + int32_t size; // Byte size of encoded section + int32_t zeroes; // Number of zeroes encoded in a section + int8_t group_start; // Start of group +} ml_encode_section_t; + +// Result of the encode process +typedef struct ml_encode_result_t +{ + uint8_t *encoded_data; // Encoded weight data + int32_t encoded_length; // Encoded weight length (in bytes) + int32_t source_length; // Source elements read + ml_encode_section_t *section_info; // Array of sections in stream + int32_t section_count; // Number of section in stream +} ml_encode_result_t; + +#define MLW_SOURCE_QUERY_WEIGHTS 0 +#define MLW_SOURCE_QUERY_SHIFTS 1 + +// State of the source iterator +typedef struct ml_source_state_t +{ + uint8_t new_dim_mask; // Dimension start mask + uint8_t end_dim_mask; // Dimension end mask + bool eos; // End-of-stream flag +} ml_source_state_t; + +// Stream input callback (encoder will collect input through this function, of the size recommended by the buffering hint) +typedef int32_t (*ml_weight_source_fn)(int32_t query, ml_source_state_t *state, int16_t *buffer, int32_t size, void *user_arg); + +// Internal state context +typedef struct mle_context_t mle_context_t; + +#define MLW_ENCODE_FLAG_NONE (0) // Default encoding flag +#define MLW_ENCODE_NO_BITSTREAM (1) // Do not write any bitstream data (only return the length) +#define MLW_ENCODE_INSERT_PALETTE (2) // Insert a new palette header with this encode +#define MLW_ENCODE_RESET_PALETTE (4) // Clear and recalculate the palette header +#define MLW_ENCODE_PARTIAL_DATA (8) // Frequency analysis and palette will be constructed from incomplete data +#define MLW_ENCODE_NO_PADDING (16) // Disable trailing padding +#define MLW_ENCODE_NO_PALETTE_LUT (32) // Disable palette LUT generation +#define MLW_ENCODE_NO_ZERO_RUNS (64) // Disable zero run generation +#define MLW_ENCODE_DPIC_FORCE_PARAMS (128) // Force debug parameters +#define MLW_ENCODE_NEW_PALETTE (MLW_ENCODE_INSERT_PALETTE|MLW_ENCODE_RESET_PALETTE) + +#define MLW_ENCODE_SYNTAX_ETHOSU (0) // EthosU bitstream encode syntax +#define MLW_ENCODE_SYNTAX_ETHOSU_FWD (2) // EthosU FWD bitstream encode syntax + +#define MLW_ENCODE_ALLOC_GENERAL (0) // General allocations used by the encoder +#define MLW_ENCODE_ALLOC_METADATA (1) // Allocation for codec's metadata output +#define MLW_ENCODE_ALLOC_STREAM0 (2) // Stream 0 allocation for this codec +#define MLW_ENCODE_ALLOC_STREAM1 (3) // Stream 1 allocation for this codec + +#if defined __cplusplus +extern "C" +{ +#endif + // Baseline encode + mle_context_t *mle_create_context(int syntax); + int mle_context_query_zeroes(mle_context_t *ctx); + int mle_context_query_weights_used(mle_context_t *ctx, uint64_t weights_used[512 / 64]); + void mle_context_set_allocator(mle_context_t *ctx, void* (*realloc_func)(void*, size_t, int purpose)); + void mle_destroy_context(mle_context_t *ctx); + int mle_encode(mle_context_t *ctx, ml_encode_result_t *result, const int16_t *inbuf, int inbuf_size, unsigned mlw_encode_flags); + void mle_free(ml_encode_result_t *result); + + int32_t ml_encode_ethosu_stream(ml_encode_result_t *result, const ml_ethosu_encode_params_t *ep, ml_weight_source_fn src, void *user_arg, mle_context_t **ctx_out); + +#if defined __cplusplus +} // extern "C" +#endif + + +#endif // MLW_ENCODE_H diff --git a/src/gallium/drivers/ethosu/mlw_codec/meson.build b/src/gallium/drivers/ethosu/mlw_codec/meson.build new file mode 100644 index 00000000000..10641e3903a --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/meson.build @@ -0,0 +1,17 @@ +# Copyright 2025 Tomeu Vizoso +# SPDX-License-Identifier: MIT + +files_mlw_codec = files( + 'source/mlw_encode.cpp', + 'source/mlw_encode_fwd.cpp', + 'source/ml_ethosu_encode.cpp', + 'source/mlw_decode.cpp', +) + +libmlw_codec = static_library( + 'mlw_codec', + [files_mlw_codec], + gnu_symbol_visibility : 'hidden', +) + +inc_libmlw_codec = [include_directories('include')] \ No newline at end of file diff --git a/src/gallium/drivers/ethosu/mlw_codec/mlw_common.h b/src/gallium/drivers/ethosu/mlw_codec/mlw_common.h deleted file mode 100644 index 4bb38387221..00000000000 --- a/src/gallium/drivers/ethosu/mlw_codec/mlw_common.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2020, 2022 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#ifndef MLW_COMMON_H -#define MLW_COMMON_H - -#define ZDIV_DISABLE 6 // not alternating mode -#define ZDIV_EOS 7 // indicates end of stream - -#define WDIV_UNCOMPRESSED 7 // indicates uncompressed weights - -#endif diff --git a/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.c b/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.c deleted file mode 100644 index 47dd132090b..00000000000 --- a/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.c +++ /dev/null @@ -1,1186 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mlw_common.h" -#include "mlw_encode.h" - -#define DPRINTF(...) -//#define DPRINTF(...) printf(__VA_ARGS__) - -#define ZERO_RUN_THRES 4 - -#ifndef min -#define min(a,b) ((a)<(b)?(a):(b)) -#endif -#ifndef max -#define max(a,b) ((a)>(b)?(a):(b)) -#endif - -#define CHECKED_MALLOC(var, size) { if ( !(var = malloc(size)) ) break; } - -typedef struct palette { - int16_t lut[32]; - int16_t inv_lut[512]; - int palsize; // number of palette entries - int palbits; // bit width of palette entries - int use_zero_runs; // zeros are coded separately - int only_palette; // no values outside the palette - int direct_offset; // added to the decoded weight index before direct conversion to sign/mag - int only_zeros; // special case that the section is all zeros -} palette_t; - -static int is_power_of_two( int x ) { - return ((x-1) & x)==0; -} - -static int round_up_divide(int num, int den) -{ - return (num + den - 1) / den; -} - -static int round_up(int num, int den) -{ - return round_up_divide(num, den) * den; -} - -static int get_palette_index_bits( int size ) { - int i; - for(i=7; i>=0; i--) - if (size > (1< (i-last_restart_idx)/4; - - if (got_palette) { - // Check if the next value is not covered by the current palette - if ( prev_idx[ buf[i]+256 ] < last_restart_idx ) { - // New value: increase the palette size - palette_size++; - DPRINTF("Note: at pos %d extend palette to size %d\n", i, palette_size); - if ( is_power_of_two(palette_size-1-exclude_zero) ) { - if ( (i - last_restart_idx - zero_cnt) > 512 || (palette_size-exclude_zero)>32 ) { - // create a new palette because we extend a long lasting palette to require one more index bit - DPRINTF("Note: at pos %d create new palette because previous has to increase one more index bit. last_restart_idx %d n %d zero_cnt %d\n", i, last_restart_idx, i - last_restart_idx, zero_cnt ); - if (restart_i == max_palettes) { - max_palettes = max_palettes*2; - restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); - if (!restart_pos) { - return -1; - } - } - DPRINTF("restart %d pos %d\n", restart_i, i); - restart_pos[restart_i++] = i; - last_restart_idx = i; - got_palette=0; - zero_cnt=0; - } - } - } - } - - prev_idx[ buf[i]+256 ] = i; - if (buf[i]==0) - zero_cnt++; - - static const int window_sizes[5][2] = {{32,1}, {64,1}, {128,1}, {256,1}, {512,1}}; - int k; - // loop over window sizes - for(k=0; k<5; k++) { - // Every Nth non-zero value, count what would be the size of a palette covering the last N NZ. - int N = window_sizes[k][0] * (got_palette?2:1); - if ( (i - last_restart_idx - zero_cnt) > 0 && ((i - last_restart_idx - zero_cnt) % N)==0 ) { - // Search backward to the position N nonzero values earlier - int nzcnt=0; - for( j=i; j>last_restart_idx; j--) { - if ( buf[j]!=0 ) { - if (nzcnt==N+1) - break; - nzcnt++; - } - } - int restart_idx = j; - - // Calculate the size of a new palette (starting at restart_idx) - int new_palette_size=0; - for(j=0; j<512; j++) { - if ( prev_idx[j] >= restart_idx ) { - new_palette_size++; - } - } - - int create_new_palette=0; - if (got_palette) { - int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); - int old_size_bits = get_palette_index_bits( palette_size - exclude_zero ); - int savings = N*(old_size_bits*15-new_size_bits*15)/16 - new_palette_size*8 - 20; - if ( savings>0 ) { - // Create new palette because it can be smaller than the existing palette - create_new_palette=1; - DPRINTF("Note: at pos %d restart smaller palette\n", restart_idx); - } - } else { - if ( (new_palette_size-exclude_zero) <= 32) { - int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); - // estimate if we will make savings by using palette mode - int savings = N*(90-new_size_bits*15)/16 - new_palette_size*8 - 20; - create_new_palette = savings>0; - } - } - if (create_new_palette) { - palette_size=new_palette_size; - got_palette=1; - last_restart_idx = restart_idx; - DPRINTF("Note: at pos %d create palette of size %d\n", last_restart_idx, new_palette_size); - if ( restart_pos[restart_i-1] != last_restart_idx) { - if (restart_i == max_palettes) { - max_palettes = max_palettes*2; - restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); - if (!restart_pos) { - return -1; - } - } - restart_pos[restart_i++] = last_restart_idx; - } - zero_cnt=0; - for( j=last_restart_idx; j<=i; j++) - if (buf[j]==0) - zero_cnt++; - } - } - } - } - // Reallocate to actual size - *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) ); - return *palette_restart_positions ? restart_i : -1; -} - -// Calculate frequency table -static void calc_freq( const int16_t *buf, int size, int freq[512] ) { - int i; - memset(freq, 0, 512*sizeof(int)); - for(i=0; ibb ? -1 : aa0) { - all_max_val = max(all_max_val, palval); - } - } - - // Count number of non-used weight values around zero (0, -1, +1, -2, +2 etc) - for(i=0; i<31; i++) { - if ((freq64[i]>>16)!=0) - break; - } - p->direct_offset = i; - - // Sort in descending frequency order - qsort(freq64, 512, sizeof(uint64_t), cmp_uint64); - - // Identify special case that there are no weights to code - // in the weight index stream (i.e. all weights are zeros) - p->only_zeros = (freq64[0]>>16)==0; - if (p->only_zeros) { - p->direct_offset=0; - } - - // Check if all weights fit into the palette (and the palette is not empty) - p->only_palette = (freq64[0]>>16)>0 && (freq64[32]>>16)==0; - - int max_palette_size; - if (p->only_palette) { - max_palette_size = 32; - } else { - // For direct-lut we must make sure that the encoded weight - // index is not > 511. We do that by limiting the palette size - // such that the greatest value can be reached after subtracting - // the palette size. - max_palette_size = min(32, 511-all_max_val); - if (max_palette_size==1) { - max_palette_size=0; // because palette of size 1 is not supported - } - } - - // Setup the 32 entry palette - int16_t palette_max_val = 0, val; - int cnt, pal_cnt=0; - for(i=0; i>16); - val = freq64[i]&0xffff; - if ( cnt==0 ) - break; - p->lut[i] = val; - palette_max_val = max(palette_max_val, val); - pal_cnt+=cnt; - } - if (i==1) - p->lut[i++] = 0; // palette size of 1 is not supported, make it 2 - - // Heuristic for when to use the palette. If more than half of the - // weights are in the palette then we use it. This ensures we don't - // use palette for e.g. rectangular distributions. - int palbits_val; - if (pal_cnt > all_cnt/2) { - p->palsize = i; - palbits_val = palette_max_val; - } else { - // No palette - p->palsize = 0; - // If no palette, then palbits is used to specify the - // number of bits required for uncompressed mode, i.e. - // the number of bits for the greatest weight value - palbits_val = all_max_val; - } - - // the palette entry bit width - // minimum 2bits (because PALBITS is in range 2..9) - int palbits=2; - while( (1<palbits = palbits; - p->use_zero_runs = use_zero_runs; -} - -// Return 1 if zero runs should be used -// If palette_size is 512, then palette is not used (in that case the palette is setup -// with the standard alternating unsigned to signed mapping) -static int find_palette( const int16_t *inbuf, int inbuf_size, palette_t *p) { - int freq[512], i; - - // Calculate frequencies of the given weight stream - calc_freq( inbuf, inbuf_size, freq); - - // Find two most common values - int most_common_freq[2]={0}, most_common_val[2]={0}; - for(i=0; i<512; i++) { - if ( freq[i] > most_common_freq[0] ) { - most_common_freq[1] = most_common_freq[0]; - most_common_val[1] = most_common_val[0]; - most_common_freq[0] = freq[i]; - most_common_val[0] = i-256; - } else if ( freq[i] > most_common_freq[1] ) { - most_common_freq[1] = freq[i]; - most_common_val[1] = i-256; - } - } - - // Decide if zero-runs (alternating mode) should be used: - // * zero should be the most common symbol - // * zero should be sufficiently more common than the second most common symbol - int use_zero_runs = most_common_val[0]==0 && most_common_freq[0] > ZERO_RUN_THRES*most_common_freq[1]; - - // Create the palette - create_palette( freq, use_zero_runs, p); - - return use_zero_runs; -} - -static void create_inverse_palette( palette_t *p) { - int i; - memset( p->inv_lut, 0, sizeof(p->inv_lut)); - for(i=0; i<512; i++) { - int val = i; - int sign = val&1; - int mag = val>>1; - int weight = sign ? -mag : mag; - int index = weight+256; - if (index >= 0 && index < 512) - p->inv_lut[ index ] = i + p->palsize - p->direct_offset; - } - for(i=0; ipalsize; i++) { - int val = p->lut[i]; - int sign = val&1; - int mag = val>>1; - int weight = sign ? -mag : mag; - int index = weight+256; - assert(index >= 0 && index < 512); - if (index >= 0 && index < 512) - p->inv_lut[ index ] = i; - } -} - -#define NWCFG 13 -#define NZCFG 4 // restrict search to ZDIV=0..3 -#define MAX_ZWCFG (max(NWCFG,NZCFG)) - -// search state -typedef struct search_state { - int bitcnt; // number of bits to reach this state - uint8_t prev_cfg; // previous grc parameter config -} search_state_t; - -// (trunc<<4) | div, 0x20 means uncompressed -static const uint8_t w_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20 }; -static const uint8_t z_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04 }; - - - -// An algorithm similar to the Viterbi algorithm is used to search for a -// good GRC parameter sequence for the given input value sequence. -// The inval buffer can contain weights, weight indices or runs. -// The return value is the resulting number of bitstream sections. -static int search_grc_params( const int *inval_buf, - int n_inval, - int zrun_mode, - int uncompressed_bits, - uint8_t *grc_param_cfg, - int *grc_param_pos, - int max_grc_param_cfg, - int *existing_grc_param_pos, - int n_existing_grc_param_pos, - int *bitcnt ) -{ - int n_cfg = zrun_mode ? NZCFG : NWCFG; - const uint8_t *grc_params = zrun_mode ? z_grc_params : w_grc_params; - int i,j; - - search_state_t *state[MAX_ZWCFG]; - for(i=0; i>4; - int q = value>>div; - int bits = trunc ? min(q+1,2) + div : q+1+div; - if (!zrun_mode && ((trunc && q>2) || q>31)) - bits=10000; // it's not possible to code the current value; give it a high cost - if (trunc==2) - bits=uncompressed_bits; - - if ( best_bitcnt + cmd_cost < state[j][i].bitcnt ) { - // Change GRC parameters - state[j][i+1].prev_cfg = best_cfg; - state[j][i+1].bitcnt = best_bitcnt + cmd_cost + bits; - } else { - // Keep same GRC parameters - state[j][i+1].prev_cfg = j; - state[j][i+1].bitcnt = state[j][i].bitcnt + bits; - } - } - } - - - // Best GRC parameter - int best_bitcnt=0x7fffffff, best_cfg=0; - for(j=0; j=0; i--) { - if (state[cfg][i].prev_cfg != cfg || i==0) { - n_cmds++; - cfg = state[cfg][i].prev_cfg; - } - } - - (void)(max_grc_param_cfg); - assert(n_cmds<=max_grc_param_cfg); - - cfg = best_cfg; - j=n_cmds-1; - int endpos=n_inval; - for(i=n_inval; i>=0; i--) { - if (state[cfg][i].prev_cfg != cfg || i==0) { - grc_param_cfg[j] = cfg; - grc_param_pos[j] = endpos; - j--; - cfg = state[cfg][i].prev_cfg; - endpos = i-1; - } - } - assert(j==-1); - - for(i=0; ibuf = buf; - bb->pos = 0; - bb->buf_size = size; - bb->log_symbols = log_symbols; -} - -static void bitbuf_putbit( bitbuf_t *bb, uint8_t bit) { - int byte_pos = bb->pos>>3; - uint8_t bit_pos = bb->pos&7; - assert( byte_pos >= 0 ); - assert( byte_pos < bb->buf_size ); - bb->buf[ byte_pos ] = ((bb->buf[ byte_pos ] & ~(1U<pos += 1; -} - -static void bitbuf_put( bitbuf_t *bb, const char *name, int len, int data) { - int i; - if (len>0) { - if (bb->log_symbols) - printf("bitbuf: pos %3d %7s len %d data %x\n", bb->pos, name, len, data); - for(i=0; i>i)&1)); - } - } -} - -// Return new bitpos -static int encode_slice( const int *w_value, - const int *z_value, - int nvalues, - palette_t *p, - int new_palette, - int uncompressed_bits, - int w_cfg, - int z_cfg, - uint8_t *bitbuf, - int bitbuf_size, - int bitpos, - int verbose ) -{ - int i,j; - bitbuf_t bitbuf_s, *bb=&bitbuf_s; - bitbuf_init( bb, bitbuf, bitbuf_size, verbose&2?1:0 ); - bb->pos = bitpos; - - assert(nvalues<32768); - if (w_cfg < 0 || z_cfg < 0) - return bitpos; - // GRC parameters for this slice - int w_grc_div = w_grc_params[w_cfg] & 15; - int w_grc_trunc = (w_grc_params[w_cfg] >> 4)==1; - int w_uncompressed = (w_grc_params[w_cfg] >> 4)==2; - int z_grc_div = z_grc_params[z_cfg] & 15; - - if (w_uncompressed) { - w_grc_div = uncompressed_bits; - } - - int zdiv = p->use_zero_runs ? z_grc_div : ZDIV_DISABLE; - int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED; - - if (verbose&1) { - printf("slice: bitoffset %7d slicelen %5d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %2d\n", - bb->pos, nvalues, zdiv, wdiv, w_grc_trunc, new_palette, p->palbits, p->palsize); - } - - // Write slice header - bitbuf_put( bb, "ZDIV", 3, zdiv); - bitbuf_put( bb, "SLICELEN", 15, nvalues-1 ); - bitbuf_put( bb, "WDIV", 3, wdiv); - bitbuf_put( bb, "WTRUNC", 1, w_grc_trunc ); - bitbuf_put( bb, "NEWPAL", 1, new_palette ); - if (new_palette) { - bitbuf_put( bb, "DIROFS", 5, p->direct_offset ); - bitbuf_put( bb, "PALSIZE", 5, max(0, p->palsize-1)); - bitbuf_put( bb, "PALBITS", 3, p->palbits-2 ); - for(i=0; ipalsize; i++) { - bitbuf_put( bb, "PALETTE", p->palbits, p->lut[i] ); - } - } - - int z_nvalues = nvalues + (new_palette?1:0); - int w_pos=0, z_pos=0; - int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q=-1, w_r=0; - int z_unary=0, z_q=-1, z_r=0; - int w_nsymbols=0, w_remain[12]={0}; - int w_prev_enable=0, w_prev_nsymbols=0, w_prev_remain[12]={0}; - int z_nsymbols=0, z_remain[12]={0}; - int z_prev_enable=0, z_prev_nsymbols=0, z_prev_remain[12]={0}; - int z_unary_len = z_grc_div<3 ? 12 : 8; - do { - int balance = p->use_zero_runs ? w_pos - z_pos : 0; - int w_enable = balance<8 && w_pos=0 && p->use_zero_runs && z_pos5 ? 8 : 12; - while(j>w_grc_div; - w_r = value&((1<=0 && j0 ? (1<0) { - w_unary1 |= w_q>1 ? (1<=0) { - w_remain[w_nsymbols] = w_r; - w_nsymbols++; - w_pos++; - } - } - } - - if (z_enable) { - // Encode chunk (zrun) - j=0; - z_nsymbols=0; - z_unary=0; - while(j>z_grc_div; - z_r = value&((1<=0 && j0 ? (1<=0) { - z_remain[z_nsymbols] = z_r; - z_nsymbols++; - z_pos++; - } - } - } - - // Write chunk to bitstream - if (w_enable && !w_uncompressed) { - bitbuf_put( bb, "WUNARY0", 12, w_unary0); - } - if (z_enable) { - bitbuf_put( bb, "ZUNARY", z_unary_len, z_unary); - } - if (w_enable && !w_uncompressed) { - bitbuf_put( bb, "WUNARY1", w_unary1_len, w_unary1); - } - if (w_prev_enable) { - for(i=0; ipos; -} - -// return new bitpos -static int encode_section( const int16_t *inbuf, - int size, - palette_t *p, - uint8_t *bitbuf, - int bitbuf_size, - int bitpos, - int verbose ) -{ - int uncompressed_bits; - - // Uncompressed mode can only be used if either all weights - // are in the palette OR if the palette is not used. - if (p->only_palette) { - // Uncompressed bits derived from palette size - uncompressed_bits=0; - while( (1<palsize ) - uncompressed_bits++; - } else if (p->palsize==0) { - // Uncompressed bits is palbits (which is the bitdepth of the greatest weight) - uncompressed_bits = p->palbits; - } else { - // Don't use uncompressed - uncompressed_bits = 100; - } - - uint8_t *w_slice_cfg=0; - uint8_t *z_slice_cfg=0; - int *w_slice_pos=0; - int *z_slice_pos=0; - int *weight_values =0; - int *zrun_values = 0; - do { - CHECKED_MALLOC( weight_values, size*sizeof(int) ); - CHECKED_MALLOC( zrun_values, size*sizeof(int) ); - - // Get weights (or weight indicies) AND zero-runs from the input weight stream. - int i=0, n_weights = 0, zcnt; - while(1) { - if (p->use_zero_runs) { - zcnt=0; - // Count zero run - // Special case: if all weights in the section are zero, we must - // still ensure we have one coded weight so the the slice length - // doesn't become 0. Therefore we skip the first zero run and code - // the zero explicitly as a weight value instead - if (!p->only_zeros || i>0) { - while( iinv_lut[inbuf[i]+256]; - weight_values[n_weights] = value; - n_weights++; - i++; - } - - // Search for good GRC parameters for the weight stream - int n_w_slice, w_bitcnt; - CHECKED_MALLOC( w_slice_cfg, size ); - CHECKED_MALLOC( w_slice_pos, size*sizeof(int) ); - n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt); - if ( n_w_slice < 0 ) { // Memory allocation failed - bitpos = -1; - break; - } - if (n_weights==0) - n_w_slice = 0; - - // Search for good GRC parameters for the zrun stream - int n_z_slice=0, z_bitcnt=0; - if (p->use_zero_runs) { - CHECKED_MALLOC( z_slice_cfg, size ); - CHECKED_MALLOC( z_slice_pos, size*sizeof(int) ); - n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt); - if ( n_z_slice < 0 ) { // Memory allocation failed - bitpos = -1; - break; - } - } - - // Encode bitstream slice - int pos=0, i_w_slice=0, i_z_slice=0, new_palette=1; - while(posuse_zero_runs ? zrun_values+pos+(!new_palette) : 0; - bitpos = encode_slice( weight_values+pos, zrun_buf, len, - p, new_palette, uncompressed_bits, - w_slice_cfg[i_w_slice], p->use_zero_runs ? z_slice_cfg[i_z_slice] : 0, - bitbuf, bitbuf_size, bitpos, verbose ); - new_palette = 0; - - if (i_w_sliceuse_zero_runs) { - free(z_slice_cfg); - free(z_slice_pos); - } - free(weight_values); - free(zrun_values); - - return bitpos; -} - -// Encode the given weight stream -// inbuf uncompressed 9bit signed weights -// inbuf_size number of weights -// outbuf compressed bitstream, buffer is malloced within this function -// verbose if non-zero, printf log -// Return value is the size in bytes of the compressed output -// Return -1 if error -int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { - int i; - // Range check - for(i=0; i255) { - printf("ERROR: weight out of range at index %d, weight value is %d (valid range is -255..255)\n", i, inbuf[i]); - return -1; - } - } - - int bitbuf_size = inbuf_size*2+1024; - assert(*outbuf == NULL); - *outbuf = malloc( bitbuf_size ); - if (!*outbuf) - { // Failed to allocate buffer - return -1; - } - - // Analyse input data to find palette re-programming points - int *palette_restart_pos = NULL; - int n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos); - - // Compress each section (using a single palette) separately - int bitpos = 0; - for ( i = 0; i < n_restarts && bitpos >= 0; i++ ) { - palette_t palette; - int pos, size; - pos = palette_restart_pos[i]; - size = (i= 0 && n_restarts >= 0 ) { // If allocation fails bitpos or n_restarts < 0 - // Add end of stream marker and align to 128bit - bitbuf_t bitbuf_s, *bb=&bitbuf_s; - bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 ); - bb->pos = bitpos; - bitbuf_put( bb, "ZDIV", 3, ZDIV_EOS); - bitbuf_put( bb, "BYTEALIGN", (8-(bb->pos&7))&7, 0xff ); - - // Pad with 0xff until 64bit aligned - while( bb->pos & 127 ) { - bitbuf_put( bb, "PAD", 8, 0xff ); - } - bitpos = bb->pos; - - assert((bitpos&127)==0); - int outbuf_size = bitpos/8; - *outbuf = realloc(*outbuf, outbuf_size); - if ( *outbuf ) { - ret = outbuf_size; - } - } - - free(palette_restart_pos); - - return ret; -} - -void mlw_free_outbuf( uint8_t *outbuf ) { - if (outbuf) - free(outbuf); -} - -struct brick_buf_s -{ - int16_t* buf; - int* strides; -}; -typedef struct brick_buf_s brick_buf_t; - -static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z) -{ - int16_t* p = buf->buf; - - p += ofm_z * buf->strides[0]; - p += wy * buf->strides[1]; - p += wx * buf->strides[2]; - p += ifm_z * buf->strides[3]; - - return *p; -} - -static void reorder_free(int16_t* buf) -{ - if (buf) - { - free(buf); - } -} - -static int16_t* reorder( - int ifm_ublock_depth, - int ofm_ublock_depth, - int ofm_depth, - int kernel_height, - int kernel_width, - int ifm_depth, - int* strides, - int16_t* inbuf, - int ofm_block_depth, - int is_depthwise, - int is_partkernel, - int ifm_bitdepth, - int decomp_h, - int decomp_w, - int64_t* padded_length) -{ - *padded_length = -1; - /* Size unknown. Start with one page at least */ - int64_t length = round_up(max(1, sizeof(int16_t)* - ofm_depth* - kernel_height* - kernel_width* - ifm_depth), - 4*1024) / sizeof(int16_t); - int16_t* weights = (int16_t*)malloc(length * sizeof(int16_t)); - if (!weights) - { // Alloc failed, so exit - return NULL; - } - - brick_buf_t brick_buf; - brick_buf.buf = inbuf; - brick_buf.strides = strides; - - int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32; - int64_t weight_cnt = 0; - for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth) - { - int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z); - // IFM blocks required for the brick - for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth) - { - int clipped_ifm_block_depth; - if (is_depthwise) - { - clipped_ifm_block_depth = ifm_ublock_depth; - } - else - { - clipped_ifm_block_depth = is_partkernel ? - min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth; - } - // Weight decomposition - // Subkernel Splitting (H) - for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h) - { - int sub_height = min(kernel_height - subkernel_y, decomp_h); - // Subkernel splitting (W) - for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w) - { - int sub_width = min(kernel_width - subkernel_x, decomp_w); - int subkernel_elements = sub_width * sub_height; - // Part kernel first works across the kernel H/W and needs padding - if (is_partkernel) - { - if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0) - { - subkernel_elements = round_up(subkernel_elements, 2); - } - else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0) - { - subkernel_elements = round_up(subkernel_elements, 4); - } - } - else if (is_depthwise) - { - subkernel_elements = round_up(subkernel_elements, 4); - } - int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1; - int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth; - for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth) - { - // OFM Ublocks in OFM-block over depth - for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth) - { - // HW Kernel element traversal - cannot be a H/W loop due to element - // padding requirement on depthwise/part-kernel configurations - for (int element = 0; element < subkernel_elements; element++) - { - int kx = element % sub_width; - int ky = element / sub_width; - // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise) - // In case of part-kernel-first IFM Ublock traversal have already been handled - // and this loop is ignored. - for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth) - { - // Feed OFM ublock elements - for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++) - { - // Source IFM ublock elements (only 1 element deep if depthwise) - for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++) - { - // Source position within the current subkernel - int wx = subkernel_x + kx; - int wy = subkernel_y + ky; - // Source IFM/OFM slices - int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer; - int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z; - int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z; - if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height)) - { - weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z); - //fprintf(stderr, "weights[%ld] %d ofm_z %d wy %d wx %d ifm_z %d\n", weight_cnt, weights[weight_cnt], ofm_z, wy, wx, ifm_z); - } - else - { - weights[weight_cnt] = 0; - } - weight_cnt++; - if (weight_cnt == length) - { - // Reallocate by doubling the buffer size as needed - length *= 2; - weights = (int16_t*)realloc(weights, length * sizeof(int16_t)); - if (!weights) - { // Realloc failed, so exit - return NULL; - } - } - } - } - } - } - } - } - } - } - } - } - - - weights = (int16_t*)realloc(weights, weight_cnt * sizeof(int16_t)); - if ( weights ) { - *padded_length = weight_cnt; - } - - return weights; -} - -// Reorder and encode the given weight stream -// Return value is the size in bytes of the compressed output -// Return -1 if error -int mlw_reorder_encode( - int ifm_ublock_depth, - int ofm_ublock_depth, - int ofm_depth, - int kernel_height, - int kernel_width, - int ifm_depth, - int* brick_strides, - int16_t* inbuf, - int ofm_block_depth, - int is_depthwise, - int is_partkernel, - int ifm_bitdepth, - int decomp_h, - int decomp_w, - uint8_t **outbuf, // *outbuf must be freed by caller - int64_t* padded_length, - int verbose) -{ - if (verbose) { - fprintf(stderr, "mlw_reorder_encode: %d %d %d %d %d %d (%d %d %d %d) %d %d %d %d %d %d\n", ifm_ublock_depth, - ofm_ublock_depth, - ofm_depth, - kernel_height, - kernel_width, - ifm_depth, - brick_strides[0], - brick_strides[1], - brick_strides[2], - brick_strides[3], - ofm_block_depth, - is_depthwise, - is_partkernel, - ifm_bitdepth, - decomp_h, - decomp_w); - } - /* Reorder weights */ - int16_t* weights = reorder( - ifm_ublock_depth, - ofm_ublock_depth, - ofm_depth, - kernel_height, - kernel_width, - ifm_depth, - brick_strides, - inbuf, - ofm_block_depth, - is_depthwise, - is_partkernel, - ifm_bitdepth, - decomp_h, - decomp_w, - padded_length); - - /* Then encode */ - int output_length = -1; - if (*padded_length > 0 && *padded_length <= INT32_MAX) - { - output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose); - } - reorder_free(weights); - - return output_length; -} diff --git a/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.h b/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.h deleted file mode 100644 index 3162031e69d..00000000000 --- a/src/gallium/drivers/ethosu/mlw_codec/mlw_encode.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates - * - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the License); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#ifndef MLW_ENCODE_H -#define MLW_ENCODE_H - -#ifdef _MSC_VER - #define MLW_ENCODE_EXPORTED __declspec(dllexport) -#else - #define MLW_ENCODE_EXPORTED __attribute__((visibility("default"))) -#endif - -#if __cplusplus -extern "C" -{ -#endif - -MLW_ENCODE_EXPORTED -int mlw_encode(int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose); - -MLW_ENCODE_EXPORTED -void mlw_free_outbuf(uint8_t *outbuf); - -MLW_ENCODE_EXPORTED -int mlw_reorder_encode( - int ifm_ublock_depth, - int ofm_ublock_depth, - int ofm_depth, - int kernel_height, - int kernel_width, - int ifm_depth, - int* brick_strides, - int16_t* inbuf, - int ofm_block_depth, - int is_depthwise, - int is_partkernel, - int ifm_bitdepth, - int decomp_h, - int decomp_w, - uint8_t **outbuf, - int64_t* padded_length, - int verbose); - -#if __cplusplus -} -#endif - -#endif diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/ml_bit_buffer.hpp b/src/gallium/drivers/ethosu/mlw_codec/source/ml_bit_buffer.hpp new file mode 100644 index 00000000000..61369d75340 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/ml_bit_buffer.hpp @@ -0,0 +1,255 @@ +// +// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#if !defined ML_BIT_BUFFER_HPP +#define ML_BIT_BUFFER_HPP + +#pragma once + +#include "ml_raw_buffer.hpp" + +#include + +struct bitbuf_t +{ +private: + uint32_t *_buf; + uint32_t _next = 0; + int _pos; // bit pos of next bit + int _limit; // in bytes + int _substream_start = 0; // start position for substreams + int _substream_end = 0; // end position for substreams + bool _enabled = false; + raw_buffer_t *_buffer; +public: + // Read constructor + bitbuf_t(const void *buf, int used_bytes) : _buffer(nullptr) + { + _limit = used_bytes & ~3; + _buf = reinterpret_cast(const_cast(buf)); + _pos = 0; + } + + // Write constructor + bitbuf_t(raw_buffer_t &buffer, int reserve_bytes, bool disable_writes) : _buffer(&buffer) + { + _enabled = !disable_writes; + if ( _enabled ) _buffer->reserve(reserve_bytes); + _limit = _buffer->capacity() & ~3; + prime( _buffer->used() * 8 ); // Start at the end of the buffer's used data + } + + // Sub-stream writer + bitbuf_t(bitbuf_t &dest, int bitpos, int bitlen=0) : _buffer(dest._buffer) + { + _limit = _buffer->capacity() & ~3; + _substream_start = bitpos; + bitlen = (bitlen <= 0) ? (_limit * 8) - bitpos : bitlen; // Default to rest of buffer + _substream_end = bitpos + bitlen; + int required = (_substream_end + 7 ) / 8; + assert( required <= _limit ); + _enabled = dest._enabled; + prime( bitpos ); + } + +public: + void put(int len, int32_t data) + { + assert( _buf == reinterpret_cast(_buffer->begin()) && "Buffer resized externally" ); + assert( (data & ((1 << len)-1)) == data && "Data must be pre-masked" ); + assert( ((_substream_end == 0) || (_pos + len <= _substream_end)) && "Write past end of substream section" ); + if ( len > 0 && _enabled ) + { + uint32_t next = _next; + int bitpos = _pos & 0x1F; + next |= uint32_t(data) << bitpos; + + if ( len >= (32 - bitpos) ) + { + // Write won't fit, reserve more output space + if ( (_pos / 8) >= _limit ) + { + extend(); + } + + _buf[_pos >> 5] = next; // Requires little-endian + next = uint32_t(data) >> (32 - bitpos); + } + + _next = next; + } + + _pos += len; + } + + void put_masked(int len, int32_t data) + { + put(len, data & ((1 << len)-1)); + } + + void fill(int len, unsigned bit) + { + const uint32_t mask = 0xFFFFFFFF * bit; + int remain = len; + while ( remain >= 32 ) + { + put(32, mask); + remain -= 32; + } + if (remain > 0) + put(remain, mask & ((1u << remain) - 1) ); + } + + void align(int bits, int fill_byte) + { + // Alignments must be power of 2 + assert( (bits & (bits - 1)) == 0 && bits ); + const int mask = bits-1; + + int distance = (bits - (_pos & mask)) & mask; + + // Byte align first + put_masked( distance & 7, fill_byte ); + distance &= ~7; + while (distance != 0) + { + put(8, fill_byte); + distance -= 8; + } + } + + void reposition(int bitpos) + { + int end = (_substream_end != 0) ? _substream_end : _limit * 8; + assert( (bitpos >= 0 && bitpos <= end) && "Can't reposition out of stream" ); + assert( (_substream_end == 0 || bitpos >= _substream_start) && "Can't reposition before substream"); + if ((_pos != bitpos) && (bitpos > 0) && (bitpos <= end) ) + { + // Reposition in bitstream. Caller must flush if writing. + prime(bitpos); + } + } + + void flush(bool done=true) + { + if ( !_enabled ) return; + // If buffering word is not empty, write it out as-is. + if ( _pos & 0x1F ) + { + // If writing a substream, blend any overlapping words with the parent stream + if ( _substream_end ) + { + int remain = _substream_end - (_pos & ~0x1F); // Remaining word-bits in this substream + if ( remain < 32 ) // Will overlap happen? + { + uint32_t mask = ~0u << (32 - remain); // Mask the parent bits that we want to keep + _next = (_buf[_pos >> 5] & mask) | (_next & ~mask); + } + } + // Otherwise limited by the buffer + else + { + // Only extend by space required to flush remaining word. + if ( (_pos / 8) >= _limit ) + { + extend(true); + } + } + _buf[_pos >> 5] = _next; + } + if ( done && !_substream_end ) + { + _buffer->set_used( _pos / 8 ); + } + } + + void sync(bitbuf_t &substream) + { + flush(false); + substream.flush(false); + prime( std::max(_pos, substream._pos) ); + substream._buffer = nullptr; + } + + int get(int len) + { + if ( len == 0 ) + { + return 0; + } + + const unsigned mask = (1u << len) - 1; + assert( (_pos / 8) < _limit ); + uint32_t next = _buf[_pos >> 5]; + int bitpos = _pos & 0x1F; + // Bits from this word + unsigned value = next >> bitpos; + _pos += len; + + // Some of the bits are in the next word + if ( len > (32 - bitpos) ) + { + assert( (_pos / 8) < _limit ); + next = _buf[_pos >> 5]; + value |= next << (32 - bitpos); + } + + return int(value & mask); + } + + void read_align(int bits) + { + // Alignments must be power of 2 + assert( (bits & (bits - 1)) == 0 && bits ); + const int mask = bits-1; + _pos += (bits - (_pos & mask)) & mask; + } + + bool read_eos() const { return _pos/8 >= _limit; } + + int read_avail() const { return (_limit - (_pos / 8)) * 8 - (_pos & 7); } + int read_avail(int watermark) const { return (watermark - (_pos / 8)) * 8 - (_pos & 7); } + + int pos() const { return _pos; } + int byte_pos() const { return _pos / 8; } + int byte_length() const { return _limit; } + +private: + void prime(int bitpos) + { + assert( (bitpos >= 0) && (bitpos / 8) < _limit ); + // Prime (start up) the bitstream writer at the given bit position + _pos = bitpos; + _buf = reinterpret_cast(_buffer->begin()); + _next = _buf[bitpos >> 5]; + _next &= (1u << (bitpos & 0x1F)) - 1; + } + + void extend(bool exact_resize=false) + { + assert(_enabled); + _buffer->set_used( (_pos / 8) & ~3 ); // Only use whole words + _buffer->reserve( sizeof(uint32_t), exact_resize ); // Buffer implementation must optimise small requests + assert( (_buffer->capacity() & ~3) > _limit ); + assert( _substream_end == 0 ); // Can't extend a substream + _limit = _buffer->capacity() & ~3; + _buf = reinterpret_cast(_buffer->begin()); + } +}; + +#endif // ML_BIT_BUFFER_HPP diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/ml_encoder_internal.hpp b/src/gallium/drivers/ethosu/mlw_codec/source/ml_encoder_internal.hpp new file mode 100644 index 00000000000..f4f3202b4f7 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/ml_encoder_internal.hpp @@ -0,0 +1,130 @@ +// +// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#if !defined ML_ENCODER_INTERNAL_HPP +#define ML_ENCODER_INTERNAL_HPP + +#pragma once + +#include "../include/mlw_encode.h" +#include "ml_bit_buffer.hpp" + +#include +#include + +#if __GNUC__ + #define ML_ENCODER_DLL_EXPORT __attribute__((visibility("default"))) +#elif _WIN32 + #if TARGET_WIN32_DLL + #define ML_ENCODER_DLL_EXPORT __declspec(dllexport) + #else + #define ML_ENCODER_DLL_EXPORT + #endif +#else + #error "undefined export semantics" +#endif + +#if !defined ENABLE_DEBUG_PACKET + #define ENABLE_DEBUG_PACKET (0) +#endif + +#if !defined ENABLE_DEBUG_BITSTREAM + #define ENABLE_DEBUG_BITSTREAM (0) +#endif + +#if ENABLE_DEBUG_PACKET + #include + #define PACKET_LOG(...) printf(__VA_ARGS__) +#else + #define PACKET_LOG(...) +#endif + +#if ENABLE_DEBUG_BITSTREAM + #include + #define BITSTREAM_LOG(...) printf(__VA_ARGS__) +#else + #define BITSTREAM_LOG(...) +#endif + +constexpr int ETHOSU_SLICELEN_BITS = 15; +constexpr int ZDIV_DISABLE = 6; // not alternating mode +constexpr int ZDIV_EOS = 7; // indicates end of stream +constexpr int WDIV_UNCOMPRESSED = 7; // indicates uncompressed weights + +struct palette_t +{ + int16_t lut[32] = {0}; + int16_t inv_lut[512] = {0}; + int freq[512] = {0}; + int palsize; // number of palette entries + int palbits; // bit width of palette entries + int direct_offset; // added to the decoded weight index before direct conversion to sign/mag + bool use_zero_runs; // zeros are coded separately + bool only_palette; // no values outside the palette + bool only_zeros; // special case that the section is all zeros +}; + +struct slice_params_t +{ + uint8_t w_grc_trunc; + bool w_uncompressed; + uint8_t z_grc_div; + uint8_t w_grc_div; +}; + +struct mle_slice_debug_t +{ + slice_params_t params; + palette_t palette; +}; + +struct mle_context_t +{ + palette_t palette; + uint64_t weights_used[512 / 64] = {0}; + int distinct_weights = 0; + int syntax = 0; + int zero_count = 0; + int slicelen_bits = ETHOSU_SLICELEN_BITS; + bool palette_valid = false; + bool single_slice_sections = false; + bool allow_empty_slices = false; + bool eos_required = false; + bool enable_slice_debug = false; + bool disable_lut = false; + int8_t fixed_wgrc = -1; + int8_t fixed_zgrc = -1; + std::vector slice_debug; + void* (*realloc_func)(void*, size_t, int); // Custom output allocator function +}; + +inline int div_round_up(int num, int div) +{ + return (num + div - 1) / div; +} + +int ml_encode_fwd(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, unsigned mlw_encode_flags); +int ml_encode_section(mle_context_t *ctx, const int16_t *inbuf, int size, palette_t *p, bitbuf_t *bitbuf); +palette_t *ml_encode_palette(mle_context_t *ctx, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags); +void ml_encode_eos(mle_context_t *ctx, bitbuf_t &bits, unsigned mlw_encode_flags); +int ml_encode_internal(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags); + +#endif // ML_ENCODER_INTERNAL_HPP + + + diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/ml_ethosu_encode.cpp b/src/gallium/drivers/ethosu/mlw_codec/source/ml_ethosu_encode.cpp new file mode 100644 index 00000000000..cd895863787 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/ml_ethosu_encode.cpp @@ -0,0 +1,120 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "../include/mlw_encode.h" + +#include "ml_bit_buffer.hpp" +#include "ml_raw_buffer.hpp" +#include "ml_encoder_internal.hpp" + +#include +#include + +#if defined __cplusplus +extern "C" +{ +#endif + + +ML_ENCODER_DLL_EXPORT int32_t ml_encode_ethosu_stream(ml_encode_result_t *result, const ml_ethosu_encode_params_t *ep, ml_weight_source_fn src, void *user_arg, mle_context_t **ctx_out) +{ + constexpr int BUFFERING_REQUEST_SIZE = 8192; // Initial input buffering + constexpr int INITIAL_OUTPUT_BUFFER = 8192; // Initial size of output buffer (doubles at every overflow) + constexpr unsigned VALID_FLAGS = MLW_ENCODE_NO_BITSTREAM; + + assert(result && ep); + if ( !(result && ep && src) ) + { + return -1; + } + + mle_context_t *ctx = mle_create_context(MLW_ENCODE_SYNTAX_ETHOSU); + // Allow forcing parameters for debug validation - it is expected that + // the caller knows what they're doing here since it accesses the opaque + // internals via the public interface. + + assert( !(ep->encoder_flags & ~(VALID_FLAGS)) ); // Check acceptable flags + unsigned ethosu_encode_flags = (ep->encoder_flags & VALID_FLAGS); + + // Input buffering of data from the source function + assert( ep->source_buffering_hint >= 0 ); + int request_size = std::max(BUFFERING_REQUEST_SIZE, ep->source_buffering_hint & 0x00FFFFFF); + raw_buffer_t buffer( request_size ); + + // The source function will communicate the state to this encoding loop + ml_source_state_t state = {0}; + state.eos = false; + + // Output bitstream allocation + raw_buffer_t output(INITIAL_OUTPUT_BUFFER, MLW_ENCODE_ALLOC_STREAM0, ep->realloc_func); + bitbuf_t bits(output, 8, ethosu_encode_flags & MLW_ENCODE_NO_BITSTREAM); + + result->source_length = 0; + + // Repeatedly ask for values until the source function signals end-of-stream. + while ( !state.eos ) + { + int16_t *buf_write = buffer.reserve(request_size); + if ( !buf_write ) + { + // Memory allocation failed + mle_destroy_context(ctx); + return -1; + } + int received = (*src)(MLW_SOURCE_QUERY_WEIGHTS, &state, buf_write, buffer.capacity() - buffer.used(), user_arg); + buffer.use(received); + + unsigned encode_flags = ethosu_encode_flags; + encode_flags |= MLW_ENCODE_NEW_PALETTE; + + int bytes_written = ml_encode_internal(ctx, bits, buffer.begin(), buffer.used(), buffer.used(), encode_flags); + if ( bytes_written < 0 ) + { + // Encoder errored + mle_destroy_context(ctx); + return -1; + } + result->source_length += buffer.used(); + buffer.clear(); + } + + ml_encode_eos(ctx, bits, ethosu_encode_flags); + + // Populate the return result + assert(bits.byte_pos() == output.used() || (ethosu_encode_flags & MLW_ENCODE_NO_BITSTREAM)); + result->encoded_length = bits.byte_pos(); + result->encoded_data = output.detach(); + result->section_info = nullptr; + result->section_count = 0; + + if (ctx_out != nullptr) + { + assert( *ctx_out == nullptr ); + *ctx_out = ctx; + } + else + { + mle_destroy_context(ctx); + } + return 1; +} + + +#if defined __cplusplus +} // extern "C" +#endif diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/ml_raw_buffer.hpp b/src/gallium/drivers/ethosu/mlw_codec/source/ml_raw_buffer.hpp new file mode 100644 index 00000000000..16b3a8d9d23 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/ml_raw_buffer.hpp @@ -0,0 +1,162 @@ +// +// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#if !defined ML_RAW_BUFFER_HPP +#define ML_RAW_BUFFER_HPP + +#pragma once + +#include +#include +#include +#include +#include +#include + +typedef void* (*realloc_t)(void *ptr, size_t size, int); + +template +struct raw_buffer_t +{ + static_assert(std::is_trivially_copyable::value, "expected simple storage type"); + constexpr static int CAPACITY_ALIGN = 16; + TYPE *_data; + int _used; + int _capacity; + int _reallocArg = 0; + realloc_t _realloc=&realloc_proxy; + +public: + raw_buffer_t(int capacity, int arg=0, realloc_t rfunc=nullptr) + { + assert(capacity > 0); + _realloc = (rfunc != nullptr) ? rfunc : &realloc_proxy; + _capacity = (capacity + CAPACITY_ALIGN - 1) & ~(CAPACITY_ALIGN - 1); + _reallocArg = arg; + _data = reinterpret_cast(_realloc(nullptr, _capacity * sizeof(TYPE), _reallocArg)); + _used = 0; + } + + raw_buffer_t(raw_buffer_t &&other) + { + _capacity = other._capacity; + other._capacity = 0; + _data = other._data; + other._data = nullptr; + _used = other._used; + other._used = 0; + _reallocArg = other._reallocArg; + _realloc = other._realloc; + } + + raw_buffer_t(TYPE *data, int used, int capacity) + { + _data = data; + _used = used; + _capacity = capacity; + } + + ~raw_buffer_t() + { + if (_data) + { + _realloc(_data, 0, _reallocArg); + } + } + + TYPE *begin() { return _data; } + TYPE *end() { return _data + _used; } + int used() const { return _used; } + int capacity() const { return _capacity; } + void clear() { _used = 0; } + + const TYPE &operator[](int index) const { assert(index < _used); return _data[index]; } + + void set_used(int used) + { + assert(used >= _used); + assert(used <= _capacity); + _used = used; + } + + TYPE *reserve(int count, bool exact_resize=false) + { + int req_capacity = _used + count; + if ( req_capacity > _capacity ) + { + if ( !exact_resize ) + { + req_capacity = std::max(req_capacity, _capacity * 2); + } + + auto *p = reinterpret_cast( _realloc(_data, req_capacity * sizeof(TYPE), _reallocArg) ); + if ( !p ) + { + return nullptr; + } + _data = p; + _capacity = req_capacity; + } + int at = _used; + return _data + at; + } + + TYPE *use(int count) + { + int at = _used; + _used += count; + return _data + at; + } + + TYPE *detach() + { + auto tmp = _data; + _data = nullptr; + return tmp; + } + + void align(int align_bytes, TYPE fill) + { + int count = (((_used + align_bytes - 1) / align_bytes) * align_bytes) - _used; + TYPE *p = reserve(count); + assert(p); + use(count); + while (count--) + { + *p++ = fill; + } + } + + void remove_left(int count) + { + int to_move = _used - count; + if (to_move >= 0) + { + memmove(_data, _data + count, to_move * sizeof(TYPE)); + } + _used = to_move; + } + +private: + static void *realloc_proxy(void *ptr, size_t size, int) + { + return realloc(ptr, size); + } +}; + +#endif // ML_RAW_BUFFER_HPP diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/mlw_decode.cpp b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_decode.cpp new file mode 100644 index 00000000000..a121f6b6e7c --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_decode.cpp @@ -0,0 +1,362 @@ +// +// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "../include/mlw_decode.h" +#include "ml_encoder_internal.hpp" +#include "ml_bit_buffer.hpp" + +#include +#include +#include + +#if NDEBUG || 1 + // Release build get bits macro + #define bitbuf_get(bb_, name_, len_) bb_.get(len_) +#else + #include + + // Debug build get bits macro + inline int bitbuf_get(bitbuf_t &bb, const char *name, int len) + { + assert(len <= 32); + int tmp = bb.get(len); + printf("%6d %s:%d = %d\n", bb.pos() - len, name, len, tmp); + fflush(stdout); + return tmp; + } +#endif + +// Extract and decode weights from the given bitstream +// +// outbuf - output decoded weights +// bb - input bitstream buffer (wherever it is positioned) +// endpos - bitstream end position (in bytes) +// single_slice - process slices one-at-a-time +// slice_len - slice length in bits +// +// Returns - the number of weights extracted from the bitstream +static int ml_decode_internal(raw_buffer_t &outbuf, bitbuf_t &bb, palette_t &palette, int endpos, bool single_slice, int slice_len) +{ + int start_offset = outbuf.used(); + int w_cnt; + int w_grc_div; + int w_grc_trunc; + int w_uncompressed; + int z_grc_div; + int z_prev_grc_div = -1; + bool new_palette; + int i, j; + endpos = std::min(endpos, bb.byte_length()); + + // Loop over all slices + do { + // Decode slice header + int bits_avail = bb.read_avail(endpos); + if ( bits_avail >= 3 ) + { + z_grc_div = bitbuf_get(bb, "ZDIV", 3); + } + else // Insufficient bits left for a terminator (EOS may be optional) + { + z_grc_div = ZDIV_EOS; + int ones = bb.get(bits_avail); + assert( ones == (1 << bits_avail) - 1 ); + } + + while ( z_grc_div == ZDIV_EOS ) + { + // End of stream + // Byte align + bb.read_align(8); + if ( bb.byte_pos() >= endpos || single_slice ) + { + goto labelExit; + } + z_grc_div = bitbuf_get(bb, "ZDIV", 3); + } + if ( bb.read_avail(endpos) <= 0 ) + { + assert(false); + break; // Unexpectedly reached end of the input stream + } + assert(z_grc_div < 4 || z_grc_div == ZDIV_DISABLE); + bool use_zero_runs = z_grc_div != ZDIV_DISABLE; // alternating grc + w_cnt = bitbuf_get(bb, "SLICELEN", slice_len) + (slice_len == ETHOSU_SLICELEN_BITS ? 1 : 0); + w_grc_div = bitbuf_get(bb, "WDIV", 3); + w_grc_trunc = bitbuf_get(bb, "WTRUNC", 1); + new_palette = bitbuf_get(bb, "NEWPAL", 1); + if ( !new_palette ) + { + // At the moment it is not supported to change between alternating + // and non-alternating without redefining the palette (this is because + // the zero is not included in the palette in case of alternating) + bool prev_use_zero_run = z_prev_grc_div != ZDIV_DISABLE; + (void)(prev_use_zero_run); + assert((z_prev_grc_div == -1) || (use_zero_runs == prev_use_zero_run)); + } + z_prev_grc_div = z_grc_div; + if ( new_palette ) + { + palette.direct_offset = bitbuf_get(bb, "DIROFS", 5); + palette.palsize = bitbuf_get(bb, "PALSIZE", 5); + if ( palette.palsize > 0 ) + { + palette.palsize++; + } + palette.palbits = bitbuf_get(bb, "PALBITS", 3) + 2; + for ( i = 0; i < palette.palsize; i++ ) + { + palette.inv_lut[i] = int16_t(bitbuf_get(bb, "PALETTE", palette.palbits)); + } + } + + if ( w_grc_div == WDIV_UNCOMPRESSED ) + { + // Uncompressed mode + w_uncompressed = 1; + int uncompressed_bits; + if ( palette.palsize > 0 ) + { + // Uncompressed bits is given by palette size. + uncompressed_bits = 0; + while ( (1 << uncompressed_bits) < palette.palsize ) + { + uncompressed_bits++; + } + } + else + { + // No palette. PALBITS is used to specify uncompressed bits. + uncompressed_bits = palette.palbits; + } + // In uncompressed mode there's only a remainder part (no unary) + // This is achieved by setting w_grc_div to index bit width + w_grc_div = uncompressed_bits; + } + else + { + w_uncompressed = 0; + assert(w_grc_div < 6); + } + + // Decode the slice + int z_cnt = w_cnt + (( slice_len != ETHOSU_SLICELEN_BITS || new_palette ) ? 1 : 0); + std::vector w_value(w_cnt); + std::vector z_value(z_cnt); + int w_pos = 0, z_pos = 0; + int w_prev_pos = 0, z_prev_pos = 0; + int w_unary0 = 0, w_unary1 = 0, w_unary1_len = 0, w_q[12] = {0}, wq = 0; + int z_unary = 0, z_q[12] = {0}, zq = 0; + int w_nsymbols = 0; + int w_prev_enable = 0, w_prev_nsymbols = 0, w_prev_q[12] = {0}; + int z_nsymbols = 0; + int z_prev_enable = 0, z_prev_nsymbols = 0, z_prev_q[12] = {0}; + int total_zcnt = 0; + int z_unary_len = z_grc_div < 3 ? 12 : 8; + + // Loop over all chunks in the slice + do + { + // Flow control to possibly throttle either the weights or zero-runs + int balance = use_zero_runs ? w_pos - z_pos : 0; + int w_enable = (balance < 8 || !use_zero_runs) && w_pos < w_cnt; + int z_enable = (balance >= 0 && use_zero_runs) && z_pos < z_cnt; + if ( w_enable ) + { + w_unary0 = w_uncompressed ? 0 : bitbuf_get(bb, "WUNARY0", 12); + } + if ( z_enable ) + { + z_unary = bitbuf_get(bb, "ZUNARY", z_unary_len); + z_nsymbols = 0; + for ( i = 0; i < z_unary_len; i++ ) + { + if ( z_unary & (1 << i) ) + { + zq++; + } + else + { + z_q[z_nsymbols++] = zq; + zq = 0; + } + } + z_pos += z_nsymbols; + } + + if ( w_enable ) + { + w_unary1_len = 0; + int max_symbols = w_uncompressed && w_grc_div > 5 ? 8 : 12; + if ( w_unary0 != 0 ) + { + for ( i = 0; i < max_symbols; i++ ) + { + if ( w_unary0 & (1 << i) ) + { + w_unary1_len++; + } + } + } + w_unary1 = (w_unary1_len > 0) ? bitbuf_get(bb, "WUNARY1", w_unary1_len) : 0; + w_nsymbols = 0; + + for ( i = 0; i < max_symbols && (w_nsymbols < (w_cnt - w_pos)); i++ ) + { + int code = 0; + if ( w_unary0 & (1 << i) ) + { + code = 1 + ( w_unary1 & 1 ); + w_unary1 = w_unary1 >> 1; + } + wq += code; + if ( code < 2 || w_grc_trunc ) + { + w_q[w_nsymbols++] = wq; + wq = 0; + } + } + w_pos += w_nsymbols; + } + + // Remainders corresponding to the quotients in the previous chunk + if ( w_prev_enable ) + { + for ( i = 0; i < w_prev_nsymbols && w_prev_pos < w_cnt; i++, w_prev_pos++ ) + { + int remain = bitbuf_get(bb, "WREMAIN", w_grc_div); + w_value[w_prev_pos] = (w_prev_q[i] << w_grc_div) + remain; + } + } + if ( z_prev_enable ) + { + for ( i = 0; i < z_prev_nsymbols && z_prev_pos < z_cnt; i++, z_prev_pos++ ) + { + int remain = 0; + if ( z_grc_div != 0 ) + { + remain = bitbuf_get(bb, "ZREMAIN", z_grc_div); + } + z_value[z_prev_pos] = (z_prev_q[i] << z_grc_div) + remain; + total_zcnt += z_value[z_prev_pos]; + } + } + w_prev_enable = w_enable; + w_prev_nsymbols = w_nsymbols; + std::copy(std::begin(w_q), std::end(w_q), std::begin(w_prev_q)); + z_prev_enable = z_enable; + z_prev_nsymbols = z_nsymbols; + std::copy(std::begin(z_q), std::end(z_q), std::begin(z_prev_q)); + } while ( w_prev_enable || z_prev_enable ); + + // Interleave non-zero and zeros into the outbuf buffer + // Increase the outbuffer to fit the new slice + int16_t *p = outbuf.reserve(w_cnt + total_zcnt); + assert(p); + + // Insert initial zeros + if ( ( slice_len != ETHOSU_SLICELEN_BITS || new_palette ) && use_zero_runs ) + { + for ( j = 0; j < z_value[0]; j++ ) + { + *p++ = 0; + } + } + + // Loop over all weights and insert zeros in-between + for ( i = 0; i < w_cnt; i++ ) + { + int val; + assert(w_value[i] < 512); // HW supports 9bit + if ( w_value[i] < palette.palsize ) + { + val = palette.inv_lut[w_value[i]]; + } + else + { + val = w_value[i] - palette.palsize + palette.direct_offset; + } + int sign = val & 1; + int mag = val >> 1; + *p++ = sign ? int16_t(-mag) : int16_t(mag); + if ( use_zero_runs ) + { + for ( j = 0; j < z_value[i + (new_palette ? 1 : 0)]; j++ ) + { + *p++ = 0; + } + } + } + + outbuf.use(w_cnt + total_zcnt); + } while (!single_slice); +labelExit: + return outbuf.used() - start_offset; +} + + +constexpr int INITIAL_BLOCKS = 4; + + +#if defined __cplusplus +extern "C" +{ +#endif + +// Decode a stream +// +// result - Resulting data from decode (must be freeed after use) +// buffer - Incoming bitstream buffer +// size_bytes - Size of the bitstream buffer (in bytes) +void ml_decode_ethosu_stream(ml_decode_result_t *result, const uint8_t *buffer, int size_bytes) +{ + assert(result && buffer && size_bytes); + result->decoded_data = nullptr; + result->section_sizes = nullptr; + + bitbuf_t bb(buffer, size_bytes); + raw_buffer_t output(4096, MLW_ENCODE_ALLOC_STREAM0, nullptr); + palette_t palette; + ml_decode_internal(output, bb, palette, size_bytes, false, ETHOSU_SLICELEN_BITS); + + // Populate the results set + result->decoded_length = output.used(); + result->decoded_data = output.detach(); +} + +ML_ENCODER_DLL_EXPORT void mld_free(ml_decode_result_t *result) +{ + if ( result ) + { + if ( result->decoded_data ) + { + result->decoded_data = static_cast(realloc( result->decoded_data, 0)); + } + if ( result->section_sizes ) + { + free( result->section_sizes ); + result->section_sizes = nullptr; + } + } +} + + +#if defined __cplusplus +} // extern "C" +#endif + diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode.cpp b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode.cpp new file mode 100644 index 00000000000..6a0004e9a53 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode.cpp @@ -0,0 +1,979 @@ +// +// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "../include/mlw_encode.h" + +#include "ml_encoder_internal.hpp" +#include "ml_raw_buffer.hpp" +#include "ml_bit_buffer.hpp" + +#include +#include +#include +#include +#include +#include + +constexpr int ZERO_FREQ_THRESHOLD = 5; +constexpr int MIN_ZERO_RUN_LENGTH = 2; + +// Create palette from the given frequencies +// Freq index 0-511 correspond to weights -256..255 +// partial_data - don't make decisions about data that will be encoded +// that wasn't included in the frequency analysis. +static void create_palette(palette_t *p, bool partial_data, bool disable_lut) +{ + uint64_t freq64[512] = {0}; + int i, all_cnt, all_max_val; + + // Pair the frequency with the value so that + // the array can be sorted on frequency while keeping + // track of the corresponding palette value + all_cnt = 0; + all_max_val = 0; + for ( i = -255; i < 256; i++ ) + { + if ( i == 0 && p->use_zero_runs ) continue; + int sign = i < 0; + int mag = abs(i); + int palval = (mag << 1) | sign; + + // Store palette value in 16 LSB bits, which will not affect the sorting + freq64[palval] = ((static_cast(p->freq[i + 256])) << 16) | palval; + all_cnt += p->freq[i + 256]; + + if ( p->freq[i + 256] > 0 ) + { + all_max_val = std::max(all_max_val, palval); + } + } + + // Cannot use direct offset with partial data. + p->only_zeros = !partial_data && (all_cnt == 0); + p->direct_offset = 0; + if ( !partial_data && (all_cnt != 0) ) + { + // Find the first non-used weight value around zero (0, -1, +1, -2, +2 etc) + for ( i = 0; i < 31; i++ ) + { + if ( (freq64[i] >> 16) != 0 ) + { + break; + } + } + p->direct_offset = i; + } + + // Sort in descending frequency order + std::sort(std::begin(freq64), std::end(freq64), [](uint64_t a, uint64_t b) { return b < a; }); + + // Check if all weights fit into the palette (and the palette is not empty) + p->only_palette = !disable_lut && !partial_data && (freq64[0] >> 16) > 0 && (freq64[32] >> 16) == 0; + + int max_palette_size; + if ( p->only_palette ) + { + max_palette_size = 32; + } + else + { + // For direct-lut we must make sure that the encoded weight + // index is not > 511. We do that by limiting the palette size + // such that the greatest value can be reached after subtracting + // the palette size. + max_palette_size = std::min(32, 511 - all_max_val); + if ( max_palette_size == 1 ) + { + max_palette_size = 0; // because palette of size 1 is not supported + } + } + + // Setup the 32 entry palette + int max_lut_val = 0, val, cnt, lut_cnt = 0; + for ( i = 0; i < max_palette_size; i++ ) + { + cnt = static_cast(freq64[i] >> 16); + val = freq64[i] & 0xffff; + // If partial data, all palette entries must be filled (even if they're wrong) + if ( cnt == 0 && !partial_data ) break; + p->lut[i] = int16_t(val); + max_lut_val = std::max(max_lut_val, val); + lut_cnt += cnt; + } + + // When all weights are the same nonzero value a palette size of 1 is possible; but not supported. + // Make the palette 2 entries long and zero the second entry (it's never indexed). + if ( i == 1 ) + { + p->lut[i++] = 0; + } + + // Heuristic for when to use the palette. If more than half of the + // weights are in the palette then we use it. This ensures we don't + // use palette for e.g. rectangular distributions. + int palbits_val; + if ( !disable_lut && (lut_cnt >= all_cnt / 2) ) + { + p->palsize = i; + palbits_val = max_lut_val; + } + else + { + // No palette + p->palsize = 0; + // If no palette, then palbits is used to specify the + // number of bits required for uncompressed mode, i.e. + // the number of bits for the greatest weight value + palbits_val = all_max_val; + } + + // the palette entry bit width + // minimum 2-bits (because PALBITS is in range 2..9) + int palbits = 2; + while ( (1 << palbits) <= palbits_val ) + { + palbits++; + } + assert(palbits <= 9); + p->palbits = palbits; +} + +static void create_inverse_palette(palette_t *p) +{ + int i; + int val = p->palsize - p->direct_offset; + for ( i = 0; i < 256; i++ ) + { + p->inv_lut[256 + i] = int16_t(val); + p->inv_lut[256 - i] = int16_t(val + 1); + val += 2; + } + p->inv_lut[0] = 0; + + for ( i = 0; i < p->palsize; i++ ) + { + val = p->lut[i]; + int sign = val & 1; + int mag = val >> 1; + int weight = sign ? -mag : mag; + assert( ((weight + 256) >= 0) && ((weight + 256) < 512) ); + p->inv_lut[weight + 256] = int16_t(i); + } +} + +// If palette_size is 512, then palette is not used (in that case the palette is setup +// with the standard alternating unsigned to signed mapping) +static void update_palette(mle_context_t *ctx, palette_t *p, const int16_t *weights, int weights_count, bool partial_data, bool disable_lut, bool disable_zruns) +{ + int(&freq)[512] = p->freq; + + int total_zeroes = 0; + int zeroes_in_run = 0; + int zeroes_in_all_runs = 0; + + // Calculate frequencies of the given weight stream + for ( int i = 0; i < weights_count; i++ ) + { + unsigned weight = weights[i] + 256; + freq[weight]++; + + uint64_t &value = ctx->weights_used[weight / 64]; + uint64_t mask = (1ull << (weight % 64)); + if ((value & mask) == 0) + { + ctx->distinct_weights++; + value |= mask; + } + + if ( weights[i] == 0 ) + { + total_zeroes++; + zeroes_in_run++; + } + else + { + if ( zeroes_in_run >= MIN_ZERO_RUN_LENGTH ) + { + zeroes_in_all_runs += zeroes_in_run; + } + zeroes_in_run = 0; + } + } + + // Detect trailing zero runs in compression + if ( zeroes_in_run >= MIN_ZERO_RUN_LENGTH ) + { + zeroes_in_all_runs += zeroes_in_run; + } + + int common_val = 0; + int common_freq = 0; + for ( int i = 0; i < 512; i++ ) + { + // Most common non-zero frequency (because we already have that) + if ( (i != 256) && freq[i] > common_freq ) + { + common_val = i - 256; + common_freq = freq[i]; + } + } + + // Decide if zero-runs (alternating mode) should be used: + // * zero runs must make up at least half of the zeroes + // * zero should be the most common symbol + // * zero should be sufficiently more common than the second most common symbol + bool use_zero_runs = zeroes_in_all_runs >= (total_zeroes / 2); + use_zero_runs &= total_zeroes > (ZERO_FREQ_THRESHOLD * common_freq); + p->use_zero_runs = use_zero_runs && !disable_zruns; + // Create the palette + create_palette(p, partial_data, disable_lut); +} + +#define NWCFG 13 +#define NZCFG 4 // restrict search to ZDIV=0..3 +#define MAX_ZWCFG ((NWCFG > NZCFG) ? NWCFG : NZCFG) + +// (trunc<<4) | div, 0x20 means uncompressed +static constexpr char w_grc_params[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20}; +static constexpr char z_grc_params[] = {0x00, 0x01, 0x02, 0x03}; + +struct grc_param_t +{ + int cfg = 0; + int end_pos = 0; +}; + +template +static int search_grc_params(const TYPE *inval_buf, int n_inval, int zrun_mode, int uncompressed_bits, + std::vector &result, bool single_slice) +{ + assert(uncompressed_bits < 32); + int n_cfg = zrun_mode ? NZCFG : NWCFG; + const char *grc_params = zrun_mode ? z_grc_params : w_grc_params; + + // Greedy forward-only GRC search (with optimisation for avoiding + // unusable GRC parameters). + const int cmd_cost = 40; + + int bit_count[MAX_ZWCFG] = {0}; + int reset_pos[MAX_ZWCFG] = {0}; + bool coded[MAX_ZWCFG] = {false}; + bool any_uncodable[MAX_ZWCFG] = {false}; + int active_bitcount = 0; + int active_cfg = -1; + int add_uncompressed_bits = (uncompressed_bits > 0) ? uncompressed_bits : 100; + + for ( int i = 0; i < n_inval; i++ ) + { + int value = inval_buf[i]; + + int best_bitcount = 0x7FFFFFFF; + int best_cfg = 0; + + // Loop over GRC parameters, calculate bits to code value, and then update the search state + for ( int j = 0; j < n_cfg; j++ ) + { + int div = grc_params[j] & 15; + int trunc = grc_params[j] >> 4; + int q = value >> div; + int bits = trunc ? std::min(q + 1, 2) + div : q + 1 + div; + + bool can_code = !(!zrun_mode && ((trunc && q > 2) || q > 31)); + if ( trunc == 2 ) + { + bits = add_uncompressed_bits; + can_code = true; + } + + if ( can_code ) + { + if ( !coded[j] ) + { + bit_count[j] = active_bitcount; // Reset non-coded to current best + } + bit_count[j] = bit_count[j] + bits; + + if ( bit_count[j] < best_bitcount ) + { + best_bitcount = bit_count[j]; + best_cfg = j; + } + } + else + { + reset_pos[j] = i + 1; + bit_count[j] += cmd_cost; // Would have to change away if used + } + + coded[j] = can_code; + any_uncodable[j] |= !can_code; + } + + // In single-slice mode we can check the bit counts afterwards; otherwise we record + // slice start points by tracking the minimum of the accumulting bit counts for + // different grc parameters. + if ( !single_slice ) + { + bool must_code = (active_cfg == -1) || !coded[active_cfg]; + if ( must_code || ((best_cfg != active_cfg) && (best_bitcount + cmd_cost) < active_bitcount) ) + { + // Commit non-initial changes + if ( active_cfg != -1 ) + { + // Range elision (was the other config better all along?) + if ( (bit_count[best_cfg] < bit_count[active_cfg]) && (reset_pos[best_cfg] <= reset_pos[active_cfg]) ) + { + // If the current BEST config started before the ACTIVE config in time, then switch to using the BEST config. + active_cfg = best_cfg; // Duplicated on both paths for clarity + } + else + { + // Otherwise use the ACTIVE config for this slice before switching to the BEST config. + grc_param_t param; + param.cfg = active_cfg; + param.end_pos = i; + assert((active_cfg != 12) || (uncompressed_bits != 0)); + result.push_back(param); + } + } + active_cfg = best_cfg; + } + } + else if ( active_cfg == -1 ) + { + active_cfg = best_cfg; + } + + active_bitcount = bit_count[active_cfg]; + } + + // terminate the run + if ( result.empty() || (result.back().cfg != active_cfg) ) + { + // If single slice then select the best minimum-bits configuration + if (single_slice) + { + assert( result.empty() ); + active_cfg = -1; + int max_bit_count = std::numeric_limits::max(); + for (int i=0; i < n_cfg; i++) + { + if ( !any_uncodable[i] && (bit_count[i] <= max_bit_count) ) + { + if ( (active_cfg != 12) || (uncompressed_bits != 0) ) + { + active_cfg = i; + max_bit_count = bit_count[i]; + } + } + } + assert(active_cfg != -1); // There isn't a usable grc parameter (fatal) + } + + grc_param_t param; + param.cfg = active_cfg; + assert((active_cfg != 12) || (uncompressed_bits != 0)); + result.push_back(param); + } + result.back().end_pos = n_inval; + + return active_bitcount; +} + +#if !ENABLE_DEBUG_BITSTREAM + // Release build putbits macro + #define bitbuf_put(bb_, name_, len_, data_) bb_->put(len_, data_) + #define bitbuf_align(bb_, name_, len_, data_) bb_->align(len_, data_) +#else + // Debug build putbits macro + inline void bitbuf_put(bitbuf_t *bb, const char *name, int len, int data) + { + assert(len <= 32); + int pre_pos = bb->pos(); + bb->put(len, data); + BITSTREAM_LOG("%6d %s:%d = %d\n", pre_pos, name, bb->pos()-pre_pos, data); + } + + // Debug build putbits macro + inline void bitbuf_align(bitbuf_t *bb, const char *name, int len, int data) + { + assert(len <= 32); + int pre_pos = bb->pos(); + bb->align(len, data); + BITSTREAM_LOG("%6d %s:%d = %d\n", pre_pos, name, bb->pos()-pre_pos, data); + } +#endif + +static slice_params_t encode_slice_header(mle_context_t *ctx, int slicelen, bool new_palette, int uncompressed_bits, int w_cfg, int z_cfg, bitbuf_t *bb) +{ + assert( (ctx->slicelen_bits == 15) || (ctx->slicelen_bits == 17) ); // Currently known formats + assert( slicelen < (1 << ctx->slicelen_bits) ); + assert( w_cfg >= 0 && w_cfg < (sizeof(w_grc_params)/sizeof(w_grc_params[0]))); + + // GRC parameters for this slice + int w_grc_trunc = (w_grc_params[w_cfg] >> 4) == 1; + int w_uncompressed = (w_grc_params[w_cfg] >> 4) == 2; + // Callers can signal a truly empty slice with a negative z_cfg index + assert( ((z_cfg < 0) && (slicelen == 0)) || (ctx->allow_empty_slices && slicelen >= 0) || (slicelen >= 1) ); + int z_grc_div = (z_cfg < 0) ? ZDIV_DISABLE : z_grc_params[z_cfg] & 15; + int w_grc_div = w_uncompressed ? uncompressed_bits : (w_grc_params[w_cfg] & 15); + + int zdiv = ctx->palette.use_zero_runs ? z_grc_div : ZDIV_DISABLE; + int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED; + + if ( ENABLE_DEBUG_BITSTREAM ) + { + BITSTREAM_LOG("slice: bitoffset %d slicelen %d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %d\n", bb->pos(), + slicelen, zdiv, wdiv, w_grc_trunc, new_palette, ctx->palette.palbits, ctx->palette.palsize); + } + + // Write slice header + bitbuf_put(bb, "ZDIV", 3, zdiv); + bitbuf_put(bb, "SLICELEN", ctx->slicelen_bits, ctx->allow_empty_slices ? slicelen : slicelen - 1); + bitbuf_put(bb, "WDIV", 3, wdiv); + bitbuf_put(bb, "WTRUNC", 1, w_grc_trunc); + bitbuf_put(bb, "NEWPAL", 1, new_palette); + if ( new_palette ) + { + bitbuf_put(bb, "DIROFS", 5, ctx->palette.direct_offset); + bitbuf_put(bb, "PALSIZE", 5, std::max(0, ctx->palette.palsize - 1)); + bitbuf_put(bb, "PALBITS", 3, ctx->palette.palbits - 2); + for (int i = 0; i < ctx->palette.palsize; i++ ) + { + bitbuf_put(bb, "PALETTE", ctx->palette.palbits, ctx->palette.lut[i]); + } + } + + slice_params_t header; + header.w_grc_trunc = w_grc_trunc; + header.w_uncompressed = w_uncompressed; + header.z_grc_div = z_grc_div; + header.w_grc_div = w_grc_div; + return header; +} + + +static void encode_slice(mle_context_t *ctx, const int16_t *w_values, int weight_count, const int32_t *z_values, int zero_count, bool new_palette, + int uncompressed_bits, int w_cfg, int z_cfg, bitbuf_t *bb) +{ + int w_cnt = weight_count; + int z_cnt = (z_values && zero_count) ? w_cnt + (new_palette ? 1 : 0) : 0; + + slice_params_t hdr = encode_slice_header(ctx, w_cnt, new_palette, uncompressed_bits, w_cfg, z_cfg, bb); + + assert(z_cfg >= 0 && "slice was signalled as truly empty"); + + // Record slice parameters for HW testbench debugging + if ( ctx->enable_slice_debug ) + { + ctx->slice_debug.push_back( mle_slice_debug_t { hdr, ctx->palette } ); + } + + int w_grc_div = hdr.w_grc_div; + bool w_uncompressed = hdr.w_uncompressed; + int z_grc_div = hdr.z_grc_div; + int w_grc_trunc = hdr.w_grc_trunc; + + int j; + int z_unary_len = z_grc_div < 3 ? 12 : 8; + int w_pos = 0, z_pos = 0; + int w_unary0 = 0, w_unary1 = 0, w_unary1_len = 0, w_q = -1, w_r = 0; + int z_unary = 0, z_q = -1, z_r = 0; + + int w_remain_data[2][12] = {{0}}; + int *w_remain = w_remain_data[0]; + int *w_prev_remain = w_remain_data[1]; + int w_nsymbols = 0; + int w_prev_enable = 0, w_prev_nsymbols = 0; + + int z_remain_data[2][12] = {{0}}; + int *z_remain = z_remain_data[0]; + int *z_prev_remain = z_remain_data[1]; + int z_nsymbols = 0; + int z_prev_enable = 0, z_prev_nsymbols = 0; + bool use_zero_runs = ctx->palette.use_zero_runs; + + do + { + int balance = use_zero_runs ? w_pos - z_pos : 0; + int w_enable = balance < 8 && w_pos < w_cnt; + int z_enable = balance >= 0 && use_zero_runs && z_pos < z_cnt; + if ( w_enable ) + { + // Encode chunk (weights) + j = 0; + w_nsymbols = 0; + w_unary0 = 0; + w_unary1 = 0; + w_unary1_len = 0; + int max_symbols = (w_uncompressed && w_grc_div > 5) ? 8 : 12; + while ( j < max_symbols ) + { + if ( w_q < 0 ) + { + if ( w_pos < w_cnt ) + { + int value = w_values[w_pos]; + assert( value >= 0 && value < 512 ); + w_q = value >> w_grc_div; + w_r = value & ((1 << w_grc_div) - 1); + assert(w_q <= 31 && (!w_grc_trunc || w_q <= 2)); + } + else + { + w_q = 0; + w_r = -1; // don't send remainder + } + } + while ( w_q >= 0 && j < max_symbols ) + { + w_unary0 |= w_q > 0 ? (1 << j) : 0; + if ( w_q > 0 ) + { + w_unary1 |= w_q > 1 ? (1 << w_unary1_len) : 0; + w_unary1_len++; + } + j++; + w_q -= 2; + if ( w_grc_trunc ) w_q--; + } + if ( w_q < 0 && w_r >= 0 ) + { + w_remain[w_nsymbols] = w_r; + w_nsymbols++; + w_pos++; + } + } + } + + if ( z_enable ) + { + // Encode chunk (zrun) + j = 0; + z_nsymbols = 0; + z_unary = 0; + while ( j < z_unary_len ) + { + if ( z_q < 0 ) + { + if ( z_pos < z_cnt ) + { + int value = z_values[z_pos]; + z_q = value >> z_grc_div; + z_r = value & ((1 << z_grc_div) - 1); + assert( z_q >= 0 ); // There are no negative length z-runs + } + else + { + z_q = 0; + z_r = -1; + } + } + while ( z_q >= 0 && j < z_unary_len ) + { + z_unary |= z_q > 0 ? (1 << j) : 0; + j++; + z_q--; + } + if ( z_q < 0 && z_r >= 0 ) + { + assert( z_nsymbols < 12 ); + z_remain[z_nsymbols] = z_r; + z_nsymbols++; + z_pos++; + } + } + } + + // Write chunk to bitstream + if ( w_enable && !w_uncompressed ) + { + bitbuf_put(bb, "WUNARY0", 12, w_unary0); // 12 bits + } + if ( z_enable ) + { + bitbuf_put(bb, "ZUNARY", z_unary_len, z_unary); // 12 or 8 bits + } + if ( w_enable && !w_uncompressed && (w_unary1_len > 0) ) + { + bitbuf_put(bb, "WUNARY1", w_unary1_len, w_unary1); // max 12 bits + } + if ( w_prev_enable ) + { + for (int i = 0; i < w_prev_nsymbols; i++ ) + { + bitbuf_put(bb, "WREMAIN", w_grc_div, w_prev_remain[i]); + } + } + if ( z_prev_enable && (z_grc_div > 0) ) + { + for (int i = 0; i < z_prev_nsymbols; i++ ) + { + bitbuf_put(bb, "ZREMAIN", z_grc_div, z_prev_remain[i]); + } + } + w_prev_enable = w_enable; + w_prev_nsymbols = w_nsymbols; + std::swap(w_prev_remain, w_remain); + z_prev_enable = z_enable; + z_prev_nsymbols = z_nsymbols; + std::swap(z_prev_remain, z_remain); + } while ( w_prev_enable || z_prev_enable ); +} + + +int ml_encode_section(mle_context_t *ctx, const int16_t *inbuf, int size, palette_t *p, bitbuf_t *bitbuf) +{ + bool new_palette = (p != nullptr); + + // Reuse previous if not specified + if ( p == nullptr ) + { + p = &ctx->palette; + } + + // Uncompressed mode can only be used if either all weights + // are in the palette OR if the palette is not used. + int uncompressed_bits = 0; + if ( p->only_palette ) + { + // Uncompressed bits derived from palette size + while ( (1 << uncompressed_bits) < p->palsize ) + { + uncompressed_bits++; + } + } + else if ( p->palsize == 0 ) + { + // Uncompressed bits is palbits (which is the bitdepth of the greatest weight) + uncompressed_bits = p->palbits; + } + + // If there are no weights at all, emit an empty slice header, then exit. + if ( size == 0 ) + { + // Signal a truly empty slice using -ve zgrc to ensure ZDIV_DISABLE is written to the stream. + if ( ctx->allow_empty_slices ) + { + encode_slice_header(ctx, 0, new_palette, uncompressed_bits, 0, -1, bitbuf); + } + return 0; + } + + std::vector weight_values; + weight_values.reserve(size); + + // If zruns was enabled, expect total to be < weight_values/2 + std::vector zrun_values; + if ( p->use_zero_runs ) + { + zrun_values.reserve( size / 4); + } + + // Get weights (or weight indicies) AND zero-runs from the input weight stream. + int i = 0; + bool allow_empty_slices = !p->only_zeros || ctx->allow_empty_slices; + int total_zcnt = 0; + const int max_slice_len = (1 << ctx->slicelen_bits) - 1; + const int max_zero_run_length = ctx->allow_empty_slices ? max_slice_len - 1 : INT32_MAX; + while ( 1 ) + { + if ( p->use_zero_runs ) + { + int zcnt = 0; + // Count zero run + // Special case: if all weights in the section are zero, we must + // still ensure we have one coded weight so the the slice length + // doesn't become 0. Therefore we skip the first zero run and code + // the zero explicitly as a weight value instead + if ( allow_empty_slices || i > 0 ) + { + while ( i < size && inbuf[i] == 0 && zcnt < max_zero_run_length ) + { + zcnt++; + i++; + } + } + total_zcnt += zcnt; + zrun_values.push_back(zcnt); + } + if ( i == size ) break; + int16_t value = p->inv_lut[inbuf[i] + 256]; + weight_values.push_back(value); + i++; + } + + // Search for good GRC parameters for the weight stream + std::vector w_slice_cfg; + int n_weights = int(weight_values.size()); + if ( n_weights ) + { + // Use a fixed grc config index if provided (partial-data mode sets this) + if ( ctx->fixed_wgrc >= 0 ) + { + w_slice_cfg.push_back(grc_param_t{ ctx->fixed_wgrc, n_weights }); + } + else + { + search_grc_params(weight_values.data(), n_weights, 0, uncompressed_bits, w_slice_cfg, ctx->single_slice_sections); + } + } + int n_w_slice = int(w_slice_cfg.size()); + + // Search for good GRC parameters for the zrun stream + std::vector z_slice_cfg; + if ( p->use_zero_runs ) + { + // Use a fixed grc config index if provided (partial-data mode sets this) + if ( ctx->fixed_zgrc >= 0 ) + { + z_slice_cfg.push_back(grc_param_t{ ctx->fixed_zgrc, n_weights + 1 }); + } + else + { + search_grc_params(zrun_values.data(), n_weights + 1, 1, 0, z_slice_cfg, ctx->single_slice_sections); + } + } + int n_z_slice = int(z_slice_cfg.size()); + + int loops = 0; + + // Encode bitstream slice + int pos = 0, i_w_slice = 0, i_z_slice = 0; + bool only_zero_runs_pass = !zrun_values.empty(); + while ( (pos < n_weights) || new_palette || only_zero_runs_pass ) + { + int w_len = 0; + int z_len = 0; + + if ( i_w_slice < n_w_slice ) + { + w_len = w_slice_cfg[i_w_slice].end_pos - pos; + w_len = std::min(w_len, max_slice_len); + } + + if ( i_z_slice < n_z_slice ) + { + z_len = z_slice_cfg[i_z_slice].end_pos - pos; + z_len = std::min(z_len, max_slice_len); + } + + // The first slice (when new_palette is 1) encodes zero runs both at the + // beginning and end (i.e. number of zero runs are len+1). + // The following slices only encode zero runs at the end (there cannot be + // any zeros in the beginning since they are encoded by the previous slice) + const int32_t *zrun_buf = p->use_zero_runs ? zrun_values.data() + pos + !(new_palette || ctx->allow_empty_slices) : nullptr; + const int16_t *w_buf = w_len ? weight_values.data() + pos : nullptr; + int w_cfg = w_len ? w_slice_cfg[i_w_slice].cfg : 0; + int z_cfg = p->use_zero_runs ? z_slice_cfg[i_z_slice].cfg : 0; + + encode_slice(ctx, w_buf, w_len, zrun_buf, z_len, new_palette, uncompressed_bits, w_cfg, z_cfg, bitbuf); + new_palette = 0; + + if ( z_len <= 0 && w_len > 0 ) + pos += w_len; + else if ( w_len <= 0 && z_len > 0 ) + pos += z_len; + else + pos += std::min(z_len, w_len); + + if ( i_w_slice < n_w_slice && w_slice_cfg[i_w_slice].end_pos <= pos ) + { + i_w_slice++; + } + if ( i_z_slice < n_z_slice && z_slice_cfg[i_z_slice].end_pos <= pos ) + { + i_z_slice++; + } + loops++; + only_zero_runs_pass = false; + } + // Single-slice sections can only generate one slice (a single loop) + assert( !ctx->single_slice_sections || (ctx->single_slice_sections && loops == 1) ); + if ( ctx->single_slice_sections && (loops != 1) ) + { + return -1; + } + return total_zcnt; +} + + +palette_t *ml_encode_palette(mle_context_t *ctx, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags) +{ + palette_t *palette = nullptr; + if ( !ctx->palette_valid || (mlw_encode_flags & MLW_ENCODE_INSERT_PALETTE) ) + { + if (mlw_encode_flags & MLW_ENCODE_RESET_PALETTE) + { + memset( ctx->palette.freq, 0, sizeof(ctx->palette.freq) ); + } + + bool partial_data = (mlw_encode_flags & MLW_ENCODE_PARTIAL_DATA) != 0; + bool disable_lut = (mlw_encode_flags & MLW_ENCODE_NO_PALETTE_LUT) != 0; + bool disable_zruns = (mlw_encode_flags & MLW_ENCODE_NO_ZERO_RUNS) != 0; + + assert( analyse_count >= encode_count && "Must analyse at least as much as is encoded"); + update_palette(ctx, &ctx->palette, weights, analyse_count, partial_data, disable_lut, disable_zruns); + ctx->palette_valid = true; + if ( !(mlw_encode_flags & MLW_ENCODE_DPIC_FORCE_PARAMS) ) + { + ctx->fixed_wgrc = (partial_data) ? 5 : -1; + ctx->fixed_zgrc = (partial_data) ? 3 : -1; + } + + create_inverse_palette(&ctx->palette); + palette = &ctx->palette; + } + return palette; +} + +void ml_encode_eos(mle_context_t *ctx, bitbuf_t &bits, unsigned mlw_encode_flags) +{ + // Add end of stream marker and align to 128bit + bitbuf_t *bb = &bits; + if ( ctx->eos_required ) + { + bitbuf_put(bb, "ZDIV", 3, ZDIV_EOS); + } + bitbuf_align(bb, "BYTEALIGN", 8, 0xff); + + if ( !(mlw_encode_flags & MLW_ENCODE_NO_PADDING) ) + { + bb->align( 128, 0xFF ); + } + bb->flush(); +} + +int ml_encode_internal(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags) +{ + palette_t *palette = ml_encode_palette(ctx, weights, encode_count, analyse_count, mlw_encode_flags); + + int zresult = ml_encode_section(ctx, weights, encode_count, palette, &bits); + if ( zresult < 0 ) + { + return -1; + } + ctx->zero_count += zresult; + return 0; +} + +extern "C" +{ + +ML_ENCODER_DLL_EXPORT mle_context_t *mle_create_context(int32_t syntax) +{ + mle_context_t *ctx = new mle_context_t; + ctx->zero_count = 0; + ctx->syntax = syntax; + ctx->realloc_func = nullptr; + if (syntax == MLW_ENCODE_SYNTAX_ETHOSU) + { + ctx->slicelen_bits = ETHOSU_SLICELEN_BITS; + ctx->allow_empty_slices = false; + ctx->single_slice_sections = false; + ctx->eos_required = true; + } + else if (syntax == MLW_ENCODE_SYNTAX_ETHOSU_FWD) + { + ctx->slicelen_bits = 0; + } + else + { + assert(false && "bad syntax"); + delete ctx; + return nullptr; + } + return ctx; +} + +ML_ENCODER_DLL_EXPORT int mle_context_query_zeroes(mle_context_t *ctx) +{ + assert( ctx ); + return ctx->zero_count; +} + +ML_ENCODER_DLL_EXPORT int mle_context_query_weights_used(mle_context_t *ctx, uint64_t weights_used[512 / 64]) +{ + assert( ctx ); + std::copy(std::begin(ctx->weights_used), std::end(ctx->weights_used), weights_used); + return ctx->distinct_weights; +} + +ML_ENCODER_DLL_EXPORT void mle_context_set_allocator(mle_context_t *ctx, void* (*realloc_func)(void*, size_t, int)) +{ + assert( ctx ); + ctx->realloc_func = realloc_func; +} + +ML_ENCODER_DLL_EXPORT void mle_destroy_context(mle_context_t *ctx) +{ + assert(ctx); + delete ctx; +} + +ML_ENCODER_DLL_EXPORT int mle_encode(mle_context_t *ctx, ml_encode_result_t *result, const int16_t *inbuf, int inbuf_size, unsigned mlw_encode_flags) +{ + assert( ctx && result ); + raw_buffer_t output(4096, MLW_ENCODE_ALLOC_STREAM0, ctx->realloc_func); + bitbuf_t bits(output, 4096, mlw_encode_flags & MLW_ENCODE_NO_BITSTREAM); + int written = 0; + + if ( ctx->syntax == MLW_ENCODE_SYNTAX_ETHOSU_FWD ) + { + written = ml_encode_fwd(ctx, bits, inbuf, inbuf_size, mlw_encode_flags); + } + else + { + int start = bits.byte_pos(); + if ( ml_encode_internal(ctx, bits, inbuf, inbuf_size, inbuf_size, mlw_encode_flags) < 0 ) + { + return -1; + } + ml_encode_eos(ctx, bits, mlw_encode_flags); + written = bits.byte_pos() - start; + } + + if ( written >= 0 ) + { + result->encoded_data = output.detach(); + result->source_length = inbuf_size; + result->encoded_length = written; + result->section_info = nullptr; + result->section_count = 0; + } + return written; +} + +ML_ENCODER_DLL_EXPORT void mle_free(ml_encode_result_t *result) +{ + if ( result ) + { + if ( result->encoded_data ) + { + free( result->encoded_data ); + result->encoded_data = nullptr; + } + if ( result->section_info ) + { + free( result->section_info ); + result->section_info = nullptr; + } + } +} + +} diff --git a/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode_fwd.cpp b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode_fwd.cpp new file mode 100644 index 00000000000..c0bcc1015c1 --- /dev/null +++ b/src/gallium/drivers/ethosu/mlw_codec/source/mlw_encode_fwd.cpp @@ -0,0 +1,197 @@ +// +// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the License); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an AS IS BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "../include/mlw_encode.h" + +#include "ml_encoder_internal.hpp" +#include "ml_raw_buffer.hpp" +#include "ml_bit_buffer.hpp" + +#include +#include +#include +#include +#include +#include + +constexpr static int LUT_MAX = 16; +constexpr static int INV_MAX = 512; + +struct fwd_header_t +{ + bool raw_mode_flag = false; + bool small_lut_flag = false; + int8_t zero_adjust = 0; + int16_t lut[LUT_MAX] = {}; + int8_t inv_index[INV_MAX] = {}; + fwd_header_t() + { + std::fill_n(inv_index, INV_MAX, -1); + } +}; + + +static inline int32_t fold(int32_t value) +{ + // Fold into positive value (sign in lsb) + return (abs(value) << 1) | (uint32_t(value) >> 31); +} + + +void fwd_emit_header(bitbuf_t &bits, const fwd_header_t &hdr) +{ + bits.put( 1, hdr.raw_mode_flag ? 1 : 0 ); + bits.put( 1, hdr.small_lut_flag ? 1 : 0 ); + bits.fill( 102, 0 ); + bits.put_masked( 8, hdr.zero_adjust ); + for (int i = 0; i < LUT_MAX; i++) + { + bits.put( 9, fold(hdr.lut[i]) ); + } +} + + +bool fwd_analyse(fwd_header_t &hdr, const int16_t *weights, int count, bitbuf_t &bits) +{ + int range_min = 1000; + int range_max = -1000; + int8_t *inv = hdr.inv_index; + int16_t *lut = hdr.lut; + int lut_used = 0; + bool use_lut = true; + + // Must check all the zero-point-correct values for full range + for (int i = 0; i < count; i++) + { + int value = weights[i]; + range_min = std::min( range_min, value ); + range_max = std::max( range_max, value ); + + // Update the LUT only while it's still viable (predicts well). + if ( use_lut ) + { + // Map the signed value to the LUT via +ve indexed table + int idx = fold(value); + assert( idx < INV_MAX ); + + // Check if value has already been indexed before adding a + // new lut entry. + if ( inv[idx] < 0 ) + { + if ( lut_used < LUT_MAX ) + { + inv[idx] = lut_used; + lut[lut_used] = value; + lut_used++; + } + else + { + use_lut = false; // LUT was full and is now unusable + } + } + // While lut2 is valid, encode the entries. When we're + // done the bitstream will be ready. + if (lut_used <= 4) + { + bits.put(2, inv[idx]); + } + } + } + + hdr.raw_mode_flag = !use_lut; + hdr.small_lut_flag = (lut_used <= 4); + hdr.zero_adjust = 0; + + // If raw mode, calculate the zero point + if ( hdr.raw_mode_flag ) + { + int full_range = (range_max - range_min); + if (full_range >= 256) + { + return false; // Can't encode this stream + } + else if ( range_min < -128 ) + { + hdr.zero_adjust = -128 - range_min; // Raw values need offsetting +ve by this amount + } + else if ( range_max > 127 ) + { + hdr.zero_adjust = 127 - range_max; // Raw values need offsetting -ve by this amount + } + } + + return (range_min >= -256) && (range_max < 256); +} + +// Encode zero-corrected weight values in the optimal fast-weight format. +int ml_encode_fwd(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int count, unsigned mlw_encode_flags) +{ + fwd_header_t header; + int pos = bits.pos(); + bits.fill(256, 0); // Reserve space for header + + // Encode lut2 weights directly to the main stream while analysing + if ( !fwd_analyse(header, weights, count, bits) ) + { + return -1; // Encoding error + } + + // Check for forced no palette + if ( mlw_encode_flags & MLW_ENCODE_NO_PALETTE_LUT ) + { + header.raw_mode_flag = 1; + header.small_lut_flag = 0; + } + + // Use a substream of the main stream for the header + bitbuf_t hdr_bits(bits, pos); + fwd_emit_header(hdr_bits, header); + bits.sync(hdr_bits); + + // LUT2 + if ( header.small_lut_flag ) + { + assert( !header.raw_mode_flag ); + } + // RAW + else if ( header.raw_mode_flag ) + { + bits.reposition(pos + 256); + for (int i=0; i < count; i++) + { + int value = (weights[i] + header.zero_adjust) & 0xFF; + bits.put(8, value); + } + } + // LUT4 + else + { + bits.reposition(pos + 256); + for (int i=0; i < count; i++) + { + int idx = fold(weights[i]); + bits.put(4, header.inv_index[idx]); + } + } + + bits.align(256, 0); + bits.flush(); + + int written = bits.pos() / 8; + return written; +}