简体   繁体   中英

Trying to understand how linq/deferred execution works

I have the following methods, part of the logic for performing stratified k-fold crossvalidation.

private static IEnumerable<IEnumerable<int>> GenerateFolds(
   IClassificationProblemData problemData, int numberOfFolds) 
{
   IRandom random = new MersenneTwister();
   IEnumerable<double> values = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);

   var valuesIndices = 
       problemData.TrainingIndices.Zip(values, (i, v) => new { Index = i, Value = v });

   IEnumerable<IEnumerable<IEnumerable<int>>> foldsByClass = 
        valuesIndices.GroupBy(x => x.Value, x => x.Index)
                     .Select(g => GenerateFolds(g, g.Count(), numberOfFolds));

   var enumerators = foldsByClass.Select(x => x.GetEnumerator()).ToList();

   while (enumerators.All(e => e.MoveNext())) 
   {
       var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
       yield return fold.ToList();
   }
}

Folds generation:

private static IEnumerable<IEnumerable<T>> GenerateFolds<T>(
    IEnumerable<T> values, int valuesCount, int numberOfFolds) 
{
    // number of folds rounded to integer and remainder
    int f = valuesCount / numberOfFolds, r = valuesCount % numberOfFolds; 
    int start = 0, end = f;

    for (int i = 0; i < numberOfFolds; ++i)
    {
        if (r > 0) 
        {
          ++end;
          --r;
        }

        yield return values.Skip(start).Take(end - start);
        start = end;
        end += f;
    }
 }

The generic GenerateFolds<T method simply splits an IEnumerable<T> into a sequence of IEnumerable s according to the specified number of folds. For example, if I had 101 training samples, it would generate one fold of size 11 and 9 folds of size 10.

The method above it groups the samples based on class values, splits each group into the specified number of folds and then joins the by-class folds into the final folds, ensuring the same distribution of class labels.

My question regards the line yield return fold.ToList() . As it is, the method works correctly, if I remove the ToList() however, the results are no longer correct. In my test case I have 641 training samples and 10 folds, which means the first fold should be of size 65 and the remaining folds of size 64. But when I remove ToList() , all the folds are of size 64 and class labels are not correctly distributed. Any ideas why? Thank you.

Lets think what is fold variable:

var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());

It is not a result of query execution. It's a query definition . Because both SelectMany and OrderBy are operators with deferred manner of execution. So, it just saves knowledge about flattening current items from all enumerators and returning them in random order. I have highlighted word current , because it's current item at the time of query execution.

Now lets think when this query will be executed. Result of GenerateFolds method execution is IEnumerable of IEnumerable<int> queries . Following code does not execute any of queries:

var folds = GenerateFolds(indices, values, numberOfFolds);

It's again just a query. You can execute it by calling ToList() or enumerating it:

var f = folds.ToList();

But even now inner queries are not executed. They are all returned, but not executed. Ie while loop in GenerateFolds has been executed while you saved queries to the list f . And e.MoveNext() has been called several times until you exited loop:

while (enumerators.All(e => e.MoveNext()))
{
    var fold = enumerators.SelectMany(e => e.Current).OrderBy(x => random.Next());
    yield return fold;
}

So, what f holds? It holds list of queries. And thus you have got them all, current item is the last item from each enumerator (remember - we have iterated while loop completely at this point of time). But none of these queries is executed yet! Here you execute first of them:

f[0].Count() 

You get count of items returned by first query (defined at the top of question). But thus you already enumerated all queries current item is the last item. And you get count of indexes in last item.

Now take a look on

folds.First().Count()

Here you don't enumerate all queries to save them in list. Ie while loop is executed only once and current item is the first item. That's why you have count of indexes in first item. And that's why these values are different.

Last question - why all works fine when you add ToList() inside your while loop. Answer is very simple - that executes each query. And you have list of indexes instead of query definition. Each query is executed on each iteration, thus current item is always different. And your code works fine.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM