ethosu: Switch to the weight encoder from Regor

We vendor the encoder used in the Regor compiler in Vela, and replace
the previous one that was used by the Python compiler and doesn't
support U85.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39611>
This commit is contained in:
Tomeu Vizoso 2026-02-17 09:18:57 +01:00 committed by Marge Bot
parent 410d74e078
commit d66d2c05d3
20 changed files with 3328 additions and 1347 deletions

View file

@ -1,2 +1,4 @@
# Vendored code # Vendored code
src/amd/vulkan/radix_sort/* src/amd/vulkan/radix_sort/*
src/gallium/drivers/ethosu/mlw_codec/**/*
src/gallium/drivers/ethosu/ethosu_encode_support.h

View file

@ -5,9 +5,53 @@
#include "util/u_inlines.h" #include "util/u_inlines.h"
#include "mlw_codec/mlw_encode.h" #include <assert.h>
#include "ethosu_ml.h"
#include "ethosu_coefs.h" #include "ethosu_coefs.h"
#include "ethosu_encode.h"
#include "ethosu_ml.h"
#include "mlw_encode.h"
static void
encode_bias_scale_u65(int64_t bias, int32_t scale, uint32_t shift, uint8_t data[10])
{
assert(-(1LL << (40 - 1)) <= bias && bias < (1LL << (40 - 1))); // signed 40-bit range
assert(0 <= scale); // unsigned 32-bit range
assert(0 <= shift && shift < (1 << 6)); // unsigned 6-bit range
data[0] = (bias >> (0 * 8)) & 0xFF;
data[1] = (bias >> (1 * 8)) & 0xFF;
data[2] = (bias >> (2 * 8)) & 0xFF;
data[3] = (bias >> (3 * 8)) & 0xFF;
data[4] = (bias >> (4 * 8)) & 0xFF;
data[5] = (scale >> (0 * 8)) & 0xFF;
data[6] = (scale >> (1 * 8)) & 0xFF;
data[7] = (scale >> (2 * 8)) & 0xFF;
data[8] = (scale >> (3 * 8)) & 0xFF;
data[9] = shift & 0x3F;
}
static void
encode_bias_scale_u85(int64_t bias, int32_t scale, uint32_t shift, uint8_t data[10])
{
assert(INT32_MIN <= bias && bias <= INT32_MAX); // signed 32-bit range
assert(0 <= scale); // unsigned 31-bit range
assert(0 <= shift && shift < (1 << 6)); // unsigned 6-bit range
data[0] = (bias >> (0 * 8)) & 0xFF;
data[1] = (bias >> (1 * 8)) & 0xFF;
data[2] = (bias >> (2 * 8)) & 0xFF;
data[3] = (bias >> (3 * 8)) & 0xFF;
data[4] = (scale >> (0 * 8)) & 0xFF;
data[5] = (scale >> (1 * 8)) & 0xFF;
data[6] = (scale >> (2 * 8)) & 0xFF;
data[7] = (scale >> (3 * 8)) & 0x7F;
data[8] = shift & 0x3F;
data[9] = 0;
}
static void static void
fill_scale_and_biases(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **scales, long *scales_size, struct pipe_resource *bias_rsrc) fill_scale_and_biases(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **scales, long *scales_size, struct pipe_resource *bias_rsrc)
@ -15,32 +59,51 @@ fill_scale_and_biases(struct ethosu_subgraph *subgraph, struct ethosu_operation
struct pipe_transfer *transfer_in; struct pipe_transfer *transfer_in;
int32_t *biases = pipe_buffer_map(subgraph->base.context, bias_rsrc, int32_t *biases = pipe_buffer_map(subgraph->base.context, bias_rsrc,
PIPE_MAP_READ, &transfer_in); PIPE_MAP_READ, &transfer_in);
float ifm_scale = operation->ifm.scale;
float ofm_scale = operation->ofm.scale;
unsigned idx = 0; unsigned idx = 0;
*scales_size = align(operation->ofm.shape.depth * 10, 16); /* U65 packs 10-byte bias/scale entries contiguously then aligns to 16.
* U85 scales are read in groups of 16 channels, so pad depth to a
* 16-channel boundary first, then multiply by 10 bytes per entry. */
if (ethosu_is_u65(ethosu_screen(subgraph->base.context->screen)))
*scales_size = align(operation->ofm.shape.depth * 10, 16);
else
*scales_size = align(operation->ofm.shape.depth, 16) * 10;
*scales = malloc(*scales_size); *scales = malloc(*scales_size);
memset(*scales, 0, *scales_size); memset(*scales, 0, *scales_size);
for (unsigned i = 0; i < operation->ofm.shape.depth; i++) { for (unsigned i = 0; i < operation->ofm.shape.depth; i++) {
uint64_t bias = biases[i];
double kernel_scale = (operation->kernel.scales != NULL) ? double kernel_scale = (operation->kernel.scales != NULL) ?
operation->kernel.scales[i] : operation->kernel.scale; operation->kernel.scales[i] : operation->kernel.scale;
double conv_scale = ((double)operation->ifm.scale * kernel_scale) / (double)operation->ofm.scale; double conv_scale;
if (!operation->ifm.is_signed) {
/* UInt8 path: multiply as float first, then cast to double */
conv_scale = (double)(ifm_scale * kernel_scale) / (double)ofm_scale;
} else {
/* Int8 path: cast to double before multiply for higher precision */
conv_scale = ((double)ifm_scale * (double)kernel_scale) / (double)ofm_scale;
}
uint32_t shift; uint32_t shift;
int scale = ethosu_quantize_scale(conv_scale, &shift); int scale = ethosu_quantize_scale(conv_scale, &shift);
(*scales)[idx++] = (bias >> (0 * 8)) & 0xFF; if (ethosu_is_u65(ethosu_screen(subgraph->base.context->screen)))
(*scales)[idx++] = (bias >> (1 * 8)) & 0xFF; encode_bias_scale_u65(
(*scales)[idx++] = (bias >> (2 * 8)) & 0xFF; biases[i], scale, shift, &(*scales)[idx]);
(*scales)[idx++] = (bias >> (3 * 8)) & 0xFF; else
(*scales)[idx++] = (bias >> (4 * 8)) & 0xFF; encode_bias_scale_u85(
biases[i], scale, shift, &(*scales)[idx]);
(*scales)[idx++] = (scale >> (0 * 8)) & 0xFF; /* Saved for NPU_SET_OFM_SCALE emission in the command stream. */
(*scales)[idx++] = (scale >> (1 * 8)) & 0xFF; if (i == 0) {
(*scales)[idx++] = (scale >> (2 * 8)) & 0xFF; operation->conv.scale = scale;
(*scales)[idx++] = (scale >> (3 * 8)) & 0xFF; operation->conv.shift = shift;
}
(*scales)[idx++] = shift & 0x3F; idx += 10;
} }
pipe_buffer_unmap(subgraph->base.context, transfer_in); pipe_buffer_unmap(subgraph->base.context, transfer_in);
@ -65,60 +128,11 @@ calculate_weights_strides(struct ethosu_operation *operation, int out_strides[4]
static void static void
fill_weights(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **weights, long *weights_size, struct pipe_resource *weight_rsrc) fill_weights(struct ethosu_subgraph *subgraph, struct ethosu_operation *operation, uint8_t **weights, long *weights_size, struct pipe_resource *weight_rsrc)
{ {
struct ethosu_screen *screen = ethosu_screen(subgraph->base.context->screen);
int brick_strides[4] = {0};
unsigned input_channels = operation->ifm.shape.depth;
if (operation->kernel.depthwise)
input_channels = 1;
calculate_weights_strides(operation, brick_strides);
struct pipe_transfer *transfer_in; struct pipe_transfer *transfer_in;
uint8_t *input_weights_8 = pipe_buffer_map(subgraph->base.context, weight_rsrc, uint8_t *input_weights_8 = pipe_buffer_map(subgraph->base.context, weight_rsrc,
PIPE_MAP_READ, &transfer_in); PIPE_MAP_READ, &transfer_in);
int16_t *input_weights = malloc(pipe_buffer_size(weight_rsrc) * sizeof(*input_weights)); ml_reorder_encode_weights(subgraph, operation, input_weights_8, pipe_buffer_size(weight_rsrc), weights, weights_size);
unsigned num_weights = pipe_buffer_size(weight_rsrc);
unsigned output_channels = operation->ofm.shape.depth;
unsigned oc_stride = output_channels > 0 ? num_weights / output_channels : num_weights;
for (unsigned i = 0; i < num_weights; i++) {
int zp;
if (operation->kernel.zero_points) {
unsigned ch = operation->kernel.depthwise ? i % output_channels : i / oc_stride;
zp = operation->kernel.zero_points[ch];
} else {
zp = operation->kernel.zero_point;
}
if (operation->kernel.is_signed)
input_weights[i] = (int8_t)input_weights_8[i] - zp;
else
input_weights[i] = input_weights_8[i] - zp;
}
pipe_buffer_unmap(subgraph->base.context, transfer_in); pipe_buffer_unmap(subgraph->base.context, transfer_in);
int64_t padded_size = 0;
*weights_size = mlw_reorder_encode(
screen->ifm_ublock.depth,
screen->ofm_ublock.depth,
operation->ofm.shape.depth,
operation->kernel.height,
operation->kernel.width,
input_channels,
brick_strides,
input_weights,
operation->block_config.ofm_block.depth,
operation->kernel.depthwise,
operation->block_config.is_partkernel,
8 /* ifm_bitdepth */,
8 /* decomp_h */,
8 /* decomp_w */,
weights,
&padded_size,
DBG_ENABLED(ETHOSU_DBG_MSGS));
free(input_weights);
} }
void void
@ -140,6 +154,11 @@ fill_coefs(struct ethosu_subgraph *subgraph,
uint8_t *weights = NULL; uint8_t *weights = NULL;
fill_weights(subgraph, operation, &weights, &operation->conv.weights.size, weight_rsrc); fill_weights(subgraph, operation, &weights, &operation->conv.weights.size, weight_rsrc);
if (!weights) {
mesa_loge("fill_weights failed");
return;
}
operation->conv.weights.region = COEFS_REGION; operation->conv.weights.region = COEFS_REGION;
operation->conv.weights.address = subgraph->coefs_used; operation->conv.weights.address = subgraph->coefs_used;
subgraph->coefs_used += ALIGN_POT(operation->conv.weights.size, 16); subgraph->coefs_used += ALIGN_POT(operation->conv.weights.size, 16);

View file

@ -0,0 +1,149 @@
//
// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-FileCopyrightText: Copyright (c) 2025 Tomeu Vizoso <tomeu@tomeuvizoso.net>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include <assert.h>
#include "ethosu_encode_support.h"
#include "ethosu_ml.h"
#include "mlw_encode.h"
extern "C" {
static int32_t
weight_func(int32_t query, ml_source_state_t *state, int16_t *buffer, int32_t size, void *user_arg)
{
assert(query == MLW_SOURCE_QUERY_WEIGHTS);
IWeightSource *src = reinterpret_cast<IWeightSource *>(user_arg);
int source_size = src->Get(buffer, size);
state->eos = source_size < size;
return source_size;
}
static int
apply_zero_point_ihwo(const WeightTransformParam *p, int value)
{
value = (value - int(p->zeroPoints[p->o % p->zeroCount]));
assert(value >= -255 && value <= 255);
return value;
}
static int
apply_zero_point_ohwi(const WeightTransformParam *p, int value)
{
value = (value - int(p->zeroPoints[p->i % p->zeroCount]));
assert(value >= -255 && value <= 255);
return value;
}
void
ml_reorder_encode_weights(struct ethosu_subgraph *subgraph,
struct ethosu_operation *operation,
const uint8_t *input_weights,
long input_weights_size,
uint8_t **weights,
long *weights_size)
{
struct ethosu_screen *screen = ethosu_screen(subgraph->base.context->screen);
int bit_depth = 8;
bool is_sparse = false;
EthosUTraversal traversal;
struct WeightTransformParam param;
WeightTransformFunc transform_func = apply_zero_point_ohwi;
ml_encode_result_t res;
Point2i stride(operation->kernel.stride_x, operation->kernel.stride_y);
Point2i dilation(operation->kernel.dilation_x, operation->kernel.dilation_y);
int ret;
ml_ethosu_encode_params_t params;
params.encoder_flags = MLW_ENCODE_FLAG_NONE;
params.source_buffering_hint = 128 * 1024;
params.realloc_func = NULL;
if (operation->kernel.depthwise) {
traversal = EthosUTraversal::Depthwise;
transform_func = apply_zero_point_ihwo;
} else if (operation->block_config.is_partkernel)
traversal = EthosUTraversal::PartKernel;
else
traversal = EthosUTraversal::DepthFirst;
int zero_point = (int)operation->kernel.zero_point;
param.zeroPoints = &zero_point;
param.zeroCount = 1;
WeightSourceCommon *source;
if (ethosu_is_u65(screen)) {
if (operation->kernel.is_signed) {
source = new EthosUWeightOrdering<int8_t>(1, dilation,
operation->block_config.ofm_block.depth, bit_depth, screen->ofm_ublock.depth,
screen->ifm_ublock.depth, transform_func, &param, traversal);
} else {
source = new EthosUWeightOrdering<uint8_t>(1, dilation,
operation->block_config.ofm_block.depth, bit_depth, screen->ofm_ublock.depth,
screen->ifm_ublock.depth, transform_func, &param, traversal);
}
} else {
if (operation->kernel.is_signed) {
source = new EthosU85WeightOrdering<int8_t>(1, 256, stride, dilation,
operation->block_config.ofm_block.depth, operation->block_config.ifm_block.depth, bit_depth, operation->block_config.ofm_ublock.depth,
transform_func, &param, traversal, is_sparse);
} else {
source = new EthosU85WeightOrdering<uint8_t>(1, 256, stride, dilation,
operation->block_config.ofm_block.depth, operation->block_config.ifm_block.depth, bit_depth, operation->block_config.ofm_ublock.depth,
transform_func, &param, traversal, is_sparse);
}
}
Shape ohwi = {static_cast<int>(operation->ofm.shape.depth),
static_cast<int>(operation->kernel.height),
static_cast<int>(operation->kernel.width),
static_cast<int>(operation->ifm.shape.depth)};
Shape ohwiStrides;
int v = 1;
for (int i = 4 - 1; i >= 0; --i) {
ohwiStrides[i] = v;
v *= ohwi[i];
}
if (operation->kernel.depthwise)
SWAP(ohwiStrides[0], ohwiStrides[3]); /* IHWO */
source->SetSource(input_weights, 0, ohwi, ohwiStrides, 0);
mle_context_t *ctx = nullptr;
ret = ml_encode_ethosu_stream(&res, &params, weight_func, source, &ctx);
mle_destroy_context(ctx);
if (ret < 0) {
mesa_loge("mlw encode failed");
*weights = NULL;
*weights_size = 0;
mle_free(&res);
delete source;
return;
}
*weights = res.encoded_data;
res.encoded_data = NULL;
*weights_size = res.encoded_length;
mle_free(&res);
delete source;
}
} // extern "C"

View file

@ -0,0 +1,27 @@
/*
* Copyright (c) 2024 Tomeu Vizoso <tomeu@tomeuvizoso.net>
* SPDX-License-Identifier: MIT
*/
#ifndef ETHOSU_ENCODE_H
#define ETHOSU_ENCODE_H
#include "ethosu_ml.h"
#ifdef __cplusplus
extern "C" {
#endif
void
ml_reorder_encode_weights(struct ethosu_subgraph *subgraph,
struct ethosu_operation *operation,
const uint8_t *input_weights,
long input_weights_size,
uint8_t **weights,
long *weights_size);
#ifdef __cplusplus
}
#endif
#endif /* ETHOSU_ENCODE_H */

View file

@ -0,0 +1,672 @@
//
// SPDX-FileCopyrightText: Copyright 2021-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-FileCopyrightText: Copyright (c) 2025 Tomeu Vizoso <tomeu@tomeuvizoso.net>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#ifndef ETHOSU_ENCODE_SUPPORT_H
#define ETHOSU_ENCODE_SUPPORT_H
#include <memory>
#include <stdexcept>
#include <vector>
#include <algorithm>
#include <cstdint>
struct WeightTransformParam
{
int o, h, w, i, zeroCount;
int *zeroPoints;
bool is_signed;
};
typedef int (*WeightTransformFunc)(const WeightTransformParam *param, int weight);
enum class EthosUTraversal {
DepthFirst,
PartKernel,
Depthwise
};
struct WeightsNotSparse : public std::runtime_error
{
WeightsNotSparse() : std::runtime_error("weights not sparse") {}
};
struct SparsityTracker
{
int _sparse_zeroes = 4;
int _sparse_index = 0;
uint32_t _sparse_pos = 0xFFFFFFFF;
void Reset() { _sparse_pos = 0xFFFFFFFF; }
void Check(uint32_t pos, int depth, int weight)
{
if ( _sparse_pos != pos )
{
_sparse_pos = pos;
_sparse_zeroes = 0;
_sparse_index = 0;
if ( depth & 3 ) throw WeightsNotSparse();
}
if ( weight == 0 ) _sparse_zeroes++;
else if ( weight > 127 || weight < -127 ) throw WeightsNotSparse();
if ( (_sparse_index & 3) == 3 )
{
if ( _sparse_zeroes < 2 ) throw WeightsNotSparse();
_sparse_zeroes = 0;
}
_sparse_index++;
}
};
struct Point2i
{
int x;
int y;
Point2i(int a, int b) : x(a), y(b) {}
};
static inline int RoundAway(int value, int align)
{
assert(align > 0);
int rem = value % align;
if ( rem == 0 )
{
return value;
}
else if ( rem < 0 )
{
return value - (align + rem);
}
return value + (align - rem);
}
typedef int Shape[4];
class IWeightSource
{
public:
virtual ~IWeightSource() = default;
virtual int Elements() = 0;
virtual int Get(int16_t *buffer, int count) = 0;
};
struct IVolumeWeightSource : public IWeightSource
{
virtual ~IVolumeWeightSource() = default;
virtual void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) = 0;
};
class WeightSourceCommon : public IVolumeWeightSource
{
protected:
const void *_source;
int16_t _streams = 1;
int16_t _streamIndex = 0;
int _ofmDepth = 0;
int _ifmDepth = 0;
int _kernelH = 0;
int _kernelW = 0;
int _ohwiStrides[4];
protected:
void SetSourceCommon(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex, bool separated)
{
assert(streamIndex < _streams);
_streamIndex = streamIndex;
int streamOffset = depthOffset * ohwiStrides[0];
_source = reinterpret_cast<const uint8_t *>(buffer) + streamOffset;
_ifmDepth = ohwiShape[3];
_ofmDepth = separated ? (ohwiShape[0] + _streams - 1 - streamIndex) / _streams : ohwiShape[0];
_kernelH = ohwiShape[1];
_kernelW = ohwiShape[2];
// Bring in values for better cache locality
_ohwiStrides[0] = ohwiStrides[0] * (separated ? _streams : 1);
_ohwiStrides[1] = ohwiStrides[1];
_ohwiStrides[2] = ohwiStrides[2];
_ohwiStrides[3] = ohwiStrides[3];
}
int Elements() override { return _ofmDepth * _ifmDepth * _kernelH * _kernelW; }
inline int WeightIndex(int ofm_z, int wy, int wx, int ifm_z) const
{
return ofm_z * _ohwiStrides[0] + wy * _ohwiStrides[1] + wx * _ohwiStrides[2] + ifm_z * _ohwiStrides[3];
}
};
template<typename TYPE>
class EthosU85WeightOrdering : public WeightSourceCommon
{
protected:
static constexpr int InterleaveDepth = 4;
// Transform
WeightTransformParam *_param;
WeightTransformFunc _transform;
// Loop Limits
Point2i _stride = Point2i(0, 0);
int _ofmBlockDepth;
int _ifmBlockDepth;
short _ofmUBlockDepth;
short _ifmUBlockDepth;
short _decompX;
short _decompY;
short _subKernelRound;
short _dwPaddingCount;
// Saved state
int _ofmBlockZ = 0;
int _ifmBlockZ = 0;
int _subKernelX = 0;
int _subKernelY = 0;
int _ifmUBlockOuter = 0;
int _ifmUBlockInner = 0;
int _ofmUBlockZ = 0;
int _ifmUBlockZ = 0;
int _subKernelElements = 0;
int _strideX = 0;
int _strideY = 0;
int _kernelX = 0;
int _kernelY = 0;
int _ofmUBlockInner = 0;
int _ofmUBlockOuter = 0;
int _ifmLoopInc = 0;
int _padding = 0;
EthosUTraversal _traversal;
bool _sparse;
SparsityTracker _sparsity;
public:
EthosU85WeightOrdering(int cores, int macs, Point2i stride, const Point2i &dilation, int ofmBlockDepth, int ifmBlockDepth, int ifmBitDepth,
int ofmUBlockDepth, WeightTransformFunc func, WeightTransformParam *param, EthosUTraversal traversal, bool sparse)
{
const bool ifm16bit = (ifmBitDepth == 16);
_streams = cores;
_transform = func;
_param = param;
_traversal = traversal;
_sparse = sparse;
_stride = stride;
_ofmBlockDepth = ofmBlockDepth;
_ifmBlockDepth = ifmBlockDepth;
_ofmUBlockDepth = short(ofmUBlockDepth);
if ( traversal == EthosUTraversal::PartKernel )
{
_subKernelRound = (ifm16bit || sparse) ? 10 : 5;
_ifmUBlockDepth = ifm16bit && !sparse ? 8 : _ifmBlockDepth;
}
else
{
if ( traversal == EthosUTraversal::DepthFirst )
{
_stride = Point2i(1, 1);
_subKernelRound = 1;
_ifmUBlockDepth = _ifmBlockDepth;
}
else if ( traversal == EthosUTraversal::Depthwise )
{
_subKernelRound = 10;
_ifmUBlockDepth = 1;
}
}
_decompX = short(8 / dilation.x);
_decompY = short(8 / dilation.y);
_dwPaddingCount = (!ifm16bit && macs <= 256) ? 0 : (macs <= 512) ? 2 : 6;
_ifmLoopInc = -_ifmBlockDepth;
}
void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) override
{
SetSourceCommon(buffer, depthOffset, ohwiShape, ohwiStrides, streamIndex, false);
assert(_streamIndex == streamIndex);
_ofmUBlockZ = _streamIndex * InterleaveDepth;
_sparsity.Reset();
}
public:
int Get(int16_t *output, int count) override
{
if ( _traversal == EthosUTraversal::Depthwise )
{
assert(!_sparse);
return GetNext<false, true>(output, count);
}
else if ( _sparse )
{
return GetNext<true, false>(output, count);
}
return GetNext<false, false>(output, count);
}
template<bool IS_SPARSE, bool IS_DEPTHWISE>
int GetNext(int16_t *output, int count)
{
if ( _ofmBlockZ >= _ofmDepth )
{
return 0;
}
int ofmBlockZ, ifmBlockZ;
int ifmUBlockOuter, ifmUBlockInner;
int ifmUBlockZ, ofmUBlockZ, ofmUBlockInner, ofmUBlockOuter;
int subKernelX, subKernelY;
int strideX, strideY;
int kernelX, kernelY;
int padding;
int16_t *write = output;
const TYPE *buffer = reinterpret_cast<const TYPE *>(_source);
for ( ofmBlockZ = _ofmBlockZ; ofmBlockZ < _ofmDepth; ofmBlockZ += _ofmBlockDepth )
{
_ifmLoopInc = -_ifmLoopInc;
int clippedOfmBlockDepth = std::min(_ofmBlockDepth, _ofmDepth - ofmBlockZ);
// IFM blocks required for the brick
for ( ifmBlockZ = _ifmBlockZ; ifmBlockZ < (IS_DEPTHWISE ? 1 : _ifmDepth) && ifmBlockZ >= 0; ifmBlockZ += _ifmLoopInc )
{
_ifmBlockZ = ifmBlockZ;
int clippedIfmBlockDepth = std::min(_ifmBlockDepth, _ifmDepth - ifmBlockZ);
// Weight decomposition
// Subkernel splitting (W)
for ( subKernelX = _subKernelX; subKernelX < _kernelW; subKernelX += _decompX )
{
int subWidth = std::min<int>(_kernelW - subKernelX, _decompX);
// Subkernel Splitting (H)
for ( subKernelY = _subKernelY; subKernelY < _kernelH; subKernelY += _decompY )
{
int subHeight = std::min<int>(_kernelH - subKernelY, _decompY);
int ifmBlockDepthOuter = IS_DEPTHWISE ? 1 : clippedIfmBlockDepth;
for ( ifmUBlockOuter = _ifmUBlockOuter; ifmUBlockOuter < ifmBlockDepthOuter; ifmUBlockOuter += _ifmUBlockDepth )
{
// OFM uBlocks in OFM-block over depth
for ( ofmUBlockOuter = _ofmUBlockOuter; ofmUBlockOuter < clippedOfmBlockDepth; ofmUBlockOuter += _ofmUBlockDepth )
{
// Part kernel first works across the kernel H/W and needs padding
if ( !_subKernelElements )
{
int subKernelElements = subWidth * subHeight;
_subKernelElements = RoundAway(subKernelElements, _subKernelRound);
}
for ( strideY = _strideY; strideY < _stride.y; ++strideY )
{
int stridedKernelH = (subHeight + _stride.y - 1 - strideY) / _stride.y;
for ( strideX = _strideX; strideX < _stride.x; ++strideX )
{
int stridedKernelW = (subWidth + _stride.x - 1 - strideX) / _stride.x;
for ( kernelY = _kernelY; kernelY < stridedKernelH; ++kernelY )
{
int y = kernelY;
for ( kernelX = _kernelX; kernelX < stridedKernelW; ++kernelX )
{
int x = kernelY % 2 == 0 ? kernelX : stridedKernelW - 1 - kernelX;
_subKernelElements--;
int ifmUBlockInnerStep = IS_DEPTHWISE ? 1 : (IS_SPARSE ? 16 : 8);
for ( ifmUBlockInner = _ifmUBlockInner; ifmUBlockInner < _ifmUBlockDepth; ifmUBlockInner += ifmUBlockInnerStep )
{
// Feed OFM uBlock elements
for ( ofmUBlockZ = _ofmUBlockZ; ofmUBlockZ < _ofmUBlockDepth; ofmUBlockZ += InterleaveDepth * _streams )
{
for ( ofmUBlockInner = _ofmUBlockInner; ofmUBlockInner < InterleaveDepth; ofmUBlockInner++ )
{
// Source IFM uBlock elements (only 1 element deep if
// depthwise)
for ( ifmUBlockZ = _ifmUBlockZ; ifmUBlockZ < ifmUBlockInnerStep; ifmUBlockZ++ )
{
// Source position within the current subkernel
int wx = subKernelX + strideX + x * _stride.x;
int wy = subKernelY + strideY + y * _stride.y;
// Source IFM/OFM slices
int ifm_z = ifmBlockZ + ifmUBlockOuter + ifmUBlockInner + ifmUBlockZ;
int ofm_z = ofmBlockZ + ofmUBlockOuter + ofmUBlockInner + ofmUBlockZ;
int weight = 0;
if ( ifm_z < _ifmDepth && ofm_z < _ofmDepth )
{
_param->o = ofm_z;
_param->h = wy;
_param->w = wx;
_param->i = ifm_z;
weight = int(buffer[WeightIndex(ofm_z, wy, wx, ifm_z)]);
weight = _transform(_param, weight);
}
if constexpr ( IS_SPARSE )
_sparsity.Check((unsigned(wy) << 16) | wx, ifm_z, weight);
*write++ = int16_t(weight);
if ( --count == 0 )
{
// Save state
_subKernelElements++;
_ifmUBlockZ = ifmUBlockZ + 1;
_ofmUBlockInner = ofmUBlockInner;
_ofmUBlockZ = ofmUBlockZ;
_ifmUBlockInner = ifmUBlockInner;
_kernelX = kernelX;
_kernelY = kernelY;
_strideX = strideX;
_strideY = strideY;
_ofmUBlockOuter = ofmUBlockOuter;
_ifmUBlockOuter = ifmUBlockOuter;
_subKernelY = subKernelY;
_subKernelX = subKernelX;
_ofmBlockZ = ofmBlockZ;
_ifmLoopInc = -_ifmLoopInc;
return int(intptr_t(write - output));
}
}
_ifmUBlockZ = 0;
}
_ofmUBlockInner = 0;
}
_ofmUBlockZ = _streamIndex * InterleaveDepth;
}
// Depthwise padding
if ( IS_DEPTHWISE && _subKernelElements % _subKernelRound == 0 )
{
int padCount = _dwPaddingCount * _ofmUBlockDepth / _streams;
for ( padding = _padding; padding < padCount; padding++ )
{
*write++ = 0;
if ( --count == 0 )
{
// Save state
_subKernelElements++;
_padding = padding + 1;
_ifmUBlockInner = ifmUBlockInner; // Will skip loop above
_kernelX = kernelX;
_kernelY = kernelY;
_strideX = strideX;
_strideY = strideY;
_ofmUBlockOuter = ofmUBlockOuter;
_ifmUBlockOuter = ifmUBlockOuter;
_subKernelY = subKernelY;
_subKernelX = subKernelX;
_ofmBlockZ = ofmBlockZ;
_ifmLoopInc = -_ifmLoopInc;
return int(intptr_t(write - output));
}
}
_padding = 0;
}
_ifmUBlockInner = 0;
}
_kernelX = 0;
}
_kernelY = 0;
}
_strideX = 0;
}
// Padding
if ( _subKernelElements > 0 )
{
int padCount = _subKernelElements + (IS_DEPTHWISE ? _dwPaddingCount : 0);
padCount = padCount * _ifmUBlockDepth * _ofmUBlockDepth / _streams;
for ( padding = _padding; padding < padCount; padding++ )
{
*write++ = 0;
if ( --count == 0 )
{
// Save state
_padding = padding + 1;
_strideY = strideY; // Will skip loop above
_ofmUBlockOuter = ofmUBlockOuter;
_ifmUBlockOuter = ifmUBlockOuter;
_subKernelY = subKernelY;
_subKernelX = subKernelX;
_ofmBlockZ = ofmBlockZ;
_ifmLoopInc = -_ifmLoopInc;
return int(intptr_t(write - output));
}
}
_padding = 0;
}
_subKernelElements = 0;
_strideY = 0;
}
_ofmUBlockOuter = 0;
}
_ifmUBlockOuter = 0;
}
_subKernelY = 0;
}
_subKernelX = 0;
}
}
_ifmLoopInc = -_ifmBlockDepth;
_ifmBlockZ = 0;
_ofmBlockZ = 0;
// Return weights generated (less than requested count == EOS)
return int(intptr_t(write - output));
}
};
template<typename TYPE>
class EthosUWeightOrdering : public WeightSourceCommon
{
protected:
// Transform
WeightTransformParam *_param;
WeightTransformFunc _transform;
// Loop Limits
int _ofmBlockDepth;
int _ifmBlockDepth;
short _ofmUBlockDepth;
short _ifmUBlockDepth;
short _decompX;
short _decompY;
short _subKernelRound;
// Saved state
int _ofmBlockZ = 0;
int _ifmBlockZ = 0;
int _subKernelX = 0;
int _subKernelY = 0;
int _ifmUBlockOuter = 0;
int _ifmUBlockInner = 0;
int _ofmUBlockZ = 0;
int _ifmUBlockZ = 0;
int _kernelElement = 0;
int _ofmUBlock = 0;
EthosUTraversal _traversal;
public:
EthosUWeightOrdering(int cores, const Point2i &dilation, int ofmBlockDepth, int ifmBitDepth, int ofmUBlockDepth,
int ifmUBlockDepth, WeightTransformFunc func, WeightTransformParam *param, EthosUTraversal traversal)
{
_streams = cores;
_ofmBlockDepth = ofmBlockDepth;
_ifmBlockDepth = ((traversal == EthosUTraversal::PartKernel) || (ifmBitDepth == 16)) ? 16 : 32;
_ofmUBlockDepth = short(ofmUBlockDepth);
_ifmUBlockDepth = short(ifmUBlockDepth);
_decompX = short(8 / dilation.x);
_decompY = short(8 / dilation.y);
if ( traversal == EthosUTraversal::Depthwise )
{
_subKernelRound = 4;
}
else if ( traversal == EthosUTraversal::PartKernel )
{
_subKernelRound = (ifmBitDepth == 16) ? 2 : 4;
}
else
{
_subKernelRound = 1;
}
_transform = func;
_param = param;
_traversal = traversal;
}
void SetSource(const void *buffer, int depthOffset, const Shape &ohwiShape, const Shape &ohwiStrides, int streamIndex) override
{
SetSourceCommon(buffer, depthOffset + streamIndex, ohwiShape, ohwiStrides, streamIndex, true);
}
public:
int Get(int16_t *output, int count) override
{
if ( _traversal == EthosUTraversal::Depthwise ) return GetNext<false, true>(output, count);
else if ( _traversal == EthosUTraversal::PartKernel ) return GetNext<true, false>(output, count);
return GetNext<false, false>(output, count);
}
template<bool IS_PARTKERNEL, bool IS_DEPTHWISE>
int GetNext(int16_t *output, int count)
{
if ( _ofmBlockZ >= _ofmDepth )
{
return 0;
}
int ofmBlockZ, ifmBlockZ;
int ifmUBlockOuter, ifmUBlockInner;
int ifmUBlockZ, ofmUBlockZ, ofmUBlock;
int subKernelX, subKernelY;
int kernelElement;
int16_t *write = output;
const TYPE *buffer = reinterpret_cast<const TYPE *>(_source);
int streamBlockDepth = (_ofmBlockDepth + _streams - 1 - _streamIndex) / _streams;
for ( ofmBlockZ = _ofmBlockZ; ofmBlockZ < _ofmDepth; ofmBlockZ += streamBlockDepth )
{
int clippedOfmBlockDepth = std::min(streamBlockDepth, _ofmDepth - ofmBlockZ);
// IFM blocks required for the brick
for ( ifmBlockZ = _ifmBlockZ; ifmBlockZ < (IS_DEPTHWISE ? 1 : _ifmDepth); ifmBlockZ += _ifmBlockDepth )
{
int clippedIfmBlockDepth;
if ( IS_DEPTHWISE )
{
clippedIfmBlockDepth = _ifmUBlockDepth;
}
else
{
clippedIfmBlockDepth = IS_PARTKERNEL ? std::min(_ifmBlockDepth, _ifmDepth - ifmBlockZ) : _ifmBlockDepth;
}
// Weight decomposition
// Subkernel Splitting (H)
for ( subKernelY = _subKernelY; subKernelY < _kernelH; subKernelY += _decompY )
{
int subHeight = std::min<int>(_kernelH - subKernelY, _decompY);
// Subkernel splitting (W)
for ( subKernelX = _subKernelX; subKernelX < _kernelW; subKernelX += _decompX )
{
int subWidth = std::min<int>(_kernelW - subKernelX, _decompX);
int subKernelElements = subWidth * subHeight;
// Part-kernel first works across the kernel H/W and needs padding
subKernelElements = RoundAway(subKernelElements, _subKernelRound);
int ifmBlockDepthOuter = IS_PARTKERNEL ? clippedIfmBlockDepth : 1;
int ifmBlockDepthInner = IS_PARTKERNEL ? 1 : clippedIfmBlockDepth;
for ( ifmUBlockOuter = _ifmUBlockOuter; ifmUBlockOuter < ifmBlockDepthOuter; ifmUBlockOuter += _ifmUBlockDepth )
{
// OFM uBlocks in OFM-block over depth
for ( ofmUBlock = _ofmUBlock; ofmUBlock < clippedOfmBlockDepth; ofmUBlock += _ofmUBlockDepth )
{
// HW Kernel element traversal - cannot be a H/W loop due to element
// padding requirement on depthwise/part-kernel configurations
for ( kernelElement = _kernelElement; kernelElement < subKernelElements; kernelElement++ )
{
int kx = kernelElement % subWidth;
int ky = kernelElement / subWidth;
// IFM uBlocks in IFM-block over depth (only 1 uBlock if depthwise)
// In case of part-kernel-first IFM uBlock traversal have already been handled
// and this loop is ignored.
for ( ifmUBlockInner = _ifmUBlockInner; ifmUBlockInner < ifmBlockDepthInner; ifmUBlockInner += _ifmUBlockDepth )
{
int ifmUBlock = ifmUBlockInner + ifmUBlockOuter;
// Feed OFM uBlock elements
for ( ofmUBlockZ = _ofmUBlockZ; ofmUBlockZ < _ofmUBlockDepth; ofmUBlockZ++ )
{
// Source IFM uBlock elements (only 1 element deep if depthwise)
for ( ifmUBlockZ = _ifmUBlockZ; ifmUBlockZ < (IS_DEPTHWISE ? 1 : _ifmUBlockDepth); ifmUBlockZ++ )
{
// Source position within the current subkernel
int wx = subKernelX + kx;
int wy = subKernelY + ky;
// Source IFM/OFM slices
int ifm_z = ifmBlockZ + ifmUBlock + ifmUBlockZ;
int ofm_z = ofmBlockZ + ofmUBlock + ofmUBlockZ;
if ( (ifm_z < _ifmDepth) && (ofm_z < _ofmDepth) && (ky < subHeight) )
{
_param->o = ofm_z;
_param->h = wy;
_param->w = wx;
_param->i = ifm_z;
int weight = int(buffer[WeightIndex(ofm_z, wy, wx, ifm_z)]);
*write = int16_t(_transform(_param, weight));
}
else
{
*write = 0;
}
write++;
if ( --count == 0 )
{
// Save state
_ifmUBlockZ = ifmUBlockZ + 1;
_ofmUBlockZ = ofmUBlockZ;
_ifmUBlockInner = ifmUBlockInner;
_kernelElement = kernelElement;
_ofmUBlock = ofmUBlock;
_ifmUBlockOuter = ifmUBlockOuter;
_subKernelX = subKernelX;
_subKernelY = subKernelY;
_ifmBlockZ = ifmBlockZ;
_ofmBlockZ = ofmBlockZ;
// Return weights generated (less than requested count == EOS)
return int(intptr_t(write - output));
}
}
_ifmUBlockZ = 0;
}
_ofmUBlockZ = 0;
}
_ifmUBlockInner = 0;
}
_kernelElement = 0;
}
_ofmUBlock = 0;
}
_ifmUBlockOuter = 0;
}
_subKernelX = 0;
}
_subKernelY = 0;
}
_ifmBlockZ = 0;
}
_ofmBlockZ = 0;
return int(intptr_t(write - output));
}
};
#endif /* ETHOSU_ENCODE_SUPPORT_H */

View file

@ -137,6 +137,8 @@ struct ethosu_operation {
struct ethosu_address_range weights; struct ethosu_address_range weights;
struct ethosu_address_range scales; struct ethosu_address_range scales;
bool depthwise; bool depthwise;
unsigned scale;
unsigned shift;
} conv; } conv;
struct { struct {

View file

@ -1,6 +1,8 @@
# Copyright 2019 Google, Inc # Copyright 2019 Google, Inc
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
subdir('mlw_codec')
ethosu_registers = custom_target( ethosu_registers = custom_target(
'ethosu_registers.h', 'ethosu_registers.h',
input : ['gen_parser.py', 'gen_header.py', 'registers.xml'], input : ['gen_parser.py', 'gen_header.py', 'registers.xml'],
@ -13,18 +15,19 @@ files_ethosu = files(
'ethosu_cmd.c', 'ethosu_cmd.c',
'ethosu_coefs.c', 'ethosu_coefs.c',
'ethosu_device.c', 'ethosu_device.c',
'ethosu_encode.cpp',
'ethosu_lower.c', 'ethosu_lower.c',
'ethosu_ml.c', 'ethosu_ml.c',
'ethosu_sched.c', 'ethosu_sched.c',
'mlw_codec/mlw_encode.c',
) )
libethosu = static_library( libethosu = static_library(
'ethosu', 'ethosu',
[files_ethosu, ethosu_registers], [files_ethosu, ethosu_registers],
include_directories : [inc_gallium_aux, inc_gallium, inc_include, inc_src], include_directories : [inc_gallium_aux, inc_gallium, inc_include, inc_src, inc_libmlw_codec],
gnu_symbol_visibility : 'hidden', gnu_symbol_visibility : 'hidden',
dependencies : [idep_mesautil, dep_libdrm], dependencies : [idep_mesautil, dep_libdrm],
link_with : [libmlw_codec],
) )
driver_ethosu = declare_dependency( driver_ethosu = declare_dependency(

View file

@ -0,0 +1,50 @@
//
// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#ifndef __MLW_DECODE_H__
#define __MLW_DECODE_H__
#pragma once
#include <stdint.h>
// Result of the decode process
typedef struct ml_decode_result_t
{
int16_t *decoded_data; // decoded weight elements
int32_t decoded_length; // decoded weight length (in elements)
int32_t section_count; // number of sections in stream
int32_t *section_sizes; // section sizes in stream
} ml_decode_result_t;
#if defined __cplusplus
extern "C"
{
#endif
void ml_decode_ethosu_stream(ml_decode_result_t *result, const uint8_t *buffer, int size_bytes);
void mld_free(ml_decode_result_t *result);
#if defined __cplusplus
} // extern "C"
#endif
#endif

View file

@ -0,0 +1,115 @@
//
// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#if !defined MLW_ENCODE_H
#define MLW_ENCODE_H
#include <stdint.h>
#include <stddef.h>
#if defined _MSC_VER
#define MLW_CODEC_PACKED
#define MLW_CODEC_USE_PACK_PRAGMA (1)
#else // __GNUC__ and clang
#define MLW_CODEC_PACKED __attribute__((packed))
#endif
// Encoder input parameters EthosU
typedef struct ml_ethosu_encode_params_t
{
int32_t source_buffering_hint; // Recommend a buffering size
uint16_t encoder_flags; // Control flags to pass to the encoder
void* (*realloc_func)(void*, size_t, int purpose); // Custom output allocator function
} ml_ethosu_encode_params_t;
// Resulting encoded section information
typedef struct ml_encode_section_t
{
int32_t offset; // Byte offset of encoded section
int32_t size; // Byte size of encoded section
int32_t zeroes; // Number of zeroes encoded in a section
int8_t group_start; // Start of group
} ml_encode_section_t;
// Result of the encode process
typedef struct ml_encode_result_t
{
uint8_t *encoded_data; // Encoded weight data
int32_t encoded_length; // Encoded weight length (in bytes)
int32_t source_length; // Source elements read
ml_encode_section_t *section_info; // Array of sections in stream
int32_t section_count; // Number of section in stream
} ml_encode_result_t;
#define MLW_SOURCE_QUERY_WEIGHTS 0
#define MLW_SOURCE_QUERY_SHIFTS 1
// State of the source iterator
typedef struct ml_source_state_t
{
uint8_t new_dim_mask; // Dimension start mask
uint8_t end_dim_mask; // Dimension end mask
bool eos; // End-of-stream flag
} ml_source_state_t;
// Stream input callback (encoder will collect input through this function, of the size recommended by the buffering hint)
typedef int32_t (*ml_weight_source_fn)(int32_t query, ml_source_state_t *state, int16_t *buffer, int32_t size, void *user_arg);
// Internal state context
typedef struct mle_context_t mle_context_t;
#define MLW_ENCODE_FLAG_NONE (0) // Default encoding flag
#define MLW_ENCODE_NO_BITSTREAM (1) // Do not write any bitstream data (only return the length)
#define MLW_ENCODE_INSERT_PALETTE (2) // Insert a new palette header with this encode
#define MLW_ENCODE_RESET_PALETTE (4) // Clear and recalculate the palette header
#define MLW_ENCODE_PARTIAL_DATA (8) // Frequency analysis and palette will be constructed from incomplete data
#define MLW_ENCODE_NO_PADDING (16) // Disable trailing padding
#define MLW_ENCODE_NO_PALETTE_LUT (32) // Disable palette LUT generation
#define MLW_ENCODE_NO_ZERO_RUNS (64) // Disable zero run generation
#define MLW_ENCODE_DPIC_FORCE_PARAMS (128) // Force debug parameters
#define MLW_ENCODE_NEW_PALETTE (MLW_ENCODE_INSERT_PALETTE|MLW_ENCODE_RESET_PALETTE)
#define MLW_ENCODE_SYNTAX_ETHOSU (0) // EthosU bitstream encode syntax
#define MLW_ENCODE_SYNTAX_ETHOSU_FWD (2) // EthosU FWD bitstream encode syntax
#define MLW_ENCODE_ALLOC_GENERAL (0) // General allocations used by the encoder
#define MLW_ENCODE_ALLOC_METADATA (1) // Allocation for codec's metadata output
#define MLW_ENCODE_ALLOC_STREAM0 (2) // Stream 0 allocation for this codec
#define MLW_ENCODE_ALLOC_STREAM1 (3) // Stream 1 allocation for this codec
#if defined __cplusplus
extern "C"
{
#endif
// Baseline encode
mle_context_t *mle_create_context(int syntax);
int mle_context_query_zeroes(mle_context_t *ctx);
int mle_context_query_weights_used(mle_context_t *ctx, uint64_t weights_used[512 / 64]);
void mle_context_set_allocator(mle_context_t *ctx, void* (*realloc_func)(void*, size_t, int purpose));
void mle_destroy_context(mle_context_t *ctx);
int mle_encode(mle_context_t *ctx, ml_encode_result_t *result, const int16_t *inbuf, int inbuf_size, unsigned mlw_encode_flags);
void mle_free(ml_encode_result_t *result);
int32_t ml_encode_ethosu_stream(ml_encode_result_t *result, const ml_ethosu_encode_params_t *ep, ml_weight_source_fn src, void *user_arg, mle_context_t **ctx_out);
#if defined __cplusplus
} // extern "C"
#endif
#endif // MLW_ENCODE_H

View file

@ -0,0 +1,17 @@
# Copyright 2025 Tomeu Vizoso <tomeu@tomeuvizoso.net>
# SPDX-License-Identifier: MIT
files_mlw_codec = files(
'source/mlw_encode.cpp',
'source/mlw_encode_fwd.cpp',
'source/ml_ethosu_encode.cpp',
'source/mlw_decode.cpp',
)
libmlw_codec = static_library(
'mlw_codec',
[files_mlw_codec],
gnu_symbol_visibility : 'hidden',
)
inc_libmlw_codec = [include_directories('include')]

View file

@ -1,29 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright 2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#ifndef MLW_COMMON_H
#define MLW_COMMON_H
#define ZDIV_DISABLE 6 // not alternating mode
#define ZDIV_EOS 7 // indicates end of stream
#define WDIV_UNCOMPRESSED 7 // indicates uncompressed weights
#endif

File diff suppressed because it is too large Load diff

View file

@ -1,65 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the License); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#ifndef MLW_ENCODE_H
#define MLW_ENCODE_H
#ifdef _MSC_VER
#define MLW_ENCODE_EXPORTED __declspec(dllexport)
#else
#define MLW_ENCODE_EXPORTED __attribute__((visibility("default")))
#endif
#if __cplusplus
extern "C"
{
#endif
MLW_ENCODE_EXPORTED
int mlw_encode(int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose);
MLW_ENCODE_EXPORTED
void mlw_free_outbuf(uint8_t *outbuf);
MLW_ENCODE_EXPORTED
int mlw_reorder_encode(
int ifm_ublock_depth,
int ofm_ublock_depth,
int ofm_depth,
int kernel_height,
int kernel_width,
int ifm_depth,
int* brick_strides,
int16_t* inbuf,
int ofm_block_depth,
int is_depthwise,
int is_partkernel,
int ifm_bitdepth,
int decomp_h,
int decomp_w,
uint8_t **outbuf,
int64_t* padded_length,
int verbose);
#if __cplusplus
}
#endif
#endif

View file

@ -0,0 +1,255 @@
//
// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#if !defined ML_BIT_BUFFER_HPP
#define ML_BIT_BUFFER_HPP
#pragma once
#include "ml_raw_buffer.hpp"
#include <cstdint>
struct bitbuf_t
{
private:
uint32_t *_buf;
uint32_t _next = 0;
int _pos; // bit pos of next bit
int _limit; // in bytes
int _substream_start = 0; // start position for substreams
int _substream_end = 0; // end position for substreams
bool _enabled = false;
raw_buffer_t<uint8_t> *_buffer;
public:
// Read constructor
bitbuf_t(const void *buf, int used_bytes) : _buffer(nullptr)
{
_limit = used_bytes & ~3;
_buf = reinterpret_cast<uint32_t *>(const_cast<void*>(buf));
_pos = 0;
}
// Write constructor
bitbuf_t(raw_buffer_t<uint8_t> &buffer, int reserve_bytes, bool disable_writes) : _buffer(&buffer)
{
_enabled = !disable_writes;
if ( _enabled ) _buffer->reserve(reserve_bytes);
_limit = _buffer->capacity() & ~3;
prime( _buffer->used() * 8 ); // Start at the end of the buffer's used data
}
// Sub-stream writer
bitbuf_t(bitbuf_t &dest, int bitpos, int bitlen=0) : _buffer(dest._buffer)
{
_limit = _buffer->capacity() & ~3;
_substream_start = bitpos;
bitlen = (bitlen <= 0) ? (_limit * 8) - bitpos : bitlen; // Default to rest of buffer
_substream_end = bitpos + bitlen;
int required = (_substream_end + 7 ) / 8;
assert( required <= _limit );
_enabled = dest._enabled;
prime( bitpos );
}
public:
void put(int len, int32_t data)
{
assert( _buf == reinterpret_cast<uint32_t *>(_buffer->begin()) && "Buffer resized externally" );
assert( (data & ((1 << len)-1)) == data && "Data must be pre-masked" );
assert( ((_substream_end == 0) || (_pos + len <= _substream_end)) && "Write past end of substream section" );
if ( len > 0 && _enabled )
{
uint32_t next = _next;
int bitpos = _pos & 0x1F;
next |= uint32_t(data) << bitpos;
if ( len >= (32 - bitpos) )
{
// Write won't fit, reserve more output space
if ( (_pos / 8) >= _limit )
{
extend();
}
_buf[_pos >> 5] = next; // Requires little-endian
next = uint32_t(data) >> (32 - bitpos);
}
_next = next;
}
_pos += len;
}
void put_masked(int len, int32_t data)
{
put(len, data & ((1 << len)-1));
}
void fill(int len, unsigned bit)
{
const uint32_t mask = 0xFFFFFFFF * bit;
int remain = len;
while ( remain >= 32 )
{
put(32, mask);
remain -= 32;
}
if (remain > 0)
put(remain, mask & ((1u << remain) - 1) );
}
void align(int bits, int fill_byte)
{
// Alignments must be power of 2
assert( (bits & (bits - 1)) == 0 && bits );
const int mask = bits-1;
int distance = (bits - (_pos & mask)) & mask;
// Byte align first
put_masked( distance & 7, fill_byte );
distance &= ~7;
while (distance != 0)
{
put(8, fill_byte);
distance -= 8;
}
}
void reposition(int bitpos)
{
int end = (_substream_end != 0) ? _substream_end : _limit * 8;
assert( (bitpos >= 0 && bitpos <= end) && "Can't reposition out of stream" );
assert( (_substream_end == 0 || bitpos >= _substream_start) && "Can't reposition before substream");
if ((_pos != bitpos) && (bitpos > 0) && (bitpos <= end) )
{
// Reposition in bitstream. Caller must flush if writing.
prime(bitpos);
}
}
void flush(bool done=true)
{
if ( !_enabled ) return;
// If buffering word is not empty, write it out as-is.
if ( _pos & 0x1F )
{
// If writing a substream, blend any overlapping words with the parent stream
if ( _substream_end )
{
int remain = _substream_end - (_pos & ~0x1F); // Remaining word-bits in this substream
if ( remain < 32 ) // Will overlap happen?
{
uint32_t mask = ~0u << (32 - remain); // Mask the parent bits that we want to keep
_next = (_buf[_pos >> 5] & mask) | (_next & ~mask);
}
}
// Otherwise limited by the buffer
else
{
// Only extend by space required to flush remaining word.
if ( (_pos / 8) >= _limit )
{
extend(true);
}
}
_buf[_pos >> 5] = _next;
}
if ( done && !_substream_end )
{
_buffer->set_used( _pos / 8 );
}
}
void sync(bitbuf_t &substream)
{
flush(false);
substream.flush(false);
prime( std::max(_pos, substream._pos) );
substream._buffer = nullptr;
}
int get(int len)
{
if ( len == 0 )
{
return 0;
}
const unsigned mask = (1u << len) - 1;
assert( (_pos / 8) < _limit );
uint32_t next = _buf[_pos >> 5];
int bitpos = _pos & 0x1F;
// Bits from this word
unsigned value = next >> bitpos;
_pos += len;
// Some of the bits are in the next word
if ( len > (32 - bitpos) )
{
assert( (_pos / 8) < _limit );
next = _buf[_pos >> 5];
value |= next << (32 - bitpos);
}
return int(value & mask);
}
void read_align(int bits)
{
// Alignments must be power of 2
assert( (bits & (bits - 1)) == 0 && bits );
const int mask = bits-1;
_pos += (bits - (_pos & mask)) & mask;
}
bool read_eos() const { return _pos/8 >= _limit; }
int read_avail() const { return (_limit - (_pos / 8)) * 8 - (_pos & 7); }
int read_avail(int watermark) const { return (watermark - (_pos / 8)) * 8 - (_pos & 7); }
int pos() const { return _pos; }
int byte_pos() const { return _pos / 8; }
int byte_length() const { return _limit; }
private:
void prime(int bitpos)
{
assert( (bitpos >= 0) && (bitpos / 8) < _limit );
// Prime (start up) the bitstream writer at the given bit position
_pos = bitpos;
_buf = reinterpret_cast<uint32_t *>(_buffer->begin());
_next = _buf[bitpos >> 5];
_next &= (1u << (bitpos & 0x1F)) - 1;
}
void extend(bool exact_resize=false)
{
assert(_enabled);
_buffer->set_used( (_pos / 8) & ~3 ); // Only use whole words
_buffer->reserve( sizeof(uint32_t), exact_resize ); // Buffer implementation must optimise small requests
assert( (_buffer->capacity() & ~3) > _limit );
assert( _substream_end == 0 ); // Can't extend a substream
_limit = _buffer->capacity() & ~3;
_buf = reinterpret_cast<uint32_t *>(_buffer->begin());
}
};
#endif // ML_BIT_BUFFER_HPP

View file

@ -0,0 +1,130 @@
//
// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#if !defined ML_ENCODER_INTERNAL_HPP
#define ML_ENCODER_INTERNAL_HPP
#pragma once
#include "../include/mlw_encode.h"
#include "ml_bit_buffer.hpp"
#include <cstdint>
#include <vector>
#if __GNUC__
#define ML_ENCODER_DLL_EXPORT __attribute__((visibility("default")))
#elif _WIN32
#if TARGET_WIN32_DLL
#define ML_ENCODER_DLL_EXPORT __declspec(dllexport)
#else
#define ML_ENCODER_DLL_EXPORT
#endif
#else
#error "undefined export semantics"
#endif
#if !defined ENABLE_DEBUG_PACKET
#define ENABLE_DEBUG_PACKET (0)
#endif
#if !defined ENABLE_DEBUG_BITSTREAM
#define ENABLE_DEBUG_BITSTREAM (0)
#endif
#if ENABLE_DEBUG_PACKET
#include <cstdio>
#define PACKET_LOG(...) printf(__VA_ARGS__)
#else
#define PACKET_LOG(...)
#endif
#if ENABLE_DEBUG_BITSTREAM
#include <cstdio>
#define BITSTREAM_LOG(...) printf(__VA_ARGS__)
#else
#define BITSTREAM_LOG(...)
#endif
constexpr int ETHOSU_SLICELEN_BITS = 15;
constexpr int ZDIV_DISABLE = 6; // not alternating mode
constexpr int ZDIV_EOS = 7; // indicates end of stream
constexpr int WDIV_UNCOMPRESSED = 7; // indicates uncompressed weights
struct palette_t
{
int16_t lut[32] = {0};
int16_t inv_lut[512] = {0};
int freq[512] = {0};
int palsize; // number of palette entries
int palbits; // bit width of palette entries
int direct_offset; // added to the decoded weight index before direct conversion to sign/mag
bool use_zero_runs; // zeros are coded separately
bool only_palette; // no values outside the palette
bool only_zeros; // special case that the section is all zeros
};
struct slice_params_t
{
uint8_t w_grc_trunc;
bool w_uncompressed;
uint8_t z_grc_div;
uint8_t w_grc_div;
};
struct mle_slice_debug_t
{
slice_params_t params;
palette_t palette;
};
struct mle_context_t
{
palette_t palette;
uint64_t weights_used[512 / 64] = {0};
int distinct_weights = 0;
int syntax = 0;
int zero_count = 0;
int slicelen_bits = ETHOSU_SLICELEN_BITS;
bool palette_valid = false;
bool single_slice_sections = false;
bool allow_empty_slices = false;
bool eos_required = false;
bool enable_slice_debug = false;
bool disable_lut = false;
int8_t fixed_wgrc = -1;
int8_t fixed_zgrc = -1;
std::vector<mle_slice_debug_t> slice_debug;
void* (*realloc_func)(void*, size_t, int); // Custom output allocator function
};
inline int div_round_up(int num, int div)
{
return (num + div - 1) / div;
}
int ml_encode_fwd(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, unsigned mlw_encode_flags);
int ml_encode_section(mle_context_t *ctx, const int16_t *inbuf, int size, palette_t *p, bitbuf_t *bitbuf);
palette_t *ml_encode_palette(mle_context_t *ctx, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags);
void ml_encode_eos(mle_context_t *ctx, bitbuf_t &bits, unsigned mlw_encode_flags);
int ml_encode_internal(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags);
#endif // ML_ENCODER_INTERNAL_HPP

View file

@ -0,0 +1,120 @@
//
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "../include/mlw_encode.h"
#include "ml_bit_buffer.hpp"
#include "ml_raw_buffer.hpp"
#include "ml_encoder_internal.hpp"
#include <cassert>
#include <cstdint>
#if defined __cplusplus
extern "C"
{
#endif
ML_ENCODER_DLL_EXPORT int32_t ml_encode_ethosu_stream(ml_encode_result_t *result, const ml_ethosu_encode_params_t *ep, ml_weight_source_fn src, void *user_arg, mle_context_t **ctx_out)
{
constexpr int BUFFERING_REQUEST_SIZE = 8192; // Initial input buffering
constexpr int INITIAL_OUTPUT_BUFFER = 8192; // Initial size of output buffer (doubles at every overflow)
constexpr unsigned VALID_FLAGS = MLW_ENCODE_NO_BITSTREAM;
assert(result && ep);
if ( !(result && ep && src) )
{
return -1;
}
mle_context_t *ctx = mle_create_context(MLW_ENCODE_SYNTAX_ETHOSU);
// Allow forcing parameters for debug validation - it is expected that
// the caller knows what they're doing here since it accesses the opaque
// internals via the public interface.
assert( !(ep->encoder_flags & ~(VALID_FLAGS)) ); // Check acceptable flags
unsigned ethosu_encode_flags = (ep->encoder_flags & VALID_FLAGS);
// Input buffering of data from the source function
assert( ep->source_buffering_hint >= 0 );
int request_size = std::max(BUFFERING_REQUEST_SIZE, ep->source_buffering_hint & 0x00FFFFFF);
raw_buffer_t<int16_t> buffer( request_size );
// The source function will communicate the state to this encoding loop
ml_source_state_t state = {0};
state.eos = false;
// Output bitstream allocation
raw_buffer_t<uint8_t> output(INITIAL_OUTPUT_BUFFER, MLW_ENCODE_ALLOC_STREAM0, ep->realloc_func);
bitbuf_t bits(output, 8, ethosu_encode_flags & MLW_ENCODE_NO_BITSTREAM);
result->source_length = 0;
// Repeatedly ask for values until the source function signals end-of-stream.
while ( !state.eos )
{
int16_t *buf_write = buffer.reserve(request_size);
if ( !buf_write )
{
// Memory allocation failed
mle_destroy_context(ctx);
return -1;
}
int received = (*src)(MLW_SOURCE_QUERY_WEIGHTS, &state, buf_write, buffer.capacity() - buffer.used(), user_arg);
buffer.use(received);
unsigned encode_flags = ethosu_encode_flags;
encode_flags |= MLW_ENCODE_NEW_PALETTE;
int bytes_written = ml_encode_internal(ctx, bits, buffer.begin(), buffer.used(), buffer.used(), encode_flags);
if ( bytes_written < 0 )
{
// Encoder errored
mle_destroy_context(ctx);
return -1;
}
result->source_length += buffer.used();
buffer.clear();
}
ml_encode_eos(ctx, bits, ethosu_encode_flags);
// Populate the return result
assert(bits.byte_pos() == output.used() || (ethosu_encode_flags & MLW_ENCODE_NO_BITSTREAM));
result->encoded_length = bits.byte_pos();
result->encoded_data = output.detach();
result->section_info = nullptr;
result->section_count = 0;
if (ctx_out != nullptr)
{
assert( *ctx_out == nullptr );
*ctx_out = ctx;
}
else
{
mle_destroy_context(ctx);
}
return 1;
}
#if defined __cplusplus
} // extern "C"
#endif

View file

@ -0,0 +1,162 @@
//
// SPDX-FileCopyrightText: Copyright 2022-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#if !defined ML_RAW_BUFFER_HPP
#define ML_RAW_BUFFER_HPP
#pragma once
#include <memory>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <type_traits>
#include <algorithm>
typedef void* (*realloc_t)(void *ptr, size_t size, int);
template <typename TYPE>
struct raw_buffer_t
{
static_assert(std::is_trivially_copyable<TYPE>::value, "expected simple storage type");
constexpr static int CAPACITY_ALIGN = 16;
TYPE *_data;
int _used;
int _capacity;
int _reallocArg = 0;
realloc_t _realloc=&realloc_proxy;
public:
raw_buffer_t(int capacity, int arg=0, realloc_t rfunc=nullptr)
{
assert(capacity > 0);
_realloc = (rfunc != nullptr) ? rfunc : &realloc_proxy;
_capacity = (capacity + CAPACITY_ALIGN - 1) & ~(CAPACITY_ALIGN - 1);
_reallocArg = arg;
_data = reinterpret_cast<TYPE*>(_realloc(nullptr, _capacity * sizeof(TYPE), _reallocArg));
_used = 0;
}
raw_buffer_t(raw_buffer_t &&other)
{
_capacity = other._capacity;
other._capacity = 0;
_data = other._data;
other._data = nullptr;
_used = other._used;
other._used = 0;
_reallocArg = other._reallocArg;
_realloc = other._realloc;
}
raw_buffer_t(TYPE *data, int used, int capacity)
{
_data = data;
_used = used;
_capacity = capacity;
}
~raw_buffer_t()
{
if (_data)
{
_realloc(_data, 0, _reallocArg);
}
}
TYPE *begin() { return _data; }
TYPE *end() { return _data + _used; }
int used() const { return _used; }
int capacity() const { return _capacity; }
void clear() { _used = 0; }
const TYPE &operator[](int index) const { assert(index < _used); return _data[index]; }
void set_used(int used)
{
assert(used >= _used);
assert(used <= _capacity);
_used = used;
}
TYPE *reserve(int count, bool exact_resize=false)
{
int req_capacity = _used + count;
if ( req_capacity > _capacity )
{
if ( !exact_resize )
{
req_capacity = std::max(req_capacity, _capacity * 2);
}
auto *p = reinterpret_cast<TYPE*>( _realloc(_data, req_capacity * sizeof(TYPE), _reallocArg) );
if ( !p )
{
return nullptr;
}
_data = p;
_capacity = req_capacity;
}
int at = _used;
return _data + at;
}
TYPE *use(int count)
{
int at = _used;
_used += count;
return _data + at;
}
TYPE *detach()
{
auto tmp = _data;
_data = nullptr;
return tmp;
}
void align(int align_bytes, TYPE fill)
{
int count = (((_used + align_bytes - 1) / align_bytes) * align_bytes) - _used;
TYPE *p = reserve(count);
assert(p);
use(count);
while (count--)
{
*p++ = fill;
}
}
void remove_left(int count)
{
int to_move = _used - count;
if (to_move >= 0)
{
memmove(_data, _data + count, to_move * sizeof(TYPE));
}
_used = to_move;
}
private:
static void *realloc_proxy(void *ptr, size_t size, int)
{
return realloc(ptr, size);
}
};
#endif // ML_RAW_BUFFER_HPP

View file

@ -0,0 +1,362 @@
//
// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "../include/mlw_decode.h"
#include "ml_encoder_internal.hpp"
#include "ml_bit_buffer.hpp"
#include <algorithm>
#include <cassert>
#include <vector>
#if NDEBUG || 1
// Release build get bits macro
#define bitbuf_get(bb_, name_, len_) bb_.get(len_)
#else
#include <cstdio>
// Debug build get bits macro
inline int bitbuf_get(bitbuf_t &bb, const char *name, int len)
{
assert(len <= 32);
int tmp = bb.get(len);
printf("%6d %s:%d = %d\n", bb.pos() - len, name, len, tmp);
fflush(stdout);
return tmp;
}
#endif
// Extract and decode weights from the given bitstream
//
// outbuf - output decoded weights
// bb - input bitstream buffer (wherever it is positioned)
// endpos - bitstream end position (in bytes)
// single_slice - process slices one-at-a-time
// slice_len - slice length in bits
//
// Returns - the number of weights extracted from the bitstream
static int ml_decode_internal(raw_buffer_t<int16_t> &outbuf, bitbuf_t &bb, palette_t &palette, int endpos, bool single_slice, int slice_len)
{
int start_offset = outbuf.used();
int w_cnt;
int w_grc_div;
int w_grc_trunc;
int w_uncompressed;
int z_grc_div;
int z_prev_grc_div = -1;
bool new_palette;
int i, j;
endpos = std::min(endpos, bb.byte_length());
// Loop over all slices
do {
// Decode slice header
int bits_avail = bb.read_avail(endpos);
if ( bits_avail >= 3 )
{
z_grc_div = bitbuf_get(bb, "ZDIV", 3);
}
else // Insufficient bits left for a terminator (EOS may be optional)
{
z_grc_div = ZDIV_EOS;
int ones = bb.get(bits_avail);
assert( ones == (1 << bits_avail) - 1 );
}
while ( z_grc_div == ZDIV_EOS )
{
// End of stream
// Byte align
bb.read_align(8);
if ( bb.byte_pos() >= endpos || single_slice )
{
goto labelExit;
}
z_grc_div = bitbuf_get(bb, "ZDIV", 3);
}
if ( bb.read_avail(endpos) <= 0 )
{
assert(false);
break; // Unexpectedly reached end of the input stream
}
assert(z_grc_div < 4 || z_grc_div == ZDIV_DISABLE);
bool use_zero_runs = z_grc_div != ZDIV_DISABLE; // alternating grc
w_cnt = bitbuf_get(bb, "SLICELEN", slice_len) + (slice_len == ETHOSU_SLICELEN_BITS ? 1 : 0);
w_grc_div = bitbuf_get(bb, "WDIV", 3);
w_grc_trunc = bitbuf_get(bb, "WTRUNC", 1);
new_palette = bitbuf_get(bb, "NEWPAL", 1);
if ( !new_palette )
{
// At the moment it is not supported to change between alternating
// and non-alternating without redefining the palette (this is because
// the zero is not included in the palette in case of alternating)
bool prev_use_zero_run = z_prev_grc_div != ZDIV_DISABLE;
(void)(prev_use_zero_run);
assert((z_prev_grc_div == -1) || (use_zero_runs == prev_use_zero_run));
}
z_prev_grc_div = z_grc_div;
if ( new_palette )
{
palette.direct_offset = bitbuf_get(bb, "DIROFS", 5);
palette.palsize = bitbuf_get(bb, "PALSIZE", 5);
if ( palette.palsize > 0 )
{
palette.palsize++;
}
palette.palbits = bitbuf_get(bb, "PALBITS", 3) + 2;
for ( i = 0; i < palette.palsize; i++ )
{
palette.inv_lut[i] = int16_t(bitbuf_get(bb, "PALETTE", palette.palbits));
}
}
if ( w_grc_div == WDIV_UNCOMPRESSED )
{
// Uncompressed mode
w_uncompressed = 1;
int uncompressed_bits;
if ( palette.palsize > 0 )
{
// Uncompressed bits is given by palette size.
uncompressed_bits = 0;
while ( (1 << uncompressed_bits) < palette.palsize )
{
uncompressed_bits++;
}
}
else
{
// No palette. PALBITS is used to specify uncompressed bits.
uncompressed_bits = palette.palbits;
}
// In uncompressed mode there's only a remainder part (no unary)
// This is achieved by setting w_grc_div to index bit width
w_grc_div = uncompressed_bits;
}
else
{
w_uncompressed = 0;
assert(w_grc_div < 6);
}
// Decode the slice
int z_cnt = w_cnt + (( slice_len != ETHOSU_SLICELEN_BITS || new_palette ) ? 1 : 0);
std::vector<int> w_value(w_cnt);
std::vector<int> z_value(z_cnt);
int w_pos = 0, z_pos = 0;
int w_prev_pos = 0, z_prev_pos = 0;
int w_unary0 = 0, w_unary1 = 0, w_unary1_len = 0, w_q[12] = {0}, wq = 0;
int z_unary = 0, z_q[12] = {0}, zq = 0;
int w_nsymbols = 0;
int w_prev_enable = 0, w_prev_nsymbols = 0, w_prev_q[12] = {0};
int z_nsymbols = 0;
int z_prev_enable = 0, z_prev_nsymbols = 0, z_prev_q[12] = {0};
int total_zcnt = 0;
int z_unary_len = z_grc_div < 3 ? 12 : 8;
// Loop over all chunks in the slice
do
{
// Flow control to possibly throttle either the weights or zero-runs
int balance = use_zero_runs ? w_pos - z_pos : 0;
int w_enable = (balance < 8 || !use_zero_runs) && w_pos < w_cnt;
int z_enable = (balance >= 0 && use_zero_runs) && z_pos < z_cnt;
if ( w_enable )
{
w_unary0 = w_uncompressed ? 0 : bitbuf_get(bb, "WUNARY0", 12);
}
if ( z_enable )
{
z_unary = bitbuf_get(bb, "ZUNARY", z_unary_len);
z_nsymbols = 0;
for ( i = 0; i < z_unary_len; i++ )
{
if ( z_unary & (1 << i) )
{
zq++;
}
else
{
z_q[z_nsymbols++] = zq;
zq = 0;
}
}
z_pos += z_nsymbols;
}
if ( w_enable )
{
w_unary1_len = 0;
int max_symbols = w_uncompressed && w_grc_div > 5 ? 8 : 12;
if ( w_unary0 != 0 )
{
for ( i = 0; i < max_symbols; i++ )
{
if ( w_unary0 & (1 << i) )
{
w_unary1_len++;
}
}
}
w_unary1 = (w_unary1_len > 0) ? bitbuf_get(bb, "WUNARY1", w_unary1_len) : 0;
w_nsymbols = 0;
for ( i = 0; i < max_symbols && (w_nsymbols < (w_cnt - w_pos)); i++ )
{
int code = 0;
if ( w_unary0 & (1 << i) )
{
code = 1 + ( w_unary1 & 1 );
w_unary1 = w_unary1 >> 1;
}
wq += code;
if ( code < 2 || w_grc_trunc )
{
w_q[w_nsymbols++] = wq;
wq = 0;
}
}
w_pos += w_nsymbols;
}
// Remainders corresponding to the quotients in the previous chunk
if ( w_prev_enable )
{
for ( i = 0; i < w_prev_nsymbols && w_prev_pos < w_cnt; i++, w_prev_pos++ )
{
int remain = bitbuf_get(bb, "WREMAIN", w_grc_div);
w_value[w_prev_pos] = (w_prev_q[i] << w_grc_div) + remain;
}
}
if ( z_prev_enable )
{
for ( i = 0; i < z_prev_nsymbols && z_prev_pos < z_cnt; i++, z_prev_pos++ )
{
int remain = 0;
if ( z_grc_div != 0 )
{
remain = bitbuf_get(bb, "ZREMAIN", z_grc_div);
}
z_value[z_prev_pos] = (z_prev_q[i] << z_grc_div) + remain;
total_zcnt += z_value[z_prev_pos];
}
}
w_prev_enable = w_enable;
w_prev_nsymbols = w_nsymbols;
std::copy(std::begin(w_q), std::end(w_q), std::begin(w_prev_q));
z_prev_enable = z_enable;
z_prev_nsymbols = z_nsymbols;
std::copy(std::begin(z_q), std::end(z_q), std::begin(z_prev_q));
} while ( w_prev_enable || z_prev_enable );
// Interleave non-zero and zeros into the outbuf buffer
// Increase the outbuffer to fit the new slice
int16_t *p = outbuf.reserve(w_cnt + total_zcnt);
assert(p);
// Insert initial zeros
if ( ( slice_len != ETHOSU_SLICELEN_BITS || new_palette ) && use_zero_runs )
{
for ( j = 0; j < z_value[0]; j++ )
{
*p++ = 0;
}
}
// Loop over all weights and insert zeros in-between
for ( i = 0; i < w_cnt; i++ )
{
int val;
assert(w_value[i] < 512); // HW supports 9bit
if ( w_value[i] < palette.palsize )
{
val = palette.inv_lut[w_value[i]];
}
else
{
val = w_value[i] - palette.palsize + palette.direct_offset;
}
int sign = val & 1;
int mag = val >> 1;
*p++ = sign ? int16_t(-mag) : int16_t(mag);
if ( use_zero_runs )
{
for ( j = 0; j < z_value[i + (new_palette ? 1 : 0)]; j++ )
{
*p++ = 0;
}
}
}
outbuf.use(w_cnt + total_zcnt);
} while (!single_slice);
labelExit:
return outbuf.used() - start_offset;
}
constexpr int INITIAL_BLOCKS = 4;
#if defined __cplusplus
extern "C"
{
#endif
// Decode a stream
//
// result - Resulting data from decode (must be freeed after use)
// buffer - Incoming bitstream buffer
// size_bytes - Size of the bitstream buffer (in bytes)
void ml_decode_ethosu_stream(ml_decode_result_t *result, const uint8_t *buffer, int size_bytes)
{
assert(result && buffer && size_bytes);
result->decoded_data = nullptr;
result->section_sizes = nullptr;
bitbuf_t bb(buffer, size_bytes);
raw_buffer_t<int16_t> output(4096, MLW_ENCODE_ALLOC_STREAM0, nullptr);
palette_t palette;
ml_decode_internal(output, bb, palette, size_bytes, false, ETHOSU_SLICELEN_BITS);
// Populate the results set
result->decoded_length = output.used();
result->decoded_data = output.detach();
}
ML_ENCODER_DLL_EXPORT void mld_free(ml_decode_result_t *result)
{
if ( result )
{
if ( result->decoded_data )
{
result->decoded_data = static_cast<int16_t*>(realloc( result->decoded_data, 0));
}
if ( result->section_sizes )
{
free( result->section_sizes );
result->section_sizes = nullptr;
}
}
}
#if defined __cplusplus
} // extern "C"
#endif

View file

@ -0,0 +1,979 @@
//
// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "../include/mlw_encode.h"
#include "ml_encoder_internal.hpp"
#include "ml_raw_buffer.hpp"
#include "ml_bit_buffer.hpp"
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <tuple>
#include <vector>
#include <limits>
constexpr int ZERO_FREQ_THRESHOLD = 5;
constexpr int MIN_ZERO_RUN_LENGTH = 2;
// Create palette from the given frequencies
// Freq index 0-511 correspond to weights -256..255
// partial_data - don't make decisions about data that will be encoded
// that wasn't included in the frequency analysis.
static void create_palette(palette_t *p, bool partial_data, bool disable_lut)
{
uint64_t freq64[512] = {0};
int i, all_cnt, all_max_val;
// Pair the frequency with the value so that
// the array can be sorted on frequency while keeping
// track of the corresponding palette value
all_cnt = 0;
all_max_val = 0;
for ( i = -255; i < 256; i++ )
{
if ( i == 0 && p->use_zero_runs ) continue;
int sign = i < 0;
int mag = abs(i);
int palval = (mag << 1) | sign;
// Store palette value in 16 LSB bits, which will not affect the sorting
freq64[palval] = ((static_cast<uint64_t>(p->freq[i + 256])) << 16) | palval;
all_cnt += p->freq[i + 256];
if ( p->freq[i + 256] > 0 )
{
all_max_val = std::max(all_max_val, palval);
}
}
// Cannot use direct offset with partial data.
p->only_zeros = !partial_data && (all_cnt == 0);
p->direct_offset = 0;
if ( !partial_data && (all_cnt != 0) )
{
// Find the first non-used weight value around zero (0, -1, +1, -2, +2 etc)
for ( i = 0; i < 31; i++ )
{
if ( (freq64[i] >> 16) != 0 )
{
break;
}
}
p->direct_offset = i;
}
// Sort in descending frequency order
std::sort(std::begin(freq64), std::end(freq64), [](uint64_t a, uint64_t b) { return b < a; });
// Check if all weights fit into the palette (and the palette is not empty)
p->only_palette = !disable_lut && !partial_data && (freq64[0] >> 16) > 0 && (freq64[32] >> 16) == 0;
int max_palette_size;
if ( p->only_palette )
{
max_palette_size = 32;
}
else
{
// For direct-lut we must make sure that the encoded weight
// index is not > 511. We do that by limiting the palette size
// such that the greatest value can be reached after subtracting
// the palette size.
max_palette_size = std::min(32, 511 - all_max_val);
if ( max_palette_size == 1 )
{
max_palette_size = 0; // because palette of size 1 is not supported
}
}
// Setup the 32 entry palette
int max_lut_val = 0, val, cnt, lut_cnt = 0;
for ( i = 0; i < max_palette_size; i++ )
{
cnt = static_cast<int>(freq64[i] >> 16);
val = freq64[i] & 0xffff;
// If partial data, all palette entries must be filled (even if they're wrong)
if ( cnt == 0 && !partial_data ) break;
p->lut[i] = int16_t(val);
max_lut_val = std::max(max_lut_val, val);
lut_cnt += cnt;
}
// When all weights are the same nonzero value a palette size of 1 is possible; but not supported.
// Make the palette 2 entries long and zero the second entry (it's never indexed).
if ( i == 1 )
{
p->lut[i++] = 0;
}
// Heuristic for when to use the palette. If more than half of the
// weights are in the palette then we use it. This ensures we don't
// use palette for e.g. rectangular distributions.
int palbits_val;
if ( !disable_lut && (lut_cnt >= all_cnt / 2) )
{
p->palsize = i;
palbits_val = max_lut_val;
}
else
{
// No palette
p->palsize = 0;
// If no palette, then palbits is used to specify the
// number of bits required for uncompressed mode, i.e.
// the number of bits for the greatest weight value
palbits_val = all_max_val;
}
// the palette entry bit width
// minimum 2-bits (because PALBITS is in range 2..9)
int palbits = 2;
while ( (1 << palbits) <= palbits_val )
{
palbits++;
}
assert(palbits <= 9);
p->palbits = palbits;
}
static void create_inverse_palette(palette_t *p)
{
int i;
int val = p->palsize - p->direct_offset;
for ( i = 0; i < 256; i++ )
{
p->inv_lut[256 + i] = int16_t(val);
p->inv_lut[256 - i] = int16_t(val + 1);
val += 2;
}
p->inv_lut[0] = 0;
for ( i = 0; i < p->palsize; i++ )
{
val = p->lut[i];
int sign = val & 1;
int mag = val >> 1;
int weight = sign ? -mag : mag;
assert( ((weight + 256) >= 0) && ((weight + 256) < 512) );
p->inv_lut[weight + 256] = int16_t(i);
}
}
// If palette_size is 512, then palette is not used (in that case the palette is setup
// with the standard alternating unsigned to signed mapping)
static void update_palette(mle_context_t *ctx, palette_t *p, const int16_t *weights, int weights_count, bool partial_data, bool disable_lut, bool disable_zruns)
{
int(&freq)[512] = p->freq;
int total_zeroes = 0;
int zeroes_in_run = 0;
int zeroes_in_all_runs = 0;
// Calculate frequencies of the given weight stream
for ( int i = 0; i < weights_count; i++ )
{
unsigned weight = weights[i] + 256;
freq[weight]++;
uint64_t &value = ctx->weights_used[weight / 64];
uint64_t mask = (1ull << (weight % 64));
if ((value & mask) == 0)
{
ctx->distinct_weights++;
value |= mask;
}
if ( weights[i] == 0 )
{
total_zeroes++;
zeroes_in_run++;
}
else
{
if ( zeroes_in_run >= MIN_ZERO_RUN_LENGTH )
{
zeroes_in_all_runs += zeroes_in_run;
}
zeroes_in_run = 0;
}
}
// Detect trailing zero runs in compression
if ( zeroes_in_run >= MIN_ZERO_RUN_LENGTH )
{
zeroes_in_all_runs += zeroes_in_run;
}
int common_val = 0;
int common_freq = 0;
for ( int i = 0; i < 512; i++ )
{
// Most common non-zero frequency (because we already have that)
if ( (i != 256) && freq[i] > common_freq )
{
common_val = i - 256;
common_freq = freq[i];
}
}
// Decide if zero-runs (alternating mode) should be used:
// * zero runs must make up at least half of the zeroes
// * zero should be the most common symbol
// * zero should be sufficiently more common than the second most common symbol
bool use_zero_runs = zeroes_in_all_runs >= (total_zeroes / 2);
use_zero_runs &= total_zeroes > (ZERO_FREQ_THRESHOLD * common_freq);
p->use_zero_runs = use_zero_runs && !disable_zruns;
// Create the palette
create_palette(p, partial_data, disable_lut);
}
#define NWCFG 13
#define NZCFG 4 // restrict search to ZDIV=0..3
#define MAX_ZWCFG ((NWCFG > NZCFG) ? NWCFG : NZCFG)
// (trunc<<4) | div, 0x20 means uncompressed
static constexpr char w_grc_params[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20};
static constexpr char z_grc_params[] = {0x00, 0x01, 0x02, 0x03};
struct grc_param_t
{
int cfg = 0;
int end_pos = 0;
};
template <typename TYPE>
static int search_grc_params(const TYPE *inval_buf, int n_inval, int zrun_mode, int uncompressed_bits,
std::vector<grc_param_t> &result, bool single_slice)
{
assert(uncompressed_bits < 32);
int n_cfg = zrun_mode ? NZCFG : NWCFG;
const char *grc_params = zrun_mode ? z_grc_params : w_grc_params;
// Greedy forward-only GRC search (with optimisation for avoiding
// unusable GRC parameters).
const int cmd_cost = 40;
int bit_count[MAX_ZWCFG] = {0};
int reset_pos[MAX_ZWCFG] = {0};
bool coded[MAX_ZWCFG] = {false};
bool any_uncodable[MAX_ZWCFG] = {false};
int active_bitcount = 0;
int active_cfg = -1;
int add_uncompressed_bits = (uncompressed_bits > 0) ? uncompressed_bits : 100;
for ( int i = 0; i < n_inval; i++ )
{
int value = inval_buf[i];
int best_bitcount = 0x7FFFFFFF;
int best_cfg = 0;
// Loop over GRC parameters, calculate bits to code value, and then update the search state
for ( int j = 0; j < n_cfg; j++ )
{
int div = grc_params[j] & 15;
int trunc = grc_params[j] >> 4;
int q = value >> div;
int bits = trunc ? std::min(q + 1, 2) + div : q + 1 + div;
bool can_code = !(!zrun_mode && ((trunc && q > 2) || q > 31));
if ( trunc == 2 )
{
bits = add_uncompressed_bits;
can_code = true;
}
if ( can_code )
{
if ( !coded[j] )
{
bit_count[j] = active_bitcount; // Reset non-coded to current best
}
bit_count[j] = bit_count[j] + bits;
if ( bit_count[j] < best_bitcount )
{
best_bitcount = bit_count[j];
best_cfg = j;
}
}
else
{
reset_pos[j] = i + 1;
bit_count[j] += cmd_cost; // Would have to change away if used
}
coded[j] = can_code;
any_uncodable[j] |= !can_code;
}
// In single-slice mode we can check the bit counts afterwards; otherwise we record
// slice start points by tracking the minimum of the accumulting bit counts for
// different grc parameters.
if ( !single_slice )
{
bool must_code = (active_cfg == -1) || !coded[active_cfg];
if ( must_code || ((best_cfg != active_cfg) && (best_bitcount + cmd_cost) < active_bitcount) )
{
// Commit non-initial changes
if ( active_cfg != -1 )
{
// Range elision (was the other config better all along?)
if ( (bit_count[best_cfg] < bit_count[active_cfg]) && (reset_pos[best_cfg] <= reset_pos[active_cfg]) )
{
// If the current BEST config started before the ACTIVE config in time, then switch to using the BEST config.
active_cfg = best_cfg; // Duplicated on both paths for clarity
}
else
{
// Otherwise use the ACTIVE config for this slice before switching to the BEST config.
grc_param_t param;
param.cfg = active_cfg;
param.end_pos = i;
assert((active_cfg != 12) || (uncompressed_bits != 0));
result.push_back(param);
}
}
active_cfg = best_cfg;
}
}
else if ( active_cfg == -1 )
{
active_cfg = best_cfg;
}
active_bitcount = bit_count[active_cfg];
}
// terminate the run
if ( result.empty() || (result.back().cfg != active_cfg) )
{
// If single slice then select the best minimum-bits configuration
if (single_slice)
{
assert( result.empty() );
active_cfg = -1;
int max_bit_count = std::numeric_limits<int>::max();
for (int i=0; i < n_cfg; i++)
{
if ( !any_uncodable[i] && (bit_count[i] <= max_bit_count) )
{
if ( (active_cfg != 12) || (uncompressed_bits != 0) )
{
active_cfg = i;
max_bit_count = bit_count[i];
}
}
}
assert(active_cfg != -1); // There isn't a usable grc parameter (fatal)
}
grc_param_t param;
param.cfg = active_cfg;
assert((active_cfg != 12) || (uncompressed_bits != 0));
result.push_back(param);
}
result.back().end_pos = n_inval;
return active_bitcount;
}
#if !ENABLE_DEBUG_BITSTREAM
// Release build putbits macro
#define bitbuf_put(bb_, name_, len_, data_) bb_->put(len_, data_)
#define bitbuf_align(bb_, name_, len_, data_) bb_->align(len_, data_)
#else
// Debug build putbits macro
inline void bitbuf_put(bitbuf_t *bb, const char *name, int len, int data)
{
assert(len <= 32);
int pre_pos = bb->pos();
bb->put(len, data);
BITSTREAM_LOG("%6d %s:%d = %d\n", pre_pos, name, bb->pos()-pre_pos, data);
}
// Debug build putbits macro
inline void bitbuf_align(bitbuf_t *bb, const char *name, int len, int data)
{
assert(len <= 32);
int pre_pos = bb->pos();
bb->align(len, data);
BITSTREAM_LOG("%6d %s:%d = %d\n", pre_pos, name, bb->pos()-pre_pos, data);
}
#endif
static slice_params_t encode_slice_header(mle_context_t *ctx, int slicelen, bool new_palette, int uncompressed_bits, int w_cfg, int z_cfg, bitbuf_t *bb)
{
assert( (ctx->slicelen_bits == 15) || (ctx->slicelen_bits == 17) ); // Currently known formats
assert( slicelen < (1 << ctx->slicelen_bits) );
assert( w_cfg >= 0 && w_cfg < (sizeof(w_grc_params)/sizeof(w_grc_params[0])));
// GRC parameters for this slice
int w_grc_trunc = (w_grc_params[w_cfg] >> 4) == 1;
int w_uncompressed = (w_grc_params[w_cfg] >> 4) == 2;
// Callers can signal a truly empty slice with a negative z_cfg index
assert( ((z_cfg < 0) && (slicelen == 0)) || (ctx->allow_empty_slices && slicelen >= 0) || (slicelen >= 1) );
int z_grc_div = (z_cfg < 0) ? ZDIV_DISABLE : z_grc_params[z_cfg] & 15;
int w_grc_div = w_uncompressed ? uncompressed_bits : (w_grc_params[w_cfg] & 15);
int zdiv = ctx->palette.use_zero_runs ? z_grc_div : ZDIV_DISABLE;
int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED;
if ( ENABLE_DEBUG_BITSTREAM )
{
BITSTREAM_LOG("slice: bitoffset %d slicelen %d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %d\n", bb->pos(),
slicelen, zdiv, wdiv, w_grc_trunc, new_palette, ctx->palette.palbits, ctx->palette.palsize);
}
// Write slice header
bitbuf_put(bb, "ZDIV", 3, zdiv);
bitbuf_put(bb, "SLICELEN", ctx->slicelen_bits, ctx->allow_empty_slices ? slicelen : slicelen - 1);
bitbuf_put(bb, "WDIV", 3, wdiv);
bitbuf_put(bb, "WTRUNC", 1, w_grc_trunc);
bitbuf_put(bb, "NEWPAL", 1, new_palette);
if ( new_palette )
{
bitbuf_put(bb, "DIROFS", 5, ctx->palette.direct_offset);
bitbuf_put(bb, "PALSIZE", 5, std::max(0, ctx->palette.palsize - 1));
bitbuf_put(bb, "PALBITS", 3, ctx->palette.palbits - 2);
for (int i = 0; i < ctx->palette.palsize; i++ )
{
bitbuf_put(bb, "PALETTE", ctx->palette.palbits, ctx->palette.lut[i]);
}
}
slice_params_t header;
header.w_grc_trunc = w_grc_trunc;
header.w_uncompressed = w_uncompressed;
header.z_grc_div = z_grc_div;
header.w_grc_div = w_grc_div;
return header;
}
static void encode_slice(mle_context_t *ctx, const int16_t *w_values, int weight_count, const int32_t *z_values, int zero_count, bool new_palette,
int uncompressed_bits, int w_cfg, int z_cfg, bitbuf_t *bb)
{
int w_cnt = weight_count;
int z_cnt = (z_values && zero_count) ? w_cnt + (new_palette ? 1 : 0) : 0;
slice_params_t hdr = encode_slice_header(ctx, w_cnt, new_palette, uncompressed_bits, w_cfg, z_cfg, bb);
assert(z_cfg >= 0 && "slice was signalled as truly empty");
// Record slice parameters for HW testbench debugging
if ( ctx->enable_slice_debug )
{
ctx->slice_debug.push_back( mle_slice_debug_t { hdr, ctx->palette } );
}
int w_grc_div = hdr.w_grc_div;
bool w_uncompressed = hdr.w_uncompressed;
int z_grc_div = hdr.z_grc_div;
int w_grc_trunc = hdr.w_grc_trunc;
int j;
int z_unary_len = z_grc_div < 3 ? 12 : 8;
int w_pos = 0, z_pos = 0;
int w_unary0 = 0, w_unary1 = 0, w_unary1_len = 0, w_q = -1, w_r = 0;
int z_unary = 0, z_q = -1, z_r = 0;
int w_remain_data[2][12] = {{0}};
int *w_remain = w_remain_data[0];
int *w_prev_remain = w_remain_data[1];
int w_nsymbols = 0;
int w_prev_enable = 0, w_prev_nsymbols = 0;
int z_remain_data[2][12] = {{0}};
int *z_remain = z_remain_data[0];
int *z_prev_remain = z_remain_data[1];
int z_nsymbols = 0;
int z_prev_enable = 0, z_prev_nsymbols = 0;
bool use_zero_runs = ctx->palette.use_zero_runs;
do
{
int balance = use_zero_runs ? w_pos - z_pos : 0;
int w_enable = balance < 8 && w_pos < w_cnt;
int z_enable = balance >= 0 && use_zero_runs && z_pos < z_cnt;
if ( w_enable )
{
// Encode chunk (weights)
j = 0;
w_nsymbols = 0;
w_unary0 = 0;
w_unary1 = 0;
w_unary1_len = 0;
int max_symbols = (w_uncompressed && w_grc_div > 5) ? 8 : 12;
while ( j < max_symbols )
{
if ( w_q < 0 )
{
if ( w_pos < w_cnt )
{
int value = w_values[w_pos];
assert( value >= 0 && value < 512 );
w_q = value >> w_grc_div;
w_r = value & ((1 << w_grc_div) - 1);
assert(w_q <= 31 && (!w_grc_trunc || w_q <= 2));
}
else
{
w_q = 0;
w_r = -1; // don't send remainder
}
}
while ( w_q >= 0 && j < max_symbols )
{
w_unary0 |= w_q > 0 ? (1 << j) : 0;
if ( w_q > 0 )
{
w_unary1 |= w_q > 1 ? (1 << w_unary1_len) : 0;
w_unary1_len++;
}
j++;
w_q -= 2;
if ( w_grc_trunc ) w_q--;
}
if ( w_q < 0 && w_r >= 0 )
{
w_remain[w_nsymbols] = w_r;
w_nsymbols++;
w_pos++;
}
}
}
if ( z_enable )
{
// Encode chunk (zrun)
j = 0;
z_nsymbols = 0;
z_unary = 0;
while ( j < z_unary_len )
{
if ( z_q < 0 )
{
if ( z_pos < z_cnt )
{
int value = z_values[z_pos];
z_q = value >> z_grc_div;
z_r = value & ((1 << z_grc_div) - 1);
assert( z_q >= 0 ); // There are no negative length z-runs
}
else
{
z_q = 0;
z_r = -1;
}
}
while ( z_q >= 0 && j < z_unary_len )
{
z_unary |= z_q > 0 ? (1 << j) : 0;
j++;
z_q--;
}
if ( z_q < 0 && z_r >= 0 )
{
assert( z_nsymbols < 12 );
z_remain[z_nsymbols] = z_r;
z_nsymbols++;
z_pos++;
}
}
}
// Write chunk to bitstream
if ( w_enable && !w_uncompressed )
{
bitbuf_put(bb, "WUNARY0", 12, w_unary0); // 12 bits
}
if ( z_enable )
{
bitbuf_put(bb, "ZUNARY", z_unary_len, z_unary); // 12 or 8 bits
}
if ( w_enable && !w_uncompressed && (w_unary1_len > 0) )
{
bitbuf_put(bb, "WUNARY1", w_unary1_len, w_unary1); // max 12 bits
}
if ( w_prev_enable )
{
for (int i = 0; i < w_prev_nsymbols; i++ )
{
bitbuf_put(bb, "WREMAIN", w_grc_div, w_prev_remain[i]);
}
}
if ( z_prev_enable && (z_grc_div > 0) )
{
for (int i = 0; i < z_prev_nsymbols; i++ )
{
bitbuf_put(bb, "ZREMAIN", z_grc_div, z_prev_remain[i]);
}
}
w_prev_enable = w_enable;
w_prev_nsymbols = w_nsymbols;
std::swap(w_prev_remain, w_remain);
z_prev_enable = z_enable;
z_prev_nsymbols = z_nsymbols;
std::swap(z_prev_remain, z_remain);
} while ( w_prev_enable || z_prev_enable );
}
int ml_encode_section(mle_context_t *ctx, const int16_t *inbuf, int size, palette_t *p, bitbuf_t *bitbuf)
{
bool new_palette = (p != nullptr);
// Reuse previous if not specified
if ( p == nullptr )
{
p = &ctx->palette;
}
// Uncompressed mode can only be used if either all weights
// are in the palette OR if the palette is not used.
int uncompressed_bits = 0;
if ( p->only_palette )
{
// Uncompressed bits derived from palette size
while ( (1 << uncompressed_bits) < p->palsize )
{
uncompressed_bits++;
}
}
else if ( p->palsize == 0 )
{
// Uncompressed bits is palbits (which is the bitdepth of the greatest weight)
uncompressed_bits = p->palbits;
}
// If there are no weights at all, emit an empty slice header, then exit.
if ( size == 0 )
{
// Signal a truly empty slice using -ve zgrc to ensure ZDIV_DISABLE is written to the stream.
if ( ctx->allow_empty_slices )
{
encode_slice_header(ctx, 0, new_palette, uncompressed_bits, 0, -1, bitbuf);
}
return 0;
}
std::vector<int16_t> weight_values;
weight_values.reserve(size);
// If zruns was enabled, expect total to be < weight_values/2
std::vector<int32_t> zrun_values;
if ( p->use_zero_runs )
{
zrun_values.reserve( size / 4);
}
// Get weights (or weight indicies) AND zero-runs from the input weight stream.
int i = 0;
bool allow_empty_slices = !p->only_zeros || ctx->allow_empty_slices;
int total_zcnt = 0;
const int max_slice_len = (1 << ctx->slicelen_bits) - 1;
const int max_zero_run_length = ctx->allow_empty_slices ? max_slice_len - 1 : INT32_MAX;
while ( 1 )
{
if ( p->use_zero_runs )
{
int zcnt = 0;
// Count zero run
// Special case: if all weights in the section are zero, we must
// still ensure we have one coded weight so the the slice length
// doesn't become 0. Therefore we skip the first zero run and code
// the zero explicitly as a weight value instead
if ( allow_empty_slices || i > 0 )
{
while ( i < size && inbuf[i] == 0 && zcnt < max_zero_run_length )
{
zcnt++;
i++;
}
}
total_zcnt += zcnt;
zrun_values.push_back(zcnt);
}
if ( i == size ) break;
int16_t value = p->inv_lut[inbuf[i] + 256];
weight_values.push_back(value);
i++;
}
// Search for good GRC parameters for the weight stream
std::vector<grc_param_t> w_slice_cfg;
int n_weights = int(weight_values.size());
if ( n_weights )
{
// Use a fixed grc config index if provided (partial-data mode sets this)
if ( ctx->fixed_wgrc >= 0 )
{
w_slice_cfg.push_back(grc_param_t{ ctx->fixed_wgrc, n_weights });
}
else
{
search_grc_params(weight_values.data(), n_weights, 0, uncompressed_bits, w_slice_cfg, ctx->single_slice_sections);
}
}
int n_w_slice = int(w_slice_cfg.size());
// Search for good GRC parameters for the zrun stream
std::vector<grc_param_t> z_slice_cfg;
if ( p->use_zero_runs )
{
// Use a fixed grc config index if provided (partial-data mode sets this)
if ( ctx->fixed_zgrc >= 0 )
{
z_slice_cfg.push_back(grc_param_t{ ctx->fixed_zgrc, n_weights + 1 });
}
else
{
search_grc_params(zrun_values.data(), n_weights + 1, 1, 0, z_slice_cfg, ctx->single_slice_sections);
}
}
int n_z_slice = int(z_slice_cfg.size());
int loops = 0;
// Encode bitstream slice
int pos = 0, i_w_slice = 0, i_z_slice = 0;
bool only_zero_runs_pass = !zrun_values.empty();
while ( (pos < n_weights) || new_palette || only_zero_runs_pass )
{
int w_len = 0;
int z_len = 0;
if ( i_w_slice < n_w_slice )
{
w_len = w_slice_cfg[i_w_slice].end_pos - pos;
w_len = std::min(w_len, max_slice_len);
}
if ( i_z_slice < n_z_slice )
{
z_len = z_slice_cfg[i_z_slice].end_pos - pos;
z_len = std::min(z_len, max_slice_len);
}
// The first slice (when new_palette is 1) encodes zero runs both at the
// beginning and end (i.e. number of zero runs are len+1).
// The following slices only encode zero runs at the end (there cannot be
// any zeros in the beginning since they are encoded by the previous slice)
const int32_t *zrun_buf = p->use_zero_runs ? zrun_values.data() + pos + !(new_palette || ctx->allow_empty_slices) : nullptr;
const int16_t *w_buf = w_len ? weight_values.data() + pos : nullptr;
int w_cfg = w_len ? w_slice_cfg[i_w_slice].cfg : 0;
int z_cfg = p->use_zero_runs ? z_slice_cfg[i_z_slice].cfg : 0;
encode_slice(ctx, w_buf, w_len, zrun_buf, z_len, new_palette, uncompressed_bits, w_cfg, z_cfg, bitbuf);
new_palette = 0;
if ( z_len <= 0 && w_len > 0 )
pos += w_len;
else if ( w_len <= 0 && z_len > 0 )
pos += z_len;
else
pos += std::min(z_len, w_len);
if ( i_w_slice < n_w_slice && w_slice_cfg[i_w_slice].end_pos <= pos )
{
i_w_slice++;
}
if ( i_z_slice < n_z_slice && z_slice_cfg[i_z_slice].end_pos <= pos )
{
i_z_slice++;
}
loops++;
only_zero_runs_pass = false;
}
// Single-slice sections can only generate one slice (a single loop)
assert( !ctx->single_slice_sections || (ctx->single_slice_sections && loops == 1) );
if ( ctx->single_slice_sections && (loops != 1) )
{
return -1;
}
return total_zcnt;
}
palette_t *ml_encode_palette(mle_context_t *ctx, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags)
{
palette_t *palette = nullptr;
if ( !ctx->palette_valid || (mlw_encode_flags & MLW_ENCODE_INSERT_PALETTE) )
{
if (mlw_encode_flags & MLW_ENCODE_RESET_PALETTE)
{
memset( ctx->palette.freq, 0, sizeof(ctx->palette.freq) );
}
bool partial_data = (mlw_encode_flags & MLW_ENCODE_PARTIAL_DATA) != 0;
bool disable_lut = (mlw_encode_flags & MLW_ENCODE_NO_PALETTE_LUT) != 0;
bool disable_zruns = (mlw_encode_flags & MLW_ENCODE_NO_ZERO_RUNS) != 0;
assert( analyse_count >= encode_count && "Must analyse at least as much as is encoded");
update_palette(ctx, &ctx->palette, weights, analyse_count, partial_data, disable_lut, disable_zruns);
ctx->palette_valid = true;
if ( !(mlw_encode_flags & MLW_ENCODE_DPIC_FORCE_PARAMS) )
{
ctx->fixed_wgrc = (partial_data) ? 5 : -1;
ctx->fixed_zgrc = (partial_data) ? 3 : -1;
}
create_inverse_palette(&ctx->palette);
palette = &ctx->palette;
}
return palette;
}
void ml_encode_eos(mle_context_t *ctx, bitbuf_t &bits, unsigned mlw_encode_flags)
{
// Add end of stream marker and align to 128bit
bitbuf_t *bb = &bits;
if ( ctx->eos_required )
{
bitbuf_put(bb, "ZDIV", 3, ZDIV_EOS);
}
bitbuf_align(bb, "BYTEALIGN", 8, 0xff);
if ( !(mlw_encode_flags & MLW_ENCODE_NO_PADDING) )
{
bb->align( 128, 0xFF );
}
bb->flush();
}
int ml_encode_internal(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int encode_count, int analyse_count, unsigned mlw_encode_flags)
{
palette_t *palette = ml_encode_palette(ctx, weights, encode_count, analyse_count, mlw_encode_flags);
int zresult = ml_encode_section(ctx, weights, encode_count, palette, &bits);
if ( zresult < 0 )
{
return -1;
}
ctx->zero_count += zresult;
return 0;
}
extern "C"
{
ML_ENCODER_DLL_EXPORT mle_context_t *mle_create_context(int32_t syntax)
{
mle_context_t *ctx = new mle_context_t;
ctx->zero_count = 0;
ctx->syntax = syntax;
ctx->realloc_func = nullptr;
if (syntax == MLW_ENCODE_SYNTAX_ETHOSU)
{
ctx->slicelen_bits = ETHOSU_SLICELEN_BITS;
ctx->allow_empty_slices = false;
ctx->single_slice_sections = false;
ctx->eos_required = true;
}
else if (syntax == MLW_ENCODE_SYNTAX_ETHOSU_FWD)
{
ctx->slicelen_bits = 0;
}
else
{
assert(false && "bad syntax");
delete ctx;
return nullptr;
}
return ctx;
}
ML_ENCODER_DLL_EXPORT int mle_context_query_zeroes(mle_context_t *ctx)
{
assert( ctx );
return ctx->zero_count;
}
ML_ENCODER_DLL_EXPORT int mle_context_query_weights_used(mle_context_t *ctx, uint64_t weights_used[512 / 64])
{
assert( ctx );
std::copy(std::begin(ctx->weights_used), std::end(ctx->weights_used), weights_used);
return ctx->distinct_weights;
}
ML_ENCODER_DLL_EXPORT void mle_context_set_allocator(mle_context_t *ctx, void* (*realloc_func)(void*, size_t, int))
{
assert( ctx );
ctx->realloc_func = realloc_func;
}
ML_ENCODER_DLL_EXPORT void mle_destroy_context(mle_context_t *ctx)
{
assert(ctx);
delete ctx;
}
ML_ENCODER_DLL_EXPORT int mle_encode(mle_context_t *ctx, ml_encode_result_t *result, const int16_t *inbuf, int inbuf_size, unsigned mlw_encode_flags)
{
assert( ctx && result );
raw_buffer_t<uint8_t> output(4096, MLW_ENCODE_ALLOC_STREAM0, ctx->realloc_func);
bitbuf_t bits(output, 4096, mlw_encode_flags & MLW_ENCODE_NO_BITSTREAM);
int written = 0;
if ( ctx->syntax == MLW_ENCODE_SYNTAX_ETHOSU_FWD )
{
written = ml_encode_fwd(ctx, bits, inbuf, inbuf_size, mlw_encode_flags);
}
else
{
int start = bits.byte_pos();
if ( ml_encode_internal(ctx, bits, inbuf, inbuf_size, inbuf_size, mlw_encode_flags) < 0 )
{
return -1;
}
ml_encode_eos(ctx, bits, mlw_encode_flags);
written = bits.byte_pos() - start;
}
if ( written >= 0 )
{
result->encoded_data = output.detach();
result->source_length = inbuf_size;
result->encoded_length = written;
result->section_info = nullptr;
result->section_count = 0;
}
return written;
}
ML_ENCODER_DLL_EXPORT void mle_free(ml_encode_result_t *result)
{
if ( result )
{
if ( result->encoded_data )
{
free( result->encoded_data );
result->encoded_data = nullptr;
}
if ( result->section_info )
{
free( result->section_info );
result->section_info = nullptr;
}
}
}
}

View file

@ -0,0 +1,197 @@
//
// SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the License); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "../include/mlw_encode.h"
#include "ml_encoder_internal.hpp"
#include "ml_raw_buffer.hpp"
#include "ml_bit_buffer.hpp"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <tuple>
#include <vector>
#include <limits>
constexpr static int LUT_MAX = 16;
constexpr static int INV_MAX = 512;
struct fwd_header_t
{
bool raw_mode_flag = false;
bool small_lut_flag = false;
int8_t zero_adjust = 0;
int16_t lut[LUT_MAX] = {};
int8_t inv_index[INV_MAX] = {};
fwd_header_t()
{
std::fill_n(inv_index, INV_MAX, -1);
}
};
static inline int32_t fold(int32_t value)
{
// Fold into positive value (sign in lsb)
return (abs(value) << 1) | (uint32_t(value) >> 31);
}
void fwd_emit_header(bitbuf_t &bits, const fwd_header_t &hdr)
{
bits.put( 1, hdr.raw_mode_flag ? 1 : 0 );
bits.put( 1, hdr.small_lut_flag ? 1 : 0 );
bits.fill( 102, 0 );
bits.put_masked( 8, hdr.zero_adjust );
for (int i = 0; i < LUT_MAX; i++)
{
bits.put( 9, fold(hdr.lut[i]) );
}
}
bool fwd_analyse(fwd_header_t &hdr, const int16_t *weights, int count, bitbuf_t &bits)
{
int range_min = 1000;
int range_max = -1000;
int8_t *inv = hdr.inv_index;
int16_t *lut = hdr.lut;
int lut_used = 0;
bool use_lut = true;
// Must check all the zero-point-correct values for full range
for (int i = 0; i < count; i++)
{
int value = weights[i];
range_min = std::min( range_min, value );
range_max = std::max( range_max, value );
// Update the LUT only while it's still viable (predicts well).
if ( use_lut )
{
// Map the signed value to the LUT via +ve indexed table
int idx = fold(value);
assert( idx < INV_MAX );
// Check if value has already been indexed before adding a
// new lut entry.
if ( inv[idx] < 0 )
{
if ( lut_used < LUT_MAX )
{
inv[idx] = lut_used;
lut[lut_used] = value;
lut_used++;
}
else
{
use_lut = false; // LUT was full and is now unusable
}
}
// While lut2 is valid, encode the entries. When we're
// done the bitstream will be ready.
if (lut_used <= 4)
{
bits.put(2, inv[idx]);
}
}
}
hdr.raw_mode_flag = !use_lut;
hdr.small_lut_flag = (lut_used <= 4);
hdr.zero_adjust = 0;
// If raw mode, calculate the zero point
if ( hdr.raw_mode_flag )
{
int full_range = (range_max - range_min);
if (full_range >= 256)
{
return false; // Can't encode this stream
}
else if ( range_min < -128 )
{
hdr.zero_adjust = -128 - range_min; // Raw values need offsetting +ve by this amount
}
else if ( range_max > 127 )
{
hdr.zero_adjust = 127 - range_max; // Raw values need offsetting -ve by this amount
}
}
return (range_min >= -256) && (range_max < 256);
}
// Encode zero-corrected weight values in the optimal fast-weight format.
int ml_encode_fwd(mle_context_t *ctx, bitbuf_t &bits, const int16_t *weights, int count, unsigned mlw_encode_flags)
{
fwd_header_t header;
int pos = bits.pos();
bits.fill(256, 0); // Reserve space for header
// Encode lut2 weights directly to the main stream while analysing
if ( !fwd_analyse(header, weights, count, bits) )
{
return -1; // Encoding error
}
// Check for forced no palette
if ( mlw_encode_flags & MLW_ENCODE_NO_PALETTE_LUT )
{
header.raw_mode_flag = 1;
header.small_lut_flag = 0;
}
// Use a substream of the main stream for the header
bitbuf_t hdr_bits(bits, pos);
fwd_emit_header(hdr_bits, header);
bits.sync(hdr_bits);
// LUT2
if ( header.small_lut_flag )
{
assert( !header.raw_mode_flag );
}
// RAW
else if ( header.raw_mode_flag )
{
bits.reposition(pos + 256);
for (int i=0; i < count; i++)
{
int value = (weights[i] + header.zero_adjust) & 0xFF;
bits.put(8, value);
}
}
// LUT4
else
{
bits.reposition(pos + 256);
for (int i=0; i < count; i++)
{
int idx = fold(weights[i]);
bits.put(4, header.inv_index[idx]);
}
}
bits.align(256, 0);
bits.flush();
int written = bits.pos() / 8;
return written;
}