www.gusucode.com > bigdata 工具箱 matlab源码程序 > bigdata/@tall/splitapply.m

    function varargout = splitapply(fun,varargin)
%SPLITAPPLY Split data into groups and apply function
%   Supported syntaxes for tall arrays:
%   Y = SPLITAPPLY(FUN,X,G)
%   Y = SPLITAPPLY(FUN,X1,X2,...,G)
%   [Y1,Y2,...] = SPLITAPPLY(FUN,...)
%
%   See also FINDGROUPS

%   Copyright 2016 The MathWorks, Inc.

narginchk(3,inf);
checkNotTall(upper(mfilename), 0, fun);
if ~isa(fun, 'function_handle')
    error(message('MATLAB:splitapply:InvalidFunction'));
end

if ~all(cellfun(@istall, varargin))
    error(message('MATLAB:bigdata:array:AllArgsTall', upper(mfilename)));
end

% This exists so that a scalar gnum is expanded to the full length. We need
% to do this as bykey and filter do not support singleton expansion.
% Further, this converts gnum to a categorical so that grouped operations
% can use the categories to determine the known groups instead unique of
% of the each chunk.
gnum = tall.validateType(varargin{end}, upper(mfilename), {'numeric'}, nargin);
gnum = slicefun(@iValidateInputsAndExtractGnumAsCategorical, varargin{1:end - 1}, gnum);
gnum.Adaptor = setSmallSizes(matlab.bigdata.internal.adaptors.CategoricalAdaptor(), 1);
varargin(end) = [];

% This is to support the syntax splitapply(fun, table(..), gnum)
varargin = iFlattenTableInputs(varargin);

[varargout{1:max(1, nargout)}] = iSplitApply(fun, gnum, varargin{:});

% The main implementation after input argument parsing.
function varargout = iSplitApply(fun, gnum, varargin)
markerFrame = matlab.bigdata.internal.InternalStackFrame(); %#ok<NASGU>

session = matlab.bigdata.internal.lazyeval.GroupedPartitionedArraySession(fun);
sessionCleanup = onCleanup(@session.close);

try
    [varargin{:}, gnum] = validateSameTallSize(varargin{:}, gnum);
    uniqueGnum = unique(gnum);

    [varargin{:}] = iWrapTallAsGroupedTall(gnum, session, varargin{:});
    [varargout{1:nargout}] = fun(varargin{:});
    for ii = 1:numel(varargout)
        varargout{ii} = iParseOutput(varargout{ii}, uniqueGnum, session);
    end
catch err
    err = matlab.bigdata.internal.lazyeval.parseExecutionError(err);
    matlab.bigdata.BigDataException.hThrowAsCallerWithSubmissionStack(err);
end

% Parse the output generated by the function handle.
%
% This is to handle both the case where the function handle emits a cell
% array or other non-tall array.
function data = iParseOutput(data, uniqueGnum, session)
if iscell(data)
    data = iParseCellOutput(data);
end

if istall(data) && isa(hGetValueImpl(data), 'matlab.bigdata.internal.lazyeval.GroupedPartitionedArray')
    data = iUnWrapGroupedTall(data);
else
    funStr = matlab.bigdata.internal.broadcast(func2str(session.FunctionHandle));
    data = clientfun(@iCreateOutputFromScalar, uniqueGnum, data, funStr);
end

% Parse a cell array output generated by the function handle.
function data = iParseCellOutput(data)
if ~iscell(data)
    return;
end

data = cellfun(@iParseCellOutput, data, 'UniformOutput', false);
if any(cellfun(@istall, data))
    % If the cell array contains any tall arrays, we need to bring the tall
    % attribute up one level so that the cell array is per group.
    data = clientfun(@(sz, varargin) reshape(varargin, sz), size(data), data{:});
end

% Wrap a collection of tall arrays each as a grouped tall array.
function varargout = iWrapTallAsGroupedTall(keys, session, varargin)
import matlab.bigdata.internal.lazyeval.GroupedPartitionedArray;
keys = hGetValueImpl(keys);
pv = cellfun(@hGetValueImpl, varargin, 'UniformOutput', false);
[varargout{1:nargout}] = GroupedPartitionedArray.create(keys, session, pv{:});
for ii = 1:numel(varargout)
    varargout{ii} = tall(varargout{ii});
    varargout{ii}.Adaptor = resetTallSize(varargin{ii}.Adaptor);
end

% Unwrap a grouped tall array
function out = iUnWrapGroupedTall(in)
gpv = hGetValueImpl(in);
funStr = matlab.bigdata.internal.broadcast(func2str(gpv.Session.FunctionHandle));
out = tall(clientfun(@iCreateOutput, gpv.Keys, gpv.Values, funStr));
out.Adaptor = resetTallSize(in.Adaptor);

% Validate and extract the gnum input. This also ensures all non-gnum inputs
% are not singleton in the tall dimension.
function gnum = iValidateInputsAndExtractGnumAsCategorical(varargin)
sz = size(varargin{1}, 1);
for ii = 2 : numel(varargin) - 1
    if size(varargin{ii}) ~= sz
        error(message('MATLAB:bigdata:array:IncompatibleTallStrictSize'));
    end
end

gnum = varargin{end};
gnumIsValid = isnumeric(gnum) && iscolumn(gnum) && all(isnan(gnum) | mod(gnum, 1) == 0 & gnum > 0);
% We set NaN values to hidden group gnum 0. This is so at the end we know
% whether NaN values existed in the input.
gnum(isnan(gnum)) = 0;
if ~gnumIsValid
    % Depending on the partitioning and order of execution, gnum could be invalid
    % either by being a non-scalar row, or a matrix. Here we throw a single
    % error that covers both cases to ensure a consistent error is returned.
    error(message('MATLAB:bigdata:array:SplitApplyUnsupportedGroupNums'));
end
if size(gnum, 1) == 1
    gnum = gnum .* ones(sz, 1);
elseif size(gnum, 1) ~= sz
    error(message('MATLAB:bigdata:array:IncompatibleTallStrictSize'));
end

gnum = categorical(gnum);

% Helper function that flattens all table inputs.
function flattenedInputs = iFlattenTableInputs(inputs)
flattenedInputs = cell(size(inputs));
for inputIndex = 1:numel(inputs)
    if strcmp(tall.getClass(inputs{inputIndex}), 'table')
        variableNames = inputs{inputIndex}.Adaptor.VariableNames;
        flattenedInputs{inputIndex} = cell(1, numel(variableNames));
        for varIndex = 1:numel(variableNames)
            flattenedInputs{inputIndex}{varIndex} = subsref(inputs{inputIndex}, substruct('.', variableNames{varIndex}));
        end
    else
        flattenedInputs{inputIndex} = inputs(inputIndex);
    end
end
flattenedInputs = [flattenedInputs{:}];

% Create the output from a scalar value. This is for user functions that do
% not return a grouped tall array.
function values = iCreateOutputFromScalar(gnum, values, funStr)
if size(values, 1) ~= 1
    idx = matlab.internal.tableUtils.ordinalString(1);
    error(message('MATLAB:bigdata:array:SplitApplyOutputNotUniform', funStr, idx));
end
values = repmat(values, numel(gnum), 1);
values = iCreateOutput(gnum, values, funStr);

% Create the output. This does no work except for error checking.
function values = iCreateOutput(gnum, values, funStr)
gnum = double(string(gnum));
[gnum, idx] = sort(gnum);
values = matlab.bigdata.internal.util.indexSlices(values, idx);

isNanGroup = (gnum == 0);
if any(isNanGroup)
    gnum(isNanGroup) = [];
    values = matlab.bigdata.internal.util.indexSlices(values, ~isNanGroup);
end

if isempty(gnum)
    if any(isNanGroup)
        return;
    else
        error(message('MATLAB:splitapply:InvalidGroupNums'));
    end
end

expectedGnum = (1:max(gnum))';
if ~isequal(gnum, expectedGnum)
    if ~isempty(setdiff(expectedGnum, gnum))
        error(message('MATLAB:splitapply:MissingGroupNums'));
    else
        idx = gnum(find(gnum(1:max(gnum)) ~= expectedGnum, 1, 'first'));
        idx = matlab.internal.tableUtils.ordinalString(idx);
        error(message('MATLAB:bigdata:array:SplitApplyOutputNotUniform', funStr, idx));
    end
end